judgeval 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. judgeval/__init__.py +139 -12
  2. judgeval/api/__init__.py +501 -0
  3. judgeval/api/api_types.py +344 -0
  4. judgeval/cli.py +2 -4
  5. judgeval/constants.py +10 -26
  6. judgeval/data/evaluation_run.py +49 -26
  7. judgeval/data/example.py +2 -2
  8. judgeval/data/judgment_types.py +266 -82
  9. judgeval/data/result.py +4 -5
  10. judgeval/data/scorer_data.py +4 -2
  11. judgeval/data/tool.py +2 -2
  12. judgeval/data/trace.py +7 -50
  13. judgeval/data/trace_run.py +7 -4
  14. judgeval/{dataset.py → dataset/__init__.py} +43 -28
  15. judgeval/env.py +67 -0
  16. judgeval/{run_evaluation.py → evaluation/__init__.py} +29 -95
  17. judgeval/exceptions.py +27 -0
  18. judgeval/integrations/langgraph/__init__.py +788 -0
  19. judgeval/judges/__init__.py +2 -2
  20. judgeval/judges/litellm_judge.py +75 -15
  21. judgeval/judges/together_judge.py +86 -18
  22. judgeval/judges/utils.py +7 -21
  23. judgeval/{common/logger.py → logger.py} +8 -6
  24. judgeval/scorers/__init__.py +0 -4
  25. judgeval/scorers/agent_scorer.py +3 -7
  26. judgeval/scorers/api_scorer.py +8 -13
  27. judgeval/scorers/base_scorer.py +52 -32
  28. judgeval/scorers/example_scorer.py +1 -3
  29. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -14
  30. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +45 -20
  31. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +2 -2
  32. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +3 -3
  33. judgeval/scorers/score.py +21 -31
  34. judgeval/scorers/trace_api_scorer.py +5 -0
  35. judgeval/scorers/utils.py +1 -103
  36. judgeval/tracer/__init__.py +1075 -2
  37. judgeval/tracer/constants.py +1 -0
  38. judgeval/tracer/exporters/__init__.py +37 -0
  39. judgeval/tracer/exporters/s3.py +119 -0
  40. judgeval/tracer/exporters/store.py +43 -0
  41. judgeval/tracer/exporters/utils.py +32 -0
  42. judgeval/tracer/keys.py +67 -0
  43. judgeval/tracer/llm/__init__.py +1233 -0
  44. judgeval/{common/tracer → tracer/llm}/providers.py +5 -10
  45. judgeval/{local_eval_queue.py → tracer/local_eval_queue.py} +15 -10
  46. judgeval/tracer/managers.py +188 -0
  47. judgeval/tracer/processors/__init__.py +181 -0
  48. judgeval/tracer/utils.py +20 -0
  49. judgeval/trainer/__init__.py +5 -0
  50. judgeval/{common/trainer → trainer}/config.py +12 -9
  51. judgeval/{common/trainer → trainer}/console.py +2 -9
  52. judgeval/{common/trainer → trainer}/trainable_model.py +12 -7
  53. judgeval/{common/trainer → trainer}/trainer.py +119 -17
  54. judgeval/utils/async_utils.py +2 -3
  55. judgeval/utils/decorators.py +24 -0
  56. judgeval/utils/file_utils.py +37 -4
  57. judgeval/utils/guards.py +32 -0
  58. judgeval/utils/meta.py +14 -0
  59. judgeval/{common/api/json_encoder.py → utils/serialize.py} +7 -1
  60. judgeval/utils/testing.py +88 -0
  61. judgeval/utils/url.py +10 -0
  62. judgeval/{version_check.py → utils/version_check.py} +3 -3
  63. judgeval/version.py +5 -0
  64. judgeval/warnings.py +4 -0
  65. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/METADATA +12 -14
  66. judgeval-0.9.0.dist-info/RECORD +80 -0
  67. judgeval/clients.py +0 -35
  68. judgeval/common/__init__.py +0 -13
  69. judgeval/common/api/__init__.py +0 -3
  70. judgeval/common/api/api.py +0 -375
  71. judgeval/common/api/constants.py +0 -186
  72. judgeval/common/exceptions.py +0 -27
  73. judgeval/common/storage/__init__.py +0 -6
  74. judgeval/common/storage/s3_storage.py +0 -97
  75. judgeval/common/tracer/__init__.py +0 -31
  76. judgeval/common/tracer/constants.py +0 -22
  77. judgeval/common/tracer/core.py +0 -2427
  78. judgeval/common/tracer/otel_exporter.py +0 -108
  79. judgeval/common/tracer/otel_span_processor.py +0 -188
  80. judgeval/common/tracer/span_processor.py +0 -37
  81. judgeval/common/tracer/span_transformer.py +0 -207
  82. judgeval/common/tracer/trace_manager.py +0 -101
  83. judgeval/common/trainer/__init__.py +0 -5
  84. judgeval/common/utils.py +0 -948
  85. judgeval/integrations/langgraph.py +0 -844
  86. judgeval/judges/mixture_of_judges.py +0 -287
  87. judgeval/judgment_client.py +0 -267
  88. judgeval/rules.py +0 -521
  89. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
  90. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
  91. judgeval/utils/alerts.py +0 -93
  92. judgeval/utils/requests.py +0 -50
  93. judgeval-0.7.1.dist-info/RECORD +0 -82
  94. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/WHEEL +0 -0
  95. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/entry_points.txt +0 -0
  96. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,844 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Sequence
