Source code for verl_omni.utils.fsdp_utils

# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP utilities for verl-omni
"""

from collections import OrderedDict
from collections.abc import Callable, Sequence
from contextlib import ExitStack, contextmanager
from functools import partial

from peft.utils.save_and_load import get_peft_model_state_dict
from verl.utils.fsdp_utils import collect_lora_params as _upstream_collect_lora_params
from verl.utils.fsdp_utils import fsdp_version
from verl.utils.fsdp_utils import layered_summon_lora_params as _upstream_layered_summon_lora_params

__all__ = ["collect_lora_params", "fsdp_summon_full_params"]


def _get_fsdp_module_cls():
    try:
        from torch.distributed.fsdp import FSDPModule
    except ImportError:
        from torch.distributed._composable.fsdp import FSDPModule
    return FSDPModule


def _iter_fsdp2_submodules(module):
    fsdp_module_cls = _get_fsdp_module_cls()
    for name, submodule in module.named_modules():
        if isinstance(submodule, fsdp_module_cls) and name != "":
            yield name, submodule


@contextmanager
def fsdp_summon_full_params(module, *, writeback: bool = False, with_grads: bool = False, recurse: bool = True):
    """Summon unsharded params for FSDP1/FSDP2, matching verl's fsdp_merge_unmerge pattern."""
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

    version = fsdp_version(module)
    if version == 0:
        yield
        return
    if version == 1:
        with FSDP.summon_full_params(module, writeback=writeback, recurse=recurse, with_grads=with_grads):
            yield
        return

    submodules = list(_iter_fsdp2_submodules(module))
    if not submodules:
        yield
        return

    with ExitStack() as stack:
        for _, submodule in submodules:
            stack.enter_context(FSDP.summon_full_params(submodule, writeback=writeback, with_grads=with_grads))
        yield


def _param_to_cpu(param):
    if hasattr(param, "full_tensor"):
        return param.full_tensor().detach().cpu()
    return param.detach().cpu()


def _peft_lora_params_to_cpu(peft_model, adapter_name: str) -> OrderedDict:
    lora_params = get_peft_model_state_dict(peft_model, adapter_name=adapter_name)
    return OrderedDict((name, _param_to_cpu(param)) for name, param in lora_params.items())


def _collect_base_weights_to_cpu(peft_model) -> OrderedDict:
    from verl.utils.device import get_device_name

    model = peft_model.base_model.model
    orig_dev = "cpu" if "cpu" in str(next(model.parameters()).device) else get_device_name()
    model = model.to("cpu")
    lora_params = OrderedDict()
    for name, param in model.state_dict().items():
        if any(x in name for x in ["_flat_param", "lora_"]):
            continue
        name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "")
        lora_params[name] = _param_to_cpu(param)
    model = model.to(orig_dev)
    return lora_params


def _collect_base_weights_from_state_dict(state_dict) -> OrderedDict:
    lora_params = OrderedDict()
    for name, param in state_dict.items():
        if any(x in name for x in ["_flat_param", "lora_"]):
            continue
        name = name.replace("_fsdp_wrapped_module.", "").replace(".base_layer", "")
        lora_params[name] = _param_to_cpu(param)
    return lora_params


def _collect_lora_params_non_layered(module, peft_model, adapter_name: str, base_sync_done: bool) -> OrderedDict:
    """Collect LoRA/base params without layered summon for FSDP1/FSDP2/non-FSDP modules."""
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from verl.utils.device import get_torch_device

    version = fsdp_version(module)
    if version == 0:
        if base_sync_done:
            return _peft_lora_params_to_cpu(peft_model, adapter_name)
        return _collect_base_weights_to_cpu(peft_model)

    if version == 1:
        with FSDP.summon_full_params(module, writeback=False):
            if base_sync_done:
                lora_params = _peft_lora_params_to_cpu(peft_model, adapter_name)
            else:
                lora_params = _collect_base_weights_to_cpu(peft_model)
        get_torch_device().empty_cache()
        return lora_params

    lora_params = OrderedDict()
    for name, submodule in _iter_fsdp2_submodules(module):
        with FSDP.summon_full_params(submodule, writeback=False):
            if base_sync_done:
                sub_lora_params = get_peft_model_state_dict(
                    peft_model, state_dict=submodule.state_dict(), adapter_name=adapter_name
                )
                block_prefix = name.replace("_fsdp_wrapped_module.", "")
                for param_name, param in sub_lora_params.items():
                    full_name = f"{block_prefix}.{param_name}" if block_prefix else param_name
                    lora_params[full_name] = _param_to_cpu(param)
            else:
                lora_params.update(_collect_base_weights_from_state_dict(submodule.state_dict()))
    get_torch_device().empty_cache()
    return lora_params


