Source code for verl_omni.workers.engine.fsdp.diffusers_impl

# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FSDP engines for diffusion models."""

import gc
import json
import logging
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Callable, Optional

import torch
import torch.distributed
from tensordict import TensorDict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
from torch.distributed.tensor import DTensor
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import get_device_id, get_device_name
from verl.utils.fsdp_utils import (
    CPUOffloadPolicy,
    FSDPModule,
    MixedPrecisionPolicy,
    apply_fsdp2,
    fsdp2_clip_grad_norm_,
    fsdp2_load_full_state_dict,
    fsdp_version,
    get_fsdp_wrap_policy,
    get_init_weight_context_manager,
    init_fn,
    load_fsdp_model_to_gpu,
    load_fsdp_optimizer,
    offload_fsdp_model_to_cpu,
    offload_fsdp_optimizer,
    replace_lora_wrapper,
)
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import convert_weight_keys
from verl.utils.py_functional import append_to_dict
from verl.workers.config import FSDPEngineConfig, FSDPOptimizerConfig
from verl.workers.engine.base import BaseEngine, BaseEngineCtx, EngineRegistry
from verl.workers.engine.fsdp.utils import create_device_mesh, get_sharding_strategy
from verl.workers.engine.utils import enable_full_determinism, prepare_micro_batches

from verl_omni.pipelines.utils import (
    build_scheduler,
    forward,
    forward_and_sample_previous_step,
    prepare_model_inputs,
    prepare_noisy_latents,
)
from verl_omni.utils.fsdp_utils import collect_lora_params
from verl_omni.workers.config import DiffusionModelConfig
from verl_omni.workers.engine.lora_adapter_mixin import LoRAAdapterMixin

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

device_name = get_device_name()


