# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Qwen-Image training-side adapter for diffusers-based diffusion RL.
"""
from typing import Optional
import numpy as np
import torch
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
from diffusers.pipelines.qwenimage.pipeline_qwenimage import calculate_shift
from tensordict import TensorDict
from verl.utils import tensordict_utils as tu
from verl.utils.device import get_device_name
from verl_omni.pipelines.model_base import DiffusionModelBase
from verl_omni.pipelines.schedulers import FlowMatchSDEDiscreteScheduler
from verl_omni.workers.config import DiffusionModelConfig
from .common import QWEN_IMAGE_VAE_SCALE_FACTOR, apply_true_cfg, build_img_shapes
__all__ = ["QwenImage"]
def _build_qwen_image_scheduler(model_path: str) -> FlowMatchSDEDiscreteScheduler:
return FlowMatchSDEDiscreteScheduler.from_pretrained(
pretrained_model_name_or_path=model_path,
subfolder="scheduler",
)
def _configure_qwen_image_scheduler(
scheduler: FlowMatchSDEDiscreteScheduler,
*,
height: int,
width: int,
num_inference_steps: int,
device: str,
) -> None:
latent_height = height // QWEN_IMAGE_VAE_SCALE_FACTOR // 2
latent_width = width // QWEN_IMAGE_VAE_SCALE_FACTOR // 2
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
latent_height * latent_width,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu)
[docs]
@DiffusionModelBase.register("QwenImagePipeline", algorithm="flow_grpo")
class QwenImage(DiffusionModelBase):
"""Training adapter for the Qwen-Image diffusion model.
Implements the :class:`~verl_omni.pipelines.model_base.DiffusionModelBase`
interface for the ``QwenImagePipeline`` architecture, providing scheduler
configuration, model-input construction, and the forward/sampling step
used during RL training (e.g. FlowGRPO).
Registered under ``"QwenImagePipeline"`` so it is automatically selected
when ``DiffusionModelConfig.architecture`` matches that name.
"""
[docs]
@classmethod
def build_scheduler(cls, model_config: DiffusionModelConfig):
"""Build and configure the SDE scheduler for the Qwen-Image model.
Args:
model_config (DiffusionModelConfig): Configuration for the diffusion model,
used to determine the model path and timestep settings.
Returns:
FlowMatchSDEDiscreteScheduler: Scheduler with timesteps already set
for the current device.
"""
scheduler = _build_qwen_image_scheduler(model_config.local_path)
cls.set_timesteps(scheduler, model_config, get_device_name())
return scheduler
[docs]
@classmethod
def set_timesteps(cls, scheduler: FlowMatchSDEDiscreteScheduler, model_config: DiffusionModelConfig, device: str):
"""Configure timesteps and sigmas on the scheduler for Qwen-Image.
Args:
scheduler (FlowMatchSDEDiscreteScheduler): The scheduler whose timesteps
and sigmas will be set.
model_config (DiffusionModelConfig): Configuration providing height, width,
and number of inference steps.
device (str): The device (e.g. ``"cuda"``) to move the timesteps to.
"""
_configure_qwen_image_scheduler(
scheduler,
height=model_config.pipeline.height,
width=model_config.pipeline.width,
num_inference_steps=model_config.pipeline.num_inference_steps,
device=device,
)
[docs]
@classmethod
def forward_and_sample_previous_step(
cls,
module: QwenImageTransformer2DModel,
scheduler: FlowMatchSDEDiscreteScheduler,
model_config: DiffusionModelConfig,
model_inputs: dict[str, torch.Tensor],
negative_model_inputs: Optional[dict[str, torch.Tensor]],
scheduler_inputs: Optional[TensorDict | dict[str, torch.Tensor]],
step: int,
):
"""Run the Qwen-Image transformer and sample the previous denoising step.
Used by RL algorithms (FlowGRPO) that require log-probabilities for
reversed-sampling. Applies True-CFG guidance when
``model_config.true_cfg_scale > 1.0``.
Args:
module (QwenImageTransformer2DModel): The Qwen-Image transformer module.
scheduler (FlowMatchSDEDiscreteScheduler): Scheduler used to sample
the previous step and compute log-probabilities.
model_config (DiffusionModelConfig): Configuration providing
``true_cfg_scale``, ``algo.noise_level``, and ``algo.sde_type``.
model_inputs (dict[str, torch.Tensor]): Positive-prompt inputs for
the transformer forward pass.
negative_model_inputs (Optional[dict[str, torch.Tensor]]): Negative-prompt
inputs used for True-CFG; may be ``None`` when CFG is disabled.
scheduler_inputs (Optional[TensorDict | dict[str, torch.Tensor]]): Must
contain ``"all_latents"`` and ``"all_timesteps"`` tensors.
step (int): Current denoising step index.
Returns:
tuple: A 4-tuple of ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)``.
"""
assert scheduler_inputs is not None
latents = scheduler_inputs["all_latents"]
timesteps = scheduler_inputs["all_timesteps"]
noise_pred = cls.forward(module, model_config, model_inputs)
true_cfg_scale = model_config.pipeline.true_cfg_scale
if true_cfg_scale > 1.0:
assert negative_model_inputs is not None
neg_noise_pred = cls.forward(module, model_config, negative_model_inputs)
noise_pred = apply_true_cfg(noise_pred, neg_noise_pred, true_cfg_scale)
_, log_prob, prev_sample_mean, std_dev_t, sqrt_dt = scheduler.sample_previous_step(
sample=latents[:, step].float(),
model_output=noise_pred.float(),
timestep=timesteps[:, step],
noise_level=model_config.algo.noise_level,
prev_sample=latents[:, step + 1].float(),
sde_type=model_config.algo.sde_type,
return_logprobs=True,
return_sqrt_dt=True,
)
return log_prob, prev_sample_mean, std_dev_t, sqrt_dt