# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FSDP engines for diffusion models."""
import gc
import json
import logging
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Callable, Optional
import torch
import torch.distributed
from tensordict import TensorDict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
from torch.distributed.tensor import DTensor
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import get_device_id, get_device_name
from verl.utils.fsdp_utils import (
CPUOffloadPolicy,
FSDPModule,
MixedPrecisionPolicy,
apply_fsdp2,
fsdp2_clip_grad_norm_,
fsdp2_load_full_state_dict,
fsdp_version,
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
replace_lora_wrapper,
)
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import convert_weight_keys
from verl.utils.py_functional import append_to_dict
from verl.workers.config import FSDPEngineConfig, FSDPOptimizerConfig
from verl.workers.engine.base import BaseEngine, BaseEngineCtx, EngineRegistry
from verl.workers.engine.fsdp.utils import create_device_mesh, get_sharding_strategy
from verl.workers.engine.utils import enable_full_determinism, prepare_micro_batches
from verl_omni.pipelines.utils import (
build_scheduler,
forward,
forward_and_sample_previous_step,
prepare_model_inputs,
prepare_noisy_latents,
)
from verl_omni.utils.fsdp_utils import collect_lora_params
from verl_omni.workers.config import DiffusionModelConfig
from verl_omni.workers.engine.lora_adapter_mixin import LoRAAdapterMixin
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
device_name = get_device_name()
[docs]
class DiffusersFSDPEngine(LoRAAdapterMixin, BaseEngine, ABC):
"""Base Diffusers engine using PyTorch FullyShardedDataParallel (FSDP).
Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism.
"""
[docs]
def __init__(
self,
model_config: DiffusionModelConfig,
engine_config: FSDPEngineConfig,
optimizer_config: FSDPOptimizerConfig,
checkpoint_config: CheckpointConfig,
):
"""
Initialize the DiffusersFSDPEngine.
Sets up distributed device meshes, LoRA, and offload policies based on config.
Args:
config: Configuration object with FSDP and model settings.
"""
super().__init__()
self.model_config = model_config
self.engine_config = engine_config
self.optimizer_config = optimizer_config
self.checkpoint_config = checkpoint_config
self.mode = None
self.rank = torch.distributed.get_rank()
self._init_device_mesh()
if self.engine_config.full_determinism:
enable_full_determinism(seed=self.engine_config.seed)
# set FSDP offload params
self._is_offload_param = self.engine_config.param_offload
self._is_offload_optimizer = self.engine_config.optimizer_offload
self._is_lora = self.model_config.lora_rank > 0
@property
def is_param_offload_enabled(self) -> bool:
return self._is_offload_param
@property
def is_optimizer_offload_enabled(self) -> bool:
return self._is_offload_optimizer
def is_mp_src_rank_with_outputs(self):
if self.ulysses_device_mesh is not None:
is_collect = self.ulysses_device_mesh["ulysses"].get_local_rank() == 0
else:
is_collect = True
return is_collect
[docs]
def initialize(self):
"""
Build the model, optimizer, and learning rate scheduler under FSDP.
Applies device, dtype, and precision configurations, including mixed precision.
Sets up checkpoint manager and FLOPs counter.
"""
# This is used to import external_lib into the huggingface systems
self._build_model_optimizer()
self.checkpoint_manager = FSDPCheckpointManager(
model=self.module,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
processing_class=self.model_config.get_processor(),
checkpoint_config=self.checkpoint_config,
trust_remote_code=self.model_config.trust_remote_code,
)
self.to(
device="cpu",
model=self._is_offload_param,
optimizer=self._is_offload_optimizer,
grad=self._is_offload_param,
)
log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger)
def _init_device_mesh(self):
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.engine_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.engine_config.ulysses_sequence_parallel_size
dp_size = self.get_data_parallel_size()
if self.ulysses_sequence_parallel_size > 1:
import diffusers
from packaging import version
if version.parse(diffusers.__version__) < version.parse("0.38.0"):
raise RuntimeError(
f"Ulysses sequence parallelism requires diffusers >= 0.38.0 (found {diffusers.__version__}). "
)
# diffusers' ContextParallelConfig.setup() unconditionally accesses self._mesh["ring", "ulysses"],
# so the mesh must have both named dimensions even though ring attention is not used.
self.ulysses_device_mesh = init_device_mesh(
device_name,
mesh_shape=(dp_size, 1, self.ulysses_sequence_parallel_size),
mesh_dim_names=["dp", "ring", "ulysses"],
)
self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1
# TODO (mike): we will drop this after it supports in diffusers.
if self.use_ulysses_sp and self.model_config.attn_backend == "_flash_3_varlen_hub":
raise ValueError(
"_flash_3_varlen_hub does not support sequence parallelism. "
"Set fsdp_config.ulysses_sequence_parallel_size=1 or switch to a different attn_backend."
)
def _build_module_from_registry(self, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]:
"""Try loading via ``DiffusionModelBase.build_module()``.
Returns ``None`` if the registry has no custom loader, so the
caller falls back to ``diffusers.AutoModel``.
"""
from verl_omni.pipelines.model_base import DiffusionModelBase
model_cls = DiffusionModelBase.get_class(self.model_config)
module = model_cls.build_module(self.model_config, torch_dtype)
if module is None:
return None
logger.warning(
"Built %s via DiffusionModelBase custom loader; engine-level hooks "
"(attention processors, gradient-checkpointing wrappers, LoRA, "
"dtype upcast) may be partially effective or silently inactive. "
"See the docstring of _build_module_from_registry.",
type(module).__name__,
)
try:
module.to(torch_dtype)
except AttributeError:
raise TypeError(
f"{type(module).__name__} returned by build_module() has no to() method. "
"Custom models must be torch.nn.Module instances."
) from None
if self.model_config.enable_gradient_checkpointing:
try:
module.enable_gradient_checkpointing()
except AttributeError:
raise NotImplementedError(
f"Gradient checkpointing is enabled in config, but {type(module).__name__} "
"does not implement enable_gradient_checkpointing(). "
"Either implement it or set enable_gradient_checkpointing=False."
) from None
logger.info(
"Gradient checkpointing enabled on %s via enable_gradient_checkpointing().",
type(module).__name__,
)
module.can_generate = lambda: False
return module
def _build_module(self):
from diffusers import AutoModel
from verl.utils.torch_dtypes import PrecisionType
torch_dtype = self.engine_config.model_dtype
if torch_dtype is None:
# if it is training, we force torch_dtype to fp32
torch_dtype = torch.float32 if not self.engine_config.forward_only else torch.bfloat16
torch_dtype = PrecisionType.to_dtype(torch_dtype)
module = self._build_module_from_registry(torch_dtype)
if module is not None:
return module
# Default path: load via diffusers AutoModel
init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
module = AutoModel.from_pretrained(
self.model_config.config_path or self.model_config.local_path,
torch_dtype=torch_dtype,
trust_remote_code=self.model_config.trust_remote_code,
subfolder="" if self.model_config.config_path else self.model_config.transformer_subfolder,
)
try:
module.set_attention_backend(self.model_config.attn_backend)
except Exception as e:
if self.model_config.attn_backend == "_flash_3_varlen_hub":
logger.warning(
"Failed to set attention backend to %s (%s). Falling back to 'native' attention backend.",
self.model_config.attn_backend,
e,
)
object.__setattr__(self.model_config, "attn_backend", "native")
module.set_attention_backend("native")
else:
raise e
# some parameters may not in torch_dtype
module.to(torch_dtype)
if self.model_config.enable_gradient_checkpointing:
module.enable_gradient_checkpointing()
# patch for checkpoint saving
def save_config(self, save_directory: str | os.PathLike):
output_config_file = os.path.join(save_directory, "config.json")
with open(output_config_file, "w", encoding="utf-8") as f:
json.dump(self, f, indent=4, sort_keys=True)
module.can_generate = lambda: False
module.config.save_pretrained = save_config.__get__(module.config)
return module
def _build_fsdp_module(self, module):
# TODO(ziheng): need to improve
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from verl.utils.torch_dtypes import PrecisionType
mixed_precision_config = self.engine_config.mixed_precision
if mixed_precision_config is not None:
param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
else:
param_dtype = torch.bfloat16
reduce_dtype = torch.float32
buffer_dtype = torch.float32
mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)
auto_wrap_policy = get_fsdp_wrap_policy(
module=module,
config=self.engine_config.wrap_policy,
is_lora=self.model_config.lora_rank > 0,
)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
# Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation
if self.engine_config.strategy == "fsdp":
# cpu_offload:
# - actor: None
# - critic: None
# - ref: CPUOffload(offload_params=True)
# We force reference policy to use CPUOffload to save memory.
# We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
cpu_offload = None
if self.engine_config.forward_only:
cpu_offload = CPUOffload(offload_params=True)
self._is_offload_param = False
self._is_offload_optimizer = False
module = FSDP(
module,
param_init_fn=init_fn,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
sync_module_states=True,
device_mesh=self.device_mesh,
forward_prefetch=self.engine_config.forward_prefetch,
use_orig_params=self.engine_config.use_orig_params,
cpu_offload=cpu_offload,
)
elif self.engine_config.strategy == "fsdp2":
# - actor: offload_policy
# - critic: offload_policy
# - ref: CPUOffloadPolicy(pin_memory=True)
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
mp_policy = MixedPrecisionPolicy(
param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
)
offload_policy = None
if self.engine_config.offload_policy or self.engine_config.forward_only:
self._is_offload_param = False
self._is_offload_optimizer = False
offload_policy = CPUOffloadPolicy(pin_memory=True)
fsdp_kwargs = {
"mesh": fsdp_mesh,
"mp_policy": mp_policy,
"offload_policy": offload_policy,
"reshard_after_forward": self.engine_config.reshard_after_forward,
}
full_state = module.state_dict()
apply_fsdp2(module, fsdp_kwargs, self.engine_config)
fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy)
else:
raise NotImplementedError(f"Unknown strategy {self.engine_config.strategy}")
if torch.distributed.get_world_size() == 1 and fsdp_version(module) == 1:
FSDP.set_state_dict_type(
module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(),
)
elif fsdp_version(module) == 1:
FSDP.set_state_dict_type(
module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
return module
def _build_scheduler(self):
return build_scheduler(self.model_config)
def _build_optimizer(self, module):
from verl.workers.config.optimizer import build_optimizer
optimizer = build_optimizer(module.parameters(), self.optimizer_config)
return optimizer
def _build_lr_scheduler(self, optimizer):
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
optim_config = self.optimizer_config
total_steps = optim_config.total_training_steps
num_warmup_steps = optim_config.lr_warmup_steps
lr_scheduler_type = optim_config.lr_scheduler_type
min_lr_ratio = optim_config.min_lr_ratio
num_cycles = optim_config.num_cycles
zero_indexed_step = optim_config.zero_indexed_step
if num_warmup_steps <= 0:
num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if lr_scheduler_type == "constant":
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
elif lr_scheduler_type == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
min_lr_ratio=min_lr_ratio,
num_cycles=num_cycles,
zero_indexed_step=zero_indexed_step,
)
else:
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
return lr_scheduler
def _build_model_optimizer(self):
from diffusers import ContextParallelConfig
from verl.utils.model import print_model_size
# Load base model with specified configuration and dtype
module = self._build_module()
# Apply LoRA adapters if low-rank adaptation is enabled
if self._is_lora:
module = self._build_lora_module(module)
if self.use_ulysses_sp:
sp_size = self.ulysses_sequence_parallel_size
module.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=sp_size, mesh=self.ulysses_device_mesh)
)
# Load diffusion scheduler
scheduler = self._build_scheduler()
# Synchronize all distributed processes before proceeding
torch.distributed.barrier()
if self.rank == 0:
print_model_size(module)
log_gpu_memory_usage("After init model from Diffusers AutoModel", logger=logger)
# Wrap model with FSDP for distributed training (sharding, mixed precision, etc.)
log_gpu_memory_usage("Before FSDP", logger=None)
module = self._build_fsdp_module(module)
log_gpu_memory_usage("After FSDP", logger=None)
if not self.engine_config.forward_only:
# Initialize optimizer with model parameters and config settings
optimizer = self._build_optimizer(module)
# Create learning rate scheduler with warmup and decay settings
lr_scheduler = self._build_lr_scheduler(optimizer)
else:
optimizer = None
lr_scheduler = None
self.module = module
self.scheduler = scheduler
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
[docs]
def train_mode(self, **kwargs):
"""
Return a context manager that switches to training mode with FSDP-specific handling.
Includes parameter and optimizer offload entry/exit.
"""
return EngineTrainModeCtx(self, **kwargs)
[docs]
def eval_mode(self, **kwargs):
"""
Return a context manager that switches to evaluation mode with FSDP-specific handling.
Includes activation offload entry/exit.
"""
return EngineEvalModeCtx(self, **kwargs)
def get_data_parallel_rank(self):
if self.ulysses_device_mesh is not None:
return self.ulysses_device_mesh["dp"].get_local_rank()
else:
return torch.distributed.get_rank()
def get_data_parallel_size(self):
return torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size
def get_data_parallel_group(self):
if self.ulysses_device_mesh is not None:
return self.ulysses_device_mesh.get_group(mesh_dim="dp")
else:
return torch.distributed.group.WORLD
def get_model_parallel_group(self):
raise NotImplementedError
def get_context_parallel_group(self):
raise NotImplementedError
def postprocess_batch_func(self, output_lst, indices, data: TensorDict):
model_output = {}
losses = []
aggregated_metrics = {}
for output in output_lst:
# model output list
model_output_lst = {}
if "model_output" in output:
for model_output_dict in output["model_output"]:
for key, val in model_output_dict.items():
model_output_lst.setdefault(key, []).append(val)
for key, val in model_output_lst.items():
model_output.setdefault(key, []).append(torch.stack(val, dim=1)) # (bsz, steps, ...)
# loss
if "loss" in output:
losses.append(output["loss"])
# metrics
if "metrics" in output:
for metrics in output["metrics"]:
append_to_dict(aggregated_metrics, metrics)
# concat results from micro batches
for key, val in model_output.items():
model_output[key] = torch.concat(val, dim=0) # (global_bsz, steps, ...)
output = {
"model_output": model_output, # a dict of tensors in shape (global_bsz, steps, ...)
"loss": losses, # micro-batch step-wise losses
"metrics": aggregated_metrics,
}
return output
@staticmethod
def _unpad_nested_embeds(embeds: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert a jagged nested tensor pair (embeds, mask) to dense padded tensors."""
batch_size = embeds.size(0)
max_seq_len = max(embeds.offsets().diff())
embed_dim = embeds.size(-1)
embeds = torch.nested.to_padded_tensor(embeds, padding=0, output_size=(batch_size, max_seq_len, embed_dim))
mask = torch.nested.to_padded_tensor(mask, padding=0, output_size=(batch_size, max_seq_len))
return embeds, mask
@staticmethod
def _pad_embeds_for_sp(embeds: torch.Tensor, mask: torch.Tensor, sp_size: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad sequence dimension of (embeds, mask) to a multiple of sp_size."""
seq_len = embeds.size(1)
aligned_seq_len = (seq_len + sp_size - 1) // sp_size * sp_size
if aligned_seq_len > seq_len:
pad_len = aligned_seq_len - seq_len
embeds = torch.nn.functional.pad(embeds, (0, 0, 0, pad_len))
mask = torch.nn.functional.pad(mask, (0, pad_len))
return embeds, mask
@abstractmethod
def forward_backward_batch(
self, data: TensorDict, loss_function: Callable, forward_only: bool = False
) -> list[TensorDict]:
"""Run forward/backward over a batch; implemented by algorithm-specific subclasses."""
pass
@abstractmethod
def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
"""Build model inputs for one diffusion step; implemented by algorithm-specific subclasses."""
pass
@abstractmethod
def prepare_model_outputs(self, output, micro_batch: TensorDict):
"""Post-process raw model output; implemented by algorithm-specific subclasses."""
pass
[docs]
@abstractmethod
def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
"""Run one diffusion step forward (and loss); implemented by algorithm-specific subclasses."""
pass
[docs]
def optimizer_zero_grad(self):
"""
Zero gradients and enforce FSDP grad-clipping logic.
"""
self.optimizer.zero_grad()
[docs]
def optimizer_step(self):
"""
Clip gradients, skip update if non-finite, and step optimizer.
Returns:
grad_norm (float): Norm of gradients before clipping.
"""
assert self.optimizer_config.clip_grad is not None
if isinstance(self.module, FSDP):
grad_norm = self.module.clip_grad_norm_(self.optimizer_config.clip_grad)
elif isinstance(self.module, FSDPModule):
grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.optimizer_config.clip_grad)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.module.parameters(), max_norm=self.optimizer_config.clip_grad
)
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
# if grad_norm is not finite, skip the update
if not torch.isfinite(grad_norm):
print(f"WARN: grad_norm is not finite: {grad_norm}")
self.optimizer.zero_grad()
else:
self.optimizer.step()
return grad_norm.item()
[docs]
def lr_scheduler_step(self):
"""
Advance FSDP scheduler and return updated learning rate.
"""
self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0] # only return the first group
return lr
[docs]
def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
"""
Move FSDP model and/or optimizer to CPU or GPU with offload support.
Note that this function executes irrespective of offload config. It serves as manual control
"""
super().to(device=device, model=model, optimizer=optimizer, grad=grad)
if self.engine_config.forward_only:
# force cpu_offload
return
device_name = get_device_name()
assert device in (device_name, "cpu")
if device == device_name:
if model:
load_fsdp_model_to_gpu(self.module)
if optimizer and self.optimizer is not None:
load_fsdp_optimizer(self.optimizer, device)
gc.collect()
elif device == "cpu":
if model:
offload_fsdp_model_to_cpu(self.module)
if optimizer and self.optimizer is not None:
offload_fsdp_optimizer(self.optimizer)
else:
raise ValueError(f"Invalid device type: {device}")
[docs]
def save_checkpoint(
self,
local_path: str,
hdfs_path: Optional[str] = None,
global_step: int = 0,
max_ckpt_to_keep: Optional[int] = None,
**kwargs,
) -> None:
"""
Save FSDP checkpoint, handling parameter offload as needed.
"""
origin_module_device = next(self.module.parameters()).device.type
if self._is_offload_param or origin_module_device == "cpu":
load_fsdp_model_to_gpu(self.module)
self.checkpoint_manager.save_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
gc.collect()
aggressive_empty_cache(force_sync=True)
[docs]
def load_checkpoint(
self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
) -> None:
"""
Load FSDP checkpoint, restoring parameters and optimizer state.
"""
import torch
if self._is_offload_param:
load_fsdp_model_to_gpu(self.module)
self.checkpoint_manager.load_checkpoint(
local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
)
torch.distributed.barrier()
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
if self._is_offload_optimizer:
offload_fsdp_optimizer(self.optimizer)
[docs]
def get_per_tensor_param(
self, layered_summon=False, base_sync_done=False, adapter_name: str | None = None, **kwargs
):
log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger)
load_fsdp_model_to_gpu(self.module)
log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger)
peft_config = None
peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module)
if hasattr(peft_model, "peft_config"): # LoRA
peft_config = peft_model.peft_config.get("default", None)
adapter_ctx = self.use_adapter(adapter_name) if adapter_name is not None else nullcontext()
with adapter_ctx:
params = collect_lora_params(
module=self.module,
layered_summon=layered_summon,
base_sync_done=base_sync_done,
is_diffusers=True,
adapter_name=adapter_name or "default",
layer_prefixes=self.model_config.fsdp_layer_prefixes,
)
if not base_sync_done:
params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
else:
params = self.module.state_dict()
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.module)
log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger)
if peft_config is not None and base_sync_done:
per_tensor_param = params.items()
else:
device = get_device_id() # used when fsdp2 set cpu_offload_policy
# TODO: cast fp32 to bf16 to reduce weight sync overhead, need more fine-grained control, e.g MoE gate
per_tensor_param = (
(
name,
param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True)
if isinstance(param, DTensor)
else param,
)
for name, param in params.items()
)
# we need to add the prefix to make it compatible with rollout engine
per_tensor_param = ((f"transformer.{name}", tensor) for name, tensor in per_tensor_param)
peft_config_dict = peft_config.to_dict() if peft_config is not None else None
return per_tensor_param, peft_config_dict
def _run_forward_backward_batch(
self,
data: TensorDict,
loss_function: Callable,
forward_only: bool,
*,
timesteps_key: str,
) -> dict:
num_timesteps = data[timesteps_key].shape[1]
tu.assign_non_tensor(data, sp_size=self.ulysses_sequence_parallel_size)
tu.assign_non_tensor(data, use_dynamic_bsz=False)
micro_batches, indices = prepare_micro_batches(
data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True
)
gradient_accumulation_steps = len(micro_batches) * num_timesteps
output_lst = []
ctx = torch.no_grad() if forward_only else nullcontext()
for micro_batch in micro_batches:
micro_batch = micro_batch.to(get_device_id())
tu.assign_non_tensor(micro_batch, gradient_accumulation_steps=gradient_accumulation_steps)
meta_info_lst = {"model_output": [], "loss": [], "metrics": []}
# Forward and backward for each timestep
with ctx:
for step in range(num_timesteps):
loss, meta_info = self.forward_step(
micro_batch, loss_function=loss_function, forward_only=forward_only, step=step
)
if not forward_only:
loss.backward()
for key, val in meta_info.items():
meta_info_lst[key].append(val)
output_lst.append(meta_info_lst)
# postprocess and return
return self.postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)
[docs]
@EngineRegistry.register(model_type="diffusion_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
class PPODiffusersFSDPEngine(DiffusersFSDPEngine):
"""Diffusers FSDP engine with PPO forward/backward and I/O preparation."""
[docs]
def forward_backward_batch(
self, data: TensorDict, loss_function: Callable, forward_only: bool = False
) -> list[TensorDict]:
return self._run_forward_backward_batch(data, loss_function, forward_only, timesteps_key="all_timesteps")
[docs]
def prepare_model_outputs(self, output, micro_batch: TensorDict):
log_prob, prev_sample_mean, std_dev_t, sqrt_dt = output
return {
"log_probs": log_prob,
"prev_sample_mean": prev_sample_mean,
"std_dev_t": std_dev_t,
"sqrt_dt": sqrt_dt,
}
[docs]
def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
model_inputs, negative_model_inputs = self.prepare_model_inputs(micro_batch=micro_batch, step=step)
raw_output = forward_and_sample_previous_step(
module=self.module,
scheduler=self.scheduler,
model_config=self.model_config,
model_inputs=model_inputs,
negative_model_inputs=negative_model_inputs,
scheduler_inputs=micro_batch,
step=step,
)
model_output = self.prepare_model_outputs(output=raw_output, micro_batch=micro_batch)
if loss_function is not None:
data = tu.get_tensordict(
{
"old_log_probs": micro_batch["old_log_probs"][:, step],
"advantages": micro_batch["advantages"][:, step],
},
)
tu.assign_non_tensor(
data,
gradient_accumulation_steps=tu.get_non_tensor_data(
micro_batch, "gradient_accumulation_steps", default=None
),
sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None),
)
# TODO (mike): refactor the data preparation logic here
if micro_batch.get("ref_log_prob", None) is not None:
data["ref_log_prob"] = micro_batch["ref_log_prob"][:, step]
if micro_batch.get("ref_prev_sample_mean", None) is not None:
data["ref_prev_sample_mean"] = micro_batch["ref_prev_sample_mean"][:, step]
if micro_batch.get("old_prev_sample_mean", None) is not None:
data["old_prev_sample_mean"] = micro_batch["old_prev_sample_mean"][:, step]
if micro_batch.get("rollout_is_weights", None) is not None:
data["rollout_is_weights"] = micro_batch["rollout_is_weights"][:, step]
loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())
else:
assert forward_only, "forward_only must be True when loss_function is None"
loss = torch.tensor(1.0, device=device_name)
metrics = {}
output = {
"model_output": model_output,
"loss": loss.detach().item(),
"metrics": metrics,
}
return loss, output
@EngineRegistry.register(model_type="diffusion_dpo_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
class DPODiffusersFSDPEngine(DiffusersFSDPEngine):
"""Diffusers FSDP engine variant for diffusion DPO."""
def _prepare_noisy_latents(self, micro_batch: TensorDict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
latents = micro_batch.get("latents_clean", None)
if latents is None:
raise KeyError("Diffusion DPO training requires `latents_clean` in the micro batch.")
return prepare_noisy_latents(
latents=latents,
scheduler=self.scheduler,
noise=micro_batch.get(
"noise", None
), # if noise is not provided, sample noise and timesteps in the forward step
timesteps=micro_batch.get(
"timesteps", None
), # if timesteps is not provided, sample timesteps in the forward step
)
def _prepare_prompt_embeds(
self,
prompt_embeds: torch.Tensor,
prompt_embeds_mask: torch.Tensor | None,
negative_prompt_embeds: torch.Tensor | None,
negative_prompt_embeds_mask: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
"""Apply common nested-tensor and sequence-parallel padding to prompt embeds."""
sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1
if not isinstance(prompt_embeds_mask, torch.Tensor):
prompt_embeds_mask = None
if not isinstance(negative_prompt_embeds, torch.Tensor):
negative_prompt_embeds = None
if not isinstance(negative_prompt_embeds_mask, torch.Tensor):
negative_prompt_embeds_mask = None
if prompt_embeds.is_nested:
prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask)
if sp_size > 1:
prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size)
if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested:
negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds(
negative_prompt_embeds, negative_prompt_embeds_mask
)
if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1:
negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp(
negative_prompt_embeds, negative_prompt_embeds_mask, sp_size
)
return prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask
def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
del step
noisy_latents, noise, timesteps = self._prepare_noisy_latents(micro_batch)
latent = micro_batch["latents_clean"].to(device=noise.device, dtype=noise.dtype)
prompt_embeds = micro_batch["prompt_embeds"]
prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None)
negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None)
negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None)
prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask = (
self._prepare_prompt_embeds(
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
)
)
model_inputs, negative_model_inputs = prepare_model_inputs(
module=self.module,
model_config=self.model_config,
latents=noisy_latents,
timesteps=timesteps,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
micro_batch=micro_batch,
step=0,
)
return model_inputs, negative_model_inputs, {"noise": noise, "latent": latent, "timesteps": timesteps}
def prepare_model_outputs(self, output, micro_batch: TensorDict):
del micro_batch
noise_pred, dpo_context = output
return {
"noise_pred": noise_pred,
"noise": dpo_context["noise"],
"latent": dpo_context["latent"],
"timesteps": dpo_context["timesteps"],
}
def forward_backward_batch(
self, data: TensorDict, loss_function: Callable, forward_only: bool = False
) -> list[TensorDict]:
tu.assign_non_tensor(data, sp_size=self.ulysses_sequence_parallel_size)
tu.assign_non_tensor(data, use_dynamic_bsz=False)
micro_batches, indices = prepare_micro_batches(
data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True
)
gradient_accumulation_steps = len(micro_batches)
output_lst = []
ctx = torch.no_grad() if forward_only else nullcontext()
for micro_batch in micro_batches:
micro_batch = micro_batch.to(get_device_id())
tu.assign_non_tensor(micro_batch, gradient_accumulation_steps=gradient_accumulation_steps)
meta_info_lst = {"model_output": [], "loss": [], "metrics": []}
with ctx:
# DPO is a one-shot flow-matching objective over final image latents,
# not a reversed-sampling objective over every rollout timestep.
loss, meta_info = self.forward_step(
micro_batch,
loss_function=loss_function,
forward_only=forward_only,
step=None, # use random step for DPO
)
if not forward_only:
loss.backward()
for key, val in meta_info.items():
meta_info_lst[key].append(val)
output_lst.append(meta_info_lst)
# postprocess and return
return self.postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)
def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
model_inputs, negative_model_inputs, dpo_context = self.prepare_model_inputs(micro_batch=micro_batch, step=step)
noise_pred = forward(
module=self.module,
model_config=self.model_config,
model_inputs=model_inputs,
negative_model_inputs=negative_model_inputs,
)
model_output = self.prepare_model_outputs(output=(noise_pred, dpo_context), micro_batch=micro_batch)
if loss_function is not None:
data = tu.get_tensordict({"sample_level_rewards": micro_batch["sample_level_rewards"]})
uid = tu.get_non_tensor_data(micro_batch, "uid", default=None)
tu.assign_non_tensor(
data,
gradient_accumulation_steps=tu.get_non_tensor_data(
micro_batch, "gradient_accumulation_steps", default=None
),
sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None),
)
if uid is not None:
tu.assign_non_tensor(data, uid=uid)
if micro_batch.get("ref_noise_pred", None) is not None:
ref_noise_pred = micro_batch["ref_noise_pred"]
if ref_noise_pred.ndim == model_output["noise_pred"].ndim + 1 and ref_noise_pred.shape[1] == 1:
ref_noise_pred = ref_noise_pred[:, 0]
data["ref_noise_pred"] = ref_noise_pred
loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())
else:
assert forward_only, "forward_only must be True when loss_function is None"
loss = torch.tensor(1.0, device=device_name)
metrics = {}
output = {
"model_output": model_output,
"loss": loss.detach().item(),
"metrics": metrics,
}
return loss, output
@EngineRegistry.register(model_type="diffusion_nft_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
class NFTDiffusersFSDPEngine(DiffusersFSDPEngine):
"""Diffusers FSDP engine for direct-preference / forward-process objectives (e.g. DiffusionNFT)."""
def forward_backward_batch(
self, data: TensorDict, loss_function: Callable, forward_only: bool = False
) -> list[TensorDict]:
return self._run_forward_backward_batch(data, loss_function, forward_only, timesteps_key="train_timesteps")
def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
x0 = micro_batch["latents_clean"]
timestep = micro_batch["train_timesteps"][:, step]
t = timestep.float() / 1000.0
t_expanded = t.view(-1, *([1] * (x0.ndim - 1)))
if micro_batch.get("forward_noise", None) is not None:
forward_noise = micro_batch["forward_noise"]
noise = forward_noise[:, step] if forward_noise.ndim == x0.ndim + 1 else forward_noise
else:
noise = torch.randn_like(x0.float())
xt = (1.0 - t_expanded) * x0 + t_expanded * noise
prompt_embeds = micro_batch["prompt_embeds"]
prompt_embeds_mask = micro_batch["prompt_embeds_mask"]
negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None)
negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None)
sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1
if prompt_embeds.is_nested:
prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask)
if sp_size > 1:
prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size)
if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested:
negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds(
negative_prompt_embeds, negative_prompt_embeds_mask
)
if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1:
negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp(
negative_prompt_embeds, negative_prompt_embeds_mask, sp_size
)
model_inputs, negative_model_inputs = prepare_model_inputs(
module=self.module,
model_config=self.model_config,
latents=xt,
timesteps=timestep,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
micro_batch=micro_batch,
step=step,
)
return model_inputs, negative_model_inputs, x0, xt, t_expanded
def prepare_model_outputs(self, output, micro_batch: TensorDict):
old_prediction, forward_prediction, ref_forward_prediction, x0, xt, t_expanded = output
return {
"old_prediction": old_prediction,
"forward_prediction": forward_prediction,
"ref_forward_prediction": ref_forward_prediction,
"x0": x0,
"xt": xt,
"t_expanded": t_expanded,
}
def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
model_inputs, negative_model_inputs, x0, xt, t_expanded = self.prepare_model_inputs(
micro_batch=micro_batch, step=step
)
with self.use_adapter("old"), torch.no_grad():
old_prediction = forward(
module=self.module,
model_config=self.model_config,
model_inputs=model_inputs,
negative_model_inputs=negative_model_inputs,
).detach()
forward_prediction = forward(
module=self.module,
model_config=self.model_config,
model_inputs=model_inputs,
negative_model_inputs=negative_model_inputs,
)
with torch.no_grad():
with self.disable_adapter():
ref_forward_prediction = forward(
module=self.module,
model_config=self.model_config,
model_inputs=model_inputs,
negative_model_inputs=negative_model_inputs,
).detach()
self._set_adapter("default")
model_output = self.prepare_model_outputs(
output=(old_prediction, forward_prediction, ref_forward_prediction, x0, xt, t_expanded),
micro_batch=micro_batch,
)
if loss_function is not None:
data = tu.get_tensordict({"reward_prob": micro_batch["reward_prob"][:, step]})
tu.assign_non_tensor(
data,
gradient_accumulation_steps=tu.get_non_tensor_data(
micro_batch, "gradient_accumulation_steps", default=None
),
sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None),
)
loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())
else:
assert forward_only, "forward_only must be True when loss_function is None"
loss = torch.tensor(1.0, device=x0.device)
metrics = {}
output = {
"model_output": model_output,
"loss": loss.detach().item(),
"metrics": metrics,
}
return loss, output
class EngineEvalModeCtx(BaseEngineCtx):
def __init__(self, engine: DiffusersFSDPEngine, **kwargs):
super().__init__(engine=engine, mode="eval", **kwargs)
def __enter__(self):
assert isinstance(self.engine, DiffusersFSDPEngine)
super().__enter__()
self.engine.module.eval()
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, DiffusersFSDPEngine)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.engine.engine_config.fsdp_size > 1:
if fsdp_version(self.engine.module) == 1:
self.engine.module._handle.reshard(True)
elif fsdp_version(self.engine.module) == 2:
self.engine.module.reshard()
super().__exit__(exc_type, exc_value, traceback)
class EngineTrainModeCtx(BaseEngineCtx):
def __init__(self, engine: DiffusersFSDPEngine, **kwargs):
super().__init__(engine=engine, mode="train", **kwargs)
def __enter__(self):
from verl_omni.pipelines.model_base import DiffusionModelBase
assert isinstance(self.engine, DiffusersFSDPEngine)
super().__enter__()
self.engine.module.train()
DiffusionModelBase.get_class(self.engine.model_config).configure_train_mode(self.engine.module)
def __exit__(self, exc_type, exc_value, traceback):
assert isinstance(self.engine, DiffusersFSDPEngine)
self.engine.optimizer_zero_grad()
super().__exit__(exc_type, exc_value, traceback)