How to Add Continuous Batching (Stepwise) Support for a Diffusion Model
Last updated: 06/15/2026.
This guide explains how to extend an existing FlowGRPO (or MixGRPO) model
integration so that it supports continuous-batching (stepwise) rollout
in vllm-omni. You must already have a working full-forward integration
following
integrating_a_diffusion_model.md.
We use the Qwen-Image FlowGRPO stepwise adapter
(verl_omni/experimental/qwen_image_flow_grpo_stepwise/)
as the worked example.
TL;DR
A stepwise adapter needs one file in one experimental package plus one registry entry:
verl_omni/experimental/<model>_<algo>_stepwise/
├── __init__.py # re-exports the stepwise adapter
└── vllm_omni_rollout_adapter.py # subclass of the standard rollout adapter
The adapter is picked up by a _stepwise registry alias that the async server
selects when step_execution=True in the rollout config. Only the rollout side
changes — the training adapter is unchanged.
Mental Model
In standard (full-forward) mode, vllm-omni calls forward() once per request,
which runs the entire SDE diffusion loop internally. In stepwise (continuous-
batching) mode, the engine interleaves one denoising step from each in-flight
request rather than running one request’s full trajectory before starting the
next.
Full-forward mode Stepwise (CB) mode
──────────────── ─────────────────
for each request: for each step:
forward() ← runs full SDE loop for each in-flight request:
denoise_step()
step_scheduler()
post_decode() ← packages trajectory
The stepwise engine calls four methods in sequence per request lifetime:
Phase |
Method |
Purpose |
|---|---|---|
Setup |
|
Encode prompt, initialise latents, scheduler, SDE state |
Per-step |
|
Run transformer forward, return noise prediction |
Per-step |
|
One scheduler step with SDE noise + log-prob bookkeeping |
Finalise |
|
Decode final latents, package trajectory into |
Your stepwise adapter must override at least prepare_encode, step_scheduler,
and post_decode. denoise_step only needs an override if your model has
non-standard CFG or transformer kwargs.
Prerequisites
A working full-forward integration with both
diffusers_training_adapter.pyandvllm_omni_rollout_adapter.pyregistered under the base algorithm (e.g.flow_grpo).Your standard rollout adapter’s
forward()/diffuse()must already collectall_latents,all_log_probs, andall_timesteps— the stepwise overrides must replicate this bookkeeping.
Step 1 — Understand the Engine Contract
The stepwise engine (vllm-omni side) manages a pool of DiffusionRequestState
objects. Each denoising step it:
Calls
denoise_step(input_batch)on the batched latents/timesteps — your override receives a batchedinput_batchwith.latents,.timesteps,.guidance,.prompt_embeds, etc.Calls
step_scheduler(state, noise_pred)per request with the noise prediction sliced back to per-request shape.
The engine feeds noise_pred from step 1 into each request’s step 2, then
re-gathers state.latents for the next round. After the final step it calls
post_decode(state) and ships the DiffusionOutput to the HTTP server.
Your prepare_encode must populate all fields on state that the
per-step methods consume — the engine never calls forward() so there is no
fallback initialisation.
Step 2 — Scaffold the Experimental Package
Create the package under verl_omni/experimental/:
verl_omni/experimental/<model>_<algo>_stepwise/
├── __init__.py
└── vllm_omni_rollout_adapter.py
The __init__.py re-exports the stepwise class:
from .vllm_omni_rollout_adapter import MyModelPipelineWithLogProbStepwise
__all__ = ["MyModelPipelineWithLogProbStepwise"]
Register the package by importing it from
verl_omni/experimental/__init__.py:
from . import my_model_flow_grpo_stepwise
from .my_model_flow_grpo_stepwise import * # noqa: F401, F403
__all__ += list(my_model_flow_grpo_stepwise.__all__)
Note. The
verl_omni/__init__.pyalready importsverl_omni.experimental, so no additional wiring to__init__.pyis needed.
Step 3 — Write the Stepwise Adapter
Subclass your standard rollout adapter and register with the _stepwise
algorithm suffix:
from verl_omni.pipelines.model_base import VllmOmniPipelineBase
from verl_omni.pipelines.my_model_flow_grpo.vllm_omni_rollout_adapter import (
MyModelPipelineWithLogProb,
)
@VllmOmniPipelineBase.register("MyModelPipeline", algorithm="flow_grpo_stepwise")
class MyModelPipelineWithLogProbStepwise(MyModelPipelineWithLogProb):
...
The architecture string must match your model_index.json::_class_name exactly
(same as the standard adapter). The algorithm suffix is always
<base_algorithm>_stepwise.
3.1 prepare_encode
This is the most involved override. It must:
Accept pre-tokenized
prompt_idsfromstate.prompts[0](a dict with keysprompt_token_ids,prompt_mask,negative_prompt_ids,negative_prompt_mask). Provide a fallback that tokenizes raw text prompts for the engine’s dummy warm-up run.Encode prompts via
encode_prompt()— return padded(B, L, D)+(B, L)mask tensors.Initialise latents (random noise in fp32).
Prepare timesteps from
num_inference_steps.Deep-copy the scheduler per request so concurrent requests don’t share mutable state.
Set RoPE sequence lengths from padded embed width, not from
mask.sum(). See RoPE text length mismatch below.Resolve SDE/log-prob knobs from
sampling.extra_args(noise_level, sde_window_size, sde_window_range, sde_type, logprobs).Populate
statewith every field thatstep_scheduler/denoise_step/post_decodewill read:prompt_embeds,prompt_embeds_mask,negative_prompt_embeds,negative_prompt_embeds_mask,latents,timesteps,step_index,scheduler,do_true_cfg,guidance,img_shapes,txt_seq_lens,negative_txt_seq_lens,sde_window,noise_level,sde_type,logprobs, and empty lists forall_latents,all_log_probs,all_timesteps.Persist the generator on
state.sampling.generatorsostep_schedulerdraws from the same RNG stream.For MixGRPO: call
_maybe_make_progressive_window()inprepare_encodebefore delegating viasuper(). See § 3.5.
The canonical reference is
QwenImagePipelineWithLogProbStepwise.prepare_encode.
RoPE text length. In continuous batching, vllm-omni pads prompt embeddings to a shared
target_seq_len. If you compute RoPEtxt_seq_lensfrommask.sum()(valid token count), each request gets a different RoPE length even though embeddings share the same width — tokens beyond position 50 get wrong positional encoding, causing a rollout/training mismatch inppo_kl. Always useprompt_embeds.shape[1]instead.
3.2 step_scheduler
Overrides the default (vanilla scheduler step) to mirror the per-iteration
body of diffuse():
Respect
sde_window: noise_level is 0.0 outside the window, and the configurednoise_levelinside.Log the initial latent when entering the SDE window.
Call
scheduler.step()withnoise_predcast to fp32.Store trajectory in fp32:
new_latents.float()goes intostate.all_latents; log-prob and timestep go into their respective lists.Keep
state.latentsin fp32 — NOT model dtype. Under continuous batching the engine gathers latents across all in-flight requests; a freshly-added request has fp32 latents while a stepped request would have bf16 latents, producing a “Mixed dtypes” error. Keeping fp32 throughout makes the batch dtype consistent.Advance
state.step_index.
The canonical reference is
QwenImagePipelineWithLogProbStepwise.step_scheduler.
3.3 denoise_step (usually optional)
Override only if your model needs non-standard transformer kwargs or CFG
logic that differs from the parent class. The default QwenImagePipeline
implementation works for most models.
If you do override:
Cast
input_batch.latentsto the transformer’s weight dtype for the forward pass (bf16).Build positive/negative kwargs via your model’s
_build_denoise_kwargs.Call
predict_noise_maybe_with_cfgfor CFG combination.Return
noise_pred.float()—step_schedulerexpects fp32.
3.4 post_decode
Packages the trajectory collected during step_scheduler:
Call
super().post_decode(state)for VAE decoding.Stack
all_latents,all_log_probs,all_timestepsfrom the state lists.Populate
output.custom_outputwith the same keys thatforward()produces:all_latents,all_log_probs,all_timesteps,prompt_embeds,prompt_embeds_mask,negative_prompt_embeds,negative_prompt_embeds_mask.Move tensors to CPU so the receiving HTTP server process does not initialise CUDA context on GPU 0.
Why
custom_outputmatters. In stepwise mode the engine ships theDiffusionOutputacross an inter-process MessageQueue. Downstream consumers (vllm_omni_async_server.generate→embeds_padding_2_no_padding) read these field names verbatim. If any field isNone, it becomes a non-tensorLinkedListin the trainingTensorDict, breakingmask.shape[0].
3.5 MixGRPO: Window Positioning
MixGRPO requires all rollouts in a batch to share one SDE window for correct
advantage estimation. In full-forward mode _maybe_make_progressive_window()
runs inside forward(). In stepwise mode forward() is never called, so the
window is never set.
The MixGRPO stepwise adapter must call _maybe_make_progressive_window() in
prepare_encode before delegating via super(). The canonical pattern
uses multiple inheritance:
@VllmOmniPipelineBase.register("QwenImagePipeline", algorithm="mix_grpo_stepwise")
class QwenImageMixGRPOPipelineWithLogProbStepwise(
QwenImageMixGRPOPipelineWithLogProb, # window logic
QwenImagePipelineWithLogProbStepwise, # stepwise overrides
):
def prepare_encode(self, state, **kwargs):
# Fix the SDE window before stepwise prepare_encode draws it
if state.sampling is not None:
if state.sampling.extra_args is None:
state.sampling.extra_args = {}
self._maybe_make_progressive_window(state.sampling.extra_args, kwargs)
return super().prepare_encode(state, **kwargs)
Step 4 — Enable Stepwise Mode
No code changes in the launcher. At runtime:
python3 -m verl_omni.trainer.main_diffusion \
actor_rollout_ref.rollout.step_execution=true \
...
The wiring:
DiffusionRolloutConfig.step_execution(defaultFalse) is read byvLLMOmniHttpServer.run_server, which setsengine_args["step_execution"] = True.DiffusionRolloutConfig.resolve_algorithm()checks whether a<algorithm>_stepwiseclass is registered for the architecture. If so, it updatesmodel_config.algorithmin-place to the stepwise variant.VllmOmniPipelineBase.get_pipeline_path(architecture, algorithm)resolves to your stepwise adapter’s dotted path, which is passed to the vllm-omni engine ascustom_pipeline_args.pipeline_class.
If no _stepwise class is registered, resolve_algorithm is a no-op and the
engine falls back to the standard full-forward pipeline (stepwise mode is not
available for that model/algorithm pair).
Step 5 — Verify Parity
Before claiming stepwise support is complete, verify that step_execution=True
produces trajectories identical to step_execution=False:
fp32 latent storage. Confirm
all_latentsare fp32, not bf16.ratio_mean ≈ 1.0at step 1.Log-prob parity.
all_log_probsshould match between the two modes (within numerical tolerance).Prompt embeddings.
prompt_embeds/prompt_embeds_maskshapes and values must be identical.MixGRPO window. All rollouts in a batch must share the same SDE window.
Tokenizer fallback. The warm-up (dummy) path must produce the same tokenization as the diffusers pipeline.
Add a smoke test under tests/special_e2e/ that runs with
step_execution=true and asserts exit code 0.
When to Refactor Instead of Duplicating
The stepwise adapter pattern currently requires ~400+ lines of code that
mostly duplicate the standard adapter’s diffuse() / forward() logic.
This is a known maintenance burden. The duplication will shrink once
vllm-omni provides native continuous-batching hooks (e.g. --skip-tokenizer-init,
prompt_token_ids support). Until then:
Keep the stepwise adapter as thin as possible — delegate to shared helpers in the parent class whenever feasible.
If you find yourself copying more than a few methods, factor the shared logic into a
common.pyin the standard pipeline package and import it from both adapters.File a feature request with vllm-omni for any missing CB hooks you need, and link it in a TODO comment.
Relationship to Other Guides
integrating_a_diffusion_model.md— Prerequisite: the standard full-forward integration.integrating_a_new_policy_gradient_algorithm_for_diffusion_model.md— If your algorithm needs a custom stepwise adapter for a policy-gradient method other than FlowGRPO/MixGRPO.common_pitfalls.md— Known issues specific to stepwise mode (fp32 latency storage, RoPE mismatch, MixGRPO window bypass, tokenizer fallback, device placement).