ragbits-evaluate 0.0.30rc1__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/evaluate/agent_simulation/__init__.py +4 -49
- ragbits/evaluate/agent_simulation/conversation.py +278 -663
- ragbits/evaluate/agent_simulation/logger.py +1 -1
- ragbits/evaluate/agent_simulation/metrics/__init__.py +0 -10
- ragbits/evaluate/agent_simulation/metrics/builtin.py +49 -59
- ragbits/evaluate/agent_simulation/metrics/collectors.py +17 -37
- ragbits/evaluate/agent_simulation/models.py +18 -198
- ragbits/evaluate/agent_simulation/results.py +49 -125
- ragbits/evaluate/agent_simulation/scenarios.py +19 -95
- ragbits/evaluate/agent_simulation/simulation.py +166 -72
- ragbits/evaluate/metrics/question_answer.py +25 -8
- {ragbits_evaluate-0.0.30rc1.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +2 -6
- {ragbits_evaluate-0.0.30rc1.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/RECORD +14 -25
- ragbits/evaluate/agent_simulation/checkers.py +0 -591
- ragbits/evaluate/agent_simulation/display.py +0 -118
- ragbits/evaluate/agent_simulation/metrics/deepeval.py +0 -295
- ragbits/evaluate/agent_simulation/tracing.py +0 -233
- ragbits/evaluate/api.py +0 -603
- ragbits/evaluate/api_types.py +0 -343
- ragbits/evaluate/execution_manager.py +0 -451
- ragbits/evaluate/stores/__init__.py +0 -36
- ragbits/evaluate/stores/base.py +0 -98
- ragbits/evaluate/stores/file.py +0 -466
- ragbits/evaluate/stores/kv.py +0 -535
- {ragbits_evaluate-0.0.30rc1.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +0 -0
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
"""Rich display components for agent simulation."""
|
|
2
|
-
|
|
3
|
-
from rich.console import Console
|
|
4
|
-
from rich.live import Live
|
|
5
|
-
from rich.panel import Panel
|
|
6
|
-
from rich.text import Text
|
|
7
|
-
|
|
8
|
-
from ragbits.evaluate.agent_simulation.models import Scenario
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def display_scenario(scenario: Scenario, console: Console | None = None) -> None:
|
|
12
|
-
"""Display scenario with rich panel.
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
scenario: The scenario to display.
|
|
16
|
-
console: Optional Rich console instance. If not provided, a new one is created.
|
|
17
|
-
"""
|
|
18
|
-
if console is None:
|
|
19
|
-
console = Console()
|
|
20
|
-
|
|
21
|
-
console.print(_build_panel(scenario))
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def _build_panel(
|
|
25
|
-
scenario: Scenario,
|
|
26
|
-
current_task_idx: int | None = None,
|
|
27
|
-
task_status: dict[int, str] | None = None,
|
|
28
|
-
metrics: dict[str, str | int | float] | None = None,
|
|
29
|
-
) -> Panel:
|
|
30
|
-
"""Build a rich panel for the scenario.
|
|
31
|
-
|
|
32
|
-
Args:
|
|
33
|
-
scenario: The scenario to display.
|
|
34
|
-
current_task_idx: Index of currently running task (for live display).
|
|
35
|
-
task_status: Dict mapping task index to status emoji/text.
|
|
36
|
-
metrics: Optional metrics to display at the bottom.
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
Rich Panel object.
|
|
40
|
-
"""
|
|
41
|
-
lines = Text()
|
|
42
|
-
|
|
43
|
-
for i, task in enumerate(scenario.tasks):
|
|
44
|
-
# Status indicator
|
|
45
|
-
if task_status and i in task_status:
|
|
46
|
-
status = task_status[i]
|
|
47
|
-
elif current_task_idx is not None and i == current_task_idx:
|
|
48
|
-
status = "▶"
|
|
49
|
-
elif current_task_idx is not None and i < current_task_idx:
|
|
50
|
-
status = "✓"
|
|
51
|
-
else:
|
|
52
|
-
status = " "
|
|
53
|
-
|
|
54
|
-
style = "bold" if current_task_idx == i else ""
|
|
55
|
-
lines.append(f"{status} {i + 1}. {task.task}\n", style=style)
|
|
56
|
-
# Show checker configuration summary
|
|
57
|
-
if task.checkers:
|
|
58
|
-
lines.append(f" → {task.get_checker_summary()}\n", style="green")
|
|
59
|
-
|
|
60
|
-
if metrics:
|
|
61
|
-
lines.append("\n")
|
|
62
|
-
for key, value in metrics.items():
|
|
63
|
-
lines.append(f"{key}: {value} ", style="cyan")
|
|
64
|
-
|
|
65
|
-
title = scenario.name
|
|
66
|
-
if scenario.group:
|
|
67
|
-
title += f" [dim]({scenario.group})[/dim]"
|
|
68
|
-
|
|
69
|
-
return Panel(lines, title=title, border_style="blue")
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class ScenarioLiveDisplay:
|
|
73
|
-
"""Live display for scenario execution with real-time updates."""
|
|
74
|
-
|
|
75
|
-
def __init__(self, scenario: Scenario, console: Console | None = None) -> None:
|
|
76
|
-
self.scenario = scenario
|
|
77
|
-
self.console = console or Console()
|
|
78
|
-
self.current_task_idx: int | None = None
|
|
79
|
-
self.task_status: dict[int, str] = {}
|
|
80
|
-
self.metrics: dict[str, str | int | float] = {}
|
|
81
|
-
self._live: Live | None = None
|
|
82
|
-
|
|
83
|
-
def __enter__(self) -> "ScenarioLiveDisplay":
|
|
84
|
-
self._live = Live(self._render(), console=self.console, refresh_per_second=4)
|
|
85
|
-
self._live.__enter__()
|
|
86
|
-
return self
|
|
87
|
-
|
|
88
|
-
def __exit__(self, *args: object) -> None:
|
|
89
|
-
if self._live:
|
|
90
|
-
self._live.__exit__(*args)
|
|
91
|
-
|
|
92
|
-
def _render(self) -> Panel:
|
|
93
|
-
return _build_panel(
|
|
94
|
-
self.scenario,
|
|
95
|
-
current_task_idx=self.current_task_idx,
|
|
96
|
-
task_status=self.task_status,
|
|
97
|
-
metrics=self.metrics,
|
|
98
|
-
)
|
|
99
|
-
|
|
100
|
-
def update(self) -> None:
|
|
101
|
-
"""Refresh the display."""
|
|
102
|
-
if self._live:
|
|
103
|
-
self._live.update(self._render())
|
|
104
|
-
|
|
105
|
-
def set_current_task(self, idx: int) -> None:
|
|
106
|
-
"""Set the currently running task index."""
|
|
107
|
-
self.current_task_idx = idx
|
|
108
|
-
self.update()
|
|
109
|
-
|
|
110
|
-
def mark_task_done(self, idx: int, success: bool = True) -> None:
|
|
111
|
-
"""Mark a task as completed."""
|
|
112
|
-
self.task_status[idx] = "✓" if success else "✗"
|
|
113
|
-
self.update()
|
|
114
|
-
|
|
115
|
-
def set_metric(self, key: str, value: str | int | float) -> None:
|
|
116
|
-
"""Update a metric value."""
|
|
117
|
-
self.metrics[key] = value
|
|
118
|
-
self.update()
|
|
@@ -1,295 +0,0 @@
|
|
|
1
|
-
"""DeepEval metric collectors following the MetricCollector protocol."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import TYPE_CHECKING, Any
|
|
6
|
-
|
|
7
|
-
from ragbits.evaluate.agent_simulation.metrics.collectors import MetricCollector
|
|
8
|
-
|
|
9
|
-
if TYPE_CHECKING:
|
|
10
|
-
from ragbits.evaluate.agent_simulation.results import TurnResult
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class DeepEvalCompletenessMetricCollector(MetricCollector):
|
|
14
|
-
"""Tracks conversation completeness using DeepEval's ConversationCompletenessMetric.
|
|
15
|
-
|
|
16
|
-
Evaluates how well the assistant addresses the user's requests throughout
|
|
17
|
-
the conversation.
|
|
18
|
-
|
|
19
|
-
Example:
|
|
20
|
-
>>> result = await run_simulation(
|
|
21
|
-
... scenario=scenario,
|
|
22
|
-
... chat=chat,
|
|
23
|
-
... config=SimulationConfig(metrics=[DeepEvalCompletenessMetricCollector]),
|
|
24
|
-
... )
|
|
25
|
-
>>> print(result.metrics.custom.get("deepeval_completeness"))
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
def __init__(self) -> None:
|
|
29
|
-
"""Initialize the completeness metric collector."""
|
|
30
|
-
self._turns: list[tuple[str, str]] = [] # (user, assistant) pairs
|
|
31
|
-
|
|
32
|
-
def on_turn_start(self, turn_index: int, task_index: int, user_message: str) -> None:
|
|
33
|
-
"""No-op for DeepEval completeness collector.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
turn_index: 1-based index of the current turn.
|
|
37
|
-
task_index: 0-based index of the current task.
|
|
38
|
-
user_message: The user message (stored in on_turn_end).
|
|
39
|
-
"""
|
|
40
|
-
pass
|
|
41
|
-
|
|
42
|
-
def on_turn_end(self, turn_result: TurnResult) -> None:
|
|
43
|
-
"""Record the turn for later evaluation.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
turn_result: The result of the completed turn.
|
|
47
|
-
"""
|
|
48
|
-
self._turns.append((turn_result.user_message, turn_result.assistant_message))
|
|
49
|
-
|
|
50
|
-
def on_conversation_end(self, all_turns: list[TurnResult]) -> dict[str, Any]:
|
|
51
|
-
"""Evaluate conversation completeness using DeepEval.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
all_turns: List of all turn results.
|
|
55
|
-
|
|
56
|
-
Returns:
|
|
57
|
-
Dictionary with deepeval_completeness score and reason.
|
|
58
|
-
"""
|
|
59
|
-
if not self._turns:
|
|
60
|
-
return {}
|
|
61
|
-
|
|
62
|
-
try:
|
|
63
|
-
from deepeval.metrics import ConversationCompletenessMetric # type: ignore[attr-defined]
|
|
64
|
-
from deepeval.test_case import ConversationalTestCase, LLMTestCase # type: ignore[attr-defined]
|
|
65
|
-
|
|
66
|
-
deepeval_turns = [LLMTestCase(input=user, actual_output=assistant) for user, assistant in self._turns]
|
|
67
|
-
test_case = ConversationalTestCase(turns=deepeval_turns)
|
|
68
|
-
metric = ConversationCompletenessMetric()
|
|
69
|
-
metric.measure(test_case)
|
|
70
|
-
|
|
71
|
-
return {
|
|
72
|
-
"deepeval_completeness": metric.score,
|
|
73
|
-
"deepeval_completeness_reason": getattr(metric, "reason", None),
|
|
74
|
-
}
|
|
75
|
-
except Exception as e:
|
|
76
|
-
return {
|
|
77
|
-
"deepeval_completeness": None,
|
|
78
|
-
"deepeval_completeness_error": str(e),
|
|
79
|
-
}
|
|
80
|
-
|
|
81
|
-
def reset(self) -> None:
|
|
82
|
-
"""Reset collector state for a new conversation."""
|
|
83
|
-
self._turns = []
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
class DeepEvalRelevancyMetricCollector(MetricCollector):
|
|
87
|
-
"""Tracks conversation relevancy using DeepEval's ConversationRelevancyMetric.
|
|
88
|
-
|
|
89
|
-
Evaluates how relevant the assistant's responses are to the user's queries.
|
|
90
|
-
|
|
91
|
-
Example:
|
|
92
|
-
>>> result = await run_simulation(
|
|
93
|
-
... scenario=scenario,
|
|
94
|
-
... chat=chat,
|
|
95
|
-
... config=SimulationConfig(metrics=[DeepEvalRelevancyMetricCollector]),
|
|
96
|
-
... )
|
|
97
|
-
>>> print(result.metrics.custom.get("deepeval_relevancy"))
|
|
98
|
-
"""
|
|
99
|
-
|
|
100
|
-
def __init__(self) -> None:
|
|
101
|
-
"""Initialize the relevancy metric collector."""
|
|
102
|
-
self._turns: list[tuple[str, str]] = []
|
|
103
|
-
|
|
104
|
-
def on_turn_start(self, turn_index: int, task_index: int, user_message: str) -> None:
|
|
105
|
-
"""No-op for DeepEval relevancy collector.
|
|
106
|
-
|
|
107
|
-
Args:
|
|
108
|
-
turn_index: 1-based index of the current turn.
|
|
109
|
-
task_index: 0-based index of the current task.
|
|
110
|
-
user_message: The user message (stored in on_turn_end).
|
|
111
|
-
"""
|
|
112
|
-
pass
|
|
113
|
-
|
|
114
|
-
def on_turn_end(self, turn_result: TurnResult) -> None:
|
|
115
|
-
"""Record the turn for later evaluation.
|
|
116
|
-
|
|
117
|
-
Args:
|
|
118
|
-
turn_result: The result of the completed turn.
|
|
119
|
-
"""
|
|
120
|
-
self._turns.append((turn_result.user_message, turn_result.assistant_message))
|
|
121
|
-
|
|
122
|
-
def on_conversation_end(self, all_turns: list[TurnResult]) -> dict[str, Any]:
|
|
123
|
-
"""Evaluate conversation relevancy using DeepEval.
|
|
124
|
-
|
|
125
|
-
Args:
|
|
126
|
-
all_turns: List of all turn results.
|
|
127
|
-
|
|
128
|
-
Returns:
|
|
129
|
-
Dictionary with deepeval_relevancy score and reason.
|
|
130
|
-
"""
|
|
131
|
-
if not self._turns:
|
|
132
|
-
return {}
|
|
133
|
-
|
|
134
|
-
try:
|
|
135
|
-
from deepeval.metrics import ConversationRelevancyMetric # type: ignore[attr-defined]
|
|
136
|
-
from deepeval.test_case import ConversationalTestCase, LLMTestCase # type: ignore[attr-defined]
|
|
137
|
-
|
|
138
|
-
deepeval_turns = [LLMTestCase(input=user, actual_output=assistant) for user, assistant in self._turns]
|
|
139
|
-
test_case = ConversationalTestCase(turns=deepeval_turns)
|
|
140
|
-
metric = ConversationRelevancyMetric()
|
|
141
|
-
metric.measure(test_case)
|
|
142
|
-
|
|
143
|
-
return {
|
|
144
|
-
"deepeval_relevancy": metric.score,
|
|
145
|
-
"deepeval_relevancy_reason": getattr(metric, "reason", None),
|
|
146
|
-
}
|
|
147
|
-
except Exception as e:
|
|
148
|
-
return {
|
|
149
|
-
"deepeval_relevancy": None,
|
|
150
|
-
"deepeval_relevancy_error": str(e),
|
|
151
|
-
}
|
|
152
|
-
|
|
153
|
-
def reset(self) -> None:
|
|
154
|
-
"""Reset collector state for a new conversation."""
|
|
155
|
-
self._turns = []
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
class DeepEvalKnowledgeRetentionMetricCollector(MetricCollector):
|
|
159
|
-
"""Tracks knowledge retention using DeepEval's KnowledgeRetentionMetric.
|
|
160
|
-
|
|
161
|
-
Evaluates how well the assistant retains and uses information from earlier
|
|
162
|
-
in the conversation.
|
|
163
|
-
|
|
164
|
-
Example:
|
|
165
|
-
>>> result = await run_simulation(
|
|
166
|
-
... scenario=scenario,
|
|
167
|
-
... chat=chat,
|
|
168
|
-
... config=SimulationConfig(metrics=[DeepEvalKnowledgeRetentionMetricCollector]),
|
|
169
|
-
... )
|
|
170
|
-
>>> print(result.metrics.custom.get("deepeval_knowledge_retention"))
|
|
171
|
-
"""
|
|
172
|
-
|
|
173
|
-
def __init__(self) -> None:
|
|
174
|
-
"""Initialize the knowledge retention metric collector."""
|
|
175
|
-
self._turns: list[tuple[str, str]] = []
|
|
176
|
-
|
|
177
|
-
def on_turn_start(self, turn_index: int, task_index: int, user_message: str) -> None:
|
|
178
|
-
"""No-op for DeepEval knowledge retention collector.
|
|
179
|
-
|
|
180
|
-
Args:
|
|
181
|
-
turn_index: 1-based index of the current turn.
|
|
182
|
-
task_index: 0-based index of the current task.
|
|
183
|
-
user_message: The user message (stored in on_turn_end).
|
|
184
|
-
"""
|
|
185
|
-
pass
|
|
186
|
-
|
|
187
|
-
def on_turn_end(self, turn_result: TurnResult) -> None:
|
|
188
|
-
"""Record the turn for later evaluation.
|
|
189
|
-
|
|
190
|
-
Args:
|
|
191
|
-
turn_result: The result of the completed turn.
|
|
192
|
-
"""
|
|
193
|
-
self._turns.append((turn_result.user_message, turn_result.assistant_message))
|
|
194
|
-
|
|
195
|
-
def on_conversation_end(self, all_turns: list[TurnResult]) -> dict[str, Any]:
|
|
196
|
-
"""Evaluate knowledge retention using DeepEval.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
all_turns: List of all turn results.
|
|
200
|
-
|
|
201
|
-
Returns:
|
|
202
|
-
Dictionary with deepeval_knowledge_retention score and reason.
|
|
203
|
-
"""
|
|
204
|
-
if not self._turns:
|
|
205
|
-
return {}
|
|
206
|
-
|
|
207
|
-
try:
|
|
208
|
-
from deepeval.metrics import KnowledgeRetentionMetric # type: ignore[attr-defined]
|
|
209
|
-
from deepeval.test_case import ConversationalTestCase, LLMTestCase # type: ignore[attr-defined]
|
|
210
|
-
|
|
211
|
-
deepeval_turns = [LLMTestCase(input=user, actual_output=assistant) for user, assistant in self._turns]
|
|
212
|
-
test_case = ConversationalTestCase(turns=deepeval_turns)
|
|
213
|
-
metric = KnowledgeRetentionMetric()
|
|
214
|
-
metric.measure(test_case)
|
|
215
|
-
|
|
216
|
-
return {
|
|
217
|
-
"deepeval_knowledge_retention": metric.score,
|
|
218
|
-
"deepeval_knowledge_retention_reason": getattr(metric, "reason", None),
|
|
219
|
-
}
|
|
220
|
-
except Exception as e:
|
|
221
|
-
return {
|
|
222
|
-
"deepeval_knowledge_retention": None,
|
|
223
|
-
"deepeval_knowledge_retention_error": str(e),
|
|
224
|
-
}
|
|
225
|
-
|
|
226
|
-
def reset(self) -> None:
|
|
227
|
-
"""Reset collector state for a new conversation."""
|
|
228
|
-
self._turns = []
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
class DeepEvalAllMetricsCollector(MetricCollector):
|
|
232
|
-
"""Composite collector that evaluates all DeepEval conversation metrics.
|
|
233
|
-
|
|
234
|
-
Runs all three DeepEval metrics (completeness, relevancy, knowledge retention)
|
|
235
|
-
at the end of the conversation.
|
|
236
|
-
|
|
237
|
-
Example:
|
|
238
|
-
>>> result = await run_simulation(
|
|
239
|
-
... scenario=scenario,
|
|
240
|
-
... chat=chat,
|
|
241
|
-
... config=SimulationConfig(metrics=[DeepEvalAllMetricsCollector]),
|
|
242
|
-
... )
|
|
243
|
-
>>> print(result.metrics.custom.get("deepeval_completeness"))
|
|
244
|
-
>>> print(result.metrics.custom.get("deepeval_relevancy"))
|
|
245
|
-
>>> print(result.metrics.custom.get("deepeval_knowledge_retention"))
|
|
246
|
-
"""
|
|
247
|
-
|
|
248
|
-
def __init__(self) -> None:
|
|
249
|
-
"""Initialize the all-metrics collector."""
|
|
250
|
-
self._completeness = DeepEvalCompletenessMetricCollector()
|
|
251
|
-
self._relevancy = DeepEvalRelevancyMetricCollector()
|
|
252
|
-
self._knowledge_retention = DeepEvalKnowledgeRetentionMetricCollector()
|
|
253
|
-
|
|
254
|
-
def on_turn_start(self, turn_index: int, task_index: int, user_message: str) -> None:
|
|
255
|
-
"""Delegate to all child collectors.
|
|
256
|
-
|
|
257
|
-
Args:
|
|
258
|
-
turn_index: 1-based index of the current turn.
|
|
259
|
-
task_index: 0-based index of the current task.
|
|
260
|
-
user_message: The user message.
|
|
261
|
-
"""
|
|
262
|
-
self._completeness.on_turn_start(turn_index, task_index, user_message)
|
|
263
|
-
self._relevancy.on_turn_start(turn_index, task_index, user_message)
|
|
264
|
-
self._knowledge_retention.on_turn_start(turn_index, task_index, user_message)
|
|
265
|
-
|
|
266
|
-
def on_turn_end(self, turn_result: TurnResult) -> None:
|
|
267
|
-
"""Delegate to all child collectors.
|
|
268
|
-
|
|
269
|
-
Args:
|
|
270
|
-
turn_result: The result of the completed turn.
|
|
271
|
-
"""
|
|
272
|
-
self._completeness.on_turn_end(turn_result)
|
|
273
|
-
self._relevancy.on_turn_end(turn_result)
|
|
274
|
-
self._knowledge_retention.on_turn_end(turn_result)
|
|
275
|
-
|
|
276
|
-
def on_conversation_end(self, all_turns: list[TurnResult]) -> dict[str, Any]:
|
|
277
|
-
"""Aggregate metrics from all child collectors.
|
|
278
|
-
|
|
279
|
-
Args:
|
|
280
|
-
all_turns: List of all turn results.
|
|
281
|
-
|
|
282
|
-
Returns:
|
|
283
|
-
Dictionary combining all DeepEval metrics.
|
|
284
|
-
"""
|
|
285
|
-
combined: dict[str, Any] = {}
|
|
286
|
-
combined.update(self._completeness.on_conversation_end(all_turns))
|
|
287
|
-
combined.update(self._relevancy.on_conversation_end(all_turns))
|
|
288
|
-
combined.update(self._knowledge_retention.on_conversation_end(all_turns))
|
|
289
|
-
return combined
|
|
290
|
-
|
|
291
|
-
def reset(self) -> None:
|
|
292
|
-
"""Reset all child collectors."""
|
|
293
|
-
self._completeness.reset()
|
|
294
|
-
self._relevancy.reset()
|
|
295
|
-
self._knowledge_retention.reset()
|
|
@@ -1,233 +0,0 @@
|
|
|
1
|
-
"""Tracing utilities for agent simulation.
|
|
2
|
-
|
|
3
|
-
Provides context managers and analyzers for capturing and analyzing
|
|
4
|
-
LLM calls, tool invocations, and token usage during simulation runs.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
from collections.abc import Iterator
|
|
8
|
-
from contextlib import contextmanager
|
|
9
|
-
from contextvars import Token
|
|
10
|
-
from dataclasses import dataclass
|
|
11
|
-
from typing import Any
|
|
12
|
-
|
|
13
|
-
from ragbits.agents.tool import ToolCallResult
|
|
14
|
-
from ragbits.core.audit.traces import MemoryTraceHandler, set_trace_handlers
|
|
15
|
-
from ragbits.core.audit.traces.memory import TraceSpan, _TraceSession, _current_session
|
|
16
|
-
from ragbits.core.llms import Usage
|
|
17
|
-
from ragbits.core.llms.base import UsageItem
|
|
18
|
-
|
|
19
|
-
__all__ = [
|
|
20
|
-
"LLMCall",
|
|
21
|
-
"MemoryTraceHandler",
|
|
22
|
-
"TraceAnalyzer",
|
|
23
|
-
"TraceSpan",
|
|
24
|
-
"collect_traces",
|
|
25
|
-
]
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@dataclass
|
|
29
|
-
class LLMCall:
|
|
30
|
-
"""Represents a single LLM call extracted from traces."""
|
|
31
|
-
|
|
32
|
-
model: str
|
|
33
|
-
prompt_tokens: int
|
|
34
|
-
completion_tokens: int
|
|
35
|
-
total_tokens: int
|
|
36
|
-
duration_ms: float | None = None
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@contextmanager
|
|
40
|
-
def collect_traces(simulation_id: str | None = None) -> Iterator[MemoryTraceHandler]:
|
|
41
|
-
"""Context manager for collecting traces during a simulation.
|
|
42
|
-
|
|
43
|
-
Sets up a context-local trace session and registers a MemoryTraceHandler
|
|
44
|
-
to capture all traced operations within the context.
|
|
45
|
-
|
|
46
|
-
Args:
|
|
47
|
-
simulation_id: Optional identifier for the simulation run.
|
|
48
|
-
|
|
49
|
-
Yields:
|
|
50
|
-
MemoryTraceHandler instance that captures traces for this context.
|
|
51
|
-
|
|
52
|
-
Example:
|
|
53
|
-
with collect_traces(simulation_id="sim-123") as handler:
|
|
54
|
-
# Run simulation code here
|
|
55
|
-
traces = handler.get_traces()
|
|
56
|
-
"""
|
|
57
|
-
# Create a new session for this context
|
|
58
|
-
session = _TraceSession()
|
|
59
|
-
token: Token[_TraceSession | None] = _current_session.set(session)
|
|
60
|
-
|
|
61
|
-
# Create and register the handler
|
|
62
|
-
handler = MemoryTraceHandler()
|
|
63
|
-
set_trace_handlers(handler)
|
|
64
|
-
|
|
65
|
-
try:
|
|
66
|
-
yield handler
|
|
67
|
-
finally:
|
|
68
|
-
# Restore previous session state
|
|
69
|
-
_current_session.reset(token)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class TraceAnalyzer:
|
|
73
|
-
"""Analyzes trace spans to extract tool calls and usage information.
|
|
74
|
-
|
|
75
|
-
This class processes trace data collected by MemoryTraceHandler to
|
|
76
|
-
provide structured access to tool invocations and token usage metrics.
|
|
77
|
-
"""
|
|
78
|
-
|
|
79
|
-
def __init__(self, traces: list[dict[str, Any]]) -> None:
|
|
80
|
-
"""Initialize the analyzer with trace data.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
traces: List of trace span dictionaries from MemoryTraceHandler.get_traces().
|
|
84
|
-
"""
|
|
85
|
-
self._traces = traces
|
|
86
|
-
self._spans = [TraceSpan.from_dict(t) for t in traces]
|
|
87
|
-
|
|
88
|
-
@classmethod
|
|
89
|
-
def from_traces(cls, traces: list[dict[str, Any]]) -> "TraceAnalyzer":
|
|
90
|
-
"""Create an analyzer from trace dictionaries.
|
|
91
|
-
|
|
92
|
-
Args:
|
|
93
|
-
traces: List of trace span dictionaries.
|
|
94
|
-
|
|
95
|
-
Returns:
|
|
96
|
-
A new TraceAnalyzer instance.
|
|
97
|
-
"""
|
|
98
|
-
return cls(traces)
|
|
99
|
-
|
|
100
|
-
def get_tool_calls(self) -> list[ToolCallResult]:
|
|
101
|
-
"""Extract all tool call results from the traces.
|
|
102
|
-
|
|
103
|
-
Searches through all spans (including nested children) for tool
|
|
104
|
-
invocation traces and extracts the tool call information.
|
|
105
|
-
|
|
106
|
-
Returns:
|
|
107
|
-
List of ToolCallResult objects representing all tool calls.
|
|
108
|
-
"""
|
|
109
|
-
tool_calls: list[ToolCallResult] = []
|
|
110
|
-
self._extract_tool_calls_recursive(self._spans, tool_calls)
|
|
111
|
-
return tool_calls
|
|
112
|
-
|
|
113
|
-
def _extract_tool_calls_recursive(self, spans: list[TraceSpan], results: list[ToolCallResult]) -> None:
|
|
114
|
-
"""Recursively extract tool calls from spans and their children.
|
|
115
|
-
|
|
116
|
-
Args:
|
|
117
|
-
spans: List of spans to process.
|
|
118
|
-
results: List to append found tool calls to.
|
|
119
|
-
"""
|
|
120
|
-
for span in spans:
|
|
121
|
-
# Check if this span represents a tool call
|
|
122
|
-
if self._is_tool_call_span(span):
|
|
123
|
-
tool_result = self._extract_tool_call(span)
|
|
124
|
-
if tool_result:
|
|
125
|
-
results.append(tool_result)
|
|
126
|
-
|
|
127
|
-
# Recurse into children
|
|
128
|
-
if span.children:
|
|
129
|
-
self._extract_tool_calls_recursive(span.children, results)
|
|
130
|
-
|
|
131
|
-
def _is_tool_call_span(self, span: TraceSpan) -> bool:
|
|
132
|
-
"""Check if a span represents a tool call.
|
|
133
|
-
|
|
134
|
-
Args:
|
|
135
|
-
span: The span to check.
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
True if the span is a tool call, False otherwise.
|
|
139
|
-
"""
|
|
140
|
-
# Tool calls typically have names like "Tool.call" or contain tool-related info
|
|
141
|
-
name = span.name.lower()
|
|
142
|
-
return "tool" in name and ("call" in name or "execute" in name or "invoke" in name)
|
|
143
|
-
|
|
144
|
-
def _extract_tool_call(self, span: TraceSpan) -> ToolCallResult | None:
|
|
145
|
-
"""Extract a ToolCallResult from a tool call span.
|
|
146
|
-
|
|
147
|
-
Args:
|
|
148
|
-
span: The tool call span.
|
|
149
|
-
|
|
150
|
-
Returns:
|
|
151
|
-
ToolCallResult if extraction succeeds, None otherwise.
|
|
152
|
-
"""
|
|
153
|
-
inputs = span.inputs
|
|
154
|
-
outputs = span.outputs
|
|
155
|
-
|
|
156
|
-
# Try to extract tool call info from span data
|
|
157
|
-
tool_name = inputs.get("name", inputs.get("tool_name", span.name))
|
|
158
|
-
tool_id = inputs.get("id", inputs.get("tool_id", ""))
|
|
159
|
-
arguments = inputs.get("arguments", inputs.get("args", {}))
|
|
160
|
-
result = outputs.get("result", outputs.get("returned", None))
|
|
161
|
-
|
|
162
|
-
if isinstance(tool_name, str):
|
|
163
|
-
return ToolCallResult(
|
|
164
|
-
id=str(tool_id) if tool_id else "",
|
|
165
|
-
name=tool_name,
|
|
166
|
-
arguments=arguments if isinstance(arguments, dict) else {},
|
|
167
|
-
result=result,
|
|
168
|
-
)
|
|
169
|
-
return None
|
|
170
|
-
|
|
171
|
-
def get_usage(self) -> Usage:
|
|
172
|
-
"""Extract aggregated token usage from the traces.
|
|
173
|
-
|
|
174
|
-
Searches through all spans for LLM call traces and aggregates
|
|
175
|
-
the token usage information.
|
|
176
|
-
|
|
177
|
-
Returns:
|
|
178
|
-
Usage object with aggregated token usage across all LLM calls.
|
|
179
|
-
"""
|
|
180
|
-
usage_items: list[UsageItem] = []
|
|
181
|
-
self._extract_usage_recursive(self._spans, usage_items)
|
|
182
|
-
return Usage(requests=usage_items)
|
|
183
|
-
|
|
184
|
-
def _extract_usage_recursive(self, spans: list[TraceSpan], results: list[UsageItem]) -> None:
|
|
185
|
-
"""Recursively extract usage info from spans and their children.
|
|
186
|
-
|
|
187
|
-
Args:
|
|
188
|
-
spans: List of spans to process.
|
|
189
|
-
results: List to append found usage items to.
|
|
190
|
-
"""
|
|
191
|
-
for span in spans:
|
|
192
|
-
# Check if this span has usage information
|
|
193
|
-
usage_item = self._extract_usage_from_span(span)
|
|
194
|
-
if usage_item:
|
|
195
|
-
results.append(usage_item)
|
|
196
|
-
|
|
197
|
-
# Recurse into children
|
|
198
|
-
if span.children:
|
|
199
|
-
self._extract_usage_recursive(span.children, results)
|
|
200
|
-
|
|
201
|
-
def _extract_usage_from_span(self, span: TraceSpan) -> UsageItem | None:
|
|
202
|
-
"""Extract a UsageItem from a span if it contains usage data.
|
|
203
|
-
|
|
204
|
-
Args:
|
|
205
|
-
span: The span to extract usage from.
|
|
206
|
-
|
|
207
|
-
Returns:
|
|
208
|
-
UsageItem if extraction succeeds, None otherwise.
|
|
209
|
-
"""
|
|
210
|
-
outputs = span.outputs
|
|
211
|
-
|
|
212
|
-
# Check for usage in outputs
|
|
213
|
-
usage_data = outputs.get("usage", None)
|
|
214
|
-
if isinstance(usage_data, dict):
|
|
215
|
-
return UsageItem(
|
|
216
|
-
model=usage_data.get("model", "unknown"),
|
|
217
|
-
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
218
|
-
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
219
|
-
total_tokens=usage_data.get("total_tokens", 0),
|
|
220
|
-
estimated_cost=usage_data.get("estimated_cost", 0.0),
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
# Check for direct token counts in outputs
|
|
224
|
-
if "prompt_tokens" in outputs or "completion_tokens" in outputs:
|
|
225
|
-
return UsageItem(
|
|
226
|
-
model=outputs.get("model", span.inputs.get("model", "unknown")),
|
|
227
|
-
prompt_tokens=outputs.get("prompt_tokens", 0),
|
|
228
|
-
completion_tokens=outputs.get("completion_tokens", 0),
|
|
229
|
-
total_tokens=outputs.get("total_tokens", 0),
|
|
230
|
-
estimated_cost=outputs.get("estimated_cost", 0.0),
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
return None
|