judgeval 0.7.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. judgeval/__init__.py +139 -12
  2. judgeval/api/__init__.py +501 -0
  3. judgeval/api/api_types.py +344 -0
  4. judgeval/cli.py +2 -4
  5. judgeval/constants.py +10 -26
  6. judgeval/data/evaluation_run.py +49 -26
  7. judgeval/data/example.py +2 -2
  8. judgeval/data/judgment_types.py +266 -82
  9. judgeval/data/result.py +4 -5
  10. judgeval/data/scorer_data.py +4 -2
  11. judgeval/data/tool.py +2 -2
  12. judgeval/data/trace.py +7 -50
  13. judgeval/data/trace_run.py +7 -4
  14. judgeval/{dataset.py → dataset/__init__.py} +43 -28
  15. judgeval/env.py +67 -0
  16. judgeval/{run_evaluation.py → evaluation/__init__.py} +29 -95
  17. judgeval/exceptions.py +27 -0
  18. judgeval/integrations/langgraph/__init__.py +788 -0
  19. judgeval/judges/__init__.py +2 -2
  20. judgeval/judges/litellm_judge.py +75 -15
  21. judgeval/judges/together_judge.py +86 -18
  22. judgeval/judges/utils.py +7 -21
  23. judgeval/{common/logger.py → logger.py} +8 -6
  24. judgeval/scorers/__init__.py +0 -4
  25. judgeval/scorers/agent_scorer.py +3 -7
  26. judgeval/scorers/api_scorer.py +8 -13
  27. judgeval/scorers/base_scorer.py +52 -32
  28. judgeval/scorers/example_scorer.py +1 -3
  29. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +0 -14
  30. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +45 -20
  31. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +2 -2
  32. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +3 -3
  33. judgeval/scorers/score.py +21 -31
  34. judgeval/scorers/trace_api_scorer.py +5 -0
  35. judgeval/scorers/utils.py +1 -103
  36. judgeval/tracer/__init__.py +1075 -2
  37. judgeval/tracer/constants.py +1 -0
  38. judgeval/tracer/exporters/__init__.py +37 -0
  39. judgeval/tracer/exporters/s3.py +119 -0
  40. judgeval/tracer/exporters/store.py +43 -0
  41. judgeval/tracer/exporters/utils.py +32 -0
  42. judgeval/tracer/keys.py +67 -0
  43. judgeval/tracer/llm/__init__.py +1233 -0
  44. judgeval/{common/tracer → tracer/llm}/providers.py +5 -10
  45. judgeval/{local_eval_queue.py → tracer/local_eval_queue.py} +15 -10
  46. judgeval/tracer/managers.py +188 -0
  47. judgeval/tracer/processors/__init__.py +181 -0
  48. judgeval/tracer/utils.py +20 -0
  49. judgeval/trainer/__init__.py +5 -0
  50. judgeval/{common/trainer → trainer}/config.py +12 -9
  51. judgeval/{common/trainer → trainer}/console.py +2 -9
  52. judgeval/{common/trainer → trainer}/trainable_model.py +12 -7
  53. judgeval/{common/trainer → trainer}/trainer.py +119 -17
  54. judgeval/utils/async_utils.py +2 -3
  55. judgeval/utils/decorators.py +24 -0
  56. judgeval/utils/file_utils.py +37 -4
  57. judgeval/utils/guards.py +32 -0
  58. judgeval/utils/meta.py +14 -0
  59. judgeval/{common/api/json_encoder.py → utils/serialize.py} +7 -1
  60. judgeval/utils/testing.py +88 -0
  61. judgeval/utils/url.py +10 -0
  62. judgeval/{version_check.py → utils/version_check.py} +3 -3
  63. judgeval/version.py +5 -0
  64. judgeval/warnings.py +4 -0
  65. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/METADATA +12 -14
  66. judgeval-0.9.0.dist-info/RECORD +80 -0
  67. judgeval/clients.py +0 -35
  68. judgeval/common/__init__.py +0 -13
  69. judgeval/common/api/__init__.py +0 -3
  70. judgeval/common/api/api.py +0 -375
  71. judgeval/common/api/constants.py +0 -186
  72. judgeval/common/exceptions.py +0 -27
  73. judgeval/common/storage/__init__.py +0 -6
  74. judgeval/common/storage/s3_storage.py +0 -97
  75. judgeval/common/tracer/__init__.py +0 -31
  76. judgeval/common/tracer/constants.py +0 -22
  77. judgeval/common/tracer/core.py +0 -2427
  78. judgeval/common/tracer/otel_exporter.py +0 -108
  79. judgeval/common/tracer/otel_span_processor.py +0 -188
  80. judgeval/common/tracer/span_processor.py +0 -37
  81. judgeval/common/tracer/span_transformer.py +0 -207
  82. judgeval/common/tracer/trace_manager.py +0 -101
  83. judgeval/common/trainer/__init__.py +0 -5
  84. judgeval/common/utils.py +0 -948
  85. judgeval/integrations/langgraph.py +0 -844
  86. judgeval/judges/mixture_of_judges.py +0 -287
  87. judgeval/judgment_client.py +0 -267
  88. judgeval/rules.py +0 -521
  89. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +0 -52
  90. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +0 -28
  91. judgeval/utils/alerts.py +0 -93
  92. judgeval/utils/requests.py +0 -50
  93. judgeval-0.7.1.dist-info/RECORD +0 -82
  94. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/WHEEL +0 -0
  95. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/entry_points.txt +0 -0
  96. {judgeval-0.7.1.dist-info → judgeval-0.9.0.dist-info}/licenses/LICENSE.md +0 -0
