Trainer Interface

Last updated: Jul 03, 2026 (API docstrings are auto-generated).

VeRL-Omni provides Ray-based trainers for diffusion / multimodal RL. TaskRunner builds worker mappings and dispatches to a trainer subclass selected by algorithm.trainer_type:

Both subclasses inherit shared worker init from BaseRayDiffusionTrainer. Rollout and reward engines are initialized only when algorithm.sample_source=online.

verl_omni.trainer.diffusion.ray_diffusion_trainer.BaseRayDiffusionTrainer

Common Ray trainer infrastructure for diffusion training.

verl_omni.trainer.diffusion.ray_diffusion_trainer.PolicyGradientRayTrainer

Policy-gradient diffusion trainer for FlowGRPO, MixGRPO, DanceGRPO, GRPO-Guard, etc.

verl_omni.trainer.diffusion.ray_diffusion_trainer.DirectPreferenceRayTrainer

Direct-preference diffusion trainer for DPO, DiffusionNFT, AWM, etc.

verl_omni.trainer.main_diffusion.TaskRunner

Ray remote class for executing distributed diffusion training tasks.

Base Ray Diffusion Trainer

BaseRayDiffusionTrainer owns colocated actor/ref worker setup, dataloaders, validation helpers, and checkpointing. init_workers always builds actor/ref workers; rollout and reward engines are added only when algorithm.sample_source=online.

class verl_omni.trainer.diffusion.ray_diffusion_trainer.BaseRayDiffusionTrainer(config, tokenizer, role_worker_mapping: dict[~verl.trainer.ppo.utils.Role, type[~verl.single_controller.base.worker.Worker]], resource_pool_manager: ~verl.single_controller.ray.base.ResourcePoolManager, ray_worker_group_cls: type[~verl.single_controller.ray.base.RayWorkerGroup] = <class 'verl.single_controller.ray.base.RayWorkerGroup'>, processor=None, train_dataset: ~torch.utils.data.dataset.Dataset | None = None, val_dataset: ~torch.utils.data.dataset.Dataset | None = None, collate_fn=None, train_sampler: ~torch.utils.data.sampler.Sampler | None = None, device_name=None)[source]

Common Ray trainer infrastructure for diffusion training.

Paradigm-specific trainers own the training loop while sharing worker initialization, validation, checkpointing, and logging behavior.

__init__(config, tokenizer, role_worker_mapping: dict[~verl.trainer.ppo.utils.Role, type[~verl.single_controller.base.worker.Worker]], resource_pool_manager: ~verl.single_controller.ray.base.ResourcePoolManager, ray_worker_group_cls: type[~verl.single_controller.ray.base.RayWorkerGroup] = <class 'verl.single_controller.ray.base.RayWorkerGroup'>, processor=None, train_dataset: ~torch.utils.data.dataset.Dataset | None = None, val_dataset: ~torch.utils.data.dataset.Dataset | None = None, collate_fn=None, train_sampler: ~torch.utils.data.sampler.Sampler | None = None, device_name=None)[source]

Initialize distributed PPO trainer with Ray backend. Note that this trainer runs on the driver process on a single CPU/GPU node.

Parameters:
  • config – Configuration object containing training parameters.

  • tokenizer – Tokenizer used for encoding and decoding text.

  • role_worker_mapping (dict[Role, WorkerType]) – Mapping from roles to worker classes.

  • resource_pool_manager (ResourcePoolManager) – Manager for Ray resource pools.

  • ray_worker_group_cls (RayWorkerGroup, optional) – Class for Ray worker groups. Defaults to RayWorkerGroup.

  • processor – Optional data processor, used for multimodal data

  • train_dataset (Optional[Dataset], optional) – Training dataset. Defaults to None.

  • val_dataset (Optional[Dataset], optional) – Validation dataset. Defaults to None.

  • collate_fn – Function to collate data samples into batches.

  • train_sampler (Optional[Sampler], optional) – Sampler for the training dataset. Defaults to None.

  • device_name (str, optional) – Device name for training (e.g., “cuda”, “cpu”). Defaults to None.

