mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,1154 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ import random
7
+ import socket
8
+ import threading
9
+ import time
10
+ import uuid
11
+ from collections import defaultdict
12
+ from collections.abc import Mapping
13
+ from typing import Any, Dict, List, Literal, Optional, Tuple, cast
14
+
15
+ import numpy as np
16
+ import requests
17
+ import torch
18
+ from flask import Flask, Response, abort, request
19
+ from tensordict import TensorDict
20
+ from verl import DataProto
21
+
22
+ from mantisdk import LLM, MantisdkServer, NamedResources, RolloutLegacy
23
+ from mantisdk.adapter.triplet import TracerTraceToTriplet, TraceToTripletBase
24
+ from mantisdk.llm_proxy import LLMProxy, ModelConfig
25
+ from mantisdk.store.base import LightningStore
26
+ from mantisdk.types import EnqueueRolloutRequest, Rollout, RolloutConfig, Task
27
+
28
+ __all__ = [
29
+ "AgentModeDaemon",
30
+ "get_left_padded_ids_and_attention_mask",
31
+ "get_right_padded_ids_and_attention_mask",
32
+ ]
33
+
34
+
35
+ def ids_startswith(
36
+ full_ids: List[int], prefix_ids: List[int], tokenizer: Any, debug: bool = False
37
+ ) -> Tuple[bool, Tuple[bool, bool, bool]]:
38
+ is_prefix: bool
39
+ template_mismatch, retoken_mismatch, others_mismatch = False, False, False
40
+ if full_ids[: len(prefix_ids)] == prefix_ids:
41
+ is_prefix = True
42
+ return True, (template_mismatch, retoken_mismatch, others_mismatch)
43
+ else:
44
+ is_prefix = False
45
+
46
+ if not debug:
47
+ return is_prefix, (template_mismatch, retoken_mismatch, others_mismatch)
48
+
49
+ def _special_token_sequence(ids: List[int]) -> List[int]:
50
+ return [id for id in ids if id in tokenizer.all_special_ids]
51
+
52
+ def _none_special_token_sequence(ids: List[int]) -> List[int]:
53
+ return [id for id in ids if id not in tokenizer.all_special_ids]
54
+
55
+ # First, handle special tokens
56
+ full_special_ids = _special_token_sequence(full_ids)
57
+ prefix_special_ids = _special_token_sequence(prefix_ids)
58
+ if sum(1 for a, b in zip(full_special_ids, prefix_special_ids) if a != b) > 0:
59
+ template_mismatch = True
60
+
61
+ # Next, handle string content
62
+ full_content_ids = _none_special_token_sequence(full_ids)
63
+ prefix_content_ids = _none_special_token_sequence(prefix_ids)
64
+ full_string = tokenizer.decode(full_ids, skip_special_tokens=True)
65
+ prefix_string = tokenizer.decode(prefix_ids, skip_special_tokens=True)
66
+ if full_content_ids[: len(prefix_content_ids)] != prefix_content_ids and full_string.startswith(prefix_string):
67
+ retoken_mismatch = True
68
+ elif full_content_ids[: len(prefix_content_ids)] != prefix_content_ids and not full_string.startswith(
69
+ prefix_string
70
+ ):
71
+ others_mismatch = True
72
+ return is_prefix, (template_mismatch, retoken_mismatch, others_mismatch)
73
+
74
+
75
+ def log_mismatch_detail(
76
+ diagnostic: Tuple[bool, bool, bool],
77
+ full_ids: List[int],
78
+ prefix_ids: List[int],
79
+ global_steps: int,
80
+ rollout_id: str,
81
+ turn_id: int,
82
+ log_dir: str | None = None,
83
+ ):
84
+ if log_dir is None:
85
+ return
86
+ os.makedirs(log_dir, exist_ok=True)
87
+ template_mismatch, retoken_mismatch, others_mismatch = diagnostic
88
+ if template_mismatch:
89
+ with open(os.path.join(log_dir, "template_mismatch.log"), "a+") as f:
90
+ print(
91
+ "-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10,
92
+ file=f,
93
+ )
94
+ print(full_ids, file=f)
95
+ print(prefix_ids, file=f)
96
+ if retoken_mismatch:
97
+ with open(os.path.join(log_dir, "retoken_mismatch.log"), "a+") as f:
98
+ print(
99
+ "-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10,
100
+ file=f,
101
+ )
102
+ print(full_ids, file=f)
103
+ print(prefix_ids, file=f)
104
+ if others_mismatch:
105
+ with open(os.path.join(log_dir, "others_mismatch.log"), "a+") as f:
106
+ print(
107
+ "-" * 10 + f" Global Steps: {global_steps}, Rollout ID: {rollout_id}, Turn ID: {turn_id} " + "-" * 10,
108
+ file=f,
109
+ )
110
+ print(full_ids, file=f)
111
+ print(prefix_ids, file=f)
112
+
113
+
114
+ def get_left_padded_ids_and_attention_mask(
115
+ ids: List[int], max_length: int, pad_token_id: int
116
+ ) -> Tuple[List[int], List[int]]:
117
+ """
118
+ Left-pad (or truncate) a sequence of token IDs to a fixed length,
119
+ and build the corresponding attention mask.
120
+
121
+ Args:
122
+ ids: the original list of token IDs.
123
+ max_length: desired total length after padding/truncation.
124
+ pad_token_id: ID to use for padding.
125
+
126
+ Returns:
127
+ padded_ids (any): list of length == max_length.
128
+ attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads.
129
+ """
130
+ seq_len = len(ids)
131
+
132
+ if seq_len >= max_length:
133
+ # too long → truncate from the left, keep the last max_length tokens
134
+ trimmed = ids[-max_length:]
135
+ attention_mask = [1] * max_length
136
+ return trimmed, attention_mask
137
+
138
+ # too short → pad on the left
139
+ pad_len = max_length - seq_len
140
+ padded_ids = [pad_token_id] * pad_len + ids
141
+ attention_mask = [0] * pad_len + [1] * seq_len
142
+ return padded_ids, attention_mask
143
+
144
+
145
+ def get_right_padded_ids_and_attention_mask(
146
+ ids: List[int], max_length: int, pad_token_id: int
147
+ ) -> Tuple[List[int], List[int]]:
148
+ """
149
+ Right-pad (or truncate) a sequence of token IDs to a fixed length,
150
+ and build the corresponding attention mask.
151
+
152
+ Args:
153
+ ids: the original list of token IDs.
154
+ max_length: desired total length after padding/truncation.
155
+ pad_token_id: ID to use for padding.
156
+
157
+ Returns:
158
+ padded_ids (any): list of length == max_length.
159
+ attention_mask (any): list of same length: 1 for non-pad tokens, 0 for pads.
160
+ """
161
+ seq_len = len(ids)
162
+
163
+ if seq_len >= max_length:
164
+ # too long → truncate to the first max_length tokens
165
+ trimmed = ids[:max_length]
166
+ attention_mask = [1] * max_length
167
+ return trimmed, attention_mask
168
+
169
+ # too short → pad on the right
170
+ pad_len = max_length - seq_len
171
+ padded_ids = ids + [pad_token_id] * pad_len
172
+ attention_mask = [1] * seq_len + [0] * pad_len
173
+ return padded_ids, attention_mask
174
+
175
+
176
+ def _find_available_port() -> int:
177
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
178
+ s.bind(("", 0))
179
+ return s.getsockname()[1]
180
+
181
+
182
+ def _to_native(obj: Any) -> Any:
183
+ """Convert data retrieved from Parquet to data usable in AGL server."""
184
+ # 1) Arrays -> list (then recurse)
185
+ if isinstance(obj, np.ndarray):
186
+ return _to_native(obj.tolist())
187
+
188
+ # 2) NumPy scalar types -> Python scalars
189
+ if isinstance(obj, np.generic):
190
+ return _to_native(obj.item())
191
+
192
+ # 3) Dict-like -> dict
193
+ if isinstance(obj, Mapping):
194
+ return {_to_native(k): _to_native(v) for k, v in obj.items()} # type: ignore
195
+
196
+ # 4) Lists/Tuples/Sets -> list
197
+ if isinstance(obj, (list, tuple, set)):
198
+ return [_to_native(x) for x in obj] # type: ignore
199
+
200
+ # 5) Anything else: leave as-is
201
+ return obj
202
+
203
+
204
+ class AgentModeDaemon:
205
+ """
206
+ AgentModeDaemon using the MantisdkServer SDK.
207
+
208
+ This class manages the server lifecycle, task queueing, and results
209
+ retrieval, while also running a proxy server for LLM requests. It maintains
210
+ the original interface for compatibility with the RayPPOTrainer.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ port: Optional[int],
216
+ train_rollout_n: int,
217
+ train_information: Dict[str, Any],
218
+ tokenizer: Any,
219
+ mini_batch_size: int,
220
+ pad_token_id: int,
221
+ reward_fillna_value: float = 0.0,
222
+ llm_timeout_seconds: float = 1200.0,
223
+ mode: Literal["v0", "v1"] = "v1",
224
+ llm_proxy: LLMProxy | None = None,
225
+ store: LightningStore | None = None,
226
+ adapter: TraceToTripletBase | None = None,
227
+ processor: Any = None,
228
+ image_base_dir: Optional[str] = None,
229
+ trace_aggregator: Dict[str, Any] = {"level": "transition"},
230
+ ):
231
+ self.mode = mode
232
+ self.llm_timeout_seconds = llm_timeout_seconds
233
+
234
+ # Server and Task Configuration
235
+ if mode == "v0":
236
+ assert port is not None
237
+ self.server_port = port
238
+ self.server = MantisdkServer(
239
+ host="0.0.0.0", port=self.server_port, task_timeout_seconds=self.llm_timeout_seconds
240
+ )
241
+ self.proxy_port = _find_available_port() # Run proxy on a different port
242
+ else:
243
+ assert store is not None
244
+ self.store = store
245
+ if llm_proxy is None:
246
+ self.llm_proxy = LLMProxy(
247
+ port=_find_available_port(),
248
+ model_list=[],
249
+ store=store,
250
+ )
251
+ else:
252
+ # Reuse the existing LLM proxy (probably configured by user)
253
+ self.llm_proxy = llm_proxy
254
+ if adapter is None:
255
+ self.adapter = TracerTraceToTriplet()
256
+ else:
257
+ # Reuse the one from trainer
258
+ self.adapter = adapter
259
+ self._internal_loop: Optional[asyncio.AbstractEventLoop] = None
260
+ self._internal_loop_thread = threading.Thread(target=self._internal_loop_runner, daemon=True)
261
+ self._internal_loop_thread.start()
262
+
263
+ # Training and Data Configuration
264
+ self.train_rollout_n = train_rollout_n
265
+ self.train_information = train_information
266
+ self.mini_batch_size = mini_batch_size
267
+ self.pad_token_id = pad_token_id
268
+ self.tokenizer = tokenizer
269
+ self.processor = processor
270
+ self.reward_fillna_value = reward_fillna_value
271
+ self.image_base_dir = image_base_dir
272
+ self.trace_aggregator = trace_aggregator
273
+
274
+ # Check if model requires multimodal position_ids (e.g., Qwen2-VL)
275
+ self._use_mrope = self._is_mrope_model()
276
+
277
+ # Internal State
278
+ self.backend_llm_server_addresses: List[str] = []
279
+ self._total_tasks_queued = 0
280
+ self._completed_rollouts_v0: Dict[str, RolloutLegacy] = {}
281
+ self._task_id_to_original_sample: Dict[str, Dict[str, Any]] = {}
282
+ self._server_thread: Optional[threading.Thread] = None
283
+ self._proxy_thread: Optional[threading.Thread] = None
284
+ self.is_train = True
285
+
286
+ def _internal_loop_runner(self):
287
+ """Run the internal loop."""
288
+ loop = asyncio.new_event_loop()
289
+ asyncio.set_event_loop(loop)
290
+ self._internal_loop = loop
291
+ loop.run_forever()
292
+ loop.close()
293
+
294
+ # Multimodal utilities for M-RoPE position embeddings
295
+
296
+ def _is_mrope_model(self) -> bool:
297
+ """Check if processor requires M-RoPE position embeddings."""
298
+ if self.processor is None or not hasattr(self.processor, "image_processor"):
299
+ return False
300
+ name = self.processor.image_processor.__class__.__name__
301
+ return "Qwen2VLImageProcessor" in name or "Qwen3VLImageProcessor" in name
302
+
303
+ def _resolve_image_path(self, path: str) -> str:
304
+ """Resolve relative image path with base directory."""
305
+ import os
306
+
307
+ if os.path.isabs(path):
308
+ return path
309
+ if self.image_base_dir is None:
310
+ raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.")
311
+ return os.path.join(self.image_base_dir, path)
312
+
313
+ def _get_image_grid_thw(self, image_urls: List[str]) -> Optional[torch.Tensor]:
314
+ """Compute image_grid_thw from image URLs for M-RoPE computation.
315
+
316
+ Args:
317
+ image_urls: List of image URLs extracted from triplet prompt payload.
318
+ URLs can be http(s):// URLs or file:// URIs, or data: URIs.
319
+ """
320
+ from PIL import Image
321
+ from verl.utils.dataset.vision_utils import process_image # pyright: ignore[reportUnknownVariableType]
322
+
323
+ if self.processor is None or not image_urls:
324
+ return None
325
+
326
+ def to_image_uri(url: str) -> str:
327
+ # Already a proper URI (http, https, file, data)
328
+ if url.startswith(("http://", "https://", "file://", "data:")):
329
+ return url
330
+ # Treat as a file path that needs resolution
331
+ resolved = self._resolve_image_path(url)
332
+ return f"file://{resolved}"
333
+
334
+ images: List[Image.Image] = [process_image({"image": to_image_uri(url)}) for url in image_urls]
335
+ model_inputs = self.processor(text=["dummy"], images=images, return_tensors="pt")
336
+ return model_inputs.get("image_grid_thw")
337
+
338
+ def _compute_mrope_position_ids(
339
+ self,
340
+ input_ids: torch.Tensor,
341
+ attention_mask: torch.Tensor,
342
+ image_grid_thw: Optional[torch.Tensor] = None,
343
+ ) -> torch.Tensor:
344
+ """Compute 4D position_ids for M-RoPE models."""
345
+ from typing import Callable
346
+
347
+ get_rope_index: Callable[..., torch.Tensor]
348
+ if "Qwen3VL" in self.processor.__class__.__name__:
349
+ from verl.models.transformers.qwen3_vl import get_rope_index # pyright: ignore[reportUnknownVariableType]
350
+ else:
351
+ from verl.models.transformers.qwen2_vl import get_rope_index # pyright: ignore[reportUnknownVariableType]
352
+
353
+ vision_pos = get_rope_index(
354
+ self.processor, input_ids=input_ids, image_grid_thw=image_grid_thw, attention_mask=attention_mask
355
+ )
356
+
357
+ valid_mask = attention_mask.bool()
358
+ text_pos = torch.zeros((1, len(input_ids)), dtype=torch.long, device=input_ids.device)
359
+ text_pos[0, valid_mask] = torch.arange(valid_mask.sum().item(), device=input_ids.device)
360
+
361
+ return torch.cat([text_pos, vision_pos], dim=0)
362
+
363
+ def _start_proxy_server_v0(self):
364
+ """
365
+ Initializes and runs a Flask-based proxy server in a separate thread.
366
+ This proxy load-balances requests to the actual backend LLM servers.
367
+ """
368
+ app = Flask(__name__)
369
+
370
+ num_requests = 0
371
+ last_request_time = 0
372
+
373
+ @app.route("/v1/<path:path>", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
374
+ def proxy(path: str): # type: ignore
375
+ if not self.backend_llm_server_addresses:
376
+ abort(503, description="No backend LLM servers available.")
377
+
378
+ # Randomly choose a backend server for load balancing
379
+ target_server = random.choice(self.backend_llm_server_addresses)
380
+ target_url = f"http://{target_server}/v1/{path}"
381
+
382
+ # Copy client request headers, removing the Host header
383
+ headers = {key: value for key, value in request.headers if key.lower() != "host"}
384
+
385
+ # Log the request for debugging
386
+ nonlocal num_requests, last_request_time
387
+ current_time = time.time()
388
+ num_requests += 1
389
+ if current_time - last_request_time > 60 or num_requests == 1 or num_requests % 100 == 0:
390
+ print(f"Proxying {request.method} request to {target_server}. Request data: {request.get_data()}")
391
+ last_request_time = current_time
392
+
393
+ try:
394
+ # Forward the request to the target backend
395
+ resp = requests.request(
396
+ method=request.method,
397
+ url=target_url,
398
+ headers=headers,
399
+ params=request.args, # type: ignore
400
+ data=request.get_data(),
401
+ cookies=request.cookies,
402
+ allow_redirects=False,
403
+ timeout=self.llm_timeout_seconds,
404
+ )
405
+ # Filter out hop-by-hop headers before returning the response
406
+ excluded_headers = [
407
+ "content-encoding",
408
+ "content-length",
409
+ "transfer-encoding",
410
+ "connection",
411
+ "keep-alive",
412
+ "proxy-authenticate",
413
+ "proxy-authorization",
414
+ "te",
415
+ "trailers",
416
+ "upgrade",
417
+ ]
418
+ response_headers = [
419
+ (name, value) for name, value in resp.raw.headers.items() if name.lower() not in excluded_headers
420
+ ]
421
+ if resp.status_code == 200:
422
+ # NOTE: from Zhiyuan's code.
423
+ # https://github.com/hzy46/verl_agent_mode/blob/2db65ea9858f645a914120357412a7540f8bd82d/verl/trainer/ppo/ray_trainer.py#L692-L711
424
+ # request_json = json.loads(request.get_data().decode("utf-8"))
425
+ response_json = json.loads(resp.content.decode("utf-8"))
426
+ # response_message = ChatCompletion(**response_json).choices[0].message.model_dump(exclude_unset=True, exclude_none=True)
427
+ # tool_schemas = request_json.get("tools", None)
428
+ # prompt_ids = self.tokenizer.apply_chat_template(request_json["messages"], tools=tool_schemas, add_generation_prompt=True, tokenize=True)
429
+ # full_ids = self.tokenizer.apply_chat_template(request_json["messages"] + [response_message], tools=tool_schemas, add_generation_prompt=False, tokenize=True)
430
+ # TBD: response_ids sometimes ends with "<eos_id>\n", shall we keep the extra "\n"?
431
+ # sometimes it has some differences with the hacky method in the end, but this should align with ToolCompletionCallback
432
+ # response_ids = full_ids[len(prompt_ids):]
433
+
434
+ # NOTE (yuge): They are different. Don't know why.
435
+ # assert response_json['prompt_token_ids'] == prompt_ids
436
+ # patched_response_ids = response_json['response_token_ids'][0]
437
+ # assert patched_response_ids == response_ids[:len(patched_response_ids)], f"{patched_response_ids} != {response_ids[:len(patched_response_ids)]}"
438
+ # response_json['prompt_token_ids'] = prompt_ids
439
+ # response_json['response_token_ids'] = [response_ids]
440
+ replaced_return_content = json.dumps(response_json).encode("utf-8")
441
+ return Response(replaced_return_content, status=resp.status_code, headers=response_headers)
442
+ return Response(resp.content, resp.status_code, response_headers)
443
+ except requests.exceptions.RequestException as e:
444
+ abort(500, description=f"Error proxying request: {e}")
445
+
446
+ def run_app():
447
+ app.run(host="0.0.0.0", port=self.proxy_port, threaded=True, debug=False)
448
+
449
+ self._proxy_thread = threading.Thread(target=run_app, daemon=True)
450
+ self._proxy_thread.start()
451
+ print(f"Proxy server running on port {self.proxy_port}")
452
+
453
+ async def _update_proxy_server_v1(self):
454
+ model_name = self.train_information.get("model")
455
+ if not model_name:
456
+ raise ValueError("Model name is not set.")
457
+ self.llm_proxy.update_model_list(
458
+ [
459
+ ModelConfig(
460
+ {
461
+ "model_name": model_name,
462
+ "litellm_params": {
463
+ "model": "hosted_vllm/" + model_name,
464
+ "api_base": f"http://{address}/v1/",
465
+ },
466
+ }
467
+ )
468
+ for address in self.backend_llm_server_addresses
469
+ ],
470
+ )
471
+
472
+ await self.llm_proxy.restart()
473
+
474
+ def start(self):
475
+ """Starts the main MantisdkServer and the proxy server."""
476
+
477
+ if self.mode == "v0":
478
+
479
+ def run_server():
480
+ """Run the MantisdkServer in a separate thread."""
481
+ asyncio.run(self.server.run_forever())
482
+
483
+ self._server_thread = threading.Thread(target=run_server, daemon=True)
484
+ self._server_thread.start()
485
+
486
+ # Wait for the server's internal startup event to be set.
487
+ print("Waiting for MantisdkServer to start...")
488
+ is_ready = self.server.startup_event.wait(timeout=20.0) # Wait up to 20s
489
+ if not is_ready:
490
+ raise RuntimeError("MantisdkServer failed to start within the timeout period.")
491
+
492
+ print(f"MantisdkServer control plane running on port {self.server_port}")
493
+
494
+ self._start_proxy_server_v0()
495
+ else:
496
+ # Agent lightning server is no longer needed;
497
+ # Start proxy server in _async_set_up
498
+ pass
499
+
500
+ async def _async_set_up(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True):
501
+ """Async helper to set up data and resources on the server."""
502
+ self.clear_data_and_server()
503
+ if server_addresses != self.backend_llm_server_addresses:
504
+ self.backend_llm_server_addresses = server_addresses
505
+ if self.mode == "v1" and not self.llm_proxy.is_running():
506
+ await self._update_proxy_server_v1()
507
+ self.is_train = is_train
508
+
509
+ # 1. Update resources on the server for clients to use
510
+ if self.mode == "v0":
511
+ llm_resource = LLM(
512
+ endpoint=f"http://127.0.0.1:{self.proxy_port}/v1",
513
+ model=self.train_information.get("model", "default-model"),
514
+ sampling_parameters={
515
+ "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0)
516
+ },
517
+ )
518
+ else:
519
+ llm_resource = self.llm_proxy.as_resource(
520
+ sampling_parameters={
521
+ "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0)
522
+ },
523
+ )
524
+
525
+ resources: NamedResources = {"main_llm": llm_resource}
526
+
527
+ if self.mode == "v0":
528
+ resources_id = await self.server.update_resources(resources)
529
+ else:
530
+ resources_update = await self.store.add_resources(resources)
531
+ resources_id = resources_update.resources_id
532
+
533
+ # 2. Queue tasks for agents to process
534
+ keys = list(data.keys())
535
+ num_samples = len(data[keys[0]])
536
+ rollouts_per_sample = self.train_rollout_n if is_train else 1
537
+
538
+ enqueue_rollout_requests: List[EnqueueRolloutRequest] = []
539
+ data_id_to_original_sample: Dict[str, Dict[str, Any]] = {}
540
+
541
+ for i in range(num_samples):
542
+ data_id = str(uuid.uuid4())
543
+ original_sample = {key: data[key][i] for key in keys}
544
+ original_sample["data_id"] = data_id
545
+ data_id_to_original_sample[data_id] = original_sample
546
+
547
+ # For training, each sample is rolled out multiple times
548
+ # Data ID is different from Rollout ID, as one data can have multiple rollouts.
549
+ for _ in range(rollouts_per_sample):
550
+ task_metadata = {"data_id": data_id, "is_train": is_train}
551
+ if self.mode == "v0":
552
+ # Queue immediately
553
+ rollout_id = await self.server.queue_task(
554
+ sample=_to_native(original_sample),
555
+ mode="train" if is_train else "val",
556
+ resources_id=resources_id,
557
+ metadata=task_metadata,
558
+ )
559
+
560
+ # Store original sample data to reconstruct batch information later
561
+ self._task_id_to_original_sample[rollout_id] = original_sample
562
+ self._total_tasks_queued += 1
563
+ else:
564
+ # Collect tasks to enqueue in batch and queue them later
565
+ enqueue_rollout_requests.append(
566
+ EnqueueRolloutRequest(
567
+ input=_to_native(original_sample),
568
+ mode="train" if is_train else "val",
569
+ resources_id=resources_id,
570
+ config=RolloutConfig(
571
+ unresponsive_seconds=self.llm_timeout_seconds,
572
+ timeout_seconds=self.llm_timeout_seconds,
573
+ ),
574
+ metadata=task_metadata,
575
+ )
576
+ )
577
+
578
+ if self.mode == "v1":
579
+ # Enqueue all the tasks in a single batch
580
+ rollouts = await self.store.enqueue_many_rollouts(enqueue_rollout_requests)
581
+ self._task_id_to_original_sample.update(
582
+ {
583
+ # Recover the original data and store it for later use.
584
+ rollout.rollout_id: data_id_to_original_sample[cast(Dict[str, Any], rollout.metadata)["data_id"]]
585
+ for rollout in rollouts
586
+ }
587
+ )
588
+ self._total_tasks_queued += len(rollouts)
589
+
590
+ def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True):
591
+ """Synchronous wrapper for setting up data and server resources."""
592
+ coro = self._async_set_up(data, server_addresses, is_train)
593
+
594
+ if self.mode == "v0":
595
+ if not self.server.loop or not self.server.startup_event.is_set():
596
+ raise RuntimeError("Server is not running or ready.")
597
+
598
+ future = asyncio.run_coroutine_threadsafe(coro, self.server.loop)
599
+
600
+ else:
601
+ if self._internal_loop is None:
602
+ raise RuntimeError("Internal loop is not running.")
603
+ future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop)
604
+ try:
605
+ future.result(timeout=300) # Wait for completion with a timeout
606
+ except Exception as e:
607
+ print(f"Failed to set up data on server: {e}")
608
+ raise
609
+
610
+ def _validate_data(self, rollout: RolloutLegacy):
611
+ if rollout.final_reward is None:
612
+ print(
613
+ f"Warning: Reward is None for rollout {rollout.rollout_id}, will be auto-set to {self.reward_fillna_value}."
614
+ )
615
+ if rollout.triplets is None:
616
+ print(f"Warning: Triplet is None for rollout {rollout.rollout_id}.")
617
+ elif len(rollout.triplets) == 0:
618
+ print(f"Warning: Length of triplets is 0 for rollout {rollout.rollout_id}.")
619
+ elif any(not r.response.get("token_ids", []) for r in rollout.triplets):
620
+ print(f"Warning: Rollout {rollout.rollout_id} contains empty response: {rollout.triplets}")
621
+ elif any(not r.prompt.get("token_ids", []) for r in rollout.triplets):
622
+ print(f"Warning: Rollout {rollout.rollout_id} contains empty prompt: {rollout.triplets}")
623
+
624
+ async def _validate_data_v1(self, rollout: Rollout) -> RolloutLegacy:
625
+ """Convert Rollout to RolloutLegacy and validate.
626
+
627
+ 1. Task: construct from Rollout
628
+ 2. Triplets: obtained by querying spans and feeding into the adapter
629
+ 3. Final reward: extracted from last triplet's reward, searching backwards if not found
630
+ """
631
+ # Query spans for this rollout (latest attempt)
632
+ spans = await self.store.query_spans(rollout.rollout_id, attempt_id="latest")
633
+
634
+ # Convert spans to triplets using the adapter
635
+ if not spans:
636
+ # No triplets found, will emit a warning later.
637
+ triplets = []
638
+ else:
639
+ triplets = self.adapter.adapt(spans)
640
+
641
+ # Extract final reward from triplets
642
+ final_reward: Optional[float] = None
643
+ if triplets:
644
+ # Search backwards through triplets for the first non-None reward
645
+ for triplet in reversed(triplets):
646
+ if triplet.reward is not None:
647
+ final_reward = triplet.reward
648
+ break
649
+
650
+ # Construct the Task object from Rollout
651
+ task = Task(
652
+ rollout_id=rollout.rollout_id,
653
+ input=rollout.input,
654
+ mode=rollout.mode,
655
+ resources_id=rollout.resources_id,
656
+ metadata=rollout.metadata or {},
657
+ )
658
+
659
+ # Create the Rollout object (without trace and logs as per user's note)
660
+ result_rollout = RolloutLegacy(
661
+ rollout_id=rollout.rollout_id,
662
+ task=task,
663
+ final_reward=final_reward,
664
+ triplets=triplets,
665
+ metadata=rollout.metadata or {},
666
+ )
667
+
668
+ # Run the same validation as v0
669
+ self._validate_data(result_rollout)
670
+
671
+ return result_rollout
672
+
673
+ async def _async_run_until_finished(self, verbose: bool = True):
674
+ """Async helper to wait for all tasks to complete."""
675
+ while len(self._completed_rollouts_v0) < self._total_tasks_queued:
676
+ if self.mode == "v0":
677
+ completed_batch = await self.server.retrieve_completed_rollouts()
678
+ else:
679
+ completed_batch = await self.store.wait_for_rollouts(
680
+ rollout_ids=list(self._task_id_to_original_sample.keys()), timeout=0
681
+ )
682
+ for rollout in completed_batch:
683
+ if rollout.rollout_id in self._completed_rollouts_v0:
684
+ # Already processed, skip
685
+ continue
686
+ if isinstance(rollout, Rollout):
687
+ rollout = await self._validate_data_v1(rollout)
688
+ else:
689
+ self._validate_data(rollout)
690
+ if rollout.rollout_id not in self._task_id_to_original_sample:
691
+ print(f"Warning: Received unknown rollout ID {rollout.rollout_id}, skipping.")
692
+ else:
693
+ self._completed_rollouts_v0[rollout.rollout_id] = rollout
694
+ if verbose:
695
+ print(f"Completed {len(self._completed_rollouts_v0)}/{self._total_tasks_queued} tasks...")
696
+ await asyncio.sleep(5)
697
+
698
+ print("All tasks finished.")
699
+
700
+ def run_until_all_finished(self, verbose: bool = True):
701
+ """Synchronously waits for all queued tasks to be completed and reported."""
702
+ if self._total_tasks_queued == 0:
703
+ print("Warning: No tasks were queued.")
704
+ return
705
+
706
+ if self.mode == "v0":
707
+ if not self.server.loop or not self.server.startup_event.is_set():
708
+ raise RuntimeError("Server is not running or ready.")
709
+ loop = self.server.loop
710
+ else:
711
+ loop = self._internal_loop
712
+ assert loop is not None
713
+
714
+ coro = self._async_run_until_finished(verbose)
715
+ future = asyncio.run_coroutine_threadsafe(coro, loop)
716
+ try:
717
+ future.result() # Wait indefinitely for all tasks to complete
718
+ except Exception as e:
719
+ print(f"Error while waiting for tasks to finish: {e}")
720
+ raise
721
+
722
+ def get_test_metrics(self):
723
+ """Calculates and returns metrics for a validation run."""
724
+ assert not self.is_train, "This method should only be called during validation."
725
+ assert len(self._completed_rollouts_v0) == self._total_tasks_queued
726
+
727
+ sample_stat_list: List[Dict[str, Any]] = []
728
+ sample_stat_list_by_source: Dict[str, List[Dict[str, Any]]] = defaultdict(
729
+ list
730
+ ) # FIXME: Evaluate whether grouping stats by source is actually needed.
731
+
732
+ for rollout_id, rollout in self._completed_rollouts_v0.items():
733
+ final_reward_raw: Optional[float] = rollout.final_reward
734
+ final_reward = self._fillna_reward(rollout)
735
+ if not rollout.triplets:
736
+ print(f"Warning: No triplets found for test rollout {rollout.rollout_id}.")
737
+ sample_stat_list.append({"reward": final_reward, "has_reward": final_reward_raw is not None})
738
+ continue
739
+ response_length_list = [len(triplet.response.get("token_ids", [])) for triplet in rollout.triplets]
740
+
741
+ if "data_source" in self._task_id_to_original_sample[rollout_id]:
742
+ # When a test sample includes a 'data_source' field, record per-source statistics for test results.
743
+ # TODO: This is a flawed design. We should have a better way to handle this.
744
+ data_source = self._task_id_to_original_sample[rollout_id]["data_source"]
745
+ sample_stat_list_by_source[data_source].append(
746
+ {
747
+ "sum_response_length": np.sum(response_length_list),
748
+ "mean_response_length": np.mean(response_length_list) if response_length_list else 0,
749
+ "turn_count": len(rollout.triplets),
750
+ "reward": final_reward,
751
+ "has_reward": final_reward_raw is not None,
752
+ }
753
+ )
754
+ sample_stat_list.append(
755
+ {
756
+ "sum_response_length": np.sum(response_length_list),
757
+ "mean_response_length": np.mean(response_length_list) if response_length_list else 0,
758
+ "turn_count": len(rollout.triplets),
759
+ "reward": final_reward,
760
+ "has_reward": final_reward_raw is not None,
761
+ }
762
+ )
763
+ metric_dict: Dict[str, Any] = {}
764
+
765
+ stats_w_trace = [stat for stat in sample_stat_list if "sum_response_length" in stat]
766
+ stats_w_trace_by_source = {
767
+ data_source: [stat for stat in sample_stats if "sum_response_length" in stat]
768
+ for data_source, sample_stats in sample_stat_list_by_source.items()
769
+ }
770
+ for data_source, sample_stats in sample_stat_list_by_source.items():
771
+ metric_dict.update(
772
+ {
773
+ f"val/{data_source}/n_rollouts": len(sample_stats),
774
+ f"val/{data_source}/n_rollouts_w_trace": len(stats_w_trace_by_source[data_source]),
775
+ f"val/{data_source}/n_rollouts_w_reward": len(
776
+ [stat for stat in sample_stats if stat["has_reward"]]
777
+ ),
778
+ f"val/{data_source}/reward": np.mean(
779
+ [stat["reward"] for stat in sample_stats]
780
+ ), # each rollout must have a reward (fillna if missing)
781
+ f"val/{data_source}/mean_response_length": np.mean(
782
+ [stat["mean_response_length"] for stat in stats_w_trace_by_source[data_source]]
783
+ ),
784
+ f"val/{data_source}/sum_response_length": np.mean(
785
+ [stat["sum_response_length"] for stat in stats_w_trace_by_source[data_source]]
786
+ ),
787
+ f"val/{data_source}/turn_count": np.mean(
788
+ [stat["turn_count"] for stat in stats_w_trace_by_source[data_source]]
789
+ ),
790
+ }
791
+ )
792
+ metric_dict.update(
793
+ {
794
+ "val/n_rollouts": len(sample_stat_list),
795
+ "val/n_rollouts_w_trace": len(stats_w_trace),
796
+ "val/n_rollouts_w_reward": len([stat for stat in sample_stat_list if stat["has_reward"]]),
797
+ "val/reward": np.mean(
798
+ [stat["reward"] for stat in sample_stat_list]
799
+ ), # each rollout must have a reward (fillna if missing)
800
+ "val/mean_response_length": np.mean([stat["mean_response_length"] for stat in stats_w_trace]),
801
+ "val/sum_response_length": np.mean([stat["sum_response_length"] for stat in stats_w_trace]),
802
+ "val/turn_count": np.mean([stat["turn_count"] for stat in stats_w_trace]),
803
+ }
804
+ )
805
+ return metric_dict
806
+
807
+ def get_train_data_batch(
808
+ self, max_prompt_length: int, max_response_length: int, device: torch.device, global_steps: int
809
+ ):
810
+ """
811
+ Processes completed rollouts to generate a training data batch.
812
+
813
+ This function reconstructs the logic from the original AgentModeDaemon,
814
+ using data retrieved from the new server architecture. It handles padding,
815
+ truncation, and tensor creation for the PPO training loop.
816
+ """
817
+ assert self.is_train, "This method should only be called during training."
818
+ assert len(self._completed_rollouts_v0) == self._total_tasks_queued
819
+
820
+ # 1. Reconstruct the `finished_id_to_sample_info` structure from completed rollouts
821
+ finished_id_to_sample_info: Dict[str, Dict[str, Any]] = {}
822
+ finished_id_to_final_reward: Dict[str, float] = {}
823
+ sample_with_reward_count = 0
824
+ for rollout_id, rollout in self._completed_rollouts_v0.items():
825
+ original_sample = self._task_id_to_original_sample[rollout_id]
826
+ sample_with_reward_count += int(rollout.final_reward is not None)
827
+ final_reward = self._fillna_reward(rollout)
828
+
829
+ if not rollout.triplets:
830
+ finished_id_to_final_reward[rollout_id] = final_reward
831
+ print(f"Warning: No triplets found for training rollout {rollout.rollout_id}, skipping.")
832
+ continue
833
+
834
+ # The client should report triplets that contain prompt_ids and response_ids.
835
+ # Example triplet.prompt: {"token_ids": [...], "image_urls": [...]}
836
+ # Example triplet.response: {"token_ids": [...]}
837
+ trace_list = [
838
+ {
839
+ "prompt_ids": t.prompt.get("token_ids", []),
840
+ "response_ids": t.response.get("token_ids", []),
841
+ "image_urls": t.prompt.get("image_urls", []),
842
+ }
843
+ for t in rollout.triplets
844
+ ]
845
+ info = {
846
+ "reward": final_reward,
847
+ "trace_list": trace_list,
848
+ "data_id": original_sample["data_id"],
849
+ }
850
+ finished_id_to_sample_info[rollout_id] = info
851
+ finished_id_to_final_reward[rollout_id] = final_reward
852
+ #
853
+ # --- Data processing and tensor creation logic ---
854
+ # Get all the reported data.
855
+ # prompt_ids are left-padded.
856
+ # response_ids are right-padded.
857
+ # They are concatenated in the middle.
858
+ # Discard handling:
859
+ # - Those exceeding max_prompt_length will be marked for discard, but not
860
+ # discarded here. They are only truncated and marked, to be discarded later.
861
+ # This is for the correctness of the advantage calculation.
862
+ # - The discard for the PPO mini-batch should also be handled this way.
863
+ input_ids_list: List[List[int]] = []
864
+ input_attention_mask_list: List[List[int]] = []
865
+ response_ids_list: List[List[int]] = []
866
+ response_attention_mask_list: List[List[int]] = []
867
+ reward_list: List[float] = []
868
+ data_id_list: List[str] = []
869
+ rollout_id_list: List[str] = []
870
+ turn_index_list: List[int] = []
871
+ is_drop_list: List[bool] = []
872
+ image_grid_thw_list: List[Optional[torch.Tensor]] = [] # For Qwen2-VL mrope
873
+ n_trunc_sample_because_of_response = 0
874
+
875
+ if self.trace_aggregator.get("level", "transition") == "transition":
876
+ for rollout_id, sample_info in finished_id_to_sample_info.items():
877
+ for turn_index, trace in enumerate(sample_info["trace_list"]):
878
+
879
+ reward_list.append(sample_info["reward"])
880
+ prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]
881
+
882
+ # Mark samples with prompts exceeding max_prompt_length to be dropped later
883
+ if len(prompt_ids) > max_prompt_length:
884
+ prompt_ids = prompt_ids[:max_prompt_length]
885
+ is_drop_list.append(True)
886
+ else:
887
+ is_drop_list.append(False)
888
+
889
+ # Truncate responses that exceed max_response_length
890
+ if len(response_ids) > max_response_length:
891
+ response_ids = response_ids[:max_response_length]
892
+ n_trunc_sample_because_of_response += 1
893
+
894
+ # Pad prompts to the left and responses to the right
895
+ one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
896
+ prompt_ids, max_prompt_length, self.pad_token_id
897
+ )
898
+ one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
899
+ response_ids, max_response_length, self.pad_token_id
900
+ )
901
+
902
+ input_ids_list.append(one_input_ids)
903
+ input_attention_mask_list.append(one_input_attention_mask)
904
+ response_ids_list.append(one_response_ids)
905
+ response_attention_mask_list.append(one_response_attention_mask)
906
+ data_id_list.append(sample_info["data_id"])
907
+ rollout_id_list.append(rollout_id)
908
+ turn_index_list.append(turn_index)
909
+
910
+ # Compute image_grid_thw for this triplet using image_urls from prompt
911
+ if self._use_mrope:
912
+ image_urls = trace.get("image_urls", [])
913
+ image_grid_thw_list.append(self._get_image_grid_thw(image_urls))
914
+
915
+ elif self.trace_aggregator.get("level", "transition") == "trajectory":
916
+ assert not self._use_mrope, "M-RoPE is not supported in trajectory level yet."
917
+
918
+ response_mask_list: List[List[int]] = []
919
+ unmerged_count: int = 0
920
+ template_mismatch_count, retoken_mismatch_count, others_mismatch_count = 0, 0, 0
921
+ response_per_turn_list: List[int] = []
922
+
923
+ for rollout_id, sample_info in finished_id_to_sample_info.items():
924
+ merged_trace_idx: List[List[int]] = []
925
+
926
+ # Identify which turns can be merged based on token ids prefix matching
927
+ current_merged_trace_idx: List[int] = []
928
+ current_context: List[int] = []
929
+ for turn_index, trace in enumerate(sample_info["trace_list"]):
930
+ response_per_turn_list.append(len(trace["response_ids"]))
931
+ is_prefix, diagnostic = ids_startswith(
932
+ trace["prompt_ids"] + trace["response_ids"],
933
+ current_context,
934
+ self.tokenizer,
935
+ self.trace_aggregator.get("debug", False),
936
+ )
937
+ if not is_prefix and self.trace_aggregator.get("debug", False) == True:
938
+ template_mismatch_count += diagnostic[0]
939
+ retoken_mismatch_count += diagnostic[1]
940
+ others_mismatch_count += diagnostic[2]
941
+ log_mismatch_detail(
942
+ diagnostic,
943
+ trace["prompt_ids"] + trace["response_ids"],
944
+ current_context,
945
+ global_steps,
946
+ rollout_id,
947
+ turn_index,
948
+ self.trace_aggregator.get("mismatch_log_dir", None),
949
+ )
950
+
951
+ if is_prefix:
952
+ current_context = trace["prompt_ids"] + trace["response_ids"]
953
+ current_merged_trace_idx.append(turn_index)
954
+ else:
955
+ merged_trace_idx.append(current_merged_trace_idx)
956
+ current_merged_trace_idx = [turn_index]
957
+ current_context = trace["prompt_ids"] + trace["response_ids"]
958
+
959
+ if current_merged_trace_idx not in merged_trace_idx:
960
+ merged_trace_idx.append(current_merged_trace_idx)
961
+
962
+ if len(merged_trace_idx) > 1:
963
+ unmerged_count += 1
964
+
965
+ # Merge all trace segments in merged_trace_idx into training samples
966
+ for current_merged_trace_idx in merged_trace_idx:
967
+ prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"]
968
+
969
+ # if the merged_trace_idx doesn't start with the beginning of the prompt_ids, we need to adjust it
970
+ if current_merged_trace_idx[0] > 0 and len(prompt_ids) > max_prompt_length:
971
+ response_ids = prompt_ids[max_prompt_length:]
972
+ prompt_ids = prompt_ids[:max_prompt_length]
973
+ response_mask = [1] * len(response_ids)
974
+ else:
975
+ response_ids = []
976
+ response_mask = []
977
+
978
+ prompt_length = len(prompt_ids)
979
+ response_ids += sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"]
980
+ response_mask += [1] * len(response_ids)
981
+ for turn_index in current_merged_trace_idx[1:]:
982
+ trace = sample_info["trace_list"][turn_index]
983
+ new_prompt_length = len(trace["prompt_ids"]) - len(response_ids) - prompt_length
984
+ response_ids += trace["prompt_ids"][-new_prompt_length:]
985
+ response_ids += trace["response_ids"]
986
+ response_mask += [0] * new_prompt_length
987
+ response_mask += [1] * len(trace["response_ids"])
988
+
989
+ reward_list.append(sample_info["reward"])
990
+
991
+ # Mark samples with prompts exceeding max_prompt_length to be dropped later
992
+ if len(prompt_ids) > max_prompt_length:
993
+ prompt_ids = prompt_ids[:max_prompt_length]
994
+ is_drop_list.append(True)
995
+ else:
996
+ is_drop_list.append(False)
997
+
998
+ # Truncate responses that exceed max_response_length
999
+ if len(response_ids) > max_response_length:
1000
+ response_ids = response_ids[:max_response_length]
1001
+ response_mask = response_mask[:max_response_length]
1002
+ n_trunc_sample_because_of_response += 1
1003
+
1004
+ # Pad prompts to the left and responses to the right
1005
+ one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
1006
+ prompt_ids, max_prompt_length, self.pad_token_id
1007
+ )
1008
+ one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
1009
+ response_ids, max_response_length, self.pad_token_id
1010
+ )
1011
+ one_response_mask, _ = get_right_padded_ids_and_attention_mask(
1012
+ response_mask, max_response_length, 0
1013
+ )
1014
+
1015
+ input_ids_list.append(one_input_ids)
1016
+ input_attention_mask_list.append(one_input_attention_mask)
1017
+ response_ids_list.append(one_response_ids)
1018
+ response_attention_mask_list.append(one_response_attention_mask)
1019
+ response_mask_list.append(one_response_mask)
1020
+ data_id_list.append(sample_info["data_id"])
1021
+ rollout_id_list.append(rollout_id)
1022
+ # turn_index_list.append(current_merged_trace_idx)
1023
+ else:
1024
+ raise ValueError(f"Unknown trace_aggregator level: {self.trace_aggregator.get('level')}")
1025
+
1026
+ n_transition = len(input_ids_list)
1027
+ batch_input_ids = torch.LongTensor(input_ids_list).to(device)
1028
+ input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
1029
+ batch_response_ids = torch.LongTensor(response_ids_list).to(device)
1030
+ response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)
1031
+ response_mask = (
1032
+ torch.LongTensor(response_mask_list).to(device) if self.trace_aggregator.get("level", "transition") == "trajectory" else None # type: ignore
1033
+ )
1034
+
1035
+ # Concatenate prompts and responses to form the full sequence
1036
+ batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
1037
+ attention_mask = torch.cat([input_attention_mask, response_attention_mask], dim=-1)
1038
+
1039
+ # Compute position_ids - use mrope for Qwen2-VL, standard 2D otherwise
1040
+ if self._use_mrope:
1041
+ # For Qwen2-VL: compute 4D position_ids (batch_size, 4, seq_length)
1042
+ position_ids_list: list[torch.Tensor] = []
1043
+ for i in range(n_transition):
1044
+ pos_ids = self._compute_mrope_position_ids(
1045
+ input_ids=batch_seq[i],
1046
+ attention_mask=attention_mask[i],
1047
+ image_grid_thw=image_grid_thw_list[i] if image_grid_thw_list else None,
1048
+ ) # (4, seq_length)
1049
+ position_ids_list.append(pos_ids)
1050
+ # Stack to (batch_size, 4, seq_length)
1051
+ position_ids = torch.stack(position_ids_list, dim=0)
1052
+ else:
1053
+ # Standard 2D position_ids (batch_size, seq_length)
1054
+ position_ids = torch.clamp(torch.cumsum(attention_mask, dim=-1) - 1, min=0)
1055
+
1056
+ is_drop_mask = torch.BoolTensor(is_drop_list).to(device)
1057
+ scores = torch.tensor(reward_list, dtype=torch.bfloat16).to(device)
1058
+
1059
+ # Create token-level scores by placing the final reward at the last token position
1060
+ token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype)
1061
+ # For mrope (3D position_ids), use the first dimension (text position_ids) for eos calculation
1062
+ if self._use_mrope:
1063
+ # position_ids is (batch_size, 4, seq_length), use first dim for text positions
1064
+ text_position_ids = position_ids[:, 0, :] # (batch_size, seq_length)
1065
+ eos_mask_idx = torch.argmax(text_position_ids * attention_mask, dim=-1) # (bsz,)
1066
+ else:
1067
+ eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
1068
+ # At the eos_mask_idx position of each sample, fill in the corresponding scores.
1069
+ # torch.arange(n_transition) generates [0,1,2,...,bsz-1] as indices for the batch dimension.
1070
+ token_level_scores[torch.arange(n_transition), eos_mask_idx] = scores
1071
+ # Only take the last response_length part of the sequence to get the token-level scores for the model's response part.
1072
+ token_level_scores = token_level_scores[:, -max_response_length:]
1073
+
1074
+ # Form the final batch using TensorDict
1075
+ batch = TensorDict(
1076
+ {
1077
+ "prompts": batch_input_ids,
1078
+ "responses": batch_response_ids,
1079
+ "input_ids": batch_seq, # here input_ids become the whole sentences
1080
+ "attention_mask": attention_mask,
1081
+ "position_ids": position_ids,
1082
+ "is_drop_mask": is_drop_mask,
1083
+ "token_level_scores": token_level_scores.contiguous(),
1084
+ **(
1085
+ {"response_mask": response_mask}
1086
+ if self.trace_aggregator.get("level", "transition") == "trajectory"
1087
+ else {}
1088
+ ),
1089
+ }, # type: ignore
1090
+ batch_size=n_transition,
1091
+ )
1092
+ data_proto = DataProto(batch=batch)
1093
+
1094
+ data_metrics = {
1095
+ "training/reward": np.mean(list(finished_id_to_final_reward.values())),
1096
+ "training/n_rollouts": len(finished_id_to_final_reward),
1097
+ "training/n_rollouts_w_trace": len(finished_id_to_sample_info),
1098
+ "training/n_rollouts_w_reward": sample_with_reward_count,
1099
+ "training/n_truncated_triplets": n_trunc_sample_because_of_response,
1100
+ "training/n_triplets": n_transition,
1101
+ # log data, only for debug testing
1102
+ **(
1103
+ {
1104
+ "training/n_unmerged_rollouts": unmerged_count, # type: ignore
1105
+ "training/n_triplets_by_turn": len(response_per_turn_list), # type: ignore
1106
+ "training/avg_response_length_by_turn": np.mean(response_per_turn_list), # type: ignore
1107
+ "training/max_response_length_by_turn": np.max(response_per_turn_list), # type: ignore
1108
+ "training/min_response_length_by_turn": np.min(response_per_turn_list), # type: ignore
1109
+ }
1110
+ if self.trace_aggregator.get("level", "transition") == "trajectory"
1111
+ else {}
1112
+ ),
1113
+ **(
1114
+ {
1115
+ "training/template_mismatch_triplets": template_mismatch_count, # type: ignore
1116
+ "training/retoken_mismatch_triplets": retoken_mismatch_count, # type: ignore
1117
+ "training/others_mismatch_triplets": others_mismatch_count, # type: ignore
1118
+ "training/template_mismatch_ratio": template_mismatch_count / len(response_per_turn_list), # type: ignore
1119
+ "training/retoken_mismatch_ratio": retoken_mismatch_count / len(response_per_turn_list), # type: ignore
1120
+ "training/others_mismatch_ratio": others_mismatch_count / len(response_per_turn_list), # type: ignore
1121
+ }
1122
+ if self.trace_aggregator.get("level", "transition") == "trajectory"
1123
+ and self.trace_aggregator.get("debug", False)
1124
+ else {}
1125
+ ),
1126
+ }
1127
+
1128
+ # Add non-tensor data for advantage calculation and logging
1129
+ data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list) # type: ignore
1130
+ data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list) # type: ignore
1131
+ if self.trace_aggregator.get("level", "transition") == "transition":
1132
+ data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list) # type: ignore
1133
+
1134
+ return data_proto, data_metrics
1135
+
1136
+ def clear_data_and_server(self):
1137
+ """Resets the internal state of the daemon for the next run."""
1138
+ self.backend_llm_server_addresses = []
1139
+ self._completed_rollouts_v0.clear()
1140
+ self._task_id_to_original_sample.clear()
1141
+ self._total_tasks_queued = 0
1142
+ # For a true reset, the server's internal queues would also need clearing.
1143
+ # This implementation assumes that `set_up_data_and_server` is called
1144
+ # for each new run, effectively starting a fresh batch.
1145
+
1146
+ def _fillna_reward(self, rollout: RolloutLegacy):
1147
+ if rollout.final_reward is None:
1148
+ if self.reward_fillna_value is not None: # type: ignore
1149
+ final_reward = self.reward_fillna_value
1150
+ else:
1151
+ raise ValueError(f"Reward is None for rollout {rollout.rollout_id}, please check the reward function.")
1152
+ else:
1153
+ final_reward = rollout.final_reward
1154
+ return final_reward