def _collect_lora_params_with_adapter(
    module,
    layered_summon: bool,
    base_sync_done: bool,
    adapter_name: str,
    layered_summon_fn: Callable,
) -> OrderedDict:
    """Verl-style LoRA collection with explicit ``adapter_name`` for PEFT state."""
    peft_model = getattr(module, "_fsdp_wrapped_module", module)
    if fsdp_version(module) > 0:
        if layered_summon:
            if not base_sync_done:
                raise ValueError(
                    "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let "
                    "rollout.load_format=safetensors"
                )
            if layered_summon_fn is _upstream_layered_summon_lora_params:
                return layered_summon_fn(module)
            return layered_summon_fn(module, adapter_name=adapter_name)
        return _collect_lora_params_non_layered(module, peft_model, adapter_name, base_sync_done)

    if base_sync_done:
        return _peft_lora_params_to_cpu(peft_model, adapter_name)
    return _collect_base_weights_to_cpu(peft_model)


def _layered_summon_lora_params_diffusers(
    fsdp_module, adapter_name: str = "default", layer_prefixes: Sequence[str] = ("transformer_blocks.",)
) -> OrderedDict:
    """Layered LoRA param collection for diffusers transformer-block models.

    Args:
        fsdp_module: The FSDP-wrapped module.
        adapter_name: LoRA adapter name.
        layer_prefixes: FSDP layer name prefixes.  Defaults to
            ``["transformer_blocks."]``.
    """
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
    from verl.utils.device import get_torch_device

    def _prefix_submodules(module, prefix):
        for name, submodule in module.named_modules():
            if name.startswith(prefix) and "." not in name[len(prefix) :]:
                yield name, submodule

    lora_params = OrderedDict()
    prefix_list = []
    for lp in layer_prefixes:
        # FSDP1
        prefix_list.append(f"_fsdp_wrapped_module.{lp}")
        # FSDP2
        prefix_list.append(lp)
    peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module)
    for prefix in prefix_list:
        for name, submodule in _prefix_submodules(fsdp_module, prefix):
            block_prefix = name.replace("_fsdp_wrapped_module.", "")
            if name.endswith(".model") or name.endswith(".layers"):
                continue
            if fsdp_version(submodule) > 0:
                with FSDP.summon_full_params(submodule, writeback=False):
                    sub_lora_params = get_peft_model_state_dict(
                        peft_model, state_dict=submodule.state_dict(), adapter_name=adapter_name
                    )
                    sub_lora_params = {
                        f"{block_prefix}.{param_name}": _param_to_cpu(param)
                        for param_name, param in sub_lora_params.items()
                    }
                    lora_params.update(sub_lora_params)
                    submodule._is_root = False
                get_torch_device().empty_cache()
    return lora_params


[docs] def collect_lora_params( module, layered_summon: bool, base_sync_done: bool, is_diffusers: bool = False, adapter_name: str = "default", layer_prefixes: Sequence[str] = ("transformer_blocks.",), ) -> OrderedDict: """Collect LoRA or base parameters for weight sync to the rollout worker. Raises ``RuntimeError`` when no parameters were collected (e.g. mismatched ``layer_prefixes``). Args: module: The FSDP-wrapped or plain module. layered_summon: Summon one FSDP unit at a time instead of the full model. base_sync_done: If ``True``, collect only LoRA weights; else full base weights. is_diffusers: Use the diffusers-specific layered summon helper. adapter_name: LoRA adapter name (usually ``"default"``). layer_prefixes: FSDP layer name prefixes (``["transformer_blocks."]`` """ use_diffusers_layered = is_diffusers and layered_summon and fsdp_version(module) > 0 if adapter_name == "default" and not use_diffusers_layered and fsdp_version(module) != 2: return _upstream_collect_lora_params(module, layered_summon=layered_summon, base_sync_done=base_sync_done) if is_diffusers: layered_summon_fn = partial( _layered_summon_lora_params_diffusers, adapter_name=adapter_name, layer_prefixes=layer_prefixes ) else: layered_summon_fn = _upstream_layered_summon_lora_params lora_params = _collect_lora_params_with_adapter( module, layered_summon=layered_summon, base_sync_done=base_sync_done, adapter_name=adapter_name, layered_summon_fn=layered_summon_fn, ) if not lora_params: raise RuntimeError( f"collect_lora_params collected 0 parameters with prefixes={layer_prefixes}. " "Check ``fsdp_layer_prefixes`` in the model config matches the model's " "FSDP layer naming (e.g. ``['transformer_blocks.']`` for DiT models)." ) return lora_params