Trainer Interface
Last updated: Jul 03, 2026 (API docstrings are auto-generated).
VeRL-Omni provides Ray-based trainers for diffusion / multimodal RL.
TaskRunner builds worker mappings and
dispatches to a trainer subclass selected by algorithm.trainer_type:
policy_gradient→PolicyGradientRayTrainer(FlowGRPO, MixGRPO, DanceGRPO, GRPO-Guard; multi-timestep reverse-process PG)direct_preference→DirectPreferenceRayTrainer(DPO, DiffusionNFT, AWM; single forward-timestep preference updates)
Both subclasses inherit shared worker init from
BaseRayDiffusionTrainer.
Rollout and reward engines are initialized only when algorithm.sample_source=online.
|
Common Ray trainer infrastructure for diffusion training. |
|
Policy-gradient diffusion trainer for FlowGRPO, MixGRPO, DanceGRPO, GRPO-Guard, etc. |
|
Direct-preference diffusion trainer for DPO, DiffusionNFT, AWM, etc. |
Ray remote class for executing distributed diffusion training tasks. |
Base Ray Diffusion Trainer
BaseRayDiffusionTrainer
owns colocated actor/ref worker setup, dataloaders, validation helpers, and
checkpointing. init_workers always builds actor/ref workers; rollout and
reward engines are added only when algorithm.sample_source=online.
- class verl_omni.trainer.diffusion.ray_diffusion_trainer.BaseRayDiffusionTrainer(config, tokenizer, role_worker_mapping: dict[~verl.trainer.ppo.utils.Role, type[~verl.single_controller.base.worker.Worker]], resource_pool_manager: ~verl.single_controller.ray.base.ResourcePoolManager, ray_worker_group_cls: type[~verl.single_controller.ray.base.RayWorkerGroup] = <class 'verl.single_controller.ray.base.RayWorkerGroup'>, processor=None, train_dataset: ~torch.utils.data.dataset.Dataset | None = None, val_dataset: ~torch.utils.data.dataset.Dataset | None = None, collate_fn=None, train_sampler: ~torch.utils.data.sampler.Sampler | None = None, device_name=None)[source]
Common Ray trainer infrastructure for diffusion training.
Paradigm-specific trainers own the training loop while sharing worker initialization, validation, checkpointing, and logging behavior.
- __init__(config, tokenizer, role_worker_mapping: dict[~verl.trainer.ppo.utils.Role, type[~verl.single_controller.base.worker.Worker]], resource_pool_manager: ~verl.single_controller.ray.base.ResourcePoolManager, ray_worker_group_cls: type[~verl.single_controller.ray.base.RayWorkerGroup] = <class 'verl.single_controller.ray.base.RayWorkerGroup'>, processor=None, train_dataset: ~torch.utils.data.dataset.Dataset | None = None, val_dataset: ~torch.utils.data.dataset.Dataset | None = None, collate_fn=None, train_sampler: ~torch.utils.data.sampler.Sampler | None = None, device_name=None)[source]
Initialize distributed PPO trainer with Ray backend. Note that this trainer runs on the driver process on a single CPU/GPU node.
- Parameters:
config – Configuration object containing training parameters.
tokenizer – Tokenizer used for encoding and decoding text.
role_worker_mapping (dict[Role, WorkerType]) – Mapping from roles to worker classes.
resource_pool_manager (ResourcePoolManager) – Manager for Ray resource pools.
ray_worker_group_cls (RayWorkerGroup, optional) – Class for Ray worker groups. Defaults to RayWorkerGroup.
processor – Optional data processor, used for multimodal data
train_dataset (Optional[Dataset], optional) – Training dataset. Defaults to None.
val_dataset (Optional[Dataset], optional) – Validation dataset. Defaults to None.
collate_fn – Function to collate data samples into batches.
train_sampler (Optional[Sampler], optional) – Sampler for the training dataset. Defaults to None.
device_name (str, optional) – Device name for training (e.g., “cuda”, “cpu”). Defaults to None.
Policy Gradient Ray Trainer
PolicyGradientRayTrainer
implements the online training loop for FlowGRPO-style algorithms: rollout
generation, reward scoring, advantage estimation over denoising timesteps, and
actor updates.
- class verl_omni.trainer.diffusion.ray_diffusion_trainer.PolicyGradientRayTrainer(config, tokenizer, role_worker_mapping: dict[~verl.trainer.ppo.utils.Role, type[~verl.single_controller.base.worker.Worker]], resource_pool_manager: ~verl.single_controller.ray.base.ResourcePoolManager, ray_worker_group_cls: type[~verl.single_controller.ray.base.RayWorkerGroup] = <class 'verl.single_controller.ray.base.RayWorkerGroup'>, processor=None, train_dataset: ~torch.utils.data.dataset.Dataset | None = None, val_dataset: ~torch.utils.data.dataset.Dataset | None = None, collate_fn=None, train_sampler: ~torch.utils.data.sampler.Sampler | None = None, device_name=None)[source]
Policy-gradient diffusion trainer for FlowGRPO, MixGRPO, DanceGRPO, GRPO-Guard, etc.
Direct Preference Ray Trainer
DirectPreferenceRayTrainer
is the extension point for direct-preference algorithms (DPO, DiffusionNFT, AWM)
that train with single forward-timestep updates rather than a full multi-step
SDE trajectory. The fit implementation is not yet available in-tree.
- class verl_omni.trainer.diffusion.ray_diffusion_trainer.DirectPreferenceRayTrainer(config, *args, **kwargs)[source]
Direct-preference diffusion trainer for DPO, DiffusionNFT, AWM, etc.
- verl_omni.trainer.diffusion.ray_diffusion_trainer.compute_advantage(data: DataProto, adv_estimator: str, norm_adv_by_std_in_grpo: bool = True, global_std: bool = True, config: DiffusionAlgoConfig | None = None) DataProto[source]
Compute advantage estimates for diffusion policy optimization.
This function computes advantage estimates for diffusion models using the registered advantage estimator (e.g., Flow-GRPO). The advantage estimates are used to guide policy optimization across denoising timesteps.
- Parameters:
data (DataProto) – The data containing batched diffusion model outputs and inputs.
adv_estimator (str) – Name of the advantage estimator to use (e.g., Flow-GRPO).
norm_adv_by_std_in_grpo (bool, optional) – Whether to normalize advantages by standard deviation in GRPO. Defaults to True.
global_std (bool, optional) – Whether to use global standard deviation for normalization. Defaults to True.
config (DiffusionAlgoConfig, optional) – Configuration object for algorithm settings. Defaults to None.
- Returns:
The updated data with computed
advantagesandreturnsin its batch.- Return type:
DataProto
Entry Point
Entrypoint for diffusion model RL training.
- class verl_omni.trainer.main_diffusion.TaskRunner[source]
Ray remote class for executing distributed diffusion training tasks.
This class encapsulates the main training logic and runs as a Ray remote actor to enable distributed execution across multiple nodes and GPUs.
- role_worker_mapping
Dictionary mapping Role enums to Ray remote worker classes
- mapping
Dictionary mapping Role enums to resource pool IDs for GPU allocation
- add_actor_rollout_worker(config)[source]
Add actor (and optional rollout/ref) workers using the unified model engine.
- add_ref_policy_worker(config, ref_policy_cls)[source]
Add reference policy worker if KL loss or KL reward is used.
- verl_omni.trainer.main_diffusion.main(config)[source]
Main entry point for diffusion model training with Hydra configuration management.
- Parameters:
config – Hydra configuration dictionary containing training parameters.
- verl_omni.trainer.main_diffusion.run_diffusion(config, task_runner_class=None) None[source]
Initialize Ray and run distributed diffusion training.
- Parameters:
config – Training configuration object containing all necessary parameters for distributed diffusion training including Ray initialization settings, model paths, and training hyperparameters.
task_runner_class – For recipe to change TaskRunner.
Diffusion Algorithms
The verl_omni.trainer.diffusion.diffusion_algos module provides the
loss-function and advantage-estimator registries used by the trainer. Custom
losses and advantage estimators can be registered via the decorators below.
Diffusion-specific loss functions and KL penalties.
- class verl_omni.trainer.diffusion.diffusion_algos.DiffusionAdvantageEstimator(value)[source]
Advantage estimators specific to diffusion-based training.
- class verl_omni.trainer.diffusion.diffusion_algos.DiffusionLossFn[source]
Abstract base for worker-side diffusion loss functions.
- abstractmethod classmethod compute_loss(**kwargs: Any) tuple[Tensor, dict[str, Any]][source]
Compute the pure mathematical loss and related metrics.
Subclasses define concrete tensor arguments (e.g.
old_log_prob,log_prob,advantages) in their implementation.
- static prepare_actor_batch(batch: DataProto, reward_tensor: Tensor, config: Any) DataProto[source]
Prepare rollout outputs for actor update when the trainer has not already done so.
Reverse-process policy-gradient losses such as FlowGRPO can keep the batch unchanged because their trainer path has already added
old_log_probsandadvantages. Offline DPO can also keep the batch unchanged because offline preference data plus reference predictions provide the loss inputs directly. Forward-process online losses such as DiffusionNFT, and online DPO override this hook to turn final-latent rollouts and rewards into loss-specific actor tensors.
- class verl_omni.trainer.diffusion.diffusion_algos.DiffusionLossResult(loss: Tensor, metrics: dict[str, Any], add_loss_metric: bool = False)[source]
Output from a batch-aware diffusion loss function.
- class verl_omni.trainer.diffusion.diffusion_algos.FlowGRPOLoss[source]
Flow-GRPO clipped policy objective.
- classmethod compute_loss(*, old_log_prob: Tensor, log_prob: Tensor, advantages: Tensor, config: DiffusionActorConfig, rollout_is_weights: Tensor | None = None) tuple[Tensor, dict[str, Any]][source]
Compute the clipped policy objective and related metrics for FlowGRPO.
Adapted from https://github.com/yifan123/flow_grpo/blob/main/scripts/train_sd3_fast.py#L885
- Parameters:
old_log_prob (torch.Tensor) – Log-probabilities of actions under the old policy, shape (batch_size,).
log_prob (torch.Tensor) – Log-probabilities of actions under the current policy, shape (batch_size,).
advantages (torch.Tensor) – Advantage estimates for each action, shape (batch_size,).
config (verl_omni.workers.config.DiffusionActorConfig) – Config for the actor.
rollout_is_weights (Optional[torch.Tensor]) – Optional Rollout Correction multiplier (same shape as
log_prob) combining IS weights and RS rejection (rejected samples have weight 0). When provided, the per-element policy loss is multiplied by these (detached) weights before the mean reduction.
- class verl_omni.trainer.diffusion.diffusion_algos.GRPOGuardLoss[source]
GRPO-Guard clipped policy objective with reverse-SDE mean drift.
- classmethod compute_loss(*, old_log_prob: Tensor, log_prob: Tensor, advantages: Tensor, config: DiffusionActorConfig, old_prev_sample_mean: Tensor, prev_sample_mean: Tensor, std_dev_t: Tensor, sqrt_dt: Tensor, rollout_is_weights: Tensor | None = None) tuple[Tensor, dict[str, Any]][source]
Compute the GRPO-Guard policy objective.
GRPO-Guard (https://arxiv.org/abs/2510.22319) augments the standard Flow-GRPO importance ratio with a “ratio-mean bias” term that explicitly penalises drift in the reverse-SDE proposal mean of the current policy relative to the rollout policy. The mean drift is then projected onto the same scale as
log_prob - old_log_probvia the per-step diffusion coefficientsqrt_dt * sigma_t, and the final policy loss is rescaled by1 / sqrt_dt**2so that gradients have a consistent magnitude across timesteps.- Parameters:
old_log_prob (torch.Tensor) – Log-probabilities under the old policy, shape
(B,).log_prob (torch.Tensor) – Log-probabilities under the current policy, shape
(B,).advantages (torch.Tensor) – Advantage estimates, shape
(B,).config – Actor configuration;
diffusion_loss.clip_ratioanddiffusion_loss.adv_clip_maxare read from it.old_prev_sample_mean (torch.Tensor) – Reverse-SDE mean from the rollout policy, shape
(B, ...).prev_sample_mean (torch.Tensor) – Reverse-SDE mean from the current policy, shape
(B, ...).std_dev_t (torch.Tensor) – Per-step SDE standard deviation, shape
(B, 1, 1, ...)or scalar.sqrt_dt (torch.Tensor) –
sqrt(-dt)for the current denoising step, shape(B,)or scalar.rollout_is_weights (Optional[torch.Tensor]) – Optional Rollout Correction multiplier (same shape as
log_prob) combining IS weights and RS rejection (rejected samples have weight 0). When provided, the per-element policy loss is multiplied by these (detached) weights before the mean reduction.
- class verl_omni.trainer.diffusion.diffusion_algos.KLLoss[source]
KL divergence between current and reference reverse-SDE means.
- classmethod compute_loss(*, prev_sample_mean: Tensor, ref_prev_sample_mean: Tensor, std_dev_t: Tensor) tuple[Tensor, dict[str, Any]][source]
Compute KL divergence given previous sample mean and reference previous sample mean (for images or videos).
- Parameters:
prev_sample_mean – (torch.Tensor) shape is (bs, s, c)
ref_prev_sample_mean – (torch.Tensor) shape is (bs, s, c)
std_dev_t – (torch.Tensor) shape is (bs, 1, 1)
- verl_omni.trainer.diffusion.diffusion_algos.compute_flow_grpo_outcome_advantage(sample_level_rewards: Tensor, index: ndarray, epsilon: float = 0.0001, norm_adv_by_std_in_grpo: bool = True, global_std: bool = True, config: DictConfig | None = None) tuple[Tensor, Tensor][source]
Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response).
- Parameters:
sample_level_rewards – (torch.Tensor) shape is (bs, response_length)
index – (np.ndarray) index array for grouping
epsilon – (float) small value to avoid division by zero
norm_adv_by_std_in_grpo – (bool) whether to scale the GRPO advantage
global_std – (bool) whether to use global std for advantage normalization
config – (Optional[DictConfig]) algorithm configuration object
Note
If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).
- Returns:
- (torch.Tensor)
shape is (bs, response_length)
- Returns: (torch.Tensor)
shape is (bs, response_length)
- Return type:
advantages
- verl_omni.trainer.diffusion.diffusion_algos.get_diffusion_adv_estimator_fn(name_or_enum)[source]
Get the diffusion advantage estimator function with a given name.
- verl_omni.trainer.diffusion.diffusion_algos.get_diffusion_loss_fn(name: str) DiffusionLossFn[source]
Get a worker-side diffusion loss function by name.
- verl_omni.trainer.diffusion.diffusion_algos.register_diffusion_adv_est(name_or_enum: str | DiffusionAdvantageEstimator) Any[source]
Register a diffusion advantage estimator function with the given name.
- Parameters:
name_or_enum – (str) or (DiffusionAdvantageEstimator) The name or enum of the advantage estimator.
- verl_omni.trainer.diffusion.diffusion_algos.register_diffusion_loss(name: str) Callable[[type[DiffusionLossFn]], type[DiffusionLossFn]][source]
Register a worker-side diffusion loss function class.
Trainer Config
- class verl_omni.trainer.config.algorithm.DiffusionAlgoConfig(_target_: str = '', trainer_type: str = 'policy_gradient', sample_source: str = 'online', adv_estimator: str = 'flow_grpo', norm_adv_by_std_in_grpo: bool = True, global_std: bool = True, old_policy_decay_schedule: str = 'copy', old_policy_decay: float | None = None, old_policy_update_interval: int = 1, timestep_fraction: float = 1.0, adv_mode: str = 'continuous', paired_preference: bool = False, rollout_correction: RolloutCorrectionConfig = <factory>)[source]
Diffusion-specific algorithm config.
Metrics
Metrics for diffusion (image generation) training.
- verl_omni.trainer.diffusion.diffusion_metric_utils.compute_data_metrics_diffusion(batch: DataProto) dict[str, Any][source]
Computes various metrics from a diffusion training batch.
For diffusion (image generation) models, rewards and advantages are indexed over denoising timesteps rather than output tokens.
- Parameters:
batch – A DataProto object containing diffusion batch data. GRPO-style batches include sample_level_rewards [B, T], advantages [B, T], and returns [B, T]. DPO-style batches may only include sample_level_rewards [B].
- Returns:
critic/rewards/mean, max, min: Per-image reward statistics
critic/rewards/zero_std_ratio: Fraction of prompt groups whose reward std is zero
critic/rewards/std_mean: Mean per-prompt reward standard deviation
critic/rewards/group_size: Average number of images sampled per unique prompt
critic/advantages/mean, max, min: Element-wise advantage statistics over B*T, when available
critic/returns/mean, max, min: Element-wise return statistics over B*T, when available
- Return type:
A dictionary of metrics including
- verl_omni.trainer.diffusion.diffusion_metric_utils.compute_throughput_metrics_diffusion(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) dict[str, Any][source]
Computes throughput metrics for diffusion (image/video generation) training.
Unlike language model training where throughput is measured in tokens/sec, diffusion training generates images, so throughput is reported as images per second.
- Parameters:
batch – A DataProto object containing diffusion batch data.
timing_raw – A dictionary mapping stage names to their execution times in seconds. Must contain a “step” key with the total step time.
n_gpus – Number of GPUs used for training.
- Returns:
perf/total_num_images: Number of images processed in the batch
perf/time_per_step: Time taken for the step in seconds
perf/throughput: Images generated per second per GPU
- Return type:
A dictionary containing
- verl_omni.trainer.diffusion.diffusion_metric_utils.compute_timing_metrics_diffusion(timing_raw: dict[str, float], num_images: int) dict[str, Any][source]
Computes timing metrics for diffusion training.
- Parameters:
timing_raw – A dictionary mapping stage names to their execution times in seconds.
num_images – Total number of images processed in the batch, used to compute per-image timing.
- Returns:
timing_s/{name}: Raw timing in seconds for each stage
timing_per_image_ms/{name}: Per-image timing in milliseconds for core compute stages (gen, ref, old_log_prob, adv, update_actor). Non-compute stages such as save_checkpoint, update_weights, and testing are excluded.
- Return type:
A dictionary containing