# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor
[docs]
@dataclass
class FlowMatchSDEDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
log_prob (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
The log probability of the previous sample.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels)` for images):
The mean of the computed sample of previous timestep.
std_dev_t (`torch.FloatTensor` of shape `(batch_size, 1, 1)`):
The standard deviation used to compute `prev_sample`.
"""
prev_sample: torch.FloatTensor
log_prob: Optional[torch.FloatTensor]
prev_sample_mean: torch.FloatTensor
std_dev_t: torch.FloatTensor
[docs]
class FlowMatchSDEDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
"""SDE version of the FlowMatchEulerDiscreteScheduler.
The implementation is based on FlowGRPO paper (https://arxiv.org/abs/2505.05470)
and diffusers v0.37 branch.
"""
[docs]
def step(
self,
model_output: torch.FloatTensor,
timestep: float | torch.FloatTensor,
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
per_token_timesteps: Optional[torch.Tensor] = None,
return_dict: bool = True,
noise_level: float = 0.7,
prev_sample: Optional[torch.FloatTensor] = None,
sde_type: Literal["sde", "cps", "dance_sde"] = "sde",
return_logprobs: bool = True,
) -> FlowMatchSDEDiscreteSchedulerOutput | tuple:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Modified from https://github.com/yifan123/flow_grpo/blob/main/flow_grpo/diffusers_patch/sd3_sde_with_logprob.py
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchSDEDiscreteSchedulerOutput`] or tuple.
noise_level (`float`, *optional*, defaults to 0.7):
The noise level used in the SDE.
prev_sample (`torch.FloatTensor`, *optional*):
The sample from the previous timestep. If not provided, it will be sampled inside the function.
sde_type (`str`, *optional*, defaults to "sde"):
The type of SDE to use. Choose between "sde", "cps", and "dance_sde".
return_logprobs (`bool`, *optional*, defaults to True):
Whether to return log probabilities of the previous sample.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if prev_sample is not None:
prev_sample = prev_sample.to(torch.float32)
prev_sample, log_prob, prev_sample_mean, std_dev_t = self.sample_previous_step(
sample=sample,
model_output=model_output,
generator=generator,
per_token_timesteps=per_token_timesteps,
noise_level=noise_level,
prev_sample=prev_sample,
sde_type=sde_type,
return_logprobs=return_logprobs,
)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample, log_prob, prev_sample_mean, std_dev_t)
return FlowMatchSDEDiscreteSchedulerOutput(
prev_sample=prev_sample, log_prob=log_prob, prev_sample_mean=prev_sample_mean, std_dev_t=std_dev_t
)
[docs]
def sample_previous_step(
self,
sample: torch.Tensor,
model_output: torch.Tensor,
timestep: Optional[torch.FloatTensor] = None,
generator: Optional[torch.Generator] = None,
per_token_timesteps: Optional[torch.Tensor] = None,
noise_level: float = 0.7,
prev_sample: Optional[torch.Tensor] = None,
sde_type: Literal["cps", "sde", "dance_sde"] = "sde",
return_logprobs: bool = True,
return_sqrt_dt: bool = False,
):
"""
Run a single SDE / CPS reverse step.
Args:
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`torch.FloatTensor`, *optional*):
The current discrete timestep in the diffusion chain. When `None`, the internal
`step_index` is used (sequential denoising loop).
generator (`torch.Generator`, *optional*):
A random number generator.
per_token_timesteps (`torch.Tensor`, *optional*):
The timesteps for each token in the sample. Currently not supported.
noise_level (`float`, *optional*, defaults to 0.7):
The noise level used in the SDE.
prev_sample (`torch.FloatTensor`, *optional*):
The sample from the previous timestep. If provided, it is used directly for
log-probability computation instead of being sampled.
sde_type (`str`, *optional*, defaults to "sde"):
The type of SDE to use. Choose between "sde", "cps", and "dance_sde".
return_logprobs (`bool`, *optional*, defaults to True):
Whether to return log probabilities of the previous sample.
return_sqrt_dt (`bool`, *optional*, defaults to False):
Whether to additionally return `sqrt(-dt)` as a tensor of shape `(batch_size,)`.
Used by GRPO-Guard to compute the importance-ratio normalization
(see `GRPOGuardLoss`).
"""
assert sde_type in ["sde", "cps", "dance_sde"]
assert sample.dtype == torch.float32
if prev_sample is not None:
assert prev_sample.dtype == torch.float32
assert model_output.dtype == torch.float32
if per_token_timesteps is not None:
raise NotImplementedError("per_token_timesteps is not supported yet for FlowMatchSDEDiscreteScheduler.")
else:
if timestep is None:
sigma_idx = self.step_index
sigma = self.sigmas[sigma_idx]
sigma_prev = self.sigmas[sigma_idx + 1]
else:
sigma_idx = torch.tensor([self.index_for_timestep(t) for t in timestep])
sigma = self.sigmas[sigma_idx].view(-1, *([1] * (len(sample.shape) - 1)))
sigma_prev = self.sigmas[sigma_idx + 1].view(-1, *([1] * (len(sample.shape) - 1)))
sigma_max = self.sigmas[1]
dt = sigma_prev - sigma
if sde_type == "sde":
std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * noise_level
prev_sample_mean = (
sample * (1 + std_dev_t**2 / (2 * sigma) * dt)
+ model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
)
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise
if return_logprobs:
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
- torch.log(std_dev_t * torch.sqrt(-1 * dt))
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
else:
log_prob = None
elif sde_type == "cps":
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2)
pred_original_sample = sample - sigma * model_output
noise_estimate = sample + model_output * (1 - sigma)
prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(
sigma_prev**2 - std_dev_t**2
)
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise
if return_logprobs:
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
else:
log_prob = None
elif sde_type == "dance_sde":
# DanceGRPO SDE step from https://github.com/XueZeyue/DanceGRPO
# Based on score-based SDE correction with eta (noise_level) controlling
# the stochasticity. This formulation is numerically stable even when
# sigma is close to 1, unlike the FlowGRPO "sde" variant.
dsigma = sigma_prev - sigma # negative (sigma decreases)
delta_t = sigma - sigma_prev # positive
# ODE mean: x_{t-1} = x_t + dsigma * model_output
prev_sample_mean = sample + dsigma * model_output
# Predicted original sample: x_0 = x_t - sigma * model_output
pred_original_sample = sample - sigma * model_output
# Score-based SDE correction term
# score_estimate = -(x_t - x_0 * (1 - sigma)) / sigma^2
# log_term = -0.5 * eta^2 * score_estimate
# prev_sample_mean += log_term * dsigma
score_estimate = -(sample - pred_original_sample * (1 - sigma)) / (sigma**2)
log_term = -0.5 * noise_level**2 * score_estimate
prev_sample_mean = prev_sample_mean + log_term * dsigma
# Noise standard deviation: eta * sqrt(delta_t)
std_dev_t = noise_level * torch.sqrt(delta_t)
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise
if return_logprobs:
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
- torch.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
)
else:
log_prob = None
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) if log_prob is not None else None
if return_sqrt_dt:
sqrt_dt = torch.sqrt(-1 * dt)
if sqrt_dt.ndim == 0:
sqrt_dt = sqrt_dt.expand(sample.shape[0]).clone()
else:
sqrt_dt = sqrt_dt.reshape(sqrt_dt.shape[0])
return prev_sample, log_prob, prev_sample_mean, std_dev_t, sqrt_dt
return prev_sample, log_prob, prev_sample_mean, std_dev_t