judgeval 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. judgeval/__init__.py +5 -4
  2. judgeval/clients.py +6 -6
  3. judgeval/common/__init__.py +7 -2
  4. judgeval/common/exceptions.py +2 -3
  5. judgeval/common/logger.py +74 -49
  6. judgeval/common/s3_storage.py +30 -23
  7. judgeval/common/tracer.py +1273 -939
  8. judgeval/common/utils.py +416 -244
  9. judgeval/constants.py +73 -61
  10. judgeval/data/__init__.py +1 -1
  11. judgeval/data/custom_example.py +3 -2
  12. judgeval/data/datasets/dataset.py +80 -54
  13. judgeval/data/datasets/eval_dataset_client.py +131 -181
  14. judgeval/data/example.py +67 -43
  15. judgeval/data/result.py +11 -9
  16. judgeval/data/scorer_data.py +4 -2
  17. judgeval/data/tool.py +25 -16
  18. judgeval/data/trace.py +57 -29
  19. judgeval/data/trace_run.py +5 -11
  20. judgeval/evaluation_run.py +22 -82
  21. judgeval/integrations/langgraph.py +546 -184
  22. judgeval/judges/base_judge.py +1 -2
  23. judgeval/judges/litellm_judge.py +33 -11
  24. judgeval/judges/mixture_of_judges.py +128 -78
  25. judgeval/judges/together_judge.py +22 -9
  26. judgeval/judges/utils.py +14 -5
  27. judgeval/judgment_client.py +259 -271
  28. judgeval/rules.py +169 -142
  29. judgeval/run_evaluation.py +462 -305
  30. judgeval/scorers/api_scorer.py +20 -11
  31. judgeval/scorers/exceptions.py +1 -0
  32. judgeval/scorers/judgeval_scorer.py +77 -58
  33. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
  34. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
  35. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
  36. judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
  37. judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
  38. judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
  39. judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
  40. judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
  41. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
  42. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
  43. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
  44. judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
  45. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
  46. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
  47. judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
  48. judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
  49. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
  50. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
  51. judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
  52. judgeval/scorers/prompt_scorer.py +48 -37
  53. judgeval/scorers/score.py +86 -53
  54. judgeval/scorers/utils.py +11 -7
  55. judgeval/tracer/__init__.py +1 -1
  56. judgeval/utils/alerts.py +23 -12
  57. judgeval/utils/{data_utils.py → file_utils.py} +5 -9
  58. judgeval/utils/requests.py +29 -0
  59. judgeval/version_check.py +5 -2
  60. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/METADATA +79 -135
  61. judgeval-0.0.46.dist-info/RECORD +69 -0
  62. judgeval-0.0.44.dist-info/RECORD +0 -68
  63. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/WHEEL +0 -0
  64. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/licenses/LICENSE.md +0 -0
@@ -2,11 +2,15 @@ from typing import Any, Dict, List, Optional, Sequence
2
2
  from uuid import UUID
3
3
  import time
4
4
  import uuid
5
- import contextvars # <--- Import contextvars
6
5
  from datetime import datetime
7
6
 
8
- from judgeval.common.tracer import TraceClient, TraceSpan, Tracer, SpanType, EvaluationConfig, cost_per_token
9
- from judgeval.data import Example # Import Example
7
+ from judgeval.common.tracer import (
8
+ TraceClient,
9
+ TraceSpan,
10
+ Tracer,
11
+ SpanType,
12
+ cost_per_token,
13
+ )
10
14
  from judgeval.data.trace import TraceUsage
11
15
 
12
16
  from langchain_core.callbacks import BaseCallbackHandler
@@ -22,6 +26,7 @@ from langchain_core.documents import Document
22
26
  # from judgeval.common.tracer import current_span_var
23
27
  # TODO: Figure out how to handle context variables. Current solution is to keep track of current span id in Tracer class
24
28
 
29
+
25
30
  # --- NEW __init__ ---
26
31
  class JudgevalCallbackHandler(BaseCallbackHandler):
27
32
  """
@@ -29,14 +34,14 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
29
34
  Manages its own internal TraceClient instance created upon first use.
30
35
  Includes verbose logging and defensive checks.
31
36
  """
37
+
32
38
  # Make all properties ignored by LangChain's callback system
33
39
  # to prevent unexpected serialization issues.
34
40
  lc_serializable = False
35
- lc_kwargs = {}
41
+ lc_kwargs: dict = {}
36
42
 
37
43
  # --- NEW __init__ ---
38
44
  def __init__(self, tracer: Tracer):
39
-
40
45
  self.tracer = tracer
41
46
  # Initialize tracking/logging variables (preserved across resets)
42
47
  self.executed_nodes: List[str] = []
@@ -45,6 +50,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
45
50
  self.traces: List[Dict[str, Any]] = []
46
51
  # Initialize execution state (reset between runs)
47
52
  self._reset_state()
53
+
48
54
  # --- END NEW __init__ ---
49
55
 
50
56
  def _reset_state(self):
@@ -55,19 +61,28 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
55
61
  self._span_id_to_start_time: Dict[str, float] = {}
56
62
  self._span_id_to_depth: Dict[str, int] = {}
57
63
  self._root_run_id: Optional[UUID] = None
58
- self._trace_saved: bool = False
64
+ self._trace_saved: bool = False # Flag to prevent actions after trace is saved
59
65
  self.span_id_to_token: Dict[str, Any] = {}
60
66
  self.trace_id_to_token: Dict[str, Any] = {}
61
-
67
+
62
68
  # Add timestamp to track when we last reset
63
69
  self._last_reset_time: float = time.time()
64
-
70
+
65
71
  # Preserve tracking/logging variables across executions:
66
72
  # - self.executed_nodes: List[str] = [] # Keep as running log
67
- # - self.executed_tools: List[str] = [] # Keep as running log
73
+ # - self.executed_tools: List[str] = [] # Keep as running log
68
74
  # - self.executed_node_tools: List[str] = [] # Keep as running log
