Source code for verl_omni.trainer.main_diffusion

# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entrypoint for diffusion model RL training."""

import os
import socket

import hydra
import ray
from omegaconf import OmegaConf
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
from verl.trainer.ppo.utils import need_reference_policy
from verl.utils.device import auto_set_device, is_cuda_available

from verl_omni.trainer.diffusion.ray_diffusion_trainer import (
    DirectPreferenceRayTrainer,
    PolicyGradientRayTrainer,
)
from verl_omni.utils.diffusion_attention import fallback_fa3_if_unavailable, validate_attention_consistency


[docs] @hydra.main(config_path="./config", config_name="diffusion_trainer", version_base=None) def main(config): """Main entry point for diffusion model training with Hydra configuration management. Args: config: Hydra configuration dictionary containing training parameters. """ # Automatically set `config.trainer.device = npu` when running on Ascend NPU. auto_set_device(config) OmegaConf.resolve(config) fallback_fa3_if_unavailable(config) validate_attention_consistency(config) run_diffusion(config)
[docs] def run_diffusion(config, task_runner_class=None) -> None: """Initialize Ray and run distributed diffusion training. Args: config: Training configuration object containing all necessary parameters for distributed diffusion training including Ray initialization settings, model paths, and training hyperparameters. task_runner_class: For recipe to change TaskRunner. """ # Check if Ray is not initialized if not ray.is_initialized(): # Initialize Ray with a local cluster configuration # Set environment variables in the runtime environment to control tokenizer parallelism, # NCCL debug level, VLLM logging level, and allow runtime LoRA updating # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration default_runtime_env = get_ppo_ray_runtime_env() ray_init_kwargs = config.ray_kwargs.get("ray_init", {}) runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {}) runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs) ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env}) print(f"ray init kwargs: {ray_init_kwargs}") ray.init(**OmegaConf.to_container(ray_init_kwargs)) if task_runner_class is None: task_runner_class = ray.remote(num_cpus=1)(TaskRunner) # please make sure main_task is not scheduled on head # When NVIDIA Nsight Systems is selected for the controller, launch the TaskRunner under nsys # using the Ray runtime env, mirroring verl/trainer/main_ppo.py. if ( is_cuda_available and OmegaConf.select(config, "global_profiler.tool") == "nsys" and OmegaConf.select(config, "global_profiler.steps") is not None and len(OmegaConf.select(config, "global_profiler.steps")) > 0 ): from verl.utils.import_utils import is_nvtx_available assert is_nvtx_available(), "nvtx is not available in CUDA platform. Please 'pip3 install nvtx'" nsight_options = OmegaConf.to_container( config.global_profiler.global_tool_config.nsys.controller_nsight_options ) runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() else: runner = task_runner_class.remote() ray.get(runner.run.remote(config)) # [Optional] get the path of the timeline trace file from the configuration, default to None # This file is used for performance analysis timeline_json_file = config.ray_kwargs.get("timeline_json_file", None) if timeline_json_file: ray.timeline(filename=timeline_json_file)
def _get_trainer_cls(config): """Return the trainer class selected by ``algorithm.trainer_type``.""" trainer_type = config.algorithm.trainer_type if trainer_type == "policy_gradient": return PolicyGradientRayTrainer if trainer_type == "direct_preference": return DirectPreferenceRayTrainer raise ValueError( f"Unsupported diffusion trainer_type {trainer_type!r}. Expected one of: 'policy_gradient', 'direct_preference'." )
[docs] class TaskRunner: """Ray remote class for executing distributed diffusion training tasks. This class encapsulates the main training logic and runs as a Ray remote actor to enable distributed execution across multiple nodes and GPUs. Attributes: role_worker_mapping: Dictionary mapping Role enums to Ray remote worker classes mapping: Dictionary mapping Role enums to resource pool IDs for GPU allocation """ def __init__(self): self.role_worker_mapping = {} self.mapping = {}
[docs] def add_actor_rollout_worker(self, config): """Add actor (and optional rollout/ref) workers using the unified model engine.""" from verl.single_controller.ray import RayWorkerGroup from verl.trainer.ppo.ray_trainer import Role from verl_omni.workers.engine_workers import ActorRolloutRefWorker actor_rollout_cls = ActorRolloutRefWorker ray_worker_group_cls = RayWorkerGroup lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) if lora_rank <= 0: lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None if config.algorithm.sample_source == "offline": if not hasattr(Role, "Actor"): raise ValueError("Offline training without rollout requires verl Role.Actor support.") role = Role.Actor elif need_reference_policy(config) and not ref_in_actor: role = Role.ActorRolloutRef else: role = Role.ActorRollout self.role_worker_mapping[role] = ray.remote(actor_rollout_cls) self.mapping[role] = "global_pool" return actor_rollout_cls, ray_worker_group_cls
[docs] def init_resource_pool_mgr(self, config): """Initialize resource pool manager.""" global_pool_id = "global_pool" resource_pool_spec = { global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes, } if config.reward.reward_model.enable_resource_pool: if config.reward.reward_model.n_gpus_per_node <= 0: raise ValueError("config.reward.reward_model.n_gpus_per_node must be greater than 0") if config.reward.reward_model.nnodes <= 0: raise ValueError("config.reward.reward_model.nnodes must be greater than 0") reward_pool = [config.reward.reward_model.n_gpus_per_node] * config.reward.reward_model.nnodes resource_pool_spec["reward_pool"] = reward_pool else: config.reward.reward_model.nnodes = config.trainer.nnodes config.reward.reward_model.n_gpus_per_node = config.trainer.n_gpus_per_node from verl.trainer.ppo.ray_trainer import ResourcePoolManager resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=self.mapping) return resource_pool_manager
[docs] def add_reward_model_resource_pool(self, config): """Register reward-model GPU pool for online sampling (used by RewardLoopManager).""" from verl.trainer.ppo.ray_trainer import Role if config.algorithm.sample_source == "online": if config.reward.reward_model.enable: # we do not use reward model workers, so we only register reward model in resource pool # without continue to register reward model worker in role mapping if config.reward.reward_model.enable_resource_pool: self.mapping[Role.RewardModel] = "reward_pool" else: self.mapping[Role.RewardModel] = "global_pool" elif config.algorithm.sample_source == "offline": return
[docs] def add_ref_policy_worker(self, config, ref_policy_cls): """Add reference policy worker if KL loss or KL reward is used.""" # Ref policy has been fused into ActorRolloutRefWorker in new model engine. # we don't need to add a separate ref policy worker group. return
[docs] def run(self, config): """Execute the main diffusion training workflow. Args: config: Training configuration object containing all parameters needed for setting up and running the diffusion training process. """ # Print the initial configuration. `resolve=True` will evaluate symbolic values. from pprint import pprint from omegaconf import OmegaConf from verl_omni.utils.fs import resolve_model_local_dir print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") pprint(OmegaConf.to_container(config, resolve=True)) OmegaConf.resolve(config) actor_rollout_cls, ray_worker_group_cls = self.add_actor_rollout_worker(config) self.add_reward_model_resource_pool(config) # Add a reference policy worker if KL loss is used. self.add_ref_policy_worker(config, actor_rollout_cls) # Resolve the model path to an on-disk directory (downloads from HDFS or HF Hub # if necessary). `use_shm` enables shared-memory copy for faster reloads. local_path = resolve_model_local_dir( config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) ) if config.actor_rollout_ref.model.tokenizer_path is None: tokenizer_path = os.path.join(local_path, "tokenizer") config.actor_rollout_ref.model.tokenizer_path = ( tokenizer_path if os.path.exists(tokenizer_path) else local_path ) # Instantiate the tokenizer and processor. from verl.utils import hf_processor, hf_tokenizer trust_remote_code = config.data.get("trust_remote_code", False) tokenizer = hf_tokenizer(config.actor_rollout_ref.model.tokenizer_path, trust_remote_code=trust_remote_code) # Used for multimodal LLM, could be None processor_path = os.path.join(local_path, "processor") if not os.path.exists(processor_path): processor_path = local_path processor = hf_processor(processor_path, trust_remote_code=trust_remote_code, use_fast=True) resource_pool_manager = self.init_resource_pool_mgr(config) from verl_omni.utils.dataset.rl_dataset import create_rl_dataset, create_rl_sampler, get_collate_fn collate_fn = get_collate_fn(config.data) # Create training and validation datasets. train_dataset = create_rl_dataset( config.data.train_files, config.data, tokenizer, processor, is_train=True, max_samples=config.data.get("train_max_samples", -1), ) val_dataset = create_rl_dataset( config.data.val_files, config.data, tokenizer, processor, is_train=False, max_samples=config.data.get("val_max_samples", -1), ) train_sampler = create_rl_sampler(config.data, train_dataset) trainer_cls = _get_trainer_cls(config) trainer = trainer_cls( config=config, tokenizer=tokenizer, processor=processor, role_worker_mapping=self.role_worker_mapping, resource_pool_manager=resource_pool_manager, ray_worker_group_cls=ray_worker_group_cls, train_dataset=train_dataset, val_dataset=val_dataset, collate_fn=collate_fn, train_sampler=train_sampler, ) # Initialize the workers of the trainer. trainer.init_workers() # Start the training process. trainer.fit()
if __name__ == "__main__": main()