init_workers()[source]

Initialize distributed training workers using Ray backend.

Policy Gradient Ray Trainer

PolicyGradientRayTrainer implements the online training loop for FlowGRPO-style algorithms: rollout generation, reward scoring, advantage estimation over denoising timesteps, and actor updates.

class verl_omni.trainer.diffusion.ray_diffusion_trainer.PolicyGradientRayTrainer(config, tokenizer, role_worker_mapping: dict[~verl.trainer.ppo.utils.Role, type[~verl.single_controller.base.worker.Worker]], resource_pool_manager: ~verl.single_controller.ray.base.ResourcePoolManager, ray_worker_group_cls: type[~verl.single_controller.ray.base.RayWorkerGroup] = <class 'verl.single_controller.ray.base.RayWorkerGroup'>, processor=None, train_dataset: ~torch.utils.data.dataset.Dataset | None = None, val_dataset: ~torch.utils.data.dataset.Dataset | None = None, collate_fn=None, train_sampler: ~torch.utils.data.sampler.Sampler | None = None, device_name=None)[source]

Policy-gradient diffusion trainer for FlowGRPO, MixGRPO, DanceGRPO, GRPO-Guard, etc.

fit()[source]

The training loop of FlowGRPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process.

Direct Preference Ray Trainer

DirectPreferenceRayTrainer is the extension point for direct-preference algorithms (DPO, DiffusionNFT, AWM) that train with single forward-timestep updates rather than a full multi-step SDE trajectory. The fit implementation is not yet available in-tree.

class verl_omni.trainer.diffusion.ray_diffusion_trainer.DirectPreferenceRayTrainer(config, *args, **kwargs)[source]

Direct-preference diffusion trainer for DPO, DiffusionNFT, AWM, etc.

fit()[source]

Training loop for direct-preference algorithms (DPO, DiffusionNFT, etc.). Offline algorithms read pre-computed rewards from the dataset. Online algorithms generate rollouts and compute rewards live.

verl_omni.trainer.diffusion.ray_diffusion_trainer.compute_advantage(data: DataProto, adv_estimator: str, norm_adv_by_std_in_grpo: bool = True, global_std: bool = True, config: DiffusionAlgoConfig | None = None) DataProto[source]

Compute advantage estimates for diffusion policy optimization.

This function computes advantage estimates for diffusion models using the registered advantage estimator (e.g., Flow-GRPO). The advantage estimates are used to guide policy optimization across denoising timesteps.

Parameters:
  • data (DataProto) – The data containing batched diffusion model outputs and inputs.

  • adv_estimator (str) – Name of the advantage estimator to use (e.g., Flow-GRPO).

  • norm_adv_by_std_in_grpo (bool, optional) – Whether to normalize advantages by standard deviation in GRPO. Defaults to True.

  • global_std (bool, optional) – Whether to use global standard deviation for normalization. Defaults to True.

  • config (DiffusionAlgoConfig, optional) – Configuration object for algorithm settings. Defaults to None.

Returns:

The updated data with computed advantages and returns in its batch.

Return type:

DataProto

Entry Point

Entrypoint for diffusion model RL training.

class verl_omni.trainer.main_diffusion.TaskRunner[source]

Ray remote class for executing distributed diffusion training tasks.

This class encapsulates the main training logic and runs as a Ray remote actor to enable distributed execution across multiple nodes and GPUs.

role_worker_mapping

Dictionary mapping Role enums to Ray remote worker classes

mapping

Dictionary mapping Role enums to resource pool IDs for GPU allocation

add_actor_rollout_worker(config)[source]

Add actor (and optional rollout/ref) workers using the unified model engine.

add_ref_policy_worker(config, ref_policy_cls)[source]

Add reference policy worker if KL loss or KL reward is used.

add_reward_model_resource_pool(config)[source]

Register reward-model GPU pool for online sampling (used by RewardLoopManager).

