mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
mantisdk/verl/daemon.py
ADDED
|
@@ -0,0 +1,1154 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import socket
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
import uuid
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from collections.abc import Mapping
|
|
13
|
+
from typing import Any, Dict, List, Literal, Optional, Tuple, cast
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import requests
|
|
17
|
+
import torch
|
|
18
|
+
from flask import Flask, Response, abort, request
|
|
19
|
+
from tensordict import TensorDict
|
|
20
|
+
from verl import DataProto
|
|
21
|
+
|
|
22
|
+
from mantisdk import LLM, MantisdkServer, NamedResources, RolloutLegacy
|
|
23
|
+
from mantisdk.adapter.triplet import TracerTraceToTriplet, TraceToTripletBase
|
|
24
|
+
from mantisdk.llm_proxy import LLMProxy, ModelConfig
|
|
25
|
+
from mantisdk.store.base import LightningStore
|
|
26
|
+
from mantisdk.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"AgentModeDaemon",
|
|
30
|
+
"get_left_padded_ids_and_attention_mask",
|
|
31
|
+
"get_right_padded_ids_and_attention_mask",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def ids_startswith(
|
|
36
|
+
full_ids: List[int], prefix_ids: List[int], tokenizer: Any, debug: bool = False
|
|
37
|
+
) -> Tuple[bool, Tuple[bool, bool, bool]]:
|
|
38
|
+
is_prefix: bool
|
|
39
|
+
template_mismatch, retoken_mismatch, others_mismatch = False, False, False
|
|
40
|
+
if full_ids[: len(prefix_ids)] == prefix_ids:
|
|
41
|
+
is_prefix = True
|
|
42
|
+
return True, (template_mismatch, retoken_mismatch, others_mismatch)
|
|
43
|
+
else:
|
|
44
|
+
is_prefix = False
|
|
45
|
+
|
|
46
|
+
if not debug:
|
|
47
|
+
return is_prefix, (template_mismatch, retoken_mismatch, others_mismatch)
|
|
48
|
+
|
|
49
|
+
def _special_token_sequence(ids: List[int]) -> List[int]:
|
|
50
|
+
return [id for id in ids if id in tokenizer.all_special_ids]
|
|
51
|
+
|
|
52
|
+
def _none_special_token_sequence(ids: List[int]) -> List[int]:
|
|
53
|
+
return [id for id in ids if id not in tokenizer.all_special_ids]
|
|
54
|
+
|
|
55
|
+
# First, handle special tokens
|
|
56
|
+
full_special_ids = _special_token_sequence(full_ids)
|
|
57
|
+
prefix_special_ids = _special_token_sequence(prefix_ids)
|
|
58
|
+
if sum(1 for a, b in zip(full_special_ids, prefix_special_ids) if a != b) > 0:
|
|
59
|
+
template_mismatch = True
|
|
60
|
+
|
|
61
|
+
# Next, handle string content
|
|
62
|
+
full_content_ids = _none_special_token_sequence(full_ids)
|
|
63
|
+
prefix_content_ids = _none_special_token_sequence(prefix_ids)
|
|
64
|
+
full_string = tokenizer.decode(full_ids, skip_special_tokens=True)
|
|
65
|
+
prefix_string = tokenizer.decode(prefix_ids, skip_special_tokens=True)
|
|
66
|
+
if full_content_ids[: len(prefix_content_ids)] != prefix_content_ids and full_string.startswith(prefix_string):
|
|
67
|
+
retoken_mismatch = True
|
|
68
|
+
elif full_content_ids[: len(prefix_content_ids)] != prefix_content_ids and not full_string.startswith(
|
|
69
|
+
prefix_string
|
|
70
|
+
):
|
|
71
|
+
others_mismatch = True
|
|
72
|
+
return is_prefix, (template_mismatch, retoken_mismatch, others_mismatch)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def log_mismatch_detail(
|
|
76
|
+
diagnostic: Tuple[bool, bool, bool],
|
|
77
|
+
full_ids: List[int],
|
|
78
|
+
prefix_ids: List[int],
|
|
79
|
+
global_steps: int,
|
|
80
|
+
rollout_id: str,
|
|
81
|
+
turn_id: int,
|
|
82
|
+
log_dir: str | None = None,
|
|
83
|
+
):
|
|
84
|
+
if log_dir is None:
|
|
85
|
+
return
|
|
86
|
+
os.makedirs(log_dir, exist_ok=True)
|
|
87
|
+
template_mismatch, retoken_mismatch, others_mismatch = diagnostic
|
|
88
|
+
if template_mismatch:
|
|
89
|
+
with open(os.path.join(log_dir, "template_mismatch.log"), "a+") as f:
|
|
90
|
+
print(
|
|
91
|
+
"-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10,
|
|
92
|
+
file=f,
|
|
93
|
+
)
|
|
94
|
+
print(full_ids, file=f)
|
|
95
|
+
print(prefix_ids, file=f)
|
|
96
|
+
if retoken_mismatch:
|
|
97
|
+
with open(os.path.join(log_dir, "retoken_mismatch.log"), "a+") as f:
|
|
98
|
+
print(
|
|
99
|
+
"-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10,
|
|
100
|
+
file=f,
|
|
101
|
+
)
|
|
102
|
+
print(full_ids, file=f)
|
|
103
|
+
print(prefix_ids, file=f)
|
|
104
|
+
if others_mismatch:
|
|
105
|
+
with open(os.path.join(log_dir, "others_mismatch.log"), "a+") as f:
|
|
106
|
+
print(
|
|
107
|
+
"-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10,
|
|
108
|
+
file=f,
|
|
109
|
+
)
|
|
110
|
+
print(full_ids, file=f)
|
|
111
|
+
print(prefix_ids, file=f)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_left_padded_ids_and_attention_mask(
|
|
115
|
+
ids: List[int], max_length: int, pad_token_id: int
|
|
116
|
+
) -> Tuple[List[int], List[int]]:
|
|
117
|
+
"""
|
|
118
|
+
Left-pad (or truncate) a sequence of token IDs to a fixed length,
|
|
119
|
+
and build the corresponding attention mask.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
ids: the original list of token IDs.
|
|
123
|
+
max_length: desired total length after padding/truncation.
|
|
124
|
+
pad_token_id: ID to use for padding.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
padded_ids (any): list of length == max_length.
|
|
128
|
+
attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads.
|
|
129
|
+
"""
|
|
130
|
+
seq_len = len(ids)
|
|
131
|
+
|
|
132
|
+
if seq_len >= max_length:
|
|
133
|
+
# too long → truncate from the left, keep the last max_length tokens
|
|
134
|
+
trimmed = ids[-max_length:]
|
|
135
|
+
attention_mask = [1] * max_length
|
|
136
|
+
return trimmed, attention_mask
|
|
137
|
+
|
|
138
|
+
# too short → pad on the left
|
|
139
|
+
pad_len = max_length - seq_len
|
|
140
|
+
padded_ids = [pad_token_id] * pad_len + ids
|
|
141
|
+
attention_mask = [0] * pad_len + [1] * seq_len
|
|
142
|
+
return padded_ids, attention_mask
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_right_padded_ids_and_attention_mask(
|
|
146
|
+
ids: List[int], max_length: int, pad_token_id: int
|
|
147
|
+
) -> Tuple[List[int], List[int]]:
|
|
148
|
+
"""
|
|
149
|
+
Right-pad (or truncate) a sequence of token IDs to a fixed length,
|
|
150
|
+
and build the corresponding attention mask.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
ids: the original list of token IDs.
|
|
154
|
+
max_length: desired total length after padding/truncation.
|
|
155
|
+
pad_token_id: ID to use for padding.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
padded_ids (any): list of length == max_length.
|
|
159
|
+
attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads.
|
|
160
|
+
"""
|
|
161
|
+
seq_len = len(ids)
|
|
162
|
+
|
|
163
|
+
if seq_len >= max_length:
|
|
164
|
+
# too long → truncate to the first max_length tokens
|
|
165
|
+
trimmed = ids[:max_length]
|
|
166
|
+
attention_mask = [1] * max_length
|
|
167
|
+
return trimmed, attention_mask
|
|
168
|
+
|
|
169
|
+
# too short → pad on the right
|
|
170
|
+
pad_len = max_length - seq_len
|
|
171
|
+
padded_ids = ids + [pad_token_id] * pad_len
|
|
172
|
+
attention_mask = [1] * seq_len + [0] * pad_len
|
|
173
|
+
return padded_ids, attention_mask
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _find_available_port() -> int:
|
|
177
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
178
|
+
s.bind(("", 0))
|
|
179
|
+
return s.getsockname()[1]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _to_native(obj: Any) -> Any:
|
|
183
|
+
"""Convert data retrieved from Parquet to data usable in AGL server."""
|
|
184
|
+
# 1) Arrays -> list (then recurse)
|
|
185
|
+
if isinstance(obj, np.ndarray):
|
|
186
|
+
return _to_native(obj.tolist())
|
|
187
|
+
|
|
188
|
+
# 2) NumPy scalar types -> Python scalars
|
|
189
|
+
if isinstance(obj, np.generic):
|
|
190
|
+
return _to_native(obj.item())
|
|
191
|
+
|
|
192
|
+
# 3) Dict-like -> dict
|
|
193
|
+
if isinstance(obj, Mapping):
|
|
194
|
+
return {_to_native(k): _to_native(v) for k, v in obj.items()} # type: ignore
|
|
195
|
+
|
|
196
|
+
# 4) Lists/Tuples/Sets -> list
|
|
197
|
+
if isinstance(obj, (list, tuple, set)):
|
|
198
|
+
return [_to_native(x) for x in obj] # type: ignore
|
|
199
|
+
|
|
200
|
+
# 5) Anything else: leave as-is
|
|
201
|
+
return obj
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class AgentModeDaemon:
|
|
205
|
+
"""
|
|
206
|
+
AgentModeDaemon using the MantisdkServer SDK.
|
|
207
|
+
|
|
208
|
+
This class manages the server lifecycle, task queueing, and results
|
|
209
|
+
retrieval, while also running a proxy server for LLM requests. It maintains
|
|
210
|
+
the original interface for compatibility with the RayPPOTrainer.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
port: Optional[int],
|
|
216
|
+
train_rollout_n: int,
|
|
217
|
+
train_information: Dict[str, Any],
|
|
218
|
+
tokenizer: Any,
|
|
219
|
+
mini_batch_size: int,
|
|
220
|
+
pad_token_id: int,
|
|
221
|
+
reward_fillna_value: float = 0.0,
|
|
222
|
+
llm_timeout_seconds: float = 1200.0,
|
|
223
|
+
mode: Literal["v0", "v1"] = "v1",
|
|
224
|
+
llm_proxy: LLMProxy | None = None,
|
|
225
|
+
store: LightningStore | None = None,
|
|
226
|
+
adapter: TraceToTripletBase | None = None,
|
|
227
|
+
processor: Any = None,
|
|
228
|
+
image_base_dir: Optional[str] = None,
|
|
229
|
+
trace_aggregator: Dict[str, Any] = {"level": "transition"},
|
|
230
|
+
):
|
|
231
|
+
self.mode = mode
|
|
232
|
+
self.llm_timeout_seconds = llm_timeout_seconds
|
|
233
|
+
|
|
234
|
+
# Server and Task Configuration
|
|
235
|
+
if mode == "v0":
|
|
236
|
+
assert port is not None
|
|
237
|
+
self.server_port = port
|
|
238
|
+
self.server = MantisdkServer(
|
|
239
|
+
host="0.0.0.0", port=self.server_port, task_timeout_seconds=self.llm_timeout_seconds
|
|
240
|
+
)
|
|
241
|
+
self.proxy_port = _find_available_port() # Run proxy on a different port
|
|
242
|
+
else:
|
|
243
|
+
assert store is not None
|
|
244
|
+
self.store = store
|
|
245
|
+
if llm_proxy is None:
|
|
246
|
+
self.llm_proxy = LLMProxy(
|
|
247
|
+
port=_find_available_port(),
|
|
248
|
+
model_list=[],
|
|
249
|
+
store=store,
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
# Reuse the existing LLM proxy (probably configured by user)
|
|
253
|
+
self.llm_proxy = llm_proxy
|
|
254
|
+
if adapter is None:
|
|
255
|
+
self.adapter = TracerTraceToTriplet()
|
|
256
|
+
else:
|
|
257
|
+
# Reuse the one from trainer
|
|
258
|
+
self.adapter = adapter
|
|
259
|
+
self._internal_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
260
|
+
self._internal_loop_thread = threading.Thread(target=self._internal_loop_runner, daemon=True)
|
|
261
|
+
self._internal_loop_thread.start()
|
|
262
|
+
|
|
263
|
+
# Training and Data Configuration
|
|
264
|
+
self.train_rollout_n = train_rollout_n
|
|
265
|
+
self.train_information = train_information
|
|
266
|
+
self.mini_batch_size = mini_batch_size
|
|
267
|
+
self.pad_token_id = pad_token_id
|
|
268
|
+
self.tokenizer = tokenizer
|
|
269
|
+
self.processor = processor
|
|
270
|
+
self.reward_fillna_value = reward_fillna_value
|
|
271
|
+
self.image_base_dir = image_base_dir
|
|
272
|
+
self.trace_aggregator = trace_aggregator
|
|
273
|
+
|
|
274
|
+
# Check if model requires multimodal position_ids (e.g., Qwen2-VL)
|
|
275
|
+
self._use_mrope = self._is_mrope_model()
|
|
276
|
+
|
|
277
|
+
# Internal State
|
|
278
|
+
self.backend_llm_server_addresses: List[str] = []
|
|
279
|
+
self._total_tasks_queued = 0
|
|
280
|
+
self._completed_rollouts_v0: Dict[str, RolloutLegacy] = {}
|
|
281
|
+
self._task_id_to_original_sample: Dict[str, Dict[str, Any]] = {}
|
|
282
|
+
self._server_thread: Optional[threading.Thread] = None
|
|
283
|
+
self._proxy_thread: Optional[threading.Thread] = None
|
|
284
|
+
self.is_train = True
|
|
285
|
+
|
|
286
|
+
def _internal_loop_runner(self):
|
|
287
|
+
"""Run the internal loop."""
|
|
288
|
+
loop = asyncio.new_event_loop()
|
|
289
|
+
asyncio.set_event_loop(loop)
|
|
290
|
+
self._internal_loop = loop
|
|
291
|
+
loop.run_forever()
|
|
292
|
+
loop.close()
|
|
293
|
+
|
|
294
|
+
# Multimodal utilities for M-RoPE position embeddings
|
|
295
|
+
|
|
296
|
+
def _is_mrope_model(self) -> bool:
|
|
297
|
+
"""Check if processor requires M-RoPE position embeddings."""
|
|
298
|
+
if self.processor is None or not hasattr(self.processor, "image_processor"):
|
|
299
|
+
return False
|
|
300
|
+
name = self.processor.image_processor.__class__.__name__
|
|
301
|
+
return "Qwen2VLImageProcessor" in name or "Qwen3VLImageProcessor" in name
|
|
302
|
+
|
|
303
|
+
def _resolve_image_path(self, path: str) -> str:
|
|
304
|
+
"""Resolve relative image path with base directory."""
|
|
305
|
+
import os
|
|
306
|
+
|
|
307
|
+
if os.path.isabs(path):
|
|
308
|
+
return path
|
|
309
|
+
if self.image_base_dir is None:
|
|
310
|
+
raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.")
|
|
311
|
+
return os.path.join(self.image_base_dir, path)
|
|
312
|
+
|
|
313
|
+
def _get_image_grid_thw(self, image_urls: List[str]) -> Optional[torch.Tensor]:
|
|
314
|
+
"""Compute image_grid_thw from image URLs for M-RoPE computation.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
image_urls: List of image URLs extracted from triplet prompt payload.
|
|
318
|
+
URLs can be http(s):// URLs or file:// URIs, or data: URIs.
|
|
319
|
+
"""
|
|
320
|
+
from PIL import Image
|
|
321
|
+
from verl.utils.dataset.vision_utils import process_image # pyright: ignore[reportUnknownVariableType]
|
|
322
|
+
|
|
323
|
+
if self.processor is None or not image_urls:
|
|
324
|
+
return None
|
|
325
|
+
|
|
326
|
+
def to_image_uri(url: str) -> str:
|
|
327
|
+
# Already a proper URI (http, https, file, data)
|
|
328
|
+
if url.startswith(("http://", "https://", "file://", "data:")):
|
|
329
|
+
return url
|
|
330
|
+
# Treat as a file path that needs resolution
|
|
331
|
+
resolved = self._resolve_image_path(url)
|
|
332
|
+
return f"file://{resolved}"
|
|
333
|
+
|
|
334
|
+
images: List[Image.Image] = [process_image({"image": to_image_uri(url)}) for url in image_urls]
|
|
335
|
+
model_inputs = self.processor(text=["dummy"], images=images, return_tensors="pt")
|
|
336
|
+
return model_inputs.get("image_grid_thw")
|
|
337
|
+
|
|
338
|
+
def _compute_mrope_position_ids(
|
|
339
|
+
self,
|
|
340
|
+
input_ids: torch.Tensor,
|
|
341
|
+
attention_mask: torch.Tensor,
|
|
342
|
+
image_grid_thw: Optional[torch.Tensor] = None,
|
|
343
|
+
) -> torch.Tensor:
|
|
344
|
+
"""Compute 4D position_ids for M-RoPE models."""
|
|
345
|
+
from typing import Callable
|
|
346
|
+
|
|
347
|
+
get_rope_index: Callable[..., torch.Tensor]
|
|
348
|
+
if "Qwen3VL" in self.processor.__class__.__name__:
|
|
349
|
+
from verl.models.transformers.qwen3_vl import get_rope_index # pyright: ignore[reportUnknownVariableType]
|
|
350
|
+
else:
|
|
351
|
+
from verl.models.transformers.qwen2_vl import get_rope_index # pyright: ignore[reportUnknownVariableType]
|
|
352
|
+
|
|
353
|
+
vision_pos = get_rope_index(
|
|
354
|
+
self.processor, input_ids=input_ids, image_grid_thw=image_grid_thw, attention_mask=attention_mask
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
valid_mask = attention_mask.bool()
|
|
358
|
+
text_pos = torch.zeros((1, len(input_ids)), dtype=torch.long, device=input_ids.device)
|
|
359
|
+
text_pos[0, valid_mask] = torch.arange(valid_mask.sum().item(), device=input_ids.device)
|
|
360
|
+
|
|
361
|
+
return torch.cat([text_pos, vision_pos], dim=0)
|
|
362
|
+
|
|
363
|
+
def _start_proxy_server_v0(self):
|
|
364
|
+
"""
|
|
365
|
+
Initializes and runs a Flask-based proxy server in a separate thread.
|
|
366
|
+
This proxy load-balances requests to the actual backend LLM servers.
|
|
367
|
+
"""
|
|
368
|
+
app = Flask(__name__)
|
|
369
|
+
|
|
370
|
+
num_requests = 0
|
|
371
|
+
last_request_time = 0
|
|
372
|
+
|
|
373
|
+
@app.route("/v1/<path:path>", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
|
|
374
|
+
def proxy(path: str): # type: ignore
|
|
375
|
+
if not self.backend_llm_server_addresses:
|
|
376
|
+
abort(503, description="No backend LLM servers available.")
|
|
377
|
+
|
|
378
|
+
# Randomly choose a backend server for load balancing
|
|
379
|
+
target_server = random.choice(self.backend_llm_server_addresses)
|
|
380
|
+
target_url = f"http://{target_server}/v1/{path}"
|
|
381
|
+
|
|
382
|
+
# Copy client request headers, removing the Host header
|
|
383
|
+
headers = {key: value for key, value in request.headers if key.lower() != "host"}
|
|
384
|
+
|
|
385
|
+
# Log the request for debugging
|
|
386
|
+
nonlocal num_requests, last_request_time
|
|
387
|
+
current_time = time.time()
|
|
388
|
+
num_requests += 1
|
|
389
|
+
if current_time - last_request_time > 60 or num_requests == 1 or num_requests % 100 == 0:
|
|
390
|
+
print(f"Proxying {request.method} request to {target_server}. Request data: {request.get_data()}")
|
|
391
|
+
last_request_time = current_time
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
# Forward the request to the target backend
|
|
395
|
+
resp = requests.request(
|
|
396
|
+
method=request.method,
|
|
397
|
+
url=target_url,
|
|
398
|
+
headers=headers,
|
|
399
|
+
params=request.args, # type: ignore
|
|
400
|
+
data=request.get_data(),
|
|
401
|
+
cookies=request.cookies,
|
|
402
|
+
allow_redirects=False,
|
|
403
|
+
timeout=self.llm_timeout_seconds,
|
|
404
|
+
)
|
|
405
|
+
# Filter out hop-by-hop headers before returning the response
|
|
406
|
+
excluded_headers = [
|
|
407
|
+
"content-encoding",
|
|
408
|
+
"content-length",
|
|
409
|
+
"transfer-encoding",
|
|
410
|
+
"connection",
|
|
411
|
+
"keep-alive",
|
|
412
|
+
"proxy-authenticate",
|
|
413
|
+
"proxy-authorization",
|
|
414
|
+
"te",
|
|
415
|
+
"trailers",
|
|
416
|
+
"upgrade",
|
|
417
|
+
]
|
|
418
|
+
response_headers = [
|
|
419
|
+
(name, value) for name, value in resp.raw.headers.items() if name.lower() not in excluded_headers
|
|
420
|
+
]
|
|
421
|
+
if resp.status_code == 200:
|
|
422
|
+
# NOTE: from Zhiyuan's code.
|
|
423
|
+
# https://github.com/hzy46/verl_agent_mode/blob/2db65ea9858f645a914120357412a7540f8bd82d/verl/trainer/ppo/ray_trainer.py#L692-L711
|
|
424
|
+
# request_json = json.loads(request.get_data().decode("utf-8"))
|
|
425
|
+
response_json = json.loads(resp.content.decode("utf-8"))
|
|
426
|
+
# response_message = ChatCompletion(**response_json).choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
|
|
427
|
+
# tool_schemas = request_json.get("tools", None)
|
|
428
|
+
# prompt_ids = self.tokenizer.apply_chat_template(request_json["messages"], tools=tool_schemas, add_generation_prompt=True, tokenize=True)
|
|
429
|
+
# full_ids = self.tokenizer.apply_chat_template(request_json["messages"] + [response_message], tools=tool_schemas, add_generation_prompt=False, tokenize=True)
|
|
430
|
+
# TBD: response_ids sometimes ends with "<eos_id>\n", shall we keep the extra "\n"?
|
|
431
|
+
# sometimes it has some differences with the hacky method in the end, but this should align with ToolCompletionCallback
|
|
432
|
+
# response_ids = full_ids[len(prompt_ids):]
|
|
433
|
+
|
|
434
|
+
# NOTE (yuge): They are different. Don't know why.
|
|
435
|
+
# assert response_json['prompt_token_ids'] == prompt_ids
|
|
436
|
+
# patched_response_ids = response_json['response_token_ids'][0]
|
|
437
|
+
# assert patched_response_ids == response_ids[:len(patched_response_ids)], f"{patched_response_ids} != {response_ids[:len(patched_response_ids)]}"
|
|
438
|
+
# response_json['prompt_token_ids'] = prompt_ids
|
|
439
|
+
# response_json['response_token_ids'] = [response_ids]
|
|
440
|
+
replaced_return_content = json.dumps(response_json).encode("utf-8")
|
|
441
|
+
return Response(replaced_return_content, status=resp.status_code, headers=response_headers)
|
|
442
|
+
return Response(resp.content, resp.status_code, response_headers)
|
|
443
|
+
except requests.exceptions.RequestException as e:
|
|
444
|
+
abort(500, description=f"Error proxying request: {e}")
|
|
445
|
+
|
|
446
|
+
def run_app():
|
|
447
|
+
app.run(host="0.0.0.0", port=self.proxy_port, threaded=True, debug=False)
|
|
448
|
+
|
|
449
|
+
self._proxy_thread = threading.Thread(target=run_app, daemon=True)
|
|
450
|
+
self._proxy_thread.start()
|
|
451
|
+
print(f"Proxy server running on port {self.proxy_port}")
|
|
452
|
+
|
|
453
|
+
async def _update_proxy_server_v1(self):
|
|
454
|
+
model_name = self.train_information.get("model")
|
|
455
|
+
if not model_name:
|
|
456
|
+
raise ValueError("Model name is not set.")
|
|
457
|
+
self.llm_proxy.update_model_list(
|
|
458
|
+
[
|
|
459
|
+
ModelConfig(
|
|
460
|
+
{
|
|
461
|
+
"model_name": model_name,
|
|
462
|
+
"litellm_params": {
|
|
463
|
+
"model": "hosted_vllm/" + model_name,
|
|
464
|
+
"api_base": f"http://{address}/v1/",
|
|
465
|
+
},
|
|
466
|
+
}
|
|
467
|
+
)
|
|
468
|
+
for address in self.backend_llm_server_addresses
|
|
469
|
+
],
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
await self.llm_proxy.restart()
|
|
473
|
+
|
|
474
|
+
def start(self):
|
|
475
|
+
"""Starts the main MantisdkServer and the proxy server."""
|
|
476
|
+
|
|
477
|
+
if self.mode == "v0":
|
|
478
|
+
|
|
479
|
+
def run_server():
|
|
480
|
+
"""Run the MantisdkServer in a separate thread."""
|
|
481
|
+
asyncio.run(self.server.run_forever())
|
|
482
|
+
|
|
483
|
+
self._server_thread = threading.Thread(target=run_server, daemon=True)
|
|
484
|
+
self._server_thread.start()
|
|
485
|
+
|
|
486
|
+
# Wait for the server's internal startup event to be set.
|
|
487
|
+
print("Waiting for MantisdkServer to start...")
|
|
488
|
+
is_ready = self.server.startup_event.wait(timeout=20.0) # Wait up to 20s
|
|
489
|
+
if not is_ready:
|
|
490
|
+
raise RuntimeError("MantisdkServer failed to start within the timeout period.")
|
|
491
|
+
|
|
492
|
+
print(f"MantisdkServer control plane running on port {self.server_port}")
|
|
493
|
+
|
|
494
|
+
self._start_proxy_server_v0()
|
|
495
|
+
else:
|
|
496
|
+
# Agent lightning server is no longer needed;
|
|
497
|
+
# Start proxy server in _async_set_up
|
|
498
|
+
pass
|
|
499
|
+
|
|
500
|
+
async def _async_set_up(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True):
|
|
501
|
+
"""Async helper to set up data and resources on the server."""
|
|
502
|
+
self.clear_data_and_server()
|
|
503
|
+
if server_addresses != self.backend_llm_server_addresses:
|
|
504
|
+
self.backend_llm_server_addresses = server_addresses
|
|
505
|
+
if self.mode == "v1" and not self.llm_proxy.is_running():
|
|
506
|
+
await self._update_proxy_server_v1()
|
|
507
|
+
self.is_train = is_train
|
|
508
|
+
|
|
509
|
+
# 1. Update resources on the server for clients to use
|
|
510
|
+
if self.mode == "v0":
|
|
511
|
+
llm_resource = LLM(
|
|
512
|
+
endpoint=f"http://127.0.0.1:{self.proxy_port}/v1",
|
|
513
|
+
model=self.train_information.get("model", "default-model"),
|
|
514
|
+
sampling_parameters={
|
|
515
|
+
"temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0)
|
|
516
|
+
},
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
llm_resource = self.llm_proxy.as_resource(
|
|
520
|
+
sampling_parameters={
|
|
521
|
+
"temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0)
|
|
522
|
+
},
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
resources: NamedResources = {"main_llm": llm_resource}
|
|
526
|
+
|
|
527
|
+
if self.mode == "v0":
|
|
528
|
+
resources_id = await self.server.update_resources(resources)
|
|
529
|
+
else:
|
|
530
|
+
resources_update = await self.store.add_resources(resources)
|
|
531
|
+
resources_id = resources_update.resources_id
|
|
532
|
+
|
|
533
|
+
# 2. Queue tasks for agents to process
|
|
534
|
+
keys = list(data.keys())
|
|
535
|
+
num_samples = len(data[keys[0]])
|
|
536
|
+
rollouts_per_sample = self.train_rollout_n if is_train else 1
|
|
537
|
+
|
|
538
|
+
enqueue_rollout_requests: List[EnqueueRolloutRequest] = []
|
|
539
|
+
data_id_to_original_sample: Dict[str, Dict[str, Any]] = {}
|
|
540
|
+
|
|
541
|
+
for i in range(num_samples):
|
|
542
|
+
data_id = str(uuid.uuid4())
|
|
543
|
+
original_sample = {key: data[key][i] for key in keys}
|
|
544
|
+
original_sample["data_id"] = data_id
|
|
545
|
+
data_id_to_original_sample[data_id] = original_sample
|
|
546
|
+
|
|
547
|
+
# For training, each sample is rolled out multiple times
|
|
548
|
+
# Data ID is different from Rollout ID, as one data can have multiple rollouts.
|
|
549
|
+
for _ in range(rollouts_per_sample):
|
|
550
|
+
task_metadata = {"data_id": data_id, "is_train": is_train}
|
|
551
|
+
if self.mode == "v0":
|
|
552
|
+
# Queue immediately
|
|
553
|
+
rollout_id = await self.server.queue_task(
|
|
554
|
+
sample=_to_native(original_sample),
|
|
555
|
+
mode="train" if is_train else "val",
|
|
556
|
+
resources_id=resources_id,
|
|
557
|
+
metadata=task_metadata,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
# Store original sample data to reconstruct batch information later
|
|
561
|
+
self._task_id_to_original_sample[rollout_id] = original_sample
|
|
562
|
+
self._total_tasks_queued += 1
|
|
563
|
+
else:
|
|
564
|
+
# Collect tasks to enqueue in batch and queue them later
|
|
565
|
+
enqueue_rollout_requests.append(
|
|
566
|
+
EnqueueRolloutRequest(
|
|
567
|
+
input=_to_native(original_sample),
|
|
568
|
+
mode="train" if is_train else "val",
|
|
569
|
+
resources_id=resources_id,
|
|
570
|
+
config=RolloutConfig(
|
|
571
|
+
unresponsive_seconds=self.llm_timeout_seconds,
|
|
572
|
+
timeout_seconds=self.llm_timeout_seconds,
|
|
573
|
+
),
|
|
574
|
+
metadata=task_metadata,
|
|
575
|
+
)
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
if self.mode == "v1":
|
|
579
|
+
# Enqueue all the tasks in a single batch
|
|
580
|
+
rollouts = await self.store.enqueue_many_rollouts(enqueue_rollout_requests)
|
|
581
|
+
self._task_id_to_original_sample.update(
|
|
582
|
+
{
|
|
583
|
+
# Recover the original data and store it for later use.
|
|
584
|
+
rollout.rollout_id: data_id_to_original_sample[cast(Dict[str, Any], rollout.metadata)["data_id"]]
|
|
585
|
+
for rollout in rollouts
|
|
586
|
+
}
|
|
587
|
+
)
|
|
588
|
+
self._total_tasks_queued += len(rollouts)
|
|
589
|
+
|
|
590
|
+
def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True):
|
|
591
|
+
"""Synchronous wrapper for setting up data and server resources."""
|
|
592
|
+
coro = self._async_set_up(data, server_addresses, is_train)
|
|
593
|
+
|
|
594
|
+
if self.mode == "v0":
|
|
595
|
+
if not self.server.loop or not self.server.startup_event.is_set():
|
|
596
|
+
raise RuntimeError("Server is not running or ready.")
|
|
597
|
+
|
|
598
|
+
future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)
|
|
599
|
+
|
|
600
|
+
else:
|
|
601
|
+
if self._internal_loop is None:
|
|
602
|
+
raise RuntimeError("Internal loop is not running.")
|
|
603
|
+
future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop)
|
|
604
|
+
try:
|
|
605
|
+
future.result(timeout=300) # Wait for completion with a timeout
|
|
606
|
+
except Exception as e:
|
|
607
|
+
print(f"Failed to set up data on server: {e}")
|
|
608
|
+
raise
|
|
609
|
+
|
|
610
|
+
def _validate_data(self, rollout: RolloutLegacy):
|
|
611
|
+
if rollout.final_reward is None:
|
|
612
|
+
print(
|
|
613
|
+
f"Warning: Reward is None for rollout {rollout.rollout_id}, will be auto-set to {self.reward_fillna_value}."
|
|
614
|
+
)
|
|
615
|
+
if rollout.triplets is None:
|
|
616
|
+
print(f"Warning: Triplet is None for rollout {rollout.rollout_id}.")
|
|
617
|
+
elif len(rollout.triplets) == 0:
|
|
618
|
+
print(f"Warning: Length of triplets is 0 for rollout {rollout.rollout_id}.")
|
|
619
|
+
elif any(not r.response.get("token_ids", []) for r in rollout.triplets):
|
|
620
|
+
print(f"Warning: Rollout {rollout.rollout_id} contains empty response: {rollout.triplets}")
|
|
621
|
+
elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets):
|
|
622
|
+
print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}")
|
|
623
|
+
|
|
624
|
+
async def _validate_data_v1(self, rollout: Rollout) -> RolloutLegacy:
|
|
625
|
+
"""Convert Rollout to RolloutLegacy and validate.
|
|
626
|
+
|
|
627
|
+
1. Task: construct from Rollout
|
|
628
|
+
2. Triplets: obtained by querying spans and feeding into the adapter
|
|
629
|
+
3. Final reward: extracted from last triplet's reward, searching backwards if not found
|
|
630
|
+
"""
|
|
631
|
+
# Query spans for this rollout (latest attempt)
|
|
632
|
+
spans = await self.store.query_spans(rollout.rollout_id, attempt_id="latest")
|
|
633
|
+
|
|
634
|
+
# Convert spans to triplets using the adapter
|
|
635
|
+
if not spans:
|
|
636
|
+
# No triplets found, will emit a warning later.
|
|
637
|
+
triplets = []
|
|
638
|
+
else:
|
|
639
|
+
triplets = self.adapter.adapt(spans)
|
|
640
|
+
|
|
641
|
+
# Extract final reward from triplets
|
|
642
|
+
final_reward: Optional[float] = None
|
|
643
|
+
if triplets:
|
|
644
|
+
# Search backwards through triplets for the first non-None reward
|
|
645
|
+
for triplet in reversed(triplets):
|
|
646
|
+
if triplet.reward is not None:
|
|
647
|
+
final_reward = triplet.reward
|
|
648
|
+
break
|
|
649
|
+
|
|
650
|
+
# Construct the Task object from Rollout
|
|
651
|
+
task = Task(
|
|
652
|
+
rollout_id=rollout.rollout_id,
|
|
653
|
+
input=rollout.input,
|
|
654
|
+
mode=rollout.mode,
|
|
655
|
+
resources_id=rollout.resources_id,
|
|
656
|
+
metadata=rollout.metadata or {},
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# Create the Rollout object (without trace and logs as per user's note)
|
|
660
|
+
result_rollout = RolloutLegacy(
|
|
661
|
+
rollout_id=rollout.rollout_id,
|
|
662
|
+
task=task,
|
|
663
|
+
final_reward=final_reward,
|
|
664
|
+
triplets=triplets,
|
|
665
|
+
metadata=rollout.metadata or {},
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
# Run the same validation as v0
|
|
669
|
+
self._validate_data(result_rollout)
|
|
670
|
+
|
|
671
|
+
return result_rollout
|
|
672
|
+
|
|
673
|
+
async def _async_run_until_finished(self, verbose: bool = True):
|
|
674
|
+
"""Async helper to wait for all tasks to complete."""
|
|
675
|
+
while len(self._completed_rollouts_v0) < self._total_tasks_queued:
|
|
676
|
+
if self.mode == "v0":
|
|
677
|
+
completed_batch = await self.server.retrieve_completed_rollouts()
|
|
678
|
+
else:
|
|
679
|
+
completed_batch = await self.store.wait_for_rollouts(
|
|
680
|
+
rollout_ids=list(self._task_id_to_original_sample.keys()), timeout=0
|
|
681
|
+
)
|
|
682
|
+
for rollout in completed_batch:
|
|
683
|
+
if rollout.rollout_id in self._completed_rollouts_v0:
|
|
684
|
+
# Already processed, skip
|
|
685
|
+
continue
|
|
686
|
+
if isinstance(rollout, Rollout):
|
|
687
|
+
rollout = await self._validate_data_v1(rollout)
|
|
688
|
+
else:
|
|
689
|
+
self._validate_data(rollout)
|
|
690
|
+
if rollout.rollout_id not in self._task_id_to_original_sample:
|
|
691
|
+
print(f"Warning: Received unknown rollout ID {rollout.rollout_id}, skipping.")
|
|
692
|
+
else:
|
|
693
|
+
self._completed_rollouts_v0[rollout.rollout_id] = rollout
|
|
694
|
+
if verbose:
|
|
695
|
+
print(f"Completed {len(self._completed_rollouts_v0)}/{self._total_tasks_queued} tasks...")
|
|
696
|
+
await asyncio.sleep(5)
|
|
697
|
+
|
|
698
|
+
print("All tasks finished.")
|
|
699
|
+
|
|
700
|
+
def run_until_all_finished(self, verbose: bool = True):
|
|
701
|
+
"""Synchronously waits for all queued tasks to be completed and reported."""
|
|
702
|
+
if self._total_tasks_queued == 0:
|
|
703
|
+
print("Warning: No tasks were queued.")
|
|
704
|
+
return
|
|
705
|
+
|
|
706
|
+
if self.mode == "v0":
|
|
707
|
+
if not self.server.loop or not self.server.startup_event.is_set():
|
|
708
|
+
raise RuntimeError("Server is not running or ready.")
|
|
709
|
+
loop = self.server.loop
|
|
710
|
+
else:
|
|
711
|
+
loop = self._internal_loop
|
|
712
|
+
assert loop is not None
|
|
713
|
+
|
|
714
|
+
coro = self._async_run_until_finished(verbose)
|
|
715
|
+
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
716
|
+
try:
|
|
717
|
+
future.result() # Wait indefinitely for all tasks to complete
|
|
718
|
+
except Exception as e:
|
|
719
|
+
print(f"Error while waiting for tasks to finish: {e}")
|
|
720
|
+
raise
|
|
721
|
+
|
|
722
|
+
def get_test_metrics(self):
|
|
723
|
+
"""Calculates and returns metrics for a validation run."""
|
|
724
|
+
assert not self.is_train, "This method should only be called during validation."
|
|
725
|
+
assert len(self._completed_rollouts_v0) == self._total_tasks_queued
|
|
726
|
+
|
|
727
|
+
sample_stat_list: List[Dict[str, Any]] = []
|
|
728
|
+
sample_stat_list_by_source: Dict[str, List[Dict[str, Any]]] = defaultdict(
|
|
729
|
+
list
|
|
730
|
+
) # FIXME: Evaluate whether grouping stats by source is actually needed.
|
|
731
|
+
|
|
732
|
+
for rollout_id, rollout in self._completed_rollouts_v0.items():
|
|
733
|
+
final_reward_raw: Optional[float] = rollout.final_reward
|
|
734
|
+
final_reward = self._fillna_reward(rollout)
|
|
735
|
+
if not rollout.triplets:
|
|
736
|
+
print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.")
|
|
737
|
+
sample_stat_list.append({"reward": final_reward, "has_reward": final_reward_raw is not None})
|
|
738
|
+
continue
|
|
739
|
+
response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
|
|
740
|
+
|
|
741
|
+
if "data_source" in self._task_id_to_original_sample[rollout_id]:
|
|
742
|
+
# When a test sample includes a 'data_source' field, record per-source statistics for test results.
|
|
743
|
+
# TODO: This is a flawed design. We should have a better way to handle this.
|
|
744
|
+
data_source = self._task_id_to_original_sample[rollout_id]["data_source"]
|
|
745
|
+
sample_stat_list_by_source[data_source].append(
|
|
746
|
+
{
|
|
747
|
+
"sum_response_length": np.sum(response_length_list),
|
|
748
|
+
"mean_response_length": np.mean(response_length_list) if response_length_list else 0,
|
|
749
|
+
"turn_count": len(rollout.triplets),
|
|
750
|
+
"reward": final_reward,
|
|
751
|
+
"has_reward": final_reward_raw is not None,
|
|
752
|
+
}
|
|
753
|
+
)
|
|
754
|
+
sample_stat_list.append(
|
|
755
|
+
{
|
|
756
|
+
"sum_response_length": np.sum(response_length_list),
|
|
757
|
+
"mean_response_length": np.mean(response_length_list) if response_length_list else 0,
|
|
758
|
+
"turn_count": len(rollout.triplets),
|
|
759
|
+
"reward": final_reward,
|
|
760
|
+
"has_reward": final_reward_raw is not None,
|
|
761
|
+
}
|
|
762
|
+
)
|
|
763
|
+
metric_dict: Dict[str, Any] = {}
|
|
764
|
+
|
|
765
|
+
stats_w_trace = [stat for stat in sample_stat_list if "sum_response_length" in stat]
|
|
766
|
+
stats_w_trace_by_source = {
|
|
767
|
+
data_source: [stat for stat in sample_stats if "sum_response_length" in stat]
|
|
768
|
+
for data_source, sample_stats in sample_stat_list_by_source.items()
|
|
769
|
+
}
|
|
770
|
+
for data_source, sample_stats in sample_stat_list_by_source.items():
|
|
771
|
+
metric_dict.update(
|
|
772
|
+
{
|
|
773
|
+
f"val/{data_source}/n_rollouts": len(sample_stats),
|
|
774
|
+
f"val/{data_source}/n_rollouts_w_trace": len(stats_w_trace_by_source[data_source]),
|
|
775
|
+
f"val/{data_source}/n_rollouts_w_reward": len(
|
|
776
|
+
[stat for stat in sample_stats if stat["has_reward"]]
|
|
777
|
+
),
|
|
778
|
+
f"val/{data_source}/reward": np.mean(
|
|
779
|
+
[stat["reward"] for stat in sample_stats]
|
|
780
|
+
), # each rollout must have a reward (fillna if missing)
|
|
781
|
+
f"val/{data_source}/mean_response_length": np.mean(
|
|
782
|
+
[stat["mean_response_length"] for stat in stats_w_trace_by_source[data_source]]
|
|
783
|
+
),
|
|
784
|
+
f"val/{data_source}/sum_response_length": np.mean(
|
|
785
|
+
[stat["sum_response_length"] for stat in stats_w_trace_by_source[data_source]]
|
|
786
|
+
),
|
|
787
|
+
f"val/{data_source}/turn_count": np.mean(
|
|
788
|
+
[stat["turn_count"] for stat in stats_w_trace_by_source[data_source]]
|
|
789
|
+
),
|
|
790
|
+
}
|
|
791
|
+
)
|
|
792
|
+
metric_dict.update(
|
|
793
|
+
{
|
|
794
|
+
"val/n_rollouts": len(sample_stat_list),
|
|
795
|
+
"val/n_rollouts_w_trace": len(stats_w_trace),
|
|
796
|
+
"val/n_rollouts_w_reward": len([stat for stat in sample_stat_list if stat["has_reward"]]),
|
|
797
|
+
"val/reward": np.mean(
|
|
798
|
+
[stat["reward"] for stat in sample_stat_list]
|
|
799
|
+
), # each rollout must have a reward (fillna if missing)
|
|
800
|
+
"val/mean_response_length": np.mean([stat["mean_response_length"] for stat in stats_w_trace]),
|
|
801
|
+
"val/sum_response_length": np.mean([stat["sum_response_length"] for stat in stats_w_trace]),
|
|
802
|
+
"val/turn_count": np.mean([stat["turn_count"] for stat in stats_w_trace]),
|
|
803
|
+
}
|
|
804
|
+
)
|
|
805
|
+
return metric_dict
|
|
806
|
+
|
|
807
|
+
def get_train_data_batch(
|
|
808
|
+
self, max_prompt_length: int, max_response_length: int, device: torch.device, global_steps: int
|
|
809
|
+
):
|
|
810
|
+
"""
|
|
811
|
+
Processes completed rollouts to generate a training data batch.
|
|
812
|
+
|
|
813
|
+
This function reconstructs the logic from the original AgentModeDaemon,
|
|
814
|
+
using data retrieved from the new server architecture. It handles padding,
|
|
815
|
+
truncation, and tensor creation for the PPO training loop.
|
|
816
|
+
"""
|
|
817
|
+
assert self.is_train, "This method should only be called during training."
|
|
818
|
+
assert len(self._completed_rollouts_v0) == self._total_tasks_queued
|
|
819
|
+
|
|
820
|
+
# 1. Reconstruct the `finished_id_to_sample_info` structure from completed rollouts
|
|
821
|
+
finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {}
|
|
822
|
+
finished_id_to_final_reward: Dict[str, float] = {}
|
|
823
|
+
sample_with_reward_count = 0
|
|
824
|
+
for rollout_id, rollout in self._completed_rollouts_v0.items():
|
|
825
|
+
original_sample = self._task_id_to_original_sample[rollout_id]
|
|
826
|
+
sample_with_reward_count += int(rollout.final_reward is not None)
|
|
827
|
+
final_reward = self._fillna_reward(rollout)
|
|
828
|
+
|
|
829
|
+
if not rollout.triplets:
|
|
830
|
+
finished_id_to_final_reward[rollout_id] = final_reward
|
|
831
|
+
print(f"Warning: No triplets found for training rollout {rollout.rollout_id}, skipping.")
|
|
832
|
+
continue
|
|
833
|
+
|
|
834
|
+
# The client should report triplets that contain prompt_ids and response_ids.
|
|
835
|
+
# Example triplet.prompt: {"token_ids": [...], "image_urls": [...]}
|
|
836
|
+
# Example triplet.response: {"token_ids": [...]}
|
|
837
|
+
trace_list = [
|
|
838
|
+
{
|
|
839
|
+
"prompt_ids": t.prompt.get("token_ids", []),
|
|
840
|
+
"response_ids": t.response.get("token_ids", []),
|
|
841
|
+
"image_urls": t.prompt.get("image_urls", []),
|
|
842
|
+
}
|
|
843
|
+
for t in rollout.triplets
|
|
844
|
+
]
|
|
845
|
+
info = {
|
|
846
|
+
"reward": final_reward,
|
|
847
|
+
"trace_list": trace_list,
|
|
848
|
+
"data_id": original_sample["data_id"],
|
|
849
|
+
}
|
|
850
|
+
finished_id_to_sample_info[rollout_id] = info
|
|
851
|
+
finished_id_to_final_reward[rollout_id] = final_reward
|
|
852
|
+
#
|
|
853
|
+
# --- Data processing and tensor creation logic ---
|
|
854
|
+
# Get all the reported data.
|
|
855
|
+
# prompt_ids are left-padded.
|
|
856
|
+
# response_ids are right-padded.
|
|
857
|
+
# They are concatenated in the middle.
|
|
858
|
+
# Discard handling:
|
|
859
|
+
# - Those exceeding max_prompt_length will be marked for discard, but not
|
|
860
|
+
# discarded here. They are only truncated and marked, to be discarded later.
|
|
861
|
+
# This is for the correctness of the advantage calculation.
|
|
862
|
+
# - The discard for the PPO mini-batch should also be handled this way.
|
|
863
|
+
input_ids_list: List[List[int]] = []
|
|
864
|
+
input_attention_mask_list: List[List[int]] = []
|
|
865
|
+
response_ids_list: List[List[int]] = []
|
|
866
|
+
response_attention_mask_list: List[List[int]] = []
|
|
867
|
+
reward_list: List[float] = []
|
|
868
|
+
data_id_list: List[str] = []
|
|
869
|
+
rollout_id_list: List[str] = []
|
|
870
|
+
turn_index_list: List[int] = []
|
|
871
|
+
is_drop_list: List[bool] = []
|
|
872
|
+
image_grid_thw_list: List[Optional[torch.Tensor]] = [] # For Qwen2-VL mrope
|
|
873
|
+
n_trunc_sample_because_of_response = 0
|
|
874
|
+
|
|
875
|
+
if self.trace_aggregator.get("level", "transition") == "transition":
|
|
876
|
+
for rollout_id, sample_info in finished_id_to_sample_info.items():
|
|
877
|
+
for turn_index, trace in enumerate(sample_info["trace_list"]):
|
|
878
|
+
|
|
879
|
+
reward_list.append(sample_info["reward"])
|
|
880
|
+
prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]
|
|
881
|
+
|
|
882
|
+
# Mark samples with prompts exceeding max_prompt_length to be dropped later
|
|
883
|
+
if len(prompt_ids) > max_prompt_length:
|
|
884
|
+
prompt_ids = prompt_ids[:max_prompt_length]
|
|
885
|
+
is_drop_list.append(True)
|
|
886
|
+
else:
|
|
887
|
+
is_drop_list.append(False)
|
|
888
|
+
|
|
889
|
+
# Truncate responses that exceed max_response_length
|
|
890
|
+
if len(response_ids) > max_response_length:
|
|
891
|
+
response_ids = response_ids[:max_response_length]
|
|
892
|
+
n_trunc_sample_because_of_response += 1
|
|
893
|
+
|
|
894
|
+
# Pad prompts to the left and responses to the right
|
|
895
|
+
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
|
|
896
|
+
prompt_ids, max_prompt_length, self.pad_token_id
|
|
897
|
+
)
|
|
898
|
+
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
|
|
899
|
+
response_ids, max_response_length, self.pad_token_id
|
|
900
|
+
)
|
|
901
|
+
|
|
902
|
+
input_ids_list.append(one_input_ids)
|
|
903
|
+
input_attention_mask_list.append(one_input_attention_mask)
|
|
904
|
+
response_ids_list.append(one_response_ids)
|
|
905
|
+
response_attention_mask_list.append(one_response_attention_mask)
|
|
906
|
+
data_id_list.append(sample_info["data_id"])
|
|
907
|
+
rollout_id_list.append(rollout_id)
|
|
908
|
+
turn_index_list.append(turn_index)
|
|
909
|
+
|
|
910
|
+
# Compute image_grid_thw for this triplet using image_urls from prompt
|
|
911
|
+
if self._use_mrope:
|
|
912
|
+
image_urls = trace.get("image_urls", [])
|
|
913
|
+
image_grid_thw_list.append(self._get_image_grid_thw(image_urls))
|
|
914
|
+
|
|
915
|
+
elif self.trace_aggregator.get("level", "transition") == "trajectory":
|
|
916
|
+
assert not self._use_mrope, "M-RoPE is not supported in trajectory level yet."
|
|
917
|
+
|
|
918
|
+
response_mask_list: List[List[int]] = []
|
|
919
|
+
unmerged_count: int = 0
|
|
920
|
+
template_mismatch_count, retoken_mismatch_count, others_mismatch_count = 0, 0, 0
|
|
921
|
+
response_per_turn_list: List[int] = []
|
|
922
|
+
|
|
923
|
+
for rollout_id, sample_info in finished_id_to_sample_info.items():
|
|
924
|
+
merged_trace_idx: List[List[int]] = []
|
|
925
|
+
|
|
926
|
+
# Identify which turns can be merged based on token ids prefix matching
|
|
927
|
+
current_merged_trace_idx: List[int] = []
|
|
928
|
+
current_context: List[int] = []
|
|
929
|
+
for turn_index, trace in enumerate(sample_info["trace_list"]):
|
|
930
|
+
response_per_turn_list.append(len(trace["response_ids"]))
|
|
931
|
+
is_prefix, diagnostic = ids_startswith(
|
|
932
|
+
trace["prompt_ids"] + trace["response_ids"],
|
|
933
|
+
current_context,
|
|
934
|
+
self.tokenizer,
|
|
935
|
+
self.trace_aggregator.get("debug", False),
|
|
936
|
+
)
|
|
937
|
+
if not is_prefix and self.trace_aggregator.get("debug", False) == True:
|
|
938
|
+
template_mismatch_count += diagnostic[0]
|
|
939
|
+
retoken_mismatch_count += diagnostic[1]
|
|
940
|
+
others_mismatch_count += diagnostic[2]
|
|
941
|
+
log_mismatch_detail(
|
|
942
|
+
diagnostic,
|
|
943
|
+
trace["prompt_ids"] + trace["response_ids"],
|
|
944
|
+
current_context,
|
|
945
|
+
global_steps,
|
|
946
|
+
rollout_id,
|
|
947
|
+
turn_index,
|
|
948
|
+
self.trace_aggregator.get("mismatch_log_dir", None),
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
if is_prefix:
|
|
952
|
+
current_context = trace["prompt_ids"] + trace["response_ids"]
|
|
953
|
+
current_merged_trace_idx.append(turn_index)
|
|
954
|
+
else:
|
|
955
|
+
merged_trace_idx.append(current_merged_trace_idx)
|
|
956
|
+
current_merged_trace_idx = [turn_index]
|
|
957
|
+
current_context = trace["prompt_ids"] + trace["response_ids"]
|
|
958
|
+
|
|
959
|
+
if current_merged_trace_idx not in merged_trace_idx:
|
|
960
|
+
merged_trace_idx.append(current_merged_trace_idx)
|
|
961
|
+
|
|
962
|
+
if len(merged_trace_idx) > 1:
|
|
963
|
+
unmerged_count += 1
|
|
964
|
+
|
|
965
|
+
# Merge all trace segments in merged_trace_idx into training samples
|
|
966
|
+
for current_merged_trace_idx in merged_trace_idx:
|
|
967
|
+
prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"]
|
|
968
|
+
|
|
969
|
+
# if the merged_trace_idx doesn't start with the beginning of the prompt_ids, we need to adjust it
|
|
970
|
+
if current_merged_trace_idx[0] > 0 and len(prompt_ids) > max_prompt_length:
|
|
971
|
+
response_ids = prompt_ids[max_prompt_length:]
|
|
972
|
+
prompt_ids = prompt_ids[:max_prompt_length]
|
|
973
|
+
response_mask = [1] * len(response_ids)
|
|
974
|
+
else:
|
|
975
|
+
response_ids = []
|
|
976
|
+
response_mask = []
|
|
977
|
+
|
|
978
|
+
prompt_length = len(prompt_ids)
|
|
979
|
+
response_ids += sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"]
|
|
980
|
+
response_mask += [1] * len(response_ids)
|
|
981
|
+
for turn_index in current_merged_trace_idx[1:]:
|
|
982
|
+
trace = sample_info["trace_list"][turn_index]
|
|
983
|
+
new_prompt_length = len(trace["prompt_ids"]) - len(response_ids) - prompt_length
|
|
984
|
+
response_ids += trace["prompt_ids"][-new_prompt_length:]
|
|
985
|
+
response_ids += trace["response_ids"]
|
|
986
|
+
response_mask += [0] * new_prompt_length
|
|
987
|
+
response_mask += [1] * len(trace["response_ids"])
|
|
988
|
+
|
|
989
|
+
reward_list.append(sample_info["reward"])
|
|
990
|
+
|
|
991
|
+
# Mark samples with prompts exceeding max_prompt_length to be dropped later
|
|
992
|
+
if len(prompt_ids) > max_prompt_length:
|
|
993
|
+
prompt_ids = prompt_ids[:max_prompt_length]
|
|
994
|
+
is_drop_list.append(True)
|
|
995
|
+
else:
|
|
996
|
+
is_drop_list.append(False)
|
|
997
|
+
|
|
998
|
+
# Truncate responses that exceed max_response_length
|
|
999
|
+
if len(response_ids) > max_response_length:
|
|
1000
|
+
response_ids = response_ids[:max_response_length]
|
|
1001
|
+
response_mask = response_mask[:max_response_length]
|
|
1002
|
+
n_trunc_sample_because_of_response += 1
|
|
1003
|
+
|
|
1004
|
+
# Pad prompts to the left and responses to the right
|
|
1005
|
+
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
|
|
1006
|
+
prompt_ids, max_prompt_length, self.pad_token_id
|
|
1007
|
+
)
|
|
1008
|
+
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
|
|
1009
|
+
response_ids, max_response_length, self.pad_token_id
|
|
1010
|
+
)
|
|
1011
|
+
one_response_mask, _ = get_right_padded_ids_and_attention_mask(
|
|
1012
|
+
response_mask, max_response_length, 0
|
|
1013
|
+
)
|
|
1014
|
+
|
|
1015
|
+
input_ids_list.append(one_input_ids)
|
|
1016
|
+
input_attention_mask_list.append(one_input_attention_mask)
|
|
1017
|
+
response_ids_list.append(one_response_ids)
|
|
1018
|
+
response_attention_mask_list.append(one_response_attention_mask)
|
|
1019
|
+
response_mask_list.append(one_response_mask)
|
|
1020
|
+
data_id_list.append(sample_info["data_id"])
|
|
1021
|
+
rollout_id_list.append(rollout_id)
|
|
1022
|
+
# turn_index_list.append(current_merged_trace_idx)
|
|
1023
|
+
else:
|
|
1024
|
+
raise ValueError(f"Unknown trace_aggregator level: {self.trace_aggregator.get('level')}")
|
|
1025
|
+
|
|
1026
|
+
n_transition = len(input_ids_list)
|
|
1027
|
+
batch_input_ids = torch.LongTensor(input_ids_list).to(device)
|
|
1028
|
+
input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
|
|
1029
|
+
batch_response_ids = torch.LongTensor(response_ids_list).to(device)
|
|
1030
|
+
response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)
|
|
1031
|
+
response_mask = (
|
|
1032
|
+
torch.LongTensor(response_mask_list).to(device) if self.trace_aggregator.get("level", "transition") == "trajectory" else None # type: ignore
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
# Concatenate prompts and responses to form the full sequence
|
|
1036
|
+
batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
|
|
1037
|
+
attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1)
|
|
1038
|
+
|
|
1039
|
+
# Compute position_ids - use mrope for Qwen2-VL, standard 2D otherwise
|
|
1040
|
+
if self._use_mrope:
|
|
1041
|
+
# For Qwen2-VL: compute 4D position_ids (batch_size, 4, seq_length)
|
|
1042
|
+
position_ids_list: list[torch.Tensor] = []
|
|
1043
|
+
for i in range(n_transition):
|
|
1044
|
+
pos_ids = self._compute_mrope_position_ids(
|
|
1045
|
+
input_ids=batch_seq[i],
|
|
1046
|
+
attention_mask=attention_mask[i],
|
|
1047
|
+
image_grid_thw=image_grid_thw_list[i] if image_grid_thw_list else None,
|
|
1048
|
+
) # (4, seq_length)
|
|
1049
|
+
position_ids_list.append(pos_ids)
|
|
1050
|
+
# Stack to (batch_size, 4, seq_length)
|
|
1051
|
+
position_ids = torch.stack(position_ids_list, dim=0)
|
|
1052
|
+
else:
|
|
1053
|
+
# Standard 2D position_ids (batch_size, seq_length)
|
|
1054
|
+
position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
|
|
1055
|
+
|
|
1056
|
+
is_drop_mask = torch.BoolTensor(is_drop_list).to(device)
|
|
1057
|
+
scores = torch.tensor(reward_list, dtype=torch.bfloat16).to(device)
|
|
1058
|
+
|
|
1059
|
+
# Create token-level scores by placing the final reward at the last token position
|
|
1060
|
+
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
|
|
1061
|
+
# For mrope (3D position_ids), use the first dimension (text position_ids) for eos calculation
|
|
1062
|
+
if self._use_mrope:
|
|
1063
|
+
# position_ids is (batch_size, 4, seq_length), use first dim for text positions
|
|
1064
|
+
text_position_ids = position_ids[:, 0, :] # (batch_size, seq_length)
|
|
1065
|
+
eos_mask_idx = torch.argmax(text_position_ids * attention_mask, dim=-1) # (bsz,)
|
|
1066
|
+
else:
|
|
1067
|
+
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
|
|
1068
|
+
# At the eos_mask_idx position of each sample, fill in the corresponding scores.
|
|
1069
|
+
# torch.arange(n_transition) generates [0,1,2,...,bsz-1] as indices for the batch dimension.
|
|
1070
|
+
token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores
|
|
1071
|
+
# Only take the last response_length part of the sequence to get the token-level scores for the model's response part.
|
|
1072
|
+
token_level_scores = token_level_scores[:, -max_response_length:]
|
|
1073
|
+
|
|
1074
|
+
# Form the final batch using TensorDict
|
|
1075
|
+
batch = TensorDict(
|
|
1076
|
+
{
|
|
1077
|
+
"prompts": batch_input_ids,
|
|
1078
|
+
"responses": batch_response_ids,
|
|
1079
|
+
"input_ids": batch_seq, # here input_ids become the whole sentences
|
|
1080
|
+
"attention_mask": attention_mask,
|
|
1081
|
+
"position_ids": position_ids,
|
|
1082
|
+
"is_drop_mask": is_drop_mask,
|
|
1083
|
+
"token_level_scores": token_level_scores.contiguous(),
|
|
1084
|
+
**(
|
|
1085
|
+
{"response_mask": response_mask}
|
|
1086
|
+
if self.trace_aggregator.get("level", "transition") == "trajectory"
|
|
1087
|
+
else {}
|
|
1088
|
+
),
|
|
1089
|
+
}, # type: ignore
|
|
1090
|
+
batch_size=n_transition,
|
|
1091
|
+
)
|
|
1092
|
+
data_proto = DataProto(batch=batch)
|
|
1093
|
+
|
|
1094
|
+
data_metrics = {
|
|
1095
|
+
"training/reward": np.mean(list(finished_id_to_final_reward.values())),
|
|
1096
|
+
"training/n_rollouts": len(finished_id_to_final_reward),
|
|
1097
|
+
"training/n_rollouts_w_trace": len(finished_id_to_sample_info),
|
|
1098
|
+
"training/n_rollouts_w_reward": sample_with_reward_count,
|
|
1099
|
+
"training/n_truncated_triplets": n_trunc_sample_because_of_response,
|
|
1100
|
+
"training/n_triplets": n_transition,
|
|
1101
|
+
# log data, only for debug testing
|
|
1102
|
+
**(
|
|
1103
|
+
{
|
|
1104
|
+
"training/n_unmerged_rollouts": unmerged_count, # type: ignore
|
|
1105
|
+
"training/n_triplets_by_turn": len(response_per_turn_list), # type: ignore
|
|
1106
|
+
"training/avg_response_length_by_turn": np.mean(response_per_turn_list), # type: ignore
|
|
1107
|
+
"training/max_response_length_by_turn": np.max(response_per_turn_list), # type: ignore
|
|
1108
|
+
"training/min_response_length_by_turn": np.min(response_per_turn_list), # type: ignore
|
|
1109
|
+
}
|
|
1110
|
+
if self.trace_aggregator.get("level", "transition") == "trajectory"
|
|
1111
|
+
else {}
|
|
1112
|
+
),
|
|
1113
|
+
**(
|
|
1114
|
+
{
|
|
1115
|
+
"training/template_mismatch_triplets": template_mismatch_count, # type: ignore
|
|
1116
|
+
"training/retoken_mismatch_triplets": retoken_mismatch_count, # type: ignore
|
|
1117
|
+
"training/others_mismatch_triplets": others_mismatch_count, # type: ignore
|
|
1118
|
+
"training/template_mismatch_ratio": template_mismatch_count / len(response_per_turn_list), # type: ignore
|
|
1119
|
+
"training/retoken_mismatch_ratio": retoken_mismatch_count / len(response_per_turn_list), # type: ignore
|
|
1120
|
+
"training/others_mismatch_ratio": others_mismatch_count / len(response_per_turn_list), # type: ignore
|
|
1121
|
+
}
|
|
1122
|
+
if self.trace_aggregator.get("level", "transition") == "trajectory"
|
|
1123
|
+
and self.trace_aggregator.get("debug", False)
|
|
1124
|
+
else {}
|
|
1125
|
+
),
|
|
1126
|
+
}
|
|
1127
|
+
|
|
1128
|
+
# Add non-tensor data for advantage calculation and logging
|
|
1129
|
+
data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list) # type: ignore
|
|
1130
|
+
data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list) # type: ignore
|
|
1131
|
+
if self.trace_aggregator.get("level", "transition") == "transition":
|
|
1132
|
+
data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list) # type: ignore
|
|
1133
|
+
|
|
1134
|
+
return data_proto, data_metrics
|
|
1135
|
+
|
|
1136
|
+
def clear_data_and_server(self):
|
|
1137
|
+
"""Resets the internal state of the daemon for the next run."""
|
|
1138
|
+
self.backend_llm_server_addresses = []
|
|
1139
|
+
self._completed_rollouts_v0.clear()
|
|
1140
|
+
self._task_id_to_original_sample.clear()
|
|
1141
|
+
self._total_tasks_queued = 0
|
|
1142
|
+
# For a true reset, the server's internal queues would also need clearing.
|
|
1143
|
+
# This implementation assumes that `set_up_data_and_server` is called
|
|
1144
|
+
# for each new run, effectively starting a fresh batch.
|
|
1145
|
+
|
|
1146
|
+
def _fillna_reward(self, rollout: RolloutLegacy):
|
|
1147
|
+
if rollout.final_reward is None:
|
|
1148
|
+
if self.reward_fillna_value is not None: # type: ignore
|
|
1149
|
+
final_reward = self.reward_fillna_value
|
|
1150
|
+
else:
|
|
1151
|
+
raise ValueError(f"Reward is None for rollout {rollout.rollout_id}, please check the reward function.")
|
|
1152
|
+
else:
|
|
1153
|
+
final_reward = rollout.final_reward
|
|
1154
|
+
return final_reward
|