Source code for verl_omni.workers.config.diffusion.model

# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Optional

from omegaconf import MISSING
from verl.base_config import BaseConfig
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.import_utils import import_external_libs
from verl.workers.config.model import MtpConfig

from verl_omni.utils.fs import resolve_model_local_dir

from .rollout import DiffusionPipelineConfig, DiffusionRolloutAlgoConfig

__all__ = ["DiffusionModelConfig"]

logger = logging.getLogger(__name__)


[docs] @dataclass class DiffusionModelConfig(BaseConfig): _mutable_fields = { "model_type", "tokenizer_path", "tokenizer", "processor", "local_path", "local_tokenizer_path", "architecture", "transformer_config", } path: str = MISSING architecture: Optional[str] = None # diffusers ``transformer/config.json``; used by DiffusionFlopsCounter for MFU. transformer_config: Optional[dict[str, Any]] = None algorithm: str = MISSING local_path: Optional[str] = None tokenizer_path: Optional[str] = None local_tokenizer_path: Optional[str] = None # model type, e.g., "diffusion_model" model_type: str = "diffusion_model" # whether to load tokenizer. This is useful when we only want to load model config load_tokenizer: bool = True tokenizer: Any = None processor: Any = None # whether to use shared memory use_shm: bool = False trust_remote_code: bool = False # custom chat template for the model custom_chat_template: Optional[str] = None external_lib: Optional[str] = None enable_gradient_checkpointing: bool = True attn_backend: str = "native" lora_rank: int = 0 lora_alpha: int = 64 lora_init_weights: str = "gaussian" target_modules: Optional[Any] = "all-linear" # allow both "all-linear" and ["q_proj","k_proj"] target_parameters: Optional[list[str]] = None # for lora adapter on nn.Parameter exclude_modules: Optional[str] = None # megatron lora config lora: dict[str, Any] = field(default_factory=dict) # path to pre-trained LoRA adapter to load for continued training lora_adapter_path: Optional[str] = None # Named LoRA policy states required by the algorithm. "reference" uses disabled adapters. policy_state_adapters: tuple[str, ...] = ("default",) # dtype to convert LoRA parameters to (e.g., "fp32", "bf16"). Default None means no conversion. lora_dtype: Optional[str] = None mtp: Optional[MtpConfig] = field(default_factory=MtpConfig) pipeline: DiffusionPipelineConfig = field(default_factory=DiffusionPipelineConfig) algo: Optional[DiffusionRolloutAlgoConfig] = field(default_factory=DiffusionRolloutAlgoConfig) fsdp_layer_prefixes: list[str] = field(default_factory=lambda: ["transformer_blocks."]) # Optional model config path. If unset, the backend uses # ``<local_path>/<transformer_subfolder>``. config_path: Optional[str] = None # Subfolder containing the diffusion transformer weights/config. transformer_subfolder: str = "transformer" def __post_init__(self): import_external_libs(self.external_lib) valid_backends = {"native", "_native_npu"} if self.attn_backend not in valid_backends: raise ValueError(f"Invalid attn_backend: {self.attn_backend}. Must be one of {sorted(valid_backends)}") self.local_path = resolve_model_local_dir(self.path, use_shm=self.use_shm) if self.tokenizer_path is None: tokenizer_path = os.path.join(self.local_path, "tokenizer") self.tokenizer_path = tokenizer_path if os.path.exists(tokenizer_path) else self.local_path if self.architecture is None: model_index_path = os.path.join(self.local_path, "model_index.json") with open(model_index_path) as f: self.architecture = json.load(f)["_class_name"] if self.transformer_config is None: config_path = os.path.join(self.local_path, "transformer", "config.json") if os.path.isfile(config_path): try: with open(config_path) as f: self.transformer_config = json.load(f) except (OSError, json.JSONDecodeError) as exc: logger.warning("Diffusion MFU disabled: failed to read %s: %s", config_path, exc) else: logger.warning( "Diffusion MFU disabled: transformer config not found at %s. " "Expected the diffusers pipeline layout `<local_path>/transformer/config.json`.", config_path, ) # construct tokenizer if self.load_tokenizer: self.local_tokenizer_path = copy_to_local(self.tokenizer_path, use_shm=self.use_shm) self.tokenizer = hf_tokenizer( self.local_tokenizer_path, trust_remote_code=self.trust_remote_code, use_fast=True ) if os.path.exists(os.path.join(self.local_path, "processor")): self.processor = hf_processor( os.path.join(self.local_path, "processor"), trust_remote_code=self.trust_remote_code ) else: self.processor = None # Ensure target_modules is a str or list[str] (only if not None) if self.target_modules is not None: if not isinstance(self.target_modules, (str | list)): raise TypeError( "target_modules must be a string or a list of strings, " f"but got {type(self.target_modules).__name__}" ) if isinstance(self.target_modules, list): for x in self.target_modules: if not isinstance(x, str): raise TypeError( f"All elements in target_modules list must be strings, but found {type(x).__name__}" ) def get_processor(self): return self.processor if self.processor is not None else self.tokenizer