judgeval 0.0.54__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. judgeval/common/api/__init__.py +3 -0
  2. judgeval/common/api/api.py +352 -0
  3. judgeval/common/api/constants.py +165 -0
  4. judgeval/common/storage/__init__.py +6 -0
  5. judgeval/common/tracer/__init__.py +31 -0
  6. judgeval/common/tracer/constants.py +22 -0
  7. judgeval/common/tracer/core.py +1916 -0
  8. judgeval/common/tracer/otel_exporter.py +108 -0
  9. judgeval/common/tracer/otel_span_processor.py +234 -0
  10. judgeval/common/tracer/span_processor.py +37 -0
  11. judgeval/common/tracer/span_transformer.py +211 -0
  12. judgeval/common/tracer/trace_manager.py +92 -0
  13. judgeval/common/utils.py +2 -2
  14. judgeval/constants.py +3 -30
  15. judgeval/data/datasets/eval_dataset_client.py +29 -156
  16. judgeval/data/judgment_types.py +4 -12
  17. judgeval/data/result.py +1 -1
  18. judgeval/data/scorer_data.py +2 -2
  19. judgeval/data/scripts/openapi_transform.py +1 -1
  20. judgeval/data/trace.py +66 -1
  21. judgeval/data/trace_run.py +0 -3
  22. judgeval/evaluation_run.py +0 -2
  23. judgeval/integrations/langgraph.py +43 -164
  24. judgeval/judgment_client.py +17 -211
  25. judgeval/run_evaluation.py +209 -611
  26. judgeval/scorers/__init__.py +2 -6
  27. judgeval/scorers/base_scorer.py +4 -23
  28. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +3 -3
  29. judgeval/scorers/judgeval_scorers/api_scorers/prompt_scorer.py +215 -0
  30. judgeval/scorers/score.py +2 -1
  31. judgeval/scorers/utils.py +1 -13
  32. judgeval/utils/requests.py +21 -0
  33. judgeval-0.1.0.dist-info/METADATA +202 -0
  34. {judgeval-0.0.54.dist-info → judgeval-0.1.0.dist-info}/RECORD +37 -29
  35. judgeval/common/tracer.py +0 -3215
  36. judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +0 -73
  37. judgeval/scorers/judgeval_scorers/classifiers/__init__.py +0 -3
  38. judgeval/scorers/judgeval_scorers/classifiers/text2sql/__init__.py +0 -3
  39. judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +0 -53
  40. judgeval-0.0.54.dist-info/METADATA +0 -1384
  41. /judgeval/common/{s3_storage.py → storage/s3_storage.py} +0 -0
  42. {judgeval-0.0.54.dist-info → judgeval-0.1.0.dist-info}/WHEEL +0 -0
  43. {judgeval-0.0.54.dist-info → judgeval-0.1.0.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py DELETED
@@ -1,3215 +0,0 @@
1
- """
2
- Tracing system for judgeval that allows for function tracing using decorators.
3
- """
4
-
5
- # Standard library imports
6
- import asyncio
7
- import functools
8
- import inspect
9
- import os
10
- import site
11
- import sysconfig
12
- import threading
13
- import time
14
- import traceback
15
- import uuid
16
- import contextvars
17
- import sys
18
- import json
19
- from contextlib import (
20
- contextmanager,
21
- AbstractAsyncContextManager,
22
- AbstractContextManager,
23
- ) # Import context manager bases
24
- from dataclasses import dataclass
25
- from datetime import datetime, timezone
26
- from http import HTTPStatus
27
- from typing import (
28
- Any,
29
- Callable,
30
- Dict,
31
- Generator,
32
- List,
33
- Literal,
34
- Optional,
35
- Tuple,
36
- Union,
37
- AsyncGenerator,
38
- TypeAlias,
39
- )
40
- from rich import print as rprint
41
- import types
42
-
43
- # Third-party imports
44
- from requests import RequestException
45
- from judgeval.utils.requests import requests
46
- from litellm import cost_per_token as _original_cost_per_token
47
- from openai import OpenAI, AsyncOpenAI
48
- from together import Together, AsyncTogether
49
- from anthropic import Anthropic, AsyncAnthropic
50
- from google import genai
51
-
52
- # Local application/library-specific imports
53
- from judgeval.constants import (
54
- JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
55
- JUDGMENT_TRACES_UPSERT_API_URL,
56
- JUDGMENT_TRACES_FETCH_API_URL,
57
- JUDGMENT_TRACES_DELETE_API_URL,
58
- JUDGMENT_PROJECT_DELETE_API_URL,
59
- JUDGMENT_TRACES_SPANS_BATCH_API_URL,
60
- JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
61
- )
62
- from judgeval.data import Example, Trace, TraceSpan, TraceUsage
63
- from judgeval.scorers import APIScorerConfig, BaseScorer
64
- from judgeval.evaluation_run import EvaluationRun
65
- from judgeval.common.utils import ExcInfo, validate_api_key
66
- from judgeval.common.logger import judgeval_logger
67
-
68
- # Standard library imports needed for the new class
69
- import concurrent.futures
70
- from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
71
- import queue
72
- import atexit
73
-
74
- # Define context variables for tracking the current trace and the current span within a trace
75
- current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
76
- "current_trace", default=None
77
- )
78
- current_span_var = contextvars.ContextVar[Optional[str]](
79
- "current_span", default=None
80
- ) # ContextVar for the active span id
81
-
82
- # Define type aliases for better code readability and maintainability
83
- ApiClient: TypeAlias = Union[
84
- OpenAI,
85
- Together,
86
- Anthropic,
87
- AsyncOpenAI,
88
- AsyncAnthropic,
89
- AsyncTogether,
90
- genai.Client,
91
- genai.client.AsyncClient,
92
- ] # Supported API clients
93
- SpanType = Literal["span", "tool", "llm", "evaluation", "chain"]
94
-
95
-
96
- # Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
97
- @dataclass
98
- class TraceAnnotation:
99
- """Represents a single annotation for a trace span."""
100
-
101
- span_id: str
102
- text: str
103
- label: str
104
- score: int
105
-
106
- def to_dict(self) -> dict:
107
- """Convert the annotation to a dictionary format for storage/transmission."""
108
- return {
109
- "span_id": self.span_id,
110
- "annotation": {"text": self.text, "label": self.label, "score": self.score},
111
- }
112
-
113
-
114
- class TraceManagerClient:
115
- """
116
- Client for handling trace endpoints with the Judgment API
117
-
118
-
119
- Operations include:
120
- - Fetching a trace by id
121
- - Saving a trace
122
- - Deleting a trace
123
- """
124
-
125
- def __init__(
126
- self,
127
- judgment_api_key: str,
128
- organization_id: str,
129
- tracer: Optional["Tracer"] = None,
130
- ):
131
- self.judgment_api_key = judgment_api_key
132
- self.organization_id = organization_id
133
- self.tracer = tracer
134
-
135
- def fetch_trace(self, trace_id: str):
136
- """
137
- Fetch a trace by its id
138
- """
139
- response = requests.post(
140
- JUDGMENT_TRACES_FETCH_API_URL,
141
- json={
142
- "trace_id": trace_id,
143
- },
144
- headers={
145
- "Content-Type": "application/json",
146
- "Authorization": f"Bearer {self.judgment_api_key}",
147
- "X-Organization-Id": self.organization_id,
148
- },
149
- verify=True,
150
- )
151
-
152
- if response.status_code != HTTPStatus.OK:
153
- raise ValueError(f"Failed to fetch traces: {response.text}")
154
-
155
- return response.json()
156
-
157
- def upsert_trace(
158
- self,
159
- trace_data: dict,
160
- offline_mode: bool = False,
161
- show_link: bool = True,
162
- final_save: bool = True,
163
- ):
164
- """
165
- Upserts a trace to the Judgment API (always overwrites if exists).
166
-
167
- Args:
168
- trace_data: The trace data to upsert
169
- offline_mode: Whether running in offline mode
170
- show_link: Whether to show the UI link (for live tracing)
171
- final_save: Whether this is the final save (controls S3 saving)
172
-
173
- Returns:
174
- dict: Server response containing UI URL and other metadata
175
- """
176
-
177
- def fallback_encoder(obj):
178
- """
179
- Custom JSON encoder fallback.
180
- Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
181
- """
182
- try:
183
- return repr(obj)
184
- except Exception:
185
- try:
186
- return str(obj)
187
- except Exception as e:
188
- return f"<Unserializable object of type {type(obj).__name__}: {e}>"
189
-
190
- serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
191
-
192
- response = requests.post(
193
- JUDGMENT_TRACES_UPSERT_API_URL,
194
- data=serialized_trace_data,
195
- headers={
196
- "Content-Type": "application/json",
197
- "Authorization": f"Bearer {self.judgment_api_key}",
198
- "X-Organization-Id": self.organization_id,
199
- },
200
- verify=True,
201
- )
202
-
203
- if response.status_code != HTTPStatus.OK:
204
- raise ValueError(f"Failed to upsert trace data: {response.text}")
205
-
206
- # Parse server response
207
- server_response = response.json()
208
-
209
- # If S3 storage is enabled, save to S3 only on final save
210
- if self.tracer and self.tracer.use_s3 and final_save:
211
- try:
212
- s3_key = self.tracer.s3_storage.save_trace(
213
- trace_data=trace_data,
214
- trace_id=trace_data["trace_id"],
215
- project_name=trace_data["project_name"],
216
- )
217
- judgeval_logger.info(f"Trace also saved to S3 at key: {s3_key}")
218
- except Exception as e:
219
- judgeval_logger.warning(f"Failed to save trace to S3: {str(e)}")
220
-
221
- if not offline_mode and show_link and "ui_results_url" in server_response:
222
- pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
223
- rprint(pretty_str)
224
-
225
- return server_response
226
-
227
- ## TODO: Should have a log endpoint, endpoint should also support batched payloads
228
- def save_annotation(self, annotation: TraceAnnotation):
229
- json_data = {
230
- "span_id": annotation.span_id,
231
- "annotation": {
232
- "text": annotation.text,
233
- "label": annotation.label,
234
- "score": annotation.score,
235
- },
236
- }
237
-
238
- response = requests.post(
239
- JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
240
- json=json_data,
241
- headers={
242
- "Content-Type": "application/json",
243
- "Authorization": f"Bearer {self.judgment_api_key}",
244
- "X-Organization-Id": self.organization_id,
245
- },
246
- verify=True,
247
- )
248
-
249
- if response.status_code != HTTPStatus.OK:
250
- raise ValueError(f"Failed to save annotation: {response.text}")
251
-
252
- return response.json()
253
-
254
- def delete_trace(self, trace_id: str):
255
- """
256
- Delete a trace from the database.
257
- """
258
- response = requests.delete(
259
- JUDGMENT_TRACES_DELETE_API_URL,
260
- json={
261
- "trace_ids": [trace_id],
262
- },
263
- headers={
264
- "Content-Type": "application/json",
265
- "Authorization": f"Bearer {self.judgment_api_key}",
266
- "X-Organization-Id": self.organization_id,
267
- },
268
- )
269
-
270
- if response.status_code != HTTPStatus.OK:
271
- raise ValueError(f"Failed to delete trace: {response.text}")
272
-
273
- return response.json()
274
-
275
- def delete_traces(self, trace_ids: List[str]):
276
- """
277
- Delete a batch of traces from the database.
278
- """
279
- response = requests.delete(
280
- JUDGMENT_TRACES_DELETE_API_URL,
281
- json={
282
- "trace_ids": trace_ids,
283
- },
284
- headers={
285
- "Content-Type": "application/json",
286
- "Authorization": f"Bearer {self.judgment_api_key}",
287
- "X-Organization-Id": self.organization_id,
288
- },
289
- )
290
-
291
- if response.status_code != HTTPStatus.OK:
292
- raise ValueError(f"Failed to delete trace: {response.text}")
293
-
294
- return response.json()
295
-
296
- def delete_project(self, project_name: str):
297
- """
298
- Deletes a project from the server. Which also deletes all evaluations and traces associated with the project.
299
- """
300
- response = requests.delete(
301
- JUDGMENT_PROJECT_DELETE_API_URL,
302
- json={
303
- "project_name": project_name,
304
- },
305
- headers={
306
- "Content-Type": "application/json",
307
- "Authorization": f"Bearer {self.judgment_api_key}",
308
- "X-Organization-Id": self.organization_id,
309
- },
310
- )
311
-
312
- if response.status_code != HTTPStatus.OK:
313
- raise ValueError(f"Failed to delete traces: {response.text}")
314
-
315
- return response.json()
316
-
317
-
318
- class TraceClient:
319
- """Client for managing a single trace context"""
320
-
321
- def __init__(
322
- self,
323
- tracer: "Tracer",
324
- trace_id: Optional[str] = None,
325
- name: str = "default",
326
- project_name: str | None = None,
327
- enable_monitoring: bool = True,
328
- enable_evaluations: bool = True,
329
- parent_trace_id: Optional[str] = None,
330
- parent_name: Optional[str] = None,
331
- ):
332
- self.name = name
333
- self.trace_id = trace_id or str(uuid.uuid4())
334
- self.project_name = project_name or "default_project"
335
- self.tracer = tracer
336
- self.enable_monitoring = enable_monitoring
337
- self.enable_evaluations = enable_evaluations
338
- self.parent_trace_id = parent_trace_id
339
- self.parent_name = parent_name
340
- self.customer_id: Optional[str] = None # Added customer_id attribute
341
- self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
342
- self.metadata: Dict[str, Any] = {}
343
- self.has_notification: Optional[bool] = False # Initialize has_notification
344
- self.update_id: int = 1 # Initialize update_id to 1, increments with each save
345
- self.trace_spans: List[TraceSpan] = []
346
- self.span_id_to_span: Dict[str, TraceSpan] = {}
347
- self.evaluation_runs: List[EvaluationRun] = []
348
- self.annotations: List[TraceAnnotation] = []
349
- self.start_time: Optional[float] = (
350
- None # Will be set after first successful save
351
- )
352
- self.trace_manager_client = TraceManagerClient(
353
- tracer.api_key, tracer.organization_id, tracer
354
- )
355
- self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
356
-
357
- # Get background span service from tracer
358
- self.background_span_service = (
359
- tracer.get_background_span_service() if tracer else None
360
- )
361
-
362
- def get_current_span(self):
363
- """Get the current span from the context var"""
364
- return self.tracer.get_current_span()
365
-
366
- def set_current_span(self, span: Any):
367
- """Set the current span from the context var"""
368
- return self.tracer.set_current_span(span)
369
-
370
- def reset_current_span(self, token: Any):
371
- """Reset the current span from the context var"""
372
- self.tracer.reset_current_span(token)
373
-
374
- @contextmanager
375
- def span(self, name: str, span_type: SpanType = "span"):
376
- """Context manager for creating a trace span, managing the current span via contextvars"""
377
- is_first_span = len(self.trace_spans) == 0
378
- if is_first_span:
379
- try:
380
- trace_id, server_response = self.save(final_save=False)
381
- # Set start_time after first successful save
382
- # Link will be shown by upsert_trace method
383
- except Exception as e:
384
- judgeval_logger.warning(
385
- f"Failed to save initial trace for live tracking: {e}"
386
- )
387
- start_time = time.time()
388
-
389
- # Generate a unique ID for *this specific span invocation*
390
- span_id = str(uuid.uuid4())
391
-
392
- parent_span_id = (
393
- self.get_current_span()
394
- ) # Get ID of the parent span from context var
395
- token = self.set_current_span(
396
- span_id
397
- ) # Set *this* span's ID as the current one
398
-
399
- current_depth = 0
400
- if parent_span_id and parent_span_id in self._span_depths:
401
- current_depth = self._span_depths[parent_span_id] + 1
402
-
403
- self._span_depths[span_id] = current_depth # Store depth by span_id
404
-
405
- span = TraceSpan(
406
- span_id=span_id,
407
- trace_id=self.trace_id,
408
- depth=current_depth,
409
- message=name,
410
- created_at=start_time,
411
- span_type=span_type,
412
- parent_span_id=parent_span_id,
413
- function=name,
414
- )
415
- self.add_span(span)
416
-
417
- # Queue span with initial state (input phase)
418
- if self.background_span_service:
419
- self.background_span_service.queue_span(span, span_state="input")
420
-
421
- try:
422
- yield self
423
- finally:
424
- duration = time.time() - start_time
425
- span.duration = duration
426
-
427
- # Queue span with completed state (output phase)
428
- if self.background_span_service:
429
- self.background_span_service.queue_span(span, span_state="completed")
430
-
431
- # Clean up depth tracking for this span_id
432
- if span_id in self._span_depths:
433
- del self._span_depths[span_id]
434
- # Reset context var
435
- self.reset_current_span(token)
436
-
437
- def async_evaluate(
438
- self,
439
- scorers: List[Union[APIScorerConfig, BaseScorer]],
440
- example: Optional[Example] = None,
441
- input: Optional[str] = None,
442
- actual_output: Optional[Union[str, List[str]]] = None,
443
- expected_output: Optional[Union[str, List[str]]] = None,
444
- context: Optional[List[str]] = None,
445
- retrieval_context: Optional[List[str]] = None,
446
- tools_called: Optional[List[str]] = None,
447
- expected_tools: Optional[List[str]] = None,
448
- additional_metadata: Optional[Dict[str, Any]] = None,
449
- model: Optional[str] = None,
450
- span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
451
- ):
452
- if not self.enable_evaluations:
453
- return
454
-
455
- start_time = time.time() # Record start time
456
-
457
- try:
458
- # Load appropriate implementations for all scorers
459
- if not scorers:
460
- judgeval_logger.warning("No valid scorers available for evaluation")
461
- return
462
-
463
- except Exception as e:
464
- judgeval_logger.warning(f"Failed to load scorers: {str(e)}")
465
- return
466
-
467
- # If example is not provided, create one from the individual parameters
468
- if example is None:
469
- # Check if any of the individual parameters are provided
470
- if any(
471
- param is not None
472
- for param in [
473
- input,
474
- actual_output,
475
- expected_output,
476
- context,
477
- retrieval_context,
478
- tools_called,
479
- expected_tools,
480
- additional_metadata,
481
- ]
482
- ):
483
- example = Example(
484
- input=input,
485
- actual_output=actual_output,
486
- expected_output=expected_output,
487
- context=context,
488
- retrieval_context=retrieval_context,
489
- tools_called=tools_called,
490
- expected_tools=expected_tools,
491
- additional_metadata=additional_metadata,
492
- )
493
- else:
494
- raise ValueError(
495
- "Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided"
496
- )
497
-
498
- # Check examples before creating evaluation run
499
-
500
- # check_examples([example], scorers)
501
-
502
- span_id_to_use = span_id if span_id is not None else self.get_current_span()
503
-
504
- eval_run = EvaluationRun(
505
- organization_id=self.tracer.organization_id,
506
- project_name=self.project_name,
507
- eval_name=f"{self.name.capitalize()}-"
508
- f"{span_id_to_use}-" # Keep original eval name format using context var if available
509
- f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
510
- examples=[example],
511
- scorers=scorers,
512
- model=model,
513
- judgment_api_key=self.tracer.api_key,
514
- trace_span_id=span_id_to_use,
515
- )
516
-
517
- self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
518
-
519
- # Queue evaluation run through background service
520
- if self.background_span_service and span_id_to_use:
521
- # Get the current span data to avoid race conditions
522
- current_span = self.span_id_to_span.get(span_id_to_use)
523
- if current_span:
524
- self.background_span_service.queue_evaluation_run(
525
- eval_run, span_id=span_id_to_use, span_data=current_span
526
- )
527
-
528
- def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
529
- # --- Modification: Use span_id from eval_run ---
530
- current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
531
- # print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
532
- # --- End Modification ---
533
-
534
- if current_span_id:
535
- span = self.span_id_to_span[current_span_id]
536
- span.has_evaluation = True # Set the has_evaluation flag
537
- self.evaluation_runs.append(eval_run)
538
-
539
- def add_annotation(self, annotation: TraceAnnotation):
540
- """Add an annotation to this trace context"""
541
- self.annotations.append(annotation)
542
- return self
543
-
544
- def record_input(self, inputs: dict):
545
- current_span_id = self.get_current_span()
546
- if current_span_id:
547
- span = self.span_id_to_span[current_span_id]
548
- # Ignore self parameter
549
- if "self" in inputs:
550
- del inputs["self"]
551
- span.inputs = inputs
552
-
553
- # Queue span with input data
554
- try:
555
- if self.background_span_service:
556
- self.background_span_service.queue_span(span, span_state="input")
557
- except Exception as e:
558
- judgeval_logger.warning(f"Failed to queue span with input data: {e}")
559
-
560
- def record_agent_name(self, agent_name: str):
561
- current_span_id = self.get_current_span()
562
- if current_span_id:
563
- span = self.span_id_to_span[current_span_id]
564
- span.agent_name = agent_name
565
-
566
- # Queue span with agent_name data
567
- if self.background_span_service:
568
- self.background_span_service.queue_span(span, span_state="agent_name")
569
-
570
- def record_state_before(self, state: dict):
571
- """Records the agent's state before a tool execution on the current span.
572
-
573
- Args:
574
- state: A dictionary representing the agent's state.
575
- """
576
- current_span_id = self.get_current_span()
577
- if current_span_id:
578
- span = self.span_id_to_span[current_span_id]
579
- span.state_before = state
580
-
581
- # Queue span with state_before data
582
- if self.background_span_service:
583
- self.background_span_service.queue_span(span, span_state="state_before")
584
-
585
- def record_state_after(self, state: dict):
586
- """Records the agent's state after a tool execution on the current span.
587
-
588
- Args:
589
- state: A dictionary representing the agent's state.
590
- """
591
- current_span_id = self.get_current_span()
592
- if current_span_id:
593
- span = self.span_id_to_span[current_span_id]
594
- span.state_after = state
595
-
596
- # Queue span with state_after data
597
- if self.background_span_service:
598
- self.background_span_service.queue_span(span, span_state="state_after")
599
-
600
- async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
601
- """Helper method to update the output of a trace entry once the coroutine completes"""
602
- try:
603
- result = await coroutine
604
- setattr(span, field, result)
605
-
606
- # Queue span with output data now that coroutine is complete
607
- if self.background_span_service and field == "output":
608
- self.background_span_service.queue_span(span, span_state="output")
609
-
610
- return result
611
- except Exception as e:
612
- setattr(span, field, f"Error: {str(e)}")
613
-
614
- # Queue span even if there was an error
615
- if self.background_span_service and field == "output":
616
- self.background_span_service.queue_span(span, span_state="output")
617
-
618
- raise
619
-
620
- def record_output(self, output: Any):
621
- current_span_id = self.get_current_span()
622
- if current_span_id:
623
- span = self.span_id_to_span[current_span_id]
624
- span.output = "<pending>" if inspect.iscoroutine(output) else output
625
-
626
- if inspect.iscoroutine(output):
627
- asyncio.create_task(self._update_coroutine(span, output, "output"))
628
-
629
- # # Queue span with output data (unless it's pending)
630
- if self.background_span_service and not inspect.iscoroutine(output):
631
- self.background_span_service.queue_span(span, span_state="output")
632
-
633
- return span # Return the created entry
634
- # Removed else block - original didn't have one
635
- return None # Return None if no span_id found
636
-
637
- def record_usage(self, usage: TraceUsage):
638
- current_span_id = self.get_current_span()
639
- if current_span_id:
640
- span = self.span_id_to_span[current_span_id]
641
- span.usage = usage
642
-
643
- # Queue span with usage data
644
- if self.background_span_service:
645
- self.background_span_service.queue_span(span, span_state="usage")
646
-
647
- return span # Return the created entry
648
- # Removed else block - original didn't have one
649
- return None # Return None if no span_id found
650
-
651
- def record_error(self, error: Dict[str, Any]):
652
- current_span_id = self.get_current_span()
653
- if current_span_id:
654
- span = self.span_id_to_span[current_span_id]
655
- span.error = error
656
-
657
- # Queue span with error data
658
- if self.background_span_service:
659
- self.background_span_service.queue_span(span, span_state="error")
660
-
661
- return span
662
- return None
663
-
664
- def add_span(self, span: TraceSpan):
665
- """Add a trace span to this trace context"""
666
- self.trace_spans.append(span)
667
- self.span_id_to_span[span.span_id] = span
668
- return self
669
-
670
- def print(self):
671
- """Print the complete trace with proper visual structure"""
672
- for span in self.trace_spans:
673
- span.print_span()
674
-
675
- def get_duration(self) -> float:
676
- """
677
- Get the total duration of this trace
678
- """
679
- if self.start_time is None:
680
- return 0.0 # No duration if trace hasn't been saved yet
681
- return time.time() - self.start_time
682
-
683
- def save(self, final_save: bool = False) -> Tuple[str, dict]:
684
- """
685
- Save the current trace to the database with rate limiting checks.
686
- First checks usage limits, then upserts the trace if allowed.
687
-
688
- Args:
689
- final_save: Whether this is the final save (updates usage counters)
690
-
691
- Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
692
- """
693
-
694
- # Calculate total elapsed time
695
- total_duration = self.get_duration()
696
-
697
- # Create trace document
698
- trace_data = {
699
- "trace_id": self.trace_id,
700
- "name": self.name,
701
- "project_name": self.project_name,
702
- "created_at": datetime.fromtimestamp(
703
- self.start_time or time.time(), timezone.utc
704
- ).isoformat(),
705
- "duration": total_duration,
706
- "trace_spans": [span.model_dump() for span in self.trace_spans],
707
- "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
708
- "offline_mode": self.tracer.offline_mode,
709
- "parent_trace_id": self.parent_trace_id,
710
- "parent_name": self.parent_name,
711
- "customer_id": self.customer_id,
712
- "tags": self.tags,
713
- "metadata": self.metadata,
714
- "update_id": self.update_id,
715
- }
716
-
717
- # If usage check passes, upsert the trace
718
- server_response = self.trace_manager_client.upsert_trace(
719
- trace_data,
720
- offline_mode=self.tracer.offline_mode,
721
- show_link=not final_save, # Show link only on initial save, not final save
722
- final_save=final_save, # Pass final_save to control S3 saving
723
- )
724
-
725
- if self.start_time is None:
726
- self.start_time = time.time()
727
-
728
- self.update_id += 1
729
-
730
- # Upload annotations
731
- # TODO: batch to the log endpoint
732
- for annotation in self.annotations:
733
- self.trace_manager_client.save_annotation(annotation)
734
- return self.trace_id, server_response
735
-
736
- def delete(self):
737
- return self.trace_manager_client.delete_trace(self.trace_id)
738
-
739
- def update_metadata(self, metadata: dict):
740
- """
741
- Set metadata for this trace.
742
-
743
- Args:
744
- metadata: Metadata as a dictionary
745
-
746
- Supported keys:
747
- - customer_id: ID of the customer using this trace
748
- - tags: List of tags for this trace
749
- - has_notification: Whether this trace has a notification
750
- - name: Name of the trace
751
- """
752
- for k, v in metadata.items():
753
- if k == "customer_id":
754
- if v is not None:
755
- self.customer_id = str(v)
756
- else:
757
- self.customer_id = None
758
- elif k == "tags":
759
- if isinstance(v, list):
760
- # Validate that all items in the list are of the expected types
761
- for item in v:
762
- if not isinstance(item, (str, set, tuple)):
763
- raise ValueError(
764
- f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
765
- )
766
- self.tags = v
767
- else:
768
- raise ValueError(
769
- f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
770
- )
771
- elif k == "has_notification":
772
- if not isinstance(v, bool):
773
- raise ValueError(
774
- f"has_notification must be a boolean, got {type(v)}"
775
- )
776
- self.has_notification = v
777
- elif k == "name":
778
- self.name = v
779
- else:
780
- self.metadata[k] = v
781
-
782
- def set_customer_id(self, customer_id: str):
783
- """
784
- Set the customer ID for this trace.
785
-
786
- Args:
787
- customer_id: The customer ID to set
788
- """
789
- self.update_metadata({"customer_id": customer_id})
790
-
791
- def set_tags(self, tags: List[Union[str, set, tuple]]):
792
- """
793
- Set the tags for this trace.
794
-
795
- Args:
796
- tags: List of tags to set
797
- """
798
- self.update_metadata({"tags": tags})
799
-
800
-
801
- def _capture_exception_for_trace(
802
- current_trace: Optional["TraceClient"], exc_info: ExcInfo
803
- ):
804
- if not current_trace:
805
- return
806
-
807
- exc_type, exc_value, exc_traceback_obj = exc_info
808
- formatted_exception = {
809
- "type": exc_type.__name__ if exc_type else "UnknownExceptionType",
810
- "message": str(exc_value) if exc_value else "No exception message",
811
- "traceback": traceback.format_tb(exc_traceback_obj)
812
- if exc_traceback_obj
813
- else [],
814
- }
815
-
816
- # This is where we specially handle exceptions that we might want to collect additional data for.
817
- # When we do this, always try checking the module from sys.modules instead of importing. This will
818
- # Let us support a wider range of exceptions without needing to import them for all clients.
819
-
820
- # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
821
- # The alternative is to hand select libraries we want from sys.modules and check for them:
822
- # As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
823
-
824
- # General HTTP Like errors
825
- try:
826
- url = getattr(getattr(exc_value, "request", None), "url", None)
827
- status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
828
- if status_code:
829
- formatted_exception["http"] = {
830
- "url": url if url else "Unknown URL",
831
- "status_code": status_code if status_code else None,
832
- }
833
- except Exception:
834
- pass
835
-
836
- current_trace.record_error(formatted_exception)
837
-
838
- # Queue the span with error state through background service
839
- if current_trace.background_span_service:
840
- current_span_id = current_trace.get_current_span()
841
- if current_span_id and current_span_id in current_trace.span_id_to_span:
842
- error_span = current_trace.span_id_to_span[current_span_id]
843
- current_trace.background_span_service.queue_span(
844
- error_span, span_state="error"
845
- )
846
-
847
-
848
- class BackgroundSpanService:
849
- """
850
- Background service for queueing and batching trace spans for efficient saving.
851
-
852
- This service:
853
- - Queues spans as they complete
854
- - Batches them for efficient network usage
855
- - Sends spans periodically or when batches reach a certain size
856
- - Handles automatic flushing when the main event terminates
857
- """
858
-
859
- def __init__(
860
- self,
861
- judgment_api_key: str,
862
- organization_id: str,
863
- batch_size: int = 10,
864
- flush_interval: float = 5.0,
865
- num_workers: int = 1,
866
- ):
867
- """
868
- Initialize the background span service.
869
-
870
- Args:
871
- judgment_api_key: API key for Judgment service
872
- organization_id: Organization ID
873
- batch_size: Number of spans to batch before sending (default: 10)
874
- flush_interval: Time in seconds between automatic flushes (default: 5.0)
875
- num_workers: Number of worker threads to process the queue (default: 1)
876
- """
877
- self.judgment_api_key = judgment_api_key
878
- self.organization_id = organization_id
879
- self.batch_size = batch_size
880
- self.flush_interval = flush_interval
881
- self.num_workers = max(1, num_workers) # Ensure at least 1 worker
882
-
883
- # Queue for pending spans
884
- self._span_queue: queue.Queue[Dict[str, Any]] = queue.Queue()
885
-
886
- # Background threads for processing spans
887
- self._worker_threads: List[threading.Thread] = []
888
- self._shutdown_event = threading.Event()
889
-
890
- # Track spans that have been sent
891
- # self._sent_spans = set()
892
-
893
- # Register cleanup on exit
894
- atexit.register(self.shutdown)
895
-
896
- # Start the background workers
897
- self._start_workers()
898
-
899
- def _start_workers(self):
900
- """Start the background worker threads."""
901
- for i in range(self.num_workers):
902
- if len(self._worker_threads) < self.num_workers:
903
- worker_thread = threading.Thread(
904
- target=self._worker_loop, daemon=True, name=f"SpanWorker-{i + 1}"
905
- )
906
- worker_thread.start()
907
- self._worker_threads.append(worker_thread)
908
-
909
- def _worker_loop(self):
910
- """Main worker loop that processes spans in batches."""
911
- batch = []
912
- last_flush_time = time.time()
913
- pending_task_count = (
914
- 0 # Track how many tasks we've taken from queue but not marked done
915
- )
916
-
917
- while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
918
- try:
919
- # First, do a blocking get to wait for at least one item
920
- if not batch: # Only block if we don't have items already
921
- try:
922
- span_data = self._span_queue.get(timeout=1.0)
923
- batch.append(span_data)
924
- pending_task_count += 1
925
- except queue.Empty:
926
- # No new spans, continue to check for flush conditions
927
- pass
928
-
929
- # Then, do non-blocking gets to drain any additional available items
930
- # up to our batch size limit
931
- while len(batch) < self.batch_size:
932
- try:
933
- span_data = self._span_queue.get_nowait() # Non-blocking
934
- batch.append(span_data)
935
- pending_task_count += 1
936
- except queue.Empty:
937
- # No more items immediately available
938
- break
939
-
940
- current_time = time.time()
941
- should_flush = len(batch) >= self.batch_size or (
942
- batch and (current_time - last_flush_time) >= self.flush_interval
943
- )
944
-
945
- if should_flush and batch:
946
- self._send_batch(batch)
947
-
948
- # Only mark tasks as done after successful sending
949
- for _ in range(pending_task_count):
950
- self._span_queue.task_done()
951
- pending_task_count = 0 # Reset counter
952
-
953
- batch.clear()
954
- last_flush_time = current_time
955
-
956
- except Exception as e:
957
- judgeval_logger.warning(f"Error in span service worker loop: {e}")
958
- # On error, still need to mark tasks as done to prevent hanging
959
- for _ in range(pending_task_count):
960
- self._span_queue.task_done()
961
- pending_task_count = 0
962
- batch.clear()
963
-
964
- # Final flush on shutdown
965
- if batch:
966
- self._send_batch(batch)
967
- # Mark remaining tasks as done
968
- for _ in range(pending_task_count):
969
- self._span_queue.task_done()
970
-
971
- def _send_batch(self, batch: List[Dict[str, Any]]):
972
- """
973
- Send a batch of spans to the server.
974
-
975
- Args:
976
- batch: List of span dictionaries to send
977
- """
978
- if not batch:
979
- return
980
-
981
- try:
982
- # Group items by type for different endpoints
983
- spans_to_send = []
984
- evaluation_runs_to_send = []
985
-
986
- for item in batch:
987
- if item["type"] == "span":
988
- spans_to_send.append(item["data"])
989
- elif item["type"] == "evaluation_run":
990
- evaluation_runs_to_send.append(item["data"])
991
-
992
- # Send spans if any
993
- if spans_to_send:
994
- self._send_spans_batch(spans_to_send)
995
-
996
- # Send evaluation runs if any
997
- if evaluation_runs_to_send:
998
- self._send_evaluation_runs_batch(evaluation_runs_to_send)
999
-
1000
- except Exception as e:
1001
- judgeval_logger.warning(f"Failed to send batch: {e}")
1002
-
1003
- def _send_spans_batch(self, spans: List[Dict[str, Any]]):
1004
- """Send a batch of spans to the spans endpoint."""
1005
- payload = {"spans": spans, "organization_id": self.organization_id}
1006
-
1007
- # Serialize with fallback encoder
1008
- def fallback_encoder(obj):
1009
- try:
1010
- return repr(obj)
1011
- except Exception:
1012
- try:
1013
- return str(obj)
1014
- except Exception as e:
1015
- return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1016
-
1017
- try:
1018
- serialized_data = json.dumps(payload, default=fallback_encoder)
1019
-
1020
- # Send the actual HTTP request to the batch endpoint
1021
- response = requests.post(
1022
- JUDGMENT_TRACES_SPANS_BATCH_API_URL,
1023
- data=serialized_data,
1024
- headers={
1025
- "Content-Type": "application/json",
1026
- "Authorization": f"Bearer {self.judgment_api_key}",
1027
- "X-Organization-Id": self.organization_id,
1028
- },
1029
- verify=True,
1030
- timeout=30, # Add timeout to prevent hanging
1031
- )
1032
-
1033
- if response.status_code != HTTPStatus.OK:
1034
- judgeval_logger.warning(
1035
- f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
1036
- )
1037
-
1038
- except RequestException as e:
1039
- judgeval_logger.warning(f"Network error sending spans batch: {e}")
1040
- except Exception as e:
1041
- judgeval_logger.warning(f"Failed to serialize or send spans batch: {e}")
1042
-
1043
- def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
1044
- """Send a batch of evaluation runs with their associated span data to the endpoint."""
1045
- # Structure payload to include both evaluation run data and span data
1046
- evaluation_entries = []
1047
- for eval_data in evaluation_runs:
1048
- # eval_data already contains the evaluation run data (no need to access ['data'])
1049
- entry = {
1050
- "evaluation_run": {
1051
- # Extract evaluation run fields (excluding span-specific fields)
1052
- key: value
1053
- for key, value in eval_data.items()
1054
- if key not in ["associated_span_id", "span_data", "queued_at"]
1055
- },
1056
- "associated_span": {
1057
- "span_id": eval_data.get("associated_span_id"),
1058
- "span_data": eval_data.get("span_data"),
1059
- },
1060
- "queued_at": eval_data.get("queued_at"),
1061
- }
1062
- evaluation_entries.append(entry)
1063
-
1064
- payload = {
1065
- "organization_id": self.organization_id,
1066
- "evaluation_entries": evaluation_entries, # Each entry contains both eval run + span data
1067
- }
1068
-
1069
- # Serialize with fallback encoder
1070
- def fallback_encoder(obj):
1071
- try:
1072
- return repr(obj)
1073
- except Exception:
1074
- try:
1075
- return str(obj)
1076
- except Exception as e:
1077
- return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1078
-
1079
- try:
1080
- serialized_data = json.dumps(payload, default=fallback_encoder)
1081
-
1082
- # Send the actual HTTP request to the batch endpoint
1083
- response = requests.post(
1084
- JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
1085
- data=serialized_data,
1086
- headers={
1087
- "Content-Type": "application/json",
1088
- "Authorization": f"Bearer {self.judgment_api_key}",
1089
- "X-Organization-Id": self.organization_id,
1090
- },
1091
- verify=True,
1092
- timeout=30, # Add timeout to prevent hanging
1093
- )
1094
-
1095
- if response.status_code != HTTPStatus.OK:
1096
- judgeval_logger.warning(
1097
- f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
1098
- )
1099
-
1100
- except RequestException as e:
1101
- judgeval_logger.warning(f"Network error sending evaluation runs batch: {e}")
1102
- except Exception as e:
1103
- judgeval_logger.warning(f"Failed to send evaluation runs batch: {e}")
1104
-
1105
- def queue_span(self, span: TraceSpan, span_state: str = "input"):
1106
- """
1107
- Queue a span for background sending.
1108
-
1109
- Args:
1110
- span: The TraceSpan object to queue
1111
- span_state: State of the span ("input", "output", "completed")
1112
- """
1113
- if not self._shutdown_event.is_set():
1114
- # Increment update_id when queueing the span
1115
- span.increment_update_id()
1116
- span_data = {
1117
- "type": "span",
1118
- "data": {
1119
- **span.model_dump(),
1120
- "span_state": span_state,
1121
- "queued_at": time.time(),
1122
- },
1123
- }
1124
- self._span_queue.put(span_data)
1125
-
1126
- def queue_evaluation_run(
1127
- self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan
1128
- ):
1129
- """
1130
- Queue an evaluation run for background sending.
1131
-
1132
- Args:
1133
- evaluation_run: The EvaluationRun object to queue
1134
- span_id: The span ID associated with this evaluation run
1135
- span_data: The span data at the time of evaluation (to avoid race conditions)
1136
- """
1137
- if not self._shutdown_event.is_set():
1138
- eval_data = {
1139
- "type": "evaluation_run",
1140
- "data": {
1141
- **evaluation_run.model_dump(),
1142
- "associated_span_id": span_id,
1143
- "span_data": span_data.model_dump(), # Include span data to avoid race conditions
1144
- "queued_at": time.time(),
1145
- },
1146
- }
1147
- self._span_queue.put(eval_data)
1148
-
1149
- def flush(self):
1150
- """Force immediate sending of all queued spans."""
1151
- try:
1152
- # Wait for the queue to be processed
1153
- self._span_queue.join()
1154
- except Exception as e:
1155
- judgeval_logger.warning(f"Error during flush: {e}")
1156
-
1157
- def shutdown(self):
1158
- """Shutdown the background service and flush remaining spans."""
1159
- if self._shutdown_event.is_set():
1160
- return
1161
-
1162
- try:
1163
- # Signal shutdown to stop new items from being queued
1164
- self._shutdown_event.set()
1165
-
1166
- # Try to flush any remaining spans
1167
- try:
1168
- self.flush()
1169
- except Exception as e:
1170
- judgeval_logger.warning(f"Error during final flush: {e}")
1171
- except Exception as e:
1172
- judgeval_logger.warning(f"Error during BackgroundSpanService shutdown: {e}")
1173
- finally:
1174
- # Clear the worker threads list (daemon threads will be killed automatically)
1175
- self._worker_threads.clear()
1176
-
1177
- def get_queue_size(self) -> int:
1178
- """Get the current size of the span queue."""
1179
- return self._span_queue.qsize()
1180
-
1181
-
1182
- class _DeepTracer:
1183
- _instance: Optional["_DeepTracer"] = None
1184
- _lock: threading.Lock = threading.Lock()
1185
- _refcount: int = 0
1186
- _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
1187
- "_deep_profiler_span_stack", default=[]
1188
- )
1189
- _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
1190
- "_deep_profiler_skip_stack", default=[]
1191
- )
1192
- _original_sys_trace: Optional[Callable] = None
1193
- _original_threading_trace: Optional[Callable] = None
1194
-
1195
- def __init__(self, tracer: "Tracer"):
1196
- self._tracer = tracer
1197
-
1198
- def _get_qual_name(self, frame) -> str:
1199
- func_name = frame.f_code.co_name
1200
- module_name = frame.f_globals.get("__name__", "unknown_module")
1201
-
1202
- try:
1203
- func = frame.f_globals.get(func_name)
1204
- if func is None:
1205
- return f"{module_name}.{func_name}"
1206
- if hasattr(func, "__qualname__"):
1207
- return f"{module_name}.{func.__qualname__}"
1208
- return f"{module_name}.{func_name}"
1209
- except Exception:
1210
- return f"{module_name}.{func_name}"
1211
-
1212
- def __new__(cls, tracer: "Tracer"):
1213
- with cls._lock:
1214
- if cls._instance is None:
1215
- cls._instance = super().__new__(cls)
1216
- return cls._instance
1217
-
1218
- def _should_trace(self, frame):
1219
- # Skip stack is maintained by the tracer as an optimization to skip earlier
1220
- # frames in the call stack that we've already determined should be skipped
1221
- skip_stack = self._skip_stack.get()
1222
- if len(skip_stack) > 0:
1223
- return False
1224
-
1225
- func_name = frame.f_code.co_name
1226
- module_name = frame.f_globals.get("__name__", None)
1227
- func = frame.f_globals.get(func_name)
1228
- if func and (
1229
- hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
1230
- ):
1231
- return False
1232
-
1233
- if (
1234
- not module_name
1235
- or func_name.startswith("<") # ex: <listcomp>
1236
- or func_name.startswith("__")
1237
- and func_name != "__call__" # dunders
1238
- or not self._is_user_code(frame.f_code.co_filename)
1239
- ):
1240
- return False
1241
-
1242
- return True
1243
-
1244
- @functools.cache
1245
- def _is_user_code(self, filename: str):
1246
- return (
1247
- bool(filename)
1248
- and not filename.startswith("<")
1249
- and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
1250
- )
1251
-
1252
- def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
1253
- """Cooperative trace function for sys.settrace that chains with existing tracers."""
1254
- # First, call the original sys trace function if it exists
1255
- original_result = None
1256
- if self._original_sys_trace:
1257
- try:
1258
- original_result = self._original_sys_trace(frame, event, arg)
1259
- except Exception:
1260
- # If the original tracer fails, continue with our tracing
1261
- pass
1262
-
1263
- # Then do our own tracing
1264
- our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
1265
-
1266
- # Return our tracer to continue tracing, but respect the original's decision
1267
- # If the original tracer returned None (stop tracing), we should respect that
1268
- if original_result is None and self._original_sys_trace:
1269
- return None
1270
-
1271
- return our_result or original_result
1272
-
1273
- def _cooperative_threading_trace(
1274
- self, frame: types.FrameType, event: str, arg: Any
1275
- ):
1276
- """Cooperative trace function for threading.settrace that chains with existing tracers."""
1277
- # First, call the original threading trace function if it exists
1278
- original_result = None
1279
- if self._original_threading_trace:
1280
- try:
1281
- original_result = self._original_threading_trace(frame, event, arg)
1282
- except Exception:
1283
- # If the original tracer fails, continue with our tracing
1284
- pass
1285
-
1286
- # Then do our own tracing
1287
- our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
1288
-
1289
- # Return our tracer to continue tracing, but respect the original's decision
1290
- # If the original tracer returned None (stop tracing), we should respect that
1291
- if original_result is None and self._original_threading_trace:
1292
- return None
1293
-
1294
- return our_result or original_result
1295
-
1296
- def _trace(
1297
- self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
1298
- ):
1299
- frame.f_trace_lines = False
1300
- frame.f_trace_opcodes = False
1301
-
1302
- if not self._should_trace(frame):
1303
- return
1304
-
1305
- if event not in ("call", "return", "exception"):
1306
- return
1307
-
1308
- current_trace = self._tracer.get_current_trace()
1309
- if not current_trace:
1310
- return
1311
-
1312
- parent_span_id = self._tracer.get_current_span()
1313
- if not parent_span_id:
1314
- return
1315
-
1316
- qual_name = self._get_qual_name(frame)
1317
- instance_name = None
1318
- if "self" in frame.f_locals:
1319
- instance = frame.f_locals["self"]
1320
- class_name = instance.__class__.__name__
1321
- class_identifiers = getattr(self._tracer, "class_identifiers", {})
1322
- instance_name = get_instance_prefixed_name(
1323
- instance, class_name, class_identifiers
1324
- )
1325
- skip_stack = self._skip_stack.get()
1326
-
1327
- if event == "call":
1328
- # If we have entries in the skip stack and the current qual_name matches the top entry,
1329
- # push it again to track nesting depth and skip
1330
- # As an optimization, we only care about duplicate qual_names.
1331
- if skip_stack:
1332
- if qual_name == skip_stack[-1]:
1333
- skip_stack.append(qual_name)
1334
- self._skip_stack.set(skip_stack)
1335
- return
1336
-
1337
- should_trace = self._should_trace(frame)
1338
-
1339
- if not should_trace:
1340
- if not skip_stack:
1341
- self._skip_stack.set([qual_name])
1342
- return
1343
- elif event == "return":
1344
- # If we have entries in skip stack and current qual_name matches the top entry,
1345
- # pop it to track exiting from the skipped section
1346
- if skip_stack and qual_name == skip_stack[-1]:
1347
- skip_stack.pop()
1348
- self._skip_stack.set(skip_stack)
1349
- return
1350
-
1351
- if skip_stack:
1352
- return
1353
-
1354
- span_stack = self._span_stack.get()
1355
- if event == "call":
1356
- if not self._should_trace(frame):
1357
- return
1358
-
1359
- span_id = str(uuid.uuid4())
1360
-
1361
- parent_depth = current_trace._span_depths.get(parent_span_id, 0)
1362
- depth = parent_depth + 1
1363
-
1364
- current_trace._span_depths[span_id] = depth
1365
-
1366
- start_time = time.time()
1367
-
1368
- span_stack.append(
1369
- {
1370
- "span_id": span_id,
1371
- "parent_span_id": parent_span_id,
1372
- "function": qual_name,
1373
- "start_time": start_time,
1374
- }
1375
- )
1376
- self._span_stack.set(span_stack)
1377
-
1378
- token = self._tracer.set_current_span(span_id)
1379
- frame.f_locals["_judgment_span_token"] = token
1380
-
1381
- span = TraceSpan(
1382
- span_id=span_id,
1383
- trace_id=current_trace.trace_id,
1384
- depth=depth,
1385
- message=qual_name,
1386
- created_at=start_time,
1387
- span_type="span",
1388
- parent_span_id=parent_span_id,
1389
- function=qual_name,
1390
- agent_name=instance_name,
1391
- )
1392
- current_trace.add_span(span)
1393
-
1394
- inputs = {}
1395
- try:
1396
- args_info = inspect.getargvalues(frame)
1397
- for arg in args_info.args:
1398
- try:
1399
- inputs[arg] = args_info.locals.get(arg)
1400
- except Exception:
1401
- inputs[arg] = "<<Unserializable>>"
1402
- current_trace.record_input(inputs)
1403
- except Exception as e:
1404
- current_trace.record_input({"error": str(e)})
1405
-
1406
- elif event == "return":
1407
- if not span_stack:
1408
- return
1409
-
1410
- current_id = self._tracer.get_current_span()
1411
-
1412
- span_data = None
1413
- for i, entry in enumerate(reversed(span_stack)):
1414
- if entry["span_id"] == current_id:
1415
- span_data = span_stack.pop(-(i + 1))
1416
- self._span_stack.set(span_stack)
1417
- break
1418
-
1419
- if not span_data:
1420
- return
1421
-
1422
- start_time = span_data["start_time"]
1423
- duration = time.time() - start_time
1424
-
1425
- current_trace.span_id_to_span[span_data["span_id"]].duration = duration
1426
-
1427
- if arg is not None:
1428
- # exception handling will take priority.
1429
- current_trace.record_output(arg)
1430
-
1431
- if span_data["span_id"] in current_trace._span_depths:
1432
- del current_trace._span_depths[span_data["span_id"]]
1433
-
1434
- if span_stack:
1435
- self._tracer.set_current_span(span_stack[-1]["span_id"])
1436
- else:
1437
- self._tracer.set_current_span(span_data["parent_span_id"])
1438
-
1439
- if "_judgment_span_token" in frame.f_locals:
1440
- self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
1441
-
1442
- elif event == "exception":
1443
- exc_type = arg[0]
1444
- if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
1445
- return
1446
- _capture_exception_for_trace(current_trace, arg)
1447
-
1448
- return continuation_func
1449
-
1450
- def __enter__(self):
1451
- with self._lock:
1452
- self._refcount += 1
1453
- if self._refcount == 1:
1454
- # Store the existing trace functions before setting ours
1455
- self._original_sys_trace = sys.gettrace()
1456
- self._original_threading_trace = threading.gettrace()
1457
-
1458
- self._skip_stack.set([])
1459
- self._span_stack.set([])
1460
-
1461
- sys.settrace(self._cooperative_sys_trace)
1462
- threading.settrace(self._cooperative_threading_trace)
1463
- return self
1464
-
1465
- def __exit__(self, exc_type, exc_val, exc_tb):
1466
- with self._lock:
1467
- self._refcount -= 1
1468
- if self._refcount == 0:
1469
- # Restore the original trace functions instead of setting to None
1470
- sys.settrace(self._original_sys_trace)
1471
- threading.settrace(self._original_threading_trace)
1472
-
1473
- # Clean up the references
1474
- self._original_sys_trace = None
1475
- self._original_threading_trace = None
1476
-
1477
-
1478
- # Below commented out function isn't used anymore?
1479
-
1480
- # def log(self, message: str, level: str = "info"):
1481
- # """ Log a message with the span context """
1482
- # current_trace = self._tracer.get_current_trace()
1483
- # if current_trace:
1484
- # current_trace.log(message, level)
1485
- # else:
1486
- # print(f"[{level}] {message}")
1487
- # current_trace.record_output({"log": message})
1488
-
1489
-
1490
- class Tracer:
1491
- # Tracer.current_trace class variable is currently used in wrap()
1492
- # TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
1493
- # Should be fine to do so as long as we keep Tracer as a singleton
1494
- current_trace: Optional[TraceClient] = None
1495
- # current_span_id: Optional[str] = None
1496
-
1497
- trace_across_async_contexts: bool = (
1498
- False # BY default, we don't trace across async contexts
1499
- )
1500
-
1501
- def __init__(
1502
- self,
1503
- api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
1504
- organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
1505
- project_name: str | None = None,
1506
- deep_tracing: bool = False, # Deep tracing is disabled by default
1507
- enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
1508
- == "true",
1509
- enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
1510
- == "true",
1511
- # S3 configuration
1512
- use_s3: bool = False,
1513
- s3_bucket_name: Optional[str] = None,
1514
- s3_aws_access_key_id: Optional[str] = None,
1515
- s3_aws_secret_access_key: Optional[str] = None,
1516
- s3_region_name: Optional[str] = None,
1517
- trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1518
- # Background span service configuration
1519
- span_batch_size: int = 50, # Number of spans to batch before sending
1520
- span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
1521
- span_num_workers: int = 10, # Number of worker threads for span processing
1522
- ):
1523
- try:
1524
- if not api_key:
1525
- raise ValueError(
1526
- "api_key parameter must be provided. Please provide a valid API key value or set the JUDGMENT_API_KEY environment variable"
1527
- )
1528
-
1529
- if not organization_id:
1530
- raise ValueError(
1531
- "organization_id parameter must be provided. Please provide a valid organization ID value or set the JUDGMENT_ORG_ID environment variable"
1532
- )
1533
-
1534
- try:
1535
- result, response = validate_api_key(api_key)
1536
- except Exception as e:
1537
- judgeval_logger.error(
1538
- f"Issue with verifying API key, disabling monitoring: {e}"
1539
- )
1540
- enable_monitoring = False
1541
- result = True
1542
-
1543
- if not result:
1544
- raise ValueError(f"Issue with passed in Judgment API key: {response}")
1545
-
1546
- if use_s3 and not s3_bucket_name:
1547
- raise ValueError("S3 bucket name must be provided when use_s3 is True")
1548
-
1549
- self.api_key: str = api_key
1550
- self.project_name: str = project_name or "default_project"
1551
- self.organization_id: str = organization_id
1552
- self.traces: List[Trace] = []
1553
- self.enable_monitoring: bool = enable_monitoring
1554
- self.enable_evaluations: bool = enable_evaluations
1555
- self.class_identifiers: Dict[
1556
- str, str
1557
- ] = {} # Dictionary to store class identifiers
1558
- self.span_id_to_previous_span_id: Dict[str, str | None] = {}
1559
- self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
1560
- self.current_span_id: Optional[str] = None
1561
- self.current_trace: Optional[TraceClient] = None
1562
- self.trace_across_async_contexts: bool = trace_across_async_contexts
1563
- Tracer.trace_across_async_contexts = trace_across_async_contexts
1564
-
1565
- # Initialize S3 storage if enabled
1566
- self.use_s3 = use_s3
1567
- if use_s3:
1568
- from judgeval.common.s3_storage import S3Storage
1569
-
1570
- try:
1571
- self.s3_storage = S3Storage(
1572
- bucket_name=s3_bucket_name,
1573
- aws_access_key_id=s3_aws_access_key_id,
1574
- aws_secret_access_key=s3_aws_secret_access_key,
1575
- region_name=s3_region_name,
1576
- )
1577
- except Exception as e:
1578
- judgeval_logger.error(
1579
- f"Issue with initializing S3 storage, disabling S3: {e}"
1580
- )
1581
- self.use_s3 = False
1582
-
1583
- self.offline_mode = False # This is used to differentiate traces between online and offline (IE experiments vs monitoring page)
1584
- self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1585
-
1586
- # Initialize background span service
1587
- self.background_span_service: Optional[BackgroundSpanService] = None
1588
- self.background_span_service = BackgroundSpanService(
1589
- judgment_api_key=api_key,
1590
- organization_id=organization_id,
1591
- batch_size=span_batch_size,
1592
- flush_interval=span_flush_interval,
1593
- num_workers=span_num_workers,
1594
- )
1595
- except Exception as e:
1596
- judgeval_logger.error(
1597
- f"Issue with initializing Tracer: {e}. Disabling monitoring and evaluations."
1598
- )
1599
- self.enable_monitoring = False
1600
- self.enable_evaluations = False
1601
-
1602
- def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
1603
- self.span_id_to_previous_span_id[span_id] = self.current_span_id
1604
- self.current_span_id = span_id
1605
- Tracer.current_span_id = span_id
1606
- try:
1607
- token = current_span_var.set(span_id)
1608
- except Exception:
1609
- token = None
1610
- return token
1611
-
1612
- def get_current_span(self) -> Optional[str]:
1613
- try:
1614
- current_span_var_val = current_span_var.get()
1615
- except Exception:
1616
- current_span_var_val = None
1617
- return (
1618
- (self.current_span_id or current_span_var_val)
1619
- if self.trace_across_async_contexts
1620
- else current_span_var_val
1621
- )
1622
-
1623
- def reset_current_span(
1624
- self,
1625
- token: Optional[contextvars.Token[str | None]] = None,
1626
- span_id: Optional[str] = None,
1627
- ):
1628
- try:
1629
- if token:
1630
- current_span_var.reset(token)
1631
- except Exception:
1632
- pass
1633
- if not span_id:
1634
- span_id = self.current_span_id
1635
- if span_id:
1636
- self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
1637
- Tracer.current_span_id = self.current_span_id
1638
-
1639
- def set_current_trace(
1640
- self, trace: TraceClient
1641
- ) -> Optional[contextvars.Token[TraceClient | None]]:
1642
- """
1643
- Set the current trace context in contextvars
1644
- """
1645
- self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
1646
- self.current_trace = trace
1647
- Tracer.current_trace = trace
1648
- try:
1649
- token = current_trace_var.set(trace)
1650
- except Exception:
1651
- token = None
1652
- return token
1653
-
1654
- def get_current_trace(self) -> Optional[TraceClient]:
1655
- """
1656
- Get the current trace context.
1657
-
1658
- Tries to get the trace client from the context variable first.
1659
- If not found (e.g., context lost across threads/tasks),
1660
- it falls back to the active trace client managed by the callback handler.
1661
- """
1662
- try:
1663
- current_trace_var_val = current_trace_var.get()
1664
- except Exception:
1665
- current_trace_var_val = None
1666
- return (
1667
- (self.current_trace or current_trace_var_val)
1668
- if self.trace_across_async_contexts
1669
- else current_trace_var_val
1670
- )
1671
-
1672
- def reset_current_trace(
1673
- self,
1674
- token: Optional[contextvars.Token[TraceClient | None]] = None,
1675
- trace_id: Optional[str] = None,
1676
- ):
1677
- try:
1678
- if token:
1679
- current_trace_var.reset(token)
1680
- except Exception:
1681
- pass
1682
- if not trace_id and self.current_trace:
1683
- trace_id = self.current_trace.trace_id
1684
- if trace_id:
1685
- self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1686
- Tracer.current_trace = self.current_trace
1687
-
1688
- @contextmanager
1689
- def trace(
1690
- self, name: str, project_name: str | None = None
1691
- ) -> Generator[TraceClient, None, None]:
1692
- """Start a new trace context using a context manager"""
1693
- trace_id = str(uuid.uuid4())
1694
- project = project_name if project_name is not None else self.project_name
1695
-
1696
- # Get parent trace info from context
1697
- parent_trace = self.get_current_trace()
1698
- parent_trace_id = None
1699
- parent_name = None
1700
-
1701
- if parent_trace:
1702
- parent_trace_id = parent_trace.trace_id
1703
- parent_name = parent_trace.name
1704
-
1705
- trace = TraceClient(
1706
- self,
1707
- trace_id,
1708
- name,
1709
- project_name=project,
1710
- enable_monitoring=self.enable_monitoring,
1711
- enable_evaluations=self.enable_evaluations,
1712
- parent_trace_id=parent_trace_id,
1713
- parent_name=parent_name,
1714
- )
1715
-
1716
- # Set the current trace in context variables
1717
- token = self.set_current_trace(trace)
1718
-
1719
- # Automatically create top-level span
1720
- with trace.span(name or "unnamed_trace"):
1721
- try:
1722
- # Save the trace to the database to handle Evaluations' trace_id referential integrity
1723
- yield trace
1724
- finally:
1725
- # Reset the context variable
1726
- self.reset_current_trace(token)
1727
-
1728
- def log(self, msg: str, label: str = "log", score: int = 1):
1729
- """Log a message with the current span context"""
1730
- current_span_id = self.get_current_span()
1731
- current_trace = self.get_current_trace()
1732
- if current_span_id and current_trace:
1733
- annotation = TraceAnnotation(
1734
- span_id=current_span_id, text=msg, label=label, score=score
1735
- )
1736
- current_trace.add_annotation(annotation)
1737
-
1738
- rprint(f"[bold]{label}:[/bold] {msg}")
1739
-
1740
- def identify(
1741
- self,
1742
- identifier: str,
1743
- track_state: bool = False,
1744
- track_attributes: Optional[List[str]] = None,
1745
- field_mappings: Optional[Dict[str, str]] = None,
1746
- ):
1747
- """
1748
- Class decorator that associates a class with a custom identifier and enables state tracking.
1749
-
1750
- This decorator creates a mapping between the class name and the provided
1751
- identifier, which can be useful for tagging, grouping, or referencing
1752
- classes in a standardized way. It also enables automatic state capture
1753
- for instances of the decorated class when used with tracing.
1754
-
1755
- Args:
1756
- identifier: The identifier to associate with the decorated class.
1757
- This will be used as the instance name in traces.
1758
- track_state: Whether to automatically capture the state (attributes)
1759
- of instances before and after function execution. Defaults to False.
1760
- track_attributes: Optional list of specific attribute names to track.
1761
- If None, all non-private attributes (not starting with '_')
1762
- will be tracked when track_state=True.
1763
- field_mappings: Optional dictionary mapping internal attribute names to
1764
- display names in the captured state. For example:
1765
- {"system_prompt": "instructions"} will capture the
1766
- 'instructions' attribute as 'system_prompt' in the state.
1767
-
1768
- Example:
1769
- @tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
1770
- class User:
1771
- # Class implementation
1772
- """
1773
-
1774
- def decorator(cls):
1775
- class_name = cls.__name__
1776
- self.class_identifiers[class_name] = {
1777
- "identifier": identifier,
1778
- "track_state": track_state,
1779
- "track_attributes": track_attributes,
1780
- "field_mappings": field_mappings or {},
1781
- }
1782
- return cls
1783
-
1784
- return decorator
1785
-
1786
- def _capture_instance_state(
1787
- self, instance: Any, class_config: Dict[str, Any]
1788
- ) -> Dict[str, Any]:
1789
- """
1790
- Capture the state of an instance based on class configuration.
1791
- Args:
1792
- instance: The instance to capture the state of.
1793
- class_config: Configuration dictionary for state capture,
1794
- expected to contain 'track_attributes' and 'field_mappings'.
1795
- """
1796
- track_attributes = class_config.get("track_attributes")
1797
- field_mappings = class_config.get("field_mappings")
1798
-
1799
- if track_attributes:
1800
- state = {attr: getattr(instance, attr, None) for attr in track_attributes}
1801
- else:
1802
- state = {
1803
- k: v for k, v in instance.__dict__.items() if not k.startswith("_")
1804
- }
1805
-
1806
- if field_mappings:
1807
- state["field_mappings"] = field_mappings
1808
-
1809
- return state
1810
-
1811
- def _get_instance_state_if_tracked(self, args):
1812
- """
1813
- Extract instance state if the instance should be tracked.
1814
-
1815
- Returns the captured state dict if tracking is enabled, None otherwise.
1816
- """
1817
- if args and hasattr(args[0], "__class__"):
1818
- instance = args[0]
1819
- class_name = instance.__class__.__name__
1820
- if (
1821
- class_name in self.class_identifiers
1822
- and isinstance(self.class_identifiers[class_name], dict)
1823
- and self.class_identifiers[class_name].get("track_state", False)
1824
- ):
1825
- return self._capture_instance_state(
1826
- instance, self.class_identifiers[class_name]
1827
- )
1828
-
1829
- def _conditionally_capture_and_record_state(
1830
- self, trace_client_instance: TraceClient, args: tuple, is_before: bool
1831
- ):
1832
- """Captures instance state if tracked and records it via the trace_client."""
1833
- state = self._get_instance_state_if_tracked(args)
1834
- if state:
1835
- if is_before:
1836
- trace_client_instance.record_state_before(state)
1837
- else:
1838
- trace_client_instance.record_state_after(state)
1839
-
1840
- def observe(
1841
- self,
1842
- func=None,
1843
- *,
1844
- name=None,
1845
- span_type: SpanType = "span",
1846
- ):
1847
- """
1848
- Decorator to trace function execution with detailed entry/exit information.
1849
-
1850
- Args:
1851
- func: The function to decorate
1852
- name: Optional custom name for the span (defaults to function name)
1853
- span_type: Type of span (default "span").
1854
- """
1855
- # If monitoring is disabled, return the function as is
1856
- try:
1857
- if not self.enable_monitoring:
1858
- return func if func else lambda f: f
1859
-
1860
- if func is None:
1861
- return lambda f: self.observe(
1862
- f,
1863
- name=name,
1864
- span_type=span_type,
1865
- )
1866
-
1867
- # Use provided name or fall back to function name
1868
- original_span_name = name or func.__name__
1869
-
1870
- # Store custom attributes on the function object
1871
- func._judgment_span_name = original_span_name
1872
- func._judgment_span_type = span_type
1873
-
1874
- except Exception:
1875
- return func
1876
-
1877
- if asyncio.iscoroutinefunction(func):
1878
-
1879
- @functools.wraps(func)
1880
- async def async_wrapper(*args, **kwargs):
1881
- nonlocal original_span_name
1882
- class_name = None
1883
- span_name = original_span_name
1884
- agent_name = None
1885
-
1886
- if args and hasattr(args[0], "__class__"):
1887
- class_name = args[0].__class__.__name__
1888
- agent_name = get_instance_prefixed_name(
1889
- args[0], class_name, self.class_identifiers
1890
- )
1891
-
1892
- # Get current trace from context
1893
- current_trace = self.get_current_trace()
1894
-
1895
- # If there's no current trace, create a root trace
1896
- if not current_trace:
1897
- trace_id = str(uuid.uuid4())
1898
- project = self.project_name
1899
-
1900
- # Create a new trace client to serve as the root
1901
- current_trace = TraceClient(
1902
- self,
1903
- trace_id,
1904
- span_name, # MODIFIED: Use span_name directly
1905
- project_name=project,
1906
- enable_monitoring=self.enable_monitoring,
1907
- enable_evaluations=self.enable_evaluations,
1908
- )
1909
-
1910
- trace_token = self.set_current_trace(current_trace)
1911
-
1912
- try:
1913
- # Use span for the function execution within the root trace
1914
- # This sets the current_span_var
1915
- with current_trace.span(
1916
- span_name, span_type=span_type
1917
- ) as span: # MODIFIED: Use span_name directly
1918
- # Record inputs
1919
- inputs = combine_args_kwargs(func, args, kwargs)
1920
- span.record_input(inputs)
1921
- if agent_name:
1922
- span.record_agent_name(agent_name)
1923
-
1924
- # Capture state before execution
1925
- self._conditionally_capture_and_record_state(
1926
- span, args, is_before=True
1927
- )
1928
-
1929
- try:
1930
- if self.deep_tracing:
1931
- with _DeepTracer(self):
1932
- result = await func(*args, **kwargs)
1933
- else:
1934
- result = await func(*args, **kwargs)
1935
- except Exception as e:
1936
- _capture_exception_for_trace(
1937
- current_trace, sys.exc_info()
1938
- )
1939
- raise e
1940
-
1941
- # Capture state after execution
1942
- self._conditionally_capture_and_record_state(
1943
- span, args, is_before=False
1944
- )
1945
-
1946
- # Record output
1947
- span.record_output(result)
1948
- return result
1949
- finally:
1950
- # Flush background spans before saving the trace
1951
- try:
1952
- complete_trace_data = {
1953
- "trace_id": current_trace.trace_id,
1954
- "name": current_trace.name,
1955
- "created_at": datetime.fromtimestamp(
1956
- current_trace.start_time or time.time(),
1957
- timezone.utc,
1958
- ).isoformat(),
1959
- "duration": current_trace.get_duration(),
1960
- "trace_spans": [
1961
- span.model_dump()
1962
- for span in current_trace.trace_spans
1963
- ],
1964
- "offline_mode": self.offline_mode,
1965
- "parent_trace_id": current_trace.parent_trace_id,
1966
- "parent_name": current_trace.parent_name,
1967
- }
1968
- # Save the completed trace
1969
- trace_id, server_response = current_trace.save(
1970
- final_save=True
1971
- )
1972
-
1973
- # Store the complete trace data instead of just server response
1974
-
1975
- self.traces.append(complete_trace_data)
1976
-
1977
- # if self.background_span_service:
1978
- # self.background_span_service.flush()
1979
-
1980
- # Reset trace context (span context resets automatically)
1981
- self.reset_current_trace(trace_token)
1982
- except Exception as e:
1983
- judgeval_logger.warning(f"Issue with async_wrapper: {e}")
1984
- return
1985
- else:
1986
- with current_trace.span(span_name, span_type=span_type) as span:
1987
- inputs = combine_args_kwargs(func, args, kwargs)
1988
- span.record_input(inputs)
1989
- if agent_name:
1990
- span.record_agent_name(agent_name)
1991
-
1992
- # Capture state before execution
1993
- self._conditionally_capture_and_record_state(
1994
- span, args, is_before=True
1995
- )
1996
-
1997
- try:
1998
- if self.deep_tracing:
1999
- with _DeepTracer(self):
2000
- result = await func(*args, **kwargs)
2001
- else:
2002
- result = await func(*args, **kwargs)
2003
- except Exception as e:
2004
- _capture_exception_for_trace(current_trace, sys.exc_info())
2005
- raise e
2006
-
2007
- # Capture state after execution
2008
- self._conditionally_capture_and_record_state(
2009
- span, args, is_before=False
2010
- )
2011
-
2012
- span.record_output(result)
2013
- return result
2014
-
2015
- return async_wrapper
2016
- else:
2017
- # Non-async function implementation with deep tracing
2018
- @functools.wraps(func)
2019
- def wrapper(*args, **kwargs):
2020
- nonlocal original_span_name
2021
- class_name = None
2022
- span_name = original_span_name
2023
- agent_name = None
2024
- if args and hasattr(args[0], "__class__"):
2025
- class_name = args[0].__class__.__name__
2026
- agent_name = get_instance_prefixed_name(
2027
- args[0], class_name, self.class_identifiers
2028
- )
2029
- # Get current trace from context
2030
- current_trace = self.get_current_trace()
2031
-
2032
- # If there's no current trace, create a root trace
2033
- if not current_trace:
2034
- trace_id = str(uuid.uuid4())
2035
- project = self.project_name
2036
-
2037
- # Create a new trace client to serve as the root
2038
- current_trace = TraceClient(
2039
- self,
2040
- trace_id,
2041
- span_name, # MODIFIED: Use span_name directly
2042
- project_name=project,
2043
- enable_monitoring=self.enable_monitoring,
2044
- enable_evaluations=self.enable_evaluations,
2045
- )
2046
-
2047
- trace_token = self.set_current_trace(current_trace)
2048
-
2049
- try:
2050
- # Use span for the function execution within the root trace
2051
- # This sets the current_span_var
2052
- with current_trace.span(
2053
- span_name, span_type=span_type
2054
- ) as span: # MODIFIED: Use span_name directly
2055
- # Record inputs
2056
- inputs = combine_args_kwargs(func, args, kwargs)
2057
- span.record_input(inputs)
2058
- if agent_name:
2059
- span.record_agent_name(agent_name)
2060
- # Capture state before execution
2061
- self._conditionally_capture_and_record_state(
2062
- span, args, is_before=True
2063
- )
2064
-
2065
- try:
2066
- if self.deep_tracing:
2067
- with _DeepTracer(self):
2068
- result = func(*args, **kwargs)
2069
- else:
2070
- result = func(*args, **kwargs)
2071
- except Exception as e:
2072
- _capture_exception_for_trace(
2073
- current_trace, sys.exc_info()
2074
- )
2075
- raise e
2076
-
2077
- # Capture state after execution
2078
- self._conditionally_capture_and_record_state(
2079
- span, args, is_before=False
2080
- )
2081
-
2082
- # Record output
2083
- span.record_output(result)
2084
- return result
2085
- finally:
2086
- # Flush background spans before saving the trace
2087
- try:
2088
- # Save the completed trace
2089
- trace_id, server_response = current_trace.save(
2090
- final_save=True
2091
- )
2092
-
2093
- # Store the complete trace data instead of just server response
2094
- complete_trace_data = {
2095
- "trace_id": current_trace.trace_id,
2096
- "name": current_trace.name,
2097
- "created_at": datetime.fromtimestamp(
2098
- current_trace.start_time or time.time(),
2099
- timezone.utc,
2100
- ).isoformat(),
2101
- "duration": current_trace.get_duration(),
2102
- "trace_spans": [
2103
- span.model_dump()
2104
- for span in current_trace.trace_spans
2105
- ],
2106
- "offline_mode": self.offline_mode,
2107
- "parent_trace_id": current_trace.parent_trace_id,
2108
- "parent_name": current_trace.parent_name,
2109
- }
2110
- self.traces.append(complete_trace_data)
2111
- # Reset trace context (span context resets automatically)
2112
- self.reset_current_trace(trace_token)
2113
- except Exception as e:
2114
- judgeval_logger.warning(f"Issue with save: {e}")
2115
- return
2116
- else:
2117
- with current_trace.span(span_name, span_type=span_type) as span:
2118
- inputs = combine_args_kwargs(func, args, kwargs)
2119
- span.record_input(inputs)
2120
- if agent_name:
2121
- span.record_agent_name(agent_name)
2122
-
2123
- # Capture state before execution
2124
- self._conditionally_capture_and_record_state(
2125
- span, args, is_before=True
2126
- )
2127
-
2128
- try:
2129
- if self.deep_tracing:
2130
- with _DeepTracer(self):
2131
- result = func(*args, **kwargs)
2132
- else:
2133
- result = func(*args, **kwargs)
2134
- except Exception as e:
2135
- _capture_exception_for_trace(current_trace, sys.exc_info())
2136
- raise e
2137
-
2138
- # Capture state after execution
2139
- self._conditionally_capture_and_record_state(
2140
- span, args, is_before=False
2141
- )
2142
-
2143
- span.record_output(result)
2144
- return result
2145
-
2146
- return wrapper
2147
-
2148
- def observe_tools(
2149
- self,
2150
- cls=None,
2151
- *,
2152
- exclude_methods: Optional[List[str]] = None,
2153
- include_private: bool = False,
2154
- warn_on_double_decoration: bool = True,
2155
- ):
2156
- """
2157
- Automatically adds @observe(span_type="tool") to all methods in a class.
2158
-
2159
- Args:
2160
- cls: The class to decorate (automatically provided when used as decorator)
2161
- exclude_methods: List of method names to skip decorating. Defaults to common magic methods
2162
- include_private: Whether to decorate methods starting with underscore. Defaults to False
2163
- warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
2164
- """
2165
-
2166
- if exclude_methods is None:
2167
- exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
2168
-
2169
- def decorate_class(cls):
2170
- if not self.enable_monitoring:
2171
- return cls
2172
-
2173
- decorated = []
2174
- skipped = []
2175
-
2176
- for name in dir(cls):
2177
- method = getattr(cls, name)
2178
-
2179
- if (
2180
- not callable(method)
2181
- or name in exclude_methods
2182
- or (name.startswith("_") and not include_private)
2183
- or not hasattr(cls, name)
2184
- ):
2185
- continue
2186
-
2187
- if hasattr(method, "_judgment_span_name"):
2188
- skipped.append(name)
2189
- if warn_on_double_decoration:
2190
- judgeval_logger.info(
2191
- f"{cls.__name__}.{name} already decorated, skipping"
2192
- )
2193
- continue
2194
-
2195
- try:
2196
- decorated_method = self.observe(method, span_type="tool")
2197
- setattr(cls, name, decorated_method)
2198
- decorated.append(name)
2199
- except Exception as e:
2200
- if warn_on_double_decoration:
2201
- judgeval_logger.warning(
2202
- f"Failed to decorate {cls.__name__}.{name}: {e}"
2203
- )
2204
-
2205
- return cls
2206
-
2207
- return decorate_class if cls is None else decorate_class(cls)
2208
-
2209
- def async_evaluate(self, *args, **kwargs):
2210
- try:
2211
- if not self.enable_monitoring or not self.enable_evaluations:
2212
- return
2213
-
2214
- # --- Get trace_id passed explicitly (if any) ---
2215
- passed_trace_id = kwargs.pop(
2216
- "trace_id", None
2217
- ) # Get and remove trace_id from kwargs
2218
-
2219
- current_trace = self.get_current_trace()
2220
-
2221
- if current_trace:
2222
- # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
2223
- # (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
2224
- if passed_trace_id:
2225
- kwargs["trace_id"] = (
2226
- passed_trace_id # Re-add if needed by TraceClient.async_evaluate
2227
- )
2228
- current_trace.async_evaluate(*args, **kwargs)
2229
- else:
2230
- judgeval_logger.warning(
2231
- "No trace found (context var or fallback), skipping evaluation"
2232
- ) # Modified warning
2233
- except Exception as e:
2234
- judgeval_logger.warning(f"Issue with async_evaluate: {e}")
2235
-
2236
- def update_metadata(self, metadata: dict):
2237
- """
2238
- Update metadata for the current trace.
2239
-
2240
- Args:
2241
- metadata: Metadata as a dictionary
2242
- """
2243
- current_trace = self.get_current_trace()
2244
- if current_trace:
2245
- current_trace.update_metadata(metadata)
2246
- else:
2247
- judgeval_logger.warning("No current trace found, cannot set metadata")
2248
-
2249
- def set_customer_id(self, customer_id: str):
2250
- """
2251
- Set the customer ID for the current trace.
2252
-
2253
- Args:
2254
- customer_id: The customer ID to set
2255
- """
2256
- current_trace = self.get_current_trace()
2257
- if current_trace:
2258
- current_trace.set_customer_id(customer_id)
2259
- else:
2260
- judgeval_logger.warning("No current trace found, cannot set customer ID")
2261
-
2262
- def set_tags(self, tags: List[Union[str, set, tuple]]):
2263
- """
2264
- Set the tags for the current trace.
2265
-
2266
- Args:
2267
- tags: List of tags to set
2268
- """
2269
- current_trace = self.get_current_trace()
2270
- if current_trace:
2271
- current_trace.set_tags(tags)
2272
- else:
2273
- judgeval_logger.warning("No current trace found, cannot set tags")
2274
-
2275
- def get_background_span_service(self) -> Optional[BackgroundSpanService]:
2276
- """Get the background span service instance."""
2277
- return self.background_span_service
2278
-
2279
- def flush_background_spans(self):
2280
- """Flush all pending spans in the background service."""
2281
- if self.background_span_service:
2282
- self.background_span_service.flush()
2283
-
2284
- def shutdown_background_service(self):
2285
- """Shutdown the background span service."""
2286
- if self.background_span_service:
2287
- self.background_span_service.shutdown()
2288
- self.background_span_service = None
2289
-
2290
-
2291
- def _get_current_trace(
2292
- trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2293
- ):
2294
- if trace_across_async_contexts:
2295
- return Tracer.current_trace
2296
- else:
2297
- return current_trace_var.get()
2298
-
2299
-
2300
- def wrap(
2301
- client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
2302
- ) -> Any:
2303
- """
2304
- Wraps an API client to add tracing capabilities.
2305
- Supports OpenAI, Together, Anthropic, and Google GenAI clients.
2306
- Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
2307
- """
2308
- (
2309
- span_name,
2310
- original_create,
2311
- original_responses_create,
2312
- original_stream,
2313
- original_beta_parse,
2314
- ) = _get_client_config(client)
2315
-
2316
- def _record_input_and_check_streaming(span, kwargs, is_responses=False):
2317
- """Record input and check for streaming"""
2318
- is_streaming = kwargs.get("stream", False)
2319
-
2320
- # Record input based on whether this is a responses endpoint
2321
- if is_responses:
2322
- span.record_input(kwargs)
2323
- else:
2324
- input_data = _format_input_data(client, **kwargs)
2325
- span.record_input(input_data)
2326
-
2327
- # Warn about token counting limitations with streaming
2328
- if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
2329
- if not kwargs.get("stream_options", {}).get("include_usage"):
2330
- judgeval_logger.warning(
2331
- "OpenAI streaming calls don't include token counts by default. "
2332
- "To enable token counting with streams, set stream_options={'include_usage': True} "
2333
- "in your API call arguments.",
2334
- UserWarning,
2335
- )
2336
-
2337
- return is_streaming
2338
-
2339
- def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
2340
- """Format and record the output in the span"""
2341
- if is_streaming:
2342
- output_entry = span.record_output("<pending stream>")
2343
- wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
2344
- return wrapper_func(
2345
- response, client, output_entry, trace_across_async_contexts
2346
- )
2347
- else:
2348
- format_func = (
2349
- _format_response_output_data if is_responses else _format_output_data
2350
- )
2351
- output, usage = format_func(client, response)
2352
- span.record_output(output)
2353
- span.record_usage(usage)
2354
-
2355
- # Queue the completed LLM span now that it has all data (input, output, usage)
2356
- current_trace = _get_current_trace(trace_across_async_contexts)
2357
- if current_trace and current_trace.background_span_service:
2358
- # Get the current span from the trace client
2359
- current_span_id = current_trace.get_current_span()
2360
- if current_span_id and current_span_id in current_trace.span_id_to_span:
2361
- completed_span = current_trace.span_id_to_span[current_span_id]
2362
- current_trace.background_span_service.queue_span(
2363
- completed_span, span_state="completed"
2364
- )
2365
-
2366
- return response
2367
-
2368
- # --- Traced Async Functions ---
2369
- async def traced_create_async(*args, **kwargs):
2370
- current_trace = _get_current_trace(trace_across_async_contexts)
2371
- if not current_trace:
2372
- return await original_create(*args, **kwargs)
2373
-
2374
- with current_trace.span(span_name, span_type="llm") as span:
2375
- is_streaming = _record_input_and_check_streaming(span, kwargs)
2376
-
2377
- try:
2378
- response_or_iterator = await original_create(*args, **kwargs)
2379
- return _format_and_record_output(
2380
- span, response_or_iterator, is_streaming, True, False
2381
- )
2382
- except Exception as e:
2383
- _capture_exception_for_trace(span, sys.exc_info())
2384
- raise e
2385
-
2386
- async def traced_beta_parse_async(*args, **kwargs):
2387
- current_trace = _get_current_trace(trace_across_async_contexts)
2388
- if not current_trace:
2389
- return await original_beta_parse(*args, **kwargs)
2390
-
2391
- with current_trace.span(span_name, span_type="llm") as span:
2392
- is_streaming = _record_input_and_check_streaming(span, kwargs)
2393
-
2394
- try:
2395
- response_or_iterator = await original_beta_parse(*args, **kwargs)
2396
- return _format_and_record_output(
2397
- span, response_or_iterator, is_streaming, True, False
2398
- )
2399
- except Exception as e:
2400
- _capture_exception_for_trace(span, sys.exc_info())
2401
- raise e
2402
-
2403
- # Async responses for OpenAI clients
2404
- async def traced_response_create_async(*args, **kwargs):
2405
- current_trace = _get_current_trace(trace_across_async_contexts)
2406
- if not current_trace:
2407
- return await original_responses_create(*args, **kwargs)
2408
-
2409
- with current_trace.span(span_name, span_type="llm") as span:
2410
- is_streaming = _record_input_and_check_streaming(
2411
- span, kwargs, is_responses=True
2412
- )
2413
-
2414
- try:
2415
- response_or_iterator = await original_responses_create(*args, **kwargs)
2416
- return _format_and_record_output(
2417
- span, response_or_iterator, is_streaming, True, True
2418
- )
2419
- except Exception as e:
2420
- _capture_exception_for_trace(span, sys.exc_info())
2421
- raise e
2422
-
2423
- # Function replacing .stream() for async clients
2424
- def traced_stream_async(*args, **kwargs):
2425
- current_trace = _get_current_trace(trace_across_async_contexts)
2426
- if not current_trace or not original_stream:
2427
- return original_stream(*args, **kwargs)
2428
-
2429
- original_manager = original_stream(*args, **kwargs)
2430
- return _TracedAsyncStreamManagerWrapper(
2431
- original_manager=original_manager,
2432
- client=client,
2433
- span_name=span_name,
2434
- trace_client=current_trace,
2435
- stream_wrapper_func=_async_stream_wrapper,
2436
- input_kwargs=kwargs,
2437
- trace_across_async_contexts=trace_across_async_contexts,
2438
- )
2439
-
2440
- # --- Traced Sync Functions ---
2441
- def traced_create_sync(*args, **kwargs):
2442
- current_trace = _get_current_trace(trace_across_async_contexts)
2443
- if not current_trace:
2444
- return original_create(*args, **kwargs)
2445
-
2446
- with current_trace.span(span_name, span_type="llm") as span:
2447
- is_streaming = _record_input_and_check_streaming(span, kwargs)
2448
-
2449
- try:
2450
- response_or_iterator = original_create(*args, **kwargs)
2451
- return _format_and_record_output(
2452
- span, response_or_iterator, is_streaming, False, False
2453
- )
2454
- except Exception as e:
2455
- _capture_exception_for_trace(span, sys.exc_info())
2456
- raise e
2457
-
2458
- def traced_beta_parse_sync(*args, **kwargs):
2459
- current_trace = _get_current_trace(trace_across_async_contexts)
2460
- if not current_trace:
2461
- return original_beta_parse(*args, **kwargs)
2462
-
2463
- with current_trace.span(span_name, span_type="llm") as span:
2464
- is_streaming = _record_input_and_check_streaming(span, kwargs)
2465
-
2466
- try:
2467
- response_or_iterator = original_beta_parse(*args, **kwargs)
2468
- return _format_and_record_output(
2469
- span, response_or_iterator, is_streaming, False, False
2470
- )
2471
- except Exception as e:
2472
- _capture_exception_for_trace(span, sys.exc_info())
2473
- raise e
2474
-
2475
- def traced_response_create_sync(*args, **kwargs):
2476
- current_trace = _get_current_trace(trace_across_async_contexts)
2477
- if not current_trace:
2478
- return original_responses_create(*args, **kwargs)
2479
-
2480
- with current_trace.span(span_name, span_type="llm") as span:
2481
- is_streaming = _record_input_and_check_streaming(
2482
- span, kwargs, is_responses=True
2483
- )
2484
-
2485
- try:
2486
- response_or_iterator = original_responses_create(*args, **kwargs)
2487
- return _format_and_record_output(
2488
- span, response_or_iterator, is_streaming, False, True
2489
- )
2490
- except Exception as e:
2491
- _capture_exception_for_trace(span, sys.exc_info())
2492
- raise e
2493
-
2494
- # Function replacing sync .stream()
2495
- def traced_stream_sync(*args, **kwargs):
2496
- current_trace = _get_current_trace(trace_across_async_contexts)
2497
- if not current_trace or not original_stream:
2498
- return original_stream(*args, **kwargs)
2499
-
2500
- original_manager = original_stream(*args, **kwargs)
2501
- return _TracedSyncStreamManagerWrapper(
2502
- original_manager=original_manager,
2503
- client=client,
2504
- span_name=span_name,
2505
- trace_client=current_trace,
2506
- stream_wrapper_func=_sync_stream_wrapper,
2507
- input_kwargs=kwargs,
2508
- trace_across_async_contexts=trace_across_async_contexts,
2509
- )
2510
-
2511
- # --- Assign Traced Methods to Client Instance ---
2512
- if isinstance(client, (AsyncOpenAI, AsyncTogether)):
2513
- client.chat.completions.create = traced_create_async
2514
- if hasattr(client, "responses") and hasattr(client.responses, "create"):
2515
- client.responses.create = traced_response_create_async
2516
- if (
2517
- hasattr(client, "beta")
2518
- and hasattr(client.beta, "chat")
2519
- and hasattr(client.beta.chat, "completions")
2520
- and hasattr(client.beta.chat.completions, "parse")
2521
- ):
2522
- client.beta.chat.completions.parse = traced_beta_parse_async
2523
- elif isinstance(client, AsyncAnthropic):
2524
- client.messages.create = traced_create_async
2525
- if original_stream:
2526
- client.messages.stream = traced_stream_async
2527
- elif isinstance(client, genai.client.AsyncClient):
2528
- client.models.generate_content = traced_create_async
2529
- elif isinstance(client, (OpenAI, Together)):
2530
- client.chat.completions.create = traced_create_sync
2531
- if hasattr(client, "responses") and hasattr(client.responses, "create"):
2532
- client.responses.create = traced_response_create_sync
2533
- if (
2534
- hasattr(client, "beta")
2535
- and hasattr(client.beta, "chat")
2536
- and hasattr(client.beta.chat, "completions")
2537
- and hasattr(client.beta.chat.completions, "parse")
2538
- ):
2539
- client.beta.chat.completions.parse = traced_beta_parse_sync
2540
- elif isinstance(client, Anthropic):
2541
- client.messages.create = traced_create_sync
2542
- if original_stream:
2543
- client.messages.stream = traced_stream_sync
2544
- elif isinstance(client, genai.Client):
2545
- client.models.generate_content = traced_create_sync
2546
-
2547
- return client
2548
-
2549
-
2550
- # Helper functions for client-specific operations
2551
-
2552
-
2553
- def _get_client_config(
2554
- client: ApiClient,
2555
- ) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
2556
- """Returns configuration tuple for the given API client.
2557
-
2558
- Args:
2559
- client: An instance of OpenAI, Together, or Anthropic client
2560
-
2561
- Returns:
2562
- tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
2563
- - span_name: String identifier for tracing
2564
- - create_method: Reference to the client's creation method
2565
- - responses_method: Reference to the client's responses method (if applicable)
2566
- - stream_method: Reference to the client's stream method (if applicable)
2567
- - beta_parse_method: Reference to the client's beta parse method (if applicable)
2568
-
2569
- Raises:
2570
- ValueError: If client type is not supported
2571
- """
2572
- if isinstance(client, (OpenAI, AsyncOpenAI)):
2573
- return (
2574
- "OPENAI_API_CALL",
2575
- client.chat.completions.create,
2576
- client.responses.create,
2577
- None,
2578
- client.beta.chat.completions.parse,
2579
- )
2580
- elif isinstance(client, (Together, AsyncTogether)):
2581
- return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
2582
- elif isinstance(client, (Anthropic, AsyncAnthropic)):
2583
- return (
2584
- "ANTHROPIC_API_CALL",
2585
- client.messages.create,
2586
- None,
2587
- client.messages.stream,
2588
- None,
2589
- )
2590
- elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2591
- return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
2592
- raise ValueError(f"Unsupported client type: {type(client)}")
2593
-
2594
-
2595
- def _format_input_data(client: ApiClient, **kwargs) -> dict:
2596
- """Format input parameters based on client type.
2597
-
2598
- Extracts relevant parameters from kwargs based on the client type
2599
- to ensure consistent tracing across different APIs.
2600
- """
2601
- if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
2602
- input_data = {
2603
- "model": kwargs.get("model"),
2604
- "messages": kwargs.get("messages"),
2605
- }
2606
- if kwargs.get("response_format"):
2607
- input_data["response_format"] = kwargs.get("response_format")
2608
- return input_data
2609
- elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2610
- return {"model": kwargs.get("model"), "contents": kwargs.get("contents")}
2611
- # Anthropic requires additional max_tokens parameter
2612
- return {
2613
- "model": kwargs.get("model"),
2614
- "messages": kwargs.get("messages"),
2615
- "max_tokens": kwargs.get("max_tokens"),
2616
- }
2617
-
2618
-
2619
- def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
2620
- """Format API response data based on client type.
2621
-
2622
- Normalizes different response formats into a consistent structure
2623
- for tracing purposes.
2624
- """
2625
- message_content = None
2626
- prompt_tokens = 0
2627
- completion_tokens = 0
2628
- model_name = None
2629
- cache_read_input_tokens = 0
2630
- cache_creation_input_tokens = 0
2631
-
2632
- if isinstance(client, (OpenAI, AsyncOpenAI)):
2633
- model_name = response.model
2634
- prompt_tokens = response.usage.input_tokens
2635
- completion_tokens = response.usage.output_tokens
2636
- cache_read_input_tokens = response.usage.input_tokens_details.cached_tokens
2637
- message_content = "".join(seg.text for seg in response.output[0].content)
2638
- else:
2639
- judgeval_logger.warning(f"Unsupported client type: {type(client)}")
2640
- return None, None
2641
-
2642
- prompt_cost, completion_cost = cost_per_token(
2643
- model=model_name,
2644
- prompt_tokens=prompt_tokens,
2645
- completion_tokens=completion_tokens,
2646
- cache_read_input_tokens=cache_read_input_tokens,
2647
- cache_creation_input_tokens=cache_creation_input_tokens,
2648
- )
2649
- total_cost_usd = (
2650
- (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2651
- )
2652
- usage = TraceUsage(
2653
- prompt_tokens=prompt_tokens,
2654
- completion_tokens=completion_tokens,
2655
- total_tokens=prompt_tokens + completion_tokens,
2656
- cache_read_input_tokens=cache_read_input_tokens,
2657
- cache_creation_input_tokens=cache_creation_input_tokens,
2658
- prompt_tokens_cost_usd=prompt_cost,
2659
- completion_tokens_cost_usd=completion_cost,
2660
- total_cost_usd=total_cost_usd,
2661
- model_name=model_name,
2662
- )
2663
- return message_content, usage
2664
-
2665
-
2666
- def _format_output_data(
2667
- client: ApiClient, response: Any
2668
- ) -> tuple[Optional[str], Optional[TraceUsage]]:
2669
- """Format API response data based on client type.
2670
-
2671
- Normalizes different response formats into a consistent structure
2672
- for tracing purposes.
2673
-
2674
- Returns:
2675
- dict containing:
2676
- - content: The generated text
2677
- - usage: Token usage statistics
2678
- """
2679
- prompt_tokens = 0
2680
- completion_tokens = 0
2681
- cache_read_input_tokens = 0
2682
- cache_creation_input_tokens = 0
2683
- model_name = None
2684
- message_content = None
2685
-
2686
- if isinstance(client, (OpenAI, AsyncOpenAI)):
2687
- model_name = response.model
2688
- prompt_tokens = response.usage.prompt_tokens
2689
- completion_tokens = response.usage.completion_tokens
2690
- cache_read_input_tokens = response.usage.prompt_tokens_details.cached_tokens
2691
- if (
2692
- hasattr(response.choices[0].message, "parsed")
2693
- and response.choices[0].message.parsed
2694
- ):
2695
- message_content = response.choices[0].message.parsed
2696
- else:
2697
- message_content = response.choices[0].message.content
2698
-
2699
- elif isinstance(client, (Together, AsyncTogether)):
2700
- model_name = "together_ai/" + response.model
2701
- prompt_tokens = response.usage.prompt_tokens
2702
- completion_tokens = response.usage.completion_tokens
2703
- if (
2704
- hasattr(response.choices[0].message, "parsed")
2705
- and response.choices[0].message.parsed
2706
- ):
2707
- message_content = response.choices[0].message.parsed
2708
- else:
2709
- message_content = response.choices[0].message.content
2710
- elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2711
- model_name = response.model_version
2712
- prompt_tokens = response.usage_metadata.prompt_token_count
2713
- completion_tokens = response.usage_metadata.candidates_token_count
2714
- message_content = response.candidates[0].content.parts[0].text
2715
- elif isinstance(client, (Anthropic, AsyncAnthropic)):
2716
- model_name = response.model
2717
- prompt_tokens = response.usage.input_tokens
2718
- completion_tokens = response.usage.output_tokens
2719
- cache_read_input_tokens = response.usage.cache_read_input_tokens
2720
- cache_creation_input_tokens = response.usage.cache_creation_input_tokens
2721
- message_content = response.content[0].text
2722
- else:
2723
- judgeval_logger.warning(f"Unsupported client type: {type(client)}")
2724
- return None, None
2725
-
2726
- prompt_cost, completion_cost = cost_per_token(
2727
- model=model_name,
2728
- prompt_tokens=prompt_tokens,
2729
- completion_tokens=completion_tokens,
2730
- cache_read_input_tokens=cache_read_input_tokens,
2731
- cache_creation_input_tokens=cache_creation_input_tokens,
2732
- )
2733
- total_cost_usd = (
2734
- (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2735
- )
2736
- usage = TraceUsage(
2737
- prompt_tokens=prompt_tokens,
2738
- completion_tokens=completion_tokens,
2739
- total_tokens=prompt_tokens + completion_tokens,
2740
- cache_read_input_tokens=cache_read_input_tokens,
2741
- cache_creation_input_tokens=cache_creation_input_tokens,
2742
- prompt_tokens_cost_usd=prompt_cost,
2743
- completion_tokens_cost_usd=completion_cost,
2744
- total_cost_usd=total_cost_usd,
2745
- model_name=model_name,
2746
- )
2747
- return message_content, usage
2748
-
2749
-
2750
- def combine_args_kwargs(func, args, kwargs):
2751
- """
2752
- Combine positional arguments and keyword arguments into a single dictionary.
2753
-
2754
- Args:
2755
- func: The function being called
2756
- args: Tuple of positional arguments
2757
- kwargs: Dictionary of keyword arguments
2758
-
2759
- Returns:
2760
- A dictionary combining both args and kwargs
2761
- """
2762
- try:
2763
- import inspect
2764
-
2765
- sig = inspect.signature(func)
2766
- param_names = list(sig.parameters.keys())
2767
-
2768
- args_dict = {}
2769
- for i, arg in enumerate(args):
2770
- if i < len(param_names):
2771
- args_dict[param_names[i]] = arg
2772
- else:
2773
- args_dict[f"arg{i}"] = arg
2774
-
2775
- return {**args_dict, **kwargs}
2776
- except Exception:
2777
- # Fallback if signature inspection fails
2778
- return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
2779
-
2780
-
2781
- # NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
2782
- # @link https://docs.python.org/3.13/library/sysconfig.html
2783
- _TRACE_FILEPATH_BLOCKLIST = tuple(
2784
- os.path.realpath(p) + os.sep
2785
- for p in {
2786
- sysconfig.get_paths()["stdlib"],
2787
- sysconfig.get_paths().get("platstdlib", ""),
2788
- *site.getsitepackages(),
2789
- site.getusersitepackages(),
2790
- *(
2791
- [os.path.join(os.path.dirname(__file__), "../../judgeval/")]
2792
- if os.environ.get("JUDGMENT_DEV")
2793
- else []
2794
- ),
2795
- }
2796
- if p
2797
- )
2798
-
2799
-
2800
- # Add the new TraceThreadPoolExecutor class
2801
- class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
2802
- """
2803
- A ThreadPoolExecutor subclass that automatically propagates contextvars
2804
- from the submitting thread to the worker thread using copy_context().run().
2805
-
2806
- This ensures that context variables like `current_trace_var` and
2807
- `current_span_var` are available within functions executed by the pool,
2808
- allowing the Tracer to maintain correct parent-child relationships across
2809
- thread boundaries.
2810
- """
2811
-
2812
- def submit(self, fn, /, *args, **kwargs):
2813
- """
2814
- Submit a callable to be executed with the captured context.
2815
- """
2816
- # Capture context from the submitting thread
2817
- ctx = contextvars.copy_context()
2818
-
2819
- # We use functools.partial to bind the arguments to the function *now*,
2820
- # as ctx.run doesn't directly accept *args, **kwargs in the same way
2821
- # submit does. It expects ctx.run(callable, arg1, arg2...).
2822
- func_with_bound_args = functools.partial(fn, *args, **kwargs)
2823
-
2824
- # Submit the ctx.run callable to the original executor.
2825
- # ctx.run will execute the (now argument-bound) function within the
2826
- # captured context in the worker thread.
2827
- return super().submit(ctx.run, func_with_bound_args)
2828
-
2829
- # Note: The `map` method would also need to be overridden for full context
2830
- # propagation if users rely on it, but `submit` is the most common use case.
2831
-
2832
-
2833
- # Helper functions for stream processing
2834
- # ---------------------------------------
2835
-
2836
-
2837
- def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
2838
- """Extracts the text content from a stream chunk based on the client type."""
2839
- try:
2840
- if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
2841
- return chunk.choices[0].delta.content
2842
- elif isinstance(client, (Anthropic, AsyncAnthropic)):
2843
- # Anthropic streams various event types, we only care for content blocks
2844
- if chunk.type == "content_block_delta":
2845
- return chunk.delta.text
2846
- elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2847
- # Google streams Candidate objects
2848
- if (
2849
- chunk.candidates
2850
- and chunk.candidates[0].content
2851
- and chunk.candidates[0].content.parts
2852
- ):
2853
- return chunk.candidates[0].content.parts[0].text
2854
- except (AttributeError, IndexError, KeyError):
2855
- # Handle cases where chunk structure is unexpected or doesn't contain content
2856
- pass # Return None
2857
- return None
2858
-
2859
-
2860
- def _extract_usage_from_final_chunk(
2861
- client: ApiClient, chunk: Any
2862
- ) -> Optional[Dict[str, int]]:
2863
- """Extracts usage data if present in the *final* chunk (client-specific)."""
2864
- try:
2865
- # OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
2866
- # This typically requires specific API versions or settings. Often usage is *not* streamed.
2867
- if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
2868
- # Check if usage is directly on the chunk (some models might do this)
2869
- if hasattr(chunk, "usage") and chunk.usage:
2870
- prompt_tokens = chunk.usage.prompt_tokens
2871
- completion_tokens = chunk.usage.completion_tokens
2872
- # Check if usage is nested within choices (less common for final chunk, but check)
2873
- elif (
2874
- chunk.choices
2875
- and hasattr(chunk.choices[0], "usage")
2876
- and chunk.choices[0].usage
2877
- ):
2878
- prompt_tokens = chunk.choices[0].usage.prompt_tokens
2879
- completion_tokens = chunk.choices[0].usage.completion_tokens
2880
-
2881
- if isinstance(client, (Together, AsyncTogether)):
2882
- model_name = "together_ai/" + chunk.model
2883
-
2884
- prompt_cost, completion_cost = cost_per_token(
2885
- model=model_name,
2886
- prompt_tokens=prompt_tokens,
2887
- completion_tokens=completion_tokens,
2888
- )
2889
- total_cost_usd = (
2890
- (prompt_cost + completion_cost)
2891
- if prompt_cost and completion_cost
2892
- else None
2893
- )
2894
- return TraceUsage(
2895
- prompt_tokens=chunk.usage.prompt_tokens,
2896
- completion_tokens=chunk.usage.completion_tokens,
2897
- total_tokens=chunk.usage.total_tokens,
2898
- prompt_tokens_cost_usd=prompt_cost,
2899
- completion_tokens_cost_usd=completion_cost,
2900
- total_cost_usd=total_cost_usd,
2901
- model_name=chunk.model,
2902
- )
2903
- # Anthropic includes usage in the 'message_stop' event type
2904
- elif isinstance(client, (Anthropic, AsyncAnthropic)):
2905
- if chunk.type == "message_stop":
2906
- # Anthropic final usage is often attached to the *message* object, not the chunk directly
2907
- # The API might provide a way to get the final message object, but typically not in the stream itself.
2908
- # Let's assume for now usage might appear in the final *chunk* metadata if supported.
2909
- # This is a placeholder - Anthropic usage typically needs a separate call or context.
2910
- pass
2911
- elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2912
- # Google provides usage metadata on the full response object, not typically streamed per chunk.
2913
- # It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
2914
- if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
2915
- return {
2916
- "prompt_tokens": chunk.usage_metadata.prompt_token_count,
2917
- "completion_tokens": chunk.usage_metadata.candidates_token_count,
2918
- "total_tokens": chunk.usage_metadata.total_token_count,
2919
- }
2920
-
2921
- except (AttributeError, IndexError, KeyError, TypeError):
2922
- # Handle cases where usage data is missing or malformed
2923
- pass # Return None
2924
- return None
2925
-
2926
-
2927
- # --- Sync Stream Wrapper ---
2928
- def _sync_stream_wrapper(
2929
- original_stream: Iterator,
2930
- client: ApiClient,
2931
- span: TraceSpan,
2932
- trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2933
- ) -> Generator[Any, None, None]:
2934
- """Wraps a synchronous stream iterator to capture content and update the trace."""
2935
- content_parts = [] # Use a list instead of string concatenation
2936
- final_usage = None
2937
- last_chunk = None
2938
- try:
2939
- for chunk in original_stream:
2940
- content_part = _extract_content_from_chunk(client, chunk)
2941
- if content_part:
2942
- content_parts.append(
2943
- content_part
2944
- ) # Append to list instead of concatenating
2945
- last_chunk = chunk # Keep track of the last chunk for potential usage data
2946
- yield chunk # Pass the chunk to the caller
2947
- finally:
2948
- # Attempt to extract usage from the last chunk received
2949
- if last_chunk:
2950
- final_usage = _extract_usage_from_final_chunk(client, last_chunk)
2951
-
2952
- # Update the trace entry with the accumulated content and usage
2953
- span.output = "".join(content_parts)
2954
- span.usage = final_usage
2955
-
2956
- # Queue the completed LLM span now that streaming is done and all data is available
2957
-
2958
- current_trace = _get_current_trace(trace_across_async_contexts)
2959
- if current_trace and current_trace.background_span_service:
2960
- current_trace.background_span_service.queue_span(
2961
- span, span_state="completed"
2962
- )
2963
-
2964
- # Note: We might need to adjust _serialize_output if this dict causes issues,
2965
- # but Pydantic's model_dump should handle dicts.
2966
-
2967
-
2968
- # --- Async Stream Wrapper ---
2969
- async def _async_stream_wrapper(
2970
- original_stream: AsyncIterator,
2971
- client: ApiClient,
2972
- span: TraceSpan,
2973
- trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2974
- ) -> AsyncGenerator[Any, None]:
2975
- # [Existing logic - unchanged]
2976
- content_parts = [] # Use a list instead of string concatenation
2977
- final_usage_data = None
2978
- last_content_chunk = None
2979
- anthropic_input_tokens = 0
2980
- anthropic_output_tokens = 0
2981
-
2982
- try:
2983
- model_name = ""
2984
- async for chunk in original_stream:
2985
- # Check for OpenAI's final usage chunk
2986
- if (
2987
- isinstance(client, (AsyncOpenAI, OpenAI))
2988
- and hasattr(chunk, "usage")
2989
- and chunk.usage is not None
2990
- ):
2991
- final_usage_data = {
2992
- "prompt_tokens": chunk.usage.prompt_tokens,
2993
- "completion_tokens": chunk.usage.completion_tokens,
2994
- "total_tokens": chunk.usage.total_tokens,
2995
- }
2996
- model_name = chunk.model
2997
- yield chunk
2998
- continue
2999
-
3000
- if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
3001
- chunk, "type"
3002
- ):
3003
- if chunk.type == "message_start":
3004
- if (
3005
- hasattr(chunk, "message")
3006
- and hasattr(chunk.message, "usage")
3007
- and hasattr(chunk.message.usage, "input_tokens")
3008
- ):
3009
- anthropic_input_tokens = chunk.message.usage.input_tokens
3010
- model_name = chunk.message.model
3011
- elif chunk.type == "message_delta":
3012
- if hasattr(chunk, "usage") and hasattr(
3013
- chunk.usage, "output_tokens"
3014
- ):
3015
- anthropic_output_tokens = chunk.usage.output_tokens
3016
-
3017
- content_part = _extract_content_from_chunk(client, chunk)
3018
- if content_part:
3019
- content_parts.append(
3020
- content_part
3021
- ) # Append to list instead of concatenating
3022
- last_content_chunk = chunk
3023
-
3024
- yield chunk
3025
- finally:
3026
- anthropic_final_usage = None
3027
- if isinstance(client, (AsyncAnthropic, Anthropic)) and (
3028
- anthropic_input_tokens > 0 or anthropic_output_tokens > 0
3029
- ):
3030
- anthropic_final_usage = {
3031
- "prompt_tokens": anthropic_input_tokens,
3032
- "completion_tokens": anthropic_output_tokens,
3033
- "total_tokens": anthropic_input_tokens + anthropic_output_tokens,
3034
- }
3035
-
3036
- usage_info = None
3037
- if final_usage_data:
3038
- usage_info = final_usage_data
3039
- elif anthropic_final_usage:
3040
- usage_info = anthropic_final_usage
3041
- elif last_content_chunk:
3042
- usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
3043
-
3044
- if usage_info and not isinstance(usage_info, TraceUsage):
3045
- prompt_cost, completion_cost = cost_per_token(
3046
- model=model_name,
3047
- prompt_tokens=usage_info["prompt_tokens"],
3048
- completion_tokens=usage_info["completion_tokens"],
3049
- )
3050
- usage_info = TraceUsage(
3051
- prompt_tokens=usage_info["prompt_tokens"],
3052
- completion_tokens=usage_info["completion_tokens"],
3053
- total_tokens=usage_info["total_tokens"],
3054
- prompt_tokens_cost_usd=prompt_cost,
3055
- completion_tokens_cost_usd=completion_cost,
3056
- total_cost_usd=prompt_cost + completion_cost,
3057
- model_name=model_name,
3058
- )
3059
- if span and hasattr(span, "output"):
3060
- span.output = "".join(content_parts)
3061
- span.usage = usage_info
3062
- start_ts = getattr(span, "created_at", time.time())
3063
- span.duration = time.time() - start_ts
3064
-
3065
- # Queue the completed LLM span now that async streaming is done and all data is available
3066
- current_trace = _get_current_trace(trace_across_async_contexts)
3067
- if current_trace and current_trace.background_span_service:
3068
- current_trace.background_span_service.queue_span(
3069
- span, span_state="completed"
3070
- )
3071
- # else: # Handle error case if necessary, but remove debug print
3072
-
3073
-
3074
- def cost_per_token(*args, **kwargs):
3075
- try:
3076
- prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar = (
3077
- _original_cost_per_token(*args, **kwargs)
3078
- )
3079
- if (
3080
- prompt_tokens_cost_usd_dollar == 0
3081
- and completion_tokens_cost_usd_dollar == 0
3082
- ):
3083
- judgeval_logger.warning("LiteLLM returned a total of 0 for cost per token")
3084
- return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
3085
- except Exception as e:
3086
- judgeval_logger.warning(f"Error calculating cost per token: {e}")
3087
- return None, None
3088
-
3089
-
3090
- class _BaseStreamManagerWrapper:
3091
- def __init__(
3092
- self,
3093
- original_manager,
3094
- client,
3095
- span_name,
3096
- trace_client,
3097
- stream_wrapper_func,
3098
- input_kwargs,
3099
- trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
3100
- ):
3101
- self._original_manager = original_manager
3102
- self._client = client
3103
- self._span_name = span_name
3104
- self._trace_client = trace_client
3105
- self._stream_wrapper_func = stream_wrapper_func
3106
- self._input_kwargs = input_kwargs
3107
- self._parent_span_id_at_entry = None
3108
- self._trace_across_async_contexts = trace_across_async_contexts
3109
-
3110
- def _create_span(self):
3111
- start_time = time.time()
3112
- span_id = str(uuid.uuid4())
3113
- current_depth = 0
3114
- if (
3115
- self._parent_span_id_at_entry
3116
- and self._parent_span_id_at_entry in self._trace_client._span_depths
3117
- ):
3118
- current_depth = (
3119
- self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
3120
- )
3121
- self._trace_client._span_depths[span_id] = current_depth
3122
- span = TraceSpan(
3123
- function=self._span_name,
3124
- span_id=span_id,
3125
- trace_id=self._trace_client.trace_id,
3126
- depth=current_depth,
3127
- message=self._span_name,
3128
- created_at=start_time,
3129
- span_type="llm",
3130
- parent_span_id=self._parent_span_id_at_entry,
3131
- )
3132
- self._trace_client.add_span(span)
3133
- return span_id, span
3134
-
3135
- def _finalize_span(self, span_id):
3136
- span = self._trace_client.span_id_to_span.get(span_id)
3137
- if span:
3138
- span.duration = time.time() - span.created_at
3139
- if span_id in self._trace_client._span_depths:
3140
- del self._trace_client._span_depths[span_id]
3141
-
3142
-
3143
- class _TracedAsyncStreamManagerWrapper(
3144
- _BaseStreamManagerWrapper, AbstractAsyncContextManager
3145
- ):
3146
- async def __aenter__(self):
3147
- self._parent_span_id_at_entry = self._trace_client.get_current_span()
3148
- if not self._trace_client:
3149
- return await self._original_manager.__aenter__()
3150
-
3151
- span_id, span = self._create_span()
3152
- self._span_context_token = self._trace_client.set_current_span(span_id)
3153
- span.inputs = _format_input_data(self._client, **self._input_kwargs)
3154
-
3155
- # Call the original __aenter__ and expect it to be an async generator
3156
- raw_iterator = await self._original_manager.__aenter__()
3157
- span.output = "<pending stream>"
3158
- return self._stream_wrapper_func(
3159
- raw_iterator, self._client, span, self._trace_across_async_contexts
3160
- )
3161
-
3162
- async def __aexit__(self, exc_type, exc_val, exc_tb):
3163
- if hasattr(self, "_span_context_token"):
3164
- span_id = self._trace_client.get_current_span()
3165
- self._finalize_span(span_id)
3166
- self._trace_client.reset_current_span(self._span_context_token)
3167
- delattr(self, "_span_context_token")
3168
- return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
3169
-
3170
-
3171
- class _TracedSyncStreamManagerWrapper(
3172
- _BaseStreamManagerWrapper, AbstractContextManager
3173
- ):
3174
- def __enter__(self):
3175
- self._parent_span_id_at_entry = self._trace_client.get_current_span()
3176
- if not self._trace_client:
3177
- return self._original_manager.__enter__()
3178
-
3179
- span_id, span = self._create_span()
3180
- self._span_context_token = self._trace_client.set_current_span(span_id)
3181
- span.inputs = _format_input_data(self._client, **self._input_kwargs)
3182
-
3183
- raw_iterator = self._original_manager.__enter__()
3184
- span.output = "<pending stream>"
3185
- return self._stream_wrapper_func(
3186
- raw_iterator, self._client, span, self._trace_across_async_contexts
3187
- )
3188
-
3189
- def __exit__(self, exc_type, exc_val, exc_tb):
3190
- if hasattr(self, "_span_context_token"):
3191
- span_id = self._trace_client.get_current_span()
3192
- self._finalize_span(span_id)
3193
- self._trace_client.reset_current_span(self._span_context_token)
3194
- delattr(self, "_span_context_token")
3195
- return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
3196
-
3197
-
3198
- # --- Helper function for instance-prefixed qual_name ---
3199
- def get_instance_prefixed_name(instance, class_name, class_identifiers):
3200
- """
3201
- Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
3202
- Otherwise, returns None.
3203
- """
3204
- if class_name in class_identifiers:
3205
- class_config = class_identifiers[class_name]
3206
- attr = class_config["identifier"]
3207
-
3208
- if hasattr(instance, attr):
3209
- instance_name = getattr(instance, attr)
3210
- return instance_name
3211
- else:
3212
- raise Exception(
3213
- f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
3214
- )
3215
- return None