judgeval 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +5 -4
- judgeval/clients.py +6 -6
- judgeval/common/__init__.py +7 -2
- judgeval/common/exceptions.py +2 -3
- judgeval/common/logger.py +74 -49
- judgeval/common/s3_storage.py +30 -23
- judgeval/common/tracer.py +1273 -939
- judgeval/common/utils.py +416 -244
- judgeval/constants.py +73 -61
- judgeval/data/__init__.py +1 -1
- judgeval/data/custom_example.py +3 -2
- judgeval/data/datasets/dataset.py +80 -54
- judgeval/data/datasets/eval_dataset_client.py +131 -181
- judgeval/data/example.py +67 -43
- judgeval/data/result.py +11 -9
- judgeval/data/scorer_data.py +4 -2
- judgeval/data/tool.py +25 -16
- judgeval/data/trace.py +57 -29
- judgeval/data/trace_run.py +5 -11
- judgeval/evaluation_run.py +22 -82
- judgeval/integrations/langgraph.py +546 -184
- judgeval/judges/base_judge.py +1 -2
- judgeval/judges/litellm_judge.py +33 -11
- judgeval/judges/mixture_of_judges.py +128 -78
- judgeval/judges/together_judge.py +22 -9
- judgeval/judges/utils.py +14 -5
- judgeval/judgment_client.py +259 -271
- judgeval/rules.py +169 -142
- judgeval/run_evaluation.py +462 -305
- judgeval/scorers/api_scorer.py +20 -11
- judgeval/scorers/exceptions.py +1 -0
- judgeval/scorers/judgeval_scorer.py +77 -58
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
- judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
- judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
- judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
- judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
- judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
- judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
- judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
- judgeval/scorers/prompt_scorer.py +48 -37
- judgeval/scorers/score.py +86 -53
- judgeval/scorers/utils.py +11 -7
- judgeval/tracer/__init__.py +1 -1
- judgeval/utils/alerts.py +23 -12
- judgeval/utils/{data_utils.py → file_utils.py} +5 -9
- judgeval/utils/requests.py +29 -0
- judgeval/version_check.py +5 -2
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/METADATA +79 -135
- judgeval-0.0.46.dist-info/RECORD +69 -0
- judgeval-0.0.44.dist-info/RECORD +0 -68
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/WHEEL +0 -0
- {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/licenses/LICENSE.md +0 -0
@@ -2,11 +2,15 @@ from typing import Any, Dict, List, Optional, Sequence
|
|
2
2
|
from uuid import UUID
|
3
3
|
import time
|
4
4
|
import uuid
|
5
|
-
import contextvars # <--- Import contextvars
|
6
5
|
from datetime import datetime
|
7
6
|
|
8
|
-
from judgeval.common.tracer import
|
9
|
-
|
7
|
+
from judgeval.common.tracer import (
|
8
|
+
TraceClient,
|
9
|
+
TraceSpan,
|
10
|
+
Tracer,
|
11
|
+
SpanType,
|
12
|
+
cost_per_token,
|
13
|
+
)
|
10
14
|
from judgeval.data.trace import TraceUsage
|
11
15
|
|
12
16
|
from langchain_core.callbacks import BaseCallbackHandler
|
@@ -22,6 +26,7 @@ from langchain_core.documents import Document
|
|
22
26
|
# from judgeval.common.tracer import current_span_var
|
23
27
|
# TODO: Figure out how to handle context variables. Current solution is to keep track of current span id in Tracer class
|
24
28
|
|
29
|
+
|
25
30
|
# --- NEW __init__ ---
|
26
31
|
class JudgevalCallbackHandler(BaseCallbackHandler):
|
27
32
|
"""
|
@@ -29,14 +34,14 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
29
34
|
Manages its own internal TraceClient instance created upon first use.
|
30
35
|
Includes verbose logging and defensive checks.
|
31
36
|
"""
|
37
|
+
|
32
38
|
# Make all properties ignored by LangChain's callback system
|
33
39
|
# to prevent unexpected serialization issues.
|
34
40
|
lc_serializable = False
|
35
|
-
lc_kwargs = {}
|
41
|
+
lc_kwargs: dict = {}
|
36
42
|
|
37
43
|
# --- NEW __init__ ---
|
38
44
|
def __init__(self, tracer: Tracer):
|
39
|
-
|
40
45
|
self.tracer = tracer
|
41
46
|
# Initialize tracking/logging variables (preserved across resets)
|
42
47
|
self.executed_nodes: List[str] = []
|
@@ -45,6 +50,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
45
50
|
self.traces: List[Dict[str, Any]] = []
|
46
51
|
# Initialize execution state (reset between runs)
|
47
52
|
self._reset_state()
|
53
|
+
|
48
54
|
# --- END NEW __init__ ---
|
49
55
|
|
50
56
|
def _reset_state(self):
|
@@ -55,19 +61,28 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
55
61
|
self._span_id_to_start_time: Dict[str, float] = {}
|
56
62
|
self._span_id_to_depth: Dict[str, int] = {}
|
57
63
|
self._root_run_id: Optional[UUID] = None
|
58
|
-
self._trace_saved: bool = False
|
64
|
+
self._trace_saved: bool = False # Flag to prevent actions after trace is saved
|
59
65
|
self.span_id_to_token: Dict[str, Any] = {}
|
60
66
|
self.trace_id_to_token: Dict[str, Any] = {}
|
61
|
-
|
67
|
+
|
62
68
|
# Add timestamp to track when we last reset
|
63
69
|
self._last_reset_time: float = time.time()
|
64
|
-
|
70
|
+
|
65
71
|
# Preserve tracking/logging variables across executions:
|
66
72
|
# - self.executed_nodes: List[str] = [] # Keep as running log
|
67
|
-
# - self.executed_tools: List[str] = [] # Keep as running log
|
73
|
+
# - self.executed_tools: List[str] = [] # Keep as running log
|
68
74
|
# - self.executed_node_tools: List[str] = [] # Keep as running log
|
69
75
|
# - self.traces: List[Dict[str, Any]] = [] # Keep for collecting multiple traces
|
70
76
|
|
77
|
+
# Also reset tracking/logging variables
|
78
|
+
self.executed_nodes: List[
|
79
|
+
str
|
80
|
+
] = [] # These last four members are only appended to and never accessed; can probably be removed but still might be useful for future reference?
|
81
|
+
self.executed_tools: List[str] = []
|
82
|
+
self.executed_node_tools: List[str] = []
|
83
|
+
self.traces: List[Dict[str, Any]] = []
|
84
|
+
|
85
|
+
# --- END NEW __init__ ---
|
71
86
|
def reset(self):
|
72
87
|
"""Public method to manually reset handler execution state for reuse"""
|
73
88
|
self._reset_state()
|
@@ -75,14 +90,11 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
75
90
|
def reset_all(self):
|
76
91
|
"""Public method to reset ALL handler state including tracking/logging data"""
|
77
92
|
self._reset_state()
|
78
|
-
# Also reset tracking/logging variables
|
79
|
-
self.executed_nodes: List[str] = []
|
80
|
-
self.executed_tools: List[str] = []
|
81
|
-
self.executed_node_tools: List[str] = []
|
82
|
-
self.traces: List[Dict[str, Any]] = []
|
83
93
|
|
84
94
|
# --- MODIFIED _ensure_trace_client ---
|
85
|
-
def _ensure_trace_client(
|
95
|
+
def _ensure_trace_client(
|
96
|
+
self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str
|
97
|
+
) -> Optional[TraceClient]:
|
86
98
|
"""
|
87
99
|
Ensures the internal trace client is initialized, creating it only once
|
88
100
|
per handler instance lifecycle (effectively per graph invocation).
|
@@ -104,36 +116,44 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
104
116
|
try:
|
105
117
|
# Use event_name as the initial trace name, might be updated later by on_chain_start if root
|
106
118
|
client_instance = TraceClient(
|
107
|
-
self.tracer,
|
108
|
-
|
119
|
+
self.tracer,
|
120
|
+
trace_id,
|
121
|
+
event_name,
|
122
|
+
project_name=project,
|
123
|
+
overwrite=False,
|
124
|
+
rules=self.tracer.rules,
|
109
125
|
enable_monitoring=self.tracer.enable_monitoring,
|
110
|
-
enable_evaluations=self.tracer.enable_evaluations
|
126
|
+
enable_evaluations=self.tracer.enable_evaluations,
|
111
127
|
)
|
112
128
|
self._trace_client = client_instance
|
113
129
|
token = self.tracer.set_current_trace(self._trace_client)
|
114
130
|
if token:
|
115
131
|
self.trace_id_to_token[trace_id] = token
|
116
132
|
if self._trace_client:
|
117
|
-
self._root_run_id =
|
118
|
-
|
133
|
+
self._root_run_id = (
|
134
|
+
run_id # Assign the first run_id encountered as the tentative root
|
135
|
+
)
|
136
|
+
self._trace_saved = False # Ensure flag is reset
|
119
137
|
# Set active client on Tracer (important for potential fallbacks)
|
120
138
|
self.tracer._active_trace_client = self._trace_client
|
121
|
-
|
139
|
+
|
122
140
|
# NEW: Initial save for live tracking (follows the new practice)
|
123
141
|
try:
|
124
|
-
trace_id_saved, server_response = self._trace_client.
|
125
|
-
overwrite=self._trace_client.overwrite,
|
126
|
-
final_save=False # Initial save for live tracking
|
142
|
+
trace_id_saved, server_response = self._trace_client.save(
|
143
|
+
overwrite=self._trace_client.overwrite,
|
144
|
+
final_save=False, # Initial save for live tracking
|
127
145
|
)
|
128
146
|
except Exception as e:
|
129
147
|
import warnings
|
130
|
-
|
131
|
-
|
148
|
+
|
149
|
+
warnings.warn(
|
150
|
+
f"Failed to save initial trace for live tracking: {e}"
|
151
|
+
)
|
152
|
+
|
132
153
|
return self._trace_client
|
133
154
|
else:
|
134
155
|
return None
|
135
|
-
except Exception
|
136
|
-
|
156
|
+
except Exception:
|
137
157
|
self._trace_client = None
|
138
158
|
self._root_run_id = None
|
139
159
|
return None
|
@@ -145,7 +165,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
145
165
|
parent_run_id: Optional[UUID],
|
146
166
|
name: str,
|
147
167
|
span_type: SpanType = "span",
|
148
|
-
inputs: Optional[Dict[str, Any]] = None
|
168
|
+
inputs: Optional[Dict[str, Any]] = None,
|
149
169
|
) -> None:
|
150
170
|
"""Start tracking a span, ensuring trace client exists"""
|
151
171
|
|
@@ -157,7 +177,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
157
177
|
if parent_run_id and parent_run_id in self._run_id_to_span_id:
|
158
178
|
parent_span_id = self._run_id_to_span_id[parent_run_id]
|
159
179
|
if parent_span_id in self._span_id_to_depth:
|
160
|
-
current_depth = self._span_id_to_depth[parent_span_id] + 1
|
180
|
+
current_depth = self._span_id_to_depth[parent_span_id] + 1
|
161
181
|
|
162
182
|
self._run_id_to_span_id[run_id] = span_id
|
163
183
|
self._span_id_to_start_time[span_id] = start_time
|
@@ -170,23 +190,23 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
170
190
|
function=name,
|
171
191
|
depth=current_depth,
|
172
192
|
created_at=start_time,
|
173
|
-
span_type=span_type
|
193
|
+
span_type=span_type,
|
174
194
|
)
|
175
195
|
|
176
196
|
# Separate metadata from inputs
|
177
197
|
if inputs:
|
178
198
|
metadata = {}
|
179
199
|
clean_inputs = {}
|
180
|
-
|
200
|
+
|
181
201
|
# Extract metadata fields
|
182
|
-
metadata_fields = [
|
202
|
+
metadata_fields = ["tags", "metadata", "kwargs", "serialized"]
|
183
203
|
for field in metadata_fields:
|
184
204
|
if field in inputs:
|
185
205
|
metadata[field] = inputs.pop(field)
|
186
|
-
|
206
|
+
|
187
207
|
# Store the remaining inputs
|
188
208
|
clean_inputs = inputs
|
189
|
-
|
209
|
+
|
190
210
|
# Set both fields on the span
|
191
211
|
new_span.inputs = clean_inputs
|
192
212
|
new_span.additional_metadata = metadata
|
@@ -195,11 +215,13 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
195
215
|
new_span.additional_metadata = {}
|
196
216
|
|
197
217
|
trace_client.add_span(new_span)
|
198
|
-
|
218
|
+
|
199
219
|
# Queue span with initial state (input phase) through background service
|
200
220
|
if trace_client.background_span_service:
|
201
|
-
trace_client.background_span_service.queue_span(
|
202
|
-
|
221
|
+
trace_client.background_span_service.queue_span(
|
222
|
+
new_span, span_state="input"
|
223
|
+
)
|
224
|
+
|
203
225
|
token = self.tracer.set_current_span(span_id)
|
204
226
|
if token:
|
205
227
|
self.span_id_to_token[span_id] = token
|
@@ -209,14 +231,15 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
209
231
|
trace_client: TraceClient,
|
210
232
|
run_id: UUID,
|
211
233
|
outputs: Optional[Any] = None,
|
212
|
-
error: Optional[BaseException] = None
|
234
|
+
error: Optional[BaseException] = None,
|
213
235
|
) -> None:
|
214
236
|
"""End tracking a span, ensuring trace client exists"""
|
215
237
|
|
216
238
|
# Get span ID and check if it exists
|
217
239
|
span_id = self._run_id_to_span_id.get(run_id)
|
218
|
-
|
219
|
-
|
240
|
+
if span_id:
|
241
|
+
token = self.span_id_to_token.pop(span_id, None)
|
242
|
+
self.tracer.reset_current_span(token, span_id)
|
220
243
|
|
221
244
|
start_time = self._span_id_to_start_time.get(span_id) if span_id else None
|
222
245
|
duration = time.time() - start_time if start_time is not None else None
|
@@ -226,7 +249,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
226
249
|
trace_span = trace_client.span_id_to_span.get(span_id)
|
227
250
|
if trace_span:
|
228
251
|
trace_span.duration = duration
|
229
|
-
|
252
|
+
|
230
253
|
# Handle outputs and error
|
231
254
|
if error:
|
232
255
|
trace_span.output = error
|
@@ -234,34 +257,41 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
234
257
|
# Separate metadata from outputs
|
235
258
|
metadata = {}
|
236
259
|
clean_outputs = {}
|
237
|
-
|
260
|
+
|
238
261
|
# Extract metadata fields
|
239
|
-
metadata_fields = [
|
262
|
+
metadata_fields = ["tags", "kwargs"]
|
240
263
|
if isinstance(outputs, dict):
|
241
264
|
for field in metadata_fields:
|
242
265
|
if field in outputs:
|
243
266
|
metadata[field] = outputs.pop(field)
|
244
|
-
|
267
|
+
|
245
268
|
# Store the remaining outputs
|
246
269
|
clean_outputs = outputs
|
247
270
|
else:
|
248
271
|
clean_outputs = outputs
|
249
|
-
|
272
|
+
|
250
273
|
# Set both fields on the span
|
251
274
|
trace_span.output = clean_outputs
|
252
275
|
if metadata:
|
253
276
|
# Merge with existing metadata
|
254
277
|
existing_metadata = trace_span.additional_metadata or {}
|
255
|
-
trace_span.additional_metadata = {
|
256
|
-
|
278
|
+
trace_span.additional_metadata = {
|
279
|
+
**existing_metadata,
|
280
|
+
**metadata,
|
281
|
+
}
|
282
|
+
|
257
283
|
# Queue span with completed state through background service
|
258
284
|
if trace_client.background_span_service:
|
259
285
|
span_state = "error" if error else "completed"
|
260
|
-
trace_client.background_span_service.queue_span(
|
286
|
+
trace_client.background_span_service.queue_span(
|
287
|
+
trace_span, span_state=span_state
|
288
|
+
)
|
261
289
|
|
262
290
|
# Clean up dictionaries for this specific span
|
263
|
-
if span_id in self._span_id_to_start_time:
|
264
|
-
|
291
|
+
if span_id in self._span_id_to_start_time:
|
292
|
+
del self._span_id_to_start_time[span_id]
|
293
|
+
if span_id in self._span_id_to_depth:
|
294
|
+
del self._span_id_to_depth[span_id]
|
265
295
|
|
266
296
|
# Check if this is the root run ending
|
267
297
|
if run_id == self._root_run_id:
|
@@ -270,34 +300,38 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
270
300
|
self._root_run_id = None
|
271
301
|
# Reset input storage for this handler instance
|
272
302
|
|
273
|
-
if
|
303
|
+
if (
|
304
|
+
self._trace_client and not self._trace_saved
|
305
|
+
): # Check if not already saved
|
274
306
|
# Flush background spans before saving the final trace
|
275
307
|
|
276
308
|
complete_trace_data = {
|
277
309
|
"trace_id": self._trace_client.trace_id,
|
278
310
|
"name": self._trace_client.name,
|
279
|
-
"created_at": datetime.utcfromtimestamp(
|
311
|
+
"created_at": datetime.utcfromtimestamp(
|
312
|
+
self._trace_client.start_time
|
313
|
+
).isoformat(),
|
280
314
|
"duration": self._trace_client.get_duration(),
|
281
|
-
"trace_spans": [
|
315
|
+
"trace_spans": [
|
316
|
+
span.model_dump() for span in self._trace_client.trace_spans
|
317
|
+
],
|
282
318
|
"overwrite": self._trace_client.overwrite,
|
283
319
|
"offline_mode": self.tracer.offline_mode,
|
284
320
|
"parent_trace_id": self._trace_client.parent_trace_id,
|
285
|
-
"parent_name": self._trace_client.parent_name
|
321
|
+
"parent_name": self._trace_client.parent_name,
|
286
322
|
}
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
overwrite=self._trace_client.overwrite,
|
291
|
-
final_save=True # Final save with usage counter updates
|
323
|
+
trace_id, trace_data = self._trace_client.save(
|
324
|
+
overwrite=self._trace_client.overwrite,
|
325
|
+
final_save=True, # Final save with usage counter updates
|
292
326
|
)
|
293
327
|
token = self.trace_id_to_token.pop(trace_id, None)
|
294
328
|
self.tracer.reset_current_trace(token, trace_id)
|
295
|
-
|
329
|
+
|
296
330
|
# Store complete trace data instead of server response
|
297
331
|
self.tracer.traces.append(complete_trace_data)
|
298
|
-
self._trace_saved = True
|
332
|
+
self._trace_saved = True # Set flag only after successful save
|
299
333
|
finally:
|
300
|
-
# --- NEW: Consolidated Cleanup Logic ---
|
334
|
+
# --- NEW: Consolidated Cleanup Logic ---
|
301
335
|
# This block executes regardless of save success/failure
|
302
336
|
# Reset root run id
|
303
337
|
self._root_run_id = None
|
@@ -309,68 +343,167 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
309
343
|
# --- Callback Methods ---
|
310
344
|
# Each method now ensures the trace client exists before proceeding
|
311
345
|
|
312
|
-
def on_retriever_start(
|
313
|
-
|
346
|
+
def on_retriever_start(
|
347
|
+
self,
|
348
|
+
serialized: Dict[str, Any],
|
349
|
+
query: str,
|
350
|
+
*,
|
351
|
+
run_id: UUID,
|
352
|
+
parent_run_id: Optional[UUID] = None,
|
353
|
+
tags: Optional[List[str]] = None,
|
354
|
+
metadata: Optional[Dict[str, Any]] = None,
|
355
|
+
**kwargs: Any,
|
356
|
+
) -> Any:
|
357
|
+
serialized_name = (
|
358
|
+
serialized.get("name", "Unknown")
|
359
|
+
if serialized
|
360
|
+
else "Unknown (Serialized=None)"
|
361
|
+
)
|
314
362
|
|
315
363
|
name = f"RETRIEVER_{(serialized_name).upper()}"
|
316
364
|
# Pass parent_run_id
|
317
|
-
trace_client = self._ensure_trace_client(
|
318
|
-
|
365
|
+
trace_client = self._ensure_trace_client(
|
366
|
+
run_id, parent_run_id, name
|
367
|
+
) # Corrected call
|
368
|
+
if not trace_client:
|
369
|
+
return
|
319
370
|
|
320
|
-
inputs = {
|
321
|
-
|
371
|
+
inputs = {
|
372
|
+
"query": query,
|
373
|
+
"tags": tags,
|
374
|
+
"metadata": metadata,
|
375
|
+
"kwargs": kwargs,
|
376
|
+
"serialized": serialized,
|
377
|
+
}
|
378
|
+
self._start_span_tracking(
|
379
|
+
trace_client,
|
380
|
+
run_id,
|
381
|
+
parent_run_id,
|
382
|
+
name,
|
383
|
+
span_type="retriever",
|
384
|
+
inputs=inputs,
|
385
|
+
)
|
322
386
|
|
323
|
-
def on_retriever_end(
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
387
|
+
def on_retriever_end(
|
388
|
+
self,
|
389
|
+
documents: Sequence[Document],
|
390
|
+
*,
|
391
|
+
run_id: UUID,
|
392
|
+
parent_run_id: Optional[UUID] = None,
|
393
|
+
**kwargs: Any,
|
394
|
+
) -> Any:
|
395
|
+
trace_client = self._ensure_trace_client(
|
396
|
+
run_id, parent_run_id, "RetrieverEnd"
|
397
|
+
) # Corrected call
|
398
|
+
if not trace_client:
|
399
|
+
return
|
400
|
+
doc_summary = [
|
401
|
+
{
|
402
|
+
"index": i,
|
403
|
+
"page_content": doc.page_content[:100] + "..."
|
404
|
+
if len(doc.page_content) > 100
|
405
|
+
else doc.page_content,
|
406
|
+
"metadata": doc.metadata,
|
407
|
+
}
|
408
|
+
for i, doc in enumerate(documents)
|
409
|
+
]
|
410
|
+
outputs = {
|
411
|
+
"document_count": len(documents),
|
412
|
+
"documents": doc_summary,
|
413
|
+
"kwargs": kwargs,
|
414
|
+
}
|
328
415
|
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
329
416
|
|
330
|
-
def on_chain_start(
|
331
|
-
|
417
|
+
def on_chain_start(
|
418
|
+
self,
|
419
|
+
serialized: Dict[str, Any],
|
420
|
+
inputs: Dict[str, Any],
|
421
|
+
*,
|
422
|
+
run_id: UUID,
|
423
|
+
parent_run_id: Optional[UUID] = None,
|
424
|
+
tags: Optional[List[str]] = None,
|
425
|
+
metadata: Optional[Dict[str, Any]] = None,
|
426
|
+
**kwargs: Any,
|
427
|
+
) -> None:
|
428
|
+
serialized_name = (
|
429
|
+
serialized.get("name") if serialized else "Unknown (Serialized=None)"
|
430
|
+
)
|
332
431
|
|
333
432
|
# --- Determine Name and Span Type ---
|
334
433
|
span_type: SpanType = "chain"
|
335
|
-
name = serialized_name if serialized_name else "Unknown Chain"
|
434
|
+
name = serialized_name if serialized_name else "Unknown Chain" # Default name
|
336
435
|
node_name = metadata.get("langgraph_node") if metadata else None
|
337
|
-
is_langgraph_root_kwarg =
|
436
|
+
is_langgraph_root_kwarg = (
|
437
|
+
kwargs.get("name") == "LangGraph"
|
438
|
+
) # Check kwargs for explicit root name
|
338
439
|
# More robust root detection: Often the first chain event with parent_run_id=None *is* the root.
|
339
440
|
is_potential_root_event = parent_run_id is None
|
340
441
|
|
341
442
|
if node_name:
|
342
|
-
name = node_name
|
343
|
-
if name not in self.executed_nodes:
|
443
|
+
name = node_name # Use node name if available
|
444
|
+
if name not in self.executed_nodes:
|
445
|
+
self.executed_nodes.append(
|
446
|
+
name
|
447
|
+
) # Leaving this in for now but can probably be removed
|
344
448
|
elif is_langgraph_root_kwarg and is_potential_root_event:
|
345
|
-
|
449
|
+
name = "LangGraph" # Explicit root detected
|
346
450
|
# Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
|
347
451
|
|
348
452
|
# --- Ensure Trace Client ---
|
349
|
-
|
350
|
-
trace_client = self._ensure_trace_client(
|
351
|
-
|
453
|
+
# Pass parent_run_id to _ensure_trace_client
|
454
|
+
trace_client = self._ensure_trace_client(
|
455
|
+
run_id, parent_run_id, name
|
456
|
+
) # Corrected call
|
457
|
+
if not trace_client:
|
458
|
+
return
|
352
459
|
|
353
460
|
# --- Update Trace Name if Root ---
|
354
461
|
# If this is the root event (parent_run_id is None) and the trace client was just created,
|
355
462
|
# ensure the trace name reflects the graph's name ('LangGraph' usually).
|
356
|
-
if
|
357
|
-
|
463
|
+
if (
|
464
|
+
is_potential_root_event
|
465
|
+
and run_id == self._root_run_id
|
466
|
+
and trace_client.name != name
|
467
|
+
):
|
468
|
+
trace_client.name = name # Update trace name to the determined root name
|
358
469
|
|
359
470
|
# --- Start Span Tracking ---
|
360
|
-
combined_inputs = {
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
471
|
+
combined_inputs = {
|
472
|
+
"inputs": inputs,
|
473
|
+
"tags": tags,
|
474
|
+
"metadata": metadata,
|
475
|
+
"kwargs": kwargs,
|
476
|
+
"serialized": serialized,
|
477
|
+
}
|
478
|
+
self._start_span_tracking(
|
479
|
+
trace_client,
|
480
|
+
run_id,
|
481
|
+
parent_run_id,
|
482
|
+
name,
|
483
|
+
span_type=span_type,
|
484
|
+
inputs=combined_inputs,
|
485
|
+
)
|
365
486
|
|
487
|
+
def on_chain_end(
|
488
|
+
self,
|
489
|
+
outputs: Dict[str, Any],
|
490
|
+
*,
|
491
|
+
run_id: UUID,
|
492
|
+
parent_run_id: Optional[UUID] = None,
|
493
|
+
tags: Optional[List[str]] = None,
|
494
|
+
**kwargs: Any,
|
495
|
+
) -> Any:
|
366
496
|
# Pass parent_run_id
|
367
|
-
trace_client = self._ensure_trace_client(
|
368
|
-
|
497
|
+
trace_client = self._ensure_trace_client(
|
498
|
+
run_id, parent_run_id, "ChainEnd"
|
499
|
+
) # Corrected call
|
500
|
+
if not trace_client:
|
501
|
+
return
|
369
502
|
|
370
503
|
span_id = self._run_id_to_span_id.get(run_id)
|
371
504
|
# If it's the root run ending, _end_span_tracking will handle cleanup/save
|
372
505
|
if not span_id and run_id != self._root_run_id:
|
373
|
-
return
|
506
|
+
return # Don't call end tracking if it's not the root and span wasn't tracked
|
374
507
|
|
375
508
|
# Prepare outputs for end tracking (moved down)
|
376
509
|
combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
|
@@ -385,128 +518,247 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
385
518
|
complete_trace_data = {
|
386
519
|
"trace_id": trace_client.trace_id,
|
387
520
|
"name": trace_client.name,
|
388
|
-
"created_at": datetime.utcfromtimestamp(
|
521
|
+
"created_at": datetime.utcfromtimestamp(
|
522
|
+
trace_client.start_time
|
523
|
+
).isoformat(),
|
389
524
|
"duration": trace_client.get_duration(),
|
390
|
-
"trace_spans": [
|
525
|
+
"trace_spans": [
|
526
|
+
span.model_dump() for span in trace_client.trace_spans
|
527
|
+
],
|
391
528
|
"overwrite": trace_client.overwrite,
|
392
529
|
"offline_mode": self.tracer.offline_mode,
|
393
530
|
"parent_trace_id": trace_client.parent_trace_id,
|
394
|
-
"parent_name": trace_client.parent_name
|
531
|
+
"parent_name": trace_client.parent_name,
|
395
532
|
}
|
396
|
-
|
397
|
-
trace_id_saved, trace_data = trace_client.save_with_rate_limiting(
|
533
|
+
trace_id_saved, trace_data = trace_client.save(
|
398
534
|
overwrite=trace_client.overwrite,
|
399
|
-
final_save=True
|
535
|
+
final_save=True,
|
400
536
|
)
|
401
|
-
|
402
|
-
|
537
|
+
|
403
538
|
self.tracer.traces.append(complete_trace_data)
|
404
539
|
self._trace_saved = True
|
405
540
|
# Reset tracer's active client *after* successful save
|
406
541
|
if self.tracer._active_trace_client == trace_client:
|
407
|
-
self.tracer._active_trace_client = None
|
408
|
-
|
542
|
+
self.tracer._active_trace_client = None
|
543
|
+
|
409
544
|
# Reset root run id after attempt
|
410
545
|
self._root_run_id = None
|
411
546
|
# Reset input storage for this handler instance
|
412
547
|
|
413
|
-
def on_chain_error(
|
548
|
+
def on_chain_error(
|
549
|
+
self,
|
550
|
+
error: BaseException,
|
551
|
+
*,
|
552
|
+
run_id: UUID,
|
553
|
+
parent_run_id: Optional[UUID] = None,
|
554
|
+
**kwargs: Any,
|
555
|
+
) -> Any:
|
414
556
|
# Pass parent_run_id
|
415
|
-
trace_client = self._ensure_trace_client(
|
557
|
+
trace_client = self._ensure_trace_client(
|
558
|
+
run_id, parent_run_id, "ChainError"
|
559
|
+
) # Corrected call
|
416
560
|
if not trace_client:
|
417
561
|
return
|
418
562
|
|
419
563
|
span_id = self._run_id_to_span_id.get(run_id)
|
420
|
-
|
564
|
+
|
421
565
|
# Let _end_span_tracking handle potential root run cleanup
|
422
566
|
if not span_id and run_id != self._root_run_id:
|
423
567
|
return
|
424
568
|
|
425
569
|
self._end_span_tracking(trace_client, run_id, error=error)
|
426
570
|
|
427
|
-
def on_tool_start(
|
428
|
-
|
571
|
+
def on_tool_start(
|
572
|
+
self,
|
573
|
+
serialized: Dict[str, Any],
|
574
|
+
input_str: str,
|
575
|
+
*,
|
576
|
+
run_id: UUID,
|
577
|
+
parent_run_id: Optional[UUID] = None,
|
578
|
+
tags: Optional[List[str]] = None,
|
579
|
+
metadata: Optional[Dict[str, Any]] = None,
|
580
|
+
inputs: Optional[Dict[str, Any]] = None,
|
581
|
+
**kwargs: Any,
|
582
|
+
) -> Any:
|
583
|
+
name = (
|
584
|
+
serialized.get("name", "Unnamed Tool")
|
585
|
+
if serialized
|
586
|
+
else "Unknown Tool (Serialized=None)"
|
587
|
+
)
|
429
588
|
|
430
589
|
# Pass parent_run_id
|
431
|
-
trace_client = self._ensure_trace_client(
|
432
|
-
|
590
|
+
trace_client = self._ensure_trace_client(
|
591
|
+
run_id, parent_run_id, name
|
592
|
+
) # Corrected call
|
593
|
+
if not trace_client:
|
594
|
+
return
|
433
595
|
|
434
|
-
combined_inputs = {
|
435
|
-
|
596
|
+
combined_inputs = {
|
597
|
+
"input_str": input_str,
|
598
|
+
"inputs": inputs,
|
599
|
+
"tags": tags,
|
600
|
+
"metadata": metadata,
|
601
|
+
"kwargs": kwargs,
|
602
|
+
"serialized": serialized,
|
603
|
+
}
|
604
|
+
self._start_span_tracking(
|
605
|
+
trace_client,
|
606
|
+
run_id,
|
607
|
+
parent_run_id,
|
608
|
+
name,
|
609
|
+
span_type="tool",
|
610
|
+
inputs=combined_inputs,
|
611
|
+
)
|
436
612
|
|
437
613
|
# --- Track executed tools (remains the same) ---
|
438
|
-
if name not in self.executed_tools:
|
614
|
+
if name not in self.executed_tools:
|
615
|
+
self.executed_tools.append(
|
616
|
+
name
|
617
|
+
) # Leaving this in for now but can probably be removed
|
439
618
|
parent_node_name = None
|
440
619
|
if parent_run_id and parent_run_id in self._run_id_to_span_id:
|
441
620
|
parent_span_id = self._run_id_to_span_id[parent_run_id]
|
442
621
|
parent_node_name = trace_client.span_id_to_span[parent_span_id].function
|
443
622
|
|
444
623
|
node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
|
445
|
-
if node_tool not in self.executed_node_tools:
|
624
|
+
if node_tool not in self.executed_node_tools:
|
625
|
+
self.executed_node_tools.append(
|
626
|
+
node_tool
|
627
|
+
) # Leaving this in for now but can probably be removed
|
446
628
|
# --- End Track executed tools ---
|
447
629
|
|
448
|
-
|
449
|
-
|
630
|
+
def on_tool_end(
|
631
|
+
self,
|
632
|
+
output: Any,
|
633
|
+
*,
|
634
|
+
run_id: UUID,
|
635
|
+
parent_run_id: Optional[UUID] = None,
|
636
|
+
**kwargs: Any,
|
637
|
+
) -> Any:
|
450
638
|
# Pass parent_run_id
|
451
|
-
trace_client = self._ensure_trace_client(
|
452
|
-
|
639
|
+
trace_client = self._ensure_trace_client(
|
640
|
+
run_id, parent_run_id, "ToolEnd"
|
641
|
+
) # Corrected call
|
642
|
+
if not trace_client:
|
643
|
+
return
|
453
644
|
outputs = {"output": output, "kwargs": kwargs}
|
454
645
|
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
455
646
|
|
456
|
-
def on_tool_error(
|
457
|
-
|
647
|
+
def on_tool_error(
|
648
|
+
self,
|
649
|
+
error: BaseException,
|
650
|
+
*,
|
651
|
+
run_id: UUID,
|
652
|
+
parent_run_id: Optional[UUID] = None,
|
653
|
+
**kwargs: Any,
|
654
|
+
) -> Any:
|
458
655
|
# Pass parent_run_id
|
459
|
-
trace_client = self._ensure_trace_client(
|
460
|
-
|
656
|
+
trace_client = self._ensure_trace_client(
|
657
|
+
run_id, parent_run_id, "ToolError"
|
658
|
+
) # Corrected call
|
659
|
+
if not trace_client:
|
660
|
+
return
|
461
661
|
self._end_span_tracking(trace_client, run_id, error=error)
|
462
662
|
|
463
|
-
def on_llm_start(
|
464
|
-
|
663
|
+
def on_llm_start(
|
664
|
+
self,
|
665
|
+
serialized: Dict[str, Any],
|
666
|
+
prompts: List[str],
|
667
|
+
*,
|
668
|
+
run_id: UUID,
|
669
|
+
parent_run_id: Optional[UUID] = None,
|
670
|
+
tags: Optional[List[str]] = None,
|
671
|
+
metadata: Optional[Dict[str, Any]] = None,
|
672
|
+
invocation_params: Optional[Dict[str, Any]] = None,
|
673
|
+
options: Optional[Dict[str, Any]] = None,
|
674
|
+
name: Optional[str] = None,
|
675
|
+
**kwargs: Any,
|
676
|
+
) -> Any:
|
465
677
|
llm_name = name or serialized.get("name", "LLM Call")
|
466
678
|
|
467
|
-
trace_client = self._ensure_trace_client(
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
679
|
+
trace_client = self._ensure_trace_client(
|
680
|
+
run_id, parent_run_id, llm_name
|
681
|
+
) # Corrected call
|
682
|
+
if not trace_client:
|
683
|
+
return
|
684
|
+
inputs = {
|
685
|
+
"prompts": prompts,
|
686
|
+
"invocation_params": invocation_params or kwargs,
|
687
|
+
"options": options,
|
688
|
+
"tags": tags,
|
689
|
+
"metadata": metadata,
|
690
|
+
"serialized": serialized,
|
691
|
+
}
|
692
|
+
self._start_span_tracking(
|
693
|
+
trace_client,
|
694
|
+
run_id,
|
695
|
+
parent_run_id,
|
696
|
+
llm_name,
|
697
|
+
span_type="llm",
|
698
|
+
inputs=inputs,
|
699
|
+
)
|
473
700
|
|
701
|
+
def on_llm_end(
|
702
|
+
self,
|
703
|
+
response: LLMResult,
|
704
|
+
*,
|
705
|
+
run_id: UUID,
|
706
|
+
parent_run_id: Optional[UUID] = None,
|
707
|
+
**kwargs: Any,
|
708
|
+
) -> Any:
|
474
709
|
# Pass parent_run_id
|
475
|
-
trace_client = self._ensure_trace_client(
|
710
|
+
trace_client = self._ensure_trace_client(
|
711
|
+
run_id, parent_run_id, "LLMEnd"
|
712
|
+
) # Corrected call
|
476
713
|
if not trace_client:
|
477
714
|
return
|
478
715
|
outputs = {"response": response, "kwargs": kwargs}
|
479
|
-
|
716
|
+
|
480
717
|
# --- Token Usage Extraction and Cost Calculation ---
|
481
718
|
prompt_tokens = None
|
482
719
|
completion_tokens = None
|
483
720
|
total_tokens = None
|
484
721
|
model_name = None
|
485
|
-
|
722
|
+
|
486
723
|
# Extract model name from response if available
|
487
|
-
if
|
488
|
-
|
489
|
-
|
724
|
+
if (
|
725
|
+
hasattr(response, "llm_output")
|
726
|
+
and response.llm_output
|
727
|
+
and isinstance(response.llm_output, dict)
|
728
|
+
):
|
729
|
+
model_name = response.llm_output.get(
|
730
|
+
"model_name"
|
731
|
+
) or response.llm_output.get("model")
|
732
|
+
|
490
733
|
# Try to get model from the first generation if available
|
491
734
|
if not model_name and response.generations and len(response.generations) > 0:
|
492
|
-
if
|
735
|
+
if (
|
736
|
+
hasattr(response.generations[0][0], "generation_info")
|
737
|
+
and response.generations[0][0].generation_info
|
738
|
+
):
|
493
739
|
gen_info = response.generations[0][0].generation_info
|
494
|
-
model_name = gen_info.get(
|
740
|
+
model_name = gen_info.get("model") or gen_info.get("model_name")
|
495
741
|
|
496
742
|
if response.llm_output and isinstance(response.llm_output, dict):
|
497
743
|
# Check for OpenAI/standard 'token_usage' first
|
498
|
-
if
|
499
|
-
token_usage = response.llm_output.get(
|
744
|
+
if "token_usage" in response.llm_output:
|
745
|
+
token_usage = response.llm_output.get("token_usage")
|
500
746
|
if token_usage and isinstance(token_usage, dict):
|
501
|
-
prompt_tokens = token_usage.get(
|
502
|
-
completion_tokens = token_usage.get(
|
503
|
-
total_tokens = token_usage.get(
|
747
|
+
prompt_tokens = token_usage.get("prompt_tokens")
|
748
|
+
completion_tokens = token_usage.get("completion_tokens")
|
749
|
+
total_tokens = token_usage.get(
|
750
|
+
"total_tokens"
|
751
|
+
) # OpenAI provides total
|
504
752
|
# Check for Anthropic 'usage'
|
505
|
-
elif
|
506
|
-
token_usage = response.llm_output.get(
|
753
|
+
elif "usage" in response.llm_output:
|
754
|
+
token_usage = response.llm_output.get("usage")
|
507
755
|
if token_usage and isinstance(token_usage, dict):
|
508
|
-
prompt_tokens = token_usage.get(
|
509
|
-
|
756
|
+
prompt_tokens = token_usage.get(
|
757
|
+
"input_tokens"
|
758
|
+
) # Anthropic uses input_tokens
|
759
|
+
completion_tokens = token_usage.get(
|
760
|
+
"output_tokens"
|
761
|
+
) # Anthropic uses output_tokens
|
510
762
|
# Calculate total if possible
|
511
763
|
if prompt_tokens is not None and completion_tokens is not None:
|
512
764
|
total_tokens = prompt_tokens + completion_tokens
|
@@ -517,57 +769,118 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
517
769
|
prompt_cost = None
|
518
770
|
completion_cost = None
|
519
771
|
total_cost_usd = None
|
520
|
-
|
521
|
-
if
|
772
|
+
|
773
|
+
if (
|
774
|
+
model_name
|
775
|
+
and prompt_tokens is not None
|
776
|
+
and completion_tokens is not None
|
777
|
+
):
|
522
778
|
try:
|
523
779
|
prompt_cost, completion_cost = cost_per_token(
|
524
780
|
model=model_name,
|
525
781
|
prompt_tokens=prompt_tokens,
|
526
|
-
completion_tokens=completion_tokens
|
782
|
+
completion_tokens=completion_tokens,
|
783
|
+
)
|
784
|
+
total_cost_usd = (
|
785
|
+
(prompt_cost + completion_cost)
|
786
|
+
if prompt_cost and completion_cost
|
787
|
+
else None
|
527
788
|
)
|
528
|
-
total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
529
789
|
except Exception as e:
|
530
790
|
# If cost calculation fails, continue without costs
|
531
791
|
import warnings
|
532
|
-
|
533
|
-
|
792
|
+
|
793
|
+
warnings.warn(
|
794
|
+
f"Failed to calculate token costs for model {model_name}: {e}"
|
795
|
+
)
|
796
|
+
|
534
797
|
# Create TraceUsage object
|
535
798
|
usage = TraceUsage(
|
536
799
|
prompt_tokens=prompt_tokens,
|
537
800
|
completion_tokens=completion_tokens,
|
538
|
-
total_tokens=total_tokens
|
801
|
+
total_tokens=total_tokens
|
802
|
+
or (
|
803
|
+
prompt_tokens + completion_tokens
|
804
|
+
if prompt_tokens and completion_tokens
|
805
|
+
else None
|
806
|
+
),
|
539
807
|
prompt_tokens_cost_usd=prompt_cost,
|
540
808
|
completion_tokens_cost_usd=completion_cost,
|
541
809
|
total_cost_usd=total_cost_usd,
|
542
|
-
model_name=model_name
|
810
|
+
model_name=model_name,
|
543
811
|
)
|
544
|
-
|
812
|
+
|
545
813
|
# Set usage on the actual span (not in outputs)
|
546
814
|
span_id = self._run_id_to_span_id.get(run_id)
|
547
815
|
if span_id and span_id in trace_client.span_id_to_span:
|
548
816
|
trace_span = trace_client.span_id_to_span[span_id]
|
549
817
|
trace_span.usage = usage
|
550
|
-
|
551
|
-
|
818
|
+
|
552
819
|
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
553
820
|
# --- End Token Usage ---
|
554
821
|
|
555
|
-
def on_llm_error(
|
556
|
-
|
822
|
+
def on_llm_error(
|
823
|
+
self,
|
824
|
+
error: BaseException,
|
825
|
+
*,
|
826
|
+
run_id: UUID,
|
827
|
+
parent_run_id: Optional[UUID] = None,
|
828
|
+
**kwargs: Any,
|
829
|
+
) -> Any:
|
557
830
|
# Pass parent_run_id
|
558
|
-
trace_client = self._ensure_trace_client(
|
559
|
-
|
831
|
+
trace_client = self._ensure_trace_client(
|
832
|
+
run_id, parent_run_id, "LLMError"
|
833
|
+
) # Corrected call
|
834
|
+
if not trace_client:
|
835
|
+
return
|
560
836
|
self._end_span_tracking(trace_client, run_id, error=error)
|
561
837
|
|
562
|
-
def on_chat_model_start(
|
838
|
+
def on_chat_model_start(
|
839
|
+
self,
|
840
|
+
serialized: Dict[str, Any],
|
841
|
+
messages: List[List[BaseMessage]],
|
842
|
+
*,
|
843
|
+
run_id: UUID,
|
844
|
+
parent_run_id: Optional[UUID] = None,
|
845
|
+
tags: Optional[List[str]] = None,
|
846
|
+
metadata: Optional[Dict[str, Any]] = None,
|
847
|
+
invocation_params: Optional[Dict[str, Any]] = None,
|
848
|
+
options: Optional[Dict[str, Any]] = None,
|
849
|
+
name: Optional[str] = None,
|
850
|
+
**kwargs: Any,
|
851
|
+
) -> Any:
|
563
852
|
# Reuse on_llm_start logic, adding message formatting if needed
|
564
853
|
chat_model_name = name or serialized.get("name", "ChatModel Call")
|
565
854
|
# Add OPENAI_API_CALL suffix if model is OpenAI and not present
|
566
|
-
is_openai =
|
567
|
-
|
568
|
-
|
855
|
+
is_openai = (
|
856
|
+
any(
|
857
|
+
key.startswith("openai") for key in serialized.get("secrets", {}).keys()
|
858
|
+
)
|
859
|
+
or "openai" in chat_model_name.lower()
|
860
|
+
)
|
861
|
+
is_anthropic = (
|
862
|
+
any(
|
863
|
+
key.startswith("anthropic")
|
864
|
+
for key in serialized.get("secrets", {}).keys()
|
865
|
+
)
|
866
|
+
or "anthropic" in chat_model_name.lower()
|
867
|
+
or "claude" in chat_model_name.lower()
|
868
|
+
)
|
869
|
+
is_together = (
|
870
|
+
any(
|
871
|
+
key.startswith("together")
|
872
|
+
for key in serialized.get("secrets", {}).keys()
|
873
|
+
)
|
874
|
+
or "together" in chat_model_name.lower()
|
875
|
+
)
|
569
876
|
# Add more checks for other providers like Google if needed
|
570
|
-
is_google =
|
877
|
+
is_google = (
|
878
|
+
any(
|
879
|
+
key.startswith("google") for key in serialized.get("secrets", {}).keys()
|
880
|
+
)
|
881
|
+
or "google" in chat_model_name.lower()
|
882
|
+
or "gemini" in chat_model_name.lower()
|
883
|
+
)
|
571
884
|
|
572
885
|
if is_openai and "OPENAI_API_CALL" not in chat_model_name:
|
573
886
|
chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
|
@@ -577,27 +890,76 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
577
890
|
chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
|
578
891
|
|
579
892
|
elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
|
580
|
-
|
581
|
-
|
582
|
-
trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call with parent_run_id
|
583
|
-
if not trace_client: return
|
584
|
-
inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
|
585
|
-
self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
|
893
|
+
chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
|
586
894
|
|
587
|
-
|
895
|
+
trace_client = self._ensure_trace_client(
|
896
|
+
run_id, parent_run_id, chat_model_name
|
897
|
+
) # Corrected call with parent_run_id
|
898
|
+
if not trace_client:
|
899
|
+
return
|
900
|
+
inputs = {
|
901
|
+
"messages": messages,
|
902
|
+
"invocation_params": invocation_params or kwargs,
|
903
|
+
"options": options,
|
904
|
+
"tags": tags,
|
905
|
+
"metadata": metadata,
|
906
|
+
"serialized": serialized,
|
907
|
+
}
|
908
|
+
self._start_span_tracking(
|
909
|
+
trace_client,
|
910
|
+
run_id,
|
911
|
+
parent_run_id,
|
912
|
+
chat_model_name,
|
913
|
+
span_type="llm",
|
914
|
+
inputs=inputs,
|
915
|
+
) # Use 'llm' span_type for consistency
|
916
|
+
|
917
|
+
def on_agent_action(
|
918
|
+
self,
|
919
|
+
action: AgentAction,
|
920
|
+
*,
|
921
|
+
run_id: UUID,
|
922
|
+
parent_run_id: Optional[UUID] = None,
|
923
|
+
**kwargs: Any,
|
924
|
+
) -> Any:
|
588
925
|
action_tool = action.tool
|
589
926
|
name = f"AGENT_ACTION_{(action_tool).upper()}"
|
590
927
|
# Pass parent_run_id
|
591
|
-
trace_client = self._ensure_trace_client(
|
592
|
-
|
928
|
+
trace_client = self._ensure_trace_client(
|
929
|
+
run_id, parent_run_id, name
|
930
|
+
) # Corrected call
|
931
|
+
if not trace_client:
|
932
|
+
return
|
593
933
|
|
594
|
-
inputs = {
|
595
|
-
|
934
|
+
inputs = {
|
935
|
+
"tool_input": action.tool_input,
|
936
|
+
"log": action.log,
|
937
|
+
"messages": action.messages,
|
938
|
+
"kwargs": kwargs,
|
939
|
+
}
|
940
|
+
self._start_span_tracking(
|
941
|
+
trace_client, run_id, parent_run_id, name, span_type="agent", inputs=inputs
|
942
|
+
)
|
596
943
|
|
597
|
-
def on_agent_finish(
|
944
|
+
def on_agent_finish(
|
945
|
+
self,
|
946
|
+
finish: AgentFinish,
|
947
|
+
*,
|
948
|
+
run_id: UUID,
|
949
|
+
parent_run_id: Optional[UUID] = None,
|
950
|
+
**kwargs: Any,
|
951
|
+
) -> Any:
|
598
952
|
# Pass parent_run_id
|
599
|
-
trace_client = self._ensure_trace_client(
|
600
|
-
|
953
|
+
trace_client = self._ensure_trace_client(
|
954
|
+
run_id, parent_run_id, "AgentFinish"
|
955
|
+
) # Corrected call
|
956
|
+
if not trace_client:
|
957
|
+
return
|
601
958
|
|
602
|
-
outputs = {
|
603
|
-
|
959
|
+
outputs = {
|
960
|
+
"return_values": finish.return_values,
|
961
|
+
"log": finish.log,
|
962
|
+
"messages": finish.messages,
|
963
|
+
"kwargs": kwargs,
|
964
|
+
}
|
965
|
+
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|