judgeval 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +139 -12
- judgeval/api/__init__.py +501 -0
- judgeval/api/api_types.py +344 -0
- judgeval/cli.py +2 -4
- judgeval/constants.py +10 -26
- judgeval/data/evaluation_run.py +49 -26
- judgeval/data/example.py +2 -2
- judgeval/data/judgment_types.py +266 -82
- judgeval/data/result.py +4 -5
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +2 -2
- judgeval/data/trace.py +7 -50
- judgeval/data/trace_run.py +7 -4
- judgeval/{dataset.py → dataset/__init__.py} +43 -28
- judgeval/env.py +67 -0
- judgeval/{run_evaluation.py → evaluation/__init__.py} +29 -95
- judgeval/exceptions.py +27 -0
- judgeval/integrations/langgraph/__init__.py +788 -0
- judgeval/judges/__init__.py +2 -2
- judgeval/judges/litellm_judge.py +75 -15
- judgeval/judges/together_judge.py +86 -18
- judgeval/judges/utils.py +7 -21
- judgeval/{common/logger.py → logger.py} +8 -6
- judgeval/scorers/__init__.py +0 -4
- judgeval/scorers/agent_scorer.py +3 -7
- judgeval/scorers/api_scorer.py +8 -13
- judgeval/scorers/base_scorer.py +52 -32
- judgeval/scorers/example_scorer.py +1 -3
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -14
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +45 -20
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +2 -2
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +3 -3
- judgeval/scorers/score.py +21 -31
- judgeval/scorers/trace_api_scorer.py +5 -0
- judgeval/scorers/utils.py +1 -103
- judgeval/tracer/__init__.py +1075 -2
- judgeval/tracer/constants.py +1 -0
- judgeval/tracer/exporters/__init__.py +37 -0
- judgeval/tracer/exporters/s3.py +119 -0
- judgeval/tracer/exporters/store.py +43 -0
- judgeval/tracer/exporters/utils.py +32 -0
- judgeval/tracer/keys.py +67 -0
- judgeval/tracer/llm/__init__.py +1233 -0
- judgeval/{common/tracer → tracer/llm}/providers.py +5 -10
- judgeval/{local_eval_queue.py → tracer/local_eval_queue.py} +15 -10
- judgeval/tracer/managers.py +188 -0
- judgeval/tracer/processors/__init__.py +181 -0
- judgeval/tracer/utils.py +20 -0
- judgeval/trainer/__init__.py +5 -0
- judgeval/{common/trainer → trainer}/config.py +12 -9
- judgeval/{common/trainer → trainer}/console.py +2 -9
- judgeval/{common/trainer → trainer}/trainable_model.py +12 -7
- judgeval/{common/trainer → trainer}/trainer.py +119 -17
- judgeval/utils/async_utils.py +2 -3
- judgeval/utils/decorators.py +24 -0
- judgeval/utils/file_utils.py +37 -4
- judgeval/utils/guards.py +32 -0
- judgeval/utils/meta.py +14 -0
- judgeval/{common/api/json_encoder.py → utils/serialize.py} +7 -1
- judgeval/utils/testing.py +88 -0
- judgeval/utils/url.py +10 -0
- judgeval/{version_check.py → utils/version_check.py} +3 -3
- judgeval/version.py +5 -0
- judgeval/warnings.py +4 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/METADATA +12 -14
- judgeval-0.9.0.dist-info/RECORD +80 -0
- judgeval/clients.py +0 -35
- judgeval/common/__init__.py +0 -13
- judgeval/common/api/__init__.py +0 -3
- judgeval/common/api/api.py +0 -375
- judgeval/common/api/constants.py +0 -186
- judgeval/common/exceptions.py +0 -27
- judgeval/common/storage/__init__.py +0 -6
- judgeval/common/storage/s3_storage.py +0 -97
- judgeval/common/tracer/__init__.py +0 -31
- judgeval/common/tracer/constants.py +0 -22
- judgeval/common/tracer/core.py +0 -2427
- judgeval/common/tracer/otel_exporter.py +0 -108
- judgeval/common/tracer/otel_span_processor.py +0 -188
- judgeval/common/tracer/span_processor.py +0 -37
- judgeval/common/tracer/span_transformer.py +0 -207
- judgeval/common/tracer/trace_manager.py +0 -101
- judgeval/common/trainer/__init__.py +0 -5
- judgeval/common/utils.py +0 -948
- judgeval/integrations/langgraph.py +0 -844
- judgeval/judges/mixture_of_judges.py +0 -287
- judgeval/judgment_client.py +0 -267
- judgeval/rules.py +0 -521
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
- judgeval/utils/alerts.py +0 -93
- judgeval/utils/requests.py +0 -50
- judgeval-0.7.1.dist-info/RECORD +0 -82
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/WHEEL +0 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/entry_points.txt +0 -0
- {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,844 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, List, Optional, Sequence
|
2
|
-
from uuid import UUID
|
3
|
-
import time
|
4
|
-
import uuid
|
5
|
-
from datetime import datetime, timezone
|
6
|
-
|
7
|
-
from judgeval.common.tracer import (
|
8
|
-
TraceClient,
|
9
|
-
TraceSpan,
|
10
|
-
Tracer,
|
11
|
-
SpanType,
|
12
|
-
cost_per_token,
|
13
|
-
)
|
14
|
-
from judgeval.data.trace import TraceUsage
|
15
|
-
|
16
|
-
from langchain_core.callbacks import BaseCallbackHandler
|
17
|
-
from langchain_core.agents import AgentAction, AgentFinish
|
18
|
-
from langchain_core.outputs import LLMResult
|
19
|
-
from langchain_core.messages.base import BaseMessage
|
20
|
-
from langchain_core.documents import Document
|
21
|
-
|
22
|
-
# TODO: Figure out how to handle context variables. Current solution is to keep track of current span id in Tracer class
|
23
|
-
|
24
|
-
|
25
|
-
class JudgevalCallbackHandler(BaseCallbackHandler):
|
26
|
-
"""
|
27
|
-
LangChain Callback Handler using run_id/parent_run_id for hierarchy.
|
28
|
-
Manages its own internal TraceClient instance created upon first use.
|
29
|
-
Includes verbose logging and defensive checks.
|
30
|
-
"""
|
31
|
-
|
32
|
-
# Make all properties ignored by LangChain's callback system
|
33
|
-
# to prevent unexpected serialization issues.
|
34
|
-
lc_serializable = False
|
35
|
-
lc_kwargs: dict = {}
|
36
|
-
|
37
|
-
def __init__(self, tracer: Tracer):
|
38
|
-
self.tracer = tracer
|
39
|
-
self.executed_nodes: List[str] = []
|
40
|
-
self._reset_state()
|
41
|
-
|
42
|
-
def _reset_state(self):
|
43
|
-
"""Reset only the critical execution state for reuse across multiple executions"""
|
44
|
-
# Reset core execution state that must be cleared between runs
|
45
|
-
self._trace_client: Optional[TraceClient] = None
|
46
|
-
self._run_id_to_span_id: Dict[UUID, str] = {}
|
47
|
-
self._span_id_to_start_time: Dict[str, float] = {}
|
48
|
-
self._span_id_to_depth: Dict[str, int] = {}
|
49
|
-
self._root_run_id: Optional[UUID] = None
|
50
|
-
self._trace_saved: bool = False
|
51
|
-
self.span_id_to_token: Dict[str, Any] = {}
|
52
|
-
self.trace_id_to_token: Dict[str, Any] = {}
|
53
|
-
|
54
|
-
# Add timestamp to track when we last reset
|
55
|
-
self._last_reset_time: float = time.time()
|
56
|
-
|
57
|
-
# Also reset tracking/logging variables
|
58
|
-
self.executed_nodes: List[str] = []
|
59
|
-
|
60
|
-
def reset(self):
|
61
|
-
"""Public method to manually reset handler execution state for reuse"""
|
62
|
-
self._reset_state()
|
63
|
-
|
64
|
-
def reset_all(self):
|
65
|
-
"""Public method to reset ALL handler state including tracking/logging data"""
|
66
|
-
self._reset_state()
|
67
|
-
|
68
|
-
def _ensure_trace_client(
|
69
|
-
self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str
|
70
|
-
) -> Optional[TraceClient]:
|
71
|
-
"""
|
72
|
-
Ensures the internal trace client is initialized, creating it only once
|
73
|
-
per handler instance lifecycle (effectively per graph invocation).
|
74
|
-
Returns the client or None.
|
75
|
-
"""
|
76
|
-
|
77
|
-
# If this is a potential new root execution (no parent_run_id) and we had a previous trace saved,
|
78
|
-
# reset state to allow reuse of the handler
|
79
|
-
if parent_run_id is None and self._trace_saved:
|
80
|
-
self._reset_state()
|
81
|
-
|
82
|
-
# If a client already exists, return it.
|
83
|
-
if self._trace_client:
|
84
|
-
return self._trace_client
|
85
|
-
|
86
|
-
# If no client exists, initialize it NOW.
|
87
|
-
trace_id = str(uuid.uuid4())
|
88
|
-
project = self.tracer.project_name
|
89
|
-
try:
|
90
|
-
# Use event_name as the initial trace name, might be updated later by on_chain_start if root
|
91
|
-
client_instance = TraceClient(
|
92
|
-
self.tracer,
|
93
|
-
trace_id,
|
94
|
-
event_name,
|
95
|
-
project_name=project,
|
96
|
-
enable_monitoring=self.tracer.enable_monitoring,
|
97
|
-
enable_evaluations=self.tracer.enable_evaluations,
|
98
|
-
)
|
99
|
-
self._trace_client = client_instance
|
100
|
-
token = self.tracer.set_current_trace(self._trace_client)
|
101
|
-
if token:
|
102
|
-
self.trace_id_to_token[trace_id] = token
|
103
|
-
|
104
|
-
if self._trace_client:
|
105
|
-
self._root_run_id = run_id
|
106
|
-
self._trace_saved = False
|
107
|
-
self.tracer._active_trace_client = self._trace_client
|
108
|
-
|
109
|
-
try:
|
110
|
-
self._trace_client.save(final_save=False)
|
111
|
-
except Exception as e:
|
112
|
-
import warnings
|
113
|
-
|
114
|
-
warnings.warn(
|
115
|
-
f"Failed to save initial trace for live tracking: {e}"
|
116
|
-
)
|
117
|
-
|
118
|
-
return self._trace_client
|
119
|
-
else:
|
120
|
-
return None
|
121
|
-
except Exception:
|
122
|
-
self._trace_client = None
|
123
|
-
self._root_run_id = None
|
124
|
-
return None
|
125
|
-
|
126
|
-
def _start_span_tracking(
|
127
|
-
self,
|
128
|
-
trace_client: TraceClient,
|
129
|
-
run_id: UUID,
|
130
|
-
parent_run_id: Optional[UUID],
|
131
|
-
name: str,
|
132
|
-
span_type: SpanType = "span",
|
133
|
-
inputs: Optional[Dict[str, Any]] = None,
|
134
|
-
) -> None:
|
135
|
-
"""Start tracking a span, ensuring trace client exists"""
|
136
|
-
if name.startswith("__") and name.endswith("__"):
|
137
|
-
return
|
138
|
-
start_time = time.time()
|
139
|
-
span_id = str(uuid.uuid4())
|
140
|
-
parent_span_id: Optional[str] = None
|
141
|
-
current_depth = 0
|
142
|
-
|
143
|
-
if parent_run_id and parent_run_id in self._run_id_to_span_id:
|
144
|
-
parent_span_id = self._run_id_to_span_id[parent_run_id]
|
145
|
-
if parent_span_id in self._span_id_to_depth:
|
146
|
-
current_depth = self._span_id_to_depth[parent_span_id] + 1
|
147
|
-
|
148
|
-
self._run_id_to_span_id[run_id] = span_id
|
149
|
-
self._span_id_to_start_time[span_id] = start_time
|
150
|
-
self._span_id_to_depth[span_id] = current_depth
|
151
|
-
|
152
|
-
new_span = TraceSpan(
|
153
|
-
span_id=span_id,
|
154
|
-
trace_id=trace_client.trace_id,
|
155
|
-
parent_span_id=parent_span_id,
|
156
|
-
function=name,
|
157
|
-
depth=current_depth,
|
158
|
-
created_at=start_time,
|
159
|
-
span_type=span_type,
|
160
|
-
)
|
161
|
-
|
162
|
-
# Separate metadata from inputs
|
163
|
-
if inputs:
|
164
|
-
metadata = {}
|
165
|
-
clean_inputs = {}
|
166
|
-
|
167
|
-
# Extract metadata fields
|
168
|
-
metadata_fields = ["tags", "metadata", "kwargs", "serialized"]
|
169
|
-
for field in metadata_fields:
|
170
|
-
if field in inputs:
|
171
|
-
metadata[field] = inputs.pop(field)
|
172
|
-
|
173
|
-
# Store the remaining inputs
|
174
|
-
clean_inputs = inputs
|
175
|
-
|
176
|
-
# Set both fields on the span
|
177
|
-
new_span.inputs = clean_inputs
|
178
|
-
new_span.additional_metadata = metadata
|
179
|
-
else:
|
180
|
-
new_span.inputs = {}
|
181
|
-
new_span.additional_metadata = {}
|
182
|
-
|
183
|
-
trace_client.add_span(new_span)
|
184
|
-
|
185
|
-
trace_client.otel_span_processor.queue_span_update(new_span, span_state="input")
|
186
|
-
|
187
|
-
token = self.tracer.set_current_span(span_id)
|
188
|
-
if token:
|
189
|
-
self.span_id_to_token[span_id] = token
|
190
|
-
|
191
|
-
def _end_span_tracking(
|
192
|
-
self,
|
193
|
-
trace_client: TraceClient,
|
194
|
-
run_id: UUID,
|
195
|
-
outputs: Optional[Any] = None,
|
196
|
-
error: Optional[BaseException] = None,
|
197
|
-
) -> None:
|
198
|
-
"""End tracking a span, ensuring trace client exists"""
|
199
|
-
|
200
|
-
# Get span ID and check if it exists
|
201
|
-
span_id = self._run_id_to_span_id.get(run_id)
|
202
|
-
if span_id:
|
203
|
-
token = self.span_id_to_token.pop(span_id, None)
|
204
|
-
self.tracer.reset_current_span(token, span_id)
|
205
|
-
|
206
|
-
start_time = self._span_id_to_start_time.get(span_id) if span_id else None
|
207
|
-
duration = time.time() - start_time if start_time is not None else None
|
208
|
-
|
209
|
-
# Add exit entry (only if span was tracked)
|
210
|
-
if span_id:
|
211
|
-
trace_span = trace_client.span_id_to_span.get(span_id)
|
212
|
-
if trace_span:
|
213
|
-
trace_span.duration = duration
|
214
|
-
|
215
|
-
# Handle outputs and error
|
216
|
-
if error:
|
217
|
-
trace_span.output = error
|
218
|
-
elif outputs:
|
219
|
-
# Separate metadata from outputs
|
220
|
-
metadata = {}
|
221
|
-
clean_outputs = {}
|
222
|
-
|
223
|
-
# Extract metadata fields
|
224
|
-
metadata_fields = ["tags", "kwargs"]
|
225
|
-
if isinstance(outputs, dict):
|
226
|
-
for field in metadata_fields:
|
227
|
-
if field in outputs:
|
228
|
-
metadata[field] = outputs.pop(field)
|
229
|
-
|
230
|
-
# Store the remaining outputs
|
231
|
-
clean_outputs = outputs
|
232
|
-
else:
|
233
|
-
clean_outputs = outputs
|
234
|
-
|
235
|
-
# Set both fields on the span
|
236
|
-
trace_span.output = clean_outputs
|
237
|
-
if metadata:
|
238
|
-
# Merge with existing metadata
|
239
|
-
existing_metadata = trace_span.additional_metadata or {}
|
240
|
-
trace_span.additional_metadata = {
|
241
|
-
**existing_metadata,
|
242
|
-
**metadata,
|
243
|
-
}
|
244
|
-
|
245
|
-
span_state = "error" if error else "completed"
|
246
|
-
trace_client.otel_span_processor.queue_span_update(
|
247
|
-
trace_span, span_state=span_state
|
248
|
-
)
|
249
|
-
|
250
|
-
# Clean up dictionaries for this specific span
|
251
|
-
if span_id in self._span_id_to_start_time:
|
252
|
-
del self._span_id_to_start_time[span_id]
|
253
|
-
if span_id in self._span_id_to_depth:
|
254
|
-
del self._span_id_to_depth[span_id]
|
255
|
-
|
256
|
-
# Check if this is the root run ending
|
257
|
-
if run_id == self._root_run_id:
|
258
|
-
try:
|
259
|
-
self._root_run_id = None
|
260
|
-
if (
|
261
|
-
self._trace_client and not self._trace_saved
|
262
|
-
): # Check if not already saved
|
263
|
-
complete_trace_data = {
|
264
|
-
"trace_id": self._trace_client.trace_id,
|
265
|
-
"name": self._trace_client.name,
|
266
|
-
"created_at": datetime.fromtimestamp(
|
267
|
-
self._trace_client.start_time, timezone.utc
|
268
|
-
).isoformat(),
|
269
|
-
"duration": self._trace_client.get_duration(),
|
270
|
-
"trace_spans": [
|
271
|
-
span.model_dump() for span in self._trace_client.trace_spans
|
272
|
-
],
|
273
|
-
"offline_mode": self.tracer.offline_mode,
|
274
|
-
"parent_trace_id": self._trace_client.parent_trace_id,
|
275
|
-
"parent_name": self._trace_client.parent_name,
|
276
|
-
}
|
277
|
-
|
278
|
-
self.tracer.flush_background_spans()
|
279
|
-
|
280
|
-
trace_id, trace_data = self._trace_client.save(
|
281
|
-
final_save=True, # Final save with usage counter updates
|
282
|
-
)
|
283
|
-
token = self.trace_id_to_token.pop(trace_id, None)
|
284
|
-
self.tracer.reset_current_trace(token, trace_id)
|
285
|
-
|
286
|
-
# Store complete trace data instead of server response
|
287
|
-
self.tracer.traces.append(complete_trace_data)
|
288
|
-
self._trace_saved = True # Set flag only after successful save
|
289
|
-
finally:
|
290
|
-
# This block executes regardless of save success/failure
|
291
|
-
# Reset root run id
|
292
|
-
self._root_run_id = None
|
293
|
-
# Reset input storage for this handler instance
|
294
|
-
if self.tracer._active_trace_client == self._trace_client:
|
295
|
-
self.tracer._active_trace_client = None
|
296
|
-
|
297
|
-
# --- Callback Methods ---
|
298
|
-
# Each method now ensures the trace client exists before proceeding
|
299
|
-
|
300
|
-
def on_retriever_start(
|
301
|
-
self,
|
302
|
-
serialized: Dict[str, Any],
|
303
|
-
query: str,
|
304
|
-
*,
|
305
|
-
run_id: UUID,
|
306
|
-
parent_run_id: Optional[UUID] = None,
|
307
|
-
tags: Optional[List[str]] = None,
|
308
|
-
metadata: Optional[Dict[str, Any]] = None,
|
309
|
-
**kwargs: Any,
|
310
|
-
) -> Any:
|
311
|
-
serialized_name = (
|
312
|
-
serialized.get("name", "Unknown")
|
313
|
-
if serialized
|
314
|
-
else "Unknown (Serialized=None)"
|
315
|
-
)
|
316
|
-
|
317
|
-
name = f"RETRIEVER_{(serialized_name).upper()}"
|
318
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
319
|
-
if not trace_client:
|
320
|
-
return
|
321
|
-
|
322
|
-
inputs = {
|
323
|
-
"query": query,
|
324
|
-
"tags": tags,
|
325
|
-
"metadata": metadata,
|
326
|
-
"kwargs": kwargs,
|
327
|
-
"serialized": serialized,
|
328
|
-
}
|
329
|
-
self._start_span_tracking(
|
330
|
-
trace_client,
|
331
|
-
run_id,
|
332
|
-
parent_run_id,
|
333
|
-
name,
|
334
|
-
span_type="retriever",
|
335
|
-
inputs=inputs,
|
336
|
-
)
|
337
|
-
|
338
|
-
def on_retriever_end(
|
339
|
-
self,
|
340
|
-
documents: Sequence[Document],
|
341
|
-
*,
|
342
|
-
run_id: UUID,
|
343
|
-
parent_run_id: Optional[UUID] = None,
|
344
|
-
**kwargs: Any,
|
345
|
-
) -> Any:
|
346
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd")
|
347
|
-
if not trace_client:
|
348
|
-
return
|
349
|
-
doc_summary = [
|
350
|
-
{
|
351
|
-
"index": i,
|
352
|
-
"page_content": (
|
353
|
-
doc.page_content[:100] + "..."
|
354
|
-
if len(doc.page_content) > 100
|
355
|
-
else doc.page_content
|
356
|
-
),
|
357
|
-
"metadata": doc.metadata,
|
358
|
-
}
|
359
|
-
for i, doc in enumerate(documents)
|
360
|
-
]
|
361
|
-
outputs = {
|
362
|
-
"document_count": len(documents),
|
363
|
-
"documents": doc_summary,
|
364
|
-
"kwargs": kwargs,
|
365
|
-
}
|
366
|
-
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
367
|
-
|
368
|
-
def on_chain_start(
|
369
|
-
self,
|
370
|
-
serialized: Dict[str, Any],
|
371
|
-
inputs: Dict[str, Any],
|
372
|
-
*,
|
373
|
-
run_id: UUID,
|
374
|
-
parent_run_id: Optional[UUID] = None,
|
375
|
-
tags: Optional[List[str]] = None,
|
376
|
-
metadata: Optional[Dict[str, Any]] = None,
|
377
|
-
**kwargs: Any,
|
378
|
-
) -> None:
|
379
|
-
serialized_name = (
|
380
|
-
serialized.get("name") if serialized else "Unknown (Serialized=None)"
|
381
|
-
)
|
382
|
-
|
383
|
-
# --- Determine Name and Span Type ---
|
384
|
-
span_type: SpanType = "chain"
|
385
|
-
name = serialized_name if serialized_name else "Unknown Chain"
|
386
|
-
node_name = metadata.get("langgraph_node") if metadata else None
|
387
|
-
is_langgraph_root_kwarg = (
|
388
|
-
kwargs.get("name") == "LangGraph"
|
389
|
-
) # Check kwargs for explicit root name
|
390
|
-
# More robust root detection: Often the first chain event with parent_run_id=None *is* the root.
|
391
|
-
is_potential_root_event = parent_run_id is None
|
392
|
-
|
393
|
-
if node_name:
|
394
|
-
name = node_name # Use node name if available
|
395
|
-
if name not in self.executed_nodes:
|
396
|
-
self.executed_nodes.append(
|
397
|
-
name
|
398
|
-
) # Leaving this in for now but can probably be removed
|
399
|
-
elif is_langgraph_root_kwarg and is_potential_root_event:
|
400
|
-
name = "LangGraph" # Explicit root detected
|
401
|
-
# Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
|
402
|
-
|
403
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
404
|
-
if not trace_client:
|
405
|
-
return
|
406
|
-
|
407
|
-
if (
|
408
|
-
is_potential_root_event
|
409
|
-
and run_id == self._root_run_id
|
410
|
-
and trace_client.name != name
|
411
|
-
):
|
412
|
-
trace_client.name = name
|
413
|
-
|
414
|
-
combined_inputs = {
|
415
|
-
"inputs": inputs,
|
416
|
-
"tags": tags,
|
417
|
-
"metadata": metadata,
|
418
|
-
"kwargs": kwargs,
|
419
|
-
"serialized": serialized,
|
420
|
-
}
|
421
|
-
self._start_span_tracking(
|
422
|
-
trace_client,
|
423
|
-
run_id,
|
424
|
-
parent_run_id,
|
425
|
-
name,
|
426
|
-
span_type=span_type,
|
427
|
-
inputs=combined_inputs,
|
428
|
-
)
|
429
|
-
|
430
|
-
def on_chain_end(
|
431
|
-
self,
|
432
|
-
outputs: Dict[str, Any],
|
433
|
-
*,
|
434
|
-
run_id: UUID,
|
435
|
-
parent_run_id: Optional[UUID] = None,
|
436
|
-
tags: Optional[List[str]] = None,
|
437
|
-
**kwargs: Any,
|
438
|
-
) -> Any:
|
439
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd")
|
440
|
-
if not trace_client:
|
441
|
-
return
|
442
|
-
|
443
|
-
span_id = self._run_id_to_span_id.get(run_id)
|
444
|
-
if not span_id and run_id != self._root_run_id:
|
445
|
-
return
|
446
|
-
|
447
|
-
combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
|
448
|
-
|
449
|
-
self._end_span_tracking(trace_client, run_id, outputs=combined_outputs)
|
450
|
-
|
451
|
-
if run_id == self._root_run_id:
|
452
|
-
if trace_client and not self._trace_saved:
|
453
|
-
complete_trace_data = {
|
454
|
-
"trace_id": trace_client.trace_id,
|
455
|
-
"name": trace_client.name,
|
456
|
-
"created_at": datetime.fromtimestamp(
|
457
|
-
trace_client.start_time, timezone.utc
|
458
|
-
).isoformat(),
|
459
|
-
"duration": trace_client.get_duration(),
|
460
|
-
"trace_spans": [
|
461
|
-
span.model_dump() for span in trace_client.trace_spans
|
462
|
-
],
|
463
|
-
"offline_mode": self.tracer.offline_mode,
|
464
|
-
"parent_trace_id": trace_client.parent_trace_id,
|
465
|
-
"parent_name": trace_client.parent_name,
|
466
|
-
}
|
467
|
-
|
468
|
-
self.tracer.flush_background_spans()
|
469
|
-
|
470
|
-
trace_client.save(
|
471
|
-
final_save=True,
|
472
|
-
)
|
473
|
-
|
474
|
-
self.tracer.traces.append(complete_trace_data)
|
475
|
-
self._trace_saved = True
|
476
|
-
if self.tracer._active_trace_client == trace_client:
|
477
|
-
self.tracer._active_trace_client = None
|
478
|
-
|
479
|
-
self._root_run_id = None
|
480
|
-
|
481
|
-
def on_chain_error(
|
482
|
-
self,
|
483
|
-
error: BaseException,
|
484
|
-
*,
|
485
|
-
run_id: UUID,
|
486
|
-
parent_run_id: Optional[UUID] = None,
|
487
|
-
**kwargs: Any,
|
488
|
-
) -> Any:
|
489
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError")
|
490
|
-
if not trace_client:
|
491
|
-
return
|
492
|
-
|
493
|
-
span_id = self._run_id_to_span_id.get(run_id)
|
494
|
-
|
495
|
-
if not span_id and run_id != self._root_run_id:
|
496
|
-
return
|
497
|
-
|
498
|
-
self._end_span_tracking(trace_client, run_id, error=error)
|
499
|
-
|
500
|
-
def on_tool_start(
|
501
|
-
self,
|
502
|
-
serialized: Dict[str, Any],
|
503
|
-
input_str: str,
|
504
|
-
*,
|
505
|
-
run_id: UUID,
|
506
|
-
parent_run_id: Optional[UUID] = None,
|
507
|
-
tags: Optional[List[str]] = None,
|
508
|
-
metadata: Optional[Dict[str, Any]] = None,
|
509
|
-
inputs: Optional[Dict[str, Any]] = None,
|
510
|
-
**kwargs: Any,
|
511
|
-
) -> Any:
|
512
|
-
name = (
|
513
|
-
serialized.get("name", "Unnamed Tool")
|
514
|
-
if serialized
|
515
|
-
else "Unknown Tool (Serialized=None)"
|
516
|
-
)
|
517
|
-
|
518
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
519
|
-
if not trace_client:
|
520
|
-
return
|
521
|
-
|
522
|
-
combined_inputs = {
|
523
|
-
"input_str": input_str,
|
524
|
-
"inputs": inputs,
|
525
|
-
"tags": tags,
|
526
|
-
"metadata": metadata,
|
527
|
-
"kwargs": kwargs,
|
528
|
-
"serialized": serialized,
|
529
|
-
}
|
530
|
-
self._start_span_tracking(
|
531
|
-
trace_client,
|
532
|
-
run_id,
|
533
|
-
parent_run_id,
|
534
|
-
name,
|
535
|
-
span_type="tool",
|
536
|
-
inputs=combined_inputs,
|
537
|
-
)
|
538
|
-
|
539
|
-
def on_tool_end(
|
540
|
-
self,
|
541
|
-
output: Any,
|
542
|
-
*,
|
543
|
-
run_id: UUID,
|
544
|
-
parent_run_id: Optional[UUID] = None,
|
545
|
-
**kwargs: Any,
|
546
|
-
) -> Any:
|
547
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd")
|
548
|
-
if not trace_client:
|
549
|
-
return
|
550
|
-
outputs = {"output": output, "kwargs": kwargs}
|
551
|
-
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
552
|
-
|
553
|
-
def on_tool_error(
|
554
|
-
self,
|
555
|
-
error: BaseException,
|
556
|
-
*,
|
557
|
-
run_id: UUID,
|
558
|
-
parent_run_id: Optional[UUID] = None,
|
559
|
-
**kwargs: Any,
|
560
|
-
) -> Any:
|
561
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError")
|
562
|
-
if not trace_client:
|
563
|
-
return
|
564
|
-
self._end_span_tracking(trace_client, run_id, error=error)
|
565
|
-
|
566
|
-
def on_llm_start(
|
567
|
-
self,
|
568
|
-
serialized: Dict[str, Any],
|
569
|
-
prompts: List[str],
|
570
|
-
*,
|
571
|
-
run_id: UUID,
|
572
|
-
parent_run_id: Optional[UUID] = None,
|
573
|
-
tags: Optional[List[str]] = None,
|
574
|
-
metadata: Optional[Dict[str, Any]] = None,
|
575
|
-
invocation_params: Optional[Dict[str, Any]] = None,
|
576
|
-
options: Optional[Dict[str, Any]] = None,
|
577
|
-
name: Optional[str] = None,
|
578
|
-
**kwargs: Any,
|
579
|
-
) -> Any:
|
580
|
-
llm_name = name or serialized.get("name", "LLM Call")
|
581
|
-
|
582
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name)
|
583
|
-
if not trace_client:
|
584
|
-
return
|
585
|
-
inputs = {
|
586
|
-
"prompts": prompts,
|
587
|
-
"invocation_params": invocation_params or kwargs,
|
588
|
-
"options": options,
|
589
|
-
"tags": tags,
|
590
|
-
"metadata": metadata,
|
591
|
-
"serialized": serialized,
|
592
|
-
}
|
593
|
-
self._start_span_tracking(
|
594
|
-
trace_client,
|
595
|
-
run_id,
|
596
|
-
parent_run_id,
|
597
|
-
llm_name,
|
598
|
-
span_type="llm",
|
599
|
-
inputs=inputs,
|
600
|
-
)
|
601
|
-
|
602
|
-
def on_llm_end(
|
603
|
-
self,
|
604
|
-
response: LLMResult,
|
605
|
-
*,
|
606
|
-
run_id: UUID,
|
607
|
-
parent_run_id: Optional[UUID] = None,
|
608
|
-
**kwargs: Any,
|
609
|
-
) -> Any:
|
610
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd")
|
611
|
-
if not trace_client:
|
612
|
-
return
|
613
|
-
outputs = {"response": response, "kwargs": kwargs}
|
614
|
-
|
615
|
-
prompt_tokens = None
|
616
|
-
completion_tokens = None
|
617
|
-
total_tokens = None
|
618
|
-
model_name = None
|
619
|
-
|
620
|
-
# Extract model name from response if available
|
621
|
-
if (
|
622
|
-
hasattr(response, "llm_output")
|
623
|
-
and response.llm_output
|
624
|
-
and isinstance(response.llm_output, dict)
|
625
|
-
):
|
626
|
-
model_name = response.llm_output.get(
|
627
|
-
"model_name"
|
628
|
-
) or response.llm_output.get("model")
|
629
|
-
|
630
|
-
# Try to get model from the first generation if available
|
631
|
-
if not model_name and response.generations and len(response.generations) > 0:
|
632
|
-
if (
|
633
|
-
hasattr(response.generations[0][0], "generation_info")
|
634
|
-
and response.generations[0][0].generation_info
|
635
|
-
):
|
636
|
-
gen_info = response.generations[0][0].generation_info
|
637
|
-
model_name = gen_info.get("model") or gen_info.get("model_name")
|
638
|
-
|
639
|
-
if response.llm_output and isinstance(response.llm_output, dict):
|
640
|
-
# Check for OpenAI/standard 'token_usage' first
|
641
|
-
if "token_usage" in response.llm_output:
|
642
|
-
token_usage = response.llm_output.get("token_usage")
|
643
|
-
if token_usage and isinstance(token_usage, dict):
|
644
|
-
prompt_tokens = token_usage.get("prompt_tokens")
|
645
|
-
completion_tokens = token_usage.get("completion_tokens")
|
646
|
-
total_tokens = token_usage.get(
|
647
|
-
"total_tokens"
|
648
|
-
) # OpenAI provides total
|
649
|
-
# Check for Anthropic 'usage'
|
650
|
-
elif "usage" in response.llm_output:
|
651
|
-
token_usage = response.llm_output.get("usage")
|
652
|
-
if token_usage and isinstance(token_usage, dict):
|
653
|
-
prompt_tokens = token_usage.get(
|
654
|
-
"input_tokens"
|
655
|
-
) # Anthropic uses input_tokens
|
656
|
-
completion_tokens = token_usage.get(
|
657
|
-
"output_tokens"
|
658
|
-
) # Anthropic uses output_tokens
|
659
|
-
# Calculate total if possible
|
660
|
-
if prompt_tokens is not None and completion_tokens is not None:
|
661
|
-
total_tokens = prompt_tokens + completion_tokens
|
662
|
-
|
663
|
-
if prompt_tokens is not None or completion_tokens is not None:
|
664
|
-
prompt_cost = None
|
665
|
-
completion_cost = None
|
666
|
-
total_cost_usd = None
|
667
|
-
|
668
|
-
if (
|
669
|
-
model_name
|
670
|
-
and prompt_tokens is not None
|
671
|
-
and completion_tokens is not None
|
672
|
-
):
|
673
|
-
try:
|
674
|
-
prompt_cost, completion_cost = cost_per_token(
|
675
|
-
model=model_name,
|
676
|
-
prompt_tokens=prompt_tokens,
|
677
|
-
completion_tokens=completion_tokens,
|
678
|
-
)
|
679
|
-
total_cost_usd = (
|
680
|
-
(prompt_cost + completion_cost)
|
681
|
-
if prompt_cost and completion_cost
|
682
|
-
else None
|
683
|
-
)
|
684
|
-
except Exception as e:
|
685
|
-
# If cost calculation fails, continue without costs
|
686
|
-
import warnings
|
687
|
-
|
688
|
-
warnings.warn(
|
689
|
-
f"Failed to calculate token costs for model {model_name}: {e}"
|
690
|
-
)
|
691
|
-
|
692
|
-
usage = TraceUsage(
|
693
|
-
prompt_tokens=prompt_tokens,
|
694
|
-
completion_tokens=completion_tokens,
|
695
|
-
total_tokens=total_tokens
|
696
|
-
or (
|
697
|
-
prompt_tokens + completion_tokens
|
698
|
-
if prompt_tokens and completion_tokens
|
699
|
-
else None
|
700
|
-
),
|
701
|
-
prompt_tokens_cost_usd=prompt_cost,
|
702
|
-
completion_tokens_cost_usd=completion_cost,
|
703
|
-
total_cost_usd=total_cost_usd,
|
704
|
-
model_name=model_name,
|
705
|
-
)
|
706
|
-
|
707
|
-
span_id = self._run_id_to_span_id.get(run_id)
|
708
|
-
if span_id and span_id in trace_client.span_id_to_span:
|
709
|
-
trace_span = trace_client.span_id_to_span[span_id]
|
710
|
-
trace_span.usage = usage
|
711
|
-
|
712
|
-
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
713
|
-
|
714
|
-
def on_llm_error(
|
715
|
-
self,
|
716
|
-
error: BaseException,
|
717
|
-
*,
|
718
|
-
run_id: UUID,
|
719
|
-
parent_run_id: Optional[UUID] = None,
|
720
|
-
**kwargs: Any,
|
721
|
-
) -> Any:
|
722
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError")
|
723
|
-
if not trace_client:
|
724
|
-
return
|
725
|
-
self._end_span_tracking(trace_client, run_id, error=error)
|
726
|
-
|
727
|
-
def on_chat_model_start(
|
728
|
-
self,
|
729
|
-
serialized: Dict[str, Any],
|
730
|
-
messages: List[List[BaseMessage]],
|
731
|
-
*,
|
732
|
-
run_id: UUID,
|
733
|
-
parent_run_id: Optional[UUID] = None,
|
734
|
-
tags: Optional[List[str]] = None,
|
735
|
-
metadata: Optional[Dict[str, Any]] = None,
|
736
|
-
invocation_params: Optional[Dict[str, Any]] = None,
|
737
|
-
options: Optional[Dict[str, Any]] = None,
|
738
|
-
name: Optional[str] = None,
|
739
|
-
**kwargs: Any,
|
740
|
-
) -> Any:
|
741
|
-
chat_model_name = name or serialized.get("name", "ChatModel Call")
|
742
|
-
is_openai = (
|
743
|
-
any(
|
744
|
-
key.startswith("openai") for key in serialized.get("secrets", {}).keys()
|
745
|
-
)
|
746
|
-
or "openai" in chat_model_name.lower()
|
747
|
-
)
|
748
|
-
is_anthropic = (
|
749
|
-
any(
|
750
|
-
key.startswith("anthropic")
|
751
|
-
for key in serialized.get("secrets", {}).keys()
|
752
|
-
)
|
753
|
-
or "anthropic" in chat_model_name.lower()
|
754
|
-
or "claude" in chat_model_name.lower()
|
755
|
-
)
|
756
|
-
is_together = (
|
757
|
-
any(
|
758
|
-
key.startswith("together")
|
759
|
-
for key in serialized.get("secrets", {}).keys()
|
760
|
-
)
|
761
|
-
or "together" in chat_model_name.lower()
|
762
|
-
)
|
763
|
-
|
764
|
-
is_google = (
|
765
|
-
any(
|
766
|
-
key.startswith("google") for key in serialized.get("secrets", {}).keys()
|
767
|
-
)
|
768
|
-
or "google" in chat_model_name.lower()
|
769
|
-
or "gemini" in chat_model_name.lower()
|
770
|
-
)
|
771
|
-
|
772
|
-
if is_openai and "OPENAI_API_CALL" not in chat_model_name:
|
773
|
-
chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
|
774
|
-
elif is_anthropic and "ANTHROPIC_API_CALL" not in chat_model_name:
|
775
|
-
chat_model_name = f"{chat_model_name} ANTHROPIC_API_CALL"
|
776
|
-
elif is_together and "TOGETHER_API_CALL" not in chat_model_name:
|
777
|
-
chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
|
778
|
-
|
779
|
-
elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
|
780
|
-
chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
|
781
|
-
|
782
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name)
|
783
|
-
if not trace_client:
|
784
|
-
return
|
785
|
-
inputs = {
|
786
|
-
"messages": messages,
|
787
|
-
"invocation_params": invocation_params or kwargs,
|
788
|
-
"options": options,
|
789
|
-
"tags": tags,
|
790
|
-
"metadata": metadata,
|
791
|
-
"serialized": serialized,
|
792
|
-
}
|
793
|
-
self._start_span_tracking(
|
794
|
-
trace_client,
|
795
|
-
run_id,
|
796
|
-
parent_run_id,
|
797
|
-
chat_model_name,
|
798
|
-
span_type="llm",
|
799
|
-
inputs=inputs,
|
800
|
-
)
|
801
|
-
|
802
|
-
def on_agent_action(
|
803
|
-
self,
|
804
|
-
action: AgentAction,
|
805
|
-
*,
|
806
|
-
run_id: UUID,
|
807
|
-
parent_run_id: Optional[UUID] = None,
|
808
|
-
**kwargs: Any,
|
809
|
-
) -> Any:
|
810
|
-
action_tool = action.tool
|
811
|
-
name = f"AGENT_ACTION_{(action_tool).upper()}"
|
812
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
813
|
-
if not trace_client:
|
814
|
-
return
|
815
|
-
|
816
|
-
inputs = {
|
817
|
-
"tool_input": action.tool_input,
|
818
|
-
"log": action.log,
|
819
|
-
"messages": action.messages,
|
820
|
-
"kwargs": kwargs,
|
821
|
-
}
|
822
|
-
self._start_span_tracking(
|
823
|
-
trace_client, run_id, parent_run_id, name, span_type="agent", inputs=inputs
|
824
|
-
)
|
825
|
-
|
826
|
-
def on_agent_finish(
|
827
|
-
self,
|
828
|
-
finish: AgentFinish,
|
829
|
-
*,
|
830
|
-
run_id: UUID,
|
831
|
-
parent_run_id: Optional[UUID] = None,
|
832
|
-
**kwargs: Any,
|
833
|
-
) -> Any:
|
834
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish")
|
835
|
-
if not trace_client:
|
836
|
-
return
|
837
|
-
|
838
|
-
outputs = {
|
839
|
-
"return_values": finish.return_values,
|
840
|
-
"log": finish.log,
|
841
|
-
"messages": finish.messages,
|
842
|
-
"kwargs": kwargs,
|
843
|
-
}
|
844
|
-
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|