2
- from uuid import UUID
3
- import time
4
- import uuid
5
- from datetime import datetime, timezone
6
-
7
- from judgeval.common.tracer import (
8
- TraceClient,
9
- TraceSpan,
10
- Tracer,
11
- SpanType,
12
- cost_per_token,
13
- )
14
- from judgeval.data.trace import TraceUsage
15
-
16
- from langchain_core.callbacks import BaseCallbackHandler
17
- from langchain_core.agents import AgentAction, AgentFinish
18
- from langchain_core.outputs import LLMResult
19
- from langchain_core.messages.base import BaseMessage
20
- from langchain_core.documents import Document
21
-
22
- # TODO: Figure out how to handle context variables. Current solution is to keep track of current span id in Tracer class
23
-
24
-
25
- class JudgevalCallbackHandler(BaseCallbackHandler):
26
- """
27
- LangChain Callback Handler using run_id/parent_run_id for hierarchy.
28
- Manages its own internal TraceClient instance created upon first use.
29
- Includes verbose logging and defensive checks.
30
- """
31
-
32
- # Make all properties ignored by LangChain's callback system
33
- # to prevent unexpected serialization issues.
34
- lc_serializable = False
35
- lc_kwargs: dict = {}
36
-
37
- def __init__(self, tracer: Tracer):
38
- self.tracer = tracer
39
- self.executed_nodes: List[str] = []
40
- self._reset_state()
41
-
42
- def _reset_state(self):
43
- """Reset only the critical execution state for reuse across multiple executions"""
44
- # Reset core execution state that must be cleared between runs
45
- self._trace_client: Optional[TraceClient] = None
46
- self._run_id_to_span_id: Dict[UUID, str] = {}
47
- self._span_id_to_start_time: Dict[str, float] = {}
48
- self._span_id_to_depth: Dict[str, int] = {}
49
- self._root_run_id: Optional[UUID] = None
50
- self._trace_saved: bool = False
51
- self.span_id_to_token: Dict[str, Any] = {}
52
- self.trace_id_to_token: Dict[str, Any] = {}
53
-
54
- # Add timestamp to track when we last reset
55
- self._last_reset_time: float = time.time()
56
-
57
- # Also reset tracking/logging variables
58
- self.executed_nodes: List[str] = []
59
-
60
- def reset(self):
61
- """Public method to manually reset handler execution state for reuse"""
62
- self._reset_state()
63
-
64
- def reset_all(self):
65
- """Public method to reset ALL handler state including tracking/logging data"""
66
- self._reset_state()
67
-
68
- def _ensure_trace_client(
69
- self, run_id: UUID, parent_run_id: Optional[UUID], event_name: str
70
- ) -> Optional[TraceClient]:
71
- """
72
- Ensures the internal trace client is initialized, creating it only once
73
- per handler instance lifecycle (effectively per graph invocation).
74
- Returns the client or None.
75
- """
76
-
77
- # If this is a potential new root execution (no parent_run_id) and we had a previous trace saved,
78
- # reset state to allow reuse of the handler
79
- if parent_run_id is None and self._trace_saved:
80
- self._reset_state()
81
-
82
- # If a client already exists, return it.
83
- if self._trace_client:
84
- return self._trace_client
85
-
86
- # If no client exists, initialize it NOW.
87
- trace_id = str(uuid.uuid4())
88
- project = self.tracer.project_name
89
- try:
90
- # Use event_name as the initial trace name, might be updated later by on_chain_start if root
91
- client_instance = TraceClient(
92
- self.tracer,
93
- trace_id,
94
- event_name,
95
- project_name=project,
96
- enable_monitoring=self.tracer.enable_monitoring,
97
- enable_evaluations=self.tracer.enable_evaluations,
98
- )
99
- self._trace_client = client_instance
100
- token = self.tracer.set_current_trace(self._trace_client)
101
- if token:
102
- self.trace_id_to_token[trace_id] = token
103
-
104
- if self._trace_client:
105
- self._root_run_id = run_id
106
- self._trace_saved = False
107
- self.tracer._active_trace_client = self._trace_client
108
-
109
- try:
110
- self._trace_client.save(final_save=False)
111
- except Exception as e:
112
- import warnings
113
-
114
- warnings.warn(
115
- f"Failed to save initial trace for live tracking: {e}"
116
- )
117
-
118
- return self._trace_client
119
- else:
120
- return None
121
- except Exception:
122
- self._trace_client = None
123
- self._root_run_id = None
124
- return None
125
-
126
- def _start_span_tracking(
127
- self,
128
- trace_client: TraceClient,
129
- run_id: UUID,
130
- parent_run_id: Optional[UUID],
131
- name: str,
132
- span_type: SpanType = "span",
133
- inputs: Optional[Dict[str, Any]] = None,
134
- ) -> None:
135
- """Start tracking a span, ensuring trace client exists"""
136
- if name.startswith("__") and name.endswith("__"):
137
- return
138
- start_time = time.time()
139
- span_id = str(uuid.uuid4())
140
- parent_span_id: Optional[str] = None
141
- current_depth = 0
142
-
143
- if parent_run_id and parent_run_id in self._run_id_to_span_id:
144
- parent_span_id = self._run_id_to_span_id[parent_run_id]
145
- if parent_span_id in self._span_id_to_depth:
146
- current_depth = self._span_id_to_depth[parent_span_id] + 1
147
-
148
- self._run_id_to_span_id[run_id] = span_id
149
- self._span_id_to_start_time[span_id] = start_time
150
- self._span_id_to_depth[span_id] = current_depth
151
-
152
- new_span = TraceSpan(
153
- span_id=span_id,
154
- trace_id=trace_client.trace_id,
155
- parent_span_id=parent_span_id,
156
- function=name,
157
- depth=current_depth,
158
- created_at=start_time,
159
- span_type=span_type,
160
- )
161
-
162
- # Separate metadata from inputs
163
- if inputs:
164
- metadata = {}
165
- clean_inputs = {}
166
-
167
- # Extract metadata fields
168
- metadata_fields = ["tags", "metadata", "kwargs", "serialized"]
169
- for field in metadata_fields:
170
- if field in inputs:
171
- metadata[field] = inputs.pop(field)
172
-
173
- # Store the remaining inputs
174
- clean_inputs = inputs
175
-
176
- # Set both fields on the span
177
- new_span.inputs = clean_inputs
178
- new_span.additional_metadata = metadata
179
- else:
180
- new_span.inputs = {}
181
- new_span.additional_metadata = {}
182
-
183
- trace_client.add_span(new_span)
184
-
185
- trace_client.otel_span_processor.queue_span_update(new_span, span_state="input")
186
-
187
- token = self.tracer.set_current_span(span_id)
188
- if token:
189
- self.span_id_to_token[span_id] = token
190
-
191
- def _end_span_tracking(
192
- self,
193
- trace_client: TraceClient,
194
- run_id: UUID,
195
- outputs: Optional[Any] = None,
196
- error: Optional[BaseException] = None,
197
- ) -> None:
198
- """End tracking a span, ensuring trace client exists"""
199
-
200
- # Get span ID and check if it exists
201
- span_id = self._run_id_to_span_id.get(run_id)
202
- if span_id:
203
- token = self.span_id_to_token.pop(span_id, None)
204
- self.tracer.reset_current_span(token, span_id)
205
-
206
- start_time = self._span_id_to_start_time.get(span_id) if span_id else None
207
- duration = time.time() - start_time if start_time is not None else None
208
-
209
- # Add exit entry (only if span was tracked)
210
- if span_id:
211
- trace_span = trace_client.span_id_to_span.get(span_id)
212
- if trace_span:
213
- trace_span.duration = duration
214
-
215
- # Handle outputs and error
216
- if error:
217
- trace_span.output = error
218
- elif outputs:
219
- # Separate metadata from outputs
220
- metadata = {}
221
- clean_outputs = {}
222
-
223
- # Extract metadata fields
224
- metadata_fields = ["tags", "kwargs"]
225
- if isinstance(outputs, dict):
226
- for field in metadata_fields:
227
- if field in outputs:
228
- metadata[field] = outputs.pop(field)
229
-
230
- # Store the remaining outputs
231
- clean_outputs = outputs
232
- else:
233
- clean_outputs = outputs
234
-
235
- # Set both fields on the span
236
- trace_span.output = clean_outputs
237
- if metadata:
238
- # Merge with existing metadata
239
- existing_metadata = trace_span.additional_metadata or {}
240
- trace_span.additional_metadata = {
241
- **existing_metadata,
242
- **metadata,
243
- }
244
-
245
- span_state = "error" if error else "completed"
246
- trace_client.otel_span_processor.queue_span_update(
247
- trace_span, span_state=span_state
248
- )
249
-
250
- # Clean up dictionaries for this specific span
251
- if span_id in self._span_id_to_start_time:
252
- del self._span_id_to_start_time[span_id]
253
- if span_id in self._span_id_to_depth:
254
- del self._span_id_to_depth[span_id]
255
-
256
- # Check if this is the root run ending
257
- if run_id == self._root_run_id:
258
- try:
259
- self._root_run_id = None
260
- if (
261
- self._trace_client and not self._trace_saved
262
- ): # Check if not already saved
263
- complete_trace_data = {
264
- "trace_id": self._trace_client.trace_id,
265
- "name": self._trace_client.name,
266
- "created_at": datetime.fromtimestamp(
267
- self._trace_client.start_time, timezone.utc
268
- ).isoformat(),
269
- "duration": self._trace_client.get_duration(),
270
- "trace_spans": [
271
- span.model_dump() for span in self._trace_client.trace_spans
272
- ],
273
- "offline_mode": self.tracer.offline_mode,
274
- "parent_trace_id": self._trace_client.parent_trace_id,
275
- "parent_name": self._trace_client.parent_name,
276
- }
277
-
278
- self.tracer.flush_background_spans()
279
-
280
- trace_id, trace_data = self._trace_client.save(
281
- final_save=True, # Final save with usage counter updates
282
- )
283
- token = self.trace_id_to_token.pop(trace_id, None)
284
- self.tracer.reset_current_trace(token, trace_id)
285
-
286
- # Store complete trace data instead of server response
287
- self.tracer.traces.append(complete_trace_data)
288
- self._trace_saved = True # Set flag only after successful save
289
- finally:
290
- # This block executes regardless of save success/failure
291
- # Reset root run id
292
- self._root_run_id = None
293
- # Reset input storage for this handler instance
294
- if self.tracer._active_trace_client == self._trace_client:
295
- self.tracer._active_trace_client = None
296
-
297
- # --- Callback Methods ---
298
- # Each method now ensures the trace client exists before proceeding
299
-
300
- def on_retriever_start(
301
- self,
302
- serialized: Dict[str, Any],
303
- query: str,
304
- *,
305
- run_id: UUID,
306
- parent_run_id: Optional[UUID] = None,
307
- tags: Optional[List[str]] = None,
308
- metadata: Optional[Dict[str, Any]] = None,
309
- **kwargs: Any,
310
- ) -> Any:
311
- serialized_name = (
312
- serialized.get("name", "Unknown")
313
- if serialized
314
- else "Unknown (Serialized=None)"
315
- )
316
-
317
- name = f"RETRIEVER_{(serialized_name).upper()}"
318
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
319
- if not trace_client:
320
- return
321
-
322
- inputs = {
323
- "query": query,
324
- "tags": tags,
325
- "metadata": metadata,
326
- "kwargs": kwargs,
327
- "serialized": serialized,
328
- }
329
- self._start_span_tracking(
330
- trace_client,
331
- run_id,
332
- parent_run_id,
333
- name,
334
- span_type="retriever",
335
- inputs=inputs,
336
- )
337
-
338
- def on_retriever_end(
339
- self,
340
- documents: Sequence[Document],
341
- *,
342
- run_id: UUID,
343
- parent_run_id: Optional[UUID] = None,
344
- **kwargs: Any,
345
- ) -> Any:
346
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "RetrieverEnd")
347
- if not trace_client:
348
- return
349
- doc_summary = [
350
- {
351
- "index": i,
352
- "page_content": (
353
- doc.page_content[:100] + "..."
354
- if len(doc.page_content) > 100
355
- else doc.page_content
356
- ),
357
- "metadata": doc.metadata,
358
- }
359
- for i, doc in enumerate(documents)
360
- ]
361
- outputs = {
362
- "document_count": len(documents),
363
- "documents": doc_summary,
364
- "kwargs": kwargs,
365
- }
366
- self._end_span_tracking(trace_client, run_id, outputs=outputs)
367
-
368
- def on_chain_start(
369
- self,
370
- serialized: Dict[str, Any],
371
- inputs: Dict[str, Any],
372
- *,
373
- run_id: UUID,
374
- parent_run_id: Optional[UUID] = None,
375
- tags: Optional[List[str]] = None,
376
- metadata: Optional[Dict[str, Any]] = None,
377
- **kwargs: Any,
378
- ) -> None:
379
- serialized_name = (
380
- serialized.get("name") if serialized else "Unknown (Serialized=None)"
381
- )
382
-
383
- # --- Determine Name and Span Type ---
384
- span_type: SpanType = "chain"
385
- name = serialized_name if serialized_name else "Unknown Chain"
386
- node_name = metadata.get("langgraph_node") if metadata else None
387
- is_langgraph_root_kwarg = (
388
- kwargs.get("name") == "LangGraph"
389
- ) # Check kwargs for explicit root name
390
- # More robust root detection: Often the first chain event with parent_run_id=None *is* the root.
391
- is_potential_root_event = parent_run_id is None
392
-
393
- if node_name:
394
- name = node_name # Use node name if available
395
- if name not in self.executed_nodes:
396
- self.executed_nodes.append(
397
- name
398
- ) # Leaving this in for now but can probably be removed
399
- elif is_langgraph_root_kwarg and is_potential_root_event:
400
- name = "LangGraph" # Explicit root detected
401
- # Add handling for other potential LangChain internal chains if needed, e.g., "RunnableSequence"
402
-
403
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
404
- if not trace_client:
405
- return
406
-
407
- if (
408
- is_potential_root_event
409
- and run_id == self._root_run_id
410
- and trace_client.name != name
411
- ):
412
- trace_client.name = name
413
-
414
- combined_inputs = {
415
- "inputs": inputs,
416
- "tags": tags,
417
- "metadata": metadata,
418
- "kwargs": kwargs,
419
- "serialized": serialized,
420
- }
421
- self._start_span_tracking(
422
- trace_client,
423
- run_id,
424
- parent_run_id,
425
- name,
426
- span_type=span_type,
427
- inputs=combined_inputs,
428
- )
429
-
430
- def on_chain_end(
431
- self,
432
- outputs: Dict[str, Any],
433
- *,
434
- run_id: UUID,
435
- parent_run_id: Optional[UUID] = None,
436
- tags: Optional[List[str]] = None,
437
- **kwargs: Any,
438
- ) -> Any:
439
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainEnd")
440
- if not trace_client:
441
- return
442
-
443
- span_id = self._run_id_to_span_id.get(run_id)
444
- if not span_id and run_id != self._root_run_id:
445
- return
446
-
447
- combined_outputs = {"outputs": outputs, "tags": tags, "kwargs": kwargs}
448
-
449
- self._end_span_tracking(trace_client, run_id, outputs=combined_outputs)
450
-
451
- if run_id == self._root_run_id:
452
- if trace_client and not self._trace_saved:
453
- complete_trace_data = {
454
- "trace_id": trace_client.trace_id,
455
- "name": trace_client.name,
456
- "created_at": datetime.fromtimestamp(
457
- trace_client.start_time, timezone.utc
458
- ).isoformat(),
459
- "duration": trace_client.get_duration(),
460
- "trace_spans": [
461
- span.model_dump() for span in trace_client.trace_spans
462
- ],
463
- "offline_mode": self.tracer.offline_mode,
464
- "parent_trace_id": trace_client.parent_trace_id,
465
- "parent_name": trace_client.parent_name,
466
- }
467
-
468
- self.tracer.flush_background_spans()
469
-
470
- trace_client.save(
471
- final_save=True,
472
- )
473
-
474
- self.tracer.traces.append(complete_trace_data)
475
- self._trace_saved = True
476
- if self.tracer._active_trace_client == trace_client:
477
- self.tracer._active_trace_client = None
478
-
479
- self._root_run_id = None
480
-
481
- def on_chain_error(
482
- self,
483
- error: BaseException,
484
- *,
485
- run_id: UUID,
486
- parent_run_id: Optional[UUID] = None,
487
- **kwargs: Any,
488
- ) -> Any:
489
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ChainError")
490
- if not trace_client:
491
- return
492
-
493
- span_id = self._run_id_to_span_id.get(run_id)
494
-
495
- if not span_id and run_id != self._root_run_id:
496
- return
497
-
498
- self._end_span_tracking(trace_client, run_id, error=error)
499
-
500
- def on_tool_start(
501
- self,
502
- serialized: Dict[str, Any],
503
- input_str: str,
504
- *,
505
- run_id: UUID,
506
- parent_run_id: Optional[UUID] = None,
507
- tags: Optional[List[str]] = None,
508
- metadata: Optional[Dict[str, Any]] = None,
509
- inputs: Optional[Dict[str, Any]] = None,
510
- **kwargs: Any,
511
- ) -> Any:
512
- name = (
513
- serialized.get("name", "Unnamed Tool")
514
- if serialized
515
- else "Unknown Tool (Serialized=None)"
516
- )
517
-
518
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
519
- if not trace_client:
520
- return
521
-
522
- combined_inputs = {
523
- "input_str": input_str,
524
- "inputs": inputs,
525
- "tags": tags,
526
- "metadata": metadata,
527
- "kwargs": kwargs,
528
- "serialized": serialized,
529
- }
530
- self._start_span_tracking(
531
- trace_client,
532
- run_id,
533
- parent_run_id,
534
- name,
535
- span_type="tool",
536
- inputs=combined_inputs,
537
- )
538
-
539
- def on_tool_end(
540
- self,
541
- output: Any,
542
- *,
543
- run_id: UUID,
544
- parent_run_id: Optional[UUID] = None,
545
- **kwargs: Any,
546
- ) -> Any:
547
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolEnd")
548
- if not trace_client:
549
- return
550
- outputs = {"output": output, "kwargs": kwargs}
551
- self._end_span_tracking(trace_client, run_id, outputs=outputs)
552
-
553
- def on_tool_error(
554
- self,
555
- error: BaseException,
556
- *,
557
- run_id: UUID,
558
- parent_run_id: Optional[UUID] = None,
559
- **kwargs: Any,
560
- ) -> Any:
561
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "ToolError")
562
- if not trace_client:
563
- return
564
- self._end_span_tracking(trace_client, run_id, error=error)
565
-
566
- def on_llm_start(
567
- self,
568
- serialized: Dict[str, Any],
569
- prompts: List[str],
570
- *,
571
- run_id: UUID,
572
- parent_run_id: Optional[UUID] = None,
573
- tags: Optional[List[str]] = None,
574
- metadata: Optional[Dict[str, Any]] = None,
575
- invocation_params: Optional[Dict[str, Any]] = None,
576
- options: Optional[Dict[str, Any]] = None,
577
- name: Optional[str] = None,
578
- **kwargs: Any,
579
- ) -> Any:
580
- llm_name = name or serialized.get("name", "LLM Call")
581
-
582
- trace_client = self._ensure_trace_client(run_id, parent_run_id, llm_name)
583
- if not trace_client:
584
- return
585
- inputs = {
586
- "prompts": prompts,
587
- "invocation_params": invocation_params or kwargs,
588
- "options": options,
589
- "tags": tags,
590
- "metadata": metadata,
591
- "serialized": serialized,
592
- }
593
- self._start_span_tracking(
594
- trace_client,
595
- run_id,
596
- parent_run_id,
597
- llm_name,
598
- span_type="llm",
599
- inputs=inputs,
600
- )
601
-
602
- def on_llm_end(
603
- self,
604
- response: LLMResult,
605
- *,
606
- run_id: UUID,
607
- parent_run_id: Optional[UUID] = None,
608
- **kwargs: Any,
609
- ) -> Any:
610
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMEnd")
611
- if not trace_client:
612
- return
613
- outputs = {"response": response, "kwargs": kwargs}
614
-
615
- prompt_tokens = None
616
- completion_tokens = None
617
- total_tokens = None
618
- model_name = None
619
-
620
- # Extract model name from response if available
621
- if (
622
- hasattr(response, "llm_output")
623
- and response.llm_output
624
- and isinstance(response.llm_output, dict)
625
- ):
626
- model_name = response.llm_output.get(
627
- "model_name"
628
- ) or response.llm_output.get("model")
629
-
630
- # Try to get model from the first generation if available
631
- if not model_name and response.generations and len(response.generations) > 0:
632
- if (
633
- hasattr(response.generations[0][0], "generation_info")
634
- and response.generations[0][0].generation_info
635
- ):
636
- gen_info = response.generations[0][0].generation_info
637
- model_name = gen_info.get("model") or gen_info.get("model_name")
638
-
639
- if response.llm_output and isinstance(response.llm_output, dict):
640
- # Check for OpenAI/standard 'token_usage' first
641
- if "token_usage" in response.llm_output:
642
- token_usage = response.llm_output.get("token_usage")
643
- if token_usage and isinstance(token_usage, dict):
644
- prompt_tokens = token_usage.get("prompt_tokens")
645
- completion_tokens = token_usage.get("completion_tokens")
646
- total_tokens = token_usage.get(
647
- "total_tokens"
648
- ) # OpenAI provides total
649
- # Check for Anthropic 'usage'
650
- elif "usage" in response.llm_output:
651
- token_usage = response.llm_output.get("usage")
652
- if token_usage and isinstance(token_usage, dict):
653
- prompt_tokens = token_usage.get(
654
- "input_tokens"
655
- ) # Anthropic uses input_tokens
656
- completion_tokens = token_usage.get(
657
- "output_tokens"
658
- ) # Anthropic uses output_tokens
659
- # Calculate total if possible
660
- if prompt_tokens is not None and completion_tokens is not None:
661
- total_tokens = prompt_tokens + completion_tokens
662
-
663
- if prompt_tokens is not None or completion_tokens is not None:
664
- prompt_cost = None
665
- completion_cost = None
666
- total_cost_usd = None
667
-
668
- if (
669
- model_name
670
- and prompt_tokens is not None
671
- and completion_tokens is not None
672
- ):
673
- try:
674
- prompt_cost, completion_cost = cost_per_token(
675
- model=model_name,
676
- prompt_tokens=prompt_tokens,
677
- completion_tokens=completion_tokens,
678
- )
679
- total_cost_usd = (
680
- (prompt_cost + completion_cost)
681
- if prompt_cost and completion_cost
682
- else None
683
- )
684
- except Exception as e:
685
- # If cost calculation fails, continue without costs
686
- import warnings
687
-
688
- warnings.warn(
689
- f"Failed to calculate token costs for model {model_name}: {e}"
690
- )
691
-
692
- usage = TraceUsage(
693
- prompt_tokens=prompt_tokens,
694
- completion_tokens=completion_tokens,
695
- total_tokens=total_tokens
696
- or (
697
- prompt_tokens + completion_tokens
698
- if prompt_tokens and completion_tokens
699
- else None
700
- ),
701
- prompt_tokens_cost_usd=prompt_cost,
702
- completion_tokens_cost_usd=completion_cost,
703
- total_cost_usd=total_cost_usd,
704
- model_name=model_name,
705
- )
706
-
707
- span_id = self._run_id_to_span_id.get(run_id)
708
- if span_id and span_id in trace_client.span_id_to_span:
709
- trace_span = trace_client.span_id_to_span[span_id]
710
- trace_span.usage = usage
711
-
712
- self._end_span_tracking(trace_client, run_id, outputs=outputs)
713
-
714
- def on_llm_error(
715
- self,
716
- error: BaseException,
717
- *,
718
- run_id: UUID,
719
- parent_run_id: Optional[UUID] = None,
720
- **kwargs: Any,
721
- ) -> Any:
722
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "LLMError")
723
- if not trace_client:
724
- return
725
- self._end_span_tracking(trace_client, run_id, error=error)
726
-
727
- def on_chat_model_start(
728
- self,
729
- serialized: Dict[str, Any],
730
- messages: List[List[BaseMessage]],
731
- *,
732
- run_id: UUID,
733
- parent_run_id: Optional[UUID] = None,
734
- tags: Optional[List[str]] = None,
735
- metadata: Optional[Dict[str, Any]] = None,
736
- invocation_params: Optional[Dict[str, Any]] = None,
737
- options: Optional[Dict[str, Any]] = None,
738
- name: Optional[str] = None,
739
- **kwargs: Any,
740
- ) -> Any:
741
- chat_model_name = name or serialized.get("name", "ChatModel Call")
742
- is_openai = (
743
- any(
744
- key.startswith("openai") for key in serialized.get("secrets", {}).keys()
745
- )
746
- or "openai" in chat_model_name.lower()
747
- )
748
- is_anthropic = (
749
- any(
750
- key.startswith("anthropic")
751
- for key in serialized.get("secrets", {}).keys()
752
- )
753
- or "anthropic" in chat_model_name.lower()
754
- or "claude" in chat_model_name.lower()
755
- )
756
- is_together = (
757
- any(
758
- key.startswith("together")
759
- for key in serialized.get("secrets", {}).keys()
760
- )
761
- or "together" in chat_model_name.lower()
762
- )
763
-
764
- is_google = (
765
- any(
766
- key.startswith("google") for key in serialized.get("secrets", {}).keys()
767
- )
768
- or "google" in chat_model_name.lower()
769
- or "gemini" in chat_model_name.lower()
770
- )
771
-
772
- if is_openai and "OPENAI_API_CALL" not in chat_model_name:
773
- chat_model_name = f"{chat_model_name} OPENAI_API_CALL"
774
- elif is_anthropic and "ANTHROPIC_API_CALL" not in chat_model_name:
775
- chat_model_name = f"{chat_model_name} ANTHROPIC_API_CALL"
776
- elif is_together and "TOGETHER_API_CALL" not in chat_model_name:
777
- chat_model_name = f"{chat_model_name} TOGETHER_API_CALL"
778
-
779
- elif is_google and "GOOGLE_API_CALL" not in chat_model_name:
780
- chat_model_name = f"{chat_model_name} GOOGLE_API_CALL"
781
-
782
- trace_client = self._ensure_trace_client(run_id, parent_run_id, chat_model_name)
783
- if not trace_client:
784
- return
785
- inputs = {
786
- "messages": messages,
787
- "invocation_params": invocation_params or kwargs,
788
- "options": options,
789
- "tags": tags,
790
- "metadata": metadata,
791
- "serialized": serialized,
792
- }
793
- self._start_span_tracking(
794
- trace_client,
795
- run_id,
796
- parent_run_id,
797
- chat_model_name,
798
- span_type="llm",
799
- inputs=inputs,
800
- )
801
-
802
- def on_agent_action(
803
- self,
804
- action: AgentAction,
805
- *,
806
- run_id: UUID,
807
- parent_run_id: Optional[UUID] = None,
808
- **kwargs: Any,
809
- ) -> Any:
810
- action_tool = action.tool
811
- name = f"AGENT_ACTION_{(action_tool).upper()}"
812
- trace_client = self._ensure_trace_client(run_id, parent_run_id, name)
813
- if not trace_client:
814
- return
815
-
816
- inputs = {
817
- "tool_input": action.tool_input,
818
- "log": action.log,
819
- "messages": action.messages,
820
- "kwargs": kwargs,
821
- }
822
- self._start_span_tracking(
823
- trace_client, run_id, parent_run_id, name, span_type="agent", inputs=inputs
824
- )
825
-
826
- def on_agent_finish(
827
- self,
828
- finish: AgentFinish,
829
- *,
830
- run_id: UUID,
831
- parent_run_id: Optional[UUID] = None,
832
- **kwargs: Any,
833
- ) -> Any:
834
- trace_client = self._ensure_trace_client(run_id, parent_run_id, "AgentFinish")
835
- if not trace_client:
836
- return
837
-
838
- outputs = {
839
- "return_values": finish.return_values,
840
- "log": finish.log,
841
- "messages": finish.messages,
842
- "kwargs": kwargs,
843
- }
844
- self._end_span_tracking(trace_client, run_id, outputs=outputs)