Rollout Correction for Diffusion Training (Experimental)

Last updated: 05/19/2026

Status: Experimental. The API, default thresholds and recommended preset may change.

Why

A FlowGRPO training step has three log-probability sources:

  1. Rollout policy — vllm / vllm_omni sample with low-precision kernels (e.g. fp8 / bf16, tensor parallelism).

  2. Old policy recompute — the actor re-runs the same trajectories under its full-precision training graph to produce old_log_probs.

  3. Current policy — recomputed every actor mini-step to drive PPO ratios.

The recompute pass (step 2) typically costs ~20 % of the per-step time. Setting algorithm.rollout_correction.bypass_mode=True skips it and reuses the rollout backend’s log-probs directly as old_log_probs, which yields the largest single training-time saving but introduces an off-policy bias because the rollout and training stacks evaluate the same trajectory slightly differently.

Rollout Correction addresses the off-policy bias with two orthogonal mechanisms:

  • Importance Sampling (IS) — multiply per-sample loss by a clipped ratio clamp(exp(old_logp - rollout_logp), ...).

  • Rejection Sampling (RS) — zero out loss for samples whose log-ratio falls outside a configurable band, so the optimizer never sees extreme outliers.

The two are orthogonal and can be combined.

Quickstart

Enable on top of any FlowGRPO run by adding two blocks of overrides:

algorithm.rollout_correction.bypass_mode=True \
algorithm.rollout_correction.rollout_is=sequence \
algorithm.rollout_correction.rollout_rs=seq_mean_k1 \
algorithm.rollout_correction.rollout_rs_threshold="0.5_2.0"

Note on rollout_is in bypass mode: When bypass_mode=True, the PPO ratio exp(current rollout) already serves as the IS correction. The rollout_is setting is used for IS diagnostics only; weights are not applied to the loss. Only rollout_rs rejection sampling affects the gradient.

A runnable end-to-end example lives at examples/flowgrpo_trainer/run_qwen_image_ocr_lora_rollout_corr.sh.

Config reference

All config keys live under algorithm.rollout_correction and mirror the upstream verl schema exactly. See the upstream documentation for the full reference:

The only diffusion-specific notes are in the tuning guide below.

Logged metrics

Metric

Meaning

rollout_corr/rollout_is_mean / rollout_is_max / rollout_is_min

Post-clip IS weight stats.

rollout_corr/rollout_is_eff_sample_size

Effective sample size of IS weights.

rollout_corr/rollout_rs_masked_fraction

Token-level fraction of steps rejected by RS.

rollout_corr/rollout_rs_seq_masked_fraction

Sequence-level fraction rejected by RS.

rollout_corr/kl

KL(π_rollout ‖ π_old) — direct off-policy drift estimator.

rollout_corr/k3_kl

K3 KL estimator (more stable for small KL).

rollout_corr/log_ppl_diff

Mean per-sequence log-PPL difference (rollout − old).

rollout_corr/chi2_token / chi2_seq

χ² divergence at token- and sequence-level.

In bypass mode metrics are computed per SDE step inside diffusion_loss and appear under actor/rollout_corr/*. In decoupled mode they are emitted once per global batch under rollout_corr/*.

If rollout_corr/rollout_rs_seq_masked_fraction is consistently above ~5 %, the rollout backend is drifting too far — tighten the RS band or fall back to bypass_mode=False.

Gradient dilution note: RS rejection zeroes the per-element loss for rejected samples but does not remove them from the mean() denominator. At high sustained rejection rates the effective gradient magnitude decreases by the factor kept / total. Monitor rollout_corr/rollout_rs_seq_masked_fraction and widen the RS band if it exceeds ~10 % over several steps.

Hyperparameter notes

Defaults (rollout_is_threshold=2.0, loss_type=ppo_clip) transfer well because:

  • The helper operates on the log-ratio directly (unit-less).

  • Diffusion log-probs are mean-pooled across latent dimensions, so per-step variance is lower than per-token LLM log-probs.

Diffusion-specific tuning guide

The SDE window is short (sde_window_size is usually 2), which changes the statistical behaviour of several RS modes:

Concern

Recommendation

seq_mean_k1 with window=2 (decoupled mode only)

The LLM default "0.5_2.0" means the mean log-ratio over only 2 steps must lie in [−0.69, 0.69]. A single outlier step can reject the entire sample. If rollout_corr/rejected_ratio is high, widen to e.g. "0.3_3.0" or "0.2_5.0".

Bypass mode RS

In bypass mode, IS/RS is computed per SDE step with shape (B, 1), so all RS modes (token_k1, seq_mean_k1, etc.) are effectively equivalent — each step is evaluated independently. seq_mean_* is recommended for consistency if you plan to switch between modes.

Token-level RS (token_k1, etc.)

With only 2 tokens, token-level statistics have very low power — a single token cannot be rejected in isolation because the per-token stat is averaged from thousands of latent dims. Prefer seq_mean_* or seq_max_* modes (decoupled) or any mode (bypass — all are per-step).

rollout_is=sequence

The product of 2 per-step ratios. With diffusion’s low per-step variance this is usually well-behaved; the default threshold of 2.0 is generous.

First-run diagnostics

Always inspect rollout_corr/log_ppl_abs_diff and rollout_corr/rollout_is_max for the first 50 steps of a new recipe. If log_ppl_abs_diff > 1.0 or rollout_is_max is pinned at the clamp threshold, the rollout-training gap is larger than expected — consider lowering the rollout precision gap or falling back to bypass_mode=False.

How it plugs in

  1. Bypass entrypoint. apply_bypass_mode_to_diffusion_batch sets old_log_probs := rollout_log_probs (zero-cost). The trainer-side decoupled correction is skipped because old == rollout would be a no-op.

  2. Per-step correction. diffusion_loss reads config.rollout_correction and computes IS/RS per SDE step via compute_rollout_correction_and_rejection_mask. For ppo_clip only the RS mask is applied; the PPO ratio exp(current rollout) handles IS.

  3. Decoupled correction. apply_rollout_correction_to_diffusion_batch runs once per global batch using old_log_probs vs rollout_log_probs and stashes a combined rollout_is_weights tensor.

  4. Loss application. flow_grpo / grpo_guard multiply per-element loss by (detached) rollout_is_weights. Diffusion has no padding so rejection is weight=0 — no separate mask needed.

The config lives on DiffusionActorConfig.rollout_correction (imported from verl). No dedicated loss registration or engine modifications are required.