# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from omegaconf import MISSING
from verl.base_config import BaseConfig
from verl.utils.profiler import ProfilerConfig
from verl.workers.config.disaggregation import DisaggregationConfig
from verl.workers.config.model import MtpConfig
from verl.workers.config.rollout import (
AgentLoopConfig,
CheckpointEngineConfig,
MultiTurnConfig,
PrometheusConfig,
)
__all__ = [
"DiffusionRolloutAlgoConfig",
"DiffusionPipelineConfig",
"DiffusionSamplingConfig",
"DiffusionRolloutConfig",
]
[docs]
@dataclass
class DiffusionRolloutAlgoConfig(BaseConfig):
"""Algorithm configuration for the SDE-based diffusion rollout."""
noise_level: float = 1.0
sde_type: str = "sde"
sde_window_size: Optional[int] = None
sde_window_range: Optional[list[int]] = None
# MixGRPO-only configs
sample_strategy: str = "random"
iters_per_group: int = 1
sde_window_seed: int = 0
def __post_init__(self):
if self.sample_strategy not in ("random", "progressive"):
raise ValueError(f"Unknown sample_strategy: {self.sample_strategy!r}")
if self.sample_strategy == "progressive" and self.iters_per_group <= 0:
raise ValueError(f"iters_per_group must be positive, got {self.iters_per_group}.")
[docs]
@dataclass
class DiffusionPipelineConfig(BaseConfig):
# for pipeline specific sampling parameters
height: int = 512
width: int = 512
num_inference_steps: int = 10
true_cfg_scale: float = 1.0
max_sequence_length: int = 512
guidance_scale: Optional[float] = None
# Wan2.2 video generation: number of frames (81 = ~3s at 24fps)
num_frames: int = 1
[docs]
@dataclass
class DiffusionSamplingConfig(BaseConfig):
# for validation only
n: int = 1
seed: int = 42
pipeline: DiffusionPipelineConfig = field(default_factory=DiffusionPipelineConfig)
algo: DiffusionRolloutAlgoConfig = field(default_factory=DiffusionRolloutAlgoConfig)
[docs]
@dataclass
class DiffusionRolloutConfig(BaseConfig):
_mutable_fields = {"max_model_len", "load_format", "engine_kwargs", "prompt_length", "expert_parallel_size"}
name: Optional[str] = MISSING
mode: str = "async"
nnodes: int = 0
n_gpus_per_node: int = 8
n: int = 1
# Base seed for deterministic training rollout RNG. Per-step base is
# ``seed + global_step - 1``. null disables rollout seeding.
seed: Optional[int] = None
prompt_length: int = 512
dtype: str = "bfloat16"
gpu_memory_utilization: float = 0.5
enforce_eager: bool = False
cudagraph_capture_sizes: Optional[list] = None
free_cache_engine: bool = True
data_parallel_size: int = 1
expert_parallel_size: int = 1
tensor_model_parallel_size: int = 2
pipeline_model_parallel_size: int = 1
max_num_batched_tokens: int = 8192
logprobs_mode: Optional[str] = "processed_logprobs"
scheduling_policy: Optional[str] = "fcfs"
val_kwargs: DiffusionSamplingConfig = field(default_factory=DiffusionSamplingConfig)
max_model_len: Optional[int] = None
max_num_seqs: int = 1024
# note that the logprob computation should belong to the actor
log_prob_micro_batch_size_per_gpu: Optional[int] = None
disable_log_stats: bool = True
engine_kwargs: dict = field(default_factory=dict)
pipeline: DiffusionPipelineConfig = field(default_factory=DiffusionPipelineConfig)
calculate_log_probs: bool = False
rollout_adapter: str = "default"
agent: AgentLoopConfig = field(default_factory=AgentLoopConfig)
multi_turn: MultiTurnConfig = field(default_factory=MultiTurnConfig)
# Use Prometheus to collect and monitor rollout statistics
prometheus: PrometheusConfig = field(default_factory=PrometheusConfig)
# Checkpoint Engine config for update weights from trainer to rollout
checkpoint_engine: CheckpointEngineConfig = field(default_factory=CheckpointEngineConfig)
enable_chunked_prefill: bool = True
enable_prefix_caching: bool = True
load_format: str = "dummy"
layered_summon: bool = False
skip_tokenizer_init: bool = True
quantization: Optional[str] = None
enable_rollout_routing_replay: bool = False
enable_sleep_mode: bool = True
mtp: Optional[MtpConfig] = field(default_factory=MtpConfig)
profiler: Optional[ProfilerConfig] = None
algo: Optional[DiffusionRolloutAlgoConfig] = field(default_factory=DiffusionRolloutAlgoConfig)
disaggregation: DisaggregationConfig = field(default_factory=DisaggregationConfig)
external_lib: Optional[str] = None
def __post_init__(self):
"""Validate the diffusion rollout config"""
if self.mode == "sync":
raise ValueError(
"Rollout mode 'sync' has been removed. Please set "
"`actor_rollout_ref.rollout.mode=async` or remove the mode setting entirely."
)
if self.rollout_adapter not in ("default", "old"):
raise ValueError(
f"Invalid diffusion rollout rollout_adapter: {self.rollout_adapter}. Must be one of ['default', 'old']."
)
if self.pipeline_model_parallel_size > 1:
if self.name == "vllm_omni":
raise NotImplementedError(
f"Current rollout {self.name=} not implemented pipeline_model_parallel_size > 1 yet."
)