[docs] class DiffusersFSDPEngine(LoRAAdapterMixin, BaseEngine, ABC): """Base Diffusers engine using PyTorch FullyShardedDataParallel (FSDP). Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism. """
[docs] def __init__( self, model_config: DiffusionModelConfig, engine_config: FSDPEngineConfig, optimizer_config: FSDPOptimizerConfig, checkpoint_config: CheckpointConfig, ): """ Initialize the DiffusersFSDPEngine. Sets up distributed device meshes, LoRA, and offload policies based on config. Args: config: Configuration object with FSDP and model settings. """ super().__init__() self.model_config = model_config self.engine_config = engine_config self.optimizer_config = optimizer_config self.checkpoint_config = checkpoint_config self.mode = None self.rank = torch.distributed.get_rank() self._init_device_mesh() if self.engine_config.full_determinism: enable_full_determinism(seed=self.engine_config.seed) # set FSDP offload params self._is_offload_param = self.engine_config.param_offload self._is_offload_optimizer = self.engine_config.optimizer_offload self._is_lora = self.model_config.lora_rank > 0
@property def is_param_offload_enabled(self) -> bool: return self._is_offload_param @property def is_optimizer_offload_enabled(self) -> bool: return self._is_offload_optimizer def is_mp_src_rank_with_outputs(self): if self.ulysses_device_mesh is not None: is_collect = self.ulysses_device_mesh["ulysses"].get_local_rank() == 0 else: is_collect = True return is_collect
[docs] def initialize(self): """ Build the model, optimizer, and learning rate scheduler under FSDP. Applies device, dtype, and precision configurations, including mixed precision. Sets up checkpoint manager and FLOPs counter. """ # This is used to import external_lib into the huggingface systems self._build_model_optimizer() self.checkpoint_manager = FSDPCheckpointManager( model=self.module, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, processing_class=self.model_config.get_processor(), checkpoint_config=self.checkpoint_config, trust_remote_code=self.model_config.trust_remote_code, ) self.to( device="cpu", model=self._is_offload_param, optimizer=self._is_offload_optimizer, grad=self._is_offload_param, ) log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger)
def _init_device_mesh(self): world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh fsdp_size = self.engine_config.fsdp_size self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.engine_config.ulysses_sequence_parallel_size dp_size = self.get_data_parallel_size() if self.ulysses_sequence_parallel_size > 1: import diffusers from packaging import version if version.parse(diffusers.__version__) < version.parse("0.38.0"): raise RuntimeError( f"Ulysses sequence parallelism requires diffusers >= 0.38.0 (found {diffusers.__version__}). " ) # diffusers' ContextParallelConfig.setup() unconditionally accesses self._mesh["ring", "ulysses"], # so the mesh must have both named dimensions even though ring attention is not used. self.ulysses_device_mesh = init_device_mesh( device_name, mesh_shape=(dp_size, 1, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "ring", "ulysses"], ) self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 # TODO (mike): we will drop this after it supports in diffusers. if self.use_ulysses_sp and self.model_config.attn_backend == "_flash_3_varlen_hub": raise ValueError( "_flash_3_varlen_hub does not support sequence parallelism. " "Set fsdp_config.ulysses_sequence_parallel_size=1 or switch to a different attn_backend." ) def _build_module_from_registry(self, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]: """Try loading via ``DiffusionModelBase.build_module()``. Returns ``None`` if the registry has no custom loader, so the caller falls back to ``diffusers.AutoModel``. """ from verl_omni.pipelines.model_base import DiffusionModelBase model_cls = DiffusionModelBase.get_class(self.model_config) module = model_cls.build_module(self.model_config, torch_dtype) if module is None: return None logger.warning( "Built %s via DiffusionModelBase custom loader; engine-level hooks " "(attention processors, gradient-checkpointing wrappers, LoRA, " "dtype upcast) may be partially effective or silently inactive. " "See the docstring of _build_module_from_registry.", type(module).__name__, ) try: module.to(torch_dtype) except AttributeError: raise TypeError( f"{type(module).__name__} returned by build_module() has no to() method. " "Custom models must be torch.nn.Module instances." ) from None if self.model_config.enable_gradient_checkpointing: try: module.enable_gradient_checkpointing() except AttributeError: raise NotImplementedError( f"Gradient checkpointing is enabled in config, but {type(module).__name__} " "does not implement enable_gradient_checkpointing(). " "Either implement it or set enable_gradient_checkpointing=False." ) from None logger.info( "Gradient checkpointing enabled on %s via enable_gradient_checkpointing().", type(module).__name__, ) module.can_generate = lambda: False return module def _build_module(self): from diffusers import AutoModel from verl.utils.torch_dtypes import PrecisionType torch_dtype = self.engine_config.model_dtype if torch_dtype is None: # if it is training, we force torch_dtype to fp32 torch_dtype = torch.float32 if not self.engine_config.forward_only else torch.bfloat16 torch_dtype = PrecisionType.to_dtype(torch_dtype) module = self._build_module_from_registry(torch_dtype) if module is not None: return module # Default path: load via diffusers AutoModel init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): warnings.simplefilter("ignore") module = AutoModel.from_pretrained( self.model_config.config_path or self.model_config.local_path, torch_dtype=torch_dtype, trust_remote_code=self.model_config.trust_remote_code, subfolder="" if self.model_config.config_path else self.model_config.transformer_subfolder, ) try: module.set_attention_backend(self.model_config.attn_backend) except Exception as e: if self.model_config.attn_backend == "_flash_3_varlen_hub": logger.warning( "Failed to set attention backend to %s (%s). Falling back to 'native' attention backend.", self.model_config.attn_backend, e, ) object.__setattr__(self.model_config, "attn_backend", "native") module.set_attention_backend("native") else: raise e # some parameters may not in torch_dtype module.to(torch_dtype) if self.model_config.enable_gradient_checkpointing: module.enable_gradient_checkpointing() # patch for checkpoint saving def save_config(self, save_directory: str | os.PathLike): output_config_file = os.path.join(save_directory, "config.json") with open(output_config_file, "w", encoding="utf-8") as f: json.dump(self, f, indent=4, sort_keys=True) module.can_generate = lambda: False module.config.save_pretrained = save_config.__get__(module.config) return module def _build_fsdp_module(self, module): # TODO(ziheng): need to improve from torch.distributed.fsdp import CPUOffload, MixedPrecision from verl.utils.torch_dtypes import PrecisionType mixed_precision_config = self.engine_config.mixed_precision if mixed_precision_config is not None: param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32")) buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32")) else: param_dtype = torch.bfloat16 reduce_dtype = torch.float32 buffer_dtype = torch.float32 mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) auto_wrap_policy = get_fsdp_wrap_policy( module=module, config=self.engine_config.wrap_policy, is_lora=self.model_config.lora_rank > 0, ) fsdp_mesh = self.device_mesh sharding_strategy = get_sharding_strategy(fsdp_mesh) # Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation if self.engine_config.strategy == "fsdp": # cpu_offload: # - actor: None # - critic: None # - ref: CPUOffload(offload_params=True) # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if self.engine_config.forward_only: cpu_offload = CPUOffload(offload_params=True) self._is_offload_param = False self._is_offload_optimizer = False module = FSDP( module, param_init_fn=init_fn, auto_wrap_policy=auto_wrap_policy, device_id=get_device_id(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, device_mesh=self.device_mesh, forward_prefetch=self.engine_config.forward_prefetch, use_orig_params=self.engine_config.use_orig_params, cpu_offload=cpu_offload, ) elif self.engine_config.strategy == "fsdp2": # - actor: offload_policy # - critic: offload_policy # - ref: CPUOffloadPolicy(pin_memory=True) assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)" mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True ) offload_policy = None if self.engine_config.offload_policy or self.engine_config.forward_only: self._is_offload_param = False self._is_offload_optimizer = False offload_policy = CPUOffloadPolicy(pin_memory=True) fsdp_kwargs = { "mesh": fsdp_mesh, "mp_policy": mp_policy, "offload_policy": offload_policy, "reshard_after_forward": self.engine_config.reshard_after_forward, } full_state = module.state_dict() apply_fsdp2(module, fsdp_kwargs, self.engine_config) fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy) else: raise NotImplementedError(f"Unknown strategy {self.engine_config.strategy}") if torch.distributed.get_world_size() == 1 and fsdp_version(module) == 1: FSDP.set_state_dict_type( module, state_dict_type=StateDictType.FULL_STATE_DICT, state_dict_config=FullStateDictConfig(), ) elif fsdp_version(module) == 1: FSDP.set_state_dict_type( module, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig(), ) return module def _build_scheduler(self): return build_scheduler(self.model_config) def _build_optimizer(self, module): from verl.workers.config.optimizer import build_optimizer optimizer = build_optimizer(module.parameters(), self.optimizer_config) return optimizer def _build_lr_scheduler(self, optimizer): from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup optim_config = self.optimizer_config total_steps = optim_config.total_training_steps num_warmup_steps = optim_config.lr_warmup_steps lr_scheduler_type = optim_config.lr_scheduler_type min_lr_ratio = optim_config.min_lr_ratio num_cycles = optim_config.num_cycles zero_indexed_step = optim_config.zero_indexed_step if num_warmup_steps <= 0: num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") if lr_scheduler_type == "constant": lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps) elif lr_scheduler_type == "cosine": lr_scheduler = get_cosine_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, zero_indexed_step=zero_indexed_step, ) else: raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") return lr_scheduler def _build_model_optimizer(self): from diffusers import ContextParallelConfig from verl.utils.model import print_model_size # Load base model with specified configuration and dtype module = self._build_module() # Apply LoRA adapters if low-rank adaptation is enabled if self._is_lora: module = self._build_lora_module(module) if self.use_ulysses_sp: sp_size = self.ulysses_sequence_parallel_size module.enable_parallelism( config=ContextParallelConfig(ulysses_degree=sp_size, mesh=self.ulysses_device_mesh) ) # Load diffusion scheduler scheduler = self._build_scheduler() # Synchronize all distributed processes before proceeding torch.distributed.barrier() if self.rank == 0: print_model_size(module) log_gpu_memory_usage("After init model from Diffusers AutoModel", logger=logger) # Wrap model with FSDP for distributed training (sharding, mixed precision, etc.) log_gpu_memory_usage("Before FSDP", logger=None) module = self._build_fsdp_module(module) log_gpu_memory_usage("After FSDP", logger=None) if not self.engine_config.forward_only: # Initialize optimizer with model parameters and config settings optimizer = self._build_optimizer(module) # Create learning rate scheduler with warmup and decay settings lr_scheduler = self._build_lr_scheduler(optimizer) else: optimizer = None lr_scheduler = None self.module = module self.scheduler = scheduler self.optimizer = optimizer self.lr_scheduler = lr_scheduler
[docs] def train_mode(self, **kwargs): """ Return a context manager that switches to training mode with FSDP-specific handling. Includes parameter and optimizer offload entry/exit. """ return EngineTrainModeCtx(self, **kwargs)
[docs] def eval_mode(self, **kwargs): """ Return a context manager that switches to evaluation mode with FSDP-specific handling. Includes activation offload entry/exit. """ return EngineEvalModeCtx(self, **kwargs)
def get_data_parallel_rank(self): if self.ulysses_device_mesh is not None: return self.ulysses_device_mesh["dp"].get_local_rank() else: return torch.distributed.get_rank() def get_data_parallel_size(self): return torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size def get_data_parallel_group(self): if self.ulysses_device_mesh is not None: return self.ulysses_device_mesh.get_group(mesh_dim="dp") else: return torch.distributed.group.WORLD def get_model_parallel_group(self): raise NotImplementedError def get_context_parallel_group(self): raise NotImplementedError def postprocess_batch_func(self, output_lst, indices, data: TensorDict): model_output = {} losses = [] aggregated_metrics = {} for output in output_lst: # model output list model_output_lst = {} if "model_output" in output: for model_output_dict in output["model_output"]: for key, val in model_output_dict.items(): model_output_lst.setdefault(key, []).append(val) for key, val in model_output_lst.items(): model_output.setdefault(key, []).append(torch.stack(val, dim=1)) # (bsz, steps, ...) # loss if "loss" in output: losses.append(output["loss"]) # metrics if "metrics" in output: for metrics in output["metrics"]: append_to_dict(aggregated_metrics, metrics) # concat results from micro batches for key, val in model_output.items(): model_output[key] = torch.concat(val, dim=0) # (global_bsz, steps, ...) output = { "model_output": model_output, # a dict of tensors in shape (global_bsz, steps, ...) "loss": losses, # micro-batch step-wise losses "metrics": aggregated_metrics, } return output @staticmethod def _unpad_nested_embeds(embeds: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert a jagged nested tensor pair (embeds, mask) to dense padded tensors.""" batch_size = embeds.size(0) max_seq_len = max(embeds.offsets().diff()) embed_dim = embeds.size(-1) embeds = torch.nested.to_padded_tensor(embeds, padding=0, output_size=(batch_size, max_seq_len, embed_dim)) mask = torch.nested.to_padded_tensor(mask, padding=0, output_size=(batch_size, max_seq_len)) return embeds, mask @staticmethod def _pad_embeds_for_sp(embeds: torch.Tensor, mask: torch.Tensor, sp_size: int) -> tuple[torch.Tensor, torch.Tensor]: """Pad sequence dimension of (embeds, mask) to a multiple of sp_size.""" seq_len = embeds.size(1) aligned_seq_len = (seq_len + sp_size - 1) // sp_size * sp_size if aligned_seq_len > seq_len: pad_len = aligned_seq_len - seq_len embeds = torch.nn.functional.pad(embeds, (0, 0, 0, pad_len)) mask = torch.nn.functional.pad(mask, (0, pad_len)) return embeds, mask @abstractmethod def forward_backward_batch( self, data: TensorDict, loss_function: Callable, forward_only: bool = False ) -> list[TensorDict]: """Run forward/backward over a batch; implemented by algorithm-specific subclasses.""" pass @abstractmethod def prepare_model_inputs(self, micro_batch: TensorDict, step: int): """Build model inputs for one diffusion step; implemented by algorithm-specific subclasses.""" pass @abstractmethod def prepare_model_outputs(self, output, micro_batch: TensorDict): """Post-process raw model output; implemented by algorithm-specific subclasses.""" pass
[docs] @abstractmethod def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step): """Run one diffusion step forward (and loss); implemented by algorithm-specific subclasses.""" pass
[docs] def optimizer_zero_grad(self): """ Zero gradients and enforce FSDP grad-clipping logic. """ self.optimizer.zero_grad()
[docs] def optimizer_step(self): """ Clip gradients, skip update if non-finite, and step optimizer. Returns: grad_norm (float): Norm of gradients before clipping. """ assert self.optimizer_config.clip_grad is not None if isinstance(self.module, FSDP): grad_norm = self.module.clip_grad_norm_(self.optimizer_config.clip_grad) elif isinstance(self.module, FSDPModule): grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.optimizer_config.clip_grad) else: grad_norm = torch.nn.utils.clip_grad_norm_( self.module.parameters(), max_norm=self.optimizer_config.clip_grad ) if isinstance(grad_norm, DTensor): grad_norm = grad_norm.full_tensor() # if grad_norm is not finite, skip the update if not torch.isfinite(grad_norm): print(f"WARN: grad_norm is not finite: {grad_norm}") self.optimizer.zero_grad() else: self.optimizer.step() return grad_norm.item()
[docs] def lr_scheduler_step(self): """ Advance FSDP scheduler and return updated learning rate. """ self.lr_scheduler.step() lr = self.lr_scheduler.get_last_lr()[0] # only return the first group return lr
[docs] def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): """ Move FSDP model and/or optimizer to CPU or GPU with offload support. Note that this function executes irrespective of offload config. It serves as manual control """ super().to(device=device, model=model, optimizer=optimizer, grad=grad) if self.engine_config.forward_only: # force cpu_offload return device_name = get_device_name() assert device in (device_name, "cpu") if device == device_name: if model: load_fsdp_model_to_gpu(self.module) if optimizer and self.optimizer is not None: load_fsdp_optimizer(self.optimizer, device) gc.collect() elif device == "cpu": if model: offload_fsdp_model_to_cpu(self.module) if optimizer and self.optimizer is not None: offload_fsdp_optimizer(self.optimizer) else: raise ValueError(f"Invalid device type: {device}")
[docs] def save_checkpoint( self, local_path: str, hdfs_path: Optional[str] = None, global_step: int = 0, max_ckpt_to_keep: Optional[int] = None, **kwargs, ) -> None: """ Save FSDP checkpoint, handling parameter offload as needed. """ origin_module_device = next(self.module.parameters()).device.type if self._is_offload_param or origin_module_device == "cpu": load_fsdp_model_to_gpu(self.module) self.checkpoint_manager.save_checkpoint( local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.module) gc.collect() aggressive_empty_cache(force_sync=True)
[docs] def load_checkpoint( self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs ) -> None: """ Load FSDP checkpoint, restoring parameters and optimizer state. """ import torch if self._is_offload_param: load_fsdp_model_to_gpu(self.module) self.checkpoint_manager.load_checkpoint( local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load ) torch.distributed.barrier() if self._is_offload_param: offload_fsdp_model_to_cpu(self.module) if self._is_offload_optimizer: offload_fsdp_optimizer(self.optimizer)
[docs] def get_per_tensor_param( self, layered_summon=False, base_sync_done=False, adapter_name: str | None = None, **kwargs ): log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger) load_fsdp_model_to_gpu(self.module) log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger) peft_config = None peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module) if hasattr(peft_model, "peft_config"): # LoRA peft_config = peft_model.peft_config.get("default", None) adapter_ctx = self.use_adapter(adapter_name) if adapter_name is not None else nullcontext() with adapter_ctx: params = collect_lora_params( module=self.module, layered_summon=layered_summon, base_sync_done=base_sync_done, is_diffusers=True, adapter_name=adapter_name or "default", layer_prefixes=self.model_config.fsdp_layer_prefixes, ) if not base_sync_done: params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()} else: params = self.module.state_dict() params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger) if self._is_offload_param: offload_fsdp_model_to_cpu(self.module) log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger) if peft_config is not None and base_sync_done: per_tensor_param = params.items() else: device = get_device_id() # used when fsdp2 set cpu_offload_policy # TODO: cast fp32 to bf16 to reduce weight sync overhead, need more fine-grained control, e.g MoE gate per_tensor_param = ( ( name, param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True) if isinstance(param, DTensor) else param, ) for name, param in params.items() ) # we need to add the prefix to make it compatible with rollout engine per_tensor_param = ((f"transformer.{name}", tensor) for name, tensor in per_tensor_param) peft_config_dict = peft_config.to_dict() if peft_config is not None else None return per_tensor_param, peft_config_dict
def _run_forward_backward_batch( self, data: TensorDict, loss_function: Callable, forward_only: bool, *, timesteps_key: str, ) -> dict: num_timesteps = data[timesteps_key].shape[1] tu.assign_non_tensor(data, sp_size=self.ulysses_sequence_parallel_size) tu.assign_non_tensor(data, use_dynamic_bsz=False) micro_batches, indices = prepare_micro_batches( data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True ) gradient_accumulation_steps = len(micro_batches) * num_timesteps output_lst = [] ctx = torch.no_grad() if forward_only else nullcontext() for micro_batch in micro_batches: micro_batch = micro_batch.to(get_device_id()) tu.assign_non_tensor(micro_batch, gradient_accumulation_steps=gradient_accumulation_steps) meta_info_lst = {"model_output": [], "loss": [], "metrics": []} # Forward and backward for each timestep with ctx: for step in range(num_timesteps): loss, meta_info = self.forward_step( micro_batch, loss_function=loss_function, forward_only=forward_only, step=step ) if not forward_only: loss.backward() for key, val in meta_info.items(): meta_info_lst[key].append(val) output_lst.append(meta_info_lst) # postprocess and return return self.postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)
[docs] @EngineRegistry.register(model_type="diffusion_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"]) class PPODiffusersFSDPEngine(DiffusersFSDPEngine): """Diffusers FSDP engine with PPO forward/backward and I/O preparation."""
[docs] def forward_backward_batch( self, data: TensorDict, loss_function: Callable, forward_only: bool = False ) -> list[TensorDict]: return self._run_forward_backward_batch(data, loss_function, forward_only, timesteps_key="all_timesteps")
[docs] def prepare_model_inputs(self, micro_batch: TensorDict, step: int): """ Extract and pre-process universal tensors, then delegate architecture-specific input construction to the registered DiffusionModelBase subclass. Handles common tensor extraction and nested-embed unpadding here. Architecture-specific input dict construction is delegated to the model registry. """ latents = micro_batch["all_latents"] timesteps = micro_batch["all_timesteps"] prompt_embeds = micro_batch.get("prompt_embeds", None) prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None) negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None) negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None) sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1 if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.is_nested: prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask) if isinstance(prompt_embeds, torch.Tensor) and sp_size > 1: prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size) if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested: negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds( negative_prompt_embeds, negative_prompt_embeds_mask ) if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1: negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp( negative_prompt_embeds, negative_prompt_embeds_mask, sp_size ) return prepare_model_inputs( module=self.module, model_config=self.model_config, latents=latents, timesteps=timesteps, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_mask=negative_prompt_embeds_mask, micro_batch=micro_batch, step=step, )
[docs] def prepare_model_outputs(self, output, micro_batch: TensorDict): log_prob, prev_sample_mean, std_dev_t, sqrt_dt = output return { "log_probs": log_prob, "prev_sample_mean": prev_sample_mean, "std_dev_t": std_dev_t, "sqrt_dt": sqrt_dt, }
[docs] def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step): model_inputs, negative_model_inputs = self.prepare_model_inputs(micro_batch=micro_batch, step=step) raw_output = forward_and_sample_previous_step( module=self.module, scheduler=self.scheduler, model_config=self.model_config, model_inputs=model_inputs, negative_model_inputs=negative_model_inputs, scheduler_inputs=micro_batch, step=step, ) model_output = self.prepare_model_outputs(output=raw_output, micro_batch=micro_batch) if loss_function is not None: data = tu.get_tensordict( { "old_log_probs": micro_batch["old_log_probs"][:, step], "advantages": micro_batch["advantages"][:, step], }, ) tu.assign_non_tensor( data, gradient_accumulation_steps=tu.get_non_tensor_data( micro_batch, "gradient_accumulation_steps", default=None ), sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None), ) # TODO (mike): refactor the data preparation logic here if micro_batch.get("ref_log_prob", None) is not None: data["ref_log_prob"] = micro_batch["ref_log_prob"][:, step] if micro_batch.get("ref_prev_sample_mean", None) is not None: data["ref_prev_sample_mean"] = micro_batch["ref_prev_sample_mean"][:, step] if micro_batch.get("old_prev_sample_mean", None) is not None: data["old_prev_sample_mean"] = micro_batch["old_prev_sample_mean"][:, step] if micro_batch.get("rollout_is_weights", None) is not None: data["rollout_is_weights"] = micro_batch["rollout_is_weights"][:, step] loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group()) else: assert forward_only, "forward_only must be True when loss_function is None" loss = torch.tensor(1.0, device=device_name) metrics = {} output = { "model_output": model_output, "loss": loss.detach().item(), "metrics": metrics, } return loss, output
@EngineRegistry.register(model_type="diffusion_dpo_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"]) class DPODiffusersFSDPEngine(DiffusersFSDPEngine): """Diffusers FSDP engine variant for diffusion DPO.""" def _prepare_noisy_latents(self, micro_batch: TensorDict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: latents = micro_batch.get("latents_clean", None) if latents is None: raise KeyError("Diffusion DPO training requires `latents_clean` in the micro batch.") return prepare_noisy_latents( latents=latents, scheduler=self.scheduler, noise=micro_batch.get( "noise", None ), # if noise is not provided, sample noise and timesteps in the forward step timesteps=micro_batch.get( "timesteps", None ), # if timesteps is not provided, sample timesteps in the forward step ) def _prepare_prompt_embeds( self, prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, negative_prompt_embeds: torch.Tensor | None, negative_prompt_embeds_mask: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: """Apply common nested-tensor and sequence-parallel padding to prompt embeds.""" sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1 if not isinstance(prompt_embeds_mask, torch.Tensor): prompt_embeds_mask = None if not isinstance(negative_prompt_embeds, torch.Tensor): negative_prompt_embeds = None if not isinstance(negative_prompt_embeds_mask, torch.Tensor): negative_prompt_embeds_mask = None if prompt_embeds.is_nested: prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask) if sp_size > 1: prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size) if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested: negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds( negative_prompt_embeds, negative_prompt_embeds_mask ) if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1: negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp( negative_prompt_embeds, negative_prompt_embeds_mask, sp_size ) return prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask def prepare_model_inputs(self, micro_batch: TensorDict, step: int): del step noisy_latents, noise, timesteps = self._prepare_noisy_latents(micro_batch) latent = micro_batch["latents_clean"].to(device=noise.device, dtype=noise.dtype) prompt_embeds = micro_batch["prompt_embeds"] prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None) negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None) negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None) prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask = ( self._prepare_prompt_embeds( prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_mask=negative_prompt_embeds_mask, ) ) model_inputs, negative_model_inputs = prepare_model_inputs( module=self.module, model_config=self.model_config, latents=noisy_latents, timesteps=timesteps, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_mask=negative_prompt_embeds_mask, micro_batch=micro_batch, step=0, ) return model_inputs, negative_model_inputs, {"noise": noise, "latent": latent, "timesteps": timesteps} def prepare_model_outputs(self, output, micro_batch: TensorDict): del micro_batch noise_pred, dpo_context = output return { "noise_pred": noise_pred, "noise": dpo_context["noise"], "latent": dpo_context["latent"], "timesteps": dpo_context["timesteps"], } def forward_backward_batch( self, data: TensorDict, loss_function: Callable, forward_only: bool = False ) -> list[TensorDict]: tu.assign_non_tensor(data, sp_size=self.ulysses_sequence_parallel_size) tu.assign_non_tensor(data, use_dynamic_bsz=False) micro_batches, indices = prepare_micro_batches( data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True ) gradient_accumulation_steps = len(micro_batches) output_lst = [] ctx = torch.no_grad() if forward_only else nullcontext() for micro_batch in micro_batches: micro_batch = micro_batch.to(get_device_id()) tu.assign_non_tensor(micro_batch, gradient_accumulation_steps=gradient_accumulation_steps) meta_info_lst = {"model_output": [], "loss": [], "metrics": []} with ctx: # DPO is a one-shot flow-matching objective over final image latents, # not a reversed-sampling objective over every rollout timestep. loss, meta_info = self.forward_step( micro_batch, loss_function=loss_function, forward_only=forward_only, step=None, # use random step for DPO ) if not forward_only: loss.backward() for key, val in meta_info.items(): meta_info_lst[key].append(val) output_lst.append(meta_info_lst) # postprocess and return return self.postprocess_batch_func(output_lst=output_lst, indices=indices, data=data) def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step): model_inputs, negative_model_inputs, dpo_context = self.prepare_model_inputs(micro_batch=micro_batch, step=step) noise_pred = forward( module=self.module, model_config=self.model_config, model_inputs=model_inputs, negative_model_inputs=negative_model_inputs, ) model_output = self.prepare_model_outputs(output=(noise_pred, dpo_context), micro_batch=micro_batch) if loss_function is not None: data = tu.get_tensordict({"sample_level_rewards": micro_batch["sample_level_rewards"]}) uid = tu.get_non_tensor_data(micro_batch, "uid", default=None) tu.assign_non_tensor( data, gradient_accumulation_steps=tu.get_non_tensor_data( micro_batch, "gradient_accumulation_steps", default=None ), sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None), ) if uid is not None: tu.assign_non_tensor(data, uid=uid) if micro_batch.get("ref_noise_pred", None) is not None: ref_noise_pred = micro_batch["ref_noise_pred"] if ref_noise_pred.ndim == model_output["noise_pred"].ndim + 1 and ref_noise_pred.shape[1] == 1: ref_noise_pred = ref_noise_pred[:, 0] data["ref_noise_pred"] = ref_noise_pred loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group()) else: assert forward_only, "forward_only must be True when loss_function is None" loss = torch.tensor(1.0, device=device_name) metrics = {} output = { "model_output": model_output, "loss": loss.detach().item(), "metrics": metrics, } return loss, output @EngineRegistry.register(model_type="diffusion_nft_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"]) class NFTDiffusersFSDPEngine(DiffusersFSDPEngine): """Diffusers FSDP engine for direct-preference / forward-process objectives (e.g. DiffusionNFT).""" def forward_backward_batch( self, data: TensorDict, loss_function: Callable, forward_only: bool = False ) -> list[TensorDict]: return self._run_forward_backward_batch(data, loss_function, forward_only, timesteps_key="train_timesteps") def prepare_model_inputs(self, micro_batch: TensorDict, step: int): x0 = micro_batch["latents_clean"] timestep = micro_batch["train_timesteps"][:, step] t = timestep.float() / 1000.0 t_expanded = t.view(-1, *([1] * (x0.ndim - 1))) if micro_batch.get("forward_noise", None) is not None: forward_noise = micro_batch["forward_noise"] noise = forward_noise[:, step] if forward_noise.ndim == x0.ndim + 1 else forward_noise else: noise = torch.randn_like(x0.float()) xt = (1.0 - t_expanded) * x0 + t_expanded * noise prompt_embeds = micro_batch["prompt_embeds"] prompt_embeds_mask = micro_batch["prompt_embeds_mask"] negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None) negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None) sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1 if prompt_embeds.is_nested: prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask) if sp_size > 1: prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size) if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested: negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds( negative_prompt_embeds, negative_prompt_embeds_mask ) if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1: negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp( negative_prompt_embeds, negative_prompt_embeds_mask, sp_size ) model_inputs, negative_model_inputs = prepare_model_inputs( module=self.module, model_config=self.model_config, latents=xt, timesteps=timestep, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds_mask=negative_prompt_embeds_mask, micro_batch=micro_batch, step=step, ) return model_inputs, negative_model_inputs, x0, xt, t_expanded def prepare_model_outputs(self, output, micro_batch: TensorDict): old_prediction, forward_prediction, ref_forward_prediction, x0, xt, t_expanded = output return { "old_prediction": old_prediction, "forward_prediction": forward_prediction, "ref_forward_prediction": ref_forward_prediction, "x0": x0, "xt": xt, "t_expanded": t_expanded, } def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step): model_inputs, negative_model_inputs, x0, xt, t_expanded = self.prepare_model_inputs( micro_batch=micro_batch, step=step ) with self.use_adapter("old"), torch.no_grad(): old_prediction = forward( module=self.module, model_config=self.model_config, model_inputs=model_inputs, negative_model_inputs=negative_model_inputs, ).detach() forward_prediction = forward( module=self.module, model_config=self.model_config, model_inputs=model_inputs, negative_model_inputs=negative_model_inputs, ) with torch.no_grad(): with self.disable_adapter(): ref_forward_prediction = forward( module=self.module, model_config=self.model_config, model_inputs=model_inputs, negative_model_inputs=negative_model_inputs, ).detach() self._set_adapter("default") model_output = self.prepare_model_outputs( output=(old_prediction, forward_prediction, ref_forward_prediction, x0, xt, t_expanded), micro_batch=micro_batch, ) if loss_function is not None: data = tu.get_tensordict({"reward_prob": micro_batch["reward_prob"][:, step]}) tu.assign_non_tensor( data, gradient_accumulation_steps=tu.get_non_tensor_data( micro_batch, "gradient_accumulation_steps", default=None ), sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None), ) loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group()) else: assert forward_only, "forward_only must be True when loss_function is None" loss = torch.tensor(1.0, device=x0.device) metrics = {} output = { "model_output": model_output, "loss": loss.detach().item(), "metrics": metrics, } return loss, output class EngineEvalModeCtx(BaseEngineCtx): def __init__(self, engine: DiffusersFSDPEngine, **kwargs): super().__init__(engine=engine, mode="eval", **kwargs) def __enter__(self): assert isinstance(self.engine, DiffusersFSDPEngine) super().__enter__() self.engine.module.eval() def __exit__(self, exc_type, exc_value, traceback): assert isinstance(self.engine, DiffusersFSDPEngine) # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes # unshard the root FSDP module if self.engine.engine_config.fsdp_size > 1: if fsdp_version(self.engine.module) == 1: self.engine.module._handle.reshard(True) elif fsdp_version(self.engine.module) == 2: self.engine.module.reshard() super().__exit__(exc_type, exc_value, traceback) class EngineTrainModeCtx(BaseEngineCtx): def __init__(self, engine: DiffusersFSDPEngine, **kwargs): super().__init__(engine=engine, mode="train", **kwargs) def __enter__(self): from verl_omni.pipelines.model_base import DiffusionModelBase assert isinstance(self.engine, DiffusersFSDPEngine) super().__enter__() self.engine.module.train() DiffusionModelBase.get_class(self.engine.model_config).configure_train_mode(self.engine.module) def __exit__(self, exc_type, exc_value, traceback): assert isinstance(self.engine, DiffusersFSDPEngine) self.engine.optimizer_zero_grad() super().__exit__(exc_type, exc_value, traceback)