Source code for verl_omni.pipelines.schedulers.flow_match_sde

# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import dataclass
from typing import Literal, Optional

import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor


[docs] @dataclass class FlowMatchSDEDiscreteSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. log_prob (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): The log probability of the previous sample. prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels)` for images): The mean of the computed sample of previous timestep. std_dev_t (`torch.FloatTensor` of shape `(batch_size, 1, 1)`): The standard deviation used to compute `prev_sample`. """ prev_sample: torch.FloatTensor log_prob: Optional[torch.FloatTensor] prev_sample_mean: torch.FloatTensor std_dev_t: torch.FloatTensor
[docs] class FlowMatchSDEDiscreteScheduler(FlowMatchEulerDiscreteScheduler): """SDE version of the FlowMatchEulerDiscreteScheduler. The implementation is based on FlowGRPO paper (https://arxiv.org/abs/2505.05470) and diffusers v0.37 branch. """
[docs] def step( self, model_output: torch.FloatTensor, timestep: float | torch.FloatTensor, sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, per_token_timesteps: Optional[torch.Tensor] = None, return_dict: bool = True, noise_level: float = 0.7, prev_sample: Optional[torch.FloatTensor] = None, sde_type: Literal["sde", "cps", "dance_sde"] = "sde", return_logprobs: bool = True, ) -> FlowMatchSDEDiscreteSchedulerOutput | tuple: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Modified from https://github.com/yifan123/flow_grpo/blob/main/flow_grpo/diffusers_patch/sd3_sde_with_logprob.py Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. per_token_timesteps (`torch.Tensor`, *optional*): The timesteps for each token in the sample. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchSDEDiscreteSchedulerOutput`] or tuple. noise_level (`float`, *optional*, defaults to 0.7): The noise level used in the SDE. prev_sample (`torch.FloatTensor`, *optional*): The sample from the previous timestep. If not provided, it will be sampled inside the function. sde_type (`str`, *optional*, defaults to "sde"): The type of SDE to use. Choose between "sde", "cps", and "dance_sde". return_logprobs (`bool`, *optional*, defaults to True): Whether to return log probabilities of the previous sample. """ if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) if prev_sample is not None: prev_sample = prev_sample.to(torch.float32) prev_sample, log_prob, prev_sample_mean, std_dev_t = self.sample_previous_step( sample=sample, model_output=model_output, generator=generator, per_token_timesteps=per_token_timesteps, noise_level=noise_level, prev_sample=prev_sample, sde_type=sde_type, return_logprobs=return_logprobs, ) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample, log_prob, prev_sample_mean, std_dev_t) return FlowMatchSDEDiscreteSchedulerOutput( prev_sample=prev_sample, log_prob=log_prob, prev_sample_mean=prev_sample_mean, std_dev_t=std_dev_t )
[docs] def sample_previous_step( self, sample: torch.Tensor, model_output: torch.Tensor, timestep: Optional[torch.FloatTensor] = None, generator: Optional[torch.Generator] = None, per_token_timesteps: Optional[torch.Tensor] = None, noise_level: float = 0.7, prev_sample: Optional[torch.Tensor] = None, sde_type: Literal["cps", "sde", "dance_sde"] = "sde", return_logprobs: bool = True, return_sqrt_dt: bool = False, ): """ Run a single SDE / CPS reverse step. Args: sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`torch.FloatTensor`, *optional*): The current discrete timestep in the diffusion chain. When `None`, the internal `step_index` is used (sequential denoising loop). generator (`torch.Generator`, *optional*): A random number generator. per_token_timesteps (`torch.Tensor`, *optional*): The timesteps for each token in the sample. Currently not supported. noise_level (`float`, *optional*, defaults to 0.7): The noise level used in the SDE. prev_sample (`torch.FloatTensor`, *optional*): The sample from the previous timestep. If provided, it is used directly for log-probability computation instead of being sampled. sde_type (`str`, *optional*, defaults to "sde"): The type of SDE to use. Choose between "sde", "cps", and "dance_sde". return_logprobs (`bool`, *optional*, defaults to True): Whether to return log probabilities of the previous sample. return_sqrt_dt (`bool`, *optional*, defaults to False): Whether to additionally return `sqrt(-dt)` as a tensor of shape `(batch_size,)`. Used by GRPO-Guard to compute the importance-ratio normalization (see `GRPOGuardLoss`). """ assert sde_type in ["sde", "cps", "dance_sde"] assert sample.dtype == torch.float32 if prev_sample is not None: assert prev_sample.dtype == torch.float32 assert model_output.dtype == torch.float32 if per_token_timesteps is not None: raise NotImplementedError("per_token_timesteps is not supported yet for FlowMatchSDEDiscreteScheduler.") else: if timestep is None: sigma_idx = self.step_index sigma = self.sigmas[sigma_idx] sigma_prev = self.sigmas[sigma_idx + 1] else: sigma_idx = torch.tensor([self.index_for_timestep(t) for t in timestep]) sigma = self.sigmas[sigma_idx].view(-1, *([1] * (len(sample.shape) - 1))) sigma_prev = self.sigmas[sigma_idx + 1].view(-1, *([1] * (len(sample.shape) - 1))) sigma_max = self.sigmas[1] dt = sigma_prev - sigma if sde_type == "sde": std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * noise_level prev_sample_mean = ( sample * (1 + std_dev_t**2 / (2 * sigma) * dt) + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt ) if prev_sample is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype, ) prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise if return_logprobs: log_prob = ( -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2)) - torch.log(std_dev_t * torch.sqrt(-1 * dt)) - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) ) else: log_prob = None elif sde_type == "cps": std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2) pred_original_sample = sample - sigma * model_output noise_estimate = sample + model_output * (1 - sigma) prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt( sigma_prev**2 - std_dev_t**2 ) if prev_sample is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype, ) prev_sample = prev_sample_mean + std_dev_t * variance_noise if return_logprobs: log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2) else: log_prob = None elif sde_type == "dance_sde": # DanceGRPO SDE step from https://github.com/XueZeyue/DanceGRPO # Based on score-based SDE correction with eta (noise_level) controlling # the stochasticity. This formulation is numerically stable even when # sigma is close to 1, unlike the FlowGRPO "sde" variant. dsigma = sigma_prev - sigma # negative (sigma decreases) delta_t = sigma - sigma_prev # positive # ODE mean: x_{t-1} = x_t + dsigma * model_output prev_sample_mean = sample + dsigma * model_output # Predicted original sample: x_0 = x_t - sigma * model_output pred_original_sample = sample - sigma * model_output # Score-based SDE correction term # score_estimate = -(x_t - x_0 * (1 - sigma)) / sigma^2 # log_term = -0.5 * eta^2 * score_estimate # prev_sample_mean += log_term * dsigma score_estimate = -(sample - pred_original_sample * (1 - sigma)) / (sigma**2) log_term = -0.5 * noise_level**2 * score_estimate prev_sample_mean = prev_sample_mean + log_term * dsigma # Noise standard deviation: eta * sqrt(delta_t) std_dev_t = noise_level * torch.sqrt(delta_t) if prev_sample is None: variance_noise = randn_tensor( model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype, ) prev_sample = prev_sample_mean + std_dev_t * variance_noise if return_logprobs: log_prob = ( -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) - torch.log(std_dev_t) - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) ) else: log_prob = None # mean along all but batch dimension log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) if log_prob is not None else None if return_sqrt_dt: sqrt_dt = torch.sqrt(-1 * dt) if sqrt_dt.ndim == 0: sqrt_dt = sqrt_dt.expand(sample.shape[0]).clone() else: sqrt_dt = sqrt_dt.reshape(sqrt_dt.shape[0]) return prev_sample, log_prob, prev_sample_mean, std_dev_t, sqrt_dt return prev_sample, log_prob, prev_sample_mean, std_dev_t