Source code for verl_omni.trainer.diffusion.ray_diffusion_trainer

# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Flow-GRPO / diffusion trainer with a Ray-based single controller.
This trainer supports model-agnostic model initialization with Hugging Face.
"""

import json
import logging
import os
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from pprint import pprint
from typing import Any, Literal, Optional

import numpy as np
import ray
import torch
from omegaconf import OmegaConf, open_dict
from PIL import Image
from torch.utils.data import Dataset, Sampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm
from verl import DataProto
from verl.checkpoint_engine import CheckpointEngineManager
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup, ResourcePoolManager
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo.metric_utils import compute_variance_proxy_metrics, process_validation_metrics
from verl.trainer.ppo.reward import extract_reward
from verl.trainer.ppo.utils import Role, WorkerType, need_reference_policy, need_reward_model
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
from verl.utils.import_utils import load_class_from_fqn
from verl.utils.metric import reduce_metrics
from verl.utils.py_functional import rename_dict
from verl.utils.tracking import ValidationGenerationsLogger
from verl.workers.rollout.llm_server import LLMServerManager

from verl_omni.trainer.config import DiffusionAlgoConfig
from verl_omni.trainer.diffusion.diffusion_algos import (
    DiffusionAdvantageEstimator,
    get_diffusion_adv_estimator_fn,
    get_diffusion_loss_fn,
)
from verl_omni.trainer.diffusion.diffusion_metric_utils import (
    compute_data_metrics_diffusion,
    compute_old_policy_metrics,
    compute_reward_extra_metrics_diffusion,
    compute_throughput_metrics_diffusion,
    compute_timing_metrics_diffusion,
)
from verl_omni.trainer.diffusion.diffusion_trainer_utils import NoOpCheckpointManager, old_policy_decay
from verl_omni.trainer.diffusion.rollout_correction import (
    apply_bypass_mode_to_diffusion_batch,
    apply_rollout_correction_to_diffusion_batch,
    rollout_correction_enabled,
)
from verl_omni.workers.utils.padding import embeds_padding_2_no_padding

sys_logger = logging.getLogger(__name__)


[docs] def compute_advantage( data: DataProto, adv_estimator: str, norm_adv_by_std_in_grpo: bool = True, global_std: bool = True, config: Optional[DiffusionAlgoConfig] = None, ) -> DataProto: """Compute advantage estimates for diffusion policy optimization. This function computes advantage estimates for diffusion models using the registered advantage estimator (e.g., Flow-GRPO). The advantage estimates are used to guide policy optimization across denoising timesteps. Args: data (DataProto): The data containing batched diffusion model outputs and inputs. adv_estimator (str): Name of the advantage estimator to use (e.g., Flow-GRPO). norm_adv_by_std_in_grpo (bool, optional): Whether to normalize advantages by standard deviation in GRPO. Defaults to True. global_std (bool, optional): Whether to use global standard deviation for normalization. Defaults to True. config (DiffusionAlgoConfig, optional): Configuration object for algorithm settings. Defaults to None. Returns: DataProto: The updated data with computed ``advantages`` and ``returns`` in its batch. """ adv_kwargs = { "sample_level_rewards": data.batch["sample_level_rewards"], "config": config, } if "uid" in data.non_tensor_batch: adv_kwargs["index"] = data.non_tensor_batch["uid"] if "reward_baselines" in data.batch: adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] adv_estimator_fn = get_diffusion_adv_estimator_fn(adv_estimator) if adv_estimator == DiffusionAdvantageEstimator.FLOW_GRPO: adv_kwargs["norm_adv_by_std_in_grpo"] = norm_adv_by_std_in_grpo adv_kwargs["global_std"] = global_std advantages, returns = adv_estimator_fn(**adv_kwargs) data.batch["advantages"] = advantages data.batch["returns"] = returns return data
[docs] class BaseRayDiffusionTrainer(ABC): """Common Ray trainer infrastructure for diffusion training. Paradigm-specific trainers own the training loop while sharing worker initialization, validation, checkpointing, and logging behavior. """
[docs] def __init__( self, config, tokenizer, role_worker_mapping: dict[Role, WorkerType], resource_pool_manager: ResourcePoolManager, ray_worker_group_cls: type[RayWorkerGroup] = RayWorkerGroup, processor=None, train_dataset: Optional[Dataset] = None, val_dataset: Optional[Dataset] = None, collate_fn=None, train_sampler: Optional[Sampler] = None, device_name=None, ): """ Initialize distributed PPO trainer with Ray backend. Note that this trainer runs on the driver process on a single CPU/GPU node. Args: config: Configuration object containing training parameters. tokenizer: Tokenizer used for encoding and decoding text. role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes. resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools. ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup. processor: Optional data processor, used for multimodal data train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None. val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None. collate_fn: Function to collate data samples into batches. train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None. device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to None. """ # Store the tokenizer for text processing self.tokenizer = tokenizer self.processor = processor self.config = config self.hybrid_engine = config.actor_rollout_ref.hybrid_engine if config.algorithm.sample_source == "online": assert self.hybrid_engine, "Currently, only support hybrid engine" assert Role.ActorRollout in role_worker_mapping or Role.ActorRolloutRef in role_worker_mapping, ( f"{role_worker_mapping.keys()=}" ) else: assert Role.Actor in role_worker_mapping, f"{role_worker_mapping.keys()=}" self.role_worker_mapping = role_worker_mapping self.resource_pool_manager = resource_pool_manager self.use_reference_policy = need_reference_policy(self.config) self.use_rm = need_reward_model(self.config) self.ray_worker_group_cls = ray_worker_group_cls self.device_name = device_name if device_name else self.config.trainer.device self.validation_generations_logger = ValidationGenerationsLogger( project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, ) # if ref_in_actor is True, the reference policy will be actor without lora applied lora_rank = config.actor_rollout_ref.model.get("lora", {}).get("rank", 0) if lora_rank <= 0: lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0) self.ref_in_actor = lora_rank > 0 or config.actor_rollout_ref.model.get("lora_adapter_path") is not None self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) self.checkpoint_manager = None
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler: Optional[Sampler]): """ Creates the train and validation dataloaders. """ # TODO: we have to make sure the batch size is divisible by the dp size from verl_omni.utils.dataset.rl_dataset import create_rl_dataset, create_rl_sampler, get_collate_fn if train_dataset is None: train_dataset = create_rl_dataset( self.config.data.train_files, self.config.data, self.tokenizer, self.processor, max_samples=self.config.data.get("train_max_samples", -1), ) if val_dataset is None: val_dataset = create_rl_dataset( self.config.data.val_files, self.config.data, self.tokenizer, self.processor, max_samples=self.config.data.get("val_max_samples", -1), ) self.train_dataset, self.val_dataset = train_dataset, val_dataset if train_sampler is None: train_sampler = create_rl_sampler(self.config.data, self.train_dataset) if collate_fn is None: collate_fn = get_collate_fn(self.config.data) num_workers = self.config.data["dataloader_num_workers"] self.train_dataloader = StatefulDataLoader( dataset=self.train_dataset, batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size), num_workers=num_workers, drop_last=True, collate_fn=collate_fn, sampler=train_sampler, ) val_batch_size = self.config.data.val_batch_size # Prefer config value if set if val_batch_size is None: val_batch_size = len(self.val_dataset) self.val_dataloader = StatefulDataLoader( dataset=self.val_dataset, batch_size=val_batch_size, num_workers=num_workers, shuffle=self.config.data.get("validation_shuffle", True), drop_last=False, collate_fn=collate_fn, ) assert len(self.train_dataloader) >= 1, "Train dataloader is empty!" assert len(self.val_dataloader) >= 1, "Validation dataloader is empty!" print( f"Size of train dataloader: {len(self.train_dataloader)}, Size of val dataloader: " f"{len(self.val_dataloader)}" ) total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs if self.config.trainer.total_training_steps is not None: total_training_steps = self.config.trainer.total_training_steps self.total_training_steps = total_training_steps print(f"Total training steps: {self.total_training_steps}") try: OmegaConf.set_struct(self.config, True) with open_dict(self.config): if OmegaConf.select(self.config, "actor_rollout_ref.actor.optim"): self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps except Exception as e: print(f"Warning: Could not set total_training_steps in config. Structure missing? Error: {e}") def _dump_generations(self, inputs, outputs, gts, scores, reward_extra_infos_dict, dump_path): """Dump rollout/validation samples as JSONL.""" os.makedirs(dump_path, exist_ok=True) visual_folder = os.path.join(dump_path, f"{self.global_steps}") os.makedirs(visual_folder, exist_ok=True) output_paths = [] images_pil = outputs.cpu().float().permute(0, 2, 3, 1).numpy() images_pil = (images_pil * 255).round().clip(0, 255).astype("uint8") for i, image in enumerate(images_pil): image_path = os.path.join(visual_folder, f"{i}.jpg") Image.fromarray(image).save(image_path) output_paths.append(image_path) filename = os.path.join(dump_path, f"{self.global_steps}.jsonl") n = len(inputs) base_data = { "input": inputs, "output": output_paths, "gts": gts, "score": scores, "step": [self.global_steps] * n, } for k, v in reward_extra_infos_dict.items(): if len(v) == n: base_data[k] = v lines = [] for i in range(n): entry = {k: v[i] for k, v in base_data.items()} lines.append(json.dumps(entry, ensure_ascii=False)) with open(filename, "w") as f: f.write("\n".join(lines) + "\n") print(f"Dumped generations to {filename}") def _log_rollout_data( self, batch: DataProto, reward_extra_infos_dict: dict, timing_raw: dict, rollout_data_dir: str ): """Log rollout data to disk. Args: batch (DataProto): The batch containing rollout data reward_extra_infos_dict (dict): Additional reward information to log timing_raw (dict): Timing information for profiling rollout_data_dir (str): Directory path to save the rollout data """ with marked_timer("dump_rollout_generations", timing_raw, color="green"): inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) outputs = batch.batch["responses"] scores = batch.batch["sample_level_scores"].sum(-1).cpu().tolist() sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] reward_extra_infos_to_dump = { k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in reward_extra_infos_dict.items() } if "request_id" in batch.non_tensor_batch: reward_extra_infos_to_dump.setdefault( "request_id", batch.non_tensor_batch["request_id"].tolist(), ) self._dump_generations( inputs=inputs, outputs=outputs, gts=sample_gts, scores=scores, reward_extra_infos_dict=reward_extra_infos_to_dump, dump_path=rollout_data_dir, ) def _maybe_log_val_generations(self, inputs, outputs, scores): """Log a table of validation samples to the configured logger (wandb or swanlab)""" generations_to_log = self.config.trainer.log_val_generations if generations_to_log == 0: return import numpy as np # Create tuples of (input, output, score) and sort by input text if "wandb" in self.config.trainer.logger: import wandb outputs = [wandb.Image(image.float(), file_type="jpg") for image in outputs] samples = list(zip(inputs, outputs, scores, strict=True)) samples.sort(key=lambda x: x[0]) # Sort by input text # Use fixed random seed for deterministic shuffling rng = np.random.RandomState(42) rng.shuffle(samples) # Take first N samples after shuffling samples = samples[:generations_to_log] # Log to each configured logger self.validation_generations_logger.log(self.config.trainer.logger, samples, self.global_steps) def _get_gen_batch(self, batch: DataProto) -> DataProto: reward_keys = set({"data_source", "reward_model", "extra_info", "uid"}) & batch.non_tensor_batch.keys() # pop those keys for generation batch_keys_to_pop = [] non_tensor_batch_keys_to_pop = set(batch.non_tensor_batch.keys()) - reward_keys gen_batch = batch.pop( batch_keys=batch_keys_to_pop, non_tensor_batch_keys=list(non_tensor_batch_keys_to_pop), ) # For agent loop, we need reward model keys to compute score. gen_batch.non_tensor_batch.update(batch.non_tensor_batch) return gen_batch def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: """ compute reward use colocate reward model """ assert self.reward_loop_manager is not None, "RewardLoopManager is None" batch_reward = self.reward_loop_manager.compute_rm_score(batch) return batch_reward def _validate(self): data_source_lst = [] reward_extra_infos_dict: dict[str, list] = defaultdict(list) # Lists to collect samples for the table sample_inputs = [] sample_outputs = [] sample_gts = [] sample_scores = [] sample_turns = [] sample_uids = [] for test_data in self.val_dataloader: test_batch = DataProto.from_single_dict(test_data) if "uid" not in test_batch.non_tensor_batch: test_batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(test_batch.batch))], dtype=object ) # repeat test batch test_batch = test_batch.repeat( repeat_times=self.config.actor_rollout_ref.rollout.val_kwargs.n, interleave=True ) ground_truths = [ item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in test_batch ] sample_gts.extend(ground_truths) test_gen_batch = self._get_gen_batch(test_batch) test_gen_batch.meta_info = { "recompute_log_prob": False, "validate": True, "global_steps": self.global_steps, } print(f"test_gen_batch meta info: {test_gen_batch.meta_info}") # pad to be divisible by dp_size size_divisor = self.config.actor_rollout_ref.rollout.agent.num_workers test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor) test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded) if self.use_rm and "rm_scores" not in test_output_gen_batch_padded.batch.keys(): # for colocate reward models, we need to sleep rollout model # to spare GPU memory for reward model self.checkpoint_manager.sleep_replicas() batch_reward = self._compute_reward_colocate(test_output_gen_batch_padded) test_output_gen_batch_padded = test_output_gen_batch_padded.union(batch_reward) # wake up rollout model # replace with wake_up method once supported self.checkpoint_manager.update_weights(self.global_steps) # unpad test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size) print("validation generation end") # Store generated outputs output_images = test_output_gen_batch.batch["responses"] sample_outputs.append(output_images) test_batch = test_batch.union(test_output_gen_batch) test_batch.meta_info["validate"] = True # Store original inputs input_ids = test_batch.batch["prompts"] input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids] sample_inputs.extend(input_texts) sample_uids.extend(test_batch.non_tensor_batch["uid"]) # evaluate using reward_function reward_tensor, reward_extra_info = extract_reward(test_batch) scores = reward_tensor.sum(-1).cpu().tolist() sample_scores.extend(scores) reward_extra_infos_dict["reward"].extend(scores) for key, values in reward_extra_info.items(): if key not in reward_extra_infos_dict: reward_extra_infos_dict[key] = [] if isinstance(values, np.ndarray): reward_extra_infos_dict[key].extend(values.tolist()) else: reward_extra_infos_dict[key].extend(values if isinstance(values, list) else [values]) # collect num_turns of each prompt if "__num_turns__" in test_batch.non_tensor_batch: sample_turns.append(test_batch.non_tensor_batch["__num_turns__"]) data_source_lst.append(test_batch.non_tensor_batch.get("data_source", ["unknown"] * reward_tensor.shape[0])) sample_outputs = torch.cat(sample_outputs, dim=0) self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores) # dump generations val_data_dir = self.config.trainer.get("validation_data_dir", None) if val_data_dir: self._dump_generations( inputs=sample_inputs, outputs=sample_outputs, gts=sample_gts, scores=sample_scores, reward_extra_infos_dict=reward_extra_infos_dict, dump_path=val_data_dir, ) for key_info, lst in reward_extra_infos_dict.items(): assert len(lst) == 0 or len(lst) == len(sample_scores), f"{key_info}: {len(lst)=}, {len(sample_scores)=}" data_sources = np.concatenate(data_source_lst, axis=0) return self._val_metrics_update(data_sources, sample_uids, reward_extra_infos_dict, sample_turns) def _val_metrics_update(self, data_sources, sample_uids, reward_extra_infos_dict, sample_turns): data_src2var2metric2val = process_validation_metrics(data_sources, sample_uids, reward_extra_infos_dict) metric_dict = {} for data_source, var2metric2val in data_src2var2metric2val.items(): core_var = "acc" if "acc" in var2metric2val else "reward" for var_name, metric2val in var2metric2val.items(): n_max = max([int(name.split("@")[-1].split("/")[0]) for name in metric2val.keys()]) for metric_name, metric_val in metric2val.items(): if ( (var_name == core_var) and any(metric_name.startswith(pfx) for pfx in ["mean", "maj", "best"]) and (f"@{n_max}" in metric_name) ): metric_sec = "val-core" else: metric_sec = "val-aux" pfx = f"{metric_sec}/{data_source}/{var_name}/{metric_name}" metric_dict[pfx] = metric_val if len(sample_turns) > 0: sample_turns = np.concatenate(sample_turns) metric_dict["val-aux/num_turns/min"] = sample_turns.min() metric_dict["val-aux/num_turns/max"] = sample_turns.max() metric_dict["val-aux/num_turns/mean"] = sample_turns.mean() return metric_dict
[docs] def init_workers(self): """Initialize distributed training workers using Ray backend.""" actor_rollout_resource_pool = self._init_colocated_workers() if self.config.algorithm.sample_source == "offline": return self._init_online_rollout_stack(actor_rollout_resource_pool)
def _init_colocated_workers(self): """Create Ray pools and colocated actor/ref worker groups (online and offline).""" self.resource_pool_manager.create_resource_pool() self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} # create actor and rollout (offline uses Role.Actor only; online uses hybrid actor_rollout roles) if Role.Actor in self.role_worker_mapping: actor_role = Role.Actor elif Role.ActorRolloutRef in self.role_worker_mapping: actor_role = Role.ActorRolloutRef else: actor_role = Role.ActorRollout if self.hybrid_engine or actor_role == Role.Actor: actor_rollout_resource_pool = self.resource_pool_manager.get_resource_pool(actor_role) actor_rollout_cls = RayClassWithInitArgs( cls=self.role_worker_mapping[actor_role], config=self.config.actor_rollout_ref, role=str(actor_role), ) self.resource_pool_to_cls[actor_rollout_resource_pool][str(actor_role)] = actor_rollout_cls else: raise NotImplementedError # create reference policy if needed if self.use_reference_policy and Role.RefPolicy in self.role_worker_mapping: resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) ref_policy_cls = RayClassWithInitArgs( self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, role=str(Role.RefPolicy), ) self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, # you should not use `create_colocated_worker_cls`. # Instead, directly pass different resource pool to different worker groups. # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. all_wg = {} wg_kwargs = {} # Setting up kwargs for RayWorkerGroup if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout # Forward profiling steps and (when nsys is selected) per-worker Nsight options to the # Ray worker group so that workers can be launched under nsys with the right capture range. if OmegaConf.select(self.config, "global_profiler.steps") is not None: wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": worker_nsight_options = OmegaConf.select( self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options" ) assert worker_nsight_options is not None, ( "global_profiler.global_tool_config.nsys.worker_nsight_options must be set " "when using nsys with global_profiler.steps" ) wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(worker_nsight_options) wg_kwargs["device_name"] = self.device_name for resource_pool, class_dict in self.resource_pool_to_cls.items(): if not class_dict: continue worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs, ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) if self.use_reference_policy and not self.ref_in_actor: if str(Role.RefPolicy) in all_wg: self.ref_policy_wg = all_wg[str(Role.RefPolicy)] self.ref_policy_wg.init_model() else: # Model engine: ActorRolloutRefWorker assert str(Role.ActorRolloutRef) in all_wg, f"{all_wg.keys()=}" self.ref_policy_wg = all_wg[str(Role.ActorRolloutRef)] # we should create rollout at the end so that vllm can have a better estimation of kv cache memory self.actor_rollout_wg = all_wg[str(actor_role)] self.actor_rollout_wg.init_model() if self.ref_in_actor: self.ref_policy_wg = self.actor_rollout_wg return actor_rollout_resource_pool def _init_online_rollout_stack(self, actor_rollout_resource_pool): """Initialize rollout, reward, and checkpoint engines (online sampling only).""" # create reward loop manager from verl.experimental.reward_loop import RewardLoopManager # initalize reward loop manager # reward model (colocate or standalone): get resource_pool # no reward model: resource_pool = None resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) if self.use_rm else None self.reward_loop_manager = RewardLoopManager( config=self.config, rm_resource_pool=resource_pool, ) # create async rollout manager and request scheduler # Note: mode is always "async" since sync mode is deprecated self.async_rollout_mode = True # Support custom AgentLoopManager via config manager_class_fqn = self.config.actor_rollout_ref.rollout.get("agent", {}).get("agent_loop_manager_class") if manager_class_fqn: AgentLoopManager = load_class_from_fqn(manager_class_fqn, "AgentLoopManager") else: from verl.experimental.agent_loop import AgentLoopManager from verl_omni.agent_loop import DiffusionAgentLoopWorker AgentLoopManager.agent_loop_workers_class = ray.remote(DiffusionAgentLoopWorker) # infrastructure overview: https://verl.readthedocs.io/en/latest/advance/reward_loop.html#architecture-design # agent_reward_loop: streaming reward computation with actor rollout # two conditions satisfied: (1) no reward model, or (2) reward model with extra resource pool enable_agent_reward_loop = not self.use_rm or self.config.reward.reward_model.enable_resource_pool # if enable_agent_reward_loop, we directly pass reward_loop_workers to agent loop manager # to stream reward computation with actor rollout reward_loop_worker_handles = self.reward_loop_manager.reward_loop_workers if enable_agent_reward_loop else None self.llm_server_manager = LLMServerManager.create( config=self.config, worker_group=self.actor_rollout_wg, rollout_resource_pool=actor_rollout_resource_pool, ) self.async_rollout_manager = AgentLoopManager.create( config=self.config, llm_client=self.llm_server_manager.get_client(), reward_loop_worker_handles=reward_loop_worker_handles, ) checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine) self.checkpoint_manager = CheckpointEngineManager( config=checkpoint_engine_config, actor_wg=self.actor_rollout_wg, replicas=self.llm_server_manager.get_replicas(), ) # sleep all replicas to load checkpoint self.checkpoint_manager.sleep_replicas() def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe # path: given_path + `/global_step_{global_steps}` + `/actor` local_global_step_folder = os.path.join( self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" ) print(f"local_global_step_folder: {local_global_step_folder}") actor_local_path = os.path.join(local_global_step_folder, "actor") actor_remote_path = ( None if self.config.trainer.default_hdfs_dir is None else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "actor") ) remove_previous_ckpt_in_save = self.config.trainer.get("remove_previous_ckpt_in_save", False) if remove_previous_ckpt_in_save: print("Warning: remove_previous_ckpt_in_save is deprecated," + " set max_actor_ckpt_to_keep=1 instead") max_actor_ckpt_to_keep = ( self.config.trainer.get("max_actor_ckpt_to_keep", None) if not remove_previous_ckpt_in_save else 1 ) self.actor_rollout_wg.save_checkpoint( actor_local_path, actor_remote_path, self.global_steps, max_ckpt_to_keep=max_actor_ckpt_to_keep ) # save dataloader local_mkdir_safe(local_global_step_folder) dataloader_local_path = os.path.join(local_global_step_folder, "data.pt") dataloader_state_dict = self.train_dataloader.state_dict() torch.save(dataloader_state_dict, dataloader_local_path) # latest checkpointed iteration tracker (for atomic usage) if ( hasattr(self.config.actor_rollout_ref.actor.checkpoint, "async_save") and self.config.actor_rollout_ref.actor.checkpoint.async_save ) or ( "async_save" in self.config.actor_rollout_ref.actor.checkpoint and self.config.actor_rollout_ref.actor.checkpoint["async_save"] ): print("skip write latest_checkpointed_iteration.txt when async_save is True") return local_latest_checkpointed_iteration = os.path.join( self.config.trainer.default_local_dir, "latest_checkpointed_iteration.txt" ) with open(local_latest_checkpointed_iteration, "w") as f: f.write(str(self.global_steps)) def _load_checkpoint(self): if self.config.trainer.resume_mode == "disable": return 0 # load from hdfs if self.config.trainer.default_hdfs_dir is not None: raise NotImplementedError("load from hdfs is not implemented yet") else: checkpoint_folder = self.config.trainer.default_local_dir if not os.path.isabs(checkpoint_folder): working_dir = os.getcwd() checkpoint_folder = os.path.join(working_dir, checkpoint_folder) global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder if self.config.trainer.resume_mode == "auto": if global_step_folder is None: print("Training from scratch") return 0 else: if self.config.trainer.resume_mode == "resume_path": assert isinstance(self.config.trainer.resume_from_path, str), "resume ckpt must be str type" assert "global_step_" in self.config.trainer.resume_from_path, ( "resume ckpt must specify the global_steps" ) global_step_folder = self.config.trainer.resume_from_path if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) print(f"Load from checkpoint folder: {global_step_folder}") # set global step self.global_steps = int(global_step_folder.split("global_step_")[-1]) print(f"Setting global step to {self.global_steps}") print(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor") # load actor self.actor_rollout_wg.load_checkpoint( actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load ) # load dataloader, dataloader_local_path = os.path.join(global_step_folder, "data.pt") if os.path.exists(dataloader_local_path): dataloader_state_dict = torch.load(dataloader_local_path, weights_only=False) self.train_dataloader.load_state_dict(dataloader_state_dict) else: print(f"Warning: No dataloader state found at {dataloader_local_path}, will start from scratch") def _update_actor(self, batch: DataProto) -> DataProto: rollout_config = self.config.actor_rollout_ref.rollout batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable # update actor batch_td = batch.to_tensordict() # step 2: convert from padding to no-padding batch_td = embeds_padding_2_no_padding(batch_td) ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size ppo_mini_batch_size = ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs seed = self.config.actor_rollout_ref.actor.data_loader_seed shuffle = self.config.actor_rollout_ref.actor.shuffle tu.assign_non_tensor( batch_td, global_batch_size=ppo_mini_batch_size, mini_batch_size=ppo_mini_batch_size, epochs=ppo_epochs, seed=seed, dataloader_kwargs={"shuffle": shuffle}, height=self.config.actor_rollout_ref.model.pipeline.height, width=self.config.actor_rollout_ref.model.pipeline.width, vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), ) actor_output = self.actor_rollout_wg.update_actor(batch_td) actor_output = tu.get(actor_output, "metrics") actor_output = rename_dict(actor_output, "actor/") if (actor_mfu := actor_output.pop("actor/mfu", None)) is not None: actor_output["perf/mfu/actor"] = actor_mfu return DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output}) def _start_profiling(self, do_profile: bool) -> None: """Start profiling for all worker groups if profiling is enabled.""" if do_profile: self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) if self.use_reference_policy and not self.ref_in_actor: self.ref_policy_wg.start_profile(profile_step=self.global_steps) def _stop_profiling(self, do_profile: bool) -> None: """Stop profiling for all worker groups if profiling is enabled.""" if do_profile: self.actor_rollout_wg.stop_profile() if self.use_reference_policy and not self.ref_in_actor: self.ref_policy_wg.stop_profile() @abstractmethod def fit(self): """Run the trainer-type-specific training loop.""" pass
[docs] class PolicyGradientRayTrainer(BaseRayDiffusionTrainer): """Policy-gradient diffusion trainer for FlowGRPO, MixGRPO, DanceGRPO, GRPO-Guard, etc.""" def _compute_ref_log_prob(self, batch: DataProto) -> DataProto: batch_td = batch.to_tensordict() batch_td = embeds_padding_2_no_padding(batch_td) metadata = { "compute_loss": False, "height": self.config.actor_rollout_ref.model.pipeline.height, "width": self.config.actor_rollout_ref.model.pipeline.width, "vae_scale_factor": self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), } if self.ref_in_actor: metadata["no_lora_adapter"] = True tu.assign_non_tensor(batch_td, **metadata) if self.ref_in_actor: output = self.actor_rollout_wg.infer_actor_batch(batch_td) else: output = self.ref_policy_wg.infer_ref_batch(batch_td) # gather output log_probs = tu.get(output, "log_probs") prev_sample_mean = tu.get(output, "prev_sample_mean") ref_log_prob = tu.get_tensordict( {"ref_log_prob": log_probs.float(), "ref_prev_sample_mean": prev_sample_mean.float()} ) return DataProto.from_tensordict(ref_log_prob) def _compute_old_log_prob(self, batch: DataProto) -> tuple[DataProto, Optional[float]]: batch_td = batch.to_tensordict() batch_td = embeds_padding_2_no_padding(batch_td) tu.assign_non_tensor( batch_td, compute_loss=False, height=self.config.actor_rollout_ref.model.pipeline.height, width=self.config.actor_rollout_ref.model.pipeline.width, vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), ) output = self.actor_rollout_wg.infer_actor_batch(batch_td) log_probs = tu.get(output, "log_probs") old_log_prob_dict = {"old_log_probs": log_probs.float()} prev_sample_mean = tu.get(output, "prev_sample_mean") if prev_sample_mean is not None: old_log_prob_dict["old_prev_sample_mean"] = prev_sample_mean.float() old_log_prob = tu.get_tensordict(old_log_prob_dict) old_log_prob_mfu = tu.get(output, "metrics").get("mfu") return DataProto.from_tensordict(old_log_prob), old_log_prob_mfu
[docs] def fit(self): """ The training loop of FlowGRPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process. """ from omegaconf import OmegaConf from verl.utils.tracking import Tracking logger = Tracking( project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True), ) self.global_steps = 0 # load checkpoint and update weights before doing anything self._load_checkpoint() self.checkpoint_manager.update_weights(self.global_steps) current_epoch = self.global_steps // len(self.train_dataloader) # perform validation before training # currently, we only support validation using the reward_function. if self.config.trainer.get("val_before_train", True): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get("val_only", False): return # add tqdm progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") # we start from step 1 self.global_steps += 1 last_val_metrics = None self.max_steps_duration = 0 # Profiler step state machine. Mirrors verl/trainer/ppo/ray_trainer.py. prev_step_profile = False curr_step_profile = ( self.global_steps in self.config.global_profiler.steps if self.config.global_profiler.steps is not None else False ) next_step_profile = False for epoch in range(current_epoch, self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) metrics = {} timing_raw = {} with marked_timer("start_profile", timing_raw): self._start_profiling( not prev_step_profile and curr_step_profile if self.config.global_profiler.profile_continuous_steps else curr_step_profile ) batch: DataProto = DataProto.from_single_dict(batch_dict) # add uid to batch batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object ) gen_batch = self._get_gen_batch(batch) # Pass step metadata to rollout before expansion. gen_batch.meta_info["global_steps"] = self.global_steps # Per-step rollout seed for reproducibility rollout_seed_cfg = self.config.actor_rollout_ref.rollout.get("seed") if rollout_seed_cfg is not None: gen_batch.meta_info["rollout_seed"] = int(rollout_seed_cfg) + self.global_steps - 1 gen_batch_output = gen_batch.repeat( repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True ) gen_batch_output.non_tensor_batch["_rollout_seed_global_idx"] = np.arange( len(gen_batch_output), dtype=np.int64 ) is_last_step = self.global_steps >= self.total_training_steps with marked_timer("step", timing_raw): # generate a batch with marked_timer("gen", timing_raw, color="red"): gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) self.checkpoint_manager.sleep_replicas() timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) with marked_timer("reward", timing_raw, color="yellow"): # compute reward model score if self.use_rm and "rm_scores" not in batch.batch.keys(): batch_reward = self._compute_reward_colocate(batch) batch = batch.union(batch_reward) # extract reward_tensor and reward_extra_infos_dict for training reward_tensor, reward_extra_infos_dict = extract_reward(batch) # Bypass mode: skip old_log_prob recompute (2 policies). # Decoupled mode: recompute old_log_probs as proximal anchor (3 policies). rollout_corr_config = self.config.algorithm.get("rollout_correction", None) bypass_recomputing_logprobs = rollout_corr_config and rollout_corr_config.get("bypass_mode", False) if bypass_recomputing_logprobs: # Use `rollout_log_probs` apply_bypass_mode_to_diffusion_batch(batch) else: # Recompute old_log_probs with marked_timer("old_log_prob", timing_raw, color="blue"): old_log_prob, old_log_prob_mfu = self._compute_old_log_prob(batch) if old_log_prob_mfu is not None: metrics.update({"perf/mfu/actor_infer": old_log_prob_mfu}) batch = batch.union(old_log_prob) assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' # Decoupled-mode rollout correction (old vs rollout). # In bypass mode old == rollout, so correction runs per-step in ``diffusion_loss``. if not bypass_recomputing_logprobs and rollout_correction_enabled(rollout_corr_config): with marked_timer("rollout_corr", timing_raw, color="cyan"): batch, rollout_corr_metrics = apply_rollout_correction_to_diffusion_batch( batch, rollout_corr_config ) metrics.update(rollout_corr_metrics) if self.use_reference_policy: # compute reference log_prob with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): ref_log_prob = self._compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) with marked_timer("adv", timing_raw, color="brown"): # we combine with rule-based rm reward_extra_infos_dict: dict[str, list] batch.batch["sample_level_scores"] = reward_tensor if reward_extra_infos_dict: batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) num_timesteps = batch.batch["old_log_probs"].shape[1] batch.batch["sample_level_rewards"] = batch.batch["sample_level_scores"].expand( -1, num_timesteps ) # compute advantages, executed on the driver process norm_adv_by_std_in_grpo = self.config.algorithm.get( "norm_adv_by_std_in_grpo", True ) # GRPO adv normalization factor batch = compute_advantage( batch, adv_estimator=self.config.algorithm.adv_estimator, norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, global_std=self.config.algorithm.global_std, config=self.config.algorithm, ) # update actor with marked_timer("update_actor", timing_raw, color="red"): actor_output = self._update_actor(batch) # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. esi_close_to_expiration = should_save_ckpt_esi( max_steps_duration=self.max_steps_duration, redundant_time=self.config.trainer.esi_redundant_time, ) # Check if the conditions for saving a checkpoint are met. # The conditions include a mandatory condition (1) and # one of the following optional conditions (2/3/4): # 1. The save frequency is set to a positive value. # 2. It's the last training step. # 3. The current step number is a multiple of the save frequency. # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. if self.config.trainer.save_freq > 0 and ( is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration ): if esi_close_to_expiration: print("Force saving checkpoint: ESI instance expiration approaching.") with marked_timer("save_checkpoint", timing_raw, color="green"): self._save_checkpoint() # update weights from trainer to rollout with marked_timer("update_weights", timing_raw, color="red"): self.checkpoint_manager.update_weights(self.global_steps) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) # validate if self.config.trainer.test_freq > 0 and ( is_last_step or self.global_steps % self.config.trainer.test_freq == 0 ): with marked_timer("testing", timing_raw, color="green"): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) with marked_timer("stop_profile", timing_raw): next_step_profile = ( self.global_steps + 1 in self.config.global_profiler.steps if self.config.global_profiler.steps is not None else False ) self._stop_profiling( curr_step_profile and not next_step_profile if self.config.global_profiler.profile_continuous_steps else curr_step_profile ) prev_step_profile = curr_step_profile curr_step_profile = next_step_profile steps_duration = timing_raw["step"] self.max_steps_duration = max(self.max_steps_duration, steps_duration) # training metrics metrics.update( { "training/global_step": self.global_steps, "training/epoch": epoch, } ) # collect metrics metrics.update(compute_data_metrics_diffusion(batch=batch)) n_gpus = self.resource_pool_manager.get_n_gpus() num_images = batch.batch["advantages"].shape[0] metrics.update(compute_timing_metrics_diffusion(timing_raw=timing_raw, num_images=num_images)) metrics.update(compute_throughput_metrics_diffusion(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) metrics.update(compute_reward_extra_metrics_diffusion(reward_extra_infos_dict)) # compute variance proxy metrics gradient_norm = metrics.get("actor/grad_norm", None) metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm)) logger.log(data=metrics, step=self.global_steps) progress_bar.update(1) self.global_steps += 1 if is_last_step: if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() return # this is experimental and may be changed/removed in the future # in favor of a general-purpose data buffer pool if hasattr(self.train_dataset, "on_batch_end"): # The dataset may be changed after each training batch self.train_dataset.on_batch_end(batch=batch)
[docs] class DirectPreferenceRayTrainer(BaseRayDiffusionTrainer): """Direct-preference diffusion trainer for DPO, DiffusionNFT, AWM, etc.""" def __init__( self, config, *args, **kwargs, ): super().__init__(config, *args, **kwargs) self.is_offline = config.algorithm.get("sample_source", "online") == "offline" loss_mode = config.actor_rollout_ref.actor.diffusion_loss.loss_mode # DPO needs trainer-side ref noise preds; DiffusionNFT computes ref in the actor engine. self.use_reference_policy = need_reference_policy(self.config) or (loss_mode == "dpo") self._has_old_adapter = "old" in tuple( config.actor_rollout_ref.model.get("policy_state_adapters", ("default",)) ) if self._has_old_adapter: self._validate_old_adapter_config() self._loss_fn = get_diffusion_loss_fn(loss_mode) def _validate_old_adapter_config(self): rollout_cfg = self.config.actor_rollout_ref.rollout actor_loss_cfg = self.config.actor_rollout_ref.actor.diffusion_loss if rollout_cfg.rollout_adapter != "old": raise ValueError("Old-adapter algorithms require actor_rollout_ref.rollout.rollout_adapter=old.") if actor_loss_cfg.loss_mode != "diffusion_nft": raise ValueError( "Old-adapter algorithms require actor_rollout_ref.actor.diffusion_loss.loss_mode=diffusion_nft." ) def init_workers(self): """Initialize actor-only workers for offline, or full stack for online preference training.""" actor_rollout_resource_pool = self._init_colocated_workers() if self.is_offline: self.reward_loop_manager = None self.llm_server_manager = None self.checkpoint_manager = NoOpCheckpointManager() return self._init_online_rollout_stack(actor_rollout_resource_pool) def _validate(self): if self.is_offline and not hasattr(self, "async_rollout_manager"): print("Skipping validation generation because offline rollout is disabled.") return {"val/offline/skipped": 1.0} return super()._validate() def _update_actor(self, batch: DataProto) -> DataProto: rollout_config = self.config.actor_rollout_ref.rollout batch.meta_info["multi_turn"] = rollout_config.multi_turn.enable batch_td = batch.to_tensordict() batch_td = embeds_padding_2_no_padding(batch_td) ppo_mini_batch_size = self.config.actor_rollout_ref.actor.ppo_mini_batch_size paired = self.config.algorithm.get("paired_preference", False) ppo_mini_batch_size = ( ppo_mini_batch_size * 2 if paired else ppo_mini_batch_size * self.config.actor_rollout_ref.rollout.n ) # direct preference has a pair per prompt ppo_epochs = self.config.actor_rollout_ref.actor.ppo_epochs seed = self.config.actor_rollout_ref.actor.data_loader_seed shuffle = self.config.actor_rollout_ref.actor.shuffle if paired and shuffle: sys_logger.warning( "Shuffle is not supported for direct preference during actor update." "This is to prevent the chosen/rejected pairs from being split across different micro batches." "Setting shuffle to False." ) shuffle = False tu.assign_non_tensor( batch_td, global_batch_size=ppo_mini_batch_size, mini_batch_size=ppo_mini_batch_size, epochs=ppo_epochs, seed=seed, dataloader_kwargs={"shuffle": shuffle}, height=self.config.actor_rollout_ref.model.pipeline.height, width=self.config.actor_rollout_ref.model.pipeline.width, vae_scale_factor=self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), ) actor_output = self.actor_rollout_wg.update_actor(batch_td) actor_output = tu.get(actor_output, "metrics") actor_output = rename_dict(actor_output, "actor/") if (actor_mfu := actor_output.pop("actor/mfu", None)) is not None: actor_output["perf/mfu/actor"] = actor_mfu return DataProto.from_single_dict(data={}, meta_info={"metrics": actor_output}) def _compute_ref_noise_pred(self, batch: DataProto) -> Optional[DataProto]: """Reference transformer output and shared flow tensors.""" batch_td = batch.to_tensordict() batch_td = embeds_padding_2_no_padding(batch_td) metadata = { "compute_loss": False, "height": self.config.actor_rollout_ref.model.pipeline.height, "width": self.config.actor_rollout_ref.model.pipeline.width, "vae_scale_factor": self.config.actor_rollout_ref.model.get("vae_scale_factor", 8), } if self.ref_in_actor: metadata["no_lora_adapter"] = True tu.assign_non_tensor(batch_td, **metadata) if self.ref_in_actor: output = self.actor_rollout_wg.infer_actor_batch(batch_td) else: output = self.ref_policy_wg.infer_ref_batch(batch_td) if output is None: return None noise_pred = tu.get(output, "noise_pred") if noise_pred.ndim >= 2 and noise_pred.shape[1] == 1: noise_pred = noise_pred[:, 0] noise = tu.get(output, "noise") if noise.ndim >= 2 and noise.shape[1] == 1: noise = noise[:, 0] timesteps = tu.get(output, "timesteps") if timesteps.ndim >= 2 and timesteps.shape[1] == 1: timesteps = timesteps[:, 0] ref_output = { "ref_noise_pred": noise_pred.float(), "noise": noise.float(), "timesteps": timesteps.float(), } return DataProto.from_tensordict(tu.get_tensordict(ref_output)) def _prepare_actor_batch(self, batch: DataProto, reward_tensor: torch.Tensor) -> DataProto: """Delegate algorithm-specific rollout-to-actor batch preparation.""" reward_tensor = reward_tensor.squeeze(-1).float() if reward_tensor.ndim > 1 else reward_tensor.float() return self._loss_fn.prepare_actor_batch(batch, reward_tensor, self.config) def _update_old_policy(self) -> tuple[bool, float, Literal["none", "copy", "ema"]]: algo_cfg = self.config.algorithm if self.global_steps % algo_cfg.old_policy_update_interval != 0: return False, 0.0, "none" decay = algo_cfg.old_policy_decay if decay is None: decay = old_policy_decay(self.global_steps, algo_cfg.old_policy_decay_schedule) if decay == 0: self.actor_rollout_wg.copy_adapter(source="default", target="old") return True, float(decay), "copy" else: self.actor_rollout_wg.ema_update_adapter(source="default", target="old", decay=decay) return True, float(decay), "ema"
[docs] def fit(self): """ Training loop for direct-preference algorithms (DPO, DiffusionNFT, etc.). Offline algorithms read pre-computed rewards from the dataset. Online algorithms generate rollouts and compute rewards live. """ from omegaconf import OmegaConf from verl.utils.tracking import Tracking logger = Tracking( project_name=self.config.trainer.project_name, experiment_name=self.config.trainer.experiment_name, default_backend=self.config.trainer.logger, config=OmegaConf.to_container(self.config, resolve=True), ) self.global_steps = 0 # load checkpoint and update weights before doing anything self._load_checkpoint() if self._has_old_adapter: self.actor_rollout_wg.copy_adapter(source="default", target="old") self.checkpoint_manager.update_weights(self.global_steps) current_epoch = self.global_steps // len(self.train_dataloader) # perform validation before training # currently, we only support validation using the reward_function. if self.config.trainer.get("val_before_train", True): val_metrics = self._validate() assert val_metrics, f"{val_metrics=}" pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) if self.config.trainer.get("val_only", False): return # add tqdm progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") # we start from step 1 self.global_steps += 1 last_val_metrics = None self.max_steps_duration = 0 # Profiler step state machine. Mirrors verl/trainer/ppo/ray_trainer.py. prev_step_profile = False curr_step_profile = ( self.global_steps in self.config.global_profiler.steps if self.config.global_profiler.steps is not None else False ) next_step_profile = False for epoch in range(current_epoch, self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=False) metrics = {} timing_raw = {} with marked_timer("start_profile", timing_raw): self._start_profiling( not prev_step_profile and curr_step_profile if self.config.global_profiler.profile_continuous_steps else curr_step_profile ) batch: DataProto = DataProto.from_single_dict(batch_dict) if "uid" not in batch.non_tensor_batch: batch.non_tensor_batch["uid"] = np.array( [str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object ) is_last_step = self.global_steps >= self.total_training_steps with marked_timer("step", timing_raw): reward_extra_infos_dict: dict[str, list] = {} if self.is_offline: reward_tensor = batch.batch["sample_level_scores"] with marked_timer("adv", timing_raw, color="brown"): batch.batch["sample_level_scores"] = reward_tensor if reward_extra_infos_dict: batch.non_tensor_batch.update( {k: np.array(v) for k, v in reward_extra_infos_dict.items()} ) batch.batch["sample_level_rewards"] = batch.batch["sample_level_scores"] if self.use_reference_policy: with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): ref_infer_res = self._compute_ref_noise_pred(batch) if ref_infer_res is not None: batch = batch.union(ref_infer_res) with marked_timer("update_actor", timing_raw, color="red"): actor_output = self._update_actor(batch) else: gen_batch = self._get_gen_batch(batch) gen_batch.meta_info["global_steps"] = self.global_steps rollout_seed_cfg = self.config.actor_rollout_ref.rollout.get("seed") if rollout_seed_cfg is not None: gen_batch.meta_info["rollout_seed"] = int(rollout_seed_cfg) + self.global_steps - 1 gen_batch_output = gen_batch.repeat( repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True, ) gen_batch_output.non_tensor_batch["_rollout_seed_global_idx"] = np.arange( len(gen_batch_output), dtype=np.int64 ) with marked_timer("gen", timing_raw, color="red"): gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) self.checkpoint_manager.sleep_replicas() timing_raw.update(gen_batch_output.meta_info["timing"]) gen_batch_output.meta_info.pop("timing", None) batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) with marked_timer("reward", timing_raw, color="yellow"): if self.use_rm and "rm_scores" not in batch.batch.keys(): batch_reward = self._compute_reward_colocate(batch) batch = batch.union(batch_reward) reward_tensor, reward_extra_infos_dict = extract_reward(batch) with marked_timer("prepare_actor_batch", timing_raw, color="brown"): batch.batch["sample_level_scores"] = reward_tensor if reward_extra_infos_dict: batch.non_tensor_batch.update( {k: np.array(v) for k, v in reward_extra_infos_dict.items()} ) batch = self._prepare_actor_batch(batch, reward_tensor) batch.batch["sample_level_rewards"] = batch.batch["sample_level_scores"] if self.use_reference_policy: with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): ref_infer_res = self._compute_ref_noise_pred(batch) if ref_infer_res is not None: batch = batch.union(ref_infer_res) with marked_timer("update_actor", timing_raw, color="red"): actor_output = self._update_actor(batch) if self._has_old_adapter: metrics.update(compute_old_policy_metrics(self._update_old_policy())) # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. esi_close_to_expiration = should_save_ckpt_esi( max_steps_duration=self.max_steps_duration, redundant_time=self.config.trainer.esi_redundant_time, ) # Check if the conditions for saving a checkpoint are met. # The conditions include a mandatory condition (1) and # one of the following optional conditions (2/3/4): # 1. The save frequency is set to a positive value. # 2. It's the last training step. # 3. The current step number is a multiple of the save frequency. # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. if self.config.trainer.save_freq > 0 and ( is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration ): if esi_close_to_expiration: print("Force saving checkpoint: ESI instance expiration approaching.") with marked_timer("save_checkpoint", timing_raw, color="green"): self._save_checkpoint() # update weights from trainer to rollout with marked_timer("update_weights", timing_raw, color="red"): self.checkpoint_manager.update_weights(self.global_steps) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir and not self.is_offline: self._log_rollout_data(batch, reward_extra_infos_dict, timing_raw, rollout_data_dir) # validate if self.config.trainer.test_freq > 0 and ( is_last_step or self.global_steps % self.config.trainer.test_freq == 0 ): with marked_timer("testing", timing_raw, color="green"): val_metrics: dict = self._validate() if is_last_step: last_val_metrics = val_metrics metrics.update(val_metrics) with marked_timer("stop_profile", timing_raw): next_step_profile = ( self.global_steps + 1 in self.config.global_profiler.steps if self.config.global_profiler.steps is not None else False ) self._stop_profiling( curr_step_profile and not next_step_profile if self.config.global_profiler.profile_continuous_steps else curr_step_profile ) prev_step_profile = curr_step_profile curr_step_profile = next_step_profile steps_duration = timing_raw["step"] self.max_steps_duration = max(self.max_steps_duration, steps_duration) # training metrics metrics.update( { "training/global_step": self.global_steps, "training/epoch": epoch, } ) # collect metrics metrics.update(compute_data_metrics_diffusion(batch=batch)) n_gpus = self.resource_pool_manager.get_n_gpus() num_images = ( batch.batch["advantages"].shape[0] if "advantages" in batch.batch else batch.batch["sample_level_scores"].shape[0] ) metrics.update(compute_timing_metrics_diffusion(timing_raw=timing_raw, num_images=num_images)) metrics.update(compute_throughput_metrics_diffusion(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) if "advantages" in batch.batch: gradient_norm = metrics.get("actor/grad_norm", None) metrics.update(compute_variance_proxy_metrics(batch=batch, gradient_norm=gradient_norm)) logger.log(data=metrics, step=self.global_steps) progress_bar.update(1) self.global_steps += 1 if is_last_step: if hasattr(self.actor_rollout_wg, "async_calls_finalize_fn_exec"): self.actor_rollout_wg.async_calls_finalize_fn_exec(blocking=True) pprint(f"Final validation metrics: {last_val_metrics}") progress_bar.close() return # this is experimental and may be changed/removed in the future # in favor of a general-purpose data buffer pool if hasattr(self.train_dataset, "on_batch_end"): # The dataset may be changed after each training batch self.train_dataset.on_batch_end(batch=batch)