Rollout Correction for Diffusion Training (Experimental)
Last updated: 05/19/2026
Status: Experimental. The API, default thresholds and recommended preset may change.
Why
A FlowGRPO training step has three log-probability sources:
Rollout policy — vllm / vllm_omni sample with low-precision kernels (e.g. fp8 / bf16, tensor parallelism).
Old policy recompute — the actor re-runs the same trajectories under its full-precision training graph to produce
old_log_probs.Current policy — recomputed every actor mini-step to drive PPO ratios.
The recompute pass (step 2) typically costs ~20 % of the per-step time. Setting
algorithm.rollout_correction.bypass_mode=True skips it and reuses the rollout backend’s
log-probs directly as old_log_probs, which yields the largest single training-time saving
but introduces an off-policy bias because the rollout and training stacks evaluate the same
trajectory slightly differently.
Rollout Correction addresses the off-policy bias with two orthogonal mechanisms:
Importance Sampling (IS) — multiply per-sample loss by a clipped ratio
clamp(exp(old_logp - rollout_logp), ...).Rejection Sampling (RS) — zero out loss for samples whose log-ratio falls outside a configurable band, so the optimizer never sees extreme outliers.
The two are orthogonal and can be combined.
Quickstart
Enable on top of any FlowGRPO run by adding two blocks of overrides:
algorithm.rollout_correction.bypass_mode=True \
algorithm.rollout_correction.rollout_is=sequence \
algorithm.rollout_correction.rollout_rs=seq_mean_k1 \
algorithm.rollout_correction.rollout_rs_threshold="0.5_2.0"
Note on
rollout_isin bypass mode: Whenbypass_mode=True, the PPO ratioexp(current − rollout)already serves as the IS correction. Therollout_issetting is used for IS diagnostics only; weights are not applied to the loss. Onlyrollout_rsrejection sampling affects the gradient.
A runnable end-to-end example lives at
examples/flowgrpo_trainer/run_qwen_image_ocr_lora_rollout_corr.sh.
Config reference
All config keys live under algorithm.rollout_correction and mirror the upstream
verl schema exactly. See the upstream documentation for the full reference:
Config keys & usage: Rollout Correction
Mathematical formulation: Rollout Correction Math
The only diffusion-specific notes are in the tuning guide below.
Logged metrics
Metric |
Meaning |
|---|---|
|
Post-clip IS weight stats. |
|
Effective sample size of IS weights. |
|
Token-level fraction of steps rejected by RS. |
|
Sequence-level fraction rejected by RS. |
|
KL(π_rollout ‖ π_old) — direct off-policy drift estimator. |
|
K3 KL estimator (more stable for small KL). |
|
Mean per-sequence log-PPL difference (rollout − old). |
|
χ² divergence at token- and sequence-level. |
In bypass mode metrics are computed per SDE step inside diffusion_loss
and appear under actor/rollout_corr/*. In decoupled mode they are
emitted once per global batch under rollout_corr/*.
If rollout_corr/rollout_rs_seq_masked_fraction is consistently above ~5 %, the
rollout backend is drifting too far — tighten the RS band or fall back to
bypass_mode=False.
Gradient dilution note: RS rejection zeroes the per-element loss for rejected samples but does not remove them from the
mean()denominator. At high sustained rejection rates the effective gradient magnitude decreases by the factorkept / total. Monitorrollout_corr/rollout_rs_seq_masked_fractionand widen the RS band if it exceeds ~10 % over several steps.
Hyperparameter notes
Defaults (rollout_is_threshold=2.0, loss_type=ppo_clip) transfer well because:
The helper operates on the log-ratio directly (unit-less).
Diffusion log-probs are mean-pooled across latent dimensions, so per-step variance is lower than per-token LLM log-probs.
Diffusion-specific tuning guide
The SDE window is short (sde_window_size is usually 2), which changes the
statistical behaviour of several RS modes:
Concern |
Recommendation |
|---|---|
|
The LLM default |
Bypass mode RS |
In bypass mode, IS/RS is computed per SDE step with shape |
Token-level RS ( |
With only 2 tokens, token-level statistics have very low power — a single token cannot be rejected in isolation because the per-token stat is averaged from thousands of latent dims. Prefer |
|
The product of 2 per-step ratios. With diffusion’s low per-step variance this is usually well-behaved; the default threshold of 2.0 is generous. |
First-run diagnostics |
Always inspect |
How it plugs in
Bypass entrypoint.
apply_bypass_mode_to_diffusion_batchsetsold_log_probs := rollout_log_probs(zero-cost). The trainer-side decoupled correction is skipped becauseold == rolloutwould be a no-op.Per-step correction.
diffusion_lossreadsconfig.rollout_correctionand computes IS/RS per SDE step viacompute_rollout_correction_and_rejection_mask. Forppo_cliponly the RS mask is applied; the PPO ratioexp(current − rollout)handles IS.Decoupled correction.
apply_rollout_correction_to_diffusion_batchruns once per global batch usingold_log_probsvsrollout_log_probsand stashes a combinedrollout_is_weightstensor.Loss application.
flow_grpo/grpo_guardmultiply per-element loss by (detached)rollout_is_weights. Diffusion has no padding so rejection is weight=0 — no separate mask needed.
The config lives on DiffusionActorConfig.rollout_correction (imported from
verl). No dedicated loss registration or engine modifications are required.