# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import random
from typing import Any, Optional
import hydra
import numpy as np
import ray
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, ConfigDict
from tensordict import TensorDict
from verl.base_config import BaseConfig
from verl.experimental.agent_loop.agent_loop import (
AgentLoopMetrics,
DictConfigWrap,
_agent_loop_registry,
)
from verl.experimental.agent_loop.utils import resolve_config_path
from verl.protocol import DataProto
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.dataset.rl_dataset import get_dataset_class
from verl.utils.profiler import simple_timer
from verl.workers.rollout.llm_server import LLMServerClient
from verl_omni.agent_loop.utils import maybe_per_rollout_seeds
from verl_omni.workers.config import DiffusionModelConfig, DiffusionRolloutConfig
def _config_to_sampling_dict(config: Optional[BaseConfig]) -> dict:
if config is None:
return {}
return {k: v for k, v in config.items() if not k.startswith("_")}
[docs]
class DiffusionAgentLoopOutput(BaseModel):
"""Agent loop output."""
model_config = ConfigDict(arbitrary_types_allowed=True)
prompt_ids: list[int]
"""Prompt token ids."""
response_diffusion_output: Any
"""Response diffusion output (torch.Tensor): image tensor (CHW) / video tensor (TCHW)."""
response_logprobs: Optional[Any] = None
"""Log probabilities for the response tokens. (torch.Tensor)"""
reward_score: Optional[float] = None
"""Reward score for the trajectory."""
num_turns: int = 0
"""Number of chat turns, including user, assistant, tool."""
metrics: AgentLoopMetrics
"""Auxiliary performance metrics"""
extra_fields: dict[str, Any] = {}
"""Extra fields for dynamic addition."""
class _InternalDiffusionAgentLoopOutput(DiffusionAgentLoopOutput):
"""Internal agent loop output with padded sequences."""
model_config = ConfigDict(arbitrary_types_allowed=True)
prompt_ids: torch.Tensor
"""Padded prompt token ids."""
response_diffusion_output: torch.Tensor
"""Response diffusion output: image (NCHW format) / video (NTCHW format)."""
response_logprobs: Optional[torch.Tensor] = None
"""Log probabilities over denoising timesteps."""
extra_fields: dict[str, Any] = {}
"""Extra fields for dynamic addition."""
[docs]
class DiffusionAgentLoopWorker:
"""Diffusion Agent loop worker takes a batch of messages and run each message in an agent loop.
Args:
config (DictConfig): whole config for main entrypoint.
llm_client (LLMServerClient): Client for the LLM server replicas, produced by
``LLMServerManager.get_client()`` in the trainer.
teacher_client (dict[str, LLMServerClient]): Not used by diffusion training; accepted to
keep the constructor signature compatible with verl's ``AgentLoopManager.create()``,
which positionally forwards a teacher client argument to each worker.
reward_loop_worker_handles (List[ray.actor.ActorHandle]): Actor handles for streaming
reward computation.
"""
[docs]
def __init__(
self,
config: DictConfig,
llm_client: LLMServerClient,
teacher_client: dict[str, LLMServerClient] | None = None,
reward_loop_worker_handles: list[ray.actor.ActorHandle] = None,
):
self.config = config
rollout_config = config.actor_rollout_ref.rollout
model_config = config.actor_rollout_ref.model
self.rollout_config: DiffusionRolloutConfig = omega_conf_to_dataclass(rollout_config)
self.model_config: DiffusionModelConfig = omega_conf_to_dataclass(model_config)
if not hasattr(self, "server_manager"):
self.server_manager = llm_client
self.dataset_cls = get_dataset_class(config.data)
self.reward_loop_worker_handles = reward_loop_worker_handles
self.tokenizer = self.model_config.tokenizer
self.processor = self.model_config.processor
self.max_prompt_embed_length = self.rollout_config.pipeline.max_sequence_length
agent_loop_config_path = self.rollout_config.agent.agent_loop_config_path
if agent_loop_config_path:
resolved_path = resolve_config_path(agent_loop_config_path)
agent_loop_configs = OmegaConf.load(resolved_path)
for agent_loop_config in agent_loop_configs:
_agent_loop_registry[agent_loop_config.name] = agent_loop_config
if self.model_config.get("custom_chat_template", None) is not None:
if self.model_config.processor is not None:
self.model_config.processor.chat_template = self.model_config.custom_chat_template
self.model_config.tokenizer.chat_template = self.model_config.custom_chat_template
[docs]
async def generate_sequences(self, batch: DataProto) -> DataProto:
"""Generate sequences from agent loop.
Args:
batch (DataProto): Input batch.
Returns:
DataProto: Output batch with the following fields.
- ``prompts``: ``[bsz, prompt_length]`` prompt token ids from dataset.
- ``responses``: diffusion output, typically ``[bsz, C, H, W]`` (image)
or ``[bsz, T, C, H, W]`` (video).
- ``rm_scores`` (optional): ``[bsz, 1]`` reward model scores.
- ``meta_info``:
- ``metrics``: ``List[dict]``, per-sample agent loop metrics.
- ``reward_extra_keys`` (optional): ``List[str]``, keys for reward
extra info for logging/validation.
"""
config = self.rollout_config
sampling_params = {
**_config_to_sampling_dict(config.pipeline),
**_config_to_sampling_dict(config.algo),
"logprobs": config.calculate_log_probs,
}
is_validate = batch.meta_info.get("validate", False)
per_rollout_seeds: Optional[list[int]] = None
if is_validate:
sampling_params.update(_config_to_sampling_dict(config.val_kwargs.pipeline))
sampling_params.update(_config_to_sampling_dict(config.val_kwargs.algo))
sampling_params["seed"] = config.val_kwargs.seed
sampling_params["logprobs"] = False
else:
sampling_params["global_steps"] = batch.meta_info["global_steps"]
global_indices = batch.non_tensor_batch.get("_rollout_seed_global_idx")
per_rollout_seeds = maybe_per_rollout_seeds(batch.meta_info, len(batch), global_indices)
if "agent_name" not in batch.non_tensor_batch:
default_agent_loop = config.agent.default_agent_loop
batch.non_tensor_batch["agent_name"] = np.array([default_agent_loop] * len(batch), dtype=object)
tasks = []
for i in range(len(batch)):
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
task_sampling_params = sampling_params.copy()
if per_rollout_seeds is not None:
task_sampling_params["seed"] = per_rollout_seeds[i]
tasks.append(asyncio.create_task(self._run_agent_loop(task_sampling_params, **kwargs)))
outputs = await asyncio.gather(*tasks)
output = self._postprocess(outputs, input_non_tensor_batch=batch.non_tensor_batch)
return output
async def _run_agent_loop(
self,
sampling_params: dict[str, Any],
*,
agent_name: str,
**kwargs,
) -> _InternalDiffusionAgentLoopOutput:
assert agent_name in _agent_loop_registry, (
f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}"
)
agent_loop_config = _agent_loop_registry[agent_name]
agent_loop = hydra.utils.instantiate(
config=agent_loop_config,
trainer_config=DictConfigWrap(config=self.config),
server_manager=self.server_manager,
tokenizer=self.tokenizer,
processor=self.processor,
dataset_cls=self.dataset_cls,
data_config=DictConfigWrap(self.config.data),
)
output: DiffusionAgentLoopOutput = await agent_loop.run(sampling_params, **kwargs)
return await self._agent_loop_postprocess(output, **kwargs)
async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionAgentLoopOutput:
"""Perform post-processing operations on the output of each individual agent loop."""
# Pad extra tensor outputs from vllm-omni (e.g. prompt embeddings).
extra_fields = {}
for k, v in output.extra_fields.items():
if isinstance(v, torch.Tensor):
if k in ["prompt_embeds", "negative_prompt_embeds"]:
pad_tuple = (0, 0, 0, self.max_prompt_embed_length - v.shape[0])
v = F.pad(v, pad_tuple, value=0)
elif k in ["prompt_embeds_mask", "negative_prompt_embeds_mask"]:
pad_tuple = (0, self.max_prompt_embed_length - v.shape[0])
v = F.pad(v, pad_tuple, value=0)
extra_fields[k] = v.unsqueeze(0)
else:
extra_fields[k] = v
extra_fields["raw_prompt"] = kwargs["raw_prompt"]
prompt_output = self.tokenizer.pad(
{"input_ids": output.prompt_ids},
padding="max_length",
max_length=self.rollout_config.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)
if prompt_output["input_ids"].dim() == 1:
prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0)
prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0)
response_diffusion_output = output.response_diffusion_output.unsqueeze(0)
response_logprobs = None
if output.response_logprobs is not None:
response_logprobs = output.response_logprobs.unsqueeze(0)
prompt_ids = prompt_output["input_ids"]
extra_fields["attention_mask"] = prompt_output["attention_mask"]
await self._compute_score(
output,
prompts=prompt_ids,
responses=response_diffusion_output,
kwargs=kwargs,
)
if "reward_extra_info" in output.extra_fields:
extra_fields["reward_extra_info"] = output.extra_fields["reward_extra_info"]
return _InternalDiffusionAgentLoopOutput(
prompt_ids=prompt_ids,
response_diffusion_output=response_diffusion_output,
response_logprobs=response_logprobs,
reward_score=output.reward_score,
num_turns=output.num_turns,
metrics=output.metrics,
extra_fields=extra_fields,
)
async def _compute_score(self, output, prompts, responses, kwargs):
"""Compute reward score for single sample."""
enable_async_reward = self.reward_loop_worker_handles is not None
if output.reward_score is None and enable_async_reward:
timing = {}
with simple_timer("compute_score", timing):
batch = TensorDict(
{
"prompts": prompts, # [1, prompt_length]
"responses": responses, # [1, C, H, W] or [1, T, C, H, W]
},
batch_size=1,
)
non_tensor_batch = {
**{k: np.array([v]) for k, v in kwargs.items()},
"__num_turns__": np.array([output.num_turns]),
"tool_extra_fields": np.array([output.extra_fields], dtype=object),
}
data = DataProto(
batch=batch,
non_tensor_batch=non_tensor_batch,
)
selected_reward_loop_worker_handle = random.choice(self.reward_loop_worker_handles)
result = await selected_reward_loop_worker_handle.compute_score.remote(data)
output.reward_score = result["reward_score"]
output.extra_fields["reward_extra_info"] = result["reward_extra_info"]
output.metrics.compute_score = timing["compute_score"]
def _postprocess(
self,
inputs: list[_InternalDiffusionAgentLoopOutput],
input_non_tensor_batch: dict | None = None,
) -> DataProto:
"""Process the padded outputs from _run_agent_loop and combine them into a batch."""
# Convert lists back to tensors and stack them to create a batch.
prompt_ids = torch.cat([input.prompt_ids for input in inputs], dim=0)
response_diffusion_output = torch.cat([input.response_diffusion_output for input in inputs], dim=0)
optional_outputs = {}
if inputs[0].response_logprobs is not None:
optional_outputs["rollout_log_probs"] = torch.cat([input.response_logprobs for input in inputs], dim=0)
# Handle extra fields that are tensors
extra_keys = [k for k, v in inputs[0].extra_fields.items() if isinstance(v, torch.Tensor)]
for key in extra_keys:
optional_outputs[key] = torch.cat([input.extra_fields[key] for input in inputs], dim=0)
for input in inputs:
del input.extra_fields[key]
batch = TensorDict(
{
"prompts": prompt_ids, # [bsz, prompt_length]
"responses": response_diffusion_output, # [bsz, C, H, W] or [bsz, T, C, H, W]
**optional_outputs,
},
batch_size=len(inputs),
)
scores = [input.reward_score for input in inputs]
if all(score is not None for score in scores):
rm_scores = torch.tensor(scores, dtype=torch.float32).unsqueeze(-1)
batch["rm_scores"] = rm_scores
non_tensor_batch = {
"__num_turns__": np.array([input.num_turns for input in inputs], dtype=np.int32),
}
if input_non_tensor_batch:
non_tensor_batch.update(input_non_tensor_batch)
# add reward_extra_info to non_tensor_batch
reward_extra_infos = [input.extra_fields.get("reward_extra_info", {}) for input in inputs]
reward_extra_keys = list(reward_extra_infos[0].keys())
for key in reward_extra_keys:
non_tensor_batch[key] = np.array([info[key] for info in reward_extra_infos])
metrics = [input.metrics.model_dump() for input in inputs]
# Collect extra fields from all inputs and convert them to np.ndarray
extra_fields = {}
all_keys = set(key for input_item in inputs for key in input_item.extra_fields)
for key in all_keys:
temp_arr = np.empty(len(inputs), dtype=object)
temp_arr[:] = [input.extra_fields.get(key) for input in inputs]
extra_fields[key] = temp_arr
non_tensor_batch.update(extra_fields)
# Only include reward_extra_keys in meta_info if rm_scores is in batch
# This avoids conflicts when reward_tensor is merged later in ray_trainer.py
if "rm_scores" in batch.keys():
meta_info = {"metrics": metrics, "reward_extra_keys": reward_extra_keys}
else:
meta_info = {"metrics": metrics}
return DataProto(
batch=batch,
non_tensor_batch=non_tensor_batch,
meta_info=meta_info,
)