Source code for verl_omni.trainer.diffusion.diffusion_metric_utils

# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Metrics for diffusion (image generation) training.
"""

from typing import Any, Literal

import numpy as np
import torch
from verl import DataProto


[docs] def compute_data_metrics_diffusion(batch: DataProto) -> dict[str, Any]: """ Computes various metrics from a diffusion training batch. For diffusion (image generation) models, rewards and advantages are indexed over denoising timesteps rather than output tokens. Args: batch: A DataProto object containing diffusion batch data. GRPO-style batches include sample_level_rewards [B, T], advantages [B, T], and returns [B, T]. DPO-style batches may only include sample_level_rewards [B]. Returns: A dictionary of metrics including: - critic/rewards/mean, max, min: Per-image reward statistics - critic/rewards/zero_std_ratio: Fraction of prompt groups whose reward std is zero - critic/rewards/std_mean: Mean per-prompt reward standard deviation - critic/rewards/group_size: Average number of images sampled per unique prompt - critic/advantages/mean, max, min: Element-wise advantage statistics over B*T, when available - critic/returns/mean, max, min: Element-wise return statistics over B*T, when available """ sample_level_rewards = batch.batch["sample_level_rewards"] if sample_level_rewards.ndim > 1: sequence_reward = sample_level_rewards.mean(dim=1) # [B] else: sequence_reward = sample_level_rewards # [B] reward_mean = torch.mean(sequence_reward).detach().item() reward_max = torch.max(sequence_reward).detach().item() reward_min = torch.min(sequence_reward).detach().item() metrics = { # reward "critic/rewards/mean": reward_mean, "critic/rewards/max": reward_max, "critic/rewards/min": reward_min, } if "advantages" in batch.batch: # Flatten [B, T] tensors for aggregate statistics across timesteps. advantages = batch.batch["advantages"].flatten() # [B*T] metrics.update( { "critic/advantages/mean": torch.mean(advantages).detach().item(), "critic/advantages/max": torch.max(advantages).detach().item(), "critic/advantages/min": torch.min(advantages).detach().item(), } ) if "returns" in batch.batch: returns = batch.batch["returns"].flatten() # [B*T] metrics.update( { "critic/returns/mean": torch.mean(returns).detach().item(), "critic/returns/max": torch.max(returns).detach().item(), "critic/returns/min": torch.min(returns).detach().item(), } ) if "uid" in batch.non_tensor_batch: rewards_np = sequence_reward.cpu().float().numpy() uid_array = np.array(batch.non_tensor_batch["uid"]) unique_uids = np.unique(uid_array) per_prompt_stds = np.array([np.std(rewards_np[uid_array == uid]) for uid in unique_uids]) metrics["critic/rewards/zero_std_ratio"] = float(np.mean(per_prompt_stds == 0)) metrics["critic/rewards/std_mean"] = float(np.mean(per_prompt_stds)) metrics["critic/rewards/group_size"] = float(len(rewards_np) / len(unique_uids)) return metrics
def compute_old_policy_metrics( update_result: tuple[bool, float, Literal["none", "copy", "ema"]], ) -> dict[str, Any]: """Build metrics for old-policy adapter refreshes.""" update_applied, decay, update_type = update_result return { "old_policy/update_applied": float(update_applied), "old_policy/copy_update": float(update_type == "copy"), "old_policy/ema_update": float(update_type == "ema"), "old_policy/decay": float(decay), }
[docs] def compute_timing_metrics_diffusion(timing_raw: dict[str, float], num_images: int) -> dict[str, Any]: """ Computes timing metrics for diffusion training. Args: timing_raw: A dictionary mapping stage names to their execution times in seconds. num_images: Total number of images processed in the batch, used to compute per-image timing. Returns: A dictionary containing: - timing_s/{name}: Raw timing in seconds for each stage - timing_per_image_ms/{name}: Per-image timing in milliseconds for core compute stages (gen, ref, old_log_prob, adv, update_actor). Non-compute stages such as save_checkpoint, update_weights, and testing are excluded. """ num_images_of_section = {name: num_images for name in ["gen", "ref", "old_log_prob", "adv", "update_actor"]} return { **{f"timing_s/{name}": value for name, value in timing_raw.items()}, **{ f"timing_per_image_ms/{name}": timing_raw[name] * 1000 / num_images_of_section[name] for name in set(num_images_of_section.keys()) & set(timing_raw.keys()) }, }
[docs] def compute_throughput_metrics_diffusion(batch: DataProto, timing_raw: dict[str, float], n_gpus: int) -> dict[str, Any]: """ Computes throughput metrics for diffusion (image/video generation) training. Unlike language model training where throughput is measured in tokens/sec, diffusion training generates images, so throughput is reported as images per second. Args: batch: A DataProto object containing diffusion batch data. timing_raw: A dictionary mapping stage names to their execution times in seconds. Must contain a "step" key with the total step time. n_gpus: Number of GPUs used for training. Returns: A dictionary containing: - perf/total_num_images: Number of images processed in the batch - perf/time_per_step: Time taken for the step in seconds - perf/throughput: Images generated per second per GPU """ if "advantages" in batch.batch: batch_size = batch.batch["advantages"].shape[0] else: batch_size = batch.batch["sample_level_rewards"].shape[0] time = timing_raw["step"] return { "perf/total_num_images": batch_size, "perf/time_per_step": time, "perf/throughput": batch_size / (time * n_gpus), }
def compute_reward_extra_metrics_diffusion(reward_extra_infos_dict: dict) -> dict[str, Any]: """Computes per-sub-reward mean metrics for multi-reward tracking.""" metrics = {} if not reward_extra_infos_dict: return metrics for key, values in reward_extra_infos_dict.items(): if isinstance(values, np.ndarray): if not np.issubdtype(values.dtype, np.number): continue metrics[f"critic/{key}/mean"] = float(values.mean()) elif isinstance(values, list) and len(values) > 0: if not isinstance(values[0], int | float): continue metrics[f"critic/{key}/mean"] = float(np.mean(values)) return metrics