# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import MISSING
from verl.base_config import BaseConfig
from verl.trainer.config import CheckpointConfig
from verl.trainer.config.algorithm import RolloutCorrectionConfig
from verl.utils.profiler import ProfilerConfig
from verl.workers.config.engine import EngineConfig, FSDPEngineConfig
from verl.workers.config.optimizer import OptimizerConfig
from .model import DiffusionModelConfig
__all__ = [
"DiffusionLossConfig",
"VeOmniDiffusionEngineConfig",
"VeOmniDiffusionOptimizerConfig",
"DiffusionActorConfig",
"FSDPDiffusionActorConfig",
"VeOmniDiffusionActorConfig",
]
[docs]
@dataclass
class DiffusionLossConfig(BaseConfig):
loss_mode: str = "flow_grpo"
clip_ratio: float = 0.0001
adv_clip_max: float = 5.0
mix_beta: float = 0.5
ref_kl_coef: float = 0.0
adaptive_weight_min: float = 1e-5
dpo_beta: float = 2000.0
def __post_init__(self):
"""Validate diffusion loss configuration."""
valid_modes = ["flow_grpo", "grpo_guard", "diffusion_nft", "dpo", "dance_grpo"]
if self.loss_mode not in valid_modes:
raise ValueError(f"Invalid diffusion loss_mode: {self.loss_mode}. Must be one of {valid_modes}")
if self.adv_clip_max <= 0:
raise ValueError(f"Diffusion adv_clip_max must be positive, got {self.adv_clip_max}.")
if self.mix_beta <= 0:
raise ValueError(f"mix_beta must be positive, got {self.mix_beta}.")
if self.adaptive_weight_min <= 0:
raise ValueError(f"adaptive_weight_min must be positive, got {self.adaptive_weight_min}.")
@dataclass
class VeOmniDiffusionEngineConfig(EngineConfig):
_mutable_fields = EngineConfig._mutable_fields | {"ulysses_parallel_size"}
# VeOmni diffusion backend only supports FSDP2.
strategy: str = "veomni"
fsdp_size: int = -1
ulysses_parallel_size: int = 1
expert_parallel_size: int = 1
init_device: str = "meta"
reshard_after_forward: bool = True
forward_prefetch: bool = True
model_dtype: str = "bfloat16"
mixed_precision: bool = True
mixed_precision_param_dtype: str = "bfloat16"
mixed_precision_reduce_dtype: str = "float32"
mixed_precision_output_dtype: Optional[str] = None
mixed_precision_cast_forward_inputs: bool = True
enable_reentrant: bool = False
enable_activation_offload: bool = False
activation_gpu_limit: float = 0.0
attn_implementation: str = "eager"
moe_implementation: str = "eager"
cross_entropy_loss_implementation: str = "eager"
rms_norm_implementation: str = "eager"
swiglu_mlp_implementation: str = "eager"
rotary_pos_emb_implementation: str = "eager"
load_balancing_loss_implementation: str = "eager"
rms_norm_gated_implementation: str = "eager"
causal_conv1d_implementation: str = "eager"
chunk_gated_delta_rule_implementation: str = "eager"
def __post_init__(self):
super().__post_init__()
if self.strategy != "veomni":
raise ValueError(f"VeOmni diffusion engine requires strategy='veomni', got {self.strategy!r}")
if self.ulysses_parallel_size != 1:
raise NotImplementedError("VeOmni Qwen-Image diffusion backend does not support Ulysses SP yet.")
@dataclass
class VeOmniDiffusionOptimizerConfig(OptimizerConfig):
optimizer: str = "adamw"
lr_min: float = 0.0
lr_start: float = 0.0
lr_decay_ratio: float = 1.0
lr_scheduler_type: str = "constant"
eps: float = 1e-8
fused: bool = False
def __post_init__(self):
super().__post_init__()
if self.lr_scheduler_type not in {"constant", "linear", "cosine"}:
raise ValueError(
f"Invalid VeOmni lr_scheduler_type={self.lr_scheduler_type!r}; "
"expected one of ['constant', 'linear', 'cosine']."
)
[docs]
@dataclass
class DiffusionActorConfig(BaseConfig):
_mutable_fields = BaseConfig._mutable_fields | {
"ppo_mini_batch_size",
"ppo_micro_batch_size_per_gpu",
"engine",
"model_config",
}
strategy: str = MISSING
ppo_mini_batch_size: int = 256
ppo_micro_batch_size_per_gpu: int = MISSING
diffusion_loss: DiffusionLossConfig = field(default_factory=DiffusionLossConfig)
loss_scale_factor: Optional[float] = None
use_kl_loss: bool = False
kl_loss_coef: float = 0.001
ppo_epochs: int = 1
shuffle: bool = False
data_loader_seed: int = 42
checkpoint: CheckpointConfig = field(default_factory=CheckpointConfig)
optim: OptimizerConfig = field(default_factory=OptimizerConfig)
engine: BaseConfig = field(default_factory=BaseConfig)
rollout_n: int = MISSING # must be override by sampling config
model_config: DiffusionModelConfig = field(default_factory=BaseConfig)
log_prob_micro_batch_size_per_gpu: Optional[int] = None
profiler: Optional[ProfilerConfig] = None
# Store global batch info for loss aggregation:
# dp_size: data parallel size
# global_batch_size: global batch size
global_batch_info: dict = field(default_factory=dict)
# Rollout Correction config.
# When bypass_mode=True, ``diffusion_loss`` computes per-step RS from here.
rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig)
def __post_init__(self):
"""Validate diffusion actor configuration parameters."""
assert self.strategy != MISSING
assert self.rollout_n != MISSING
[docs]
@dataclass
class FSDPDiffusionActorConfig(DiffusionActorConfig):
# Training strategy: fsdp or fsdp2
strategy: str = "fsdp"
grad_clip: float = 1.0
fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
def __post_init__(self):
"""Validate diffusion FSDP actor configuration parameters."""
super().__post_init__()
self.engine = self.fsdp_config
# Sync strategy to engine config so engine_workers can pick the right FSDP version.
# EngineConfig.strategy defaults to None, so without this, engine_workers.py always
# falls back to FSDP1 even when actor.strategy="fsdp2".
object.__setattr__(self.engine, "strategy", self.strategy)
@dataclass
class VeOmniDiffusionActorConfig(DiffusionActorConfig):
strategy: str = "veomni"
veomni_config: VeOmniDiffusionEngineConfig = field(default_factory=VeOmniDiffusionEngineConfig)
optim: VeOmniDiffusionOptimizerConfig = field(default_factory=VeOmniDiffusionOptimizerConfig)
def __post_init__(self):
super().__post_init__()
self.engine = self.veomni_config