init_resource_pool_mgr(config)[source]

Initialize resource pool manager.

run(config)[source]

Execute the main diffusion training workflow.

Parameters:

config – Training configuration object containing all parameters needed for setting up and running the diffusion training process.

verl_omni.trainer.main_diffusion.main(config)[source]

Main entry point for diffusion model training with Hydra configuration management.

Parameters:

config – Hydra configuration dictionary containing training parameters.

verl_omni.trainer.main_diffusion.run_diffusion(config, task_runner_class=None) None[source]

Initialize Ray and run distributed diffusion training.

Parameters:
  • config – Training configuration object containing all necessary parameters for distributed diffusion training including Ray initialization settings, model paths, and training hyperparameters.

  • task_runner_class – For recipe to change TaskRunner.

Diffusion Algorithms

The verl_omni.trainer.diffusion.diffusion_algos module provides the loss-function and advantage-estimator registries used by the trainer. Custom losses and advantage estimators can be registered via the decorators below.

Diffusion-specific loss functions and KL penalties.

class verl_omni.trainer.diffusion.diffusion_algos.DiffusionAdvantageEstimator(value)[source]

Advantage estimators specific to diffusion-based training.

class verl_omni.trainer.diffusion.diffusion_algos.DiffusionLossFn[source]

Abstract base for worker-side diffusion loss functions.

abstractmethod classmethod compute_loss(**kwargs: Any) tuple[Tensor, dict[str, Any]][source]

Compute the pure mathematical loss and related metrics.

Subclasses define concrete tensor arguments (e.g. old_log_prob, log_prob, advantages) in their implementation.

static prepare_actor_batch(batch: DataProto, reward_tensor: Tensor, config: Any) DataProto[source]

Prepare rollout outputs for actor update when the trainer has not already done so.

Reverse-process policy-gradient losses such as FlowGRPO can keep the batch unchanged because their trainer path has already added old_log_probs and advantages. Offline DPO can also keep the batch unchanged because offline preference data plus reference predictions provide the loss inputs directly. Forward-process online losses such as DiffusionNFT, and online DPO override this hook to turn final-latent rollouts and rewards into loss-specific actor tensors.

validate_inputs(*, loss_name: str, model_output: dict[str, Any], data: TensorDict) None[source]

Validate that the worker batch contains inputs required by this loss.

class verl_omni.trainer.diffusion.diffusion_algos.DiffusionLossResult(loss: Tensor, metrics: dict[str, Any], add_loss_metric: bool = False)[source]

Output from a batch-aware diffusion loss function.

class verl_omni.trainer.diffusion.diffusion_algos.FlowGRPOLoss[source]

Flow-GRPO clipped policy objective.

classmethod compute_loss(*, old_log_prob: Tensor, log_prob: Tensor, advantages: Tensor, config: DiffusionActorConfig, rollout_is_weights: Tensor | None = None) tuple[Tensor, dict[str, Any]][source]

Compute the clipped policy objective and related metrics for FlowGRPO.

Adapted from https://github.com/yifan123/flow_grpo/blob/main/scripts/train_sd3_fast.py#L885

Parameters:
  • old_log_prob (torch.Tensor) – Log-probabilities of actions under the old policy, shape (batch_size,).

  • log_prob (torch.Tensor) – Log-probabilities of actions under the current policy, shape (batch_size,).

  • advantages (torch.Tensor) – Advantage estimates for each action, shape (batch_size,).

  • config (verl_omni.workers.config.DiffusionActorConfig) – Config for the actor.

  • rollout_is_weights (Optional[torch.Tensor]) – Optional Rollout Correction multiplier (same shape as log_prob) combining IS weights and RS rejection (rejected samples have weight 0). When provided, the per-element policy loss is multiplied by these (detached) weights before the mean reduction.

class verl_omni.trainer.diffusion.diffusion_algos.GRPOGuardLoss[source]

GRPO-Guard clipped policy objective with reverse-SDE mean drift.

