Common Pitfalls
Last updated: 05/22/2026.
Float32 precision loss in stored rollout latents
Symptom
Training metrics show a systematic negative bias at step 1 (before any weight update):
actor/ratio_meanconsistently below1.0(e.g.0.99996)actor/ppo_klandactor/pg_clipfracinflated at step 1actor/pg_clipfrac_higheris zero — all clipping on the lower sideMost visible with rollout correction (
bypass_mode=True), but also degrades stored trajectory precision in standard training.
Root cause
FlowMatchSDEDiscreteScheduler.step() computes log_prob in float32
using the fp32 prev_sample, then casts prev_sample back to
model_output.dtype (bfloat16) before returning. The stored latents
lose precision, creating a mismatch with the log-prob computation.
Fix
Two changes in the scheduler, one in the rollout adapter. The training adapter is unchanged — it already uses fp32 correctly.
1. Scheduler — step() no longer truncates prev_sample to bfloat16,
and sample_previous_step() asserts model_output is float32 so callers
cannot accidentally pass lower precision.
2. Rollout adapter — latents are cast to the transformer’s native dtype before the forward pass (performance), noise_pred is cast to float32 before the scheduler (precision), and all stored latents are in float32.
Verification
The fix eliminates the systematic precision-loss bias from the scheduler.
In non-bypass mode (no rollout correction) ratio_mean ≈ 1.0 at step 1.
In bypass mode a ~3×10⁻⁵ KL divergence remains due to the vLLM vs PyTorch
attention kernel difference, which is unavoidable when using different
inference backends.
Metric |
Before fix (bypass) |
After fix (bypass) |
No bypass |
|---|---|---|---|
|
~3.6×10⁻⁵ |
~3.3×10⁻⁵ |
~1×10⁻⁶ |
|
~12% |
~9% |
~1% |