# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FSDP engines for diffusion models."""

import gc
import json
import logging
import os
import warnings
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import Callable, Optional

import torch
import torch.distributed
from tensordict import TensorDict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
from torch.distributed.tensor import DTensor
from verl.trainer.config import CheckpointConfig
from verl.utils import tensordict_utils as tu
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.debug import log_gpu_memory_usage
from verl.utils.device import get_device_id, get_device_name
from verl.utils.fsdp_utils import (
    CPUOffloadPolicy,
    FSDPModule,
    MixedPrecisionPolicy,
    apply_fsdp2,
    fsdp2_clip_grad_norm_,
    fsdp2_load_full_state_dict,
    fsdp_version,
    get_fsdp_wrap_policy,
    get_init_weight_context_manager,
    init_fn,
    load_fsdp_model_to_gpu,
    load_fsdp_optimizer,
    offload_fsdp_model_to_cpu,
    offload_fsdp_optimizer,
    replace_lora_wrapper,
)
from verl.utils.memory_utils import aggressive_empty_cache
from verl.utils.model import convert_weight_keys
from verl.utils.py_functional import append_to_dict
from verl.workers.config import FSDPEngineConfig, FSDPOptimizerConfig
from verl.workers.engine.base import BaseEngine, BaseEngineCtx, EngineRegistry
from verl.workers.engine.fsdp.utils import create_device_mesh, get_sharding_strategy
from verl.workers.engine.utils import enable_full_determinism, prepare_micro_batches

from verl_omni.pipelines.utils import (
    build_scheduler,
    forward,
    forward_and_sample_previous_step,
    prepare_model_inputs,
    prepare_noisy_latents,
)
from verl_omni.utils.fsdp_utils import collect_lora_params
from verl_omni.workers.config import DiffusionModelConfig
from verl_omni.workers.engine.lora_adapter_mixin import LoRAAdapterMixin

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

device_name = get_device_name()