classmethod compute_loss(*, old_log_prob: Tensor, log_prob: Tensor, advantages: Tensor, config: DiffusionActorConfig, old_prev_sample_mean: Tensor, prev_sample_mean: Tensor, std_dev_t: Tensor, sqrt_dt: Tensor, rollout_is_weights: Tensor | None = None) tuple[Tensor, dict[str, Any]][source]

Compute the GRPO-Guard policy objective.

GRPO-Guard (https://arxiv.org/abs/2510.22319) augments the standard Flow-GRPO importance ratio with a “ratio-mean bias” term that explicitly penalises drift in the reverse-SDE proposal mean of the current policy relative to the rollout policy. The mean drift is then projected onto the same scale as log_prob - old_log_prob via the per-step diffusion coefficient sqrt_dt * sigma_t, and the final policy loss is rescaled by 1 / sqrt_dt**2 so that gradients have a consistent magnitude across timesteps.

Parameters:
  • old_log_prob (torch.Tensor) – Log-probabilities under the old policy, shape (B,).

  • log_prob (torch.Tensor) – Log-probabilities under the current policy, shape (B,).

  • advantages (torch.Tensor) – Advantage estimates, shape (B,).

  • config – Actor configuration; diffusion_loss.clip_ratio and diffusion_loss.adv_clip_max are read from it.

  • old_prev_sample_mean (torch.Tensor) – Reverse-SDE mean from the rollout policy, shape (B, ...).

  • prev_sample_mean (torch.Tensor) – Reverse-SDE mean from the current policy, shape (B, ...).

  • std_dev_t (torch.Tensor) – Per-step SDE standard deviation, shape (B, 1, 1, ...) or scalar.

  • sqrt_dt (torch.Tensor) – sqrt(-dt) for the current denoising step, shape (B,) or scalar.

  • rollout_is_weights (Optional[torch.Tensor]) – Optional Rollout Correction multiplier (same shape as log_prob) combining IS weights and RS rejection (rejected samples have weight 0). When provided, the per-element policy loss is multiplied by these (detached) weights before the mean reduction.

class verl_omni.trainer.diffusion.diffusion_algos.KLLoss[source]

KL divergence between current and reference reverse-SDE means.

classmethod compute_loss(*, prev_sample_mean: Tensor, ref_prev_sample_mean: Tensor, std_dev_t: Tensor) tuple[Tensor, dict[str, Any]][source]

Compute KL divergence given previous sample mean and reference previous sample mean (for images or videos).

Parameters:
  • prev_sample_mean – (torch.Tensor) shape is (bs, s, c)

  • ref_prev_sample_mean – (torch.Tensor) shape is (bs, s, c)

  • std_dev_t – (torch.Tensor) shape is (bs, 1, 1)

verl_omni.trainer.diffusion.diffusion_algos.compute_flow_grpo_outcome_advantage(sample_level_rewards: Tensor, index: ndarray, epsilon: float = 0.0001, norm_adv_by_std_in_grpo: bool = True, global_std: bool = True, config: DictConfig | None = None) tuple[Tensor, Tensor][source]

Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response).

Parameters:
  • sample_level_rewards(torch.Tensor) shape is (bs, response_length)

  • index(np.ndarray) index array for grouping

  • epsilon(float) small value to avoid division by zero

  • norm_adv_by_std_in_grpo(bool) whether to scale the GRPO advantage

  • global_std(bool) whether to use global std for advantage normalization

  • config(Optional[DictConfig]) algorithm configuration object

Note