69
75
  # - self.traces: List[Dict[str, Any]] = [] # Keep for collecting multiple traces
70
76
 
77
+ # Also reset tracking/logging variables
78
+ self.executed_nodes: List[
79
+ str
80
+ ] = [] # These last four members are only appended to and never accessed; can probably be removed but still might be useful for future reference?
81
+ self.executed_tools: List[str] = []
82
+ self.executed_node_tools: List[str] = []
83
+ self.traces: List[Dict[str, Any]] = []
84
+
85
+ # --- END NEW __init__ ---
71
86
  def reset(self):
72
87
  """Public method to manually reset handler execution state for reuse"""
73
88
  self._reset_state()
@@ -75,14 +90,11 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
75
90
  def reset_all(self):
76
91
  """Public method to reset ALL handler state including tracking/logging data"""
77
92
  self._reset_state()
78
- # Also reset tracking/logging variables
79
- self.executed_nodes: List[str] = []
80
- self.executed_tools: List[str] = []
81
- self.executed_node_tools: List[str] = []
82
- self.traces: List[Dict[str, Any]] = []
83
93
 
84
94
  # --- MODIFIED _ensure_trace_client ---
85
- def _ensure_trace_client(self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str) -> Optional[TraceClient]:
95
+ def _ensure_trace_client(
96
+ self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str
97
+ ) -> Optional[TraceClient]:
86
98
  """
87
99
  Ensures the internal trace client is initialized, creating it only once
88
100
  per handler instance lifecycle (effectively per graph invocation).
@@ -104,36 +116,44 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
104
116
  try:
105
117
  # Use event_name as the initial trace name, might be updated later by on_chain_start if root
106
118
  client_instance = TraceClient(
107
- self.tracer, trace_id, event_name, project_name=project,
108
- overwrite=False, rules=self.tracer.rules,
119
+ self.tracer,
120
+ trace_id,
121
+ event_name,
122
+ project_name=project,
123
+ overwrite=False,
124
+ rules=self.tracer.rules,
109
125
  enable_monitoring=self.tracer.enable_monitoring,
110
- enable_evaluations=self.tracer.enable_evaluations
126
+ enable_evaluations=self.tracer.enable_evaluations,
111
127
  )
112
128
  self._trace_client = client_instance
113
129
  token = self.tracer.set_current_trace(self._trace_client)
114
130
  if token:
115
131
  self.trace_id_to_token[trace_id] = token
116
132
  if self._trace_client:
117
- self._root_run_id = run_id # Assign the first run_id encountered as the tentative root
118
- self._trace_saved = False # Ensure flag is reset
133
+ self._root_run_id = (
134
+ run_id # Assign the first run_id encountered as the tentative root
135
+ )
136
+ self._trace_saved = False # Ensure flag is reset
119
137
  # Set active client on Tracer (important for potential fallbacks)
120
138
  self.tracer._active_trace_client = self._trace_client
121
-
139
+
122
140
  # NEW: Initial save for live tracking (follows the new practice)
123
141
  try:
124
- trace_id_saved, server_response = self._trace_client.save_with_rate_limiting(
125
- overwrite=self._trace_client.overwrite,
126
- final_save=False # Initial save for live tracking
142
+ trace_id_saved, server_response = self._trace_client.save(
143
+ overwrite=self._trace_client.overwrite,
144
+ final_save=False, # Initial save for live tracking
127
145
  )
128
146
  except Exception as e:
129
147
  import warnings
130
- warnings.warn(f"Failed to save initial trace for live tracking: {e}")
131
-
148
+
149
+ warnings.warn(
150
+ f"Failed to save initial trace for live tracking: {e}"
151
+ )
152
+
132
153
  return self._trace_client
133
154
  else:
134
155
  return None
135
- except Exception as e:
136
-
156
+ except Exception:
137
157
  self._trace_client = None
138
158
  self._root_run_id = None
139
159
  return None
@@ -145,7 +165,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
145
165
  parent_run_id: Optional[UUID],
146
166
  name: str,
147
167
  span_type: SpanType = "span",
148
- inputs: Optional[Dict[str, Any]] = None
168
+ inputs: Optional[Dict[str, Any]] = None,
149
169
  ) -> None:
150
170
  """Start tracking a span, ensuring trace client exists"""
151
171
 
@@ -157,7 +177,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
157
177
  if parent_run_id and parent_run_id in self._run_id_to_span_id:
158
178
  parent_span_id = self._run_id_to_span_id[parent_run_id]
159
179
  if parent_span_id in self._span_id_to_depth:
160
- current_depth = self._span_id_to_depth[parent_span_id] + 1
180
+ current_depth = self._span_id_to_depth[parent_span_id] + 1
161
181
 
162
182
  self._run_id_to_span_id[run_id] = span_id
163
183
  self._span_id_to_start_time[span_id] = start_time
@@ -170,23 +190,23 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
170
190
  function=name,
171
191
  depth=current_depth,
172
192
  created_at=start_time,
173
- span_type=span_type
193
+ span_type=span_type,
174
194
  )
175
195
 
176
196
  # Separate metadata from inputs
177
197
  if inputs:
178
198
  metadata = {}
179
199
  clean_inputs = {}
180
-
200
+
181
201
  # Extract metadata fields
182
- metadata_fields = ['tags', 'metadata', 'kwargs', 'serialized']
202
+ metadata_fields = ["tags", "metadata", "kwargs", "serialized"]
183
203
  for field in metadata_fields:
184
204
  if field in inputs:
185
205
  metadata[field] = inputs.pop(field)
186
-
206
+
187
207
  # Store the remaining inputs
188
208
  clean_inputs = inputs
189
-
209
+
190
210
  # Set both fields on the span
191
211
  new_span.inputs = clean_inputs
192
212
  new_span.additional_metadata = metadata
@@ -195,11 +215,13 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
195
215
  new_span.additional_metadata = {}
196
216
 
197
217
  trace_client.add_span(new_span)
198
-
218
+
199
219
  # Queue span with initial state (input phase) through background service
200
220
  if trace_client.background_span_service:
201
- trace_client.background_span_service.queue_span(new_span, span_state="input")
202
-
221
+ trace_client.background_span_service.queue_span(
222
+ new_span, span_state="input"
223
+ )
224
+
203
225
  token = self.tracer.set_current_span(span_id)
204
226
  if token:
205
227
  self.span_id_to_token[span_id] = token
@@ -209,14 +231,15 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
209
231
  trace_client: TraceClient,
210
232
  run_id: UUID,
211
233
  outputs: Optional[Any] = None,
212
- error: Optional[BaseException] = None
234
+ error: Optional[BaseException] = None,
213
235
  ) -> None:
214
236
  """End tracking a span, ensuring trace client exists"""
215
237
 
216
238
  # Get span ID and check if it exists
217
239
  span_id = self._run_id_to_span_id.get(run_id)
218
- token = self.span_id_to_token.pop(span_id, None)
219
- self.tracer.reset_current_span(token, span_id)
240
+ if span_id:
241
+ token = self.span_id_to_token.pop(span_id, None)
242
+ self.tracer.reset_current_span(token, span_id)
220
243
 
221
244
  start_time = self._span_id_to_start_time.get(span_id) if span_id else None
222
245
  duration = time.time() - start_time if start_time is not None else None
@@ -226,7 +249,7 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
226
249
  trace_span = trace_client.span_id_to_span.get(span_id)
227
250
  if trace_span:
228
251
  trace_span.duration = duration
229
-
252
+
230
253
  # Handle outputs and error
231
254
  if error:
232
255
  trace_span.output = error
@@ -234,34 +257,41 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
234
257
  # Separate metadata from outputs
235
258
  metadata = {}
236
259
  clean_outputs = {}
237
-
260
+
238
261
  # Extract metadata fields
239
- metadata_fields = ['tags', 'kwargs']
262
+ metadata_fields = ["tags", "kwargs"]
240
263
  if isinstance(outputs, dict):
241
264
  for field in metadata_fields:
242
265
  if field in outputs:
243
266
  metadata[field] = outputs.pop(field)
244
-
267
+
245
268
  # Store the remaining outputs
246
269
  clean_outputs = outputs
247
270
  else:
248
271
  clean_outputs = outputs
249
-
272
+
250
273
  # Set both fields on the span
251
274
  trace_span.output = clean_outputs
252
275
  if metadata:
253
276
  # Merge with existing metadata
254
277
  existing_metadata = trace_span.additional_metadata or {}
255
- trace_span.additional_metadata = {**existing_metadata, **metadata}
256
-
278
+ trace_span.additional_metadata = {
279
+ **existing_metadata,
280
+ **metadata,
281
+ }
282
+
257
283
  # Queue span with completed state through background service
258
284
  if trace_client.background_span_service:
259
285
  span_state = "error" if error else "completed"
260
- trace_client.background_span_service.queue_span(trace_span, span_state=span_state)
286
+ trace_client.background_span_service.queue_span(
287
+ trace_span, span_state=span_state
288
+ )
261
289
 
262
290
  # Clean up dictionaries for this specific span
263
- if span_id in self._span_id_to_start_time: del self._span_id_to_start_time[span_id]
264
- if span_id in self._span_id_to_depth: del self._span_id_to_depth[span_id]
291
+ if span_id in self._span_id_to_start_time:
292
+ del self._span_id_to_start_time[span_id]
293
+ if span_id in self._span_id_to_depth:
294
+ del self._span_id_to_depth[span_id]
265
295
 
266
296
  # Check if this is the root run ending
267
297
  if run_id == self._root_run_id:
@@ -270,34 +300,38 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
270
300
  self._root_run_id = None
271
301
  # Reset input storage for this handler instance
272
302
 
273
- if self._trace_client and not self._trace_saved: # Check if not already saved
303
+ if (
304
+ self._trace_client and not self._trace_saved
305
+ ): # Check if not already saved
274
306
  # Flush background spans before saving the final trace
275
307
 
276
308
  complete_trace_data = {
277
309
  "trace_id": self._trace_client.trace_id,
278
310
  "name": self._trace_client.name,
279
- "created_at": datetime.utcfromtimestamp(self._trace_client.start_time).isoformat(),
311
+ "created_at": datetime.utcfromtimestamp(
312
+ self._trace_client.start_time
313
+ ).isoformat(),
280
314
  "duration": self._trace_client.get_duration(),
281
- "trace_spans": [span.model_dump() for span in self._trace_client.trace_spans],
315
+ "trace_spans": [
316
+ span.model_dump() for span in self._trace_client.trace_spans
317
+ ],
282
318
  "overwrite": self._trace_client.overwrite,
283
319
  "offline_mode": self.tracer.offline_mode,
284
320
  "parent_trace_id": self._trace_client.parent_trace_id,
285
- "parent_name": self._trace_client.parent_name
321
+ "parent_name": self._trace_client.parent_name,
286
322
  }
287
-
288
- # NEW: Use save_with_rate_limiting with final_save=True for final save
289
- trace_id, trace_data = self._trace_client.save_with_rate_limiting(
290
- overwrite=self._trace_client.overwrite,
291
- final_save=True # Final save with usage counter updates
323
+ trace_id, trace_data = self._trace_client.save(
324
+ overwrite=self._trace_client.overwrite,
325
+ final_save=True, # Final save with usage counter updates
292
326
  )
293
327
  token = self.trace_id_to_token.pop(trace_id, None)
294
328
  self.tracer.reset_current_trace(token, trace_id)
295
-
329
+
296
330
  # Store complete trace data instead of server response
297
331
  self.tracer.traces.append(complete_trace_data)
298
- self._trace_saved = True # Set flag only after successful save
332
+ self._trace_saved = True # Set flag only after successful save
299
333
  finally:
300
- # --- NEW: Consolidated Cleanup Logic ---
334
+ # --- NEW: Consolidated Cleanup Logic ---
301
335
  # This block executes regardless of save success/failure
302
336
  # Reset root run id
303
337
  self._root_run_id = None
@@ -309,68 +343,167 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
309
343
  # --- Callback Methods ---
310
344
  # Each method now ensures the trace client exists before proceeding
311
345
 
312
- def on_retriever_start(self, serialized: Dict[str, Any], query: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
313
- serialized_name = serialized.get('name', 'Unknown') if serialized else "Unknown (Serialized=None)"
346
+ def on_retriever_start(
347
+ self,
348
+ serialized: Dict[str, Any],
349
+ query: str,
350
+ *,
351
+ run_id: UUID,
352
+ parent_run_id: Optional[UUID] = None,
353
+ tags: Optional[List[str]] = None,
354
+ metadata: Optional[Dict[str, Any]] = None,
355
+ **kwargs: Any,
356
+ ) -> Any:
357
+ serialized_name = (
358
+ serialized.get("name", "Unknown")
359
+ if serialized
360
+ else "Unknown (Serialized=None)"
361
+ )
314
362
 
315
363
  name = f"RETRIEVER_{(serialized_name).upper()}"
316
364
  # Pass parent_run_id
317
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
318
- if not trace_client: return
365
+ trace_client = self._ensure_trace_client(
366
+ run_id, parent_run_id, name
367
+ ) # Corrected call
368
+ if not trace_client:
369
+ return
319
370
 
320
- inputs = {'query': query, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
321
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="retriever", inputs=inputs)
371
+ inputs = {
372
+ "query": query,
373
+ "tags": tags,
374
+ "metadata": metadata,
375
+ "kwargs": kwargs,
376
+ "serialized": serialized,
377
+ }
378
+ self._start_span_tracking(
379
+ trace_client,
380
+ run_id,
381
+ parent_run_id,
382
+ name,
383
+ span_type="retriever",
384
+ inputs=inputs,
385
+ )
322
386
 
323
- def on_retriever_end(self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
324
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd") # Corrected call
325
- if not trace_client: return
326
- doc_summary = [{"index": i, "page_content": doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content, "metadata": doc.metadata} for i, doc in enumerate(documents)]
327
- outputs = {"document_count": len(documents), "documents": doc_summary, "kwargs": kwargs}
387
+ def on_retriever_end(
388
+ self,
389
+ documents: Sequence[Document],
390
+ *,
391
+ run_id: UUID,
392
+ parent_run_id: Optional[UUID] = None,
393
+ **kwargs: Any,
394
+ ) -> Any:
395
+ trace_client = self._ensure_trace_client(
396
+ run_id, parent_run_id, "RetrieverEnd"
397
+ ) # Corrected call
398
+ if not trace_client:
399
+ return
400
+ doc_summary = [
401
+ {
402
+ "index": i,
403
+ "page_content": doc.page_content[:100] + "..."
404
+ if len(doc.page_content) > 100
405
+ else doc.page_content,
406
+ "metadata": doc.metadata,
407
+ }
408
+ for i, doc in enumerate(documents)
409
+ ]
410
+ outputs = {
411
+ "document_count": len(documents),
412
+ "documents": doc_summary,
413
+ "kwargs": kwargs,
414
+ }
328
415
  self._end_span_tracking(trace_client, run_id, outputs=outputs)
329
416
 
330
- def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
331
- serialized_name = serialized.get('name') if serialized else "Unknown (Serialized=None)"
417
+ def on_chain_start(
418
+ self,
419
+ serialized: Dict[str, Any],
420
+ inputs: Dict[str, Any],
421
+ *,
422
+ run_id: UUID,
423
+ parent_run_id: Optional[UUID] = None,
424
+ tags: Optional[List[str]] = None,
425
+ metadata: Optional[Dict[str, Any]] = None,
426
+ **kwargs: Any,
427
+ ) -> None:
428
+ serialized_name = (
429
+ serialized.get("name") if serialized else "Unknown (Serialized=None)"
430
+ )
332
431
 
333
432
  # --- Determine Name and Span Type ---
334
433
  span_type: SpanType = "chain"
335
- name = serialized_name if serialized_name else "Unknown Chain" # Default name
434
+ name = serialized_name if serialized_name else "Unknown Chain" # Default name
336
435
  node_name = metadata.get("langgraph_node") if metadata else None
337
- is_langgraph_root_kwarg = kwargs.get('name') == 'LangGraph' # Check kwargs for explicit root name
436
+ is_langgraph_root_kwarg = (
437
+ kwargs.get("name") == "LangGraph"
438
+ ) # Check kwargs for explicit root name
338
439
  # More robust root detection: Often the first chain event with parent_run_id=None *is* the root.
339
440
  is_potential_root_event = parent_run_id is None
340
441
 
341
442
  if node_name:
342
- name = node_name # Use node name if available
343
- if name not in self.executed_nodes: self.executed_nodes.append(name) # Leaving this in for now but can probably be removed
443
+ name = node_name # Use node name if available
444
+ if name not in self.executed_nodes:
445
+ self.executed_nodes.append(
446
+ name
447
+ ) # Leaving this in for now but can probably be removed
344
448
  elif is_langgraph_root_kwarg and is_potential_root_event:
345
- name = "LangGraph" # Explicit root detected
449
+ name = "LangGraph" # Explicit root detected
346
450
  # Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
347
451
 
348
452
  # --- Ensure Trace Client ---
349
- # Pass parent_run_id to _ensure_trace_client
350
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
351
- if not trace_client: return
453
+ # Pass parent_run_id to _ensure_trace_client
454
+ trace_client = self._ensure_trace_client(
455
+ run_id, parent_run_id, name
456
+ ) # Corrected call
457
+ if not trace_client:
458
+ return
352
459
 
353
460
  # --- Update Trace Name if Root ---
354
461
  # If this is the root event (parent_run_id is None) and the trace client was just created,
355
462
  # ensure the trace name reflects the graph's name ('LangGraph' usually).
356
- if is_potential_root_event and run_id == self._root_run_id and trace_client.name != name:
357
- trace_client.name = name # Update trace name to the determined root name
463
+ if (
464
+ is_potential_root_event
465
+ and run_id == self._root_run_id
466
+ and trace_client.name != name
467
+ ):
468
+ trace_client.name = name # Update trace name to the determined root name
358
469
 
359
470
  # --- Start Span Tracking ---
360
- combined_inputs = {'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
361
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type=span_type, inputs=combined_inputs)
362
-
363
-
364
- def on_chain_end(self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any) -> Any:
471
+ combined_inputs = {
472
+ "inputs": inputs,
473
+ "tags": tags,
474
+ "metadata": metadata,
475
+ "kwargs": kwargs,
476
+ "serialized": serialized,
477
+ }
478
+ self._start_span_tracking(
479
+ trace_client,
480
+ run_id,
481
+ parent_run_id,
482
+ name,
483
+ span_type=span_type,
484
+ inputs=combined_inputs,
485
+ )
365
486
 
487
+ def on_chain_end(
488
+ self,
489
+ outputs: Dict[str, Any],
490
+ *,
491
+ run_id: UUID,
492
+ parent_run_id: Optional[UUID] = None,
493
+ tags: Optional[List[str]] = None,
494
+ **kwargs: Any,
495
+ ) -> Any:
366
496
  # Pass parent_run_id
367
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd") # Corrected call
368
- if not trace_client: return
497
+ trace_client = self._ensure_trace_client(
498
+ run_id, parent_run_id, "ChainEnd"
499
+ ) # Corrected call
500
+ if not trace_client:
501
+ return
369
502
 
370
503
  span_id = self._run_id_to_span_id.get(run_id)
371
504
  # If it's the root run ending, _end_span_tracking will handle cleanup/save
372
505
  if not span_id and run_id != self._root_run_id:
373
- return # Don't call end tracking if it's not the root and span wasn't tracked
506
+ return # Don't call end tracking if it's not the root and span wasn't tracked
374
507
 
375
508
  # Prepare outputs for end tracking (moved down)
376
509
  combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
@@ -385,128 +518,247 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
385
518
  complete_trace_data = {
386
519
  "trace_id": trace_client.trace_id,
387
520
  "name": trace_client.name,
388
- "created_at": datetime.utcfromtimestamp(trace_client.start_time).isoformat(),
521
+ "created_at": datetime.utcfromtimestamp(
522
+ trace_client.start_time
523
+ ).isoformat(),
389
524
  "duration": trace_client.get_duration(),
390
- "trace_spans": [span.model_dump() for span in trace_client.trace_spans],
525
+ "trace_spans": [
526
+ span.model_dump() for span in trace_client.trace_spans
527
+ ],
391
528
  "overwrite": trace_client.overwrite,
392
529
  "offline_mode": self.tracer.offline_mode,
393
530
  "parent_trace_id": trace_client.parent_trace_id,
394
- "parent_name": trace_client.parent_name
531
+ "parent_name": trace_client.parent_name,
395
532
  }
396
- # NEW: Use save_with_rate_limiting with final_save=True for final save
397
- trace_id_saved, trace_data = trace_client.save_with_rate_limiting(
533
+ trace_id_saved, trace_data = trace_client.save(
398
534
  overwrite=trace_client.overwrite,
399
- final_save=True # Final save with usage counter updates
535
+ final_save=True,
400
536
  )
401
-
402
-
537
+
403
538
  self.tracer.traces.append(complete_trace_data)
404
539
  self._trace_saved = True
405
540
  # Reset tracer's active client *after* successful save
406
541
  if self.tracer._active_trace_client == trace_client:
407
- self.tracer._active_trace_client = None
408
-
542
+ self.tracer._active_trace_client = None
543
+
409
544
  # Reset root run id after attempt
410
545
  self._root_run_id = None
411
546
  # Reset input storage for this handler instance
412
547
 
413
- def on_chain_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
548
+ def on_chain_error(
549
+ self,
550
+ error: BaseException,
551
+ *,
552
+ run_id: UUID,
553
+ parent_run_id: Optional[UUID] = None,
554
+ **kwargs: Any,
555
+ ) -> Any:
414
556
  # Pass parent_run_id
415
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError") # Corrected call
557
+ trace_client = self._ensure_trace_client(
558
+ run_id, parent_run_id, "ChainError"
559
+ ) # Corrected call
416
560
  if not trace_client:
417
561
  return
418
562
 
419
563
  span_id = self._run_id_to_span_id.get(run_id)
420
-
564
+
421
565
  # Let _end_span_tracking handle potential root run cleanup
422
566
  if not span_id and run_id != self._root_run_id:
423
567
  return
424
568
 
425
569
  self._end_span_tracking(trace_client, run_id, error=error)
426
570
 
427
- def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
428
- name = serialized.get("name", "Unnamed Tool") if serialized else "Unknown Tool (Serialized=None)"
571
+ def on_tool_start(
572
+ self,
573
+ serialized: Dict[str, Any],
574
+ input_str: str,
575
+ *,
576
+ run_id: UUID,
577
+ parent_run_id: Optional[UUID] = None,
578
+ tags: Optional[List[str]] = None,
579
+ metadata: Optional[Dict[str, Any]] = None,
580
+ inputs: Optional[Dict[str, Any]] = None,
581
+ **kwargs: Any,
582
+ ) -> Any:
583
+ name = (
584
+ serialized.get("name", "Unnamed Tool")
585
+ if serialized
586
+ else "Unknown Tool (Serialized=None)"
587
+ )
429
588
 
430
589
  # Pass parent_run_id
431
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
432
- if not trace_client: return
590
+ trace_client = self._ensure_trace_client(
591
+ run_id, parent_run_id, name
592
+ ) # Corrected call
593
+ if not trace_client:
594
+ return
433
595
 
434
- combined_inputs = {'input_str': input_str, 'inputs': inputs, 'tags': tags, 'metadata': metadata, 'kwargs': kwargs, 'serialized': serialized}
435
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="tool", inputs=combined_inputs)
596
+ combined_inputs = {
597
+ "input_str": input_str,
598
+ "inputs": inputs,
599
+ "tags": tags,
600
+ "metadata": metadata,
601
+ "kwargs": kwargs,
602
+ "serialized": serialized,
603
+ }
604
+ self._start_span_tracking(
605
+ trace_client,
606
+ run_id,
607
+ parent_run_id,
608
+ name,
609
+ span_type="tool",
610
+ inputs=combined_inputs,
611
+ )
436
612
 
437
613
  # --- Track executed tools (remains the same) ---
438
- if name not in self.executed_tools: self.executed_tools.append(name) # Leaving this in for now but can probably be removed
614
+ if name not in self.executed_tools:
615
+ self.executed_tools.append(
616
+ name
617
+ ) # Leaving this in for now but can probably be removed
439
618
  parent_node_name = None
440
619
  if parent_run_id and parent_run_id in self._run_id_to_span_id:
441
620
  parent_span_id = self._run_id_to_span_id[parent_run_id]
442
621
  parent_node_name = trace_client.span_id_to_span[parent_span_id].function
443
622
 
444
623
  node_tool = f"{parent_node_name}:{name}" if parent_node_name else name
445
- if node_tool not in self.executed_node_tools: self.executed_node_tools.append(node_tool) # Leaving this in for now but can probably be removed
624
+ if node_tool not in self.executed_node_tools:
625
+ self.executed_node_tools.append(
626
+ node_tool
627
+ ) # Leaving this in for now but can probably be removed
446
628
  # --- End Track executed tools ---
447
629
 
448
-
449
- def on_tool_end(self, output: Any, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
630
+ def on_tool_end(
631
+ self,
632
+ output: Any,
633
+ *,
634
+ run_id: UUID,
635
+ parent_run_id: Optional[UUID] = None,
636
+ **kwargs: Any,
637
+ ) -> Any:
450
638
  # Pass parent_run_id
451
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd") # Corrected call
452
- if not trace_client: return
639
+ trace_client = self._ensure_trace_client(
640
+ run_id, parent_run_id, "ToolEnd"
641
+ ) # Corrected call
642
+ if not trace_client:
643
+ return
453
644
  outputs = {"output": output, "kwargs": kwargs}
454
645
  self._end_span_tracking(trace_client, run_id, outputs=outputs)
455
646
 
456
- def on_tool_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
457
-
647
+ def on_tool_error(
648
+ self,
649
+ error: BaseException,
650
+ *,
651
+ run_id: UUID,
652
+ parent_run_id: Optional[UUID] = None,
653
+ **kwargs: Any,
654
+ ) -> Any:
458
655
  # Pass parent_run_id
459
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError") # Corrected call
460
- if not trace_client: return
656
+ trace_client = self._ensure_trace_client(
657
+ run_id, parent_run_id, "ToolError"
658
+ ) # Corrected call
659
+ if not trace_client:
660
+ return
461
661
  self._end_span_tracking(trace_client, run_id, error=error)
462
662
 
463
- def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
464
-
663
+ def on_llm_start(
664
+ self,
665
+ serialized: Dict[str, Any],
666
+ prompts: List[str],
667
+ *,
668
+ run_id: UUID,
669
+ parent_run_id: Optional[UUID] = None,
670
+ tags: Optional[List[str]] = None,
671
+ metadata: Optional[Dict[str, Any]] = None,
672
+ invocation_params: Optional[Dict[str, Any]] = None,
673
+ options: Optional[Dict[str, Any]] = None,
674
+ name: Optional[str] = None,
675
+ **kwargs: Any,
676
+ ) -> Any:
465
677
  llm_name = name or serialized.get("name", "LLM Call")
466
678
 
467
- trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name) # Corrected call
468
- if not trace_client: return
469
- inputs = {'prompts': prompts, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
470
- self._start_span_tracking(trace_client, run_id, parent_run_id, llm_name, span_type="llm", inputs=inputs)
471
-
472
- def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
679
+ trace_client = self._ensure_trace_client(
680
+ run_id, parent_run_id, llm_name
681
+ ) # Corrected call
682
+ if not trace_client:
683
+ return
684
+ inputs = {
685
+ "prompts": prompts,
686
+ "invocation_params": invocation_params or kwargs,
687
+ "options": options,
688
+ "tags": tags,
689
+ "metadata": metadata,
690
+ "serialized": serialized,
691
+ }
692
+ self._start_span_tracking(
693
+ trace_client,
694
+ run_id,
695
+ parent_run_id,
696
+ llm_name,
697
+ span_type="llm",
698
+ inputs=inputs,
699
+ )
473
700
 
701
+ def on_llm_end(
702
+ self,
703
+ response: LLMResult,
704
+ *,
705
+ run_id: UUID,
706
+ parent_run_id: Optional[UUID] = None,
707
+ **kwargs: Any,
708
+ ) -> Any:
474
709
  # Pass parent_run_id
475
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd") # Corrected call
710
+ trace_client = self._ensure_trace_client(
711
+ run_id, parent_run_id, "LLMEnd"
712
+ ) # Corrected call
476
713
  if not trace_client:
477
714
  return
478
715
  outputs = {"response": response, "kwargs": kwargs}
479
-
716
+
480
717
  # --- Token Usage Extraction and Cost Calculation ---
481
718
  prompt_tokens = None
482
719
  completion_tokens = None
483
720
  total_tokens = None
484
721
  model_name = None
485
-
722
+
486
723
  # Extract model name from response if available
487
- if hasattr(response, 'llm_output') and response.llm_output and isinstance(response.llm_output, dict):
488
- model_name = response.llm_output.get('model_name') or response.llm_output.get('model')
489
-
724
+ if (
725
+ hasattr(response, "llm_output")
726
+ and response.llm_output
727
+ and isinstance(response.llm_output, dict)
728
+ ):
729
+ model_name = response.llm_output.get(
730
+ "model_name"
731
+ ) or response.llm_output.get("model")
732
+
490
733
  # Try to get model from the first generation if available
491
734
  if not model_name and response.generations and len(response.generations) > 0:
492
- if hasattr(response.generations[0][0], 'generation_info') and response.generations[0][0].generation_info:
735
+ if (
736
+ hasattr(response.generations[0][0], "generation_info")
737
+ and response.generations[0][0].generation_info
738
+ ):
493
739
  gen_info = response.generations[0][0].generation_info
494
- model_name = gen_info.get('model') or gen_info.get('model_name')
740
+ model_name = gen_info.get("model") or gen_info.get("model_name")
495
741
 
496
742
  if response.llm_output and isinstance(response.llm_output, dict):
497
743
  # Check for OpenAI/standard 'token_usage' first
498
- if 'token_usage' in response.llm_output:
499
- token_usage = response.llm_output.get('token_usage')
744
+ if "token_usage" in response.llm_output:
745
+ token_usage = response.llm_output.get("token_usage")
500
746
  if token_usage and isinstance(token_usage, dict):
501
- prompt_tokens = token_usage.get('prompt_tokens')
502
- completion_tokens = token_usage.get('completion_tokens')
503
- total_tokens = token_usage.get('total_tokens') # OpenAI provides total
747
+ prompt_tokens = token_usage.get("prompt_tokens")
748
+ completion_tokens = token_usage.get("completion_tokens")
749
+ total_tokens = token_usage.get(
750
+ "total_tokens"
751
+ ) # OpenAI provides total
504
752
  # Check for Anthropic 'usage'
505
- elif 'usage' in response.llm_output:
506
- token_usage = response.llm_output.get('usage')
753
+ elif "usage" in response.llm_output:
754
+ token_usage = response.llm_output.get("usage")
507
755
  if token_usage and isinstance(token_usage, dict):
508
- prompt_tokens = token_usage.get('input_tokens') # Anthropic uses input_tokens
509
- completion_tokens = token_usage.get('output_tokens') # Anthropic uses output_tokens
756
+ prompt_tokens = token_usage.get(
757
+ "input_tokens"
758
+ ) # Anthropic uses input_tokens
759
+ completion_tokens = token_usage.get(
760
+ "output_tokens"
761
+ ) # Anthropic uses output_tokens
510
762
  # Calculate total if possible
511
763
  if prompt_tokens is not None and completion_tokens is not None:
512
764
  total_tokens = prompt_tokens + completion_tokens
@@ -517,57 +769,118 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
517
769
  prompt_cost = None
518
770
  completion_cost = None
519
771
  total_cost_usd = None
520
-
521
- if model_name and prompt_tokens is not None and completion_tokens is not None:
772
+
773
+ if (
774
+ model_name
775
+ and prompt_tokens is not None
776
+ and completion_tokens is not None
777
+ ):
522
778
  try:
523
779
  prompt_cost, completion_cost = cost_per_token(
524
780
  model=model_name,
525
781
  prompt_tokens=prompt_tokens,
526
- completion_tokens=completion_tokens
782
+ completion_tokens=completion_tokens,
783
+ )
784
+ total_cost_usd = (
785
+ (prompt_cost + completion_cost)
786
+ if prompt_cost and completion_cost
787
+ else None
527
788
  )
528
- total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
529
789
  except Exception as e:
530
790
  # If cost calculation fails, continue without costs
531
791
  import warnings
532
- warnings.warn(f"Failed to calculate token costs for model {model_name}: {e}")
533
-
792
+
793
+ warnings.warn(
794
+ f"Failed to calculate token costs for model {model_name}: {e}"
795
+ )
796
+
534
797
  # Create TraceUsage object
535
798
  usage = TraceUsage(
536
799
  prompt_tokens=prompt_tokens,
537
800
  completion_tokens=completion_tokens,
538
- total_tokens=total_tokens or (prompt_tokens + completion_tokens if prompt_tokens and completion_tokens else None),
801
+ total_tokens=total_tokens
802
+ or (
803
+ prompt_tokens + completion_tokens
804
+ if prompt_tokens and completion_tokens
805
+ else None
806
+ ),
539
807
  prompt_tokens_cost_usd=prompt_cost,
540
808
  completion_tokens_cost_usd=completion_cost,
541
809
  total_cost_usd=total_cost_usd,
542
- model_name=model_name
810
+ model_name=model_name,
543
811
  )
544
-
812
+
545
813
  # Set usage on the actual span (not in outputs)
546
814
  span_id = self._run_id_to_span_id.get(run_id)
547
815
  if span_id and span_id in trace_client.span_id_to_span:
548
816
  trace_span = trace_client.span_id_to_span[span_id]
549
817
  trace_span.usage = usage
550
-
551
-
818
+
552
819
  self._end_span_tracking(trace_client, run_id, outputs=outputs)
553
820
  # --- End Token Usage ---
554
821
 
555
- def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
556
-
822
+ def on_llm_error(
823
+ self,
824
+ error: BaseException,
825
+ *,
826
+ run_id: UUID,
827
+ parent_run_id: Optional[UUID] = None,
828
+ **kwargs: Any,
829
+ ) -> Any:
557
830
  # Pass parent_run_id
558
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError") # Corrected call
559
- if not trace_client: return
831
+ trace_client = self._ensure_trace_client(
832
+ run_id, parent_run_id, "LLMError"
833
+ ) # Corrected call
834
+ if not trace_client:
835
+ return
560
836
  self._end_span_tracking(trace_client, run_id, error=error)
561
837
 
562
- def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, invocation_params: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, name: Optional[str] = None, **kwargs: Any) -> Any:
838
+ def on_chat_model_start(
839
+ self,
840
+ serialized: Dict[str, Any],
841
+ messages: List[List[BaseMessage]],
842
+ *,
843
+ run_id: UUID,
844
+ parent_run_id: Optional[UUID] = None,
845
+ tags: Optional[List[str]] = None,
846
+ metadata: Optional[Dict[str, Any]] = None,
847
+ invocation_params: Optional[Dict[str, Any]] = None,
848
+ options: Optional[Dict[str, Any]] = None,
849
+ name: Optional[str] = None,
850
+ **kwargs: Any,
851
+ ) -> Any:
563
852
  # Reuse on_llm_start logic, adding message formatting if needed
564
853
  chat_model_name = name or serialized.get("name", "ChatModel Call")
565
854
  # Add OPENAI_API_CALL suffix if model is OpenAI and not present
566
- is_openai = any(key.startswith('openai') for key in serialized.get('secrets', {}).keys()) or 'openai' in chat_model_name.lower()
567
- is_anthropic = any(key.startswith('anthropic') for key in serialized.get('secrets', {}).keys()) or 'anthropic' in chat_model_name.lower() or 'claude' in chat_model_name.lower()
568
- is_together = any(key.startswith('together') for key in serialized.get('secrets', {}).keys()) or 'together' in chat_model_name.lower()
855
+ is_openai = (
856
+ any(
857
+ key.startswith("openai") for key in serialized.get("secrets", {}).keys()
858
+ )
859
+ or "openai" in chat_model_name.lower()
860
+ )
861
+ is_anthropic = (
862
+ any(
863
+ key.startswith("anthropic")
864
+ for key in serialized.get("secrets", {}).keys()
865
+ )
866
+ or "anthropic" in chat_model_name.lower()
867
+ or "claude" in chat_model_name.lower()
868
+ )
869
+ is_together = (
870
+ any(
871
+ key.startswith("together")
872
+ for key in serialized.get("secrets", {}).keys()
873
+ )
874
+ or "together" in chat_model_name.lower()
875
+ )
569
876
  # Add more checks for other providers like Google if needed
570
- is_google = any(key.startswith('google') for key in serialized.get('secrets', {}).keys()) or 'google' in chat_model_name.lower() or 'gemini' in chat_model_name.lower()
877
+ is_google = (
878
+ any(
879
+ key.startswith("google") for key in serialized.get("secrets", {}).keys()
880
+ )
881
+ or "google" in chat_model_name.lower()
882
+ or "gemini" in chat_model_name.lower()
883
+ )
571
884
 
572
885
  if is_openai and "OPENAI_API_CALL" not in chat_model_name:
573
886
  chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
@@ -577,27 +890,76 @@ class JudgevalCallbackHandler(BaseCallbackHandler):
577
890
  chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
578
891
 
579
892
  elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
580
- chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
581
-
582
- trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name) # Corrected call with parent_run_id
583
- if not trace_client: return
584
- inputs = {'messages': messages, 'invocation_params': invocation_params or kwargs, 'options': options, 'tags': tags, 'metadata': metadata, 'serialized': serialized}
585
- self._start_span_tracking(trace_client, run_id, parent_run_id, chat_model_name, span_type="llm", inputs=inputs) # Use 'llm' span_type for consistency
893
+ chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
586
894
 
587
- def on_agent_action(self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
895
+ trace_client = self._ensure_trace_client(
896
+ run_id, parent_run_id, chat_model_name
897
+ ) # Corrected call with parent_run_id
898
+ if not trace_client:
899
+ return
900
+ inputs = {
901
+ "messages": messages,
902
+ "invocation_params": invocation_params or kwargs,
903
+ "options": options,
904
+ "tags": tags,
905
+ "metadata": metadata,
906
+ "serialized": serialized,
907
+ }
908
+ self._start_span_tracking(
909
+ trace_client,
910
+ run_id,
911
+ parent_run_id,
912
+ chat_model_name,
913
+ span_type="llm",
914
+ inputs=inputs,
915
+ ) # Use 'llm' span_type for consistency
916
+
917
+ def on_agent_action(
918
+ self,
919
+ action: AgentAction,
920
+ *,
921
+ run_id: UUID,
922
+ parent_run_id: Optional[UUID] = None,
923
+ **kwargs: Any,
924
+ ) -> Any:
588
925
  action_tool = action.tool
589
926
  name = f"AGENT_ACTION_{(action_tool).upper()}"
590
927
  # Pass parent_run_id
591
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name) # Corrected call
592
- if not trace_client: return
928
+ trace_client = self._ensure_trace_client(
929
+ run_id, parent_run_id, name
930
+ ) # Corrected call
931
+ if not trace_client:
932
+ return
593
933
 
594
- inputs = {'tool_input': action.tool_input, 'log': action.log, 'messages': action.messages, 'kwargs': kwargs}
595
- self._start_span_tracking(trace_client, run_id, parent_run_id, name, span_type="agent", inputs=inputs)
934
+ inputs = {
935
+ "tool_input": action.tool_input,
936
+ "log": action.log,
937
+ "messages": action.messages,
938
+ "kwargs": kwargs,
939
+ }
940
+ self._start_span_tracking(
941
+ trace_client, run_id, parent_run_id, name, span_type="agent", inputs=inputs
942
+ )
596
943
 
597
- def on_agent_finish(self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any) -> Any:
944
+ def on_agent_finish(
945
+ self,
946
+ finish: AgentFinish,
947
+ *,
948
+ run_id: UUID,
949
+ parent_run_id: Optional[UUID] = None,
950
+ **kwargs: Any,
951
+ ) -> Any:
598
952
  # Pass parent_run_id
599
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish") # Corrected call
600
- if not trace_client: return
953
+ trace_client = self._ensure_trace_client(
954
+ run_id, parent_run_id, "AgentFinish"
955
+ ) # Corrected call
956
+ if not trace_client:
957
+ return
601
958
 
602
- outputs = {'return_values': finish.return_values, 'log': finish.log, 'messages': finish.messages, 'kwargs': kwargs}
603
- self._end_span_tracking(trace_client, run_id, outputs=outputs)
959
+ outputs = {
960
+ "return_values": finish.return_values,
961
+ "log": finish.log,
962
+ "messages": finish.messages,
963
+ "kwargs": kwargs,
964
+ }
965
+ self._end_span_tracking(trace_client, run_id, outputs=outputs)