# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion-specific algorithm config additions for verl_omni."""
from dataclasses import dataclass, field
from typing import Optional
from verl.base_config import BaseConfig
from verl.trainer.config.algorithm import RolloutCorrectionConfig
from verl_omni.trainer.diffusion.diffusion_trainer_utils import OLD_POLICY_DECAY_SCHEDULES
__all__ = ["DiffusionAlgoConfig", "RolloutCorrectionConfig"]
[docs]
@dataclass
class DiffusionAlgoConfig(BaseConfig):
"""Diffusion-specific algorithm config."""
trainer_type: str = "policy_gradient"
sample_source: str = "online"
adv_estimator: str = "flow_grpo"
norm_adv_by_std_in_grpo: bool = True
global_std: bool = True
old_policy_decay_schedule: str = "copy"
old_policy_decay: Optional[float] = None
old_policy_update_interval: int = 1
timestep_fraction: float = 1.0
adv_mode: str = "continuous"
paired_preference: bool = False # True for pair-based algorithms (e.g. DPO)
rollout_correction: RolloutCorrectionConfig = field(default_factory=RolloutCorrectionConfig)
def __post_init__(self):
valid_adv_modes = {"continuous", "positive_only", "negative_only", "one_only", "binary"}
if self.adv_mode not in valid_adv_modes:
raise ValueError(f"Invalid adv_mode: {self.adv_mode}. Must be one of {sorted(valid_adv_modes)}")
if self.old_policy_decay_schedule not in OLD_POLICY_DECAY_SCHEDULES:
raise ValueError(
f"Invalid old_policy_decay_schedule: {self.old_policy_decay_schedule}. "
f"Must be one of {sorted(OLD_POLICY_DECAY_SCHEDULES)}"
)
if self.old_policy_decay is not None and not 0 <= self.old_policy_decay <= 1:
raise ValueError(f"old_policy_decay must be in [0, 1], got {self.old_policy_decay}.")
if self.old_policy_update_interval <= 0:
raise ValueError(f"old_policy_update_interval must be positive, got {self.old_policy_update_interval}.")
if not 0 < self.timestep_fraction <= 1:
raise ValueError(f"timestep_fraction must be in (0, 1], got {self.timestep_fraction}.")