# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from dataclasses import dataclass
from typing import Literal, Optional

import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import BaseOutput
from diffusers.utils.torch_utils import randn_tensor


@dataclass
class FlowMatchSDEDiscreteSchedulerOutput(BaseOutput):
    """
    Output class for the scheduler's `step` function output.

    Args:
        prev_sample (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels)` for images):
            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
            denoising loop.
        log_prob (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
            The log probability of the previous sample.
        prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels)` for images):
            The mean of the computed sample of previous timestep.
        std_dev_t (`torch.FloatTensor` of shape `(batch_size, 1, 1)`):
            The standard deviation used to compute `prev_sample`.
    """

    prev_sample: torch.FloatTensor
    log_prob: Optional[torch.FloatTensor]
    prev_sample_mean: torch.FloatTensor
    std_dev_t: torch.FloatTensor


class FlowMatchSDEDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
    """SDE version of the FlowMatchEulerDiscreteScheduler.
    The implementation is based on FlowGRPO paper (https://arxiv.org/abs/2505.05470)
    and diffusers v0.37 branch.
    """

    def step(
        self,
        model_output: torch.FloatTensor,
        timestep: float | torch.FloatTensor,
        sample: torch.FloatTensor,
        s_churn: float = 0.0,
        s_tmin: float = 0.0,
        s_tmax: float = float("inf"),
        s_noise: float = 1.0,
        generator: Optional[torch.Generator] = None,
        per_token_timesteps: Optional[torch.Tensor] = None,
        return_dict: bool = True,
        noise_level: float = 0.7,
        prev_sample: Optional[torch.FloatTensor] = None,
        sde_type: Literal["sde", "cps", "dance_sde"] = "sde",
        return_logprobs: bool = True,
    ) -> FlowMatchSDEDiscreteSchedulerOutput | tuple:
        """
        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
        process from the learned model outputs (most often the predicted noise).

        Modified from https://github.com/yifan123/flow_grpo/blob/main/flow_grpo/diffusers_patch/sd3_sde_with_logprob.py

        Args:
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`float`):
                The current discrete timestep in the diffusion chain.
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            s_churn (`float`):
            s_tmin  (`float`):
            s_tmax  (`float`):
            s_noise (`float`, defaults to 1.0):
                Scaling factor for noise added to the sample.
            generator (`torch.Generator`, *optional*):
                A random number generator.
            per_token_timesteps (`torch.Tensor`, *optional*):
                The timesteps for each token in the sample.
            return_dict (`bool`):
                Whether or not to return a
                [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchSDEDiscreteSchedulerOutput`] or tuple.
            noise_level (`float`, *optional*, defaults to 0.7):
                The noise level used in the SDE.
            prev_sample (`torch.FloatTensor`, *optional*):
                The sample from the previous timestep. If not provided, it will be sampled inside the function.
            sde_type (`str`, *optional*, defaults to "sde"):
                The type of SDE to use. Choose between "sde", "cps", and "dance_sde".
            return_logprobs (`bool`, *optional*, defaults to True):
                Whether to return log probabilities of the previous sample.
        """

        if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
            raise ValueError(
                (
                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
                    " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass"
                    " one of the `scheduler.timesteps` as a timestep."
                ),
            )

        if self.step_index is None:
            self._init_step_index(timestep)

        # Upcast to avoid precision issues when computing prev_sample
        sample = sample.to(torch.float32)
        if prev_sample is not None:
            prev_sample = prev_sample.to(torch.float32)

        prev_sample, log_prob, prev_sample_mean, std_dev_t = self.sample_previous_step(
            sample=sample,
            model_output=model_output,
            generator=generator,
            per_token_timesteps=per_token_timesteps,
            noise_level=noise_level,
            prev_sample=prev_sample,
            sde_type=sde_type,
            return_logprobs=return_logprobs,
        )

        # upon completion increase step index by one
        self._step_index += 1

        if not return_dict:
            return (prev_sample, log_prob, prev_sample_mean, std_dev_t)

        return FlowMatchSDEDiscreteSchedulerOutput(
            prev_sample=prev_sample, log_prob=log_prob, prev_sample_mean=prev_sample_mean, std_dev_t=std_dev_t
        )

    def sample_previous_step(
        self,
        sample: torch.Tensor,
        model_output: torch.Tensor,
        timestep: Optional[torch.FloatTensor] = None,
        generator: Optional[torch.Generator] = None,
        per_token_timesteps: Optional[torch.Tensor] = None,
        noise_level: float = 0.7,
        prev_sample: Optional[torch.Tensor] = None,
        sde_type: Literal["cps", "sde", "dance_sde"] = "sde",
        return_logprobs: bool = True,
        return_sqrt_dt: bool = False,
    ):
        """
        Run a single SDE / CPS reverse step.

        Args:
            sample (`torch.FloatTensor`):
                A current instance of a sample created by the diffusion process.
            model_output (`torch.FloatTensor`):
                The direct output from learned diffusion model.
            timestep (`torch.FloatTensor`, *optional*):
                The current discrete timestep in the diffusion chain. When `None`, the internal
                `step_index` is used (sequential denoising loop).
            generator (`torch.Generator`, *optional*):
                A random number generator.
            per_token_timesteps (`torch.Tensor`, *optional*):
                The timesteps for each token in the sample. Currently not supported.
            noise_level (`float`, *optional*, defaults to 0.7):
                The noise level used in the SDE.
            prev_sample (`torch.FloatTensor`, *optional*):
                The sample from the previous timestep. If provided, it is used directly for
                log-probability computation instead of being sampled.
            sde_type (`str`, *optional*, defaults to "sde"):
                The type of SDE to use. Choose between "sde", "cps", and "dance_sde".
            return_logprobs (`bool`, *optional*, defaults to True):
                Whether to return log probabilities of the previous sample.
            return_sqrt_dt (`bool`, *optional*, defaults to False):
                Whether to additionally return `sqrt(-dt)` as a tensor of shape `(batch_size,)`.
                Used by GRPO-Guard to compute the importance-ratio normalization
                (see `GRPOGuardLoss`).
        """
        assert sde_type in ["sde", "cps", "dance_sde"]
        assert sample.dtype == torch.float32
        if prev_sample is not None:
            assert prev_sample.dtype == torch.float32
        assert model_output.dtype == torch.float32

        if per_token_timesteps is not None:
            raise NotImplementedError("per_token_timesteps is not supported yet for FlowMatchSDEDiscreteScheduler.")
        else:
            if timestep is None:
                sigma_idx = self.step_index
                sigma = self.sigmas[sigma_idx]
                sigma_prev = self.sigmas[sigma_idx + 1]
            else:
                sigma_idx = torch.tensor([self.index_for_timestep(t) for t in timestep])
                sigma = self.sigmas[sigma_idx].view(-1, *([1] * (len(sample.shape) - 1)))
                sigma_prev = self.sigmas[sigma_idx + 1].view(-1, *([1] * (len(sample.shape) - 1)))

            sigma_max = self.sigmas[1]
            dt = sigma_prev - sigma

        if sde_type == "sde":
            std_dev_t = torch.sqrt(sigma / (1 - torch.where(sigma == 1, sigma_max, sigma))) * noise_level

            prev_sample_mean = (
                sample * (1 + std_dev_t**2 / (2 * sigma) * dt)
                + model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
            )

            if prev_sample is None:
                variance_noise = randn_tensor(
                    model_output.shape,
                    generator=generator,
                    device=model_output.device,
                    dtype=model_output.dtype,
                )
                prev_sample = prev_sample_mean + std_dev_t * torch.sqrt(-1 * dt) * variance_noise

            if return_logprobs:
                log_prob = (
                    -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * ((std_dev_t * torch.sqrt(-1 * dt)) ** 2))
                    - torch.log(std_dev_t * torch.sqrt(-1 * dt))
                    - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
                )
            else:
                log_prob = None

        elif sde_type == "cps":
            std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2)
            pred_original_sample = sample - sigma * model_output
            noise_estimate = sample + model_output * (1 - sigma)
            prev_sample_mean = pred_original_sample * (1 - sigma_prev) + noise_estimate * torch.sqrt(
                sigma_prev**2 - std_dev_t**2
            )

            if prev_sample is None:
                variance_noise = randn_tensor(
                    model_output.shape,
                    generator=generator,
                    device=model_output.device,
                    dtype=model_output.dtype,
                )
                prev_sample = prev_sample_mean + std_dev_t * variance_noise

            if return_logprobs:
                log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
            else:
                log_prob = None

        elif sde_type == "dance_sde":
            # DanceGRPO SDE step from https://github.com/XueZeyue/DanceGRPO
            # Based on score-based SDE correction with eta (noise_level) controlling
            # the stochasticity. This formulation is numerically stable even when
            # sigma is close to 1, unlike the FlowGRPO "sde" variant.
            dsigma = sigma_prev - sigma  # negative (sigma decreases)
            delta_t = sigma - sigma_prev  # positive

            # ODE mean: x_{t-1} = x_t + dsigma * model_output
            prev_sample_mean = sample + dsigma * model_output

            # Predicted original sample: x_0 = x_t - sigma * model_output
            pred_original_sample = sample - sigma * model_output

            # Score-based SDE correction term
            # score_estimate = -(x_t - x_0 * (1 - sigma)) / sigma^2
            # log_term = -0.5 * eta^2 * score_estimate
            # prev_sample_mean += log_term * dsigma
            score_estimate = -(sample - pred_original_sample * (1 - sigma)) / (sigma**2)
            log_term = -0.5 * noise_level**2 * score_estimate
            prev_sample_mean = prev_sample_mean + log_term * dsigma

            # Noise standard deviation: eta * sqrt(delta_t)
            std_dev_t = noise_level * torch.sqrt(delta_t)

            if prev_sample is None:
                variance_noise = randn_tensor(
                    model_output.shape,
                    generator=generator,
                    device=model_output.device,
                    dtype=model_output.dtype,
                )
                prev_sample = prev_sample_mean + std_dev_t * variance_noise

            if return_logprobs:
                log_prob = (
                    -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
                    - torch.log(std_dev_t)
                    - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
                )
            else:
                log_prob = None

        # mean along all but batch dimension
        log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) if log_prob is not None else None
        if return_sqrt_dt:
            sqrt_dt = torch.sqrt(-1 * dt)
            if sqrt_dt.ndim == 0:
                sqrt_dt = sqrt_dt.expand(sample.shape[0]).clone()
            else:
                sqrt_dt = sqrt_dt.reshape(sqrt_dt.shape[0])
            return prev_sample, log_prob, prev_sample_mean, std_dev_t, sqrt_dt
        return prev_sample, log_prob, prev_sample_mean, std_dev_t
