Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,7 @@ agentlightning/dashboard/**/*.svg

# Docker data
docker/data/

# AGL simulation
examples/simulation/envs/alfworld/alfworld_source/*
wandb/
43 changes: 43 additions & 0 deletions agentlightning/adapter/triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,49 @@ def adapt(self, source: Union[Sequence[Span], Sequence[ReadableSpan]], /) -> Lis
)
return trajectory

def adapt_group(self, span_groups: Dict[str, Dict[str, Span]], /) -> List[Triplet]:
def get_token_ids(span, key: str) -> list:
return span.attributes.get(key, []) if span else []

def get_reward_0(span) -> float:
if span is None:
return 0.0
return float(span.attributes.get("agentlightning.reward.0.value", 0.0))

def get_reward_1(span):
if span is None:
return None
val = span.attributes.get("agentlightning.reward.1.value")
return float(val) if val is not None else None

def get_message(span) -> str:
if span is None:
return None
return span.attributes.get("agentlightning.object.literal")

triplets: List[Triplet] = []

for key, value in span_groups.items():
call_span = value.get("call_span")
object_span = value.get("object_span")
annotation_span = value.get("annotation_span")

request_id = value.get("request_id")

triplets.append(
Triplet(
prompt={"token_ids": get_token_ids(call_span, "prompt_token_ids")},
response={"token_ids": get_token_ids(call_span, "response_token_ids")},
reward=get_reward_0(annotation_span),
metadata=dict(
response_id=request_id,
intrinsic_reward=get_reward_1(annotation_span),
message=get_message(object_span),
),
)
)
return triplets


class LlmProxyTraceToTriplet(TraceToTripletBase):
"""Convert telemetry emitted by the LLM Proxy into triplet trajectories.
Expand Down
165 changes: 143 additions & 22 deletions agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,66 @@ def _validate_data(self, rollout: RolloutLegacy):
elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets):
print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}")

def _extract_span_groups(self, spans):
def resolve_step_count(span, next_span, spans, index):
"""
Determine step_count for a given span using next_span or fallback search.
"""
# CASE A: If next_span exists and parent_id matches
if next_span and span.parent_id == next_span.span_id:
return next_span.attributes.get("step_count")

# CASE B: Fallback — search forward for agentlightning.operation
for s in spans[index + 1 :]:
if s.name == "agentlightning.operation" and span.parent_id == s.span_id:
return s.attributes.get("step_count")

return None

def extract_step_count_from_links(span):
"""
Extract step_count from agentlightning.link.* attributes.
"""
key = span.attributes.get("agentlightning.link.0.key_match")
if key == "step_count":
return span.attributes.get("agentlightning.link.0.value_match")
return None

span_groups = {}

for i, span in enumerate(spans):
next_span = spans[i + 1] if i + 1 < len(spans) else None
step_count = None

if span.name == "openai.chat.completion":
step_count = resolve_step_count(span, next_span, spans, i)
if step_count is None:
continue

step_count = str(step_count)
span_groups.setdefault(step_count, {})
span_groups[step_count]["call_span"] = span

elif span.name == "agentlightning.object":
step_count = extract_step_count_from_links(span)
if step_count is None:
continue

step_count = str(step_count)
span_groups.setdefault(step_count, {})
span_groups[step_count]["object_span"] = span

elif span.name == "agentlightning.annotation":
step_count = extract_step_count_from_links(span)
if step_count is None:
continue

step_count = str(step_count)
span_groups.setdefault(step_count, {})
span_groups[step_count]["annotation_span"] = span

return span_groups

async def _validate_data_v1(self, rollout: Rollout) -> RolloutLegacy:
"""Convert Rollout to RolloutLegacy and validate.

Expand All @@ -472,13 +532,15 @@ async def _validate_data_v1(self, rollout: Rollout) -> RolloutLegacy:
"""
# Query spans for this rollout (latest attempt)
spans = await self.store.query_spans(rollout.rollout_id, attempt_id="latest")
span_groups = self._extract_span_groups(spans)

# Convert spans to triplets using the adapter
if not spans:
# No triplets found, will emit a warning later.
triplets = []
else:
triplets = self.adapter.adapt(spans)
# triplets = self.adapter.adapt(spans)
triplets = self.adapter.adapt_group(span_groups)

# Extract final reward from triplets
final_reward: Optional[float] = None
Expand Down Expand Up @@ -646,7 +708,14 @@ def get_test_metrics(self):
)
return metric_dict

def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, device: torch.device):
def get_train_data_batch(
self,
max_prompt_length: int,
max_response_length: int,
device: torch.device,
use_final_reward_as_step_reward: bool = True,
is_gigpo: bool = False,
):
"""
Processes completed rollouts to generate a training data batch.

Expand Down Expand Up @@ -674,12 +743,32 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
# The client should report triplets that contain prompt_ids and response_ids.
# Example triplet.prompt: {"token_ids": [...]}
# Example triplet.response: {"token_ids": [...]}
trace_list = [
{"prompt_ids": t.prompt.get("token_ids", []), "response_ids": t.response.get("token_ids", [])}
for t in rollout.triplets
]
# trace_list = [
# {"prompt_ids": t.prompt.get("token_ids", []), "response_ids": t.response.get("token_ids", [])}
# for t in rollout.triplets
# ]
trace_list = []
for t in rollout.triplets:
trace_dict = {
"prompt_ids": t.prompt.get("token_ids", []),
"response_ids": t.response.get("token_ids", []),
"step_reward": t.reward,
}

# Optional fields
intrinsic = t.metadata.get("intrinsic_reward")
message = t.metadata.get("message")

if intrinsic is not None:
trace_dict["step_intrinsic_reward"] = intrinsic

if message is not None:
trace_dict["message"] = message

trace_list.append(trace_dict)

info = {
"reward": final_reward,
"final_reward": final_reward,
"trace_list": trace_list,
"data_id": original_sample["data_id"],
}
Expand All @@ -700,17 +789,28 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
input_attention_mask_list: List[List[int]] = []
response_ids_list: List[List[int]] = []
response_attention_mask_list: List[List[int]] = []
reward_list: List[float] = []
final_reward_list: List[float] = []
step_reward_list: List[float] = []
data_id_list: List[str] = []
rollout_id_list: List[str] = []
turn_index_list: List[int] = []
is_drop_list: List[bool] = []
n_trunc_sample_because_of_response = 0

# optional fields
step_intrinsic_reward_list: List[float] = []
message_list: List[str] = []

for rollout_id, sample_info in finished_id_to_sample_info.items():
for turn_index, trace in enumerate(sample_info["trace_list"]):

reward_list.append(sample_info["reward"])
final_reward_list.append(sample_info["final_reward"])
step_reward_list.append(trace["step_reward"])
if "step_intrinsic_reward" in trace:
step_intrinsic_reward_list.append(trace["step_intrinsic_reward"])
if "message" in trace:
message_list.append(trace["message"])

prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]

# Mark samples with prompts exceeding max_prompt_length to be dropped later
Expand Down Expand Up @@ -752,7 +852,10 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1)
position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
is_drop_mask = torch.BoolTensor(is_drop_list).to(device)
scores = torch.tensor(reward_list, dtype=torch.bfloat16).to(device)
if use_final_reward_as_step_reward:
scores = torch.tensor(final_reward_list, dtype=torch.float32).to(device)
else:
scores = torch.tensor(step_reward_list, dtype=torch.float32).to(device)

# Create token-level scores by placing the final reward at the last token position
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
Expand All @@ -763,19 +866,33 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
# Only take the last response_length part of the sequence to get the token-level scores for the model's response part.
token_level_scores = token_level_scores[:, -max_response_length:]

# Create token-level intrinsic rewards
token_level_intrinsic_rewards = None
if len(step_intrinsic_reward_list) > 0:
step_intrinsic_reward_list = [0.0 if reward is None else reward for reward in step_intrinsic_reward_list]
intrinsic_rewards = torch.tensor(step_intrinsic_reward_list, dtype=torch.float32).to(device)
token_level_intrinsic_rewards = torch.zeros_like(attention_mask, dtype=intrinsic_rewards.dtype)
token_level_intrinsic_rewards[torch.arange(n_transition), eos_mask_idx] = intrinsic_rewards
token_level_intrinsic_rewards = token_level_intrinsic_rewards[:, -max_response_length:]

# Form the final batch using TensorDict
batch = TensorDict(
{
"prompts": batch_input_ids,
"responses": batch_response_ids,
"input_ids": batch_seq, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"position_ids": position_ids,
"is_drop_mask": is_drop_mask,
"token_level_scores": token_level_scores.contiguous(),
},
batch_size=n_transition,
)
batch_dict = {
"prompts": batch_input_ids,
"responses": batch_response_ids,
"input_ids": batch_seq, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"position_ids": position_ids,
"is_drop_mask": is_drop_mask,
"token_level_scores": token_level_scores.contiguous(),
}
batch_dict["step_rewards"] = torch.tensor(np.array(step_reward_list), dtype=torch.float32).to(device)
if token_level_intrinsic_rewards is not None:
batch_dict["step_intrinsic_rewards"] = torch.tensor(
np.array(step_intrinsic_reward_list), dtype=torch.float32
).to(device)
batch_dict["token_level_intrinsic_rewards"] = token_level_intrinsic_rewards.contiguous()

batch = TensorDict(batch_dict, batch_size=n_transition)
data_proto = DataProto(batch=batch)

data_metrics = {
Expand All @@ -792,6 +909,10 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list) # type: ignore
data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list) # type: ignore

data_proto.non_tensor_batch["step_rewards"] = np.array(step_reward_list)
if len(message_list) > 0 and is_gigpo:
data_proto.non_tensor_batch["anchor_obs"] = np.array(message_list)

return data_proto, data_metrics

def clear_data_and_server(self):
Expand Down
29 changes: 29 additions & 0 deletions examples/simulation/captioners/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import warnings

from .history import HistoryPromptBuilder


def create_prompt_builder(config):
"""
Creates an instance of a prompt builder based on the provided configuration.
This function initializes a prompt builder by extracting relevant configuration
parameters. It can be extended or modified to support different types of prompt
builders beyond just the HistoryPromptBuilder.
Args:
config (Config): An object containing configuration settings, which must
include the following keys:
- max_text_history (int): Maximum number of text history entries to retain.
Returns:
PromptBuilder: An instance of a prompt builder configured with the specified
history limits and any additional parameters defined in the config.
"""

max_history = config.get("max_history", None)
if max_history is not None:
warnings.warn("The 'max_history' parameter is deprecated. Please use 'max_text_history' instead.")

max_text_history = max_history
if max_text_history is None:
max_text_history = config.max_text_history

return HistoryPromptBuilder(max_text_history=max_text_history)
Loading