class DiffusersFSDPEngine(LoRAAdapterMixin, BaseEngine, ABC):
    """Base Diffusers engine using PyTorch FullyShardedDataParallel (FSDP).

    Supports model sharding, activation/optimizer offloading, LoRA, and sequence parallelism.
    """

    def __init__(
        self,
        model_config: DiffusionModelConfig,
        engine_config: FSDPEngineConfig,
        optimizer_config: FSDPOptimizerConfig,
        checkpoint_config: CheckpointConfig,
    ):
        """
        Initialize the DiffusersFSDPEngine.

        Sets up distributed device meshes, LoRA, and offload policies based on config.

        Args:
            config: Configuration object with FSDP and model settings.
        """
        super().__init__()

        self.model_config = model_config
        self.engine_config = engine_config
        self.optimizer_config = optimizer_config
        self.checkpoint_config = checkpoint_config

        self.mode = None

        self.rank = torch.distributed.get_rank()

        self._init_device_mesh()

        if self.engine_config.full_determinism:
            enable_full_determinism(seed=self.engine_config.seed)

        # set FSDP offload params
        self._is_offload_param = self.engine_config.param_offload
        self._is_offload_optimizer = self.engine_config.optimizer_offload
        self._is_lora = self.model_config.lora_rank > 0

    @property
    def is_param_offload_enabled(self) -> bool:
        return self._is_offload_param

    @property
    def is_optimizer_offload_enabled(self) -> bool:
        return self._is_offload_optimizer

    def is_mp_src_rank_with_outputs(self):
        if self.ulysses_device_mesh is not None:
            is_collect = self.ulysses_device_mesh["ulysses"].get_local_rank() == 0
        else:
            is_collect = True
        return is_collect

    def initialize(self):
        """
        Build the model, optimizer, and learning rate scheduler under FSDP.

        Applies device, dtype, and precision configurations, including mixed precision.
        Sets up checkpoint manager and FLOPs counter.
        """
        # This is used to import external_lib into the huggingface systems
        self._build_model_optimizer()

        self.checkpoint_manager = FSDPCheckpointManager(
            model=self.module,
            optimizer=self.optimizer,
            lr_scheduler=self.lr_scheduler,
            processing_class=self.model_config.get_processor(),
            checkpoint_config=self.checkpoint_config,
            trust_remote_code=self.model_config.trust_remote_code,
        )

        self.to(
            device="cpu",
            model=self._is_offload_param,
            optimizer=self._is_offload_optimizer,
            grad=self._is_offload_param,
        )

        log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger)

    def _init_device_mesh(self):
        world_size = torch.distributed.get_world_size()
        from torch.distributed.device_mesh import init_device_mesh

        fsdp_size = self.engine_config.fsdp_size

        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
        self.ulysses_device_mesh = None
        self.ulysses_sequence_parallel_size = self.engine_config.ulysses_sequence_parallel_size
        dp_size = self.get_data_parallel_size()
        if self.ulysses_sequence_parallel_size > 1:
            import diffusers
            from packaging import version

            if version.parse(diffusers.__version__) < version.parse("0.38.0"):
                raise RuntimeError(
                    f"Ulysses sequence parallelism requires diffusers >= 0.38.0 (found {diffusers.__version__}). "
                )

            # diffusers' ContextParallelConfig.setup() unconditionally accesses self._mesh["ring", "ulysses"],
            # so the mesh must have both named dimensions even though ring attention is not used.
            self.ulysses_device_mesh = init_device_mesh(
                device_name,
                mesh_shape=(dp_size, 1, self.ulysses_sequence_parallel_size),
                mesh_dim_names=["dp", "ring", "ulysses"],
            )

        self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1

    def _build_module_from_registry(self, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]:
        """Try loading via ``DiffusionModelBase.build_module()``.

        Returns ``None`` if the registry has no custom loader, so the
        caller falls back to ``diffusers.AutoModel``.
        """
        from verl_omni.pipelines.model_base import DiffusionModelBase

        model_cls = DiffusionModelBase.get_class(self.model_config)
        module = model_cls.build_module(self.model_config, torch_dtype)
        if module is None:
            return None

        logger.warning(
            "Built %s via DiffusionModelBase custom loader; engine-level hooks "
            "(attention processors, gradient-checkpointing wrappers, LoRA, "
            "dtype upcast) may be partially effective or silently inactive. "
            "See the docstring of _build_module_from_registry.",
            type(module).__name__,
        )

        try:
            module.to(torch_dtype)
        except AttributeError:
            raise TypeError(
                f"{type(module).__name__} returned by build_module() has no to() method. "
                "Custom models must be torch.nn.Module instances."
            ) from None

        if self.model_config.enable_gradient_checkpointing:
            try:
                module.enable_gradient_checkpointing()
            except AttributeError:
                raise NotImplementedError(
                    f"Gradient checkpointing is enabled in config, but {type(module).__name__} "
                    "does not implement enable_gradient_checkpointing(). "
                    "Either implement it or set enable_gradient_checkpointing=False."
                ) from None
            logger.info(
                "Gradient checkpointing enabled on %s via enable_gradient_checkpointing().",
                type(module).__name__,
            )

        module.can_generate = lambda: False
        return module

    def _build_module(self):
        from diffusers import AutoModel
        from verl.utils.torch_dtypes import PrecisionType

        torch_dtype = self.engine_config.model_dtype

        if torch_dtype is None:
            # if it is training, we force torch_dtype to fp32
            torch_dtype = torch.float32 if not self.engine_config.forward_only else torch.bfloat16

        torch_dtype = PrecisionType.to_dtype(torch_dtype)

        module = self._build_module_from_registry(torch_dtype)
        if module is not None:
            return module

        # Default path: load via diffusers AutoModel
        init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh)

        with init_context(), warnings.catch_warnings():
            warnings.simplefilter("ignore")

            module = AutoModel.from_pretrained(
                self.model_config.config_path or self.model_config.local_path,
                torch_dtype=torch_dtype,
                trust_remote_code=self.model_config.trust_remote_code,
                subfolder="" if self.model_config.config_path else self.model_config.transformer_subfolder,
            )
            module.set_attention_backend(self.model_config.attn_backend)

            # some parameters may not in torch_dtype
            module.to(torch_dtype)

            if self.model_config.enable_gradient_checkpointing:
                module.enable_gradient_checkpointing()

            # patch for checkpoint saving
            def save_config(self, save_directory: str | os.PathLike):
                output_config_file = os.path.join(save_directory, "config.json")
                with open(output_config_file, "w", encoding="utf-8") as f:
                    json.dump(self, f, indent=4, sort_keys=True)

            module.can_generate = lambda: False
            module.config.save_pretrained = save_config.__get__(module.config)

        return module

    def _build_fsdp_module(self, module):
        # TODO(ziheng): need to improve
        from torch.distributed.fsdp import CPUOffload, MixedPrecision
        from verl.utils.torch_dtypes import PrecisionType

        mixed_precision_config = self.engine_config.mixed_precision
        if mixed_precision_config is not None:
            param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16"))
            reduce_dtype = PrecisionType.to_dtype(mixed_precision_config.get("reduce_dtype", "fp32"))
            buffer_dtype = PrecisionType.to_dtype(mixed_precision_config.get("buffer_dtype", "fp32"))
        else:
            param_dtype = torch.bfloat16
            reduce_dtype = torch.float32
            buffer_dtype = torch.float32

        mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype)

        auto_wrap_policy = get_fsdp_wrap_policy(
            module=module,
            config=self.engine_config.wrap_policy,
            is_lora=self.model_config.lora_rank > 0,
        )

        fsdp_mesh = self.device_mesh
        sharding_strategy = get_sharding_strategy(fsdp_mesh)

        # Note: We force turn off CPUOffload because it causes incorrect results when using grad accumulation
        if self.engine_config.strategy == "fsdp":
            # cpu_offload:
            # - actor: None
            # - critic: None
            # - ref: CPUOffload(offload_params=True)

            # We force reference policy to use CPUOffload to save memory.
            # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation
            cpu_offload = None
            if self.engine_config.forward_only:
                cpu_offload = CPUOffload(offload_params=True)
                self._is_offload_param = False
                self._is_offload_optimizer = False

            module = FSDP(
                module,
                param_init_fn=init_fn,
                auto_wrap_policy=auto_wrap_policy,
                device_id=get_device_id(),
                sharding_strategy=sharding_strategy,
                mixed_precision=mixed_precision,
                sync_module_states=True,
                device_mesh=self.device_mesh,
                forward_prefetch=self.engine_config.forward_prefetch,
                use_orig_params=self.engine_config.use_orig_params,
                cpu_offload=cpu_offload,
            )
        elif self.engine_config.strategy == "fsdp2":
            # - actor: offload_policy
            # - critic: offload_policy
            # - ref: CPUOffloadPolicy(pin_memory=True)
            assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
            mp_policy = MixedPrecisionPolicy(
                param_dtype=param_dtype, reduce_dtype=reduce_dtype, cast_forward_inputs=True
            )
            offload_policy = None
            if self.engine_config.offload_policy or self.engine_config.forward_only:
                self._is_offload_param = False
                self._is_offload_optimizer = False
                offload_policy = CPUOffloadPolicy(pin_memory=True)

            fsdp_kwargs = {
                "mesh": fsdp_mesh,
                "mp_policy": mp_policy,
                "offload_policy": offload_policy,
                "reshard_after_forward": self.engine_config.reshard_after_forward,
            }
            full_state = module.state_dict()
            apply_fsdp2(module, fsdp_kwargs, self.engine_config)
            fsdp2_load_full_state_dict(module, full_state, fsdp_mesh, offload_policy)
        else:
            raise NotImplementedError(f"Unknown strategy {self.engine_config.strategy}")

        if torch.distributed.get_world_size() == 1 and fsdp_version(module) == 1:
            FSDP.set_state_dict_type(
                module,
                state_dict_type=StateDictType.FULL_STATE_DICT,
                state_dict_config=FullStateDictConfig(),
            )
        elif fsdp_version(module) == 1:
            FSDP.set_state_dict_type(
                module,
                state_dict_type=StateDictType.SHARDED_STATE_DICT,
                state_dict_config=ShardedStateDictConfig(),
            )

        return module

    def _build_scheduler(self):
        return build_scheduler(self.model_config)

    def _build_optimizer(self, module):
        from verl.workers.config.optimizer import build_optimizer

        optimizer = build_optimizer(module.parameters(), self.optimizer_config)

        return optimizer

    def _build_lr_scheduler(self, optimizer):
        from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup

        optim_config = self.optimizer_config

        total_steps = optim_config.total_training_steps
        num_warmup_steps = optim_config.lr_warmup_steps
        lr_scheduler_type = optim_config.lr_scheduler_type
        min_lr_ratio = optim_config.min_lr_ratio
        num_cycles = optim_config.num_cycles
        zero_indexed_step = optim_config.zero_indexed_step
        if num_warmup_steps <= 0:
            num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio
            num_warmup_steps = int(num_warmup_steps_ratio * total_steps)

        if self.rank == 0:
            print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")

        if lr_scheduler_type == "constant":
            lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
        elif lr_scheduler_type == "cosine":
            lr_scheduler = get_cosine_schedule_with_warmup(
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=total_steps,
                min_lr_ratio=min_lr_ratio,
                num_cycles=num_cycles,
                zero_indexed_step=zero_indexed_step,
            )
        else:
            raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
        return lr_scheduler

    def _build_model_optimizer(self):
        from diffusers import ContextParallelConfig
        from verl.utils.model import print_model_size

        # Load base model with specified configuration and dtype
        module = self._build_module()
        # Apply LoRA adapters if low-rank adaptation is enabled
        if self._is_lora:
            module = self._build_lora_module(module)

        if self.use_ulysses_sp:
            sp_size = self.ulysses_sequence_parallel_size
            module.enable_parallelism(
                config=ContextParallelConfig(ulysses_degree=sp_size, mesh=self.ulysses_device_mesh)
            )

        # Load diffusion scheduler
        scheduler = self._build_scheduler()

        # Synchronize all distributed processes before proceeding
        torch.distributed.barrier()
        if self.rank == 0:
            print_model_size(module)
        log_gpu_memory_usage("After init model from Diffusers AutoModel", logger=logger)

        # Wrap model with FSDP for distributed training (sharding, mixed precision, etc.)
        log_gpu_memory_usage("Before FSDP", logger=None)
        module = self._build_fsdp_module(module)
        log_gpu_memory_usage("After FSDP", logger=None)

        if not self.engine_config.forward_only:
            # Initialize optimizer with model parameters and config settings
            optimizer = self._build_optimizer(module)
            # Create learning rate scheduler with warmup and decay settings
            lr_scheduler = self._build_lr_scheduler(optimizer)
        else:
            optimizer = None
            lr_scheduler = None

        self.module = module
        self.scheduler = scheduler
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

    def train_mode(self, **kwargs):
        """
        Return a context manager that switches to training mode with FSDP-specific handling.

        Includes parameter and optimizer offload entry/exit.
        """
        return EngineTrainModeCtx(self, **kwargs)

    def eval_mode(self, **kwargs):
        """
        Return a context manager that switches to evaluation mode with FSDP-specific handling.

        Includes activation offload entry/exit.
        """
        return EngineEvalModeCtx(self, **kwargs)

    def get_data_parallel_rank(self):
        if self.ulysses_device_mesh is not None:
            return self.ulysses_device_mesh["dp"].get_local_rank()
        else:
            return torch.distributed.get_rank()

    def get_data_parallel_size(self):
        return torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size

    def get_data_parallel_group(self):
        if self.ulysses_device_mesh is not None:
            return self.ulysses_device_mesh.get_group(mesh_dim="dp")
        else:
            return torch.distributed.group.WORLD

    def get_model_parallel_group(self):
        raise NotImplementedError

    def get_context_parallel_group(self):
        raise NotImplementedError

    def postprocess_batch_func(self, output_lst, indices, data: TensorDict):
        model_output = {}
        losses = []
        aggregated_metrics = {}

        for output in output_lst:
            # model output list
            model_output_lst = {}
            if "model_output" in output:
                for model_output_dict in output["model_output"]:
                    for key, val in model_output_dict.items():
                        model_output_lst.setdefault(key, []).append(val)
                for key, val in model_output_lst.items():
                    model_output.setdefault(key, []).append(torch.stack(val, dim=1))  # (bsz, steps, ...)
            # loss
            if "loss" in output:
                losses.append(output["loss"])

            # metrics
            if "metrics" in output:
                for metrics in output["metrics"]:
                    append_to_dict(aggregated_metrics, metrics)

        # concat results from micro batches
        for key, val in model_output.items():
            model_output[key] = torch.concat(val, dim=0)  # (global_bsz, steps, ...)

        output = {
            "model_output": model_output,  # a dict of tensors in shape (global_bsz, steps, ...)
            "loss": losses,  # micro-batch step-wise losses
            "metrics": aggregated_metrics,
        }

        return output

    @staticmethod
    def _unpad_nested_embeds(embeds: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Convert a jagged nested tensor pair (embeds, mask) to dense padded tensors."""
        batch_size = embeds.size(0)
        max_seq_len = max(embeds.offsets().diff())
        embed_dim = embeds.size(-1)
        embeds = torch.nested.to_padded_tensor(embeds, padding=0, output_size=(batch_size, max_seq_len, embed_dim))
        mask = torch.nested.to_padded_tensor(mask, padding=0, output_size=(batch_size, max_seq_len))
        return embeds, mask

    @staticmethod
    def _pad_embeds_for_sp(embeds: torch.Tensor, mask: torch.Tensor, sp_size: int) -> tuple[torch.Tensor, torch.Tensor]:
        """Pad sequence dimension of (embeds, mask) to a multiple of sp_size."""
        seq_len = embeds.size(1)
        aligned_seq_len = (seq_len + sp_size - 1) // sp_size * sp_size
        if aligned_seq_len > seq_len:
            pad_len = aligned_seq_len - seq_len
            embeds = torch.nn.functional.pad(embeds, (0, 0, 0, pad_len))
            mask = torch.nn.functional.pad(mask, (0, pad_len))
        return embeds, mask

    @abstractmethod
    def forward_backward_batch(
        self, data: TensorDict, loss_function: Callable, forward_only: bool = False
    ) -> list[TensorDict]:
        """Run forward/backward over a batch; implemented by algorithm-specific subclasses."""
        pass

    @abstractmethod
    def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
        """Build model inputs for one diffusion step; implemented by algorithm-specific subclasses."""
        pass

    @abstractmethod
    def prepare_model_outputs(self, output, micro_batch: TensorDict):
        """Post-process raw model output; implemented by algorithm-specific subclasses."""
        pass

    @abstractmethod
    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
        """Run one diffusion step forward (and loss); implemented by algorithm-specific subclasses."""
        pass

    def optimizer_zero_grad(self):
        """
        Zero gradients and enforce FSDP grad-clipping logic.
        """
        self.optimizer.zero_grad()

    def optimizer_step(self):
        """
        Clip gradients, skip update if non-finite, and step optimizer.

        Returns:
            grad_norm (float): Norm of gradients before clipping.
        """
        assert self.optimizer_config.clip_grad is not None

        if isinstance(self.module, FSDP):
            grad_norm = self.module.clip_grad_norm_(self.optimizer_config.clip_grad)
        elif isinstance(self.module, FSDPModule):
            grad_norm = fsdp2_clip_grad_norm_(self.module.parameters(), max_norm=self.optimizer_config.clip_grad)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                self.module.parameters(), max_norm=self.optimizer_config.clip_grad
            )

        if isinstance(grad_norm, DTensor):
            grad_norm = grad_norm.full_tensor()

        # if grad_norm is not finite, skip the update
        if not torch.isfinite(grad_norm):
            print(f"WARN: grad_norm is not finite: {grad_norm}")
            self.optimizer.zero_grad()
        else:
            self.optimizer.step()
        return grad_norm.item()

    def lr_scheduler_step(self):
        """
        Advance FSDP scheduler and return updated learning rate.
        """
        self.lr_scheduler.step()
        lr = self.lr_scheduler.get_last_lr()[0]  # only return the first group
        return lr

    def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True):
        """
        Move FSDP model and/or optimizer to CPU or GPU with offload support.
        Note that this function executes irrespective of offload config. It serves as manual control
        """
        super().to(device=device, model=model, optimizer=optimizer, grad=grad)

        if self.engine_config.forward_only:
            # force cpu_offload
            return

        device_name = get_device_name()

        assert device in (device_name, "cpu")
        if device == device_name:
            if model:
                load_fsdp_model_to_gpu(self.module)
            if optimizer and self.optimizer is not None:
                load_fsdp_optimizer(self.optimizer, device)
            gc.collect()
        elif device == "cpu":
            if model:
                offload_fsdp_model_to_cpu(self.module)
            if optimizer and self.optimizer is not None:
                offload_fsdp_optimizer(self.optimizer)
        else:
            raise ValueError(f"Invalid device type: {device}")

    def save_checkpoint(
        self,
        local_path: str,
        hdfs_path: Optional[str] = None,
        global_step: int = 0,
        max_ckpt_to_keep: Optional[int] = None,
        **kwargs,
    ) -> None:
        """
        Save FSDP checkpoint, handling parameter offload as needed.
        """
        origin_module_device = next(self.module.parameters()).device.type
        if self._is_offload_param or origin_module_device == "cpu":
            load_fsdp_model_to_gpu(self.module)

        self.checkpoint_manager.save_checkpoint(
            local_path=local_path, hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep
        )

        torch.distributed.barrier()
        if self._is_offload_param:
            offload_fsdp_model_to_cpu(self.module)
        gc.collect()
        aggressive_empty_cache(force_sync=True)

    def load_checkpoint(
        self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs
    ) -> None:
        """
        Load FSDP checkpoint, restoring parameters and optimizer state.
        """
        import torch

        if self._is_offload_param:
            load_fsdp_model_to_gpu(self.module)

        self.checkpoint_manager.load_checkpoint(
            local_path=local_path, hdfs_path=hdfs_path, del_local_after_load=del_local_after_load
        )

        torch.distributed.barrier()
        if self._is_offload_param:
            offload_fsdp_model_to_cpu(self.module)

        if self._is_offload_optimizer:
            offload_fsdp_optimizer(self.optimizer)

    def get_per_tensor_param(
        self, layered_summon=False, base_sync_done=False, adapter_name: str | None = None, **kwargs
    ):
        log_gpu_memory_usage("Before load_fsdp_model_to_gpu", logger=logger)

        load_fsdp_model_to_gpu(self.module)

        log_gpu_memory_usage("After load_fsdp_model_to_gpu", logger=logger)

        peft_config = None

        peft_model = getattr(self.module, "_fsdp_wrapped_module", self.module)
        if hasattr(peft_model, "peft_config"):  # LoRA
            peft_config = peft_model.peft_config.get("default", None)
            adapter_ctx = self.use_adapter(adapter_name) if adapter_name is not None else nullcontext()
            with adapter_ctx:
                params = collect_lora_params(
                    module=self.module,
                    layered_summon=layered_summon,
                    base_sync_done=base_sync_done,
                    is_diffusers=True,
                    adapter_name=adapter_name or "default",
                    layer_prefixes=self.model_config.fsdp_layer_prefixes,
                )
            if not base_sync_done:
                params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
        else:
            params = self.module.state_dict()

        params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))

        log_gpu_memory_usage("Before offload_fsdp_model_to_cpu", logger=logger)
        if self._is_offload_param:
            offload_fsdp_model_to_cpu(self.module)
        log_gpu_memory_usage("After offload_fsdp_model_to_cpu", logger=logger)

        if peft_config is not None and base_sync_done:
            per_tensor_param = params.items()
        else:
            device = get_device_id()  # used when fsdp2 set cpu_offload_policy
            # TODO: cast fp32 to bf16 to reduce weight sync overhead, need more fine-grained control, e.g MoE gate
            per_tensor_param = (
                (
                    name,
                    param.to(device, non_blocking=True).full_tensor().to(torch.bfloat16, non_blocking=True)
                    if isinstance(param, DTensor)
                    else param,
                )
                for name, param in params.items()
            )

        # we need to add the prefix to make it compatible with rollout engine
        per_tensor_param = ((f"transformer.{name}", tensor) for name, tensor in per_tensor_param)
        peft_config_dict = peft_config.to_dict() if peft_config is not None else None
        return per_tensor_param, peft_config_dict

    def _run_forward_backward_batch(
        self,
        data: TensorDict,
        loss_function: Callable,
        forward_only: bool,
        *,
        timesteps_key: str,
    ) -> dict:
        num_timesteps = data[timesteps_key].shape[1]
        tu.assign_non_tensor(data, sp_size=self.ulysses_sequence_parallel_size)
        tu.assign_non_tensor(data, use_dynamic_bsz=False)

        micro_batches, indices = prepare_micro_batches(
            data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True
        )

        gradient_accumulation_steps = len(micro_batches) * num_timesteps
        output_lst = []
        ctx = torch.no_grad() if forward_only else nullcontext()

        for micro_batch in micro_batches:
            micro_batch = micro_batch.to(get_device_id())
            tu.assign_non_tensor(micro_batch, gradient_accumulation_steps=gradient_accumulation_steps)
            meta_info_lst = {"model_output": [], "loss": [], "metrics": []}
            # Forward and backward for each timestep
            with ctx:
                for step in range(num_timesteps):
                    loss, meta_info = self.forward_step(
                        micro_batch, loss_function=loss_function, forward_only=forward_only, step=step
                    )
                    if not forward_only:
                        loss.backward()
                    for key, val in meta_info.items():
                        meta_info_lst[key].append(val)
            output_lst.append(meta_info_lst)

        # postprocess and return
        return self.postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)


@EngineRegistry.register(model_type="diffusion_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
class PPODiffusersFSDPEngine(DiffusersFSDPEngine):
    """Diffusers FSDP engine with PPO forward/backward and I/O preparation."""

    def forward_backward_batch(
        self, data: TensorDict, loss_function: Callable, forward_only: bool = False
    ) -> list[TensorDict]:
        return self._run_forward_backward_batch(data, loss_function, forward_only, timesteps_key="all_timesteps")

    def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
        """
        Extract and pre-process universal tensors, then delegate architecture-specific
        input construction to the registered DiffusionModelBase subclass.

        Handles common tensor extraction and nested-embed unpadding here.
        Architecture-specific input dict construction is delegated to the model registry.
        """
        latents = micro_batch["all_latents"]
        timesteps = micro_batch["all_timesteps"]
        prompt_embeds = micro_batch.get("prompt_embeds", None)
        prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None)
        negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None)
        negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None)
        sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1

        if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.is_nested:
            prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask)

        if isinstance(prompt_embeds, torch.Tensor) and sp_size > 1:
            prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size)

        if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested:
            negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds(
                negative_prompt_embeds, negative_prompt_embeds_mask
            )

        if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1:
            negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp(
                negative_prompt_embeds, negative_prompt_embeds_mask, sp_size
            )

        return prepare_model_inputs(
            module=self.module,
            model_config=self.model_config,
            latents=latents,
            timesteps=timesteps,
            prompt_embeds=prompt_embeds,
            prompt_embeds_mask=prompt_embeds_mask,
            negative_prompt_embeds=negative_prompt_embeds,
            negative_prompt_embeds_mask=negative_prompt_embeds_mask,
            micro_batch=micro_batch,
            step=step,
        )

    def prepare_model_outputs(self, output, micro_batch: TensorDict):
        log_prob, prev_sample_mean, std_dev_t, sqrt_dt = output
        return {
            "log_probs": log_prob,
            "prev_sample_mean": prev_sample_mean,
            "std_dev_t": std_dev_t,
            "sqrt_dt": sqrt_dt,
        }

    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
        model_inputs, negative_model_inputs = self.prepare_model_inputs(micro_batch=micro_batch, step=step)
        raw_output = forward_and_sample_previous_step(
            module=self.module,
            scheduler=self.scheduler,
            model_config=self.model_config,
            model_inputs=model_inputs,
            negative_model_inputs=negative_model_inputs,
            scheduler_inputs=micro_batch,
            step=step,
        )
        model_output = self.prepare_model_outputs(output=raw_output, micro_batch=micro_batch)

        if loss_function is not None:
            data = tu.get_tensordict(
                {
                    "old_log_probs": micro_batch["old_log_probs"][:, step],
                    "advantages": micro_batch["advantages"][:, step],
                },
            )
            tu.assign_non_tensor(
                data,
                gradient_accumulation_steps=tu.get_non_tensor_data(
                    micro_batch, "gradient_accumulation_steps", default=None
                ),
                sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None),
            )

            # TODO (mike): refactor the data preparation logic here
            if micro_batch.get("ref_log_prob", None) is not None:
                data["ref_log_prob"] = micro_batch["ref_log_prob"][:, step]

            if micro_batch.get("ref_prev_sample_mean", None) is not None:
                data["ref_prev_sample_mean"] = micro_batch["ref_prev_sample_mean"][:, step]

            if micro_batch.get("old_prev_sample_mean", None) is not None:
                data["old_prev_sample_mean"] = micro_batch["old_prev_sample_mean"][:, step]

            if micro_batch.get("rollout_is_weights", None) is not None:
                data["rollout_is_weights"] = micro_batch["rollout_is_weights"][:, step]

            loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())
        else:
            assert forward_only, "forward_only must be True when loss_function is None"
            loss = torch.tensor(1.0, device=device_name)
            metrics = {}

        output = {
            "model_output": model_output,
            "loss": loss.detach().item(),
            "metrics": metrics,
        }

        return loss, output


@EngineRegistry.register(model_type="diffusion_dpo_model", backend=["fsdp", "fsdp2"], device=["cuda"])
class DPODiffusersFSDPEngine(DiffusersFSDPEngine):
    """Diffusers FSDP engine variant for diffusion DPO."""

    def _prepare_noisy_latents(self, micro_batch: TensorDict) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        latents = micro_batch.get("image_latents", None)
        if latents is None:
            raise KeyError("Diffusion DPO training requires `image_latents` in the micro batch.")

        return prepare_noisy_latents(
            latents=latents,
            scheduler=self.scheduler,
            noise=micro_batch.get(
                "noise", None
            ),  # if noise is not provided, sample noise and timesteps in the forward step
            timesteps=micro_batch.get(
                "timesteps", None
            ),  # if timesteps is not provided, sample timesteps in the forward step
        )

    def _prepare_prompt_embeds(
        self,
        prompt_embeds: torch.Tensor,
        prompt_embeds_mask: torch.Tensor | None,
        negative_prompt_embeds: torch.Tensor | None,
        negative_prompt_embeds_mask: torch.Tensor | None,
    ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
        """Apply common nested-tensor and sequence-parallel padding to prompt embeds."""
        sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1
        if not isinstance(prompt_embeds_mask, torch.Tensor):
            prompt_embeds_mask = None
        if not isinstance(negative_prompt_embeds, torch.Tensor):
            negative_prompt_embeds = None
        if not isinstance(negative_prompt_embeds_mask, torch.Tensor):
            negative_prompt_embeds_mask = None

        if prompt_embeds.is_nested:
            prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask)

        if sp_size > 1:
            prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size)

        if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested:
            negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds(
                negative_prompt_embeds, negative_prompt_embeds_mask
            )

        if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1:
            negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp(
                negative_prompt_embeds, negative_prompt_embeds_mask, sp_size
            )

        return prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask

    def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
        del step

        noisy_latents, noise, timesteps = self._prepare_noisy_latents(micro_batch)
        latent = micro_batch["image_latents"].to(device=noise.device, dtype=noise.dtype)
        prompt_embeds = micro_batch["prompt_embeds"]
        prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None)
        negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None)
        negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None)
        prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask = (
            self._prepare_prompt_embeds(
                prompt_embeds=prompt_embeds,
                prompt_embeds_mask=prompt_embeds_mask,
                negative_prompt_embeds=negative_prompt_embeds,
                negative_prompt_embeds_mask=negative_prompt_embeds_mask,
            )
        )

        model_inputs, negative_model_inputs = prepare_model_inputs(
            module=self.module,
            model_config=self.model_config,
            latents=noisy_latents,
            timesteps=timesteps,
            prompt_embeds=prompt_embeds,
            prompt_embeds_mask=prompt_embeds_mask,
            negative_prompt_embeds=negative_prompt_embeds,
            negative_prompt_embeds_mask=negative_prompt_embeds_mask,
            micro_batch=micro_batch,
            step=0,
        )
        return model_inputs, negative_model_inputs, {"noise": noise, "latent": latent, "timesteps": timesteps}

    def prepare_model_outputs(self, output, micro_batch: TensorDict):
        del micro_batch

        noise_pred, dpo_context = output

        return {
            "noise_pred": noise_pred,
            "noise": dpo_context["noise"],
            "latent": dpo_context["latent"],
            "timesteps": dpo_context["timesteps"],
        }

    def forward_backward_batch(
        self, data: TensorDict, loss_function: Callable, forward_only: bool = False
    ) -> list[TensorDict]:
        tu.assign_non_tensor(data, sp_size=self.ulysses_sequence_parallel_size)
        tu.assign_non_tensor(data, use_dynamic_bsz=False)

        micro_batches, indices = prepare_micro_batches(
            data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True
        )

        gradient_accumulation_steps = len(micro_batches)

        output_lst = []

        ctx = torch.no_grad() if forward_only else nullcontext()

        for micro_batch in micro_batches:
            micro_batch = micro_batch.to(get_device_id())
            tu.assign_non_tensor(micro_batch, gradient_accumulation_steps=gradient_accumulation_steps)
            meta_info_lst = {"model_output": [], "loss": [], "metrics": []}
            with ctx:
                # DPO is a one-shot flow-matching objective over final image latents,
                # not a reversed-sampling objective over every rollout timestep.
                loss, meta_info = self.forward_step(
                    micro_batch,
                    loss_function=loss_function,
                    forward_only=forward_only,
                    step=None,  # use random step for DPO
                )

                if not forward_only:
                    loss.backward()

                for key, val in meta_info.items():
                    meta_info_lst[key].append(val)

            output_lst.append(meta_info_lst)

        # postprocess and return
        return self.postprocess_batch_func(output_lst=output_lst, indices=indices, data=data)

    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
        model_inputs, negative_model_inputs, dpo_context = self.prepare_model_inputs(micro_batch=micro_batch, step=step)
        noise_pred = forward(
            module=self.module,
            model_config=self.model_config,
            model_inputs=model_inputs,
            negative_model_inputs=negative_model_inputs,
        )
        model_output = self.prepare_model_outputs(output=(noise_pred, dpo_context), micro_batch=micro_batch)

        if loss_function is not None:
            data = tu.get_tensordict({"sample_level_rewards": micro_batch["sample_level_rewards"]})
            uid = tu.get_non_tensor_data(micro_batch, "uid", default=None)
            tu.assign_non_tensor(
                data,
                gradient_accumulation_steps=tu.get_non_tensor_data(
                    micro_batch, "gradient_accumulation_steps", default=None
                ),
                sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None),
            )
            if uid is not None:
                tu.assign_non_tensor(data, uid=uid)
            if micro_batch.get("ref_noise_pred", None) is not None:
                ref_noise_pred = micro_batch["ref_noise_pred"]
                if ref_noise_pred.ndim == model_output["noise_pred"].ndim + 1 and ref_noise_pred.shape[1] == 1:
                    ref_noise_pred = ref_noise_pred[:, 0]
                data["ref_noise_pred"] = ref_noise_pred
            loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())
        else:
            assert forward_only, "forward_only must be True when loss_function is None"
            loss = torch.tensor(1.0, device=device_name)
            metrics = {}

        output = {
            "model_output": model_output,
            "loss": loss.detach().item(),
            "metrics": metrics,
        }

        return loss, output


@EngineRegistry.register(model_type="diffusion_nft_model", backend=["fsdp", "fsdp2"], device=["cuda", "npu"])
class NFTDiffusersFSDPEngine(DiffusersFSDPEngine):
    """Diffusers FSDP engine for direct-preference / forward-process objectives (e.g. DiffusionNFT)."""

    def forward_backward_batch(
        self, data: TensorDict, loss_function: Callable, forward_only: bool = False
    ) -> list[TensorDict]:
        return self._run_forward_backward_batch(data, loss_function, forward_only, timesteps_key="train_timesteps")

    def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
        x0 = micro_batch["latents_clean"]
        timestep = micro_batch["train_timesteps"][:, step]
        t = timestep.float() / 1000.0
        t_expanded = t.view(-1, *([1] * (x0.ndim - 1)))

        if micro_batch.get("forward_noise", None) is not None:
            forward_noise = micro_batch["forward_noise"]
            noise = forward_noise[:, step] if forward_noise.ndim == x0.ndim + 1 else forward_noise
        else:
            noise = torch.randn_like(x0.float())
        xt = (1.0 - t_expanded) * x0 + t_expanded * noise

        prompt_embeds = micro_batch["prompt_embeds"]
        prompt_embeds_mask = micro_batch["prompt_embeds_mask"]
        negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None)
        negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None)
        sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1

        if prompt_embeds.is_nested:
            prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask)

        if sp_size > 1:
            prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size)

        if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested:
            negative_prompt_embeds, negative_prompt_embeds_mask = self._unpad_nested_embeds(
                negative_prompt_embeds, negative_prompt_embeds_mask
            )

        if isinstance(negative_prompt_embeds, torch.Tensor) and sp_size > 1:
            negative_prompt_embeds, negative_prompt_embeds_mask = self._pad_embeds_for_sp(
                negative_prompt_embeds, negative_prompt_embeds_mask, sp_size
            )

        model_inputs, negative_model_inputs = prepare_model_inputs(
            module=self.module,
            model_config=self.model_config,
            latents=xt,
            timesteps=timestep,
            prompt_embeds=prompt_embeds,
            prompt_embeds_mask=prompt_embeds_mask,
            negative_prompt_embeds=negative_prompt_embeds,
            negative_prompt_embeds_mask=negative_prompt_embeds_mask,
            micro_batch=micro_batch,
            step=step,
        )
        return model_inputs, negative_model_inputs, x0, xt, t_expanded

    def prepare_model_outputs(self, output, micro_batch: TensorDict):
        old_prediction, forward_prediction, ref_forward_prediction, x0, xt, t_expanded = output
        return {
            "old_prediction": old_prediction,
            "forward_prediction": forward_prediction,
            "ref_forward_prediction": ref_forward_prediction,
            "x0": x0,
            "xt": xt,
            "t_expanded": t_expanded,
        }

    def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step):
        model_inputs, negative_model_inputs, x0, xt, t_expanded = self.prepare_model_inputs(
            micro_batch=micro_batch, step=step
        )

        with self.use_adapter("old"), torch.no_grad():
            old_prediction = forward(
                module=self.module,
                model_config=self.model_config,
                model_inputs=model_inputs,
                negative_model_inputs=negative_model_inputs,
            ).detach()

        forward_prediction = forward(
            module=self.module,
            model_config=self.model_config,
            model_inputs=model_inputs,
            negative_model_inputs=negative_model_inputs,
        )

        with torch.no_grad():
            with self.disable_adapter():
                ref_forward_prediction = forward(
                    module=self.module,
                    model_config=self.model_config,
                    model_inputs=model_inputs,
                    negative_model_inputs=negative_model_inputs,
                ).detach()
        self._set_adapter("default")

        model_output = self.prepare_model_outputs(
            output=(old_prediction, forward_prediction, ref_forward_prediction, x0, xt, t_expanded),
            micro_batch=micro_batch,
        )

        if loss_function is not None:
            data = tu.get_tensordict({"reward_prob": micro_batch["reward_prob"][:, step]})
            tu.assign_non_tensor(
                data,
                gradient_accumulation_steps=tu.get_non_tensor_data(
                    micro_batch, "gradient_accumulation_steps", default=None
                ),
                sp_size=tu.get_non_tensor_data(micro_batch, "sp_size", default=None),
            )
            loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group())
        else:
            assert forward_only, "forward_only must be True when loss_function is None"
            loss = torch.tensor(1.0, device=x0.device)
            metrics = {}

        output = {
            "model_output": model_output,
            "loss": loss.detach().item(),
            "metrics": metrics,
        }
        return loss, output


class EngineEvalModeCtx(BaseEngineCtx):
    def __init__(self, engine: DiffusersFSDPEngine, **kwargs):
        super().__init__(engine=engine, mode="eval", **kwargs)

    def __enter__(self):
        assert isinstance(self.engine, DiffusersFSDPEngine)
        super().__enter__()
        self.engine.module.eval()

    def __exit__(self, exc_type, exc_value, traceback):
        assert isinstance(self.engine, DiffusersFSDPEngine)

        # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
        # unshard the root FSDP module
        if self.engine.engine_config.fsdp_size > 1:
            if fsdp_version(self.engine.module) == 1:
                self.engine.module._handle.reshard(True)
            elif fsdp_version(self.engine.module) == 2:
                self.engine.module.reshard()

        super().__exit__(exc_type, exc_value, traceback)


class EngineTrainModeCtx(BaseEngineCtx):
    def __init__(self, engine: DiffusersFSDPEngine, **kwargs):
        super().__init__(engine=engine, mode="train", **kwargs)

    def __enter__(self):
        assert isinstance(self.engine, DiffusersFSDPEngine)
        super().__enter__()
        self.engine.module.train()

    def __exit__(self, exc_type, exc_value, traceback):
        assert isinstance(self.engine, DiffusersFSDPEngine)
        self.engine.optimizer_zero_grad()
        super().__exit__(exc_type, exc_value, traceback)
