mantisdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mantisdk might be problematic. Click here for more details.

Files changed (190) hide show
  1. mantisdk/__init__.py +22 -0
  2. mantisdk/adapter/__init__.py +15 -0
  3. mantisdk/adapter/base.py +94 -0
  4. mantisdk/adapter/messages.py +270 -0
  5. mantisdk/adapter/triplet.py +1028 -0
  6. mantisdk/algorithm/__init__.py +39 -0
  7. mantisdk/algorithm/apo/__init__.py +5 -0
  8. mantisdk/algorithm/apo/apo.py +889 -0
  9. mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
  10. mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
  11. mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
  12. mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
  13. mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
  14. mantisdk/algorithm/base.py +162 -0
  15. mantisdk/algorithm/decorator.py +264 -0
  16. mantisdk/algorithm/fast.py +250 -0
  17. mantisdk/algorithm/gepa/__init__.py +59 -0
  18. mantisdk/algorithm/gepa/adapter.py +459 -0
  19. mantisdk/algorithm/gepa/gepa.py +364 -0
  20. mantisdk/algorithm/gepa/lib/__init__.py +18 -0
  21. mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
  22. mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
  23. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
  24. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
  25. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
  26. mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
  27. mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
  28. mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
  29. mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
  30. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
  31. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
  32. mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
  33. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
  34. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
  35. mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
  36. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
  37. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
  38. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
  39. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
  40. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
  41. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
  42. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
  43. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
  44. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
  45. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
  46. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
  47. mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
  48. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
  49. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
  50. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
  51. mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
  52. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
  53. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
  54. mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
  55. mantisdk/algorithm/gepa/lib/api.py +375 -0
  56. mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
  57. mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
  58. mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
  59. mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
  60. mantisdk/algorithm/gepa/lib/core/result.py +233 -0
  61. mantisdk/algorithm/gepa/lib/core/state.py +636 -0
  62. mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
  63. mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
  64. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
  65. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
  66. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
  67. mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
  68. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
  69. mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
  70. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
  71. mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
  72. mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
  73. mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
  74. mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
  75. mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
  76. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
  77. mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
  78. mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
  79. mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
  80. mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
  81. mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
  82. mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
  83. mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
  84. mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
  85. mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
  86. mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
  87. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
  88. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
  89. mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
  90. mantisdk/algorithm/gepa/lib/py.typed +0 -0
  91. mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
  92. mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
  93. mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
  94. mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
  95. mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
  96. mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
  97. mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
  98. mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
  99. mantisdk/algorithm/gepa/tracing.py +105 -0
  100. mantisdk/algorithm/utils.py +177 -0
  101. mantisdk/algorithm/verl/__init__.py +5 -0
  102. mantisdk/algorithm/verl/interface.py +202 -0
  103. mantisdk/cli/__init__.py +56 -0
  104. mantisdk/cli/prometheus.py +115 -0
  105. mantisdk/cli/store.py +131 -0
  106. mantisdk/cli/vllm.py +29 -0
  107. mantisdk/client.py +408 -0
  108. mantisdk/config.py +348 -0
  109. mantisdk/emitter/__init__.py +43 -0
  110. mantisdk/emitter/annotation.py +370 -0
  111. mantisdk/emitter/exception.py +54 -0
  112. mantisdk/emitter/message.py +61 -0
  113. mantisdk/emitter/object.py +117 -0
  114. mantisdk/emitter/reward.py +320 -0
  115. mantisdk/env_var.py +156 -0
  116. mantisdk/execution/__init__.py +15 -0
  117. mantisdk/execution/base.py +64 -0
  118. mantisdk/execution/client_server.py +443 -0
  119. mantisdk/execution/events.py +69 -0
  120. mantisdk/execution/inter_process.py +16 -0
  121. mantisdk/execution/shared_memory.py +282 -0
  122. mantisdk/instrumentation/__init__.py +119 -0
  123. mantisdk/instrumentation/agentops.py +314 -0
  124. mantisdk/instrumentation/agentops_langchain.py +45 -0
  125. mantisdk/instrumentation/litellm.py +83 -0
  126. mantisdk/instrumentation/vllm.py +81 -0
  127. mantisdk/instrumentation/weave.py +500 -0
  128. mantisdk/litagent/__init__.py +11 -0
  129. mantisdk/litagent/decorator.py +536 -0
  130. mantisdk/litagent/litagent.py +252 -0
  131. mantisdk/llm_proxy.py +1890 -0
  132. mantisdk/logging.py +370 -0
  133. mantisdk/reward.py +7 -0
  134. mantisdk/runner/__init__.py +11 -0
  135. mantisdk/runner/agent.py +845 -0
  136. mantisdk/runner/base.py +182 -0
  137. mantisdk/runner/legacy.py +309 -0
  138. mantisdk/semconv.py +170 -0
  139. mantisdk/server.py +401 -0
  140. mantisdk/store/__init__.py +23 -0
  141. mantisdk/store/base.py +897 -0
  142. mantisdk/store/client_server.py +2092 -0
  143. mantisdk/store/collection/__init__.py +30 -0
  144. mantisdk/store/collection/base.py +587 -0
  145. mantisdk/store/collection/memory.py +970 -0
  146. mantisdk/store/collection/mongo.py +1412 -0
  147. mantisdk/store/collection_based.py +1823 -0
  148. mantisdk/store/insight.py +648 -0
  149. mantisdk/store/listener.py +58 -0
  150. mantisdk/store/memory.py +396 -0
  151. mantisdk/store/mongo.py +165 -0
  152. mantisdk/store/sqlite.py +3 -0
  153. mantisdk/store/threading.py +357 -0
  154. mantisdk/store/utils.py +142 -0
  155. mantisdk/tracer/__init__.py +16 -0
  156. mantisdk/tracer/agentops.py +242 -0
  157. mantisdk/tracer/base.py +287 -0
  158. mantisdk/tracer/dummy.py +106 -0
  159. mantisdk/tracer/otel.py +555 -0
  160. mantisdk/tracer/weave.py +677 -0
  161. mantisdk/trainer/__init__.py +6 -0
  162. mantisdk/trainer/init_utils.py +263 -0
  163. mantisdk/trainer/legacy.py +367 -0
  164. mantisdk/trainer/registry.py +12 -0
  165. mantisdk/trainer/trainer.py +618 -0
  166. mantisdk/types/__init__.py +6 -0
  167. mantisdk/types/core.py +553 -0
  168. mantisdk/types/resources.py +204 -0
  169. mantisdk/types/tracer.py +515 -0
  170. mantisdk/types/tracing.py +218 -0
  171. mantisdk/utils/__init__.py +1 -0
  172. mantisdk/utils/id.py +18 -0
  173. mantisdk/utils/metrics.py +1025 -0
  174. mantisdk/utils/otel.py +578 -0
  175. mantisdk/utils/otlp.py +536 -0
  176. mantisdk/utils/server_launcher.py +1045 -0
  177. mantisdk/utils/system_snapshot.py +81 -0
  178. mantisdk/verl/__init__.py +8 -0
  179. mantisdk/verl/__main__.py +6 -0
  180. mantisdk/verl/async_server.py +46 -0
  181. mantisdk/verl/config.yaml +27 -0
  182. mantisdk/verl/daemon.py +1154 -0
  183. mantisdk/verl/dataset.py +44 -0
  184. mantisdk/verl/entrypoint.py +248 -0
  185. mantisdk/verl/trainer.py +549 -0
  186. mantisdk-0.1.0.dist-info/METADATA +119 -0
  187. mantisdk-0.1.0.dist-info/RECORD +190 -0
  188. mantisdk-0.1.0.dist-info/WHEEL +4 -0
  189. mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
  190. mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