If norm_adv_by_std_in_grpo is True, the advantage is scaled by the std, as in the original GRPO. If False, the advantage is not scaled, as in Dr.GRPO (https://arxiv.org/abs/2503.20783).

Returns:

(torch.Tensor)

shape is (bs, response_length)

Returns: (torch.Tensor)

shape is (bs, response_length)

Return type:

advantages

verl_omni.trainer.diffusion.diffusion_algos.get_diffusion_adv_estimator_fn(name_or_enum)[source]

Get the diffusion advantage estimator function with a given name.

verl_omni.trainer.diffusion.diffusion_algos.get_diffusion_loss_fn(name: str) DiffusionLossFn[source]

Get a worker-side diffusion loss function by name.

verl_omni.trainer.diffusion.diffusion_algos.register_diffusion_adv_est(name_or_enum: str | DiffusionAdvantageEstimator) Any[source]

Register a diffusion advantage estimator function with the given name.

Parameters:

name_or_enum(str) or (DiffusionAdvantageEstimator) The name or enum of the advantage estimator.

verl_omni.trainer.diffusion.diffusion_algos.register_diffusion_loss(name: str) Callable[[type[DiffusionLossFn]], type[DiffusionLossFn]][source]

Register a worker-side diffusion loss function class.

Trainer Config

class verl_omni.trainer.config.algorithm.DiffusionAlgoConfig(_target_: str = '', trainer_type: str = 'policy_gradient', sample_source: str = 'online', adv_estimator: str = 'flow_grpo', norm_adv_by_std_in_grpo: bool = True, global_std: bool = True, old_policy_decay_schedule: str = 'copy', old_policy_decay: float | None = None, old_policy_update_interval: int = 1, timestep_fraction: float = 1.0, adv_mode: str = 'continuous', paired_preference: bool = False, rollout_correction: RolloutCorrectionConfig = <factory>)[source]

Diffusion-specific algorithm config.

Metrics

Metrics for diffusion (image generation) training.

verl_omni.trainer.diffusion.diffusion_metric_utils.compute_data_metrics_diffusion(batch: DataProto) dict[str, Any][source]

Computes various metrics from a diffusion training batch.

For diffusion (image generation) models, rewards and advantages are indexed over denoising timesteps rather than output tokens.

Parameters:

batch – A DataProto object containing diffusion batch data. GRPO-style batches include sample_level_rewards [B, T], advantages [B, T], and returns [B, T]. DPO-style batches may only include sample_level_rewards [B].

Returns:

  • critic/rewards/mean, max, min: Per-image reward statistics

  • critic/rewards/zero_std_ratio: Fraction of prompt groups whose reward std is zero

  • critic/rewards/std_mean: Mean per-prompt reward standard deviation

  • critic/rewards/group_size: Average number of images sampled per unique prompt

  • critic/advantages/mean, max, min: Element-wise advantage statistics over B*T, when available

  • critic/returns/mean, max, min: Element-wise return statistics over B*T, when available

Return type:

A dictionary of metrics including

verl_omni.trainer.diffusion.diffusion_metric_utils.compute_throughput_metrics_diffusion(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) dict[str, Any][source]

Computes throughput metrics for diffusion (image/video generation) training.

Unlike language model training where throughput is measured in tokens/sec, diffusion training generates images, so throughput is reported as images per second.

Parameters:
  • batch – A DataProto object containing diffusion batch data.

  • timing_raw – A dictionary mapping stage names to their execution times in seconds. Must contain a “step” key with the total step time.

  • n_gpus – Number of GPUs used for training.

Returns:

  • perf/total_num_images: Number of images processed in the batch

  • perf/time_per_step: Time taken for the step in seconds

  • perf/throughput: Images generated per second per GPU

Return type:

A dictionary containing

verl_omni.trainer.diffusion.diffusion_metric_utils.compute_timing_metrics_diffusion(timing_raw: dict[str, float], num_images: int) dict[str, Any][source]

Computes timing metrics for diffusion training.

Parameters:
  • timing_raw – A dictionary mapping stage names to their execution times in seconds.

  • num_images – Total number of images processed in the batch, used to compute per-image timing.

Returns:

  • timing_s/{name}: Raw timing in seconds for each stage

  • timing_per_image_ms/{name}: Per-image timing in milliseconds for core compute stages (gen, ref, old_log_prob, adv, update_actor). Non-compute stages such as save_checkpoint, update_weights, and testing are excluded.

Return type:

A dictionary containing