# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from tensordict import TensorDict
from verl.trainer.ppo.rollout_corr_helper import compute_rollout_correction_and_rejection_mask
from verl.utils import tensordict_utils as tu
from verl.utils.metric import AggregationType, Metric
from verl_omni.trainer.diffusion.diffusion_algos import get_diffusion_loss_fn
from verl_omni.workers.config import DiffusionActorConfig
def _apply_bypass_rc(
log_prob: torch.Tensor, # (B,) current policy log-prob
old_log_prob: torch.Tensor, # (B,) == rollout_log_prob in bypass
rc_cfg, # RolloutCorrectionConfig
data: TensorDict, # modified in-place
metrics: dict, # modified in-place
) -> None:
"""Compute per-step IS/RS for bypass mode and stash weights into ``data``."""
log_prob_2d = log_prob.unsqueeze(-1) # current policy log-prob (π_θ)
rollout_lp_2d = old_log_prob.unsqueeze(-1) # rollout policy log-prob (π_rollout)
response_mask = torch.ones_like(log_prob_2d)
# In bypass mode, RS checks current→rollout drift: pass current as old_log_prob, rollout as rollout_log_prob.
# This matches the mathematical intent: RS mask is applied to exp(log_prob - rollout_log_prob).
is_weights_proto, modified_mask, rc_metrics = compute_rollout_correction_and_rejection_mask(
old_log_prob=log_prob_2d, # current policy (π_θ)
rollout_log_prob=rollout_lp_2d, # rollout policy (π_rollout)
response_mask=response_mask,
rollout_is=rc_cfg.rollout_is,
rollout_is_threshold=rc_cfg.rollout_is_threshold,
rollout_is_batch_normalize=rc_cfg.rollout_is_batch_normalize,
rollout_rs=rc_cfg.rollout_rs,
rollout_rs_threshold=rc_cfg.rollout_rs_threshold,
)
# ppo_clip: PPO ratio handles IS, only RS mask is applied.
assert rc_cfg.loss_type == "ppo_clip", f"Only loss_type='ppo_clip' is supported, got {rc_cfg.loss_type!r}"
weights: torch.Tensor | None = None
if rc_cfg.rollout_rs:
rs_mask = modified_mask
weights = rs_mask if weights is None else weights * rs_mask
if weights is not None:
existing = data.get("rollout_is_weights", None)
data["rollout_is_weights"] = (
weights.squeeze(-1).to(dtype=log_prob.dtype)
if existing is None
else existing * weights.squeeze(-1).to(dtype=log_prob.dtype)
)
for k, v in rc_metrics.items():
metrics[k] = Metric(value=float(v), aggregation=AggregationType.MEAN)
[docs]
def diffusion_loss(config: DiffusionActorConfig, model_output, data: TensorDict, dp_group=None):
"""Compute loss for diffusion model"""
config.global_batch_info["loss_scale_factor"] = config.loss_scale_factor
metrics = {}
loss_mode = config.diffusion_loss.get("loss_mode", "flow_grpo")
loss_func = get_diffusion_loss_fn(loss_mode)
# Rollout Correction bypass mode only applies to log-prob policy-gradient losses.
if "log_probs" in loss_func.required_model_output_keys:
log_prob = model_output["log_probs"]
old_log_prob = data["old_log_probs"]
rc_cfg = config.rollout_correction
# Rollout Correction bypass mode: compute IS/RS weights per-step and
# stash ``rollout_is_weights`` into ``data`` before loss dispatch.
if rc_cfg.bypass_mode:
_apply_bypass_rc(log_prob, old_log_prob, rc_cfg, data, metrics)
loss_func.validate_inputs(loss_name=loss_mode, model_output=model_output, data=data)
loss_result = loss_func(config=config, model_output=model_output, data=data)
loss_value = loss_result.loss
metrics_values = loss_result.metrics
metrics_values = Metric.from_dict(metrics_values, aggregation=AggregationType.MEAN)
metrics.update(metrics_values)
if loss_result.add_loss_metric:
metrics["actor/loss"] = Metric(value=loss_value, aggregation=AggregationType.MEAN)
if config.use_kl_loss:
loss_func = get_diffusion_loss_fn("kl")
loss_func.validate_inputs(loss_name="kl", model_output=model_output, data=data)
kl_result = loss_func(config=config, model_output=model_output, data=data)
loss_value += kl_result.loss * config.kl_loss_coef
metrics.update(Metric.from_dict(kl_result.metrics, aggregation=AggregationType.MEAN))
metrics["kl_coef"] = config.kl_loss_coef
if kl_result.add_loss_metric:
metrics["actor/weighted_kl_loss"] = Metric(
value=kl_result.loss * config.kl_loss_coef,
aggregation=AggregationType.MEAN,
)
gradient_accumulation_steps = tu.get_non_tensor_data(data, "gradient_accumulation_steps", default=None)
loss_value = loss_value / gradient_accumulation_steps
sp_size = tu.get_non_tensor_data(data, "sp_size", default=None)
if sp_size > 1:
loss_value = loss_value * sp_size
return loss_value, metrics