@@ -1,2427 +0,0 @@
1
- """
2
- Tracing system for judgeval that allows for function tracing using decorators.
3
- """
4
-
5
- from __future__ import annotations
6
-
7
- import asyncio
8
- import atexit
9
- import functools
10
- import inspect
11
- import os
12
- import threading
13
- import time
14
- import traceback
15
- import uuid
16
- import contextvars
17
- import sys
18
- from contextlib import (
19
- contextmanager,
20
- )
21
- from datetime import datetime, timezone
22
- from typing import (
23
- Any,
24
- Callable,
25
- Dict,
26
- Generator,
27
- List,
28
- Optional,
29
- ParamSpec,
30
- Tuple,
31
- TypeVar,
32
- Union,
33
- TypeAlias,
34
- overload,
35
- )
36
- import types
37
- import random
38
-
39
-
40
- from judgeval.common.tracer.constants import _TRACE_FILEPATH_BLOCKLIST
41
-
42
- from judgeval.common.tracer.otel_span_processor import JudgmentSpanProcessor
43
- from judgeval.common.tracer.span_processor import SpanProcessorBase
44
- from judgeval.common.tracer.trace_manager import TraceManagerClient
45
-
46
- from judgeval.data import Example, Trace, TraceSpan, TraceUsage
47
- from judgeval.scorers import APIScorerConfig, BaseScorer
48
- from judgeval.data.evaluation_run import EvaluationRun
49
- from judgeval.local_eval_queue import LocalEvaluationQueue
50
- from judgeval.common.api import JudgmentApiClient
51
- from judgeval.common.utils import OptExcInfo, validate_api_key
52
- from judgeval.common.logger import judgeval_logger
53
-
54
- from litellm import cost_per_token as _original_cost_per_token # type: ignore
55
- from judgeval.common.tracer.providers import (
56
- HAS_OPENAI,
57
- HAS_TOGETHER,
58
- HAS_ANTHROPIC,
59
- HAS_GOOGLE_GENAI,
60
- HAS_GROQ,
61
- ApiClient,
62
- )
63
- from judgeval.constants import DEFAULT_GPT_MODEL
64
-
65
-
66
- current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
67
- "current_trace", default=None
68
- )
69
- current_span_var = contextvars.ContextVar[Optional[str]]("current_span", default=None)
70
-
71
-
72
- SpanType: TypeAlias = str
73
-
74
-
75
- class TraceClient:
76
- """Client for managing a single trace context"""
77
-
78
- def __init__(
79
- self,
80
- tracer: Tracer,
81
- trace_id: Optional[str] = None,
82
- name: str = "default",
83
- project_name: Union[str, None] = None,
84
- enable_monitoring: bool = True,
85
- enable_evaluations: bool = True,
86
- parent_trace_id: Optional[str] = None,
87
- parent_name: Optional[str] = None,
88
- ):
89
- self.name = name
90
- self.trace_id = trace_id or str(uuid.uuid4())
91
- self.project_name = project_name or "default_project"
92
- self.tracer = tracer
93
- self.enable_monitoring = enable_monitoring
94
- self.enable_evaluations = enable_evaluations
95
- self.parent_trace_id = parent_trace_id
96
- self.parent_name = parent_name
97
- self.customer_id: Optional[str] = None
98
- self.tags: List[Union[str, set, tuple]] = []
99
- self.metadata: Dict[str, Any] = {}
100
- self.has_notification: Optional[bool] = False
101
- self.update_id: int = 1
102
- self.trace_spans: List[TraceSpan] = []
103
- self.span_id_to_span: Dict[str, TraceSpan] = {}
104
- self.evaluation_runs: List[EvaluationRun] = []
105
- self.start_time: Optional[float] = None
106
- self.trace_manager_client = TraceManagerClient(
107
- tracer.api_key, tracer.organization_id, tracer
108
- )
109
- self._span_depths: Dict[str, int] = {}
110
-
111
- self.otel_span_processor = tracer.otel_span_processor
112
-
113
- def get_current_span(self):
114
- """Get the current span from the context var"""
115
- return self.tracer.get_current_span()
116
-
117
- def set_current_span(self, span: Any):
118
- """Set the current span from the context var"""
119
- return self.tracer.set_current_span(span)
120
-
121
- def reset_current_span(self, token: Any):
122
- """Reset the current span from the context var"""
123
- self.tracer.reset_current_span(token)
124
-
125
- @contextmanager
126
- def span(self, name: str, span_type: SpanType = "span"):
127
- """Context manager for creating a trace span, managing the current span via contextvars"""
128
- is_first_span = len(self.trace_spans) == 0
129
- if is_first_span:
130
- try:
131
- self.save(final_save=False)
132
- except Exception as e:
133
- judgeval_logger.warning(
134
- f"Failed to save initial trace for live tracking: {e}"
135
- )
136
- start_time = time.time()
137
-
138
- span_id = str(uuid.uuid4())
139
-
140
- parent_span_id = self.get_current_span()
141
- token = self.set_current_span(span_id)
142
-
143
- current_depth = 0
144
- if parent_span_id and parent_span_id in self._span_depths:
145
- current_depth = self._span_depths[parent_span_id] + 1
146
-
147
- self._span_depths[span_id] = current_depth
148
-
149
- span = TraceSpan(
150
- span_id=span_id,
151
- trace_id=self.trace_id,
152
- depth=current_depth,
153
- message=name,
154
- created_at=start_time,
155
- span_type=span_type,
156
- parent_span_id=parent_span_id,
157
- function=name,
158
- )
159
- self.add_span(span)
160
-
161
- self.otel_span_processor.queue_span_update(span, span_state="input")
162
-
163
- try:
164
- yield self
165
- finally:
166
- duration = time.time() - start_time
167
- span.duration = duration
168
-
169
- self.otel_span_processor.queue_span_update(span, span_state="completed")
170
-
171
- if span_id in self._span_depths:
172
- del self._span_depths[span_id]
173
- self.reset_current_span(token)
174
-
175
- def async_evaluate(
176
- self,
177
- scorer: Union[APIScorerConfig, BaseScorer],
178
- example: Example,
179
- model: str = DEFAULT_GPT_MODEL,
180
- ):
181
- start_time = time.time()
182
- span_id = self.get_current_span()
183
- eval_run_name = (
184
- f"{self.name.capitalize()}-{span_id}-{scorer.score_type.capitalize()}"
185
- )
186
- hosted_scoring = isinstance(scorer, APIScorerConfig) or (
187
- isinstance(scorer, BaseScorer) and scorer.server_hosted
188
- )
189
- if hosted_scoring:
190
- eval_run = EvaluationRun(
191
- organization_id=self.tracer.organization_id,
192
- project_name=self.project_name,
193
- eval_name=eval_run_name,
194
- examples=[example],
195
- scorers=[scorer],
196
- model=model,
197
- trace_span_id=span_id,
198
- )
199
-
200
- self.add_eval_run(eval_run, start_time)
201
-
202
- if span_id:
203
- current_span = self.span_id_to_span.get(span_id)
204
- if current_span:
205
- self.otel_span_processor.queue_evaluation_run(
206
- eval_run, span_id=span_id, span_data=current_span
207
- )
208
- else:
209
- # Handle custom scorers using local evaluation queue
210
- eval_run = EvaluationRun(
211
- organization_id=self.tracer.organization_id,
212
- project_name=self.project_name,
213
- eval_name=eval_run_name,
214
- examples=[example],
215
- scorers=[scorer],
216
- model=model,
217
- trace_span_id=span_id,
218
- )
219
-
220
- self.add_eval_run(eval_run, start_time)
221
-
222
- # Enqueue the evaluation run to the local evaluation queue
223
- self.tracer.local_eval_queue.enqueue(eval_run)
224
-
225
- def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
226
- current_span_id = eval_run.trace_span_id
227
-
228
- if current_span_id:
229
- span = self.span_id_to_span[current_span_id]
230
- span.has_evaluation = True
231
- self.evaluation_runs.append(eval_run)
232
-
233
- def record_input(self, inputs: dict):
234
- current_span_id = self.get_current_span()
235
- if current_span_id:
236
- span = self.span_id_to_span[current_span_id]
237
- if "self" in inputs:
238
- del inputs["self"]
239
- span.inputs = inputs
240
-
241
- try:
242
- self.otel_span_processor.queue_span_update(span, span_state="input")
243
- except Exception as e:
244
- judgeval_logger.warning(f"Failed to queue span with input data: {e}")
245
-
246
- def record_agent_name(self, agent_name: str):
247
- current_span_id = self.get_current_span()
248
- if current_span_id:
249
- span = self.span_id_to_span[current_span_id]
250
- span.agent_name = agent_name
251
-
252
- self.otel_span_processor.queue_span_update(span, span_state="agent_name")
253
-
254
- def record_class_name(self, class_name: str):
255
- current_span_id = self.get_current_span()
256
- if current_span_id:
257
- span = self.span_id_to_span[current_span_id]
258
- span.class_name = class_name
259
-
260
- self.otel_span_processor.queue_span_update(span, span_state="class_name")
261
-
262
- def record_state_before(self, state: dict):
263
- """Records the agent's state before a tool execution on the current span.
264
-
265
- Args:
266
- state: A dictionary representing the agent's state.
267
- """
268
- current_span_id = self.get_current_span()
269
- if current_span_id:
270
- span = self.span_id_to_span[current_span_id]
271
- span.state_before = state
272
-
273
- self.otel_span_processor.queue_span_update(span, span_state="state_before")
274
-
275
- def record_state_after(self, state: dict):
276
- """Records the agent's state after a tool execution on the current span.
277
-
278
- Args:
279
- state: A dictionary representing the agent's state.
280
- """
281
- current_span_id = self.get_current_span()
282
- if current_span_id:
283
- span = self.span_id_to_span[current_span_id]
284
- span.state_after = state
285
-
286
- self.otel_span_processor.queue_span_update(span, span_state="state_after")
287
-
288
- def record_output(self, output: Any):
289
- current_span_id = self.get_current_span()
290
- if current_span_id:
291
- span = self.span_id_to_span[current_span_id]
292
- span.output = output
293
-
294
- self.otel_span_processor.queue_span_update(span, span_state="output")
295
-
296
- return span
297
- return None
298
-
299
- def record_usage(self, usage: TraceUsage):
300
- current_span_id = self.get_current_span()
301
- if current_span_id:
302
- span = self.span_id_to_span[current_span_id]
303
- span.usage = usage
304
-
305
- self.otel_span_processor.queue_span_update(span, span_state="usage")
306
-
307
- return span
308
- return None
309
-
310
- def record_error(self, error: Dict[str, Any]):
311
- current_span_id = self.get_current_span()
312
- if current_span_id:
313
- span = self.span_id_to_span[current_span_id]
314
- span.error = error
315
-
316
- self.otel_span_processor.queue_span_update(span, span_state="error")
317
-
318
- return span
319
- return None
320
-
321
- def add_span(self, span: TraceSpan):
322
- """Add a trace span to this trace context"""
323
- self.trace_spans.append(span)
324
- self.span_id_to_span[span.span_id] = span
325
- return self
326
-
327
- def print(self):
328
- """Print the complete trace with proper visual structure"""
329
- for span in self.trace_spans:
330
- span.print_span()
331
-
332
- def get_duration(self) -> float:
333
- """
334
- Get the total duration of this trace
335
- """
336
- if self.start_time is None:
337
- return 0.0
338
- return time.time() - self.start_time
339
-
340
- def save(self, final_save: bool = False) -> Tuple[str, dict]:
341
- """
342
- Save the current trace to the database with rate limiting checks.
343
- First checks usage limits, then upserts the trace if allowed.
344
-
345
- Args:
346
- final_save: Whether this is the final save (updates usage counters)
347
-
348
- Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
349
- """
350
- if final_save:
351
- try:
352
- self.otel_span_processor.flush_pending_spans()
353
- except Exception as e:
354
- judgeval_logger.warning(
355
- f"Error flushing spans for trace {self.trace_id}: {e}"
356
- )
357
-
358
- total_duration = self.get_duration()
359
-
360
- trace_data = {
361
- "trace_id": self.trace_id,
362
- "name": self.name,
363
- "project_name": self.project_name,
364
- "created_at": datetime.fromtimestamp(
365
- self.start_time or time.time(), timezone.utc
366
- ).isoformat(),
367
- "duration": total_duration,
368
- "trace_spans": [span.model_dump() for span in self.trace_spans],
369
- "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
370
- "offline_mode": self.tracer.offline_mode,
371
- "parent_trace_id": self.parent_trace_id,
372
- "parent_name": self.parent_name,
373
- "customer_id": self.customer_id,
374
- "tags": self.tags,
375
- "metadata": self.metadata,
376
- "update_id": self.update_id,
377
- }
378
-
379
- server_response = self.trace_manager_client.upsert_trace(
380
- trace_data,
381
- offline_mode=self.tracer.offline_mode,
382
- show_link=not final_save,
383
- final_save=final_save,
384
- )
385
-
386
- if self.start_time is None:
387
- self.start_time = time.time()
388
-
389
- self.update_id += 1
390
-
391
- return self.trace_id, server_response
392
-
393
- def delete(self):
394
- return self.trace_manager_client.delete_trace(self.trace_id)
395
-
396
- def update_metadata(self, metadata: dict):
397
- """
398
- Set metadata for this trace.
399
-
400
- Args:
401
- metadata: Metadata as a dictionary
402
-
403
- Supported keys:
404
- - customer_id: ID of the customer using this trace
405
- - tags: List of tags for this trace
406
- - has_notification: Whether this trace has a notification
407
- - name: Name of the trace
408
- """
409
- for k, v in metadata.items():
410
- if k == "customer_id":
411
- if v is not None:
412
- self.customer_id = str(v)
413
- else:
414
- self.customer_id = None
415
- elif k == "tags":
416
- if isinstance(v, list):
417
- for item in v:
418
- if not isinstance(item, (str, set, tuple)):
419
- raise ValueError(
420
- f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
421
- )
422
- self.tags = v
423
- else:
424
- raise ValueError(
425
- f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
426
- )
427
- elif k == "has_notification":
428
- if not isinstance(v, bool):
429
- raise ValueError(
430
- f"has_notification must be a boolean, got {type(v)}"
431
- )
432
- self.has_notification = v
433
- elif k == "name":
434
- self.name = v
435
- else:
436
- self.metadata[k] = v
437
-
438
- def set_customer_id(self, customer_id: str):
439
- """
440
- Set the customer ID for this trace.
441
-
442
- Args:
443
- customer_id: The customer ID to set
444
- """
445
- self.update_metadata({"customer_id": customer_id})
446
-
447
- def set_tags(self, tags: List[Union[str, set, tuple]]):
448
- """
449
- Set the tags for this trace.
450
-
451
- Args:
452
- tags: List of tags to set
453
- """
454
- self.update_metadata({"tags": tags})
455
-
456
- def set_reward_score(self, reward_score: Union[float, Dict[str, float]]):
457
- """
458
- Set the reward score for this trace to be used for RL or SFT.
459
-
460
- Args:
461
- reward_score: The reward score to set
462
- """
463
- self.update_metadata({"reward_score": reward_score})
464
-
465
-
466
- def _capture_exception_for_trace(
467
- current_trace: Optional[TraceClient], exc_info: OptExcInfo
468
- ):
469
- if not current_trace:
470
- return
471
-
472
- exc_type, exc_value, exc_traceback_obj = exc_info
473
- formatted_exception = {
474
- "type": exc_type.__name__ if exc_type else "UnknownExceptionType",
475
- "message": str(exc_value) if exc_value else "No exception message",
476
- "traceback": (
477
- traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
478
- ),
479
- }
480
-
481
- # This is where we specially handle exceptions that we might want to collect additional data for.
482
- # When we do this, always try checking the module from sys.modules instead of importing. This will
483
- # Let us support a wider range of exceptions without needing to import them for all clients.
484
-
485
- # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
486
- # The alternative is to hand select libraries we want from sys.modules and check for them:
487
- # As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
488
-
489
- # General HTTP Like errors
490
- try:
491
- url = getattr(getattr(exc_value, "request", None), "url", None)
492
- status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
493
- if status_code:
494
- formatted_exception["http"] = {
495
- "url": url if url else "Unknown URL",
496
- "status_code": status_code if status_code else None,
497
- }
498
- except Exception:
499
- pass
500
-
501
- current_trace.record_error(formatted_exception)
502
-
503
-
504
- class _DeepTracer:
505
- _instance: Optional["_DeepTracer"] = None
506
- _lock: threading.Lock = threading.Lock()
507
- _refcount: int = 0
508
- _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
509
- "_deep_profiler_span_stack", default=[]
510
- )
511
- _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
512
- "_deep_profiler_skip_stack", default=[]
513
- )
514
- _original_sys_trace: Optional[Callable] = None
515
- _original_threading_trace: Optional[Callable] = None
516
-
517
- def __init__(self, tracer: "Tracer"):
518
- self._tracer = tracer
519
-
520
- def _get_qual_name(self, frame) -> str:
521
- func_name = frame.f_code.co_name
522
- module_name = frame.f_globals.get("__name__", "unknown_module")
523
-
524
- try:
525
- func = frame.f_globals.get(func_name)
526
- if func is None:
527
- return f"{module_name}.{func_name}"
528
- if hasattr(func, "__qualname__"):
529
- return f"{module_name}.{func.__qualname__}"
530
- return f"{module_name}.{func_name}"
531
- except Exception:
532
- return f"{module_name}.{func_name}"
533
-
534
- def __new__(cls, tracer: "Tracer"):
535
- with cls._lock:
536
- if cls._instance is None:
537
- cls._instance = super().__new__(cls)
538
- return cls._instance
539
-
540
- def _should_trace(self, frame):
541
- # Skip stack is maintained by the tracer as an optimization to skip earlier
542
- # frames in the call stack that we've already determined should be skipped
543
- skip_stack = self._skip_stack.get()
544
- if len(skip_stack) > 0:
545
- return False
546
-
547
- func_name = frame.f_code.co_name
548
- module_name = frame.f_globals.get("__name__", None)
549
- func = frame.f_globals.get(func_name)
550
- if func and (
551
- hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
552
- ):
553
- return False
554
-
555
- if (
556
- not module_name
557
- or func_name.startswith("<") # ex: <listcomp>
558
- or func_name.startswith("__")
559
- and func_name != "__call__" # dunders
560
- or not self._is_user_code(frame.f_code.co_filename)
561
- ):
562
- return False
563
-
564
- return True
565
-
566
- @functools.cache
567
- def _is_user_code(self, filename: str):
568
- return (
569
- bool(filename)
570
- and not filename.startswith("<")
571
- and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
572
- )
573
-
574
- def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
575
- """Cooperative trace function for sys.settrace that chains with existing tracers."""
576
- # First, call the original sys trace function if it exists
577
- original_result = None
578
- if self._original_sys_trace:
579
- try:
580
- original_result = self._original_sys_trace(frame, event, arg)
581
- except Exception:
582
- pass
583
-
584
- our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
585
-
586
- if original_result is None and self._original_sys_trace:
587
- return None
588
-
589
- return our_result or original_result
590
-
591
- def _cooperative_threading_trace(
592
- self, frame: types.FrameType, event: str, arg: Any
593
- ):
594
- """Cooperative trace function for threading.settrace that chains with existing tracers."""
595
- original_result = None
596
- if self._original_threading_trace:
597
- try:
598
- original_result = self._original_threading_trace(frame, event, arg)
599
- except Exception:
600
- pass
601
-
602
- our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
603
-
604
- if original_result is None and self._original_threading_trace:
605
- return None
606
-
607
- return our_result or original_result
608
-
609
- def _trace(
610
- self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
611
- ):
612
- frame.f_trace_lines = False
613
- frame.f_trace_opcodes = False
614
-
615
- if not self._should_trace(frame):
616
- return
617
-
618
- if event not in ("call", "return", "exception"):
619
- return
620
-
621
- current_trace = self._tracer.get_current_trace()
622
- if not current_trace:
623
- return
624
-
625
- parent_span_id = self._tracer.get_current_span()
626
- if not parent_span_id:
627
- return
628
-
629
- qual_name = self._get_qual_name(frame)
630
- instance_name = None
631
- class_name = None
632
- if "self" in frame.f_locals:
633
- instance = frame.f_locals["self"]
634
- class_name = instance.__class__.__name__
635
- class_identifiers = getattr(self._tracer, "class_identifiers", {})
636
- instance_name = get_instance_prefixed_name(
637
- instance, class_name, class_identifiers
638
- )
639
- skip_stack = self._skip_stack.get()
640
-
641
- if event == "call":
642
- # If we have entries in the skip stack and the current qual_name matches the top entry,
643
- # push it again to track nesting depth and skip
644
- # As an optimization, we only care about duplicate qual_names.
645
- if skip_stack:
646
- if qual_name == skip_stack[-1]:
647
- skip_stack.append(qual_name)
648
- self._skip_stack.set(skip_stack)
649
- return
650
-
651
- should_trace = self._should_trace(frame)
652
-
653
- if not should_trace:
654
- if not skip_stack:
655
- self._skip_stack.set([qual_name])
656
- return
657
- elif event == "return":
658
- # If we have entries in skip stack and current qual_name matches the top entry,
659
- # pop it to track exiting from the skipped section
660
- if skip_stack and qual_name == skip_stack[-1]:
661
- skip_stack.pop()
662
- self._skip_stack.set(skip_stack)
663
- return
664
-
665
- if skip_stack:
666
- return
667
-
668
- span_stack = self._span_stack.get()
669
- if event == "call":
670
- if not self._should_trace(frame):
671
- return
672
-
673
- span_id = str(uuid.uuid4())
674
-
675
- parent_depth = current_trace._span_depths.get(parent_span_id, 0)
676
- depth = parent_depth + 1
677
-
678
- current_trace._span_depths[span_id] = depth
679
-
680
- start_time = time.time()
681
-
682
- span_stack.append(
683
- {
684
- "span_id": span_id,
685
- "parent_span_id": parent_span_id,
686
- "function": qual_name,
687
- "start_time": start_time,
688
- }
689
- )
690
- self._span_stack.set(span_stack)
691
-
692
- token = self._tracer.set_current_span(span_id)
693
- frame.f_locals["_judgment_span_token"] = token
694
-
695
- span = TraceSpan(
696
- span_id=span_id,
697
- trace_id=current_trace.trace_id,
698
- depth=depth,
699
- message=qual_name,
700
- created_at=start_time,
701
- span_type="span",
702
- parent_span_id=parent_span_id,
703
- function=qual_name,
704
- agent_name=instance_name,
705
- class_name=class_name,
706
- )
707
- current_trace.add_span(span)
708
-
709
- inputs = {}
710
- try:
711
- args_info = inspect.getargvalues(frame)
712
- for arg in args_info.args:
713
- try:
714
- inputs[arg] = args_info.locals.get(arg)
715
- except Exception:
716
- inputs[arg] = "<<Unserializable>>"
717
- current_trace.record_input(inputs)
718
- except Exception as e:
719
- current_trace.record_input({"error": str(e)})
720
-
721
- elif event == "return":
722
- if not span_stack:
723
- return
724
-
725
- current_id = self._tracer.get_current_span()
726
-
727
- span_data = None
728
- for i, entry in enumerate(reversed(span_stack)):
729
- if entry["span_id"] == current_id:
730
- span_data = span_stack.pop(-(i + 1))
731
- self._span_stack.set(span_stack)
732
- break
733
-
734
- if not span_data:
735
- return
736
-
737
- start_time = span_data["start_time"]
738
- duration = time.time() - start_time
739
-
740
- current_trace.span_id_to_span[span_data["span_id"]].duration = duration
741
-
742
- if arg is not None:
743
- # exception handling will take priority.
744
- current_trace.record_output(arg)
745
-
746
- if span_data["span_id"] in current_trace._span_depths:
747
- del current_trace._span_depths[span_data["span_id"]]
748
-
749
- if span_stack:
750
- self._tracer.set_current_span(span_stack[-1]["span_id"])
751
- else:
752
- self._tracer.set_current_span(span_data["parent_span_id"])
753
-
754
- if "_judgment_span_token" in frame.f_locals:
755
- self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
756
-
757
- elif event == "exception":
758
- exc_type = arg[0]
759
- if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
760
- return
761
- _capture_exception_for_trace(current_trace, arg)
762
-
763
- return continuation_func
764
-
765
- def __enter__(self):
766
- with self._lock:
767
- self._refcount += 1
768
- if self._refcount == 1:
769
- # Store the existing trace functions before setting ours
770
- self._original_sys_trace = sys.gettrace()
771
- self._original_threading_trace = threading.gettrace()
772
-
773
- self._skip_stack.set([])
774
- self._span_stack.set([])
775
-
776
- sys.settrace(self._cooperative_sys_trace)
777
- threading.settrace(self._cooperative_threading_trace)
778
- return self
779
-
780
- def __exit__(self, exc_type, exc_val, exc_tb):
781
- with self._lock:
782
- self._refcount -= 1
783
- if self._refcount == 0:
784
- # Restore the original trace functions instead of setting to None
785
- sys.settrace(self._original_sys_trace)
786
- threading.settrace(self._original_threading_trace)
787
-
788
- # Clean up the references
789
- self._original_sys_trace = None
790
- self._original_threading_trace = None
791
-
792
-
793
- T = TypeVar("T", bound=Callable[..., Any])
794
- P = ParamSpec("P")
795
-
796
-
797
- class Tracer:
798
- # Tracer.current_trace class variable is currently used in wrap()
799
- # TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
800
- # Should be fine to do so as long as we keep Tracer as a singleton
801
- current_trace: Optional[TraceClient] = None
802
- # current_span_id: Optional[str] = None
803
-
804
- trace_across_async_contexts: bool = (
805
- False # BY default, we don't trace across async contexts
806
- )
807
-
808
- def __init__(
809
- self,
810
- api_key: Union[str, None] = os.getenv("JUDGMENT_API_KEY"),
811
- organization_id: Union[str, None] = os.getenv("JUDGMENT_ORG_ID"),
812
- project_name: Union[str, None] = None,
813
- deep_tracing: bool = False, # Deep tracing is disabled by default
814
- enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
815
- == "true",
816
- enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
817
- == "true",
818
- show_trace_urls: bool = os.getenv("JUDGMENT_SHOW_TRACE_URLS", "true").lower()
819
- == "true",
820
- # S3 configuration
821
- use_s3: bool = False,
822
- s3_bucket_name: Optional[str] = None,
823
- s3_aws_access_key_id: Optional[str] = None,
824
- s3_aws_secret_access_key: Optional[str] = None,
825
- s3_region_name: Optional[str] = None,
826
- trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
827
- span_batch_size: int = 50,
828
- span_flush_interval: float = 1.0,
829
- span_max_queue_size: int = 2048,
830
- span_export_timeout: int = 30000,
831
- ):
832
- try:
833
- if not api_key:
834
- raise ValueError(
835
- "api_key parameter must be provided. Please provide a valid API key value or set the JUDGMENT_API_KEY environment variable"
836
- )
837
-
838
- if not organization_id:
839
- raise ValueError(
840
- "organization_id parameter must be provided. Please provide a valid organization ID value or set the JUDGMENT_ORG_ID environment variable"
841
- )
842
-
843
- try:
844
- result, response = validate_api_key(api_key)
845
- except Exception as e:
846
- judgeval_logger.error(
847
- f"Issue with verifying API key, disabling monitoring: {e}"
848
- )
849
- enable_monitoring = False
850
- result = True
851
-
852
- if not result:
853
- raise ValueError(f"Issue with passed in Judgment API key: {response}")
854
-
855
- if use_s3 and not s3_bucket_name:
856
- raise ValueError("S3 bucket name must be provided when use_s3 is True")
857
-
858
- self.api_key: str = api_key
859
- self.project_name: str = project_name or "default_project"
860
- self.organization_id: str = organization_id
861
- self.traces: List[Trace] = []
862
- self.enable_monitoring: bool = enable_monitoring
863
- self.enable_evaluations: bool = enable_evaluations
864
- self.show_trace_urls: bool = show_trace_urls
865
- self.class_identifiers: Dict[
866
- str, str
867
- ] = {} # Dictionary to store class identifiers
868
- self.span_id_to_previous_span_id: Dict[str, Union[str, None]] = {}
869
- self.trace_id_to_previous_trace: Dict[str, Union[TraceClient, None]] = {}
870
- self.current_span_id: Optional[str] = None
871
- self.current_trace: Optional[TraceClient] = None
872
- self.trace_across_async_contexts: bool = trace_across_async_contexts
873
- Tracer.trace_across_async_contexts = trace_across_async_contexts
874
-
875
- # Initialize S3 storage if enabled
876
- self.use_s3 = use_s3
877
- if use_s3:
878
- from judgeval.common.storage.s3_storage import S3Storage
879
-
880
- try:
881
- self.s3_storage = S3Storage(
882
- bucket_name=s3_bucket_name,
883
- aws_access_key_id=s3_aws_access_key_id,
884
- aws_secret_access_key=s3_aws_secret_access_key,
885
- region_name=s3_region_name,
886
- )
887
- except Exception as e:
888
- judgeval_logger.error(
889
- f"Issue with initializing S3 storage, disabling S3: {e}"
890
- )
891
- self.use_s3 = False
892
-
893
- self.offline_mode = False # This is used to differentiate traces between online and offline (IE experiments vs monitoring page)
894
- self.deep_tracing: bool = deep_tracing
895
-
896
- self.span_batch_size = span_batch_size
897
- self.span_flush_interval = span_flush_interval
898
- self.span_max_queue_size = span_max_queue_size
899
- self.span_export_timeout = span_export_timeout
900
- self.otel_span_processor: SpanProcessorBase
901
- if enable_monitoring:
902
- self.otel_span_processor = JudgmentSpanProcessor(
903
- judgment_api_key=api_key,
904
- organization_id=organization_id,
905
- batch_size=span_batch_size,
906
- flush_interval=span_flush_interval,
907
- max_queue_size=span_max_queue_size,
908
- export_timeout=span_export_timeout,
909
- )
910
- else:
911
- self.otel_span_processor = SpanProcessorBase()
912
-
913
- # Initialize local evaluation queue for custom scorers
914
- self.local_eval_queue = LocalEvaluationQueue()
915
-
916
- # Start workers with callback to log results only if monitoring is enabled
917
- if enable_evaluations and enable_monitoring:
918
- self.local_eval_queue.start_workers(
919
- callback=self._log_eval_results_callback
920
- )
921
-
922
- atexit.register(self._cleanup_on_exit)
923
- except Exception as e:
924
- judgeval_logger.error(
925
- f"Issue with initializing Tracer: {e}. Disabling monitoring and evaluations."
926
- )
927
- self.enable_monitoring = False
928
- self.enable_evaluations = False
929
-
930
- def set_current_span(
931
- self, span_id: str
932
- ) -> Optional[contextvars.Token[Union[str, None]]]:
933
- self.span_id_to_previous_span_id[span_id] = self.current_span_id
934
- self.current_span_id = span_id
935
- Tracer.current_span_id = span_id
936
- try:
937
- token = current_span_var.set(span_id)
938
- except Exception:
939
- token = None
940
- return token
941
-
942
- def get_current_span(self) -> Optional[str]:
943
- try:
944
- current_span_var_val = current_span_var.get()
945
- except Exception:
946
- current_span_var_val = None
947
- return (
948
- (self.current_span_id or current_span_var_val)
949
- if self.trace_across_async_contexts
950
- else current_span_var_val
951
- )
952
-
953
- def reset_current_span(
954
- self,
955
- token: Optional[contextvars.Token[Union[str, None]]] = None,
956
- span_id: Optional[str] = None,
957
- ):
958
- try:
959
- if token:
960
- current_span_var.reset(token)
961
- except Exception:
962
- pass
963
- if not span_id:
964
- span_id = self.current_span_id
965
- if span_id:
966
- self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
967
- Tracer.current_span_id = self.current_span_id
968
-
969
- def set_current_trace(
970
- self, trace: TraceClient
971
- ) -> Optional[contextvars.Token[Union[TraceClient, None]]]:
972
- """
973
- Set the current trace context in contextvars
974
- """
975
- self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
976
- self.current_trace = trace
977
- Tracer.current_trace = trace
978
- try:
979
- token = current_trace_var.set(trace)
980
- except Exception:
981
- token = None
982
- return token
983
-
984
- def get_current_trace(self) -> Optional[TraceClient]:
985
- """
986
- Get the current trace context.
987
-
988
- Tries to get the trace client from the context variable first.
989
- If not found (e.g., context lost across threads/tasks),
990
- it falls back to the active trace client managed by the callback handler.
991
- """
992
- try:
993
- current_trace_var_val = current_trace_var.get()
994
- except Exception:
995
- current_trace_var_val = None
996
- return (
997
- (self.current_trace or current_trace_var_val)
998
- if self.trace_across_async_contexts
999
- else current_trace_var_val
1000
- )
1001
-
1002
- def reset_current_trace(
1003
- self,
1004
- token: Optional[contextvars.Token[Union[TraceClient, None]]] = None,
1005
- trace_id: Optional[str] = None,
1006
- ):
1007
- try:
1008
- if token:
1009
- current_trace_var.reset(token)
1010
- except Exception:
1011
- pass
1012
- if not trace_id and self.current_trace:
1013
- trace_id = self.current_trace.trace_id
1014
- if trace_id:
1015
- self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1016
- Tracer.current_trace = self.current_trace
1017
-
1018
- @contextmanager
1019
- def trace(
1020
- self, name: str, project_name: Union[str, None] = None
1021
- ) -> Generator[TraceClient, None, None]:
1022
- """Start a new trace context using a context manager"""
1023
- trace_id = str(uuid.uuid4())
1024
- project = project_name if project_name is not None else self.project_name
1025
-
1026
- # Get parent trace info from context
1027
- parent_trace = self.get_current_trace()
1028
- parent_trace_id = None
1029
- parent_name = None
1030
-
1031
- if parent_trace:
1032
- parent_trace_id = parent_trace.trace_id
1033
- parent_name = parent_trace.name
1034
-
1035
- trace = TraceClient(
1036
- self,
1037
- trace_id,
1038
- name,
1039
- project_name=project,
1040
- enable_monitoring=self.enable_monitoring,
1041
- enable_evaluations=self.enable_evaluations,
1042
- parent_trace_id=parent_trace_id,
1043
- parent_name=parent_name,
1044
- )
1045
-
1046
- # Set the current trace in context variables
1047
- token = self.set_current_trace(trace)
1048
-
1049
- with trace.span(name or "unnamed_trace"):
1050
- try:
1051
- # Save the trace to the database to handle Evaluations' trace_id referential integrity
1052
- yield trace
1053
- finally:
1054
- # Reset the context variable
1055
- self.reset_current_trace(token)
1056
-
1057
- def agent(
1058
- self,
1059
- identifier: Optional[str] = None,
1060
- track_state: Optional[bool] = False,
1061
- track_attributes: Optional[List[str]] = None,
1062
- field_mappings: Optional[Dict[str, str]] = None,
1063
- ):
1064
- """
1065
- Class decorator that associates a class with a custom identifier and enables state tracking.
1066
-
1067
- This decorator creates a mapping between the class name and the provided
1068
- identifier, which can be useful for tagging, grouping, or referencing
1069
- classes in a standardized way. It also enables automatic state capture
1070
- for instances of the decorated class when used with tracing.
1071
-
1072
- Args:
1073
- identifier: The identifier to associate with the decorated class.
1074
- This will be used as the instance name in traces.
1075
- track_state: Whether to automatically capture the state (attributes)
1076
- of instances before and after function execution. Defaults to False.
1077
- track_attributes: Optional list of specific attribute names to track.
1078
- If None, all non-private attributes (not starting with '_')
1079
- will be tracked when track_state=True.
1080
- field_mappings: Optional dictionary mapping internal attribute names to
1081
- display names in the captured state. For example:
1082
- {"system_prompt": "instructions"} will capture the
1083
- 'instructions' attribute as 'system_prompt' in the state.
1084
-
1085
- Example:
1086
- @tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
1087
- class User:
1088
- # Class implementation
1089
- """
1090
-
1091
- def decorator(cls):
1092
- class_name = cls.__name__
1093
- self.class_identifiers[class_name] = {
1094
- "identifier": identifier,
1095
- "track_state": track_state,
1096
- "track_attributes": track_attributes,
1097
- "field_mappings": field_mappings or {},
1098
- "class_name": class_name,
1099
- }
1100
- return cls
1101
-
1102
- return decorator
1103
-
1104
- def identify(self, *args, **kwargs):
1105
- judgeval_logger.warning(
1106
- "identify() is deprecated and may not be supported in future versions of judgeval. Use the agent() decorator instead."
1107
- )
1108
- return self.agent(*args, **kwargs)
1109
-
1110
- def _capture_instance_state(
1111
- self, instance: Any, class_config: Dict[str, Any]
1112
- ) -> Dict[str, Any]:
1113
- """
1114
- Capture the state of an instance based on class configuration.
1115
- Args:
1116
- instance: The instance to capture the state of.
1117
- class_config: Configuration dictionary for state capture,
1118
- expected to contain 'track_attributes' and 'field_mappings'.
1119
- """
1120
- track_attributes = class_config.get("track_attributes")
1121
- field_mappings = class_config.get("field_mappings")
1122
-
1123
- if track_attributes:
1124
- state = {attr: getattr(instance, attr, None) for attr in track_attributes}
1125
- else:
1126
- state = {
1127
- k: v for k, v in instance.__dict__.items() if not k.startswith("_")
1128
- }
1129
-
1130
- if field_mappings:
1131
- state["field_mappings"] = field_mappings
1132
-
1133
- return state
1134
-
1135
- def _get_instance_state_if_tracked(self, args):
1136
- """
1137
- Extract instance state if the instance should be tracked.
1138
-
1139
- Returns the captured state dict if tracking is enabled, None otherwise.
1140
- """
1141
- if args and hasattr(args[0], "__class__"):
1142
- instance = args[0]
1143
- class_name = instance.__class__.__name__
1144
- if (
1145
- class_name in self.class_identifiers
1146
- and isinstance(self.class_identifiers[class_name], dict)
1147
- and self.class_identifiers[class_name].get("track_state", False)
1148
- ):
1149
- return self._capture_instance_state(
1150
- instance, self.class_identifiers[class_name]
1151
- )
1152
-
1153
- def _conditionally_capture_and_record_state(
1154
- self, trace_client_instance: TraceClient, args: tuple, is_before: bool
1155
- ):
1156
- """Captures instance state if tracked and records it via the trace_client."""
1157
- state = self._get_instance_state_if_tracked(args)
1158
- if state:
1159
- if is_before:
1160
- trace_client_instance.record_state_before(state)
1161
- else:
1162
- trace_client_instance.record_state_after(state)
1163
-
1164
- @overload
1165
- def observe(
1166
- self, func: T, *, name: Optional[str] = None, span_type: SpanType = "span"
1167
- ) -> T: ...
1168
-
1169
- @overload
1170
- def observe(
1171
- self,
1172
- *,
1173
- name: Optional[str] = None,
1174
- span_type: SpanType = "span",
1175
- ) -> Callable[[T], T]: ...
1176
-
1177
- def observe(
1178
- self,
1179
- func: Optional[T] = None,
1180
- *,
1181
- name: Optional[str] = None,
1182
- span_type: SpanType = "span",
1183
- ):
1184
- """
1185
- Decorator to trace function execution with detailed entry/exit information.
1186
-
1187
- Args:
1188
- func: The function to decorate
1189
- name: Optional custom name for the span (defaults to function name)
1190
- span_type: Type of span (default "span").
1191
- """
1192
- # If monitoring is disabled, return the function as is
1193
- try:
1194
- if not self.enable_monitoring:
1195
- return func if func else lambda f: f
1196
-
1197
- if func is None:
1198
- return lambda func: self.observe(
1199
- func,
1200
- name=name,
1201
- span_type=span_type,
1202
- )
1203
-
1204
- # Use provided name or fall back to function name
1205
- original_span_name = name or func.__name__
1206
-
1207
- # Store custom attributes on the function object
1208
- func._judgment_span_name = original_span_name # type: ignore
1209
- func._judgment_span_type = span_type # type: ignore
1210
-
1211
- except Exception:
1212
- return func
1213
-
1214
- def _record_span_data(span, args, kwargs):
1215
- """Helper function to record inputs, agent info, and state on a span."""
1216
- # Get class and agent info
1217
- class_name = None
1218
- agent_name = None
1219
- if args and hasattr(args[0], "__class__"):
1220
- class_name = args[0].__class__.__name__
1221
- agent_name = get_instance_prefixed_name(
1222
- args[0], class_name, self.class_identifiers
1223
- )
1224
-
1225
- # Record inputs, agent name, class name
1226
- inputs = combine_args_kwargs(func, args, kwargs)
1227
- span.record_input(inputs)
1228
- if agent_name:
1229
- span.record_agent_name(agent_name)
1230
- if class_name and class_name in self.class_identifiers:
1231
- span.record_class_name(class_name)
1232
-
1233
- # Capture state before execution
1234
- self._conditionally_capture_and_record_state(span, args, is_before=True)
1235
-
1236
- return class_name, agent_name
1237
-
1238
- def _finalize_span_data(span, result, args):
1239
- """Helper function to record outputs and final state on a span."""
1240
- # Record output
1241
- span.record_output(result)
1242
-
1243
- # Capture state after execution
1244
- self._conditionally_capture_and_record_state(span, args, is_before=False)
1245
-
1246
- def _cleanup_trace(current_trace, trace_token, wrapper_type="function"):
1247
- """Helper function to handle trace cleanup in finally blocks."""
1248
- try:
1249
- trace_id, server_response = current_trace.save(final_save=True)
1250
-
1251
- complete_trace_data = {
1252
- "trace_id": current_trace.trace_id,
1253
- "name": current_trace.name,
1254
- "project_name": current_trace.project_name,
1255
- "created_at": datetime.fromtimestamp(
1256
- current_trace.start_time or time.time(),
1257
- timezone.utc,
1258
- ).isoformat(),
1259
- "duration": current_trace.get_duration(),
1260
- "trace_spans": [
1261
- span.model_dump() for span in current_trace.trace_spans
1262
- ],
1263
- "evaluation_runs": [
1264
- run.model_dump() for run in current_trace.evaluation_runs
1265
- ],
1266
- "offline_mode": self.offline_mode,
1267
- "parent_trace_id": current_trace.parent_trace_id,
1268
- "parent_name": current_trace.parent_name,
1269
- "customer_id": current_trace.customer_id,
1270
- "tags": current_trace.tags,
1271
- "metadata": current_trace.metadata,
1272
- "update_id": current_trace.update_id,
1273
- }
1274
- self.traces.append(complete_trace_data)
1275
- self.reset_current_trace(trace_token)
1276
- except Exception as e:
1277
- judgeval_logger.warning(f"Issue with {wrapper_type} cleanup: {e}")
1278
-
1279
- def _execute_in_span(
1280
- current_trace, span_name, span_type, execution_func, args, kwargs
1281
- ):
1282
- """Helper function to execute code within a span context."""
1283
- with current_trace.span(span_name, span_type=span_type) as span:
1284
- _record_span_data(span, args, kwargs)
1285
-
1286
- try:
1287
- result = execution_func()
1288
- _finalize_span_data(span, result, args)
1289
- return result
1290
- except Exception as e:
1291
- _capture_exception_for_trace(current_trace, sys.exc_info())
1292
- raise e
1293
-
1294
- async def _execute_in_span_async(
1295
- current_trace, span_name, span_type, async_execution_func, args, kwargs
1296
- ):
1297
- """Helper function to execute async code within a span context."""
1298
- with current_trace.span(span_name, span_type=span_type) as span:
1299
- _record_span_data(span, args, kwargs)
1300
-
1301
- try:
1302
- result = await async_execution_func()
1303
- _finalize_span_data(span, result, args)
1304
- return result
1305
- except Exception as e:
1306
- _capture_exception_for_trace(current_trace, sys.exc_info())
1307
- raise e
1308
-
1309
- def _create_new_trace(self, span_name):
1310
- """Helper function to create a new trace and set it as current."""
1311
- trace_id = str(uuid.uuid4())
1312
- project = self.project_name
1313
-
1314
- current_trace = TraceClient(
1315
- self,
1316
- trace_id,
1317
- span_name,
1318
- project_name=project,
1319
- enable_monitoring=self.enable_monitoring,
1320
- enable_evaluations=self.enable_evaluations,
1321
- )
1322
-
1323
- trace_token = self.set_current_trace(current_trace)
1324
- return current_trace, trace_token
1325
-
1326
- def _execute_with_auto_trace_creation(
1327
- span_name, span_type, execution_func, args, kwargs
1328
- ):
1329
- """Helper function that handles automatic trace creation and span execution."""
1330
- current_trace = self.get_current_trace()
1331
-
1332
- if not current_trace:
1333
- current_trace, trace_token = _create_new_trace(self, span_name)
1334
-
1335
- try:
1336
- result = _execute_in_span(
1337
- current_trace,
1338
- span_name,
1339
- span_type,
1340
- execution_func,
1341
- args,
1342
- kwargs,
1343
- )
1344
- return result
1345
- finally:
1346
- # Cleanup the trace we created
1347
- _cleanup_trace(current_trace, trace_token, "auto_trace")
1348
- else:
1349
- # Use existing trace
1350
- return _execute_in_span(
1351
- current_trace, span_name, span_type, execution_func, args, kwargs
1352
- )
1353
-
1354
- async def _execute_with_auto_trace_creation_async(
1355
- span_name, span_type, async_execution_func, args, kwargs
1356
- ):
1357
- """Helper function that handles automatic trace creation and async span execution."""
1358
- current_trace = self.get_current_trace()
1359
-
1360
- if not current_trace:
1361
- current_trace, trace_token = _create_new_trace(self, span_name)
1362
-
1363
- try:
1364
- result = await _execute_in_span_async(
1365
- current_trace,
1366
- span_name,
1367
- span_type,
1368
- async_execution_func,
1369
- args,
1370
- kwargs,
1371
- )
1372
- return result
1373
- finally:
1374
- # Cleanup the trace we created
1375
- _cleanup_trace(current_trace, trace_token, "async_auto_trace")
1376
- else:
1377
- # Use existing trace
1378
- return await _execute_in_span_async(
1379
- current_trace,
1380
- span_name,
1381
- span_type,
1382
- async_execution_func,
1383
- args,
1384
- kwargs,
1385
- )
1386
-
1387
- # Check for generator functions first
1388
- if inspect.isgeneratorfunction(func):
1389
-
1390
- @functools.wraps(func)
1391
- def generator_wrapper(*args, **kwargs):
1392
- # Get the generator from the original function
1393
- generator = func(*args, **kwargs)
1394
-
1395
- # Create wrapper generator that creates spans for each yield
1396
- def traced_generator():
1397
- while True:
1398
- try:
1399
- # Handle automatic trace creation and span execution
1400
- item = _execute_with_auto_trace_creation(
1401
- original_span_name,
1402
- span_type,
1403
- lambda: next(generator),
1404
- args,
1405
- kwargs,
1406
- )
1407
- yield item
1408
- except StopIteration:
1409
- break
1410
-
1411
- return traced_generator()
1412
-
1413
- return generator_wrapper
1414
-
1415
- # Check for async generator functions
1416
- elif inspect.isasyncgenfunction(func):
1417
-
1418
- @functools.wraps(func)
1419
- def async_generator_wrapper(*args, **kwargs):
1420
- # Get the async generator from the original function
1421
- async_generator = func(*args, **kwargs)
1422
-
1423
- # Create wrapper async generator that creates spans for each yield
1424
- async def traced_async_generator():
1425
- while True:
1426
- try:
1427
- # Handle automatic trace creation and span execution
1428
- item = await _execute_with_auto_trace_creation_async(
1429
- original_span_name,
1430
- span_type,
1431
- lambda: async_generator.__anext__(),
1432
- args,
1433
- kwargs,
1434
- )
1435
- if inspect.iscoroutine(item):
1436
- item = await item
1437
- yield item
1438
- except StopAsyncIteration:
1439
- break
1440
-
1441
- return traced_async_generator()
1442
-
1443
- return async_generator_wrapper
1444
-
1445
- elif asyncio.iscoroutinefunction(func):
1446
-
1447
- @functools.wraps(func)
1448
- async def async_wrapper(*args, **kwargs):
1449
- nonlocal original_span_name
1450
- span_name = original_span_name
1451
-
1452
- async def async_execution():
1453
- if self.deep_tracing:
1454
- with _DeepTracer(self):
1455
- return await func(*args, **kwargs)
1456
- else:
1457
- return await func(*args, **kwargs)
1458
-
1459
- result = await _execute_with_auto_trace_creation_async(
1460
- span_name, span_type, async_execution, args, kwargs
1461
- )
1462
-
1463
- return result
1464
-
1465
- return async_wrapper
1466
- else:
1467
- # Non-async function implementation with deep tracing
1468
- @functools.wraps(func)
1469
- def wrapper(*args, **kwargs):
1470
- nonlocal original_span_name
1471
- span_name = original_span_name
1472
-
1473
- def sync_execution():
1474
- if self.deep_tracing:
1475
- with _DeepTracer(self):
1476
- return func(*args, **kwargs)
1477
- else:
1478
- return func(*args, **kwargs)
1479
-
1480
- return _execute_with_auto_trace_creation(
1481
- span_name, span_type, sync_execution, args, kwargs
1482
- )
1483
-
1484
- return wrapper
1485
-
1486
- def observe_tools(
1487
- self,
1488
- cls=None,
1489
- *,
1490
- exclude_methods: Optional[List[str]] = None,
1491
- include_private: bool = False,
1492
- warn_on_double_decoration: bool = True,
1493
- ):
1494
- """
1495
- Automatically adds @observe(span_type="tool") to all methods in a class.
1496
-
1497
- Args:
1498
- cls: The class to decorate (automatically provided when used as decorator)
1499
- exclude_methods: List of method names to skip decorating. Defaults to common magic methods
1500
- include_private: Whether to decorate methods starting with underscore. Defaults to False
1501
- warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
1502
- """
1503
-
1504
- if exclude_methods is None:
1505
- exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
1506
-
1507
- def decorate_class(cls):
1508
- if not self.enable_monitoring:
1509
- return cls
1510
-
1511
- decorated = []
1512
- skipped = []
1513
-
1514
- for name in dir(cls):
1515
- method = getattr(cls, name)
1516
-
1517
- if (
1518
- not callable(method)
1519
- or name in exclude_methods
1520
- or (name.startswith("_") and not include_private)
1521
- or not hasattr(cls, name)
1522
- ):
1523
- continue
1524
-
1525
- if hasattr(method, "_judgment_span_name"):
1526
- skipped.append(name)
1527
- if warn_on_double_decoration:
1528
- judgeval_logger.info(
1529
- f"{cls.__name__}.{name} already decorated, skipping"
1530
- )
1531
- continue
1532
-
1533
- try:
1534
- decorated_method = self.observe(method, span_type="tool")
1535
- setattr(cls, name, decorated_method)
1536
- decorated.append(name)
1537
- except Exception as e:
1538
- if warn_on_double_decoration:
1539
- judgeval_logger.warning(
1540
- f"Failed to decorate {cls.__name__}.{name}: {e}"
1541
- )
1542
-
1543
- return cls
1544
-
1545
- return decorate_class if cls is None else decorate_class(cls)
1546
-
1547
- def async_evaluate(
1548
- self,
1549
- scorer: Union[APIScorerConfig, BaseScorer],
1550
- example: Example,
1551
- model: str = DEFAULT_GPT_MODEL,
1552
- sampling_rate: float = 1,
1553
- ):
1554
- try:
1555
- if not self.enable_monitoring or not self.enable_evaluations:
1556
- return
1557
-
1558
- if not isinstance(scorer, (APIScorerConfig, BaseScorer)):
1559
- judgeval_logger.warning(
1560
- f"Scorer must be an instance of APIScorerConfig or BaseScorer, got {type(scorer)}, skipping evaluation"
1561
- )
1562
- return
1563
-
1564
- if not isinstance(example, Example):
1565
- judgeval_logger.warning(
1566
- f"Example must be an instance of Example, got {type(example)} skipping evaluation"
1567
- )
1568
- return
1569
-
1570
- if sampling_rate < 0:
1571
- judgeval_logger.warning(
1572
- "Cannot set sampling_rate below 0, skipping evaluation"
1573
- )
1574
- return
1575
-
1576
- if sampling_rate > 1:
1577
- judgeval_logger.warning(
1578
- "Cannot set sampling_rate above 1, skipping evaluation"
1579
- )
1580
- return
1581
-
1582
- percentage = random.uniform(0, 1)
1583
- if percentage > sampling_rate:
1584
- judgeval_logger.info("Skipping async_evaluate due to sampling rate")
1585
- return
1586
-
1587
- current_trace = self.get_current_trace()
1588
- if current_trace:
1589
- current_trace.async_evaluate(
1590
- scorer=scorer, example=example, model=model
1591
- )
1592
- else:
1593
- judgeval_logger.warning(
1594
- "No trace found (context var or fallback), skipping evaluation"
1595
- )
1596
- except Exception as e:
1597
- judgeval_logger.warning(f"Issue with async_evaluate: {e}")
1598
-
1599
- def update_metadata(self, metadata: dict):
1600
- """
1601
- Update metadata for the current trace.
1602
-
1603
- Args:
1604
- metadata: Metadata as a dictionary
1605
- """
1606
- current_trace = self.get_current_trace()
1607
- if current_trace:
1608
- current_trace.update_metadata(metadata)
1609
- else:
1610
- judgeval_logger.warning("No current trace found, cannot set metadata")
1611
-
1612
- def set_customer_id(self, customer_id: str):
1613
- """
1614
- Set the customer ID for the current trace.
1615
-
1616
- Args:
1617
- customer_id: The customer ID to set
1618
- """
1619
- current_trace = self.get_current_trace()
1620
- if current_trace:
1621
- current_trace.set_customer_id(customer_id)
1622
- else:
1623
- judgeval_logger.warning("No current trace found, cannot set customer ID")
1624
-
1625
- def set_tags(self, tags: List[Union[str, set, tuple]]):
1626
- """
1627
- Set the tags for the current trace.
1628
-
1629
- Args:
1630
- tags: List of tags to set
1631
- """
1632
- current_trace = self.get_current_trace()
1633
- if current_trace:
1634
- current_trace.set_tags(tags)
1635
- else:
1636
- judgeval_logger.warning("No current trace found, cannot set tags")
1637
-
1638
- def set_reward_score(self, reward_score: Union[float, Dict[str, float]]):
1639
- """
1640
- Set the reward score for this trace to be used for RL or SFT.
1641
-
1642
- Args:
1643
- reward_score: The reward score to set
1644
- """
1645
- current_trace = self.get_current_trace()
1646
- if current_trace:
1647
- current_trace.set_reward_score(reward_score)
1648
- else:
1649
- judgeval_logger.warning("No current trace found, cannot set reward score")
1650
-
1651
- def get_otel_span_processor(self) -> SpanProcessorBase:
1652
- """Get the OpenTelemetry span processor instance."""
1653
- return self.otel_span_processor
1654
-
1655
- def flush_background_spans(self, timeout_millis: int = 30000):
1656
- """Flush all pending spans in the background service."""
1657
- self.otel_span_processor.force_flush(timeout_millis)
1658
-
1659
- def shutdown_background_service(self):
1660
- """Shutdown the background span service."""
1661
- self.otel_span_processor.shutdown()
1662
- self.otel_span_processor = SpanProcessorBase()
1663
-
1664
- def wait_for_completion(self, timeout: Optional[float] = 30.0) -> bool:
1665
- """Wait for all evaluations and span processing to complete.
1666
-
1667
- This method blocks until all queued evaluations are processed and
1668
- all pending spans are flushed to the server.
1669
-
1670
- Args:
1671
- timeout: Maximum time to wait in seconds. Defaults to 30 seconds.
1672
- None means wait indefinitely.
1673
-
1674
- Returns:
1675
- True if all processing completed within the timeout, False otherwise.
1676
-
1677
- """
1678
- try:
1679
- judgeval_logger.debug(
1680
- "Waiting for all evaluations and spans to complete..."
1681
- )
1682
-
1683
- # Wait for all queued evaluation work to complete
1684
- eval_completed = self.local_eval_queue.wait_for_completion()
1685
- if not eval_completed:
1686
- judgeval_logger.warning(
1687
- f"Local evaluation queue did not complete within {timeout} seconds"
1688
- )
1689
- return False
1690
-
1691
- self.flush_background_spans()
1692
-
1693
- judgeval_logger.debug("All evaluations and spans completed successfully")
1694
- return True
1695
-
1696
- except Exception as e:
1697
- judgeval_logger.warning(f"Error while waiting for completion: {e}")
1698
- return False
1699
-
1700
- def _log_eval_results_callback(self, evaluation_run, scoring_results):
1701
- """Callback to log evaluation results after local processing."""
1702
- try:
1703
- if scoring_results and self.enable_evaluations and self.enable_monitoring:
1704
- # Convert scoring results to the format expected by API client
1705
- results_dict = [
1706
- result.model_dump(warnings=False) for result in scoring_results
1707
- ]
1708
- api_client = JudgmentApiClient(self.api_key, self.organization_id)
1709
- api_client.log_evaluation_results(
1710
- results_dict, evaluation_run.model_dump(warnings=False)
1711
- )
1712
- except Exception as e:
1713
- judgeval_logger.warning(f"Failed to log local evaluation results: {e}")
1714
-
1715
- def _cleanup_on_exit(self):
1716
- """Cleanup handler called on application exit to ensure spans are flushed."""
1717
- try:
1718
- # Wait for all queued evaluation work to complete before stopping
1719
- completed = self.local_eval_queue.wait_for_completion()
1720
- if not completed:
1721
- judgeval_logger.warning(
1722
- "Local evaluation queue did not complete within 30 seconds"
1723
- )
1724
-
1725
- self.local_eval_queue.stop_workers()
1726
- self.flush_background_spans()
1727
- except Exception as e:
1728
- judgeval_logger.warning(f"Error during tracer cleanup: {e}")
1729
- finally:
1730
- try:
1731
- self.shutdown_background_service()
1732
- except Exception as e:
1733
- judgeval_logger.warning(
1734
- f"Error during background service shutdown: {e}"
1735
- )
1736
-
1737
- def trace_to_message_history(
1738
- self, trace: Union[Trace, TraceClient]
1739
- ) -> List[Dict[str, str]]:
1740
- """
1741
- Extract message history from a trace for training purposes.
1742
-
1743
- This method processes trace spans to reconstruct the conversation flow,
1744
- extracting messages in chronological order from LLM, user, and tool spans.
1745
-
1746
- Args:
1747
- trace: Trace or TraceClient instance to extract messages from
1748
-
1749
- Returns:
1750
- List of message dictionaries with 'role' and 'content' keys
1751
-
1752
- Raises:
1753
- ValueError: If no trace is provided
1754
- """
1755
- if not trace:
1756
- raise ValueError("No trace provided")
1757
-
1758
- # Handle both Trace and TraceClient objects
1759
- if isinstance(trace, TraceClient):
1760
- spans = trace.trace_spans
1761
- else:
1762
- spans = trace.trace_spans if hasattr(trace, "trace_spans") else []
1763
-
1764
- messages = []
1765
- first_found = False
1766
-
1767
- # Process spans in chronological order
1768
- for span in sorted(
1769
- spans, key=lambda s: s.created_at if hasattr(s, "created_at") else 0
1770
- ):
1771
- # Skip spans without output (except for first LLM span which may have input messages)
1772
- if span.output is None and span.span_type != "llm":
1773
- continue
1774
-
1775
- if span.span_type == "llm":
1776
- # For the first LLM span, extract input messages (system + user prompts)
1777
- if not first_found and hasattr(span, "inputs") and span.inputs:
1778
- input_messages = span.inputs.get("messages", [])
1779
- if input_messages:
1780
- first_found = True
1781
- # Add input messages (typically system and user messages)
1782
- for msg in input_messages:
1783
- if (
1784
- isinstance(msg, dict)
1785
- and "role" in msg
1786
- and "content" in msg
1787
- ):
1788
- messages.append(
1789
- {"role": msg["role"], "content": msg["content"]}
1790
- )
1791
-
1792
- # Add assistant response from span output
1793
- if span.output is not None:
1794
- messages.append({"role": "assistant", "content": str(span.output)})
1795
-
1796
- elif span.span_type == "user":
1797
- # Add user messages
1798
- if span.output is not None:
1799
- messages.append({"role": "user", "content": str(span.output)})
1800
-
1801
- elif span.span_type == "tool":
1802
- # Add tool responses as user messages (common pattern in training)
1803
- if span.output is not None:
1804
- messages.append({"role": "user", "content": str(span.output)})
1805
-
1806
- return messages
1807
-
1808
- def get_current_message_history(self) -> List[Dict[str, str]]:
1809
- """
1810
- Get message history from the current trace.
1811
-
1812
- Returns:
1813
- List of message dictionaries from the current trace context
1814
-
1815
- Raises:
1816
- ValueError: If no current trace is found
1817
- """
1818
- current_trace = self.get_current_trace()
1819
- if not current_trace:
1820
- raise ValueError("No current trace found")
1821
-
1822
- return self.trace_to_message_history(current_trace)
1823
-
1824
-
1825
- def _get_current_trace(
1826
- trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
1827
- ):
1828
- if trace_across_async_contexts:
1829
- return Tracer.current_trace
1830
- else:
1831
- return current_trace_var.get()
1832
-
1833
-
1834
- def wrap(
1835
- client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
1836
- ) -> Any:
1837
- """
1838
- Wraps an API client to add tracing capabilities.
1839
- Supports OpenAI, Together, Anthropic, Google GenAI clients, and TrainableModel.
1840
- Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1841
- """
1842
- (
1843
- span_name,
1844
- original_create,
1845
- original_responses_create,
1846
- original_stream,
1847
- original_beta_parse,
1848
- ) = _get_client_config(client)
1849
-
1850
- def process_span(span, response):
1851
- """Format and record the output in the span"""
1852
- output, usage = _format_output_data(client, response)
1853
- span.record_output(output)
1854
- span.record_usage(usage)
1855
-
1856
- return response
1857
-
1858
- def wrapped(function):
1859
- def wrapper(*args, **kwargs):
1860
- current_trace = _get_current_trace(trace_across_async_contexts)
1861
- if not current_trace:
1862
- return function(*args, **kwargs)
1863
-
1864
- with current_trace.span(span_name, span_type="llm") as span:
1865
- span.record_input(kwargs)
1866
-
1867
- try:
1868
- response = function(*args, **kwargs)
1869
- return process_span(span, response)
1870
- except Exception as e:
1871
- _capture_exception_for_trace(span, sys.exc_info())
1872
- raise e
1873
-
1874
- return wrapper
1875
-
1876
- def wrapped_async(function):
1877
- async def wrapper(*args, **kwargs):
1878
- current_trace = _get_current_trace(trace_across_async_contexts)
1879
- if not current_trace:
1880
- return await function(*args, **kwargs)
1881
-
1882
- with current_trace.span(span_name, span_type="llm") as span:
1883
- span.record_input(kwargs)
1884
-
1885
- try:
1886
- response = await function(*args, **kwargs)
1887
- return process_span(span, response)
1888
- except Exception as e:
1889
- _capture_exception_for_trace(span, sys.exc_info())
1890
- raise e
1891
-
1892
- return wrapper
1893
-
1894
- if HAS_OPENAI:
1895
- from judgeval.common.tracer.providers import openai_OpenAI, openai_AsyncOpenAI
1896
-
1897
- assert openai_OpenAI is not None, "OpenAI client not found"
1898
- assert openai_AsyncOpenAI is not None, "OpenAI async client not found"
1899
- if isinstance(client, (openai_OpenAI)):
1900
- setattr(client.chat.completions, "create", wrapped(original_create))
1901
- setattr(client.responses, "create", wrapped(original_responses_create))
1902
- setattr(client.beta.chat.completions, "parse", wrapped(original_beta_parse))
1903
- elif isinstance(client, (openai_AsyncOpenAI)):
1904
- setattr(client.chat.completions, "create", wrapped_async(original_create))
1905
- setattr(
1906
- client.responses, "create", wrapped_async(original_responses_create)
1907
- )
1908
- setattr(
1909
- client.beta.chat.completions,
1910
- "parse",
1911
- wrapped_async(original_beta_parse),
1912
- )
1913
-
1914
- if HAS_TOGETHER:
1915
- from judgeval.common.tracer.providers import (
1916
- together_Together,
1917
- together_AsyncTogether,
1918
- )
1919
-
1920
- assert together_Together is not None, "Together client not found"
1921
- assert together_AsyncTogether is not None, "Together async client not found"
1922
- if isinstance(client, (together_Together)):
1923
- setattr(client.chat.completions, "create", wrapped(original_create))
1924
- elif isinstance(client, (together_AsyncTogether)):
1925
- setattr(client.chat.completions, "create", wrapped_async(original_create))
1926
-
1927
- if HAS_ANTHROPIC:
1928
- from judgeval.common.tracer.providers import (
1929
- anthropic_Anthropic,
1930
- anthropic_AsyncAnthropic,
1931
- )
1932
-
1933
- assert anthropic_Anthropic is not None, "Anthropic client not found"
1934
- assert anthropic_AsyncAnthropic is not None, "Anthropic async client not found"
1935
- if isinstance(client, (anthropic_Anthropic)):
1936
- setattr(client.messages, "create", wrapped(original_create))
1937
- elif isinstance(client, (anthropic_AsyncAnthropic)):
1938
- setattr(client.messages, "create", wrapped_async(original_create))
1939
-
1940
- if HAS_GOOGLE_GENAI:
1941
- from judgeval.common.tracer.providers import (
1942
- google_genai_Client,
1943
- google_genai_AsyncClient,
1944
- )
1945
-
1946
- assert google_genai_Client is not None, "Google GenAI client not found"
1947
- assert google_genai_AsyncClient is not None, (
1948
- "Google GenAI async client not found"
1949
- )
1950
- if isinstance(client, (google_genai_Client)):
1951
- setattr(client.models, "generate_content", wrapped(original_create))
1952
- elif isinstance(client, (google_genai_AsyncClient)):
1953
- setattr(client.models, "generate_content", wrapped_async(original_create))
1954
-
1955
- if HAS_GROQ:
1956
- from judgeval.common.tracer.providers import groq_Groq, groq_AsyncGroq
1957
-
1958
- assert groq_Groq is not None, "Groq client not found"
1959
- assert groq_AsyncGroq is not None, "Groq async client not found"
1960
- if isinstance(client, (groq_Groq)):
1961
- setattr(client.chat.completions, "create", wrapped(original_create))
1962
- elif isinstance(client, (groq_AsyncGroq)):
1963
- setattr(client.chat.completions, "create", wrapped_async(original_create))
1964
-
1965
- # Check for TrainableModel from judgeval.common.trainer
1966
- try:
1967
- from judgeval.common.trainer import TrainableModel
1968
-
1969
- if isinstance(client, TrainableModel):
1970
- # Define a wrapper function that can be reapplied to new model instances
1971
- def wrap_model_instance(model_instance):
1972
- """Wrap a model instance with tracing functionality"""
1973
- if hasattr(model_instance, "chat") and hasattr(
1974
- model_instance.chat, "completions"
1975
- ):
1976
- if hasattr(model_instance.chat.completions, "create"):
1977
- setattr(
1978
- model_instance.chat.completions,
1979
- "create",
1980
- wrapped(model_instance.chat.completions.create),
1981
- )
1982
- if hasattr(model_instance.chat.completions, "acreate"):
1983
- setattr(
1984
- model_instance.chat.completions,
1985
- "acreate",
1986
- wrapped_async(model_instance.chat.completions.acreate),
1987
- )
1988
-
1989
- # Register the wrapper function with the TrainableModel
1990
- client._register_tracer_wrapper(wrap_model_instance)
1991
-
1992
- # Apply wrapping to the current model
1993
- wrap_model_instance(client._current_model)
1994
- except ImportError:
1995
- pass # TrainableModel not available
1996
-
1997
- return client
1998
-
1999
-
2000
- # Helper functions for client-specific operations
2001
-
2002
-
2003
- def _get_client_config(
2004
- client: ApiClient,
2005
- ) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
2006
- """Returns configuration tuple for the given API client.
2007
-
2008
- Args:
2009
- client: An instance of OpenAI, Together, or Anthropic client
2010
-
2011
- Returns:
2012
- tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
2013
- - span_name: String identifier for tracing
2014
- - create_method: Reference to the client's creation method
2015
- - responses_method: Reference to the client's responses method (if applicable)
2016
- - stream_method: Reference to the client's stream method (if applicable)
2017
- - beta_parse_method: Reference to the client's beta parse method (if applicable)
2018
-
2019
- Raises:
2020
- ValueError: If client type is not supported
2021
- """
2022
-
2023
- if HAS_OPENAI:
2024
- from judgeval.common.tracer.providers import openai_OpenAI, openai_AsyncOpenAI
2025
-
2026
- assert openai_OpenAI is not None, "OpenAI client not found"
2027
- assert openai_AsyncOpenAI is not None, "OpenAI async client not found"
2028
- if isinstance(client, (openai_OpenAI)):
2029
- return (
2030
- "OPENAI_API_CALL",
2031
- client.chat.completions.create,
2032
- client.responses.create,
2033
- None,
2034
- client.beta.chat.completions.parse,
2035
- )
2036
- elif isinstance(client, (openai_AsyncOpenAI)):
2037
- return (
2038
- "OPENAI_API_CALL",
2039
- client.chat.completions.create,
2040
- client.responses.create,
2041
- None,
2042
- client.beta.chat.completions.parse,
2043
- )
2044
- if HAS_TOGETHER:
2045
- from judgeval.common.tracer.providers import (
2046
- together_Together,
2047
- together_AsyncTogether,
2048
- )
2049
-
2050
- assert together_Together is not None, "Together client not found"
2051
- assert together_AsyncTogether is not None, "Together async client not found"
2052
- if isinstance(client, (together_Together)):
2053
- return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
2054
- elif isinstance(client, (together_AsyncTogether)):
2055
- return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
2056
- if HAS_ANTHROPIC:
2057
- from judgeval.common.tracer.providers import (
2058
- anthropic_Anthropic,
2059
- anthropic_AsyncAnthropic,
2060
- )
2061
-
2062
- assert anthropic_Anthropic is not None, "Anthropic client not found"
2063
- assert anthropic_AsyncAnthropic is not None, "Anthropic async client not found"
2064
- if isinstance(client, (anthropic_Anthropic)):
2065
- return (
2066
- "ANTHROPIC_API_CALL",
2067
- client.messages.create,
2068
- None,
2069
- client.messages.stream,
2070
- None,
2071
- )
2072
- elif isinstance(client, (anthropic_AsyncAnthropic)):
2073
- return (
2074
- "ANTHROPIC_API_CALL",
2075
- client.messages.create,
2076
- None,
2077
- client.messages.stream,
2078
- None,
2079
- )
2080
- if HAS_GOOGLE_GENAI:
2081
- from judgeval.common.tracer.providers import (
2082
- google_genai_Client,
2083
- google_genai_AsyncClient,
2084
- )
2085
-
2086
- assert google_genai_Client is not None, "Google GenAI client not found"
2087
- assert google_genai_AsyncClient is not None, (
2088
- "Google GenAI async client not found"
2089
- )
2090
- if isinstance(client, (google_genai_Client)):
2091
- return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
2092
- elif isinstance(client, (google_genai_AsyncClient)):
2093
- return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
2094
- if HAS_GROQ:
2095
- from judgeval.common.tracer.providers import groq_Groq, groq_AsyncGroq
2096
-
2097
- assert groq_Groq is not None, "Groq client not found"
2098
- assert groq_AsyncGroq is not None, "Groq async client not found"
2099
- if isinstance(client, (groq_Groq)):
2100
- return "GROQ_API_CALL", client.chat.completions.create, None, None, None
2101
- elif isinstance(client, (groq_AsyncGroq)):
2102
- return "GROQ_API_CALL", client.chat.completions.create, None, None, None
2103
-
2104
- # Check for TrainableModel
2105
- try:
2106
- from judgeval.common.trainer import TrainableModel
2107
-
2108
- if isinstance(client, TrainableModel):
2109
- return (
2110
- "FIREWORKS_TRAINABLE_MODEL_CALL",
2111
- client._current_model.chat.completions.create,
2112
- None,
2113
- None,
2114
- None,
2115
- )
2116
- except ImportError:
2117
- pass # TrainableModel not available
2118
-
2119
- raise ValueError(f"Unsupported client type: {type(client)}")
2120
-
2121
-
2122
- def _format_output_data(
2123
- client: ApiClient, response: Any
2124
- ) -> tuple[Optional[str], Optional[TraceUsage]]:
2125
- """Format API response data based on client type.
2126
-
2127
- Normalizes different response formats into a consistent structure
2128
- for tracing purposes.
2129
-
2130
- Returns:
2131
- dict containing:
2132
- - content: The generated text
2133
- - usage: Token usage statistics
2134
- """
2135
- prompt_tokens = 0
2136
- completion_tokens = 0
2137
- cache_read_input_tokens = 0
2138
- cache_creation_input_tokens = 0
2139
- model_name = None
2140
- message_content = None
2141
-
2142
- if HAS_OPENAI:
2143
- from judgeval.common.tracer.providers import (
2144
- openai_OpenAI,
2145
- openai_AsyncOpenAI,
2146
- openai_ChatCompletion,
2147
- openai_Response,
2148
- openai_ParsedChatCompletion,
2149
- )
2150
-
2151
- assert openai_OpenAI is not None, "OpenAI client not found"
2152
- assert openai_AsyncOpenAI is not None, "OpenAI async client not found"
2153
- assert openai_ChatCompletion is not None, "OpenAI chat completion not found"
2154
- assert openai_Response is not None, "OpenAI response not found"
2155
- assert openai_ParsedChatCompletion is not None, (
2156
- "OpenAI parsed chat completion not found"
2157
- )
2158
-
2159
- if isinstance(client, (openai_OpenAI, openai_AsyncOpenAI)):
2160
- if isinstance(response, openai_ChatCompletion):
2161
- model_name = response.model
2162
- prompt_tokens = response.usage.prompt_tokens if response.usage else 0
2163
- completion_tokens = (
2164
- response.usage.completion_tokens if response.usage else 0
2165
- )
2166
- cache_read_input_tokens = (
2167
- response.usage.prompt_tokens_details.cached_tokens
2168
- if response.usage
2169
- and response.usage.prompt_tokens_details
2170
- and response.usage.prompt_tokens_details.cached_tokens
2171
- else 0
2172
- )
2173
-
2174
- if isinstance(response, openai_ParsedChatCompletion):
2175
- message_content = response.choices[0].message.parsed
2176
- else:
2177
- message_content = response.choices[0].message.content
2178
- elif isinstance(response, openai_Response):
2179
- model_name = response.model
2180
- prompt_tokens = response.usage.input_tokens if response.usage else 0
2181
- completion_tokens = (
2182
- response.usage.output_tokens if response.usage else 0
2183
- )
2184
- cache_read_input_tokens = (
2185
- response.usage.input_tokens_details.cached_tokens
2186
- if response.usage and response.usage.input_tokens_details
2187
- else 0
2188
- )
2189
- if hasattr(response.output[0], "content"):
2190
- message_content = "".join(
2191
- seg.text
2192
- for seg in response.output[0].content
2193
- if hasattr(seg, "text")
2194
- )
2195
- # Note: LiteLLM seems to use cache_read_input_tokens to calculate the cost for OpenAI
2196
- return message_content, _create_usage(
2197
- model_name,
2198
- prompt_tokens,
2199
- completion_tokens,
2200
- cache_read_input_tokens,
2201
- cache_creation_input_tokens,
2202
- )
2203
-
2204
- if HAS_TOGETHER:
2205
- from judgeval.common.tracer.providers import (
2206
- together_Together,
2207
- together_AsyncTogether,
2208
- )
2209
-
2210
- assert together_Together is not None, "Together client not found"
2211
- assert together_AsyncTogether is not None, "Together async client not found"
2212
-
2213
- if isinstance(client, (together_Together, together_AsyncTogether)):
2214
- model_name = "together_ai/" + response.model
2215
- prompt_tokens = response.usage.prompt_tokens
2216
- completion_tokens = response.usage.completion_tokens
2217
- message_content = response.choices[0].message.content
2218
-
2219
- # As of 2025-07-14, Together does not do any input cache token tracking
2220
- return message_content, _create_usage(
2221
- model_name,
2222
- prompt_tokens,
2223
- completion_tokens,
2224
- cache_read_input_tokens,
2225
- cache_creation_input_tokens,
2226
- )
2227
-
2228
- if HAS_GOOGLE_GENAI:
2229
- from judgeval.common.tracer.providers import (
2230
- google_genai_Client,
2231
- google_genai_AsyncClient,
2232
- )
2233
-
2234
- assert google_genai_Client is not None, "Google GenAI client not found"
2235
- assert google_genai_AsyncClient is not None, (
2236
- "Google GenAI async client not found"
2237
- )
2238
- if isinstance(client, (google_genai_Client, google_genai_AsyncClient)):
2239
- model_name = response.model_version
2240
- prompt_tokens = response.usage_metadata.prompt_token_count
2241
- completion_tokens = response.usage_metadata.candidates_token_count
2242
- message_content = response.candidates[0].content.parts[0].text
2243
-
2244
- if hasattr(response.usage_metadata, "cached_content_token_count"):
2245
- cache_read_input_tokens = (
2246
- response.usage_metadata.cached_content_token_count
2247
- )
2248
- return message_content, _create_usage(
2249
- model_name,
2250
- prompt_tokens,
2251
- completion_tokens,
2252
- cache_read_input_tokens,
2253
- cache_creation_input_tokens,
2254
- )
2255
-
2256
- if HAS_ANTHROPIC:
2257
- from judgeval.common.tracer.providers import (
2258
- anthropic_Anthropic,
2259
- anthropic_AsyncAnthropic,
2260
- )
2261
-
2262
- assert anthropic_Anthropic is not None, "Anthropic client not found"
2263
- assert anthropic_AsyncAnthropic is not None, "Anthropic async client not found"
2264
- if isinstance(client, (anthropic_Anthropic, anthropic_AsyncAnthropic)):
2265
- model_name = response.model
2266
- prompt_tokens = response.usage.input_tokens
2267
- completion_tokens = response.usage.output_tokens
2268
- cache_read_input_tokens = response.usage.cache_read_input_tokens
2269
- cache_creation_input_tokens = response.usage.cache_creation_input_tokens
2270
- message_content = response.content[0].text
2271
- return message_content, _create_usage(
2272
- model_name,
2273
- prompt_tokens,
2274
- completion_tokens,
2275
- cache_read_input_tokens,
2276
- cache_creation_input_tokens,
2277
- )
2278
-
2279
- if HAS_GROQ:
2280
- from judgeval.common.tracer.providers import groq_Groq, groq_AsyncGroq
2281
-
2282
- assert groq_Groq is not None, "Groq client not found"
2283
- assert groq_AsyncGroq is not None, "Groq async client not found"
2284
- if isinstance(client, (groq_Groq, groq_AsyncGroq)):
2285
- model_name = "groq/" + response.model
2286
- prompt_tokens = response.usage.prompt_tokens
2287
- completion_tokens = response.usage.completion_tokens
2288
- message_content = response.choices[0].message.content
2289
- return message_content, _create_usage(
2290
- model_name,
2291
- prompt_tokens,
2292
- completion_tokens,
2293
- cache_read_input_tokens,
2294
- cache_creation_input_tokens,
2295
- )
2296
-
2297
- # Check for TrainableModel
2298
- try:
2299
- from judgeval.common.trainer import TrainableModel
2300
-
2301
- if isinstance(client, TrainableModel):
2302
- # TrainableModel uses Fireworks LLM internally, so response format should be similar to OpenAI
2303
- if (
2304
- hasattr(response, "model")
2305
- and hasattr(response, "usage")
2306
- and hasattr(response, "choices")
2307
- ):
2308
- model_name = response.model
2309
- prompt_tokens = response.usage.prompt_tokens if response.usage else 0
2310
- completion_tokens = (
2311
- response.usage.completion_tokens if response.usage else 0
2312
- )
2313
- message_content = response.choices[0].message.content
2314
-
2315
- # Use LiteLLM cost calculation with fireworks_ai prefix
2316
- # LiteLLM supports Fireworks AI models for cost calculation when prefixed with "fireworks_ai/"
2317
- fireworks_model_name = f"fireworks_ai/{model_name}"
2318
- return message_content, _create_usage(
2319
- fireworks_model_name,
2320
- prompt_tokens,
2321
- completion_tokens,
2322
- cache_read_input_tokens,
2323
- cache_creation_input_tokens,
2324
- )
2325
- except ImportError:
2326
- pass # TrainableModel not available
2327
-
2328
- judgeval_logger.warning(f"Unsupported client type: {type(client)}")
2329
- return None, None
2330
-
2331
-
2332
- def _create_usage(
2333
- model_name: str,
2334
- prompt_tokens: int,
2335
- completion_tokens: int,
2336
- cache_read_input_tokens: int = 0,
2337
- cache_creation_input_tokens: int = 0,
2338
- ) -> TraceUsage:
2339
- """Helper function to create TraceUsage object with cost calculation."""
2340
- prompt_cost, completion_cost = cost_per_token(
2341
- model=model_name,
2342
- prompt_tokens=prompt_tokens,
2343
- completion_tokens=completion_tokens,
2344
- cache_read_input_tokens=cache_read_input_tokens,
2345
- cache_creation_input_tokens=cache_creation_input_tokens,
2346
- )
2347
- total_cost_usd = (
2348
- (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2349
- )
2350
- return TraceUsage(
2351
- prompt_tokens=prompt_tokens,
2352
- completion_tokens=completion_tokens,
2353
- total_tokens=prompt_tokens + completion_tokens,
2354
- cache_read_input_tokens=cache_read_input_tokens,
2355
- cache_creation_input_tokens=cache_creation_input_tokens,
2356
- prompt_tokens_cost_usd=prompt_cost,
2357
- completion_tokens_cost_usd=completion_cost,
2358
- total_cost_usd=total_cost_usd,
2359
- model_name=model_name,
2360
- )
2361
-
2362
-
2363
- def combine_args_kwargs(func, args, kwargs):
2364
- """
2365
- Combine positional arguments and keyword arguments into a single dictionary.
2366
-
2367
- Args:
2368
- func: The function being called
2369
- args: Tuple of positional arguments
2370
- kwargs: Dictionary of keyword arguments
2371
-
2372
- Returns:
2373
- A dictionary combining both args and kwargs
2374
- """
2375
- try:
2376
- import inspect
2377
-
2378
- sig = inspect.signature(func)
2379
- param_names = list(sig.parameters.keys())
2380
-
2381
- args_dict = {}
2382
- for i, arg in enumerate(args):
2383
- if i < len(param_names):
2384
- args_dict[param_names[i]] = arg
2385
- else:
2386
- args_dict[f"arg{i}"] = arg
2387
-
2388
- return {**args_dict, **kwargs}
2389
- except Exception:
2390
- # Fallback if signature inspection fails
2391
- return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
2392
-
2393
-
2394
- def cost_per_token(*args, **kwargs):
2395
- try:
2396
- prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
2397
- _original_cost_per_token(*args, **kwargs)
2398
- )
2399
- if (
2400
- prompt_tokens_cost_usd_dollar == 0
2401
- and completion_tokens_cost_usd_dollar == 0
2402
- ):
2403
- judgeval_logger.warning("LiteLLM returned a total of 0 for cost per token")
2404
- return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
2405
- except Exception as e:
2406
- judgeval_logger.warning(f"Error calculating cost per token: {e}")
2407
- return None, None
2408
-
2409
-
2410
- # --- Helper function for instance-prefixed qual_name ---
2411
- def get_instance_prefixed_name(instance, class_name, class_identifiers):
2412
- """
2413
- Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
2414
- Otherwise, returns None.
2415
- """
2416
- if class_name in class_identifiers:
2417
- class_config = class_identifiers[class_name]
2418
- attr = class_config.get("identifier")
2419
- if attr:
2420
- if hasattr(instance, attr) and not callable(getattr(instance, attr)):
2421
- instance_name = getattr(instance, attr)
2422
- return instance_name
2423
- else:
2424
- raise Exception(
2425
- f"Attribute {attr} does not exist for {class_name}. Check your agent() decorator."
2426
- )
2427
- return None