@@ -0,0 +1,1028 @@
1
+ # Copyright (c) Microsoft. All rights reserved.
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import re
8
+ from enum import Enum
9
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
10
+
11
+ from opentelemetry.sdk.trace import ReadableSpan
12
+ from pydantic import BaseModel
13
+
14
+ from mantisdk.emitter.reward import get_reward_value
15
+ from mantisdk.semconv import AGL_OPERATION, AGL_REWARD, LightningSpanAttributes
16
+ from mantisdk.types import Span, Triplet
17
+ from mantisdk.utils.otel import filter_and_unflatten_attributes
18
+
19
+ from .base import TraceAdapter
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def _attributes_get_multiple(attributes: Dict[str, Any], keys: List[str]) -> Optional[str]:
25
+ """Get a string from the attributes, if present.
26
+ If there are multiple matches, the first one is returned.
27
+ """
28
+ for key in keys:
29
+ if key in attributes:
30
+ if isinstance(attributes[key], str):
31
+ return attributes[key]
32
+ else:
33
+ logger.warning(f"Attribute {key} is found but is not a string: {attributes[key]}")
34
+ return None
35
+
36
+
37
+ def _attributes_get_ids_multiple(attributes: Dict[str, Any], keys: List[str]) -> Optional[List[int]]:
38
+ """Get a list of integers from the attributes, if present.
39
+ If there are multiple matches, the first one is returned.
40
+ """
41
+ for key in keys:
42
+ if key in attributes:
43
+ if (isinstance(attributes[key], list) or isinstance(attributes[key], tuple)) and all(
44
+ isinstance(x, int) for x in attributes[key]
45
+ ):
46
+ return list(attributes[key])
47
+ else:
48
+ logger.warning(f"Attribute {key} is found but is not a list of integers: {attributes[key]}")
49
+ return None
50
+
51
+
52
+ def _attributes_unflatten_multiple(
53
+ attributes: Dict[str, Any], keys: List[str]
54
+ ) -> Union[Dict[str, Any], List[Any], None]:
55
+ """Unflatten the attributes, if present.
56
+ If there are multiple matches, the first one is returned.
57
+ """
58
+ for key in keys:
59
+ result = filter_and_unflatten_attributes(attributes, key)
60
+ if result:
61
+ return result
62
+ return None
63
+
64
+
65
+ class Transition(BaseModel):
66
+ """A single transition within a reinforcement learning trajectory.
67
+
68
+ Attributes:
69
+ state: Token identifiers describing the model input state.
70
+ action: Token identifiers representing the model output.
71
+ response_id: Identifier of the LLM response used to deduplicate spans.
72
+ agent_name: Human-readable agent name captured from the trace.
73
+ reward: Scalar reward associated with the transition, if available.
74
+ """
75
+
76
+ state: List[int]
77
+ action: List[int]
78
+ response_id: Optional[str]
79
+ # action_logprobs: List[float]
80
+ agent_name: str
81
+ reward: Optional[float]
82
+
83
+
84
+ class RewardMatchPolicy(str, Enum):
85
+ """Strategies for matching rewards to LLM call spans.
86
+
87
+ !!! note
88
+ Each reward span must expose a payload shaped like `{"type": "reward", "value": <float>|None}`
89
+ as described in `reward.py`.
90
+ """
91
+
92
+ FIRST_SIBLING = "first_sibling"
93
+ """Use the first sibling in the current trace subtree as the reward unless another LLM call match is found."""
94
+
95
+ FIRST_OCCURRENCE = "first_occurrence"
96
+ """Use the first reward encountered in chronological order after the current LLM call match."""
97
+
98
+
99
+ class TraceTree:
100
+ """Tree representation of a trace span and its descendants.
101
+
102
+ Attributes:
103
+ id: Unique identifier for the span node.
104
+ span: [`Span`][mantisdk.Span] backing this node.
105
+ children: Child nodes connected to the current span.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ id: str,
111
+ span: Span,
112
+ children: Optional[List["TraceTree"]] = None,
113
+ ):
114
+ self.id = id
115
+ self.span = span
116
+ self.children = children or []
117
+
118
+ @property
119
+ def start_time(self):
120
+ return self.span.start_time
121
+
122
+ @property
123
+ def end_time(self):
124
+ return self.span.end_time
125
+
126
+ def find_id(self, id: str) -> "TraceTree | None":
127
+ if self.id == id:
128
+ return self
129
+ for child in self.children:
130
+ found = child.find_id(id)
131
+ if found:
132
+ return found
133
+ return None
134
+
135
+ def add_child(self, child: "TraceTree") -> None:
136
+ self.children.append(child)
137
+
138
+ def visualize(self, filename: str, interested_span_match: str | None = None) -> None:
139
+ """Render the trace tree with Graphviz for debugging purposes.
140
+
141
+ Args:
142
+ filename: Base filename for the generated `.png` diagram.
143
+ interested_span_match: Optional regular expression used to keep only matching spans
144
+ (and their ancestors) in the output.
145
+
146
+ !!! note
147
+ The method requires the optional `graphviz` dependency to be available in the runtime
148
+ environment.
149
+ """
150
+ import graphviz
151
+
152
+ dot = graphviz.Digraph(comment="Trace Tree")
153
+
154
+ should_visit_cache: Dict[str, bool] = {}
155
+
156
+ def should_visit(node: "TraceTree") -> bool:
157
+ if node.id in should_visit_cache:
158
+ return should_visit_cache[node.id]
159
+ if interested_span_match is not None:
160
+ if re.search(interested_span_match, node.span.name):
161
+ should_visit_cache[node.id] = True
162
+ return True
163
+ else:
164
+ should_visit_cache[node.id] = False
165
+ for child in node.children:
166
+ if should_visit(child):
167
+ should_visit_cache[node.id] = True
168
+
169
+ return should_visit_cache[node.id]
170
+ else:
171
+ return True
172
+
173
+ def visit(node: "TraceTree") -> bool:
174
+ if not should_visit(node):
175
+ return False
176
+ agent_name = node.agent_name()
177
+ vis_name = node.id[-8:] + " (" + node.span.name + ")"
178
+ if agent_name is not None:
179
+ vis_name += " [" + agent_name + "]"
180
+ dot.node(node.id, vis_name) # type: ignore
181
+ for child in node.children:
182
+ if visit(child):
183
+ dot.edge(node.id, child.id) # type: ignore
184
+ return True
185
+
186
+ visit(self)
187
+ dot.render(filename, format="png", cleanup=True) # type: ignore
188
+
189
+ def names_tuple(self) -> Tuple[str, List[Any]]:
190
+ """Return the span name alongside nested child names.
191
+
192
+ Returns:
193
+ A tuple of the current span name and a list of tuples for each child containing the
194
+ child name and its descendants.
195
+ """
196
+ name = self.span.name
197
+ agent_name = self.agent_name()
198
+ if agent_name is not None:
199
+ name += " [" + agent_name + "]"
200
+ children_names: List[Tuple[str, List[Any]]] = []
201
+ for child in self.children:
202
+ child_name, child_children = child.names_tuple()
203
+ children_names.append((child_name, child_children))
204
+ return name, children_names
205
+
206
+ def traverse(self) -> List["TraceTree"]:
207
+ """Traverse the tree depth first and return every node."""
208
+ spans: List["TraceTree"] = [self]
209
+ for child in self.children:
210
+ spans.extend(child.traverse())
211
+ return spans
212
+
213
+ def to_json(self) -> dict[str, Any]:
214
+ """Convert the tree node into a JSON-serialisable structure."""
215
+ if isinstance(self.span, ReadableSpan):
216
+ span_data = json.loads(self.span.to_json())
217
+ else:
218
+ span_data = self.span.model_dump()
219
+ return {
220
+ "id": self.id,
221
+ "span": span_data,
222
+ "children": [child.to_json() for child in self.children],
223
+ }
224
+
225
+ @classmethod
226
+ def from_spans(cls, spans: List[Span]) -> "TraceTree":
227
+ """Construct a tree from a flat list of spans.
228
+
229
+ Args:
230
+ spans: Spans that collectively form a single trace segment.
231
+
232
+ Returns:
233
+ A [`TraceTree`][mantisdk.adapter.triplet.TraceTree] rooted at either the
234
+ discovered root span or a synthetic root when multiple roots are present.
235
+
236
+ Raises:
237
+ ValueError: If the span list is empty or no root span can be inferred.
238
+ """
239
+
240
+ if not spans:
241
+ raise ValueError("No spans provided to create TraceTree.")
242
+
243
+ # Process trace items in topological order
244
+ id_to_span = {span.span_id: span for span in spans}
245
+
246
+ forward_graph: dict[str, list[str]] = {}
247
+ root_ids: list[str] = []
248
+ for span in spans:
249
+ span_id = span.span_id
250
+ if span.parent_id is None:
251
+ root_ids.append(span.span_id)
252
+ else:
253
+ if span.parent_id not in forward_graph:
254
+ forward_graph[span.parent_id] = []
255
+ forward_graph[span.parent_id].append(span_id)
256
+
257
+ # Diff between span with data and forward_graph keys
258
+ # Sometimes the top-level session span is lost.
259
+ unfound_roots = set(forward_graph.keys()) - set(id_to_span.keys())
260
+ for unfound_root in unfound_roots:
261
+ root_ids.append(unfound_root)
262
+
263
+ def visit(node_id: str) -> "TraceTree":
264
+ children: list[TraceTree] = []
265
+ if node_id in forward_graph:
266
+ for child_id in forward_graph[node_id]:
267
+ children.append(visit(child_id))
268
+
269
+ if node_id not in id_to_span:
270
+ assert len(children) > 0
271
+ virtual_span = Span.from_attributes(
272
+ rollout_id=children[0].span.rollout_id,
273
+ attempt_id=children[0].span.attempt_id,
274
+ sequence_id=children[0].span.sequence_id,
275
+ trace_id=children[0].span.trace_id,
276
+ span_id=node_id,
277
+ parent_id=None,
278
+ attributes={},
279
+ start_time=min(child.start_time for child in children if child.start_time is not None),
280
+ end_time=max(child.end_time for child in children if child.end_time is not None),
281
+ )
282
+ return cls(node_id, virtual_span, children=children)
283
+ else:
284
+ return cls(
285
+ node_id,
286
+ id_to_span[node_id],
287
+ children=children,
288
+ )
289
+
290
+ # Create a virtual root span if multiple root spans are found
291
+ if len(root_ids) > 1:
292
+ root_spans = [visit(root_id) for root_id in root_ids]
293
+ virtual_root = TraceTree(
294
+ id="virtual-root",
295
+ span=Span.from_attributes(
296
+ rollout_id=root_spans[0].span.rollout_id,
297
+ attempt_id=root_spans[0].span.attempt_id,
298
+ sequence_id=root_spans[0].span.sequence_id,
299
+ trace_id=root_spans[0].span.trace_id,
300
+ span_id=None, # Generate one
301
+ parent_id=None,
302
+ name="virtual-root",
303
+ attributes={},
304
+ start_time=root_spans[0].start_time,
305
+ end_time=root_spans[-1].end_time,
306
+ ),
307
+ children=root_spans,
308
+ )
309
+ return virtual_root
310
+ elif len(root_ids) == 0:
311
+ # No root spans found
312
+ raise ValueError("No root spans found in the trace.")
313
+ else:
314
+ root_span = visit(root_ids[0])
315
+ return root_span
316
+
317
+ def agent_name(self) -> Optional[str]:
318
+ """Return the agent name associated with the span, if any.
319
+
320
+ Returns:
321
+ Agent name extracted from known attributes, otherwise `None`.
322
+ """
323
+ attributes = self.span.attributes
324
+ if attributes is None: # type: ignore
325
+ return None
326
+
327
+ # Case 1: OpenAI Agent SDK
328
+ agent_name = cast(Optional[str], attributes.get("agent.name"))
329
+ if agent_name is not None:
330
+ return agent_name
331
+
332
+ # Case 2: Agentops decorator @agent
333
+ is_agent = attributes.get("agentops.span.kind") == "agent"
334
+ if is_agent:
335
+ agent_name = cast(Optional[str], attributes.get("operation.name"))
336
+ if agent_name is not None:
337
+ return agent_name
338
+
339
+ # Case 3: Autogen team
340
+ agent_name = cast(Optional[str], attributes.get("recipient_agent_type"))
341
+ if agent_name is not None:
342
+ return agent_name
343
+
344
+ # Case 4: LangGraph
345
+ agent_name = cast(Optional[str], attributes.get("langchain.chain.type"))
346
+ if agent_name is not None:
347
+ return agent_name
348
+
349
+ # Case 5: agent-framework
350
+ agent_name = cast(Optional[str], attributes.get("executor.id"))
351
+ if agent_name is not None:
352
+ return agent_name
353
+
354
+ # Case 6: Weave
355
+ is_agent_type = attributes.get("type") == "agent"
356
+ if is_agent_type:
357
+ agent_name = cast(Optional[str], attributes.get("mantisdk.operation.input.name"))
358
+ if agent_name is not None:
359
+ return agent_name
360
+
361
+ # Case 7: Weave + LangChain
362
+ if self.span.name.startswith("langchain.Chain."):
363
+ attributes_lc_name = cast(Optional[str], attributes.get("lc_name"))
364
+ if attributes_lc_name is not None:
365
+ return attributes_lc_name
366
+
367
+ def maybe_reward_dict(self) -> dict[str, Any]:
368
+ """Return a reward payload if the span encodes one.
369
+
370
+ Returns:
371
+ Dictionary containing reward metadata, or an empty dictionary when no reward is found.
372
+ """
373
+ reward_value = get_reward_value(self.span)
374
+ if reward_value is not None:
375
+ return {"type": "reward", "value": reward_value}
376
+ else:
377
+ return {}
378
+
379
+ def is_reward_span(self) -> bool:
380
+ """Return whether the span explicitly encodes a reward.
381
+
382
+ Returns:
383
+ `True` when the span payload describes a reward, otherwise `False`.
384
+ """
385
+ maybe_reward = self.maybe_reward_dict()
386
+ if maybe_reward and maybe_reward.get("type") == "reward": # type: ignore
387
+ return True
388
+
389
+ # Mantisdk 0.3+
390
+ if (
391
+ self.span.name == AGL_OPERATION
392
+ and self.span.attributes.get(LightningSpanAttributes.OPERATION_NAME.value) == AGL_REWARD
393
+ ):
394
+ return True
395
+
396
+ return False
397
+
398
+ def find_llm_calls(
399
+ self,
400
+ *,
401
+ llm_call_match: str,
402
+ agent_match: Optional[str],
403
+ within_matching_subtree: str | None = None,
404
+ within_reward: Optional[bool] = None,
405
+ within_llm_call: Optional[bool] = None,
406
+ existing_llm_call_response_ids: Optional[set[str]] = None,
407
+ ) -> List[Tuple["TraceTree", str]]:
408
+ """Find LLM call spans matching the supplied filters.
409
+
410
+ Args:
411
+ llm_call_match: Regular expression used to match span names that qualify as LLM calls.
412
+ agent_match: Optional regular expression that must match the enclosing agent span name.
413
+ within_matching_subtree: Marker propagated through recursive calls to record matching agents.
414
+ within_reward: When `True`, suppresses LLM matches under reward spans.
415
+ within_llm_call: When `True`, prevents duplicate matches for nested LLM calls.
416
+ existing_llm_call_response_ids: Known response identifiers used to deduplicate spans.
417
+
418
+ Returns:
419
+ A list of tuples pairing the matching node with the agent subtree label that triggered the
420
+ match.
421
+ """
422
+ llm_calls: List[Tuple[TraceTree, str]] = []
423
+
424
+ is_llm_call = True
425
+ if within_matching_subtree is None or within_reward is True:
426
+ # We must be in an interesting agent subtree, and not in a reward span.
427
+ is_llm_call = False
428
+ if re.search(llm_call_match, self.span.name) is None:
429
+ # The span name does not match the LLM call match.
430
+ is_llm_call = False
431
+ if is_llm_call:
432
+ # Check the response id
433
+ response_id = _attributes_get_multiple(
434
+ self.span.attributes, ["gen_ai.response.id", "mantisdk.operation.output.id"]
435
+ )
436
+ if response_id is None and within_llm_call is True:
437
+ is_llm_call = False
438
+ if (
439
+ response_id is not None
440
+ and existing_llm_call_response_ids is not None
441
+ and response_id in existing_llm_call_response_ids
442
+ ):
443
+ is_llm_call = False
444
+
445
+ if is_llm_call:
446
+ llm_calls.append((self, within_matching_subtree)) # type: ignore
447
+ if existing_llm_call_response_ids is None:
448
+ existing_llm_call_response_ids = set()
449
+ if response_id is not None:
450
+ existing_llm_call_response_ids.add(response_id)
451
+ if within_llm_call is not None:
452
+ within_llm_call = True
453
+
454
+ agent_name = self.agent_name()
455
+ if agent_name is not None:
456
+ if agent_match is None or re.search(agent_match, agent_name):
457
+ within_matching_subtree = agent_name
458
+ else:
459
+ within_matching_subtree = None
460
+
461
+ if within_reward is not None and self.is_reward_span():
462
+ within_reward = True
463
+
464
+ for child in self.children:
465
+ llm_calls.extend(
466
+ child.find_llm_calls(
467
+ llm_call_match=llm_call_match,
468
+ agent_match=agent_match,
469
+ within_matching_subtree=within_matching_subtree,
470
+ within_reward=within_reward,
471
+ within_llm_call=within_llm_call,
472
+ existing_llm_call_response_ids=existing_llm_call_response_ids,
473
+ )
474
+ )
475
+
476
+ return llm_calls
477
+
478
+ def repair_hierarchy(self) -> None:
479
+ """Repair missing parent-child relationships introduced by mixed tracing systems.
480
+
481
+ Some agent frameworks emit spans via multiple subsystems, which can cause LLM completion
482
+ spans to float directly under the root span instead of being nested under the correct agent.
483
+ The method re-parents those spans to the closest ancestor that fully envelopes the child in
484
+ time.
485
+
486
+ If we don't, when we want to select the LLM completion span with agent as filter.
487
+ We will never get the correct span underneath.
488
+ """
489
+ # If the current node has only one child, recursively repair its hierarchy directly.
490
+ # This special-case handling is needed because when a trace is manually ended
491
+ # (via agentops.end_trace), the AgentOps provider automatically wraps all spans
492
+ # under an extra synthetic root node (e.g., "run_one.session").
493
+ if len(self.children) == 1:
494
+ self.children[0].repair_hierarchy()
495
+ return
496
+
497
+ nodes_to_repair = list(self.children)
498
+
499
+ for repair_node in nodes_to_repair:
500
+ if len(self.children) == 1:
501
+ # If there is only one child, we don't need to repair the hierarchy.
502
+ break
503
+ # Find the closest parent span (but not the root itself)
504
+ closest_parent = None
505
+ closest_duration = float("inf")
506
+ for node in self.traverse():
507
+ if node.id == repair_node.id:
508
+ continue
509
+ if node is self:
510
+ continue
511
+ if node.start_time <= repair_node.start_time and node.end_time >= repair_node.end_time: # type: ignore
512
+ duration_delta = node.end_time - repair_node.end_time + repair_node.start_time - node.start_time # type: ignore
513
+ if duration_delta > 0 and duration_delta < closest_duration:
514
+ closest_duration = duration_delta # type: ignore
515
+ closest_parent = node
516
+
517
+ # Repair the hierarchy
518
+ if closest_parent is not None:
519
+ self.children.remove(repair_node)
520
+ closest_parent.children.append(repair_node)
521
+
522
+ def match_rewards(self, reward_match: str, llm_calls: List["TraceTree"]) -> dict[str, Optional[float]]:
523
+ """Assign rewards to previously matched LLM calls.
524
+
525
+ Args:
526
+ reward_match: Strategy identifier from
527
+ [`RewardMatchPolicy`][mantisdk.adapter.triplet.RewardMatchPolicy].
528
+ llm_calls: Trace nodes representing LLM call spans.
529
+
530
+ Returns:
531
+ Mapping from span identifier to reward value or `None` when no reward is available.
532
+ """
533
+ llm_call_ids = set([llm_call.id for llm_call in llm_calls])
534
+ rewards: dict[str, Optional[float]] = {}
535
+
536
+ if reward_match == RewardMatchPolicy.FIRST_OCCURRENCE:
537
+ time_sorted: List[TraceTree] = cast(List[TraceTree], sorted(self.traverse(), key=lambda x: x.start_time)) # type: ignore
538
+ assign_to: List[Tuple[str, int]] = [] # type: ignore
539
+ for item in time_sorted:
540
+ if item.id in llm_call_ids:
541
+ assign_to.append((item.id, item.end_time)) # type: ignore
542
+
543
+ # get reward
544
+ agentops_output = item.maybe_reward_dict()
545
+ if agentops_output and agentops_output.get("type") == "reward":
546
+ for assign_to_id, assign_to_end_time in reversed(assign_to):
547
+ # This reward happens before the end of the LLM call.
548
+ if assign_to_end_time > item.start_time: # type: ignore
549
+ continue
550
+ # Ok, we found someone to assign to
551
+ if assign_to_id in rewards:
552
+ # If the reward is already set, skip
553
+ continue
554
+ rewards[assign_to_id] = agentops_output.get("value", None)
555
+ break
556
+
557
+ elif reward_match == RewardMatchPolicy.FIRST_SIBLING:
558
+ for item in self.traverse():
559
+ assign_to: List[Tuple[str, int]] = []
560
+ for child in item.children:
561
+ if child.id in llm_call_ids:
562
+ assign_to.append((child.id, child.end_time)) # type: ignore
563
+
564
+ agentops_output = child.maybe_reward_dict()
565
+ if agentops_output and agentops_output.get("type") == "reward":
566
+ for assign_to_id, assign_to_end_time in reversed(assign_to):
567
+ if assign_to_end_time > child.start_time: # type: ignore
568
+ # This reward happens before the end of the LLM call.
569
+ continue
570
+ if assign_to_id in rewards:
571
+ continue
572
+ rewards[assign_to_id] = agentops_output.get("value", None)
573
+ break
574
+
575
+ return rewards
576
+
577
+ def extract_prompt_image_urls(self, prompt_raw_content: Any) -> List[str]:
578
+ """Extract image URLs from the span attributes, in order of appearance.
579
+
580
+ Args:
581
+ prompt_raw_content: The raw content of the prompt, which can be in one of several formats:
582
+
583
+ - List[dict]: A list of message entries, each being a dict with at least a "content" key.
584
+ - Dict[str, Any]: A dictionary, often with numeric string keys (e.g., `{"0": {...}, "1": {...}}`), where each value is a message entry.
585
+ If the dict does not have numeric keys, it is treated as a single message entry.
586
+ """
587
+ message_entries: List[Any] = []
588
+ if isinstance(prompt_raw_content, list):
589
+ message_entries = cast(List[Any], prompt_raw_content)
590
+ elif isinstance(prompt_raw_content, dict):
591
+ # Common when the attributes expand to {"0": {...}, "prompt_filter_results": ...}
592
+ numeric_keys = [
593
+ key
594
+ for key in cast(Dict[str, Any], prompt_raw_content).keys()
595
+ if isinstance(key, str) and key.isdigit() # pyright: ignore[reportUnnecessaryIsInstance]
596
+ ]
597
+ if numeric_keys:
598
+ for key in sorted(numeric_keys, key=int):
599
+ message_entries.append(prompt_raw_content[key])
600
+ else:
601
+ message_entries = [prompt_raw_content]
602
+ else:
603
+ return []
604
+
605
+ image_urls: List[str] = []
606
+ for message in cast(List[Dict[str, Any]], message_entries):
607
+ if (
608
+ not isinstance(message, dict) # pyright: ignore[reportUnnecessaryIsInstance]
609
+ or "content" not in message
610
+ ):
611
+ continue
612
+ content = message["content"]
613
+ if isinstance(content, str):
614
+ try:
615
+ content = json.loads(content) # This content should now be a list
616
+ except json.JSONDecodeError:
617
+ logger.debug(f"Failed to parse message content as JSON: {content}")
618
+ continue
619
+ if isinstance(content, list):
620
+ for content_part in cast(List[Dict[str, Any]], content):
621
+ if not isinstance(content_part, dict): # pyright: ignore[reportUnnecessaryIsInstance]
622
+ continue
623
+ if content_part.get("type") == "image_url":
624
+ image_url_dict = cast(Dict[str, Any], content_part.get("image_url"))
625
+ if not isinstance(image_url_dict, dict): # pyright: ignore[reportUnnecessaryIsInstance]
626
+ continue
627
+ if "url" in image_url_dict:
628
+ image_urls.append(image_url_dict["url"])
629
+ return image_urls
630
+
631
+ def span_to_triplet(self, span: Span, agent_name: str) -> Triplet:
632
+ """Convert a span to a triplet.
633
+
634
+ Subclass can override this method to add more fields to the triplet,
635
+ such as chat messages and tool calls.
636
+ """
637
+ prompt_token_ids = (
638
+ _attributes_get_ids_multiple(
639
+ span.attributes,
640
+ [
641
+ "prompt_token_ids",
642
+ "mantisdk.operation.output.prompt_token_ids", # Weave tracer
643
+ ],
644
+ )
645
+ or []
646
+ )
647
+ response_token_ids = (
648
+ _attributes_get_ids_multiple(
649
+ span.attributes,
650
+ [
651
+ "response_token_ids",
652
+ "mantisdk.operation.output.response_token_ids.0", # Weave tracer
653
+ "mantisdk.operation.output.choices.0.token_ids", # Weave tracer with newer vLLM
654
+ "mantisdk.operation.output.choices.0.provider_specific_fields.token_ids", # new vLLM + new OpenAI client SDK
655
+ ],
656
+ )
657
+ or []
658
+ )
659
+
660
+ response_id = _attributes_get_multiple(
661
+ span.attributes, ["gen_ai.response.id", "mantisdk.operation.output.id"]
662
+ )
663
+ request_metadata = _attributes_unflatten_multiple(
664
+ span.attributes, ["gen_ai.request", "mantisdk.operation.input"]
665
+ )
666
+ response_metadata = _attributes_unflatten_multiple(
667
+ span.attributes, ["gen_ai.response", "mantisdk.operation.output"]
668
+ )
669
+ # Special handling for Weave tracer: messages are handled separately
670
+ if isinstance(request_metadata, dict):
671
+ request_metadata.pop("messages", None)
672
+ if isinstance(response_metadata, dict):
673
+ response_metadata.pop("choices", None)
674
+ response_metadata.pop("prompt_token_ids", None)
675
+ response_metadata.pop("response_token_ids", None)
676
+
677
+ prompt_raw_content = _attributes_unflatten_multiple(
678
+ span.attributes, ["gen_ai.prompt", "mantisdk.operation.input.messages"]
679
+ )
680
+ completion_raw_content = _attributes_unflatten_multiple(
681
+ span.attributes, ["gen_ai.completion", "mantisdk.operation.output.choices"]
682
+ )
683
+ image_urls = self.extract_prompt_image_urls(prompt_raw_content)
684
+ prompt_payload = {"token_ids": prompt_token_ids, "raw_content": prompt_raw_content, "image_urls": image_urls}
685
+ response_payload = {"token_ids": response_token_ids, "raw_content": completion_raw_content}
686
+
687
+ # FIXME: logprob doesn't support Weave tracer yet.
688
+ logprobs_content = span.attributes.get("logprobs.content", None) # type: ignore
689
+ if isinstance(logprobs_content, str):
690
+ logprobs_content = json.loads(logprobs_content)
691
+ response_payload["logprobs"] = logprobs_content
692
+
693
+ return Triplet(
694
+ prompt=prompt_payload,
695
+ response=response_payload,
696
+ reward=None,
697
+ metadata=dict(
698
+ request=request_metadata, response=response_metadata, response_id=response_id, agent_name=agent_name
699
+ ),
700
+ )
701
+
702
+ def to_trajectory(
703
+ self,
704
+ llm_call_match: str = r"openai\.chat\.completion",
705
+ agent_match: Optional[str] = None,
706
+ exclude_llm_call_in_reward: bool = True,
707
+ dedup_llm_call: bool = True,
708
+ reward_match: RewardMatchPolicy = RewardMatchPolicy.FIRST_OCCURRENCE,
709
+ final_reward: Optional[float] = None,
710
+ _skip_empty_token_spans: bool = False,
711
+ ) -> List[Triplet]:
712
+ """Convert the trace tree into a trajectory of [`Triplet`][mantisdk.Triplet] items.
713
+
714
+ Args:
715
+ llm_call_match: Regular expression for LLM call span names.
716
+ agent_match: Optional regular expression for agent span names.
717
+ exclude_llm_call_in_reward: When `True`, prevents searching for rewards under the LLM
718
+ call subtree.
719
+ dedup_llm_call: When `True`, deduplicates spans using the LLM response identifier.
720
+ reward_match: Reward matching policy used to associate reward spans with LLM calls.
721
+ final_reward: Optional reward appended to the final transition when provided.
722
+
723
+ Returns:
724
+ A list of [`Triplet`][mantisdk.Triplet] objects ordered by call sequence.
725
+ """
726
+ # Find all LLM calls
727
+ llm_calls = self.find_llm_calls(
728
+ llm_call_match=llm_call_match,
729
+ agent_match=agent_match,
730
+ within_matching_subtree="*" if agent_match is None else None,
731
+ within_reward=False if exclude_llm_call_in_reward else None,
732
+ within_llm_call=False if dedup_llm_call else None,
733
+ existing_llm_call_response_ids=set(),
734
+ )
735
+
736
+ id_transitions: List[Tuple[str, Triplet]] = []
737
+ # We need to filter out the LLM calls with unrecorded token IDs
738
+ filtered_llm_calls: List[Tuple[TraceTree, str]] = []
739
+ for llm_call, agent_name in llm_calls:
740
+ triplet = self.span_to_triplet(llm_call.span, agent_name)
741
+ # This is a hot-fix for Tinker+CrewAI, which has some anonymous requests outside the trained agent.
742
+ # TODO: We might need to reconsider this.
743
+ if _skip_empty_token_spans and (
744
+ not triplet.prompt.get("token_ids") or not triplet.response.get("token_ids")
745
+ ):
746
+ logger.warning(f"Skipping LLM call with unrecorded token IDs: {triplet}")
747
+ continue
748
+ filtered_llm_calls.append((llm_call, agent_name))
749
+ id_transitions.append((llm_call.id, triplet))
750
+
751
+ rewards = self.match_rewards(reward_match, [call for call, _ in filtered_llm_calls])
752
+ transitions = [
753
+ transition.model_copy(update={"reward": rewards.get(id, None)}) for id, transition in id_transitions
754
+ ]
755
+ if final_reward is not None and len(transitions) > 0:
756
+ # Add the final reward to the last transition
757
+ transitions[-1] = transitions[-1].model_copy(update={"reward": final_reward})
758
+ return transitions
759
+
760
+ def __repr__(self):
761
+ return (
762
+ f"TraceTree(id={self.id}, span={self.span}, start_time={self.start_time}, "
763
+ + f"end_time={self.end_time}, children={self.children})"
764
+ )
765
+
766
+
767
+ class TraceToTripletBase(TraceAdapter[List[Triplet]]):
768
+ """Base class for adapters that emit [`Triplet`][mantisdk.Triplet] trajectories."""
769
+
770
+
771
+ class TracerTraceToTriplet(TraceToTripletBase):
772
+ """Convert tracer-emitted spans into triplet trajectories.
773
+
774
+ Attributes:
775
+ repair_hierarchy: When `True`, repair the span tree using
776
+ [`TraceTree.repair_hierarchy()`][mantisdk.adapter.triplet.TraceTree.repair_hierarchy]
777
+ before matching calls and rewards.
778
+ llm_call_match: Regular expression pattern that selects LLM call span names.
779
+ agent_match: Optional regular expression pattern for agent span names. When omitted, spans
780
+ from any agent are considered.
781
+ exclude_llm_call_in_reward: When `True`, ignore matches under reward spans while searching
782
+ for rewards.
783
+ reward_match: Strategy used to associate rewards with LLM calls.
784
+ """
785
+
786
+ def __init__(
787
+ self,
788
+ repair_hierarchy: bool = True,
789
+ llm_call_match: str = r"openai\.chat\.completion",
790
+ agent_match: Optional[str] = None,
791
+ exclude_llm_call_in_reward: bool = True,
792
+ reward_match: RewardMatchPolicy = RewardMatchPolicy.FIRST_OCCURRENCE,
793
+ _skip_empty_token_spans: bool = False,
794
+ ):
795
+ self.repair_hierarchy = repair_hierarchy
796
+ self.llm_call_match = llm_call_match
797
+ self.agent_match = agent_match
798
+ self.exclude_llm_call_in_reward = exclude_llm_call_in_reward
799
+ self.reward_match = reward_match
800
+ self._skip_empty_token_spans = _skip_empty_token_spans
801
+
802
+ def visualize(
803
+ self,
804
+ source: Union[List[Span], List[ReadableSpan]],
805
+ /,
806
+ filename: str = "trace_tree",
807
+ interested_span_match: str | None = None,
808
+ ) -> TraceTree:
809
+ """Visualize the trace tree built from the supplied spans.
810
+
811
+ Args:
812
+ source: Collection of Mantisdk [`Span`][mantisdk.Span] objects
813
+ or raw `opentelemetry.sdk.trace.ReadableSpan` instances.
814
+ filename: Base filename for the generated image; `.png` is appended automatically.
815
+ interested_span_match: Optional regular expression used to highlight a subset of spans.
816
+
817
+ Returns:
818
+ The [`TraceTree`][mantisdk.adapter.triplet.TraceTree] built from the provided
819
+ spans.
820
+ """
821
+ source_normalized = [
822
+ Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
823
+ for span in source
824
+ ]
825
+ trace_tree = TraceTree.from_spans(source_normalized)
826
+ if self.repair_hierarchy:
827
+ trace_tree.repair_hierarchy()
828
+ trace_tree.visualize(filename, interested_span_match=interested_span_match)
829
+ return trace_tree
830
+
831
+ def adapt(self, source: Union[Sequence[Span], Sequence[ReadableSpan]], /) -> List[Triplet]: # type: ignore
832
+ """Convert tracer spans into [`Triplet`][mantisdk.Triplet] trajectories.
833
+
834
+ Args:
835
+ source: Mantisdk spans or raw OpenTelemetry spans that form a trace.
836
+
837
+ Returns:
838
+ Ordered list of trajectory transitions with prompt, response, and reward information.
839
+ """
840
+ source_normalized = [
841
+ Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
842
+ for span in source
843
+ ]
844
+ trace_tree = TraceTree.from_spans(source_normalized)
845
+ if self.repair_hierarchy:
846
+ trace_tree.repair_hierarchy()
847
+ trajectory = trace_tree.to_trajectory(
848
+ llm_call_match=self.llm_call_match,
849
+ agent_match=self.agent_match,
850
+ exclude_llm_call_in_reward=self.exclude_llm_call_in_reward,
851
+ reward_match=self.reward_match,
852
+ _skip_empty_token_spans=self._skip_empty_token_spans,
853
+ )
854
+ return trajectory
855
+
856
+
857
+ class LlmProxyTraceToTriplet(TraceToTripletBase):
858
+ """Convert telemetry emitted by the LLM Proxy into triplet trajectories.
859
+
860
+ !!! warning
861
+ This adapter is experimental and might be merged with
862
+ [`TracerTraceToTriplet`][mantisdk.TracerTraceToTriplet] in the future.
863
+
864
+ !!! danger
865
+ Do not rely on timestamps when using this adapter. Proxy spans can originate on different
866
+ machines with unsynchronised clocks, so `sequence_id` is treated as the sole source of
867
+ ordering.
868
+
869
+ Strategy:
870
+
871
+ 1. Sort spans by `(sequence_id, start_time)` for deterministic processing.
872
+ 2. Extract token identifiers from `litellm_request` or `raw_gen_ai_request` spans.
873
+ 3. Extract rewards from spans exposing AgentOps-style payloads or explicit reward spans.
874
+ 4. Match each reward to the most recent unmatched LLM call whose sequence is smaller.
875
+ """
876
+
877
+ def _literal_eval_maybe(self, v: Any) -> Any:
878
+ import ast
879
+
880
+ if isinstance(v, str):
881
+ try:
882
+ return ast.literal_eval(v)
883
+ except Exception:
884
+ return v
885
+ return v
886
+
887
+ def _extract_tokens_from_raw(self, attrs: Dict[str, Any]) -> Tuple[List[int], List[int]]:
888
+ """Extract token ids from raw_gen_ai_request attributes.
889
+
890
+ - llm.hosted_vllm.prompt_token_ids: string -> List[int]
891
+ - llm.hosted_vllm.response_token_ids: string -> List[List[int]] -> take first
892
+ - llm.hosted_vllm.choices: string -> [{'token_ids': [...]}] -> take first
893
+ """
894
+ prompt_ids: List[int] = []
895
+ resp_ids: List[int] = []
896
+
897
+ # prompt
898
+ p = attrs.get("llm.hosted_vllm.prompt_token_ids")
899
+ p = self._literal_eval_maybe(p)
900
+ if isinstance(p, list) and all(isinstance(x, int) for x in p): # type: ignore
901
+ prompt_ids = cast(List[int], p)
902
+
903
+ # response preferred path
904
+ r = attrs.get("llm.hosted_vllm.response_token_ids")
905
+ r = self._literal_eval_maybe(r)
906
+ if isinstance(r, list) and len(r) > 0 and isinstance(r[0], list): # type: ignore
907
+ first = cast(List[Any], r[0])
908
+ if all(isinstance(x, int) for x in first):
909
+ resp_ids = cast(List[int], first)
910
+
911
+ # fallback via choices
912
+ if not resp_ids:
913
+ choices = attrs.get("llm.hosted_vllm.choices")
914
+ choices = self._literal_eval_maybe(choices)
915
+ if isinstance(choices, list) and choices:
916
+ cand = cast(Any, choices[0])
917
+ if isinstance(cand, dict):
918
+ tids = cast(Dict[str, Any], cand).get("token_ids")
919
+ if isinstance(tids, list) and all(isinstance(x, int) for x in tids): # type: ignore
920
+ resp_ids = cast(List[int], tids)
921
+
922
+ return prompt_ids, resp_ids
923
+
924
+ def _extract_tokens_from_openai(self, attrs: Dict[str, Any]) -> Tuple[List[int], List[int]]:
925
+ prompt_ids = cast(Any, attrs.get("prompt_token_ids") or [])
926
+ resp_ids = cast(Any, attrs.get("response_token_ids") or [])
927
+ prompt_ids = self._literal_eval_maybe(prompt_ids)
928
+ resp_ids = self._literal_eval_maybe(resp_ids)
929
+ if not (isinstance(prompt_ids, list) and all(isinstance(x, int) for x in prompt_ids)): # type: ignore
930
+ prompt_ids = []
931
+ if not (isinstance(resp_ids, list) and all(isinstance(x, int) for x in resp_ids)): # type: ignore
932
+ resp_ids = []
933
+ return cast(List[int], prompt_ids), cast(List[int], resp_ids)
934
+
935
+ def _maybe_reward_value(self, span: Span) -> Optional[float]:
936
+ """Parse reward from typical AgentOps payloads or explicit reward spans."""
937
+ return get_reward_value(span)
938
+
939
+ def _request_id_from_attrs(self, attrs: Dict[str, Any]) -> Optional[str]:
940
+ # Prefer OpenAI-like id if present, else proxy raw id.
941
+ rid = attrs.get("gen_ai.response.id") or attrs.get("llm.hosted_vllm.id")
942
+ return str(rid) if isinstance(rid, str) and rid else None
943
+
944
+ def adapt(self, source: Sequence[Span], /) -> List[Triplet]: # type: ignore
945
+ """Convert LLM Proxy spans into [`Triplet`][mantisdk.Triplet] trajectories.
946
+
947
+ Args:
948
+ source: Spans emitted by the LLM Proxy containing prompt, response, and reward data.
949
+
950
+ Returns:
951
+ Ordered trajectory transitions matched purely by `sequence_id`.
952
+ """
953
+ # 1) Sort deterministically by (sequence_id, start_time).
954
+ spans = sorted(
955
+ source,
956
+ key=lambda s: (s.sequence_id, s.start_time),
957
+ )
958
+
959
+ # 2) Collect LLM calls with token IDs.
960
+ llm_items: List[Dict[str, Any]] = []
961
+ seen_request_ids: set[str] = set()
962
+ for s in spans:
963
+ attrs = s.attributes or {}
964
+ prompt_ids: List[int] = []
965
+ resp_ids: List[int] = []
966
+
967
+ if s.name == "raw_gen_ai_request":
968
+ prompt_ids, resp_ids = self._extract_tokens_from_raw(attrs)
969
+ elif s.name == "litellm_request":
970
+ # Some proxies never include token ids here. Ignore unless present.
971
+ prompt_ids, resp_ids = self._extract_tokens_from_openai(attrs)
972
+
973
+ if prompt_ids and resp_ids:
974
+ rid = self._request_id_from_attrs(attrs)
975
+ if rid:
976
+ # Duplicated request ID. This request is already handled.
977
+ if rid in seen_request_ids:
978
+ continue
979
+ seen_request_ids.add(rid)
980
+ llm_items.append(
981
+ dict(
982
+ span=s,
983
+ seq=s.sequence_id,
984
+ response_ids=resp_ids,
985
+ prompt_ids=prompt_ids,
986
+ request_id=rid,
987
+ )
988
+ )
989
+
990
+ # Order LLM items by sequence only.
991
+ llm_items.sort(key=lambda x: x["seq"])
992
+
993
+ # Collect rewards by sequence only.
994
+ rewards: List[Tuple[int, Optional[float]]] = []
995
+ for s in spans:
996
+ val = self._maybe_reward_value(s)
997
+ if val is not None:
998
+ rewards.append((s.sequence_id, val))
999
+
1000
+ # First-occurrence matching by sequence_id only:
1001
+ # For reward at sequence R, assign to the most recent unmatched LLM with seq < R.
1002
+ assigned: Dict[str, Optional[float]] = {}
1003
+ for r_seq, r_val in sorted(rewards, key=lambda x: x[0]):
1004
+ for item in reversed(llm_items):
1005
+ sid = item["span"].span_id
1006
+ if sid in assigned:
1007
+ continue
1008
+ if item["seq"] < r_seq:
1009
+ assigned[sid] = r_val
1010
+ break
1011
+
1012
+ # Build triplets in LLM sequence order.
1013
+ triplets: List[Triplet] = []
1014
+ for item in llm_items:
1015
+ s = item["span"]
1016
+ triplets.append(
1017
+ Triplet(
1018
+ prompt={"token_ids": item["prompt_ids"]},
1019
+ response={"token_ids": item["response_ids"]},
1020
+ reward=assigned.get(s.span_id, None),
1021
+ metadata=dict(
1022
+ # This is called response_id to align with the other adapters.
1023
+ response_id=item["request_id"],
1024
+ ),
1025
+ )
1026
+ )
1027
+
1028
+ return triplets