judgeval 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/api/__init__.py +3 -0
- judgeval/common/api/api.py +352 -0
- judgeval/common/api/constants.py +165 -0
- judgeval/common/storage/__init__.py +6 -0
- judgeval/common/tracer/__init__.py +31 -0
- judgeval/common/tracer/constants.py +22 -0
- judgeval/common/tracer/core.py +1916 -0
- judgeval/common/tracer/otel_exporter.py +108 -0
- judgeval/common/tracer/otel_span_processor.py +234 -0
- judgeval/common/tracer/span_processor.py +37 -0
- judgeval/common/tracer/span_transformer.py +211 -0
- judgeval/common/tracer/trace_manager.py +92 -0
- judgeval/common/utils.py +2 -2
- judgeval/constants.py +3 -30
- judgeval/data/datasets/eval_dataset_client.py +29 -156
- judgeval/data/judgment_types.py +4 -12
- judgeval/data/result.py +1 -1
- judgeval/data/scorer_data.py +2 -2
- judgeval/data/scripts/openapi_transform.py +1 -1
- judgeval/data/trace.py +66 -1
- judgeval/data/trace_run.py +0 -3
- judgeval/evaluation_run.py +0 -2
- judgeval/integrations/langgraph.py +43 -164
- judgeval/judgment_client.py +17 -211
- judgeval/run_evaluation.py +209 -611
- judgeval/scorers/__init__.py +2 -6
- judgeval/scorers/base_scorer.py +4 -23
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +3 -3
- judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +215 -0
- judgeval/scorers/score.py +2 -1
- judgeval/scorers/utils.py +1 -13
- judgeval/utils/requests.py +21 -0
- judgeval-0.1.0.dist-info/METADATA +202 -0
- {judgeval-0.0.54.dist-info → judgeval-0.1.0.dist-info}/RECORD +37 -29
- judgeval/common/tracer.py +0 -3215
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +0 -73
- judgeval/scorers/judgeval_scorers/classifiers/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/__init__.py +0 -3
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +0 -53
- judgeval-0.0.54.dist-info/METADATA +0 -1384
- /judgeval/common/{s3_storage.py → storage/s3_storage.py} +0 -0
- {judgeval-0.0.54.dist-info → judgeval-0.1.0.dist-info}/WHEEL +0 -0
- {judgeval-0.0.54.dist-info → judgeval-0.1.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -19,15 +19,9 @@ from langchain_core.outputs import LLMResult
|
|
19
19
|
from langchain_core.messages.base import BaseMessage
|
20
20
|
from langchain_core.documents import Document
|
21
21
|
|
22
|
-
# --- Get context vars from tracer module ---
|
23
|
-
# Assuming tracer.py defines these and they are accessible
|
24
|
-
# If not, redefine them here or adjust import
|
25
|
-
|
26
|
-
# from judgeval.common.tracer import current_span_var
|
27
22
|
# TODO: Figure out how to handle context variables. Current solution is to keep track of current span id in Tracer class
|
28
23
|
|
29
24
|
|
30
|
-
# --- NEW __init__ ---
|
31
25
|
class JudgevalCallbackHandler(BaseCallbackHandler):
|
32
26
|
"""
|
33
27
|
LangChain Callback Handler using run_id/parent_run_id for hierarchy.
|
@@ -40,19 +34,11 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
40
34
|
lc_serializable = False
|
41
35
|
lc_kwargs: dict = {}
|
42
36
|
|
43
|
-
# --- NEW __init__ ---
|
44
37
|
def __init__(self, tracer: Tracer):
|
45
38
|
self.tracer = tracer
|
46
|
-
# Initialize tracking/logging variables (preserved across resets)
|
47
39
|
self.executed_nodes: List[str] = []
|
48
|
-
self.executed_tools: List[str] = []
|
49
|
-
self.executed_node_tools: List[str] = []
|
50
|
-
self.traces: List[Dict[str, Any]] = []
|
51
|
-
# Initialize execution state (reset between runs)
|
52
40
|
self._reset_state()
|
53
41
|
|
54
|
-
# --- END NEW __init__ ---
|
55
|
-
|
56
42
|
def _reset_state(self):
|
57
43
|
"""Reset only the critical execution state for reuse across multiple executions"""
|
58
44
|
# Reset core execution state that must be cleared between runs
|
@@ -61,28 +47,16 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
61
47
|
self._span_id_to_start_time: Dict[str, float] = {}
|
62
48
|
self._span_id_to_depth: Dict[str, int] = {}
|
63
49
|
self._root_run_id: Optional[UUID] = None
|
64
|
-
self._trace_saved: bool = False
|
50
|
+
self._trace_saved: bool = False
|
65
51
|
self.span_id_to_token: Dict[str, Any] = {}
|
66
52
|
self.trace_id_to_token: Dict[str, Any] = {}
|
67
53
|
|
68
54
|
# Add timestamp to track when we last reset
|
69
55
|
self._last_reset_time: float = time.time()
|
70
56
|
|
71
|
-
# Preserve tracking/logging variables across executions:
|
72
|
-
# - self.executed_nodes: List[str] = [] # Keep as running log
|
73
|
-
# - self.executed_tools: List[str] = [] # Keep as running log
|
74
|
-
# - self.executed_node_tools: List[str] = [] # Keep as running log
|
75
|
-
# - self.traces: List[Dict[str, Any]] = [] # Keep for collecting multiple traces
|
76
|
-
|
77
57
|
# Also reset tracking/logging variables
|
78
|
-
self.executed_nodes: List[
|
79
|
-
|
80
|
-
] = [] # These last four members are only appended to and never accessed; can probably be removed but still might be useful for future reference?
|
81
|
-
self.executed_tools: List[str] = []
|
82
|
-
self.executed_node_tools: List[str] = []
|
83
|
-
self.traces: List[Dict[str, Any]] = []
|
84
|
-
|
85
|
-
# --- END NEW __init__ ---
|
58
|
+
self.executed_nodes: List[str] = []
|
59
|
+
|
86
60
|
def reset(self):
|
87
61
|
"""Public method to manually reset handler execution state for reuse"""
|
88
62
|
self._reset_state()
|
@@ -91,7 +65,6 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
91
65
|
"""Public method to reset ALL handler state including tracking/logging data"""
|
92
66
|
self._reset_state()
|
93
67
|
|
94
|
-
# --- MODIFIED _ensure_trace_client ---
|
95
68
|
def _ensure_trace_client(
|
96
69
|
self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str
|
97
70
|
) -> Optional[TraceClient]:
|
@@ -127,19 +100,14 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
127
100
|
token = self.tracer.set_current_trace(self._trace_client)
|
128
101
|
if token:
|
129
102
|
self.trace_id_to_token[trace_id] = token
|
103
|
+
|
130
104
|
if self._trace_client:
|
131
|
-
self._root_run_id =
|
132
|
-
|
133
|
-
)
|
134
|
-
self._trace_saved = False # Ensure flag is reset
|
135
|
-
# Set active client on Tracer (important for potential fallbacks)
|
105
|
+
self._root_run_id = run_id
|
106
|
+
self._trace_saved = False
|
136
107
|
self.tracer._active_trace_client = self._trace_client
|
137
108
|
|
138
|
-
# NEW: Initial save for live tracking (follows the new practice)
|
139
109
|
try:
|
140
|
-
|
141
|
-
final_save=False, # Initial save for live tracking
|
142
|
-
)
|
110
|
+
self._trace_client.save(final_save=False)
|
143
111
|
except Exception as e:
|
144
112
|
import warnings
|
145
113
|
|
@@ -207,18 +175,13 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
207
175
|
# Set both fields on the span
|
208
176
|
new_span.inputs = clean_inputs
|
209
177
|
new_span.additional_metadata = metadata
|
210
|
-
new_span.increment_update_id() # Thread-safe increment for span modification
|
211
178
|
else:
|
212
179
|
new_span.inputs = {}
|
213
180
|
new_span.additional_metadata = {}
|
214
181
|
|
215
182
|
trace_client.add_span(new_span)
|
216
183
|
|
217
|
-
|
218
|
-
if trace_client.background_span_service:
|
219
|
-
trace_client.background_span_service.queue_span(
|
220
|
-
new_span, span_state="input"
|
221
|
-
)
|
184
|
+
trace_client.otel_span_processor.queue_span_update(new_span, span_state="input")
|
222
185
|
|
223
186
|
token = self.tracer.set_current_span(span_id)
|
224
187
|
if token:
|
@@ -247,12 +210,10 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
247
210
|
trace_span = trace_client.span_id_to_span.get(span_id)
|
248
211
|
if trace_span:
|
249
212
|
trace_span.duration = duration
|
250
|
-
trace_span.increment_update_id() # Thread-safe increment for span modification
|
251
213
|
|
252
214
|
# Handle outputs and error
|
253
215
|
if error:
|
254
216
|
trace_span.output = error
|
255
|
-
trace_span.increment_update_id() # Thread-safe increment for span modification
|
256
217
|
elif outputs:
|
257
218
|
# Separate metadata from outputs
|
258
219
|
metadata = {}
|
@@ -272,7 +233,6 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
272
233
|
|
273
234
|
# Set both fields on the span
|
274
235
|
trace_span.output = clean_outputs
|
275
|
-
trace_span.increment_update_id() # Thread-safe increment for span modification
|
276
236
|
if metadata:
|
277
237
|
# Merge with existing metadata
|
278
238
|
existing_metadata = trace_span.additional_metadata or {}
|
@@ -280,14 +240,11 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
280
240
|
**existing_metadata,
|
281
241
|
**metadata,
|
282
242
|
}
|
283
|
-
trace_span.increment_update_id() # Thread-safe increment for span modification
|
284
243
|
|
285
|
-
|
286
|
-
|
287
|
-
span_state
|
288
|
-
|
289
|
-
trace_span, span_state=span_state
|
290
|
-
)
|
244
|
+
span_state = "error" if error else "completed"
|
245
|
+
trace_client.otel_span_processor.queue_span_update(
|
246
|
+
trace_span, span_state=span_state
|
247
|
+
)
|
291
248
|
|
292
249
|
# Clean up dictionaries for this specific span
|
293
250
|
if span_id in self._span_id_to_start_time:
|
@@ -298,15 +255,10 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
298
255
|
# Check if this is the root run ending
|
299
256
|
if run_id == self._root_run_id:
|
300
257
|
try:
|
301
|
-
# Reset root run id after attempt
|
302
258
|
self._root_run_id = None
|
303
|
-
# Reset input storage for this handler instance
|
304
|
-
|
305
259
|
if (
|
306
260
|
self._trace_client and not self._trace_saved
|
307
261
|
): # Check if not already saved
|
308
|
-
# Flush background spans before saving the final trace
|
309
|
-
|
310
262
|
complete_trace_data = {
|
311
263
|
"trace_id": self._trace_client.trace_id,
|
312
264
|
"name": self._trace_client.name,
|
@@ -321,6 +273,9 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
321
273
|
"parent_trace_id": self._trace_client.parent_trace_id,
|
322
274
|
"parent_name": self._trace_client.parent_name,
|
323
275
|
}
|
276
|
+
|
277
|
+
self.tracer.flush_background_spans()
|
278
|
+
|
324
279
|
trace_id, trace_data = self._trace_client.save(
|
325
280
|
final_save=True, # Final save with usage counter updates
|
326
281
|
)
|
@@ -331,14 +286,12 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
331
286
|
self.tracer.traces.append(complete_trace_data)
|
332
287
|
self._trace_saved = True # Set flag only after successful save
|
333
288
|
finally:
|
334
|
-
# --- NEW: Consolidated Cleanup Logic ---
|
335
289
|
# This block executes regardless of save success/failure
|
336
290
|
# Reset root run id
|
337
291
|
self._root_run_id = None
|
338
292
|
# Reset input storage for this handler instance
|
339
293
|
if self.tracer._active_trace_client == self._trace_client:
|
340
294
|
self.tracer._active_trace_client = None
|
341
|
-
# --- End Cleanup Logic ---
|
342
295
|
|
343
296
|
# --- Callback Methods ---
|
344
297
|
# Each method now ensures the trace client exists before proceeding
|
@@ -361,10 +314,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
361
314
|
)
|
362
315
|
|
363
316
|
name = f"RETRIEVER_{(serialized_name).upper()}"
|
364
|
-
|
365
|
-
trace_client = self._ensure_trace_client(
|
366
|
-
run_id, parent_run_id, name
|
367
|
-
) # Corrected call
|
317
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
368
318
|
if not trace_client:
|
369
319
|
return
|
370
320
|
|
@@ -392,17 +342,17 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
392
342
|
parent_run_id: Optional[UUID] = None,
|
393
343
|
**kwargs: Any,
|
394
344
|
) -> Any:
|
395
|
-
trace_client = self._ensure_trace_client(
|
396
|
-
run_id, parent_run_id, "RetrieverEnd"
|
397
|
-
) # Corrected call
|
345
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd")
|
398
346
|
if not trace_client:
|
399
347
|
return
|
400
348
|
doc_summary = [
|
401
349
|
{
|
402
350
|
"index": i,
|
403
|
-
"page_content":
|
404
|
-
|
405
|
-
|
351
|
+
"page_content": (
|
352
|
+
doc.page_content[:100] + "..."
|
353
|
+
if len(doc.page_content) > 100
|
354
|
+
else doc.page_content
|
355
|
+
),
|
406
356
|
"metadata": doc.metadata,
|
407
357
|
}
|
408
358
|
for i, doc in enumerate(documents)
|
@@ -431,7 +381,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
431
381
|
|
432
382
|
# --- Determine Name and Span Type ---
|
433
383
|
span_type: SpanType = "chain"
|
434
|
-
name = serialized_name if serialized_name else "Unknown Chain"
|
384
|
+
name = serialized_name if serialized_name else "Unknown Chain"
|
435
385
|
node_name = metadata.get("langgraph_node") if metadata else None
|
436
386
|
is_langgraph_root_kwarg = (
|
437
387
|
kwargs.get("name") == "LangGraph"
|
@@ -449,25 +399,17 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
449
399
|
name = "LangGraph" # Explicit root detected
|
450
400
|
# Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
|
451
401
|
|
452
|
-
|
453
|
-
# Pass parent_run_id to _ensure_trace_client
|
454
|
-
trace_client = self._ensure_trace_client(
|
455
|
-
run_id, parent_run_id, name
|
456
|
-
) # Corrected call
|
402
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
457
403
|
if not trace_client:
|
458
404
|
return
|
459
405
|
|
460
|
-
# --- Update Trace Name if Root ---
|
461
|
-
# If this is the root event (parent_run_id is None) and the trace client was just created,
|
462
|
-
# ensure the trace name reflects the graph's name ('LangGraph' usually).
|
463
406
|
if (
|
464
407
|
is_potential_root_event
|
465
408
|
and run_id == self._root_run_id
|
466
409
|
and trace_client.name != name
|
467
410
|
):
|
468
|
-
trace_client.name = name
|
411
|
+
trace_client.name = name
|
469
412
|
|
470
|
-
# --- Start Span Tracking ---
|
471
413
|
combined_inputs = {
|
472
414
|
"inputs": inputs,
|
473
415
|
"tags": tags,
|
@@ -493,28 +435,20 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
493
435
|
tags: Optional[List[str]] = None,
|
494
436
|
**kwargs: Any,
|
495
437
|
) -> Any:
|
496
|
-
|
497
|
-
trace_client = self._ensure_trace_client(
|
498
|
-
run_id, parent_run_id, "ChainEnd"
|
499
|
-
) # Corrected call
|
438
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd")
|
500
439
|
if not trace_client:
|
501
440
|
return
|
502
441
|
|
503
442
|
span_id = self._run_id_to_span_id.get(run_id)
|
504
|
-
# If it's the root run ending, _end_span_tracking will handle cleanup/save
|
505
443
|
if not span_id and run_id != self._root_run_id:
|
506
|
-
return
|
444
|
+
return
|
507
445
|
|
508
|
-
# Prepare outputs for end tracking (moved down)
|
509
446
|
combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
|
510
447
|
|
511
|
-
# Call end_span_tracking with potentially determined span_type
|
512
448
|
self._end_span_tracking(trace_client, run_id, outputs=combined_outputs)
|
513
449
|
|
514
|
-
# --- Root node cleanup (Existing logic - slightly modified save call) ---
|
515
450
|
if run_id == self._root_run_id:
|
516
451
|
if trace_client and not self._trace_saved:
|
517
|
-
# Store complete trace data instead of server response
|
518
452
|
complete_trace_data = {
|
519
453
|
"trace_id": trace_client.trace_id,
|
520
454
|
"name": trace_client.name,
|
@@ -529,19 +463,19 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
529
463
|
"parent_trace_id": trace_client.parent_trace_id,
|
530
464
|
"parent_name": trace_client.parent_name,
|
531
465
|
}
|
532
|
-
|
466
|
+
|
467
|
+
self.tracer.flush_background_spans()
|
468
|
+
|
469
|
+
trace_client.save(
|
533
470
|
final_save=True,
|
534
471
|
)
|
535
472
|
|
536
473
|
self.tracer.traces.append(complete_trace_data)
|
537
474
|
self._trace_saved = True
|
538
|
-
# Reset tracer's active client *after* successful save
|
539
475
|
if self.tracer._active_trace_client == trace_client:
|
540
476
|
self.tracer._active_trace_client = None
|
541
477
|
|
542
|
-
# Reset root run id after attempt
|
543
478
|
self._root_run_id = None
|
544
|
-
# Reset input storage for this handler instance
|
545
479
|
|
546
480
|
def on_chain_error(
|
547
481
|
self,
|
@@ -551,16 +485,12 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
551
485
|
parent_run_id: Optional[UUID] = None,
|
552
486
|
**kwargs: Any,
|
553
487
|
) -> Any:
|
554
|
-
|
555
|
-
trace_client = self._ensure_trace_client(
|
556
|
-
run_id, parent_run_id, "ChainError"
|
557
|
-
) # Corrected call
|
488
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError")
|
558
489
|
if not trace_client:
|
559
490
|
return
|
560
491
|
|
561
492
|
span_id = self._run_id_to_span_id.get(run_id)
|
562
493
|
|
563
|
-
# Let _end_span_tracking handle potential root run cleanup
|
564
494
|
if not span_id and run_id != self._root_run_id:
|
565
495
|
return
|
566
496
|
|
@@ -584,10 +514,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
584
514
|
else "Unknown Tool (Serialized=None)"
|
585
515
|
)
|
586
516
|
|
587
|
-
|
588
|
-
trace_client = self._ensure_trace_client(
|
589
|
-
run_id, parent_run_id, name
|
590
|
-
) # Corrected call
|
517
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
591
518
|
if not trace_client:
|
592
519
|
return
|
593
520
|
|
@@ -608,23 +535,6 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
608
535
|
inputs=combined_inputs,
|
609
536
|
)
|
610
537
|
|
611
|
-
# --- Track executed tools (remains the same) ---
|
612
|
-
if name not in self.executed_tools:
|
613
|
-
self.executed_tools.append(
|
614
|
-
name
|
615
|
-
) # Leaving this in for now but can probably be removed
|
616
|
-
parent_node_name = None
|
617
|
-
if parent_run_id and parent_run_id in self._run_id_to_span_id:
|
618
|
-
parent_span_id = self._run_id_to_span_id[parent_run_id]
|
619
|
-
parent_node_name = trace_client.span_id_to_span[parent_span_id].function
|
620
|
-
|
621
|
-
node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
|
622
|
-
if node_tool not in self.executed_node_tools:
|
623
|
-
self.executed_node_tools.append(
|
624
|
-
node_tool
|
625
|
-
) # Leaving this in for now but can probably be removed
|
626
|
-
# --- End Track executed tools ---
|
627
|
-
|
628
538
|
def on_tool_end(
|
629
539
|
self,
|
630
540
|
output: Any,
|
@@ -633,10 +543,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
633
543
|
parent_run_id: Optional[UUID] = None,
|
634
544
|
**kwargs: Any,
|
635
545
|
) -> Any:
|
636
|
-
|
637
|
-
trace_client = self._ensure_trace_client(
|
638
|
-
run_id, parent_run_id, "ToolEnd"
|
639
|
-
) # Corrected call
|
546
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd")
|
640
547
|
if not trace_client:
|
641
548
|
return
|
642
549
|
outputs = {"output": output, "kwargs": kwargs}
|
@@ -650,10 +557,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
650
557
|
parent_run_id: Optional[UUID] = None,
|
651
558
|
**kwargs: Any,
|
652
559
|
) -> Any:
|
653
|
-
|
654
|
-
trace_client = self._ensure_trace_client(
|
655
|
-
run_id, parent_run_id, "ToolError"
|
656
|
-
) # Corrected call
|
560
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError")
|
657
561
|
if not trace_client:
|
658
562
|
return
|
659
563
|
self._end_span_tracking(trace_client, run_id, error=error)
|
@@ -674,9 +578,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
674
578
|
) -> Any:
|
675
579
|
llm_name = name or serialized.get("name", "LLM Call")
|
676
580
|
|
677
|
-
trace_client = self._ensure_trace_client(
|
678
|
-
run_id, parent_run_id, llm_name
|
679
|
-
) # Corrected call
|
581
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name)
|
680
582
|
if not trace_client:
|
681
583
|
return
|
682
584
|
inputs = {
|
@@ -704,15 +606,11 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
704
606
|
parent_run_id: Optional[UUID] = None,
|
705
607
|
**kwargs: Any,
|
706
608
|
) -> Any:
|
707
|
-
|
708
|
-
trace_client = self._ensure_trace_client(
|
709
|
-
run_id, parent_run_id, "LLMEnd"
|
710
|
-
) # Corrected call
|
609
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd")
|
711
610
|
if not trace_client:
|
712
611
|
return
|
713
612
|
outputs = {"response": response, "kwargs": kwargs}
|
714
613
|
|
715
|
-
# --- Token Usage Extraction and Cost Calculation ---
|
716
614
|
prompt_tokens = None
|
717
615
|
completion_tokens = None
|
718
616
|
total_tokens = None
|
@@ -761,9 +659,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
761
659
|
if prompt_tokens is not None and completion_tokens is not None:
|
762
660
|
total_tokens = prompt_tokens + completion_tokens
|
763
661
|
|
764
|
-
# --- Create TraceUsage object and set on span ---
|
765
662
|
if prompt_tokens is not None or completion_tokens is not None:
|
766
|
-
# Calculate costs if model name is available
|
767
663
|
prompt_cost = None
|
768
664
|
completion_cost = None
|
769
665
|
total_cost_usd = None
|
@@ -792,7 +688,6 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
792
688
|
f"Failed to calculate token costs for model {model_name}: {e}"
|
793
689
|
)
|
794
690
|
|
795
|
-
# Create TraceUsage object
|
796
691
|
usage = TraceUsage(
|
797
692
|
prompt_tokens=prompt_tokens,
|
798
693
|
completion_tokens=completion_tokens,
|
@@ -808,15 +703,12 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
808
703
|
model_name=model_name,
|
809
704
|
)
|
810
705
|
|
811
|
-
# Set usage on the actual span (not in outputs)
|
812
706
|
span_id = self._run_id_to_span_id.get(run_id)
|
813
707
|
if span_id and span_id in trace_client.span_id_to_span:
|
814
708
|
trace_span = trace_client.span_id_to_span[span_id]
|
815
709
|
trace_span.usage = usage
|
816
|
-
trace_span.increment_update_id() # Thread-safe increment for span modification
|
817
710
|
|
818
711
|
self._end_span_tracking(trace_client, run_id, outputs=outputs)
|
819
|
-
# --- End Token Usage ---
|
820
712
|
|
821
713
|
def on_llm_error(
|
822
714
|
self,
|
@@ -826,10 +718,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
826
718
|
parent_run_id: Optional[UUID] = None,
|
827
719
|
**kwargs: Any,
|
828
720
|
) -> Any:
|
829
|
-
|
830
|
-
trace_client = self._ensure_trace_client(
|
831
|
-
run_id, parent_run_id, "LLMError"
|
832
|
-
) # Corrected call
|
721
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError")
|
833
722
|
if not trace_client:
|
834
723
|
return
|
835
724
|
self._end_span_tracking(trace_client, run_id, error=error)
|
@@ -848,9 +737,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
848
737
|
name: Optional[str] = None,
|
849
738
|
**kwargs: Any,
|
850
739
|
) -> Any:
|
851
|
-
# Reuse on_llm_start logic, adding message formatting if needed
|
852
740
|
chat_model_name = name or serialized.get("name", "ChatModel Call")
|
853
|
-
# Add OPENAI_API_CALL suffix if model is OpenAI and not present
|
854
741
|
is_openai = (
|
855
742
|
any(
|
856
743
|
key.startswith("openai") for key in serialized.get("secrets", {}).keys()
|
@@ -872,7 +759,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
872
759
|
)
|
873
760
|
or "together" in chat_model_name.lower()
|
874
761
|
)
|
875
|
-
|
762
|
+
|
876
763
|
is_google = (
|
877
764
|
any(
|
878
765
|
key.startswith("google") for key in serialized.get("secrets", {}).keys()
|
@@ -891,9 +778,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
891
778
|
elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
|
892
779
|
chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
|
893
780
|
|
894
|
-
trace_client = self._ensure_trace_client(
|
895
|
-
run_id, parent_run_id, chat_model_name
|
896
|
-
) # Corrected call with parent_run_id
|
781
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name)
|
897
782
|
if not trace_client:
|
898
783
|
return
|
899
784
|
inputs = {
|
@@ -911,7 +796,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
911
796
|
chat_model_name,
|
912
797
|
span_type="llm",
|
913
798
|
inputs=inputs,
|
914
|
-
)
|
799
|
+
)
|
915
800
|
|
916
801
|
def on_agent_action(
|
917
802
|
self,
|
@@ -923,10 +808,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
923
808
|
) -> Any:
|
924
809
|
action_tool = action.tool
|
925
810
|
name = f"AGENT_ACTION_{(action_tool).upper()}"
|
926
|
-
|
927
|
-
trace_client = self._ensure_trace_client(
|
928
|
-
run_id, parent_run_id, name
|
929
|
-
) # Corrected call
|
811
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
|
930
812
|
if not trace_client:
|
931
813
|
return
|
932
814
|
|
@@ -948,10 +830,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
|
|
948
830
|
parent_run_id: Optional[UUID] = None,
|
949
831
|
**kwargs: Any,
|
950
832
|
) -> Any:
|
951
|
-
|
952
|
-
trace_client = self._ensure_trace_client(
|
953
|
-
run_id, parent_run_id, "AgentFinish"
|
954
|
-
) # Corrected call
|
833
|
+
trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish")
|
955
834
|
if not trace_client:
|
956
835
|
return
|
957
836
|
|