mantisdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mantisdk might be problematic. Click here for more details.
- mantisdk/__init__.py +22 -0
- mantisdk/adapter/__init__.py +15 -0
- mantisdk/adapter/base.py +94 -0
- mantisdk/adapter/messages.py +270 -0
- mantisdk/adapter/triplet.py +1028 -0
- mantisdk/algorithm/__init__.py +39 -0
- mantisdk/algorithm/apo/__init__.py +5 -0
- mantisdk/algorithm/apo/apo.py +889 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant01.poml +22 -0
- mantisdk/algorithm/apo/prompts/apply_edit_variant02.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant01.poml +18 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant02.poml +16 -0
- mantisdk/algorithm/apo/prompts/text_gradient_variant03.poml +107 -0
- mantisdk/algorithm/base.py +162 -0
- mantisdk/algorithm/decorator.py +264 -0
- mantisdk/algorithm/fast.py +250 -0
- mantisdk/algorithm/gepa/__init__.py +59 -0
- mantisdk/algorithm/gepa/adapter.py +459 -0
- mantisdk/algorithm/gepa/gepa.py +364 -0
- mantisdk/algorithm/gepa/lib/__init__.py +18 -0
- mantisdk/algorithm/gepa/lib/adapters/README.md +12 -0
- mantisdk/algorithm/gepa/lib/adapters/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/README.md +341 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/__init__.py +1 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/anymaths_adapter.py +174 -0
- mantisdk/algorithm/gepa/lib/adapters/anymaths_adapter/requirements.txt +1 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/README.md +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/default_adapter/default_adapter.py +209 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/README.md +7 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_adapter/dspy_adapter.py +307 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/README.md +99 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/dspy_program_proposal_signature.py +137 -0
- mantisdk/algorithm/gepa/lib/adapters/dspy_full_program_adapter/full_program_adapter.py +266 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/GEPA_RAG.md +621 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/__init__.py +56 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/evaluation_metrics.py +226 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/generic_rag_adapter.py +496 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/rag_pipeline.py +238 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_store_interface.py +212 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/__init__.py +2 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/chroma_store.py +196 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/lancedb_store.py +422 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/milvus_store.py +409 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/qdrant_store.py +368 -0
- mantisdk/algorithm/gepa/lib/adapters/generic_rag_adapter/vector_stores/weaviate_store.py +418 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/README.md +552 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/__init__.py +37 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_adapter.py +705 -0
- mantisdk/algorithm/gepa/lib/adapters/mcp_adapter/mcp_client.py +364 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/README.md +9 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/adapters/terminal_bench_adapter/terminal_bench_adapter.py +217 -0
- mantisdk/algorithm/gepa/lib/api.py +375 -0
- mantisdk/algorithm/gepa/lib/core/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/core/adapter.py +180 -0
- mantisdk/algorithm/gepa/lib/core/data_loader.py +74 -0
- mantisdk/algorithm/gepa/lib/core/engine.py +356 -0
- mantisdk/algorithm/gepa/lib/core/result.py +233 -0
- mantisdk/algorithm/gepa/lib/core/state.py +636 -0
- mantisdk/algorithm/gepa/lib/examples/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/examples/aime.py +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/eval_default.py +111 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/instruction_prompt.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/prompt-templates/optimal_prompt.txt +24 -0
- mantisdk/algorithm/gepa/lib/examples/anymaths-bench/train_anymaths.py +177 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/arc_agi.ipynb +25705 -0
- mantisdk/algorithm/gepa/lib/examples/dspy_full_program_evolution/example.ipynb +348 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/__init__.py +4 -0
- mantisdk/algorithm/gepa/lib/examples/mcp_adapter/mcp_optimization_example.py +455 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/RAG_GUIDE.md +613 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/__init__.py +9 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/rag_optimization.py +824 -0
- mantisdk/algorithm/gepa/lib/examples/rag_adapter/requirements-rag.txt +29 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/instruction_prompt.txt +16 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/prompt-templates/terminus.txt +9 -0
- mantisdk/algorithm/gepa/lib/examples/terminal-bench/train_terminus.py +161 -0
- mantisdk/algorithm/gepa/lib/gepa_utils.py +117 -0
- mantisdk/algorithm/gepa/lib/logging/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/logging/experiment_tracker.py +187 -0
- mantisdk/algorithm/gepa/lib/logging/logger.py +75 -0
- mantisdk/algorithm/gepa/lib/logging/utils.py +103 -0
- mantisdk/algorithm/gepa/lib/proposer/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/base.py +31 -0
- mantisdk/algorithm/gepa/lib/proposer/merge.py +357 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/base.py +49 -0
- mantisdk/algorithm/gepa/lib/proposer/reflective_mutation/reflective_mutation.py +176 -0
- mantisdk/algorithm/gepa/lib/py.typed +0 -0
- mantisdk/algorithm/gepa/lib/strategies/__init__.py +0 -0
- mantisdk/algorithm/gepa/lib/strategies/batch_sampler.py +77 -0
- mantisdk/algorithm/gepa/lib/strategies/candidate_selector.py +50 -0
- mantisdk/algorithm/gepa/lib/strategies/component_selector.py +36 -0
- mantisdk/algorithm/gepa/lib/strategies/eval_policy.py +64 -0
- mantisdk/algorithm/gepa/lib/strategies/instruction_proposal.py +127 -0
- mantisdk/algorithm/gepa/lib/utils/__init__.py +10 -0
- mantisdk/algorithm/gepa/lib/utils/stop_condition.py +196 -0
- mantisdk/algorithm/gepa/tracing.py +105 -0
- mantisdk/algorithm/utils.py +177 -0
- mantisdk/algorithm/verl/__init__.py +5 -0
- mantisdk/algorithm/verl/interface.py +202 -0
- mantisdk/cli/__init__.py +56 -0
- mantisdk/cli/prometheus.py +115 -0
- mantisdk/cli/store.py +131 -0
- mantisdk/cli/vllm.py +29 -0
- mantisdk/client.py +408 -0
- mantisdk/config.py +348 -0
- mantisdk/emitter/__init__.py +43 -0
- mantisdk/emitter/annotation.py +370 -0
- mantisdk/emitter/exception.py +54 -0
- mantisdk/emitter/message.py +61 -0
- mantisdk/emitter/object.py +117 -0
- mantisdk/emitter/reward.py +320 -0
- mantisdk/env_var.py +156 -0
- mantisdk/execution/__init__.py +15 -0
- mantisdk/execution/base.py +64 -0
- mantisdk/execution/client_server.py +443 -0
- mantisdk/execution/events.py +69 -0
- mantisdk/execution/inter_process.py +16 -0
- mantisdk/execution/shared_memory.py +282 -0
- mantisdk/instrumentation/__init__.py +119 -0
- mantisdk/instrumentation/agentops.py +314 -0
- mantisdk/instrumentation/agentops_langchain.py +45 -0
- mantisdk/instrumentation/litellm.py +83 -0
- mantisdk/instrumentation/vllm.py +81 -0
- mantisdk/instrumentation/weave.py +500 -0
- mantisdk/litagent/__init__.py +11 -0
- mantisdk/litagent/decorator.py +536 -0
- mantisdk/litagent/litagent.py +252 -0
- mantisdk/llm_proxy.py +1890 -0
- mantisdk/logging.py +370 -0
- mantisdk/reward.py +7 -0
- mantisdk/runner/__init__.py +11 -0
- mantisdk/runner/agent.py +845 -0
- mantisdk/runner/base.py +182 -0
- mantisdk/runner/legacy.py +309 -0
- mantisdk/semconv.py +170 -0
- mantisdk/server.py +401 -0
- mantisdk/store/__init__.py +23 -0
- mantisdk/store/base.py +897 -0
- mantisdk/store/client_server.py +2092 -0
- mantisdk/store/collection/__init__.py +30 -0
- mantisdk/store/collection/base.py +587 -0
- mantisdk/store/collection/memory.py +970 -0
- mantisdk/store/collection/mongo.py +1412 -0
- mantisdk/store/collection_based.py +1823 -0
- mantisdk/store/insight.py +648 -0
- mantisdk/store/listener.py +58 -0
- mantisdk/store/memory.py +396 -0
- mantisdk/store/mongo.py +165 -0
- mantisdk/store/sqlite.py +3 -0
- mantisdk/store/threading.py +357 -0
- mantisdk/store/utils.py +142 -0
- mantisdk/tracer/__init__.py +16 -0
- mantisdk/tracer/agentops.py +242 -0
- mantisdk/tracer/base.py +287 -0
- mantisdk/tracer/dummy.py +106 -0
- mantisdk/tracer/otel.py +555 -0
- mantisdk/tracer/weave.py +677 -0
- mantisdk/trainer/__init__.py +6 -0
- mantisdk/trainer/init_utils.py +263 -0
- mantisdk/trainer/legacy.py +367 -0
- mantisdk/trainer/registry.py +12 -0
- mantisdk/trainer/trainer.py +618 -0
- mantisdk/types/__init__.py +6 -0
- mantisdk/types/core.py +553 -0
- mantisdk/types/resources.py +204 -0
- mantisdk/types/tracer.py +515 -0
- mantisdk/types/tracing.py +218 -0
- mantisdk/utils/__init__.py +1 -0
- mantisdk/utils/id.py +18 -0
- mantisdk/utils/metrics.py +1025 -0
- mantisdk/utils/otel.py +578 -0
- mantisdk/utils/otlp.py +536 -0
- mantisdk/utils/server_launcher.py +1045 -0
- mantisdk/utils/system_snapshot.py +81 -0
- mantisdk/verl/__init__.py +8 -0
- mantisdk/verl/__main__.py +6 -0
- mantisdk/verl/async_server.py +46 -0
- mantisdk/verl/config.yaml +27 -0
- mantisdk/verl/daemon.py +1154 -0
- mantisdk/verl/dataset.py +44 -0
- mantisdk/verl/entrypoint.py +248 -0
- mantisdk/verl/trainer.py +549 -0
- mantisdk-0.1.0.dist-info/METADATA +119 -0
- mantisdk-0.1.0.dist-info/RECORD +190 -0
- mantisdk-0.1.0.dist-info/WHEEL +4 -0
- mantisdk-0.1.0.dist-info/entry_points.txt +2 -0
- mantisdk-0.1.0.dist-info/licenses/LICENSE +19 -0
|
@@ -0,0 +1,1028 @@
|
|
|
1
|
+
# Copyright (c) Microsoft. All rights reserved.
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import re
|
|
8
|
+
from enum import Enum
|
|
9
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
10
|
+
|
|
11
|
+
from opentelemetry.sdk.trace import ReadableSpan
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
from mantisdk.emitter.reward import get_reward_value
|
|
15
|
+
from mantisdk.semconv import AGL_OPERATION, AGL_REWARD, LightningSpanAttributes
|
|
16
|
+
from mantisdk.types import Span, Triplet
|
|
17
|
+
from mantisdk.utils.otel import filter_and_unflatten_attributes
|
|
18
|
+
|
|
19
|
+
from .base import TraceAdapter
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _attributes_get_multiple(attributes: Dict[str, Any], keys: List[str]) -> Optional[str]:
|
|
25
|
+
"""Get a string from the attributes, if present.
|
|
26
|
+
If there are multiple matches, the first one is returned.
|
|
27
|
+
"""
|
|
28
|
+
for key in keys:
|
|
29
|
+
if key in attributes:
|
|
30
|
+
if isinstance(attributes[key], str):
|
|
31
|
+
return attributes[key]
|
|
32
|
+
else:
|
|
33
|
+
logger.warning(f"Attribute {key} is found but is not a string: {attributes[key]}")
|
|
34
|
+
return None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _attributes_get_ids_multiple(attributes: Dict[str, Any], keys: List[str]) -> Optional[List[int]]:
|
|
38
|
+
"""Get a list of integers from the attributes, if present.
|
|
39
|
+
If there are multiple matches, the first one is returned.
|
|
40
|
+
"""
|
|
41
|
+
for key in keys:
|
|
42
|
+
if key in attributes:
|
|
43
|
+
if (isinstance(attributes[key], list) or isinstance(attributes[key], tuple)) and all(
|
|
44
|
+
isinstance(x, int) for x in attributes[key]
|
|
45
|
+
):
|
|
46
|
+
return list(attributes[key])
|
|
47
|
+
else:
|
|
48
|
+
logger.warning(f"Attribute {key} is found but is not a list of integers: {attributes[key]}")
|
|
49
|
+
return None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _attributes_unflatten_multiple(
|
|
53
|
+
attributes: Dict[str, Any], keys: List[str]
|
|
54
|
+
) -> Union[Dict[str, Any], List[Any], None]:
|
|
55
|
+
"""Unflatten the attributes, if present.
|
|
56
|
+
If there are multiple matches, the first one is returned.
|
|
57
|
+
"""
|
|
58
|
+
for key in keys:
|
|
59
|
+
result = filter_and_unflatten_attributes(attributes, key)
|
|
60
|
+
if result:
|
|
61
|
+
return result
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class Transition(BaseModel):
|
|
66
|
+
"""A single transition within a reinforcement learning trajectory.
|
|
67
|
+
|
|
68
|
+
Attributes:
|
|
69
|
+
state: Token identifiers describing the model input state.
|
|
70
|
+
action: Token identifiers representing the model output.
|
|
71
|
+
response_id: Identifier of the LLM response used to deduplicate spans.
|
|
72
|
+
agent_name: Human-readable agent name captured from the trace.
|
|
73
|
+
reward: Scalar reward associated with the transition, if available.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
state: List[int]
|
|
77
|
+
action: List[int]
|
|
78
|
+
response_id: Optional[str]
|
|
79
|
+
# action_logprobs: List[float]
|
|
80
|
+
agent_name: str
|
|
81
|
+
reward: Optional[float]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class RewardMatchPolicy(str, Enum):
|
|
85
|
+
"""Strategies for matching rewards to LLM call spans.
|
|
86
|
+
|
|
87
|
+
!!! note
|
|
88
|
+
Each reward span must expose a payload shaped like `{"type": "reward", "value": <float>|None}`
|
|
89
|
+
as described in `reward.py`.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
FIRST_SIBLING = "first_sibling"
|
|
93
|
+
"""Use the first sibling in the current trace subtree as the reward unless another LLM call match is found."""
|
|
94
|
+
|
|
95
|
+
FIRST_OCCURRENCE = "first_occurrence"
|
|
96
|
+
"""Use the first reward encountered in chronological order after the current LLM call match."""
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class TraceTree:
|
|
100
|
+
"""Tree representation of a trace span and its descendants.
|
|
101
|
+
|
|
102
|
+
Attributes:
|
|
103
|
+
id: Unique identifier for the span node.
|
|
104
|
+
span: [`Span`][mantisdk.Span] backing this node.
|
|
105
|
+
children: Child nodes connected to the current span.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
id: str,
|
|
111
|
+
span: Span,
|
|
112
|
+
children: Optional[List["TraceTree"]] = None,
|
|
113
|
+
):
|
|
114
|
+
self.id = id
|
|
115
|
+
self.span = span
|
|
116
|
+
self.children = children or []
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def start_time(self):
|
|
120
|
+
return self.span.start_time
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def end_time(self):
|
|
124
|
+
return self.span.end_time
|
|
125
|
+
|
|
126
|
+
def find_id(self, id: str) -> "TraceTree | None":
|
|
127
|
+
if self.id == id:
|
|
128
|
+
return self
|
|
129
|
+
for child in self.children:
|
|
130
|
+
found = child.find_id(id)
|
|
131
|
+
if found:
|
|
132
|
+
return found
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
def add_child(self, child: "TraceTree") -> None:
|
|
136
|
+
self.children.append(child)
|
|
137
|
+
|
|
138
|
+
def visualize(self, filename: str, interested_span_match: str | None = None) -> None:
|
|
139
|
+
"""Render the trace tree with Graphviz for debugging purposes.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
filename: Base filename for the generated `.png` diagram.
|
|
143
|
+
interested_span_match: Optional regular expression used to keep only matching spans
|
|
144
|
+
(and their ancestors) in the output.
|
|
145
|
+
|
|
146
|
+
!!! note
|
|
147
|
+
The method requires the optional `graphviz` dependency to be available in the runtime
|
|
148
|
+
environment.
|
|
149
|
+
"""
|
|
150
|
+
import graphviz
|
|
151
|
+
|
|
152
|
+
dot = graphviz.Digraph(comment="Trace Tree")
|
|
153
|
+
|
|
154
|
+
should_visit_cache: Dict[str, bool] = {}
|
|
155
|
+
|
|
156
|
+
def should_visit(node: "TraceTree") -> bool:
|
|
157
|
+
if node.id in should_visit_cache:
|
|
158
|
+
return should_visit_cache[node.id]
|
|
159
|
+
if interested_span_match is not None:
|
|
160
|
+
if re.search(interested_span_match, node.span.name):
|
|
161
|
+
should_visit_cache[node.id] = True
|
|
162
|
+
return True
|
|
163
|
+
else:
|
|
164
|
+
should_visit_cache[node.id] = False
|
|
165
|
+
for child in node.children:
|
|
166
|
+
if should_visit(child):
|
|
167
|
+
should_visit_cache[node.id] = True
|
|
168
|
+
|
|
169
|
+
return should_visit_cache[node.id]
|
|
170
|
+
else:
|
|
171
|
+
return True
|
|
172
|
+
|
|
173
|
+
def visit(node: "TraceTree") -> bool:
|
|
174
|
+
if not should_visit(node):
|
|
175
|
+
return False
|
|
176
|
+
agent_name = node.agent_name()
|
|
177
|
+
vis_name = node.id[-8:] + " (" + node.span.name + ")"
|
|
178
|
+
if agent_name is not None:
|
|
179
|
+
vis_name += " [" + agent_name + "]"
|
|
180
|
+
dot.node(node.id, vis_name) # type: ignore
|
|
181
|
+
for child in node.children:
|
|
182
|
+
if visit(child):
|
|
183
|
+
dot.edge(node.id, child.id) # type: ignore
|
|
184
|
+
return True
|
|
185
|
+
|
|
186
|
+
visit(self)
|
|
187
|
+
dot.render(filename, format="png", cleanup=True) # type: ignore
|
|
188
|
+
|
|
189
|
+
def names_tuple(self) -> Tuple[str, List[Any]]:
|
|
190
|
+
"""Return the span name alongside nested child names.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
A tuple of the current span name and a list of tuples for each child containing the
|
|
194
|
+
child name and its descendants.
|
|
195
|
+
"""
|
|
196
|
+
name = self.span.name
|
|
197
|
+
agent_name = self.agent_name()
|
|
198
|
+
if agent_name is not None:
|
|
199
|
+
name += " [" + agent_name + "]"
|
|
200
|
+
children_names: List[Tuple[str, List[Any]]] = []
|
|
201
|
+
for child in self.children:
|
|
202
|
+
child_name, child_children = child.names_tuple()
|
|
203
|
+
children_names.append((child_name, child_children))
|
|
204
|
+
return name, children_names
|
|
205
|
+
|
|
206
|
+
def traverse(self) -> List["TraceTree"]:
|
|
207
|
+
"""Traverse the tree depth first and return every node."""
|
|
208
|
+
spans: List["TraceTree"] = [self]
|
|
209
|
+
for child in self.children:
|
|
210
|
+
spans.extend(child.traverse())
|
|
211
|
+
return spans
|
|
212
|
+
|
|
213
|
+
def to_json(self) -> dict[str, Any]:
|
|
214
|
+
"""Convert the tree node into a JSON-serialisable structure."""
|
|
215
|
+
if isinstance(self.span, ReadableSpan):
|
|
216
|
+
span_data = json.loads(self.span.to_json())
|
|
217
|
+
else:
|
|
218
|
+
span_data = self.span.model_dump()
|
|
219
|
+
return {
|
|
220
|
+
"id": self.id,
|
|
221
|
+
"span": span_data,
|
|
222
|
+
"children": [child.to_json() for child in self.children],
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def from_spans(cls, spans: List[Span]) -> "TraceTree":
|
|
227
|
+
"""Construct a tree from a flat list of spans.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
spans: Spans that collectively form a single trace segment.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
A [`TraceTree`][mantisdk.adapter.triplet.TraceTree] rooted at either the
|
|
234
|
+
discovered root span or a synthetic root when multiple roots are present.
|
|
235
|
+
|
|
236
|
+
Raises:
|
|
237
|
+
ValueError: If the span list is empty or no root span can be inferred.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
if not spans:
|
|
241
|
+
raise ValueError("No spans provided to create TraceTree.")
|
|
242
|
+
|
|
243
|
+
# Process trace items in topological order
|
|
244
|
+
id_to_span = {span.span_id: span for span in spans}
|
|
245
|
+
|
|
246
|
+
forward_graph: dict[str, list[str]] = {}
|
|
247
|
+
root_ids: list[str] = []
|
|
248
|
+
for span in spans:
|
|
249
|
+
span_id = span.span_id
|
|
250
|
+
if span.parent_id is None:
|
|
251
|
+
root_ids.append(span.span_id)
|
|
252
|
+
else:
|
|
253
|
+
if span.parent_id not in forward_graph:
|
|
254
|
+
forward_graph[span.parent_id] = []
|
|
255
|
+
forward_graph[span.parent_id].append(span_id)
|
|
256
|
+
|
|
257
|
+
# Diff between span with data and forward_graph keys
|
|
258
|
+
# Sometimes the top-level session span is lost.
|
|
259
|
+
unfound_roots = set(forward_graph.keys()) - set(id_to_span.keys())
|
|
260
|
+
for unfound_root in unfound_roots:
|
|
261
|
+
root_ids.append(unfound_root)
|
|
262
|
+
|
|
263
|
+
def visit(node_id: str) -> "TraceTree":
|
|
264
|
+
children: list[TraceTree] = []
|
|
265
|
+
if node_id in forward_graph:
|
|
266
|
+
for child_id in forward_graph[node_id]:
|
|
267
|
+
children.append(visit(child_id))
|
|
268
|
+
|
|
269
|
+
if node_id not in id_to_span:
|
|
270
|
+
assert len(children) > 0
|
|
271
|
+
virtual_span = Span.from_attributes(
|
|
272
|
+
rollout_id=children[0].span.rollout_id,
|
|
273
|
+
attempt_id=children[0].span.attempt_id,
|
|
274
|
+
sequence_id=children[0].span.sequence_id,
|
|
275
|
+
trace_id=children[0].span.trace_id,
|
|
276
|
+
span_id=node_id,
|
|
277
|
+
parent_id=None,
|
|
278
|
+
attributes={},
|
|
279
|
+
start_time=min(child.start_time for child in children if child.start_time is not None),
|
|
280
|
+
end_time=max(child.end_time for child in children if child.end_time is not None),
|
|
281
|
+
)
|
|
282
|
+
return cls(node_id, virtual_span, children=children)
|
|
283
|
+
else:
|
|
284
|
+
return cls(
|
|
285
|
+
node_id,
|
|
286
|
+
id_to_span[node_id],
|
|
287
|
+
children=children,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Create a virtual root span if multiple root spans are found
|
|
291
|
+
if len(root_ids) > 1:
|
|
292
|
+
root_spans = [visit(root_id) for root_id in root_ids]
|
|
293
|
+
virtual_root = TraceTree(
|
|
294
|
+
id="virtual-root",
|
|
295
|
+
span=Span.from_attributes(
|
|
296
|
+
rollout_id=root_spans[0].span.rollout_id,
|
|
297
|
+
attempt_id=root_spans[0].span.attempt_id,
|
|
298
|
+
sequence_id=root_spans[0].span.sequence_id,
|
|
299
|
+
trace_id=root_spans[0].span.trace_id,
|
|
300
|
+
span_id=None, # Generate one
|
|
301
|
+
parent_id=None,
|
|
302
|
+
name="virtual-root",
|
|
303
|
+
attributes={},
|
|
304
|
+
start_time=root_spans[0].start_time,
|
|
305
|
+
end_time=root_spans[-1].end_time,
|
|
306
|
+
),
|
|
307
|
+
children=root_spans,
|
|
308
|
+
)
|
|
309
|
+
return virtual_root
|
|
310
|
+
elif len(root_ids) == 0:
|
|
311
|
+
# No root spans found
|
|
312
|
+
raise ValueError("No root spans found in the trace.")
|
|
313
|
+
else:
|
|
314
|
+
root_span = visit(root_ids[0])
|
|
315
|
+
return root_span
|
|
316
|
+
|
|
317
|
+
def agent_name(self) -> Optional[str]:
|
|
318
|
+
"""Return the agent name associated with the span, if any.
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
Agent name extracted from known attributes, otherwise `None`.
|
|
322
|
+
"""
|
|
323
|
+
attributes = self.span.attributes
|
|
324
|
+
if attributes is None: # type: ignore
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
# Case 1: OpenAI Agent SDK
|
|
328
|
+
agent_name = cast(Optional[str], attributes.get("agent.name"))
|
|
329
|
+
if agent_name is not None:
|
|
330
|
+
return agent_name
|
|
331
|
+
|
|
332
|
+
# Case 2: Agentops decorator @agent
|
|
333
|
+
is_agent = attributes.get("agentops.span.kind") == "agent"
|
|
334
|
+
if is_agent:
|
|
335
|
+
agent_name = cast(Optional[str], attributes.get("operation.name"))
|
|
336
|
+
if agent_name is not None:
|
|
337
|
+
return agent_name
|
|
338
|
+
|
|
339
|
+
# Case 3: Autogen team
|
|
340
|
+
agent_name = cast(Optional[str], attributes.get("recipient_agent_type"))
|
|
341
|
+
if agent_name is not None:
|
|
342
|
+
return agent_name
|
|
343
|
+
|
|
344
|
+
# Case 4: LangGraph
|
|
345
|
+
agent_name = cast(Optional[str], attributes.get("langchain.chain.type"))
|
|
346
|
+
if agent_name is not None:
|
|
347
|
+
return agent_name
|
|
348
|
+
|
|
349
|
+
# Case 5: agent-framework
|
|
350
|
+
agent_name = cast(Optional[str], attributes.get("executor.id"))
|
|
351
|
+
if agent_name is not None:
|
|
352
|
+
return agent_name
|
|
353
|
+
|
|
354
|
+
# Case 6: Weave
|
|
355
|
+
is_agent_type = attributes.get("type") == "agent"
|
|
356
|
+
if is_agent_type:
|
|
357
|
+
agent_name = cast(Optional[str], attributes.get("mantisdk.operation.input.name"))
|
|
358
|
+
if agent_name is not None:
|
|
359
|
+
return agent_name
|
|
360
|
+
|
|
361
|
+
# Case 7: Weave + LangChain
|
|
362
|
+
if self.span.name.startswith("langchain.Chain."):
|
|
363
|
+
attributes_lc_name = cast(Optional[str], attributes.get("lc_name"))
|
|
364
|
+
if attributes_lc_name is not None:
|
|
365
|
+
return attributes_lc_name
|
|
366
|
+
|
|
367
|
+
def maybe_reward_dict(self) -> dict[str, Any]:
|
|
368
|
+
"""Return a reward payload if the span encodes one.
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
Dictionary containing reward metadata, or an empty dictionary when no reward is found.
|
|
372
|
+
"""
|
|
373
|
+
reward_value = get_reward_value(self.span)
|
|
374
|
+
if reward_value is not None:
|
|
375
|
+
return {"type": "reward", "value": reward_value}
|
|
376
|
+
else:
|
|
377
|
+
return {}
|
|
378
|
+
|
|
379
|
+
def is_reward_span(self) -> bool:
|
|
380
|
+
"""Return whether the span explicitly encodes a reward.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
`True` when the span payload describes a reward, otherwise `False`.
|
|
384
|
+
"""
|
|
385
|
+
maybe_reward = self.maybe_reward_dict()
|
|
386
|
+
if maybe_reward and maybe_reward.get("type") == "reward": # type: ignore
|
|
387
|
+
return True
|
|
388
|
+
|
|
389
|
+
# Mantisdk 0.3+
|
|
390
|
+
if (
|
|
391
|
+
self.span.name == AGL_OPERATION
|
|
392
|
+
and self.span.attributes.get(LightningSpanAttributes.OPERATION_NAME.value) == AGL_REWARD
|
|
393
|
+
):
|
|
394
|
+
return True
|
|
395
|
+
|
|
396
|
+
return False
|
|
397
|
+
|
|
398
|
+
def find_llm_calls(
|
|
399
|
+
self,
|
|
400
|
+
*,
|
|
401
|
+
llm_call_match: str,
|
|
402
|
+
agent_match: Optional[str],
|
|
403
|
+
within_matching_subtree: str | None = None,
|
|
404
|
+
within_reward: Optional[bool] = None,
|
|
405
|
+
within_llm_call: Optional[bool] = None,
|
|
406
|
+
existing_llm_call_response_ids: Optional[set[str]] = None,
|
|
407
|
+
) -> List[Tuple["TraceTree", str]]:
|
|
408
|
+
"""Find LLM call spans matching the supplied filters.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
llm_call_match: Regular expression used to match span names that qualify as LLM calls.
|
|
412
|
+
agent_match: Optional regular expression that must match the enclosing agent span name.
|
|
413
|
+
within_matching_subtree: Marker propagated through recursive calls to record matching agents.
|
|
414
|
+
within_reward: When `True`, suppresses LLM matches under reward spans.
|
|
415
|
+
within_llm_call: When `True`, prevents duplicate matches for nested LLM calls.
|
|
416
|
+
existing_llm_call_response_ids: Known response identifiers used to deduplicate spans.
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
A list of tuples pairing the matching node with the agent subtree label that triggered the
|
|
420
|
+
match.
|
|
421
|
+
"""
|
|
422
|
+
llm_calls: List[Tuple[TraceTree, str]] = []
|
|
423
|
+
|
|
424
|
+
is_llm_call = True
|
|
425
|
+
if within_matching_subtree is None or within_reward is True:
|
|
426
|
+
# We must be in an interesting agent subtree, and not in a reward span.
|
|
427
|
+
is_llm_call = False
|
|
428
|
+
if re.search(llm_call_match, self.span.name) is None:
|
|
429
|
+
# The span name does not match the LLM call match.
|
|
430
|
+
is_llm_call = False
|
|
431
|
+
if is_llm_call:
|
|
432
|
+
# Check the response id
|
|
433
|
+
response_id = _attributes_get_multiple(
|
|
434
|
+
self.span.attributes, ["gen_ai.response.id", "mantisdk.operation.output.id"]
|
|
435
|
+
)
|
|
436
|
+
if response_id is None and within_llm_call is True:
|
|
437
|
+
is_llm_call = False
|
|
438
|
+
if (
|
|
439
|
+
response_id is not None
|
|
440
|
+
and existing_llm_call_response_ids is not None
|
|
441
|
+
and response_id in existing_llm_call_response_ids
|
|
442
|
+
):
|
|
443
|
+
is_llm_call = False
|
|
444
|
+
|
|
445
|
+
if is_llm_call:
|
|
446
|
+
llm_calls.append((self, within_matching_subtree)) # type: ignore
|
|
447
|
+
if existing_llm_call_response_ids is None:
|
|
448
|
+
existing_llm_call_response_ids = set()
|
|
449
|
+
if response_id is not None:
|
|
450
|
+
existing_llm_call_response_ids.add(response_id)
|
|
451
|
+
if within_llm_call is not None:
|
|
452
|
+
within_llm_call = True
|
|
453
|
+
|
|
454
|
+
agent_name = self.agent_name()
|
|
455
|
+
if agent_name is not None:
|
|
456
|
+
if agent_match is None or re.search(agent_match, agent_name):
|
|
457
|
+
within_matching_subtree = agent_name
|
|
458
|
+
else:
|
|
459
|
+
within_matching_subtree = None
|
|
460
|
+
|
|
461
|
+
if within_reward is not None and self.is_reward_span():
|
|
462
|
+
within_reward = True
|
|
463
|
+
|
|
464
|
+
for child in self.children:
|
|
465
|
+
llm_calls.extend(
|
|
466
|
+
child.find_llm_calls(
|
|
467
|
+
llm_call_match=llm_call_match,
|
|
468
|
+
agent_match=agent_match,
|
|
469
|
+
within_matching_subtree=within_matching_subtree,
|
|
470
|
+
within_reward=within_reward,
|
|
471
|
+
within_llm_call=within_llm_call,
|
|
472
|
+
existing_llm_call_response_ids=existing_llm_call_response_ids,
|
|
473
|
+
)
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
return llm_calls
|
|
477
|
+
|
|
478
|
+
def repair_hierarchy(self) -> None:
|
|
479
|
+
"""Repair missing parent-child relationships introduced by mixed tracing systems.
|
|
480
|
+
|
|
481
|
+
Some agent frameworks emit spans via multiple subsystems, which can cause LLM completion
|
|
482
|
+
spans to float directly under the root span instead of being nested under the correct agent.
|
|
483
|
+
The method re-parents those spans to the closest ancestor that fully envelopes the child in
|
|
484
|
+
time.
|
|
485
|
+
|
|
486
|
+
If we don't, when we want to select the LLM completion span with agent as filter.
|
|
487
|
+
We will never get the correct span underneath.
|
|
488
|
+
"""
|
|
489
|
+
# If the current node has only one child, recursively repair its hierarchy directly.
|
|
490
|
+
# This special-case handling is needed because when a trace is manually ended
|
|
491
|
+
# (via agentops.end_trace), the AgentOps provider automatically wraps all spans
|
|
492
|
+
# under an extra synthetic root node (e.g., "run_one.session").
|
|
493
|
+
if len(self.children) == 1:
|
|
494
|
+
self.children[0].repair_hierarchy()
|
|
495
|
+
return
|
|
496
|
+
|
|
497
|
+
nodes_to_repair = list(self.children)
|
|
498
|
+
|
|
499
|
+
for repair_node in nodes_to_repair:
|
|
500
|
+
if len(self.children) == 1:
|
|
501
|
+
# If there is only one child, we don't need to repair the hierarchy.
|
|
502
|
+
break
|
|
503
|
+
# Find the closest parent span (but not the root itself)
|
|
504
|
+
closest_parent = None
|
|
505
|
+
closest_duration = float("inf")
|
|
506
|
+
for node in self.traverse():
|
|
507
|
+
if node.id == repair_node.id:
|
|
508
|
+
continue
|
|
509
|
+
if node is self:
|
|
510
|
+
continue
|
|
511
|
+
if node.start_time <= repair_node.start_time and node.end_time >= repair_node.end_time: # type: ignore
|
|
512
|
+
duration_delta = node.end_time - repair_node.end_time + repair_node.start_time - node.start_time # type: ignore
|
|
513
|
+
if duration_delta > 0 and duration_delta < closest_duration:
|
|
514
|
+
closest_duration = duration_delta # type: ignore
|
|
515
|
+
closest_parent = node
|
|
516
|
+
|
|
517
|
+
# Repair the hierarchy
|
|
518
|
+
if closest_parent is not None:
|
|
519
|
+
self.children.remove(repair_node)
|
|
520
|
+
closest_parent.children.append(repair_node)
|
|
521
|
+
|
|
522
|
+
def match_rewards(self, reward_match: str, llm_calls: List["TraceTree"]) -> dict[str, Optional[float]]:
|
|
523
|
+
"""Assign rewards to previously matched LLM calls.
|
|
524
|
+
|
|
525
|
+
Args:
|
|
526
|
+
reward_match: Strategy identifier from
|
|
527
|
+
[`RewardMatchPolicy`][mantisdk.adapter.triplet.RewardMatchPolicy].
|
|
528
|
+
llm_calls: Trace nodes representing LLM call spans.
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
Mapping from span identifier to reward value or `None` when no reward is available.
|
|
532
|
+
"""
|
|
533
|
+
llm_call_ids = set([llm_call.id for llm_call in llm_calls])
|
|
534
|
+
rewards: dict[str, Optional[float]] = {}
|
|
535
|
+
|
|
536
|
+
if reward_match == RewardMatchPolicy.FIRST_OCCURRENCE:
|
|
537
|
+
time_sorted: List[TraceTree] = cast(List[TraceTree], sorted(self.traverse(), key=lambda x: x.start_time)) # type: ignore
|
|
538
|
+
assign_to: List[Tuple[str, int]] = [] # type: ignore
|
|
539
|
+
for item in time_sorted:
|
|
540
|
+
if item.id in llm_call_ids:
|
|
541
|
+
assign_to.append((item.id, item.end_time)) # type: ignore
|
|
542
|
+
|
|
543
|
+
# get reward
|
|
544
|
+
agentops_output = item.maybe_reward_dict()
|
|
545
|
+
if agentops_output and agentops_output.get("type") == "reward":
|
|
546
|
+
for assign_to_id, assign_to_end_time in reversed(assign_to):
|
|
547
|
+
# This reward happens before the end of the LLM call.
|
|
548
|
+
if assign_to_end_time > item.start_time: # type: ignore
|
|
549
|
+
continue
|
|
550
|
+
# Ok, we found someone to assign to
|
|
551
|
+
if assign_to_id in rewards:
|
|
552
|
+
# If the reward is already set, skip
|
|
553
|
+
continue
|
|
554
|
+
rewards[assign_to_id] = agentops_output.get("value", None)
|
|
555
|
+
break
|
|
556
|
+
|
|
557
|
+
elif reward_match == RewardMatchPolicy.FIRST_SIBLING:
|
|
558
|
+
for item in self.traverse():
|
|
559
|
+
assign_to: List[Tuple[str, int]] = []
|
|
560
|
+
for child in item.children:
|
|
561
|
+
if child.id in llm_call_ids:
|
|
562
|
+
assign_to.append((child.id, child.end_time)) # type: ignore
|
|
563
|
+
|
|
564
|
+
agentops_output = child.maybe_reward_dict()
|
|
565
|
+
if agentops_output and agentops_output.get("type") == "reward":
|
|
566
|
+
for assign_to_id, assign_to_end_time in reversed(assign_to):
|
|
567
|
+
if assign_to_end_time > child.start_time: # type: ignore
|
|
568
|
+
# This reward happens before the end of the LLM call.
|
|
569
|
+
continue
|
|
570
|
+
if assign_to_id in rewards:
|
|
571
|
+
continue
|
|
572
|
+
rewards[assign_to_id] = agentops_output.get("value", None)
|
|
573
|
+
break
|
|
574
|
+
|
|
575
|
+
return rewards
|
|
576
|
+
|
|
577
|
+
def extract_prompt_image_urls(self, prompt_raw_content: Any) -> List[str]:
|
|
578
|
+
"""Extract image URLs from the span attributes, in order of appearance.
|
|
579
|
+
|
|
580
|
+
Args:
|
|
581
|
+
prompt_raw_content: The raw content of the prompt, which can be in one of several formats:
|
|
582
|
+
|
|
583
|
+
- List[dict]: A list of message entries, each being a dict with at least a "content" key.
|
|
584
|
+
- Dict[str, Any]: A dictionary, often with numeric string keys (e.g., `{"0": {...}, "1": {...}}`), where each value is a message entry.
|
|
585
|
+
If the dict does not have numeric keys, it is treated as a single message entry.
|
|
586
|
+
"""
|
|
587
|
+
message_entries: List[Any] = []
|
|
588
|
+
if isinstance(prompt_raw_content, list):
|
|
589
|
+
message_entries = cast(List[Any], prompt_raw_content)
|
|
590
|
+
elif isinstance(prompt_raw_content, dict):
|
|
591
|
+
# Common when the attributes expand to {"0": {...}, "prompt_filter_results": ...}
|
|
592
|
+
numeric_keys = [
|
|
593
|
+
key
|
|
594
|
+
for key in cast(Dict[str, Any], prompt_raw_content).keys()
|
|
595
|
+
if isinstance(key, str) and key.isdigit() # pyright: ignore[reportUnnecessaryIsInstance]
|
|
596
|
+
]
|
|
597
|
+
if numeric_keys:
|
|
598
|
+
for key in sorted(numeric_keys, key=int):
|
|
599
|
+
message_entries.append(prompt_raw_content[key])
|
|
600
|
+
else:
|
|
601
|
+
message_entries = [prompt_raw_content]
|
|
602
|
+
else:
|
|
603
|
+
return []
|
|
604
|
+
|
|
605
|
+
image_urls: List[str] = []
|
|
606
|
+
for message in cast(List[Dict[str, Any]], message_entries):
|
|
607
|
+
if (
|
|
608
|
+
not isinstance(message, dict) # pyright: ignore[reportUnnecessaryIsInstance]
|
|
609
|
+
or "content" not in message
|
|
610
|
+
):
|
|
611
|
+
continue
|
|
612
|
+
content = message["content"]
|
|
613
|
+
if isinstance(content, str):
|
|
614
|
+
try:
|
|
615
|
+
content = json.loads(content) # This content should now be a list
|
|
616
|
+
except json.JSONDecodeError:
|
|
617
|
+
logger.debug(f"Failed to parse message content as JSON: {content}")
|
|
618
|
+
continue
|
|
619
|
+
if isinstance(content, list):
|
|
620
|
+
for content_part in cast(List[Dict[str, Any]], content):
|
|
621
|
+
if not isinstance(content_part, dict): # pyright: ignore[reportUnnecessaryIsInstance]
|
|
622
|
+
continue
|
|
623
|
+
if content_part.get("type") == "image_url":
|
|
624
|
+
image_url_dict = cast(Dict[str, Any], content_part.get("image_url"))
|
|
625
|
+
if not isinstance(image_url_dict, dict): # pyright: ignore[reportUnnecessaryIsInstance]
|
|
626
|
+
continue
|
|
627
|
+
if "url" in image_url_dict:
|
|
628
|
+
image_urls.append(image_url_dict["url"])
|
|
629
|
+
return image_urls
|
|
630
|
+
|
|
631
|
+
def span_to_triplet(self, span: Span, agent_name: str) -> Triplet:
|
|
632
|
+
"""Convert a span to a triplet.
|
|
633
|
+
|
|
634
|
+
Subclass can override this method to add more fields to the triplet,
|
|
635
|
+
such as chat messages and tool calls.
|
|
636
|
+
"""
|
|
637
|
+
prompt_token_ids = (
|
|
638
|
+
_attributes_get_ids_multiple(
|
|
639
|
+
span.attributes,
|
|
640
|
+
[
|
|
641
|
+
"prompt_token_ids",
|
|
642
|
+
"mantisdk.operation.output.prompt_token_ids", # Weave tracer
|
|
643
|
+
],
|
|
644
|
+
)
|
|
645
|
+
or []
|
|
646
|
+
)
|
|
647
|
+
response_token_ids = (
|
|
648
|
+
_attributes_get_ids_multiple(
|
|
649
|
+
span.attributes,
|
|
650
|
+
[
|
|
651
|
+
"response_token_ids",
|
|
652
|
+
"mantisdk.operation.output.response_token_ids.0", # Weave tracer
|
|
653
|
+
"mantisdk.operation.output.choices.0.token_ids", # Weave tracer with newer vLLM
|
|
654
|
+
"mantisdk.operation.output.choices.0.provider_specific_fields.token_ids", # new vLLM + new OpenAI client SDK
|
|
655
|
+
],
|
|
656
|
+
)
|
|
657
|
+
or []
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
response_id = _attributes_get_multiple(
|
|
661
|
+
span.attributes, ["gen_ai.response.id", "mantisdk.operation.output.id"]
|
|
662
|
+
)
|
|
663
|
+
request_metadata = _attributes_unflatten_multiple(
|
|
664
|
+
span.attributes, ["gen_ai.request", "mantisdk.operation.input"]
|
|
665
|
+
)
|
|
666
|
+
response_metadata = _attributes_unflatten_multiple(
|
|
667
|
+
span.attributes, ["gen_ai.response", "mantisdk.operation.output"]
|
|
668
|
+
)
|
|
669
|
+
# Special handling for Weave tracer: messages are handled separately
|
|
670
|
+
if isinstance(request_metadata, dict):
|
|
671
|
+
request_metadata.pop("messages", None)
|
|
672
|
+
if isinstance(response_metadata, dict):
|
|
673
|
+
response_metadata.pop("choices", None)
|
|
674
|
+
response_metadata.pop("prompt_token_ids", None)
|
|
675
|
+
response_metadata.pop("response_token_ids", None)
|
|
676
|
+
|
|
677
|
+
prompt_raw_content = _attributes_unflatten_multiple(
|
|
678
|
+
span.attributes, ["gen_ai.prompt", "mantisdk.operation.input.messages"]
|
|
679
|
+
)
|
|
680
|
+
completion_raw_content = _attributes_unflatten_multiple(
|
|
681
|
+
span.attributes, ["gen_ai.completion", "mantisdk.operation.output.choices"]
|
|
682
|
+
)
|
|
683
|
+
image_urls = self.extract_prompt_image_urls(prompt_raw_content)
|
|
684
|
+
prompt_payload = {"token_ids": prompt_token_ids, "raw_content": prompt_raw_content, "image_urls": image_urls}
|
|
685
|
+
response_payload = {"token_ids": response_token_ids, "raw_content": completion_raw_content}
|
|
686
|
+
|
|
687
|
+
# FIXME: logprob doesn't support Weave tracer yet.
|
|
688
|
+
logprobs_content = span.attributes.get("logprobs.content", None) # type: ignore
|
|
689
|
+
if isinstance(logprobs_content, str):
|
|
690
|
+
logprobs_content = json.loads(logprobs_content)
|
|
691
|
+
response_payload["logprobs"] = logprobs_content
|
|
692
|
+
|
|
693
|
+
return Triplet(
|
|
694
|
+
prompt=prompt_payload,
|
|
695
|
+
response=response_payload,
|
|
696
|
+
reward=None,
|
|
697
|
+
metadata=dict(
|
|
698
|
+
request=request_metadata, response=response_metadata, response_id=response_id, agent_name=agent_name
|
|
699
|
+
),
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
def to_trajectory(
|
|
703
|
+
self,
|
|
704
|
+
llm_call_match: str = r"openai\.chat\.completion",
|
|
705
|
+
agent_match: Optional[str] = None,
|
|
706
|
+
exclude_llm_call_in_reward: bool = True,
|
|
707
|
+
dedup_llm_call: bool = True,
|
|
708
|
+
reward_match: RewardMatchPolicy = RewardMatchPolicy.FIRST_OCCURRENCE,
|
|
709
|
+
final_reward: Optional[float] = None,
|
|
710
|
+
_skip_empty_token_spans: bool = False,
|
|
711
|
+
) -> List[Triplet]:
|
|
712
|
+
"""Convert the trace tree into a trajectory of [`Triplet`][mantisdk.Triplet] items.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
llm_call_match: Regular expression for LLM call span names.
|
|
716
|
+
agent_match: Optional regular expression for agent span names.
|
|
717
|
+
exclude_llm_call_in_reward: When `True`, prevents searching for rewards under the LLM
|
|
718
|
+
call subtree.
|
|
719
|
+
dedup_llm_call: When `True`, deduplicates spans using the LLM response identifier.
|
|
720
|
+
reward_match: Reward matching policy used to associate reward spans with LLM calls.
|
|
721
|
+
final_reward: Optional reward appended to the final transition when provided.
|
|
722
|
+
|
|
723
|
+
Returns:
|
|
724
|
+
A list of [`Triplet`][mantisdk.Triplet] objects ordered by call sequence.
|
|
725
|
+
"""
|
|
726
|
+
# Find all LLM calls
|
|
727
|
+
llm_calls = self.find_llm_calls(
|
|
728
|
+
llm_call_match=llm_call_match,
|
|
729
|
+
agent_match=agent_match,
|
|
730
|
+
within_matching_subtree="*" if agent_match is None else None,
|
|
731
|
+
within_reward=False if exclude_llm_call_in_reward else None,
|
|
732
|
+
within_llm_call=False if dedup_llm_call else None,
|
|
733
|
+
existing_llm_call_response_ids=set(),
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
id_transitions: List[Tuple[str, Triplet]] = []
|
|
737
|
+
# We need to filter out the LLM calls with unrecorded token IDs
|
|
738
|
+
filtered_llm_calls: List[Tuple[TraceTree, str]] = []
|
|
739
|
+
for llm_call, agent_name in llm_calls:
|
|
740
|
+
triplet = self.span_to_triplet(llm_call.span, agent_name)
|
|
741
|
+
# This is a hot-fix for Tinker+CrewAI, which has some anonymous requests outside the trained agent.
|
|
742
|
+
# TODO: We might need to reconsider this.
|
|
743
|
+
if _skip_empty_token_spans and (
|
|
744
|
+
not triplet.prompt.get("token_ids") or not triplet.response.get("token_ids")
|
|
745
|
+
):
|
|
746
|
+
logger.warning(f"Skipping LLM call with unrecorded token IDs: {triplet}")
|
|
747
|
+
continue
|
|
748
|
+
filtered_llm_calls.append((llm_call, agent_name))
|
|
749
|
+
id_transitions.append((llm_call.id, triplet))
|
|
750
|
+
|
|
751
|
+
rewards = self.match_rewards(reward_match, [call for call, _ in filtered_llm_calls])
|
|
752
|
+
transitions = [
|
|
753
|
+
transition.model_copy(update={"reward": rewards.get(id, None)}) for id, transition in id_transitions
|
|
754
|
+
]
|
|
755
|
+
if final_reward is not None and len(transitions) > 0:
|
|
756
|
+
# Add the final reward to the last transition
|
|
757
|
+
transitions[-1] = transitions[-1].model_copy(update={"reward": final_reward})
|
|
758
|
+
return transitions
|
|
759
|
+
|
|
760
|
+
def __repr__(self):
|
|
761
|
+
return (
|
|
762
|
+
f"TraceTree(id={self.id}, span={self.span}, start_time={self.start_time}, "
|
|
763
|
+
+ f"end_time={self.end_time}, children={self.children})"
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
class TraceToTripletBase(TraceAdapter[List[Triplet]]):
|
|
768
|
+
"""Base class for adapters that emit [`Triplet`][mantisdk.Triplet] trajectories."""
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
class TracerTraceToTriplet(TraceToTripletBase):
|
|
772
|
+
"""Convert tracer-emitted spans into triplet trajectories.
|
|
773
|
+
|
|
774
|
+
Attributes:
|
|
775
|
+
repair_hierarchy: When `True`, repair the span tree using
|
|
776
|
+
[`TraceTree.repair_hierarchy()`][mantisdk.adapter.triplet.TraceTree.repair_hierarchy]
|
|
777
|
+
before matching calls and rewards.
|
|
778
|
+
llm_call_match: Regular expression pattern that selects LLM call span names.
|
|
779
|
+
agent_match: Optional regular expression pattern for agent span names. When omitted, spans
|
|
780
|
+
from any agent are considered.
|
|
781
|
+
exclude_llm_call_in_reward: When `True`, ignore matches under reward spans while searching
|
|
782
|
+
for rewards.
|
|
783
|
+
reward_match: Strategy used to associate rewards with LLM calls.
|
|
784
|
+
"""
|
|
785
|
+
|
|
786
|
+
def __init__(
|
|
787
|
+
self,
|
|
788
|
+
repair_hierarchy: bool = True,
|
|
789
|
+
llm_call_match: str = r"openai\.chat\.completion",
|
|
790
|
+
agent_match: Optional[str] = None,
|
|
791
|
+
exclude_llm_call_in_reward: bool = True,
|
|
792
|
+
reward_match: RewardMatchPolicy = RewardMatchPolicy.FIRST_OCCURRENCE,
|
|
793
|
+
_skip_empty_token_spans: bool = False,
|
|
794
|
+
):
|
|
795
|
+
self.repair_hierarchy = repair_hierarchy
|
|
796
|
+
self.llm_call_match = llm_call_match
|
|
797
|
+
self.agent_match = agent_match
|
|
798
|
+
self.exclude_llm_call_in_reward = exclude_llm_call_in_reward
|
|
799
|
+
self.reward_match = reward_match
|
|
800
|
+
self._skip_empty_token_spans = _skip_empty_token_spans
|
|
801
|
+
|
|
802
|
+
def visualize(
|
|
803
|
+
self,
|
|
804
|
+
source: Union[List[Span], List[ReadableSpan]],
|
|
805
|
+
/,
|
|
806
|
+
filename: str = "trace_tree",
|
|
807
|
+
interested_span_match: str | None = None,
|
|
808
|
+
) -> TraceTree:
|
|
809
|
+
"""Visualize the trace tree built from the supplied spans.
|
|
810
|
+
|
|
811
|
+
Args:
|
|
812
|
+
source: Collection of Mantisdk [`Span`][mantisdk.Span] objects
|
|
813
|
+
or raw `opentelemetry.sdk.trace.ReadableSpan` instances.
|
|
814
|
+
filename: Base filename for the generated image; `.png` is appended automatically.
|
|
815
|
+
interested_span_match: Optional regular expression used to highlight a subset of spans.
|
|
816
|
+
|
|
817
|
+
Returns:
|
|
818
|
+
The [`TraceTree`][mantisdk.adapter.triplet.TraceTree] built from the provided
|
|
819
|
+
spans.
|
|
820
|
+
"""
|
|
821
|
+
source_normalized = [
|
|
822
|
+
Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
|
|
823
|
+
for span in source
|
|
824
|
+
]
|
|
825
|
+
trace_tree = TraceTree.from_spans(source_normalized)
|
|
826
|
+
if self.repair_hierarchy:
|
|
827
|
+
trace_tree.repair_hierarchy()
|
|
828
|
+
trace_tree.visualize(filename, interested_span_match=interested_span_match)
|
|
829
|
+
return trace_tree
|
|
830
|
+
|
|
831
|
+
def adapt(self, source: Union[Sequence[Span], Sequence[ReadableSpan]], /) -> List[Triplet]: # type: ignore
|
|
832
|
+
"""Convert tracer spans into [`Triplet`][mantisdk.Triplet] trajectories.
|
|
833
|
+
|
|
834
|
+
Args:
|
|
835
|
+
source: Mantisdk spans or raw OpenTelemetry spans that form a trace.
|
|
836
|
+
|
|
837
|
+
Returns:
|
|
838
|
+
Ordered list of trajectory transitions with prompt, response, and reward information.
|
|
839
|
+
"""
|
|
840
|
+
source_normalized = [
|
|
841
|
+
Span.from_opentelemetry(span, "dummy", "dummy", 0) if isinstance(span, ReadableSpan) else span
|
|
842
|
+
for span in source
|
|
843
|
+
]
|
|
844
|
+
trace_tree = TraceTree.from_spans(source_normalized)
|
|
845
|
+
if self.repair_hierarchy:
|
|
846
|
+
trace_tree.repair_hierarchy()
|
|
847
|
+
trajectory = trace_tree.to_trajectory(
|
|
848
|
+
llm_call_match=self.llm_call_match,
|
|
849
|
+
agent_match=self.agent_match,
|
|
850
|
+
exclude_llm_call_in_reward=self.exclude_llm_call_in_reward,
|
|
851
|
+
reward_match=self.reward_match,
|
|
852
|
+
_skip_empty_token_spans=self._skip_empty_token_spans,
|
|
853
|
+
)
|
|
854
|
+
return trajectory
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
class LlmProxyTraceToTriplet(TraceToTripletBase):
|
|
858
|
+
"""Convert telemetry emitted by the LLM Proxy into triplet trajectories.
|
|
859
|
+
|
|
860
|
+
!!! warning
|
|
861
|
+
This adapter is experimental and might be merged with
|
|
862
|
+
[`TracerTraceToTriplet`][mantisdk.TracerTraceToTriplet] in the future.
|
|
863
|
+
|
|
864
|
+
!!! danger
|
|
865
|
+
Do not rely on timestamps when using this adapter. Proxy spans can originate on different
|
|
866
|
+
machines with unsynchronised clocks, so `sequence_id` is treated as the sole source of
|
|
867
|
+
ordering.
|
|
868
|
+
|
|
869
|
+
Strategy:
|
|
870
|
+
|
|
871
|
+
1. Sort spans by `(sequence_id, start_time)` for deterministic processing.
|
|
872
|
+
2. Extract token identifiers from `litellm_request` or `raw_gen_ai_request` spans.
|
|
873
|
+
3. Extract rewards from spans exposing AgentOps-style payloads or explicit reward spans.
|
|
874
|
+
4. Match each reward to the most recent unmatched LLM call whose sequence is smaller.
|
|
875
|
+
"""
|
|
876
|
+
|
|
877
|
+
def _literal_eval_maybe(self, v: Any) -> Any:
|
|
878
|
+
import ast
|
|
879
|
+
|
|
880
|
+
if isinstance(v, str):
|
|
881
|
+
try:
|
|
882
|
+
return ast.literal_eval(v)
|
|
883
|
+
except Exception:
|
|
884
|
+
return v
|
|
885
|
+
return v
|
|
886
|
+
|
|
887
|
+
def _extract_tokens_from_raw(self, attrs: Dict[str, Any]) -> Tuple[List[int], List[int]]:
|
|
888
|
+
"""Extract token ids from raw_gen_ai_request attributes.
|
|
889
|
+
|
|
890
|
+
- llm.hosted_vllm.prompt_token_ids: string -> List[int]
|
|
891
|
+
- llm.hosted_vllm.response_token_ids: string -> List[List[int]] -> take first
|
|
892
|
+
- llm.hosted_vllm.choices: string -> [{'token_ids': [...]}] -> take first
|
|
893
|
+
"""
|
|
894
|
+
prompt_ids: List[int] = []
|
|
895
|
+
resp_ids: List[int] = []
|
|
896
|
+
|
|
897
|
+
# prompt
|
|
898
|
+
p = attrs.get("llm.hosted_vllm.prompt_token_ids")
|
|
899
|
+
p = self._literal_eval_maybe(p)
|
|
900
|
+
if isinstance(p, list) and all(isinstance(x, int) for x in p): # type: ignore
|
|
901
|
+
prompt_ids = cast(List[int], p)
|
|
902
|
+
|
|
903
|
+
# response preferred path
|
|
904
|
+
r = attrs.get("llm.hosted_vllm.response_token_ids")
|
|
905
|
+
r = self._literal_eval_maybe(r)
|
|
906
|
+
if isinstance(r, list) and len(r) > 0 and isinstance(r[0], list): # type: ignore
|
|
907
|
+
first = cast(List[Any], r[0])
|
|
908
|
+
if all(isinstance(x, int) for x in first):
|
|
909
|
+
resp_ids = cast(List[int], first)
|
|
910
|
+
|
|
911
|
+
# fallback via choices
|
|
912
|
+
if not resp_ids:
|
|
913
|
+
choices = attrs.get("llm.hosted_vllm.choices")
|
|
914
|
+
choices = self._literal_eval_maybe(choices)
|
|
915
|
+
if isinstance(choices, list) and choices:
|
|
916
|
+
cand = cast(Any, choices[0])
|
|
917
|
+
if isinstance(cand, dict):
|
|
918
|
+
tids = cast(Dict[str, Any], cand).get("token_ids")
|
|
919
|
+
if isinstance(tids, list) and all(isinstance(x, int) for x in tids): # type: ignore
|
|
920
|
+
resp_ids = cast(List[int], tids)
|
|
921
|
+
|
|
922
|
+
return prompt_ids, resp_ids
|
|
923
|
+
|
|
924
|
+
def _extract_tokens_from_openai(self, attrs: Dict[str, Any]) -> Tuple[List[int], List[int]]:
|
|
925
|
+
prompt_ids = cast(Any, attrs.get("prompt_token_ids") or [])
|
|
926
|
+
resp_ids = cast(Any, attrs.get("response_token_ids") or [])
|
|
927
|
+
prompt_ids = self._literal_eval_maybe(prompt_ids)
|
|
928
|
+
resp_ids = self._literal_eval_maybe(resp_ids)
|
|
929
|
+
if not (isinstance(prompt_ids, list) and all(isinstance(x, int) for x in prompt_ids)): # type: ignore
|
|
930
|
+
prompt_ids = []
|
|
931
|
+
if not (isinstance(resp_ids, list) and all(isinstance(x, int) for x in resp_ids)): # type: ignore
|
|
932
|
+
resp_ids = []
|
|
933
|
+
return cast(List[int], prompt_ids), cast(List[int], resp_ids)
|
|
934
|
+
|
|
935
|
+
def _maybe_reward_value(self, span: Span) -> Optional[float]:
|
|
936
|
+
"""Parse reward from typical AgentOps payloads or explicit reward spans."""
|
|
937
|
+
return get_reward_value(span)
|
|
938
|
+
|
|
939
|
+
def _request_id_from_attrs(self, attrs: Dict[str, Any]) -> Optional[str]:
|
|
940
|
+
# Prefer OpenAI-like id if present, else proxy raw id.
|
|
941
|
+
rid = attrs.get("gen_ai.response.id") or attrs.get("llm.hosted_vllm.id")
|
|
942
|
+
return str(rid) if isinstance(rid, str) and rid else None
|
|
943
|
+
|
|
944
|
+
def adapt(self, source: Sequence[Span], /) -> List[Triplet]: # type: ignore
|
|
945
|
+
"""Convert LLM Proxy spans into [`Triplet`][mantisdk.Triplet] trajectories.
|
|
946
|
+
|
|
947
|
+
Args:
|
|
948
|
+
source: Spans emitted by the LLM Proxy containing prompt, response, and reward data.
|
|
949
|
+
|
|
950
|
+
Returns:
|
|
951
|
+
Ordered trajectory transitions matched purely by `sequence_id`.
|
|
952
|
+
"""
|
|
953
|
+
# 1) Sort deterministically by (sequence_id, start_time).
|
|
954
|
+
spans = sorted(
|
|
955
|
+
source,
|
|
956
|
+
key=lambda s: (s.sequence_id, s.start_time),
|
|
957
|
+
)
|
|
958
|
+
|
|
959
|
+
# 2) Collect LLM calls with token IDs.
|
|
960
|
+
llm_items: List[Dict[str, Any]] = []
|
|
961
|
+
seen_request_ids: set[str] = set()
|
|
962
|
+
for s in spans:
|
|
963
|
+
attrs = s.attributes or {}
|
|
964
|
+
prompt_ids: List[int] = []
|
|
965
|
+
resp_ids: List[int] = []
|
|
966
|
+
|
|
967
|
+
if s.name == "raw_gen_ai_request":
|
|
968
|
+
prompt_ids, resp_ids = self._extract_tokens_from_raw(attrs)
|
|
969
|
+
elif s.name == "litellm_request":
|
|
970
|
+
# Some proxies never include token ids here. Ignore unless present.
|
|
971
|
+
prompt_ids, resp_ids = self._extract_tokens_from_openai(attrs)
|
|
972
|
+
|
|
973
|
+
if prompt_ids and resp_ids:
|
|
974
|
+
rid = self._request_id_from_attrs(attrs)
|
|
975
|
+
if rid:
|
|
976
|
+
# Duplicated request ID. This request is already handled.
|
|
977
|
+
if rid in seen_request_ids:
|
|
978
|
+
continue
|
|
979
|
+
seen_request_ids.add(rid)
|
|
980
|
+
llm_items.append(
|
|
981
|
+
dict(
|
|
982
|
+
span=s,
|
|
983
|
+
seq=s.sequence_id,
|
|
984
|
+
response_ids=resp_ids,
|
|
985
|
+
prompt_ids=prompt_ids,
|
|
986
|
+
request_id=rid,
|
|
987
|
+
)
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
# Order LLM items by sequence only.
|
|
991
|
+
llm_items.sort(key=lambda x: x["seq"])
|
|
992
|
+
|
|
993
|
+
# Collect rewards by sequence only.
|
|
994
|
+
rewards: List[Tuple[int, Optional[float]]] = []
|
|
995
|
+
for s in spans:
|
|
996
|
+
val = self._maybe_reward_value(s)
|
|
997
|
+
if val is not None:
|
|
998
|
+
rewards.append((s.sequence_id, val))
|
|
999
|
+
|
|
1000
|
+
# First-occurrence matching by sequence_id only:
|
|
1001
|
+
# For reward at sequence R, assign to the most recent unmatched LLM with seq < R.
|
|
1002
|
+
assigned: Dict[str, Optional[float]] = {}
|
|
1003
|
+
for r_seq, r_val in sorted(rewards, key=lambda x: x[0]):
|
|
1004
|
+
for item in reversed(llm_items):
|
|
1005
|
+
sid = item["span"].span_id
|
|
1006
|
+
if sid in assigned:
|
|
1007
|
+
continue
|
|
1008
|
+
if item["seq"] < r_seq:
|
|
1009
|
+
assigned[sid] = r_val
|
|
1010
|
+
break
|
|
1011
|
+
|
|
1012
|
+
# Build triplets in LLM sequence order.
|
|
1013
|
+
triplets: List[Triplet] = []
|
|
1014
|
+
for item in llm_items:
|
|
1015
|
+
s = item["span"]
|
|
1016
|
+
triplets.append(
|
|
1017
|
+
Triplet(
|
|
1018
|
+
prompt={"token_ids": item["prompt_ids"]},
|
|
1019
|
+
response={"token_ids": item["response_ids"]},
|
|
1020
|
+
reward=assigned.get(s.span_id, None),
|
|
1021
|
+
metadata=dict(
|
|
1022
|
+
# This is called response_id to align with the other adapters.
|
|
1023
|
+
response_id=item["request_id"],
|
|
1024
|
+
),
|
|
1025
|
+
)
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
return triplets
|