judgeval 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. judgeval/__init__.py +5 -4
  2. judgeval/clients.py +6 -6
  3. judgeval/common/__init__.py +7 -2
  4. judgeval/common/exceptions.py +2 -3
  5. judgeval/common/logger.py +74 -49
  6. judgeval/common/s3_storage.py +30 -23
  7. judgeval/common/tracer.py +1273 -939
  8. judgeval/common/utils.py +416 -244
  9. judgeval/constants.py +73 -61
  10. judgeval/data/__init__.py +1 -1
  11. judgeval/data/custom_example.py +3 -2
  12. judgeval/data/datasets/dataset.py +80 -54
  13. judgeval/data/datasets/eval_dataset_client.py +131 -181
  14. judgeval/data/example.py +67 -43
  15. judgeval/data/result.py +11 -9
  16. judgeval/data/scorer_data.py +4 -2
  17. judgeval/data/tool.py +25 -16
  18. judgeval/data/trace.py +57 -29
  19. judgeval/data/trace_run.py +5 -11
  20. judgeval/evaluation_run.py +22 -82
  21. judgeval/integrations/langgraph.py +546 -184
  22. judgeval/judges/base_judge.py +1 -2
  23. judgeval/judges/litellm_judge.py +33 -11
  24. judgeval/judges/mixture_of_judges.py +128 -78
  25. judgeval/judges/together_judge.py +22 -9
  26. judgeval/judges/utils.py +14 -5
  27. judgeval/judgment_client.py +259 -271
  28. judgeval/rules.py +169 -142
  29. judgeval/run_evaluation.py +462 -305
  30. judgeval/scorers/api_scorer.py +20 -11
  31. judgeval/scorers/exceptions.py +1 -0
  32. judgeval/scorers/judgeval_scorer.py +77 -58
  33. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
  34. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
  35. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
  36. judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
  37. judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
  38. judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
  39. judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
  40. judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
  41. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
  42. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
  43. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
  44. judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
  45. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
  46. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
  47. judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
  48. judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
  49. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
  50. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
  51. judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
  52. judgeval/scorers/prompt_scorer.py +48 -37
  53. judgeval/scorers/score.py +86 -53
  54. judgeval/scorers/utils.py +11 -7
  55. judgeval/tracer/__init__.py +1 -1
  56. judgeval/utils/alerts.py +23 -12
  57. judgeval/utils/{data_utils.py → file_utils.py} +5 -9
  58. judgeval/utils/requests.py +29 -0
  59. judgeval/version_check.py +5 -2
  60. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/METADATA +79 -135
  61. judgeval-0.0.46.dist-info/RECORD +69 -0
  62. judgeval-0.0.44.dist-info/RECORD +0 -68
  63. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/WHEEL +0 -0
  64. {judgeval-0.0.44.dist-info → judgeval-0.0.46.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Tracing system for judgeval that allows for function tracing using decorators.
3
3
  """
4
+
4
5
  # Standard library imports
5
6
  import asyncio
6
7
  import functools
@@ -16,9 +17,13 @@ import warnings
16
17
  import contextvars
17
18
  import sys
18
19
  import json
19
- from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
20
- from dataclasses import dataclass, field
21
- from datetime import datetime, timezone
20
+ from contextlib import (
21
+ contextmanager,
22
+ AbstractAsyncContextManager,
23
+ AbstractContextManager,
24
+ ) # Import context manager bases
25
+ from dataclasses import dataclass
26
+ from datetime import datetime
22
27
  from http import HTTPStatus
23
28
  from typing import (
24
29
  Any,
@@ -37,9 +42,9 @@ from rich import print as rprint
37
42
  import types
38
43
 
39
44
  # Third-party imports
40
- import requests
45
+ from requests import RequestException
46
+ from judgeval.utils.requests import requests
41
47
  from litellm import cost_per_token as _original_cost_per_token
42
- from rich import print as rprint
43
48
  from openai import OpenAI, AsyncOpenAI
44
49
  from together import Together, AsyncTogether
45
50
  from anthropic import Anthropic, AsyncAnthropic
@@ -50,12 +55,7 @@ from judgeval.constants import (
50
55
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
51
56
  JUDGMENT_TRACES_SAVE_API_URL,
52
57
  JUDGMENT_TRACES_UPSERT_API_URL,
53
- JUDGMENT_TRACES_USAGE_CHECK_API_URL,
54
- JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
55
58
  JUDGMENT_TRACES_FETCH_API_URL,
56
- RABBITMQ_HOST,
57
- RABBITMQ_PORT,
58
- RABBITMQ_QUEUE,
59
59
  JUDGMENT_TRACES_DELETE_API_URL,
60
60
  JUDGMENT_PROJECT_DELETE_API_URL,
61
61
  JUDGMENT_TRACES_SPANS_BATCH_API_URL,
@@ -70,32 +70,37 @@ from judgeval.common.exceptions import JudgmentAPIError
70
70
 
71
71
  # Standard library imports needed for the new class
72
72
  import concurrent.futures
73
- from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
73
+ from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
74
74
  import queue
75
75
  import atexit
76
76
 
77
77
  # Define context variables for tracking the current trace and the current span within a trace
78
- current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
79
- current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
78
+ current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
79
+ "current_trace", default=None
80
+ )
81
+ current_span_var = contextvars.ContextVar[Optional[str]](
82
+ "current_span", default=None
83
+ ) # ContextVar for the active span id
80
84
 
81
85
  # Define type aliases for better code readability and maintainability
82
- ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
83
- SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
86
+ ApiClient: TypeAlias = Union[
87
+ OpenAI,
88
+ Together,
89
+ Anthropic,
90
+ AsyncOpenAI,
91
+ AsyncAnthropic,
92
+ AsyncTogether,
93
+ genai.Client,
94
+ genai.client.AsyncClient,
95
+ ] # Supported API clients
96
+ SpanType = Literal["span", "tool", "llm", "evaluation", "chain"]
84
97
 
85
- # --- Evaluation Config Dataclass (Moved from langgraph.py) ---
86
- @dataclass
87
- class EvaluationConfig:
88
- """Configuration for triggering an evaluation from the handler."""
89
- scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
90
- example: Example
91
- model: Optional[str] = None
92
- log_results: Optional[bool] = True
93
- # --- End Evaluation Config Dataclass ---
94
98
 
95
99
  # Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
96
100
  @dataclass
97
101
  class TraceAnnotation:
98
102
  """Represents a single annotation for a trace span."""
103
+
99
104
  span_id: str
100
105
  text: str
101
106
  label: str
@@ -105,24 +110,27 @@ class TraceAnnotation:
105
110
  """Convert the annotation to a dictionary format for storage/transmission."""
106
111
  return {
107
112
  "span_id": self.span_id,
108
- "annotation": {
109
- "text": self.text,
110
- "label": self.label,
111
- "score": self.score
112
- }
113
+ "annotation": {"text": self.text, "label": self.label, "score": self.score},
113
114
  }
114
-
115
+
116
+
115
117
  class TraceManagerClient:
116
118
  """
117
119
  Client for handling trace endpoints with the Judgment API
118
-
120
+
119
121
 
120
122
  Operations include:
121
123
  - Fetching a trace by id
122
124
  - Saving a trace
123
125
  - Deleting a trace
124
126
  """
125
- def __init__(self, judgment_api_key: str, organization_id: str, tracer: Optional["Tracer"] = None):
127
+
128
+ def __init__(
129
+ self,
130
+ judgment_api_key: str,
131
+ organization_id: str,
132
+ tracer: Optional["Tracer"] = None,
133
+ ):
126
134
  self.judgment_api_key = judgment_api_key
127
135
  self.organization_id = organization_id
128
136
  self.tracer = tracer
@@ -139,17 +147,19 @@ class TraceManagerClient:
139
147
  headers={
140
148
  "Content-Type": "application/json",
141
149
  "Authorization": f"Bearer {self.judgment_api_key}",
142
- "X-Organization-Id": self.organization_id
150
+ "X-Organization-Id": self.organization_id,
143
151
  },
144
- verify=True
152
+ verify=True,
145
153
  )
146
154
 
147
155
  if response.status_code != HTTPStatus.OK:
148
156
  raise ValueError(f"Failed to fetch traces: {response.text}")
149
-
157
+
150
158
  return response.json()
151
159
 
152
- def save_trace(self, trace_data: dict, offline_mode: bool = False, final_save: bool = True):
160
+ def save_trace(
161
+ self, trace_data: dict, offline_mode: bool = False, final_save: bool = True
162
+ ):
153
163
  """
154
164
  Saves a trace to the Judgment Supabase and optionally to S3 if configured.
155
165
 
@@ -158,12 +168,12 @@ class TraceManagerClient:
158
168
  offline_mode: Whether running in offline mode
159
169
  final_save: Whether this is the final save (controls S3 saving)
160
170
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
161
-
171
+
162
172
  Returns:
163
173
  dict: Server response containing UI URL and other metadata
164
174
  """
165
175
  # Save to Judgment API
166
-
176
+
167
177
  def fallback_encoder(obj):
168
178
  """
169
179
  Custom JSON encoder fallback.
@@ -180,7 +190,7 @@ class TraceManagerClient:
180
190
  except Exception as e:
181
191
  # If both fail, you might return a placeholder or re-raise
182
192
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
183
-
193
+
184
194
  serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
185
195
  response = requests.post(
186
196
  JUDGMENT_TRACES_SAVE_API_URL,
@@ -188,71 +198,46 @@ class TraceManagerClient:
188
198
  headers={
189
199
  "Content-Type": "application/json",
190
200
  "Authorization": f"Bearer {self.judgment_api_key}",
191
- "X-Organization-Id": self.organization_id
201
+ "X-Organization-Id": self.organization_id,
192
202
  },
193
- verify=True
203
+ verify=True,
194
204
  )
195
-
205
+
196
206
  if response.status_code == HTTPStatus.BAD_REQUEST:
197
- raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
207
+ raise ValueError(
208
+ f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}"
209
+ )
198
210
  elif response.status_code != HTTPStatus.OK:
199
211
  raise ValueError(f"Failed to save trace data: {response.text}")
200
-
212
+
201
213
  # Parse server response
202
214
  server_response = response.json()
203
-
215
+
204
216
  # If S3 storage is enabled, save to S3 only on final save
205
217
  if self.tracer and self.tracer.use_s3 and final_save:
206
218
  try:
207
219
  s3_key = self.tracer.s3_storage.save_trace(
208
220
  trace_data=trace_data,
209
221
  trace_id=trace_data["trace_id"],
210
- project_name=trace_data["project_name"]
222
+ project_name=trace_data["project_name"],
211
223
  )
212
224
  print(f"Trace also saved to S3 at key: {s3_key}")
213
225
  except Exception as e:
214
226
  warnings.warn(f"Failed to save trace to S3: {str(e)}")
215
-
227
+
216
228
  if not offline_mode and "ui_results_url" in server_response:
217
229
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
218
230
  rprint(pretty_str)
219
-
220
- return server_response
221
231
 
222
- def check_usage_limits(self, count: int = 1):
223
- """
224
- Check if the organization can use the requested number of traces without exceeding limits.
225
-
226
- Args:
227
- count: Number of traces to check for (default: 1)
228
-
229
- Returns:
230
- dict: Server response with rate limit status and usage info
231
-
232
- Raises:
233
- ValueError: If rate limits would be exceeded or other errors occur
234
- """
235
- response = requests.post(
236
- JUDGMENT_TRACES_USAGE_CHECK_API_URL,
237
- json={"count": count},
238
- headers={
239
- "Content-Type": "application/json",
240
- "Authorization": f"Bearer {self.judgment_api_key}",
241
- "X-Organization-Id": self.organization_id
242
- },
243
- verify=True
244
- )
245
-
246
- if response.status_code == HTTPStatus.FORBIDDEN:
247
- # Rate limits exceeded
248
- error_data = response.json()
249
- raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
250
- elif response.status_code != HTTPStatus.OK:
251
- raise ValueError(f"Failed to check usage limits: {response.text}")
252
-
253
- return response.json()
232
+ return server_response
254
233
 
255
- def upsert_trace(self, trace_data: dict, offline_mode: bool = False, show_link: bool = True, final_save: bool = True):
234
+ def upsert_trace(
235
+ self,
236
+ trace_data: dict,
237
+ offline_mode: bool = False,
238
+ show_link: bool = True,
239
+ final_save: bool = True,
240
+ ):
256
241
  """
257
242
  Upserts a trace to the Judgment API (always overwrites if exists).
258
243
 
@@ -261,10 +246,11 @@ class TraceManagerClient:
261
246
  offline_mode: Whether running in offline mode
262
247
  show_link: Whether to show the UI link (for live tracing)
263
248
  final_save: Whether this is the final save (controls S3 saving)
264
-
249
+
265
250
  Returns:
266
251
  dict: Server response containing UI URL and other metadata
267
252
  """
253
+
268
254
  def fallback_encoder(obj):
269
255
  """
270
256
  Custom JSON encoder fallback.
@@ -277,7 +263,7 @@ class TraceManagerClient:
277
263
  return str(obj)
278
264
  except Exception as e:
279
265
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
280
-
266
+
281
267
  serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
282
268
 
283
269
  response = requests.post(
@@ -286,63 +272,34 @@ class TraceManagerClient:
286
272
  headers={
287
273
  "Content-Type": "application/json",
288
274
  "Authorization": f"Bearer {self.judgment_api_key}",
289
- "X-Organization-Id": self.organization_id
275
+ "X-Organization-Id": self.organization_id,
290
276
  },
291
- verify=True
277
+ verify=True,
292
278
  )
293
-
279
+
294
280
  if response.status_code != HTTPStatus.OK:
295
281
  raise ValueError(f"Failed to upsert trace data: {response.text}")
296
-
282
+
297
283
  # Parse server response
298
284
  server_response = response.json()
299
-
285
+
300
286
  # If S3 storage is enabled, save to S3 only on final save
301
287
  if self.tracer and self.tracer.use_s3 and final_save:
302
288
  try:
303
289
  s3_key = self.tracer.s3_storage.save_trace(
304
290
  trace_data=trace_data,
305
291
  trace_id=trace_data["trace_id"],
306
- project_name=trace_data["project_name"]
292
+ project_name=trace_data["project_name"],
307
293
  )
308
294
  print(f"Trace also saved to S3 at key: {s3_key}")
309
295
  except Exception as e:
310
296
  warnings.warn(f"Failed to save trace to S3: {str(e)}")
311
-
297
+
312
298
  if not offline_mode and show_link and "ui_results_url" in server_response:
313
299
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
314
300
  rprint(pretty_str)
315
-
316
- return server_response
317
301
 
318
- def update_usage_counters(self, count: int = 1):
319
- """
320
- Update trace usage counters after successfully saving traces.
321
-
322
- Args:
323
- count: Number of traces to count (default: 1)
324
-
325
- Returns:
326
- dict: Server response with updated usage information
327
-
328
- Raises:
329
- ValueError: If the update fails
330
- """
331
- response = requests.post(
332
- JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
333
- json={"count": count},
334
- headers={
335
- "Content-Type": "application/json",
336
- "Authorization": f"Bearer {self.judgment_api_key}",
337
- "X-Organization-Id": self.organization_id
338
- },
339
- verify=True
340
- )
341
-
342
- if response.status_code != HTTPStatus.OK:
343
- raise ValueError(f"Failed to update usage counters: {response.text}")
344
-
345
- return response.json()
302
+ return server_response
346
303
 
347
304
  ## TODO: Should have a log endpoint, endpoint should also support batched payloads
348
305
  def save_annotation(self, annotation: TraceAnnotation):
@@ -351,24 +308,24 @@ class TraceManagerClient:
351
308
  "annotation": {
352
309
  "text": annotation.text,
353
310
  "label": annotation.label,
354
- "score": annotation.score
355
- }
356
- }
311
+ "score": annotation.score,
312
+ },
313
+ }
357
314
 
358
315
  response = requests.post(
359
316
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
360
317
  json=json_data,
361
318
  headers={
362
- 'Content-Type': 'application/json',
363
- 'Authorization': f'Bearer {self.judgment_api_key}',
364
- 'X-Organization-Id': self.organization_id
319
+ "Content-Type": "application/json",
320
+ "Authorization": f"Bearer {self.judgment_api_key}",
321
+ "X-Organization-Id": self.organization_id,
365
322
  },
366
- verify=True
323
+ verify=True,
367
324
  )
368
-
325
+
369
326
  if response.status_code != HTTPStatus.OK:
370
327
  raise ValueError(f"Failed to save annotation: {response.text}")
371
-
328
+
372
329
  return response.json()
373
330
 
374
331
  def delete_trace(self, trace_id: str):
@@ -383,15 +340,15 @@ class TraceManagerClient:
383
340
  headers={
384
341
  "Content-Type": "application/json",
385
342
  "Authorization": f"Bearer {self.judgment_api_key}",
386
- "X-Organization-Id": self.organization_id
387
- }
343
+ "X-Organization-Id": self.organization_id,
344
+ },
388
345
  )
389
346
 
390
347
  if response.status_code != HTTPStatus.OK:
391
348
  raise ValueError(f"Failed to delete trace: {response.text}")
392
-
349
+
393
350
  return response.json()
394
-
351
+
395
352
  def delete_traces(self, trace_ids: List[str]):
396
353
  """
397
354
  Delete a batch of traces from the database.
@@ -404,15 +361,15 @@ class TraceManagerClient:
404
361
  headers={
405
362
  "Content-Type": "application/json",
406
363
  "Authorization": f"Bearer {self.judgment_api_key}",
407
- "X-Organization-Id": self.organization_id
408
- }
364
+ "X-Organization-Id": self.organization_id,
365
+ },
409
366
  )
410
367
 
411
368
  if response.status_code != HTTPStatus.OK:
412
369
  raise ValueError(f"Failed to delete trace: {response.text}")
413
-
370
+
414
371
  return response.json()
415
-
372
+
416
373
  def delete_project(self, project_name: str):
417
374
  """
418
375
  Deletes a project from the server. Which also deletes all evaluations and traces associated with the project.
@@ -425,31 +382,31 @@ class TraceManagerClient:
425
382
  headers={
426
383
  "Content-Type": "application/json",
427
384
  "Authorization": f"Bearer {self.judgment_api_key}",
428
- "X-Organization-Id": self.organization_id
429
- }
385
+ "X-Organization-Id": self.organization_id,
386
+ },
430
387
  )
431
388
 
432
389
  if response.status_code != HTTPStatus.OK:
433
390
  raise ValueError(f"Failed to delete traces: {response.text}")
434
-
391
+
435
392
  return response.json()
436
393
 
437
394
 
438
395
  class TraceClient:
439
396
  """Client for managing a single trace context"""
440
-
397
+
441
398
  def __init__(
442
399
  self,
443
- tracer: Optional["Tracer"],
400
+ tracer: "Tracer",
444
401
  trace_id: Optional[str] = None,
445
402
  name: str = "default",
446
- project_name: str = None,
403
+ project_name: str | None = None,
447
404
  overwrite: bool = False,
448
405
  rules: Optional[List[Rule]] = None,
449
406
  enable_monitoring: bool = True,
450
407
  enable_evaluations: bool = True,
451
408
  parent_trace_id: Optional[str] = None,
452
- parent_name: Optional[str] = None
409
+ parent_name: Optional[str] = None,
453
410
  ):
454
411
  self.name = name
455
412
  self.trace_id = trace_id or str(uuid.uuid4())
@@ -461,39 +418,48 @@ class TraceClient:
461
418
  self.enable_evaluations = enable_evaluations
462
419
  self.parent_trace_id = parent_trace_id
463
420
  self.parent_name = parent_name
421
+ self.customer_id: Optional[str] = None # Added customer_id attribute
422
+ self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
423
+ self.metadata: Dict[str, Any] = {}
424
+ self.has_notification: Optional[bool] = False # Initialize has_notification
464
425
  self.trace_spans: List[TraceSpan] = []
465
426
  self.span_id_to_span: Dict[str, TraceSpan] = {}
466
427
  self.evaluation_runs: List[EvaluationRun] = []
467
428
  self.annotations: List[TraceAnnotation] = []
468
- self.start_time = None # Will be set after first successful save
469
- self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
470
- self.visited_nodes = []
471
- self.executed_tools = []
472
- self.executed_node_tools = []
473
- self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
474
-
429
+ self.start_time: Optional[float] = (
430
+ None # Will be set after first successful save
431
+ )
432
+ self.trace_manager_client = TraceManagerClient(
433
+ tracer.api_key, tracer.organization_id, tracer
434
+ )
435
+ self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
436
+
475
437
  # Get background span service from tracer
476
- self.background_span_service = tracer.get_background_span_service() if tracer else None
438
+ self.background_span_service = (
439
+ tracer.get_background_span_service() if tracer else None
440
+ )
477
441
 
478
442
  def get_current_span(self):
479
443
  """Get the current span from the context var"""
480
444
  return self.tracer.get_current_span()
481
-
445
+
482
446
  def set_current_span(self, span: Any):
483
447
  """Set the current span from the context var"""
484
448
  return self.tracer.set_current_span(span)
485
-
449
+
486
450
  def reset_current_span(self, token: Any):
487
451
  """Reset the current span from the context var"""
488
452
  self.tracer.reset_current_span(token)
489
-
453
+
490
454
  @contextmanager
491
455
  def span(self, name: str, span_type: SpanType = "span"):
492
456
  """Context manager for creating a trace span, managing the current span via contextvars"""
493
457
  is_first_span = len(self.trace_spans) == 0
494
458
  if is_first_span:
495
459
  try:
496
- trace_id, server_response = self.save_with_rate_limiting(overwrite=self.overwrite, final_save=False)
460
+ trace_id, server_response = self.save(
461
+ overwrite=self.overwrite, final_save=False
462
+ )
497
463
  # Set start_time after first successful save
498
464
  if self.start_time is None:
499
465
  self.start_time = time.time()
@@ -501,19 +467,23 @@ class TraceClient:
501
467
  except Exception as e:
502
468
  warnings.warn(f"Failed to save initial trace for live tracking: {e}")
503
469
  start_time = time.time()
504
-
470
+
505
471
  # Generate a unique ID for *this specific span invocation*
506
472
  span_id = str(uuid.uuid4())
507
-
508
- parent_span_id = self.get_current_span() # Get ID of the parent span from context var
509
- token = self.set_current_span(span_id) # Set *this* span's ID as the current one
510
-
473
+
474
+ parent_span_id = (
475
+ self.get_current_span()
476
+ ) # Get ID of the parent span from context var
477
+ token = self.set_current_span(
478
+ span_id
479
+ ) # Set *this* span's ID as the current one
480
+
511
481
  current_depth = 0
512
482
  if parent_span_id and parent_span_id in self._span_depths:
513
483
  current_depth = self._span_depths[parent_span_id] + 1
514
-
515
- self._span_depths[span_id] = current_depth # Store depth by span_id
516
-
484
+
485
+ self._span_depths[span_id] = current_depth # Store depth by span_id
486
+
517
487
  span = TraceSpan(
518
488
  span_id=span_id,
519
489
  trace_id=self.trace_id,
@@ -525,23 +495,21 @@ class TraceClient:
525
495
  function=name,
526
496
  )
527
497
  self.add_span(span)
528
-
529
-
530
-
498
+
531
499
  # Queue span with initial state (input phase)
532
500
  if self.background_span_service:
533
501
  self.background_span_service.queue_span(span, span_state="input")
534
-
502
+
535
503
  try:
536
504
  yield self
537
505
  finally:
538
506
  duration = time.time() - start_time
539
507
  span.duration = duration
540
-
508
+
541
509
  # Queue span with completed state (output phase)
542
510
  if self.background_span_service:
543
511
  self.background_span_service.queue_span(span, span_state="completed")
544
-
512
+
545
513
  # Clean up depth tracking for this span_id
546
514
  if span_id in self._span_depths:
547
515
  del self._span_depths[span_id]
@@ -561,12 +529,11 @@ class TraceClient:
561
529
  expected_tools: Optional[List[str]] = None,
562
530
  additional_metadata: Optional[Dict[str, Any]] = None,
563
531
  model: Optional[str] = None,
564
- span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
565
- log_results: Optional[bool] = True
532
+ span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
566
533
  ):
567
534
  if not self.enable_evaluations:
568
535
  return
569
-
536
+
570
537
  start_time = time.time() # Record start time
571
538
 
572
539
  try:
@@ -574,21 +541,35 @@ class TraceClient:
574
541
  if not scorers:
575
542
  warnings.warn("No valid scorers available for evaluation")
576
543
  return
577
-
544
+
578
545
  # Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
579
- if self.rules and any(isinstance(scorer, JudgevalScorer) for scorer in scorers):
580
- raise ValueError("Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types.")
581
-
546
+ if self.rules and any(
547
+ isinstance(scorer, JudgevalScorer) for scorer in scorers
548
+ ):
549
+ raise ValueError(
550
+ "Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types."
551
+ )
552
+
582
553
  except Exception as e:
583
554
  warnings.warn(f"Failed to load scorers: {str(e)}")
584
555
  return
585
-
556
+
586
557
  # If example is not provided, create one from the individual parameters
587
558
  if example is None:
588
559
  # Check if any of the individual parameters are provided
589
- if any(param is not None for param in [input, actual_output, expected_output, context,
590
- retrieval_context, tools_called, expected_tools,
591
- additional_metadata]):
560
+ if any(
561
+ param is not None
562
+ for param in [
563
+ input,
564
+ actual_output,
565
+ expected_output,
566
+ context,
567
+ retrieval_context,
568
+ tools_called,
569
+ expected_tools,
570
+ additional_metadata,
571
+ ]
572
+ ):
592
573
  example = Example(
593
574
  input=input,
594
575
  actual_output=actual_output,
@@ -600,12 +581,14 @@ class TraceClient:
600
581
  additional_metadata=additional_metadata,
601
582
  )
602
583
  else:
603
- raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
604
-
584
+ raise ValueError(
585
+ "Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided"
586
+ )
587
+
605
588
  # Check examples before creating evaluation run
606
-
589
+
607
590
  # check_examples([example], scorers)
608
-
591
+
609
592
  # --- Modification: Capture span_id immediately ---
610
593
  # span_id_at_eval_call = current_span_var.get()
611
594
  # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
@@ -617,37 +600,32 @@ class TraceClient:
617
600
  # Combine the trace-level rules with any evaluation-specific rules)
618
601
  eval_run = EvaluationRun(
619
602
  organization_id=self.tracer.organization_id,
620
- log_results=log_results,
621
603
  project_name=self.project_name,
622
604
  eval_name=f"{self.name.capitalize()}-"
623
- f"{span_id_to_use}-" # Keep original eval name format using context var if available
624
- f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
605
+ f"{span_id_to_use}-" # Keep original eval name format using context var if available
606
+ f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
625
607
  examples=[example],
626
608
  scorers=scorers,
627
609
  model=model,
628
- metadata={},
629
610
  judgment_api_key=self.tracer.api_key,
630
611
  override=self.overwrite,
631
- trace_span_id=span_id_to_use, # Pass the determined ID
632
- rules=self.rules # Use the combined rules
612
+ trace_span_id=span_id_to_use,
633
613
  )
634
-
614
+
635
615
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
636
-
616
+
637
617
  # Queue evaluation run through background service
638
618
  if self.background_span_service and span_id_to_use:
639
619
  # Get the current span data to avoid race conditions
640
620
  current_span = self.span_id_to_span.get(span_id_to_use)
641
621
  if current_span:
642
622
  self.background_span_service.queue_evaluation_run(
643
- eval_run,
644
- span_id=span_id_to_use,
645
- span_data=current_span
623
+ eval_run, span_id=span_id_to_use, span_data=current_span
646
624
  )
647
-
625
+
648
626
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
649
- # --- Modification: Use span_id from eval_run ---
650
- current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
627
+ # --- Modification: Use span_id from eval_run ---
628
+ current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
651
629
  # print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
652
630
  # --- End Modification ---
653
631
 
@@ -658,10 +636,10 @@ class TraceClient:
658
636
  self.evaluation_runs.append(eval_run)
659
637
 
660
638
  def add_annotation(self, annotation: TraceAnnotation):
661
- """Add an annotation to this trace context"""
662
- self.annotations.append(annotation)
663
- return self
664
-
639
+ """Add an annotation to this trace context"""
640
+ self.annotations.append(annotation)
641
+ return self
642
+
665
643
  def record_input(self, inputs: dict):
666
644
  current_span_id = self.get_current_span()
667
645
  if current_span_id:
@@ -670,17 +648,20 @@ class TraceClient:
670
648
  if "self" in inputs:
671
649
  del inputs["self"]
672
650
  span.inputs = inputs
673
-
651
+
674
652
  # Queue span with input data
675
- if self.background_span_service:
676
- self.background_span_service.queue_span(span, span_state="input")
677
-
653
+ try:
654
+ if self.background_span_service:
655
+ self.background_span_service.queue_span(span, span_state="input")
656
+ except Exception as e:
657
+ warnings.warn(f"Failed to queue span with input data: {e}")
658
+
678
659
  def record_agent_name(self, agent_name: str):
679
660
  current_span_id = self.get_current_span()
680
661
  if current_span_id:
681
662
  span = self.span_id_to_span[current_span_id]
682
663
  span.agent_name = agent_name
683
-
664
+
684
665
  # Queue span with agent_name data
685
666
  if self.background_span_service:
686
667
  self.background_span_service.queue_span(span, span_state="agent_name")
@@ -695,11 +676,11 @@ class TraceClient:
695
676
  if current_span_id:
696
677
  span = self.span_id_to_span[current_span_id]
697
678
  span.state_before = state
698
-
679
+
699
680
  # Queue span with state_before data
700
681
  if self.background_span_service:
701
682
  self.background_span_service.queue_span(span, span_state="state_before")
702
-
683
+
703
684
  def record_state_after(self, state: dict):
704
685
  """Records the agent's state after a tool execution on the current span.
705
686
 
@@ -710,7 +691,7 @@ class TraceClient:
710
691
  if current_span_id:
711
692
  span = self.span_id_to_span[current_span_id]
712
693
  span.state_after = state
713
-
694
+
714
695
  # Queue span with state_after data
715
696
  if self.background_span_service:
716
697
  self.background_span_service.queue_span(span, span_state="state_after")
@@ -720,19 +701,19 @@ class TraceClient:
720
701
  try:
721
702
  result = await coroutine
722
703
  setattr(span, field, result)
723
-
704
+
724
705
  # Queue span with output data now that coroutine is complete
725
706
  if self.background_span_service and field == "output":
726
707
  self.background_span_service.queue_span(span, span_state="output")
727
-
708
+
728
709
  return result
729
710
  except Exception as e:
730
711
  setattr(span, field, f"Error: {str(e)}")
731
-
712
+
732
713
  # Queue span even if there was an error
733
714
  if self.background_span_service and field == "output":
734
715
  self.background_span_service.queue_span(span, span_state="output")
735
-
716
+
736
717
  raise
737
718
 
738
719
  def record_output(self, output: Any):
@@ -740,56 +721,56 @@ class TraceClient:
740
721
  if current_span_id:
741
722
  span = self.span_id_to_span[current_span_id]
742
723
  span.output = "<pending>" if inspect.iscoroutine(output) else output
743
-
724
+
744
725
  if inspect.iscoroutine(output):
745
726
  asyncio.create_task(self._update_coroutine(span, output, "output"))
746
-
727
+
747
728
  # # Queue span with output data (unless it's pending)
748
729
  if self.background_span_service and not inspect.iscoroutine(output):
749
730
  self.background_span_service.queue_span(span, span_state="output")
750
731
 
751
- return span # Return the created entry
732
+ return span # Return the created entry
752
733
  # Removed else block - original didn't have one
753
- return None # Return None if no span_id found
754
-
734
+ return None # Return None if no span_id found
735
+
755
736
  def record_usage(self, usage: TraceUsage):
756
737
  current_span_id = self.get_current_span()
757
738
  if current_span_id:
758
739
  span = self.span_id_to_span[current_span_id]
759
740
  span.usage = usage
760
-
741
+
761
742
  # Queue span with usage data
762
743
  if self.background_span_service:
763
744
  self.background_span_service.queue_span(span, span_state="usage")
764
-
765
- return span # Return the created entry
745
+
746
+ return span # Return the created entry
766
747
  # Removed else block - original didn't have one
767
- return None # Return None if no span_id found
768
-
748
+ return None # Return None if no span_id found
749
+
769
750
  def record_error(self, error: Dict[str, Any]):
770
751
  current_span_id = self.get_current_span()
771
752
  if current_span_id:
772
753
  span = self.span_id_to_span[current_span_id]
773
754
  span.error = error
774
-
755
+
775
756
  # Queue span with error data
776
757
  if self.background_span_service:
777
758
  self.background_span_service.queue_span(span, span_state="error")
778
-
759
+
779
760
  return span
780
761
  return None
781
-
762
+
782
763
  def add_span(self, span: TraceSpan):
783
764
  """Add a trace span to this trace context"""
784
765
  self.trace_spans.append(span)
785
766
  self.span_id_to_span[span.span_id] = span
786
767
  return self
787
-
768
+
788
769
  def print(self):
789
770
  """Print the complete trace with proper visual structure"""
790
771
  for span in self.trace_spans:
791
772
  span.print_span()
792
-
773
+
793
774
  def get_duration(self) -> float:
794
775
  """
795
776
  Get the total duration of this trace
@@ -798,57 +779,23 @@ class TraceClient:
798
779
  return 0.0 # No duration if trace hasn't been saved yet
799
780
  return time.time() - self.start_time
800
781
 
801
- def save(self, overwrite: bool = False) -> Tuple[str, dict]:
802
- """
803
- Save the current trace to the database.
804
- Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
805
- """
806
- # Set start_time if this is the first save
807
- if self.start_time is None:
808
- self.start_time = time.time()
809
-
810
- # Calculate total elapsed time
811
- total_duration = self.get_duration()
812
- # Create trace document - Always use standard keys for top-level counts
813
- trace_data = {
814
- "trace_id": self.trace_id,
815
- "name": self.name,
816
- "project_name": self.project_name,
817
- "created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
818
- "duration": total_duration,
819
- "trace_spans": [span.model_dump() for span in self.trace_spans],
820
- "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
821
- "overwrite": overwrite,
822
- "offline_mode": self.tracer.offline_mode,
823
- "parent_trace_id": self.parent_trace_id,
824
- "parent_name": self.parent_name
825
- }
826
- # --- Log trace data before saving ---
827
- server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
828
-
829
- # upload annotations
830
- # TODO: batch to the log endpoint
831
- for annotation in self.annotations:
832
- self.trace_manager_client.save_annotation(annotation)
833
-
834
- return self.trace_id, server_response
835
-
836
- def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
782
+ def save(
783
+ self, overwrite: bool = False, final_save: bool = False
784
+ ) -> Tuple[str, dict]:
837
785
  """
838
786
  Save the current trace to the database with rate limiting checks.
839
787
  First checks usage limits, then upserts the trace if allowed.
840
-
788
+
841
789
  Args:
842
790
  overwrite: Whether to overwrite existing traces
843
791
  final_save: Whether this is the final save (updates usage counters)
844
-
792
+
845
793
  Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
846
794
  """
847
795
 
848
-
849
796
  # Calculate total elapsed time
850
797
  total_duration = self.get_duration()
851
-
798
+
852
799
  # Create trace document
853
800
  trace_data = {
854
801
  "trace_id": self.trace_id,
@@ -861,35 +808,20 @@ class TraceClient:
861
808
  "overwrite": overwrite,
862
809
  "offline_mode": self.tracer.offline_mode,
863
810
  "parent_trace_id": self.parent_trace_id,
864
- "parent_name": self.parent_name
811
+ "parent_name": self.parent_name,
812
+ "customer_id": self.customer_id,
813
+ "tags": self.tags,
814
+ "metadata": self.metadata,
865
815
  }
866
-
867
- # Check usage limits first
868
- try:
869
- usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
870
- # Usage check passed silently - no need to show detailed info
871
- except ValueError as e:
872
- # Rate limit exceeded
873
- warnings.warn(f"Rate limit check failed for live tracing: {e}")
874
- raise e
875
-
816
+
876
817
  # If usage check passes, upsert the trace
877
818
  server_response = self.trace_manager_client.upsert_trace(
878
- trace_data,
819
+ trace_data,
879
820
  offline_mode=self.tracer.offline_mode,
880
821
  show_link=not final_save, # Show link only on initial save, not final save
881
- final_save=final_save # Pass final_save to control S3 saving
822
+ final_save=final_save, # Pass final_save to control S3 saving
882
823
  )
883
824
 
884
- # Update usage counters only on final save
885
- if final_save:
886
- try:
887
- usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
888
- # Usage updated silently - no need to show detailed usage info
889
- except ValueError as e:
890
- # Log warning but don't fail the trace save since the trace was already saved
891
- warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
892
-
893
825
  # Upload annotations
894
826
  # TODO: batch to the log endpoint
895
827
  for annotation in self.annotations:
@@ -900,8 +832,77 @@ class TraceClient:
900
832
 
901
833
  def delete(self):
902
834
  return self.trace_manager_client.delete_trace(self.trace_id)
903
-
904
- def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: ExcInfo):
835
+
836
+ def update_metadata(self, metadata: dict):
837
+ """
838
+ Set metadata for this trace.
839
+
840
+ Args:
841
+ metadata: Metadata as a dictionary
842
+
843
+ Supported keys:
844
+ - customer_id: ID of the customer using this trace
845
+ - tags: List of tags for this trace
846
+ - has_notification: Whether this trace has a notification
847
+ - overwrite: Whether to overwrite existing traces
848
+ - name: Name of the trace
849
+ """
850
+ for k, v in metadata.items():
851
+ if k == "customer_id":
852
+ if v is not None:
853
+ self.customer_id = str(v)
854
+ else:
855
+ self.customer_id = None
856
+ elif k == "tags":
857
+ if isinstance(v, list):
858
+ # Validate that all items in the list are of the expected types
859
+ for item in v:
860
+ if not isinstance(item, (str, set, tuple)):
861
+ raise ValueError(
862
+ f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
863
+ )
864
+ self.tags = v
865
+ else:
866
+ raise ValueError(
867
+ f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
868
+ )
869
+ elif k == "has_notification":
870
+ if not isinstance(v, bool):
871
+ raise ValueError(
872
+ f"has_notification must be a boolean, got {type(v)}"
873
+ )
874
+ self.has_notification = v
875
+ elif k == "overwrite":
876
+ if not isinstance(v, bool):
877
+ raise ValueError(f"overwrite must be a boolean, got {type(v)}")
878
+ self.overwrite = v
879
+ elif k == "name":
880
+ self.name = v
881
+ else:
882
+ self.metadata[k] = v
883
+
884
+ def set_customer_id(self, customer_id: str):
885
+ """
886
+ Set the customer ID for this trace.
887
+
888
+ Args:
889
+ customer_id: The customer ID to set
890
+ """
891
+ self.update_metadata({"customer_id": customer_id})
892
+
893
+ def set_tags(self, tags: List[Union[str, set, tuple]]):
894
+ """
895
+ Set the tags for this trace.
896
+
897
+ Args:
898
+ tags: List of tags to set
899
+ """
900
+ self.update_metadata({"tags": tags})
901
+
902
+
903
+ def _capture_exception_for_trace(
904
+ current_trace: Optional["TraceClient"], exc_info: ExcInfo
905
+ ):
905
906
  if not current_trace:
906
907
  return
907
908
 
@@ -909,18 +910,20 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
909
910
  formatted_exception = {
910
911
  "type": exc_type.__name__ if exc_type else "UnknownExceptionType",
911
912
  "message": str(exc_value) if exc_value else "No exception message",
912
- "traceback": traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
913
+ "traceback": traceback.format_tb(exc_traceback_obj)
914
+ if exc_traceback_obj
915
+ else [],
913
916
  }
914
-
917
+
915
918
  # This is where we specially handle exceptions that we might want to collect additional data for.
916
919
  # When we do this, always try checking the module from sys.modules instead of importing. This will
917
920
  # Let us support a wider range of exceptions without needing to import them for all clients.
918
-
919
- # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
921
+
922
+ # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
920
923
  # The alternative is to hand select libraries we want from sys.modules and check for them:
921
924
  # As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
922
925
 
923
- # General HTTP Like errors
926
+ # General HTTP Like errors
924
927
  try:
925
928
  url = getattr(getattr(exc_value, "request", None), "url", None)
926
929
  status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
@@ -929,34 +932,43 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
929
932
  "url": url if url else "Unknown URL",
930
933
  "status_code": status_code if status_code else None,
931
934
  }
932
- except Exception as e:
935
+ except Exception:
933
936
  pass
934
937
 
935
938
  current_trace.record_error(formatted_exception)
936
-
939
+
937
940
  # Queue the span with error state through background service
938
941
  if current_trace.background_span_service:
939
942
  current_span_id = current_trace.get_current_span()
940
943
  if current_span_id and current_span_id in current_trace.span_id_to_span:
941
944
  error_span = current_trace.span_id_to_span[current_span_id]
942
- current_trace.background_span_service.queue_span(error_span, span_state="error")
945
+ current_trace.background_span_service.queue_span(
946
+ error_span, span_state="error"
947
+ )
948
+
943
949
 
944
950
  class BackgroundSpanService:
945
951
  """
946
952
  Background service for queueing and batching trace spans for efficient saving.
947
-
953
+
948
954
  This service:
949
955
  - Queues spans as they complete
950
956
  - Batches them for efficient network usage
951
957
  - Sends spans periodically or when batches reach a certain size
952
958
  - Handles automatic flushing when the main event terminates
953
959
  """
954
-
955
- def __init__(self, judgment_api_key: str, organization_id: str,
956
- batch_size: int = 10, flush_interval: float = 5.0, num_workers: int = 1):
960
+
961
+ def __init__(
962
+ self,
963
+ judgment_api_key: str,
964
+ organization_id: str,
965
+ batch_size: int = 10,
966
+ flush_interval: float = 5.0,
967
+ num_workers: int = 1,
968
+ ):
957
969
  """
958
970
  Initialize the background span service.
959
-
971
+
960
972
  Args:
961
973
  judgment_api_key: API key for Judgment service
962
974
  organization_id: Organization ID
@@ -969,41 +981,41 @@ class BackgroundSpanService:
969
981
  self.batch_size = batch_size
970
982
  self.flush_interval = flush_interval
971
983
  self.num_workers = max(1, num_workers) # Ensure at least 1 worker
972
-
984
+
973
985
  # Queue for pending spans
974
- self._span_queue = queue.Queue()
975
-
986
+ self._span_queue: queue.Queue[Dict[str, Any]] = queue.Queue()
987
+
976
988
  # Background threads for processing spans
977
- self._worker_threads = []
989
+ self._worker_threads: List[threading.Thread] = []
978
990
  self._shutdown_event = threading.Event()
979
-
991
+
980
992
  # Track spans that have been sent
981
- self._sent_spans = set()
982
-
993
+ # self._sent_spans = set()
994
+
983
995
  # Register cleanup on exit
984
996
  atexit.register(self.shutdown)
985
-
997
+
986
998
  # Start the background workers
987
999
  self._start_workers()
988
-
1000
+
989
1001
  def _start_workers(self):
990
1002
  """Start the background worker threads."""
991
1003
  for i in range(self.num_workers):
992
1004
  if len(self._worker_threads) < self.num_workers:
993
1005
  worker_thread = threading.Thread(
994
- target=self._worker_loop,
995
- daemon=True,
996
- name=f"SpanWorker-{i+1}"
1006
+ target=self._worker_loop, daemon=True, name=f"SpanWorker-{i + 1}"
997
1007
  )
998
1008
  worker_thread.start()
999
1009
  self._worker_threads.append(worker_thread)
1000
-
1010
+
1001
1011
  def _worker_loop(self):
1002
1012
  """Main worker loop that processes spans in batches."""
1003
1013
  batch = []
1004
1014
  last_flush_time = time.time()
1005
- pending_task_count = 0 # Track how many tasks we've taken from queue but not marked done
1006
-
1015
+ pending_task_count = (
1016
+ 0 # Track how many tasks we've taken from queue but not marked done
1017
+ )
1018
+
1007
1019
  while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
1008
1020
  try:
1009
1021
  # First, do a blocking get to wait for at least one item
@@ -1015,7 +1027,7 @@ class BackgroundSpanService:
1015
1027
  except queue.Empty:
1016
1028
  # No new spans, continue to check for flush conditions
1017
1029
  pass
1018
-
1030
+
1019
1031
  # Then, do non-blocking gets to drain any additional available items
1020
1032
  # up to our batch size limit
1021
1033
  while len(batch) < self.batch_size:
@@ -1026,24 +1038,23 @@ class BackgroundSpanService:
1026
1038
  except queue.Empty:
1027
1039
  # No more items immediately available
1028
1040
  break
1029
-
1041
+
1030
1042
  current_time = time.time()
1031
- should_flush = (
1032
- len(batch) >= self.batch_size or
1033
- (batch and (current_time - last_flush_time) >= self.flush_interval)
1043
+ should_flush = len(batch) >= self.batch_size or (
1044
+ batch and (current_time - last_flush_time) >= self.flush_interval
1034
1045
  )
1035
-
1046
+
1036
1047
  if should_flush and batch:
1037
1048
  self._send_batch(batch)
1038
-
1049
+
1039
1050
  # Only mark tasks as done after successful sending
1040
1051
  for _ in range(pending_task_count):
1041
1052
  self._span_queue.task_done()
1042
1053
  pending_task_count = 0 # Reset counter
1043
-
1054
+
1044
1055
  batch.clear()
1045
1056
  last_flush_time = current_time
1046
-
1057
+
1047
1058
  except Exception as e:
1048
1059
  warnings.warn(f"Error in span service worker loop: {e}")
1049
1060
  # On error, still need to mark tasks as done to prevent hanging
@@ -1051,53 +1062,50 @@ class BackgroundSpanService:
1051
1062
  self._span_queue.task_done()
1052
1063
  pending_task_count = 0
1053
1064
  batch.clear()
1054
-
1065
+
1055
1066
  # Final flush on shutdown
1056
1067
  if batch:
1057
1068
  self._send_batch(batch)
1058
1069
  # Mark remaining tasks as done
1059
1070
  for _ in range(pending_task_count):
1060
1071
  self._span_queue.task_done()
1061
-
1072
+
1062
1073
  def _send_batch(self, batch: List[Dict[str, Any]]):
1063
1074
  """
1064
1075
  Send a batch of spans to the server.
1065
-
1076
+
1066
1077
  Args:
1067
1078
  batch: List of span dictionaries to send
1068
1079
  """
1069
1080
  if not batch:
1070
1081
  return
1071
-
1082
+
1072
1083
  try:
1073
- # Group spans by type for different endpoints
1084
+ # Group items by type for different endpoints
1074
1085
  spans_to_send = []
1075
1086
  evaluation_runs_to_send = []
1076
-
1087
+
1077
1088
  for item in batch:
1078
- if item['type'] == 'span':
1079
- spans_to_send.append(item['data'])
1080
- elif item['type'] == 'evaluation_run':
1081
- evaluation_runs_to_send.append(item['data'])
1082
-
1089
+ if item["type"] == "span":
1090
+ spans_to_send.append(item["data"])
1091
+ elif item["type"] == "evaluation_run":
1092
+ evaluation_runs_to_send.append(item["data"])
1093
+
1083
1094
  # Send spans if any
1084
1095
  if spans_to_send:
1085
1096
  self._send_spans_batch(spans_to_send)
1086
-
1097
+
1087
1098
  # Send evaluation runs if any
1088
1099
  if evaluation_runs_to_send:
1089
1100
  self._send_evaluation_runs_batch(evaluation_runs_to_send)
1090
-
1101
+
1091
1102
  except Exception as e:
1092
- warnings.warn(f"Failed to send span batch: {e}")
1093
-
1103
+ warnings.warn(f"Failed to send batch: {e}")
1104
+
1094
1105
  def _send_spans_batch(self, spans: List[Dict[str, Any]]):
1095
1106
  """Send a batch of spans to the spans endpoint."""
1096
- payload = {
1097
- "spans": spans,
1098
- "organization_id": self.organization_id
1099
- }
1100
-
1107
+ payload = {"spans": spans, "organization_id": self.organization_id}
1108
+
1101
1109
  # Serialize with fallback encoder
1102
1110
  def fallback_encoder(obj):
1103
1111
  try:
@@ -1107,10 +1115,10 @@ class BackgroundSpanService:
1107
1115
  return str(obj)
1108
1116
  except Exception as e:
1109
1117
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1110
-
1118
+
1111
1119
  try:
1112
1120
  serialized_data = json.dumps(payload, default=fallback_encoder)
1113
-
1121
+
1114
1122
  # Send the actual HTTP request to the batch endpoint
1115
1123
  response = requests.post(
1116
1124
  JUDGMENT_TRACES_SPANS_BATCH_API_URL,
@@ -1118,21 +1126,22 @@ class BackgroundSpanService:
1118
1126
  headers={
1119
1127
  "Content-Type": "application/json",
1120
1128
  "Authorization": f"Bearer {self.judgment_api_key}",
1121
- "X-Organization-Id": self.organization_id
1129
+ "X-Organization-Id": self.organization_id,
1122
1130
  },
1123
1131
  verify=True,
1124
- timeout=30 # Add timeout to prevent hanging
1132
+ timeout=30, # Add timeout to prevent hanging
1125
1133
  )
1126
-
1134
+
1127
1135
  if response.status_code != HTTPStatus.OK:
1128
- warnings.warn(f"Failed to send spans batch: HTTP {response.status_code} - {response.text}")
1129
-
1130
-
1131
- except requests.RequestException as e:
1136
+ warnings.warn(
1137
+ f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
1138
+ )
1139
+
1140
+ except RequestException as e:
1132
1141
  warnings.warn(f"Network error sending spans batch: {e}")
1133
1142
  except Exception as e:
1134
1143
  warnings.warn(f"Failed to serialize or send spans batch: {e}")
1135
-
1144
+
1136
1145
  def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
1137
1146
  """Send a batch of evaluation runs with their associated span data to the endpoint."""
1138
1147
  # Structure payload to include both evaluation run data and span data
@@ -1142,22 +1151,23 @@ class BackgroundSpanService:
1142
1151
  entry = {
1143
1152
  "evaluation_run": {
1144
1153
  # Extract evaluation run fields (excluding span-specific fields)
1145
- key: value for key, value in eval_data.items()
1146
- if key not in ['associated_span_id', 'span_data', 'queued_at']
1154
+ key: value
1155
+ for key, value in eval_data.items()
1156
+ if key not in ["associated_span_id", "span_data", "queued_at"]
1147
1157
  },
1148
1158
  "associated_span": {
1149
- "span_id": eval_data.get('associated_span_id'),
1150
- "span_data": eval_data.get('span_data')
1159
+ "span_id": eval_data.get("associated_span_id"),
1160
+ "span_data": eval_data.get("span_data"),
1151
1161
  },
1152
- "queued_at": eval_data.get('queued_at')
1162
+ "queued_at": eval_data.get("queued_at"),
1153
1163
  }
1154
1164
  evaluation_entries.append(entry)
1155
-
1165
+
1156
1166
  payload = {
1157
1167
  "organization_id": self.organization_id,
1158
- "evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
1168
+ "evaluation_entries": evaluation_entries, # Each entry contains both eval run + span data
1159
1169
  }
1160
-
1170
+
1161
1171
  # Serialize with fallback encoder
1162
1172
  def fallback_encoder(obj):
1163
1173
  try:
@@ -1167,10 +1177,10 @@ class BackgroundSpanService:
1167
1177
  return str(obj)
1168
1178
  except Exception as e:
1169
1179
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1170
-
1180
+
1171
1181
  try:
1172
1182
  serialized_data = json.dumps(payload, default=fallback_encoder)
1173
-
1183
+
1174
1184
  # Send the actual HTTP request to the batch endpoint
1175
1185
  response = requests.post(
1176
1186
  JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
@@ -1178,25 +1188,26 @@ class BackgroundSpanService:
1178
1188
  headers={
1179
1189
  "Content-Type": "application/json",
1180
1190
  "Authorization": f"Bearer {self.judgment_api_key}",
1181
- "X-Organization-Id": self.organization_id
1191
+ "X-Organization-Id": self.organization_id,
1182
1192
  },
1183
1193
  verify=True,
1184
- timeout=30 # Add timeout to prevent hanging
1194
+ timeout=30, # Add timeout to prevent hanging
1185
1195
  )
1186
-
1196
+
1187
1197
  if response.status_code != HTTPStatus.OK:
1188
- warnings.warn(f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}")
1189
-
1190
-
1191
- except requests.RequestException as e:
1198
+ warnings.warn(
1199
+ f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
1200
+ )
1201
+
1202
+ except RequestException as e:
1192
1203
  warnings.warn(f"Network error sending evaluation runs batch: {e}")
1193
1204
  except Exception as e:
1194
1205
  warnings.warn(f"Failed to send evaluation runs batch: {e}")
1195
-
1206
+
1196
1207
  def queue_span(self, span: TraceSpan, span_state: str = "input"):
1197
1208
  """
1198
1209
  Queue a span for background sending.
1199
-
1210
+
1200
1211
  Args:
1201
1212
  span: The TraceSpan object to queue
1202
1213
  span_state: State of the span ("input", "output", "completed")
@@ -1207,15 +1218,17 @@ class BackgroundSpanService:
1207
1218
  "data": {
1208
1219
  **span.model_dump(),
1209
1220
  "span_state": span_state,
1210
- "queued_at": time.time()
1211
- }
1221
+ "queued_at": time.time(),
1222
+ },
1212
1223
  }
1213
1224
  self._span_queue.put(span_data)
1214
-
1215
- def queue_evaluation_run(self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan):
1225
+
1226
+ def queue_evaluation_run(
1227
+ self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan
1228
+ ):
1216
1229
  """
1217
1230
  Queue an evaluation run for background sending.
1218
-
1231
+
1219
1232
  Args:
1220
1233
  evaluation_run: The EvaluationRun object to queue
1221
1234
  span_id: The span ID associated with this evaluation run
@@ -1228,11 +1241,11 @@ class BackgroundSpanService:
1228
1241
  **evaluation_run.model_dump(),
1229
1242
  "associated_span_id": span_id,
1230
1243
  "span_data": span_data.model_dump(), # Include span data to avoid race conditions
1231
- "queued_at": time.time()
1232
- }
1244
+ "queued_at": time.time(),
1245
+ },
1233
1246
  }
1234
1247
  self._span_queue.put(eval_data)
1235
-
1248
+
1236
1249
  def flush(self):
1237
1250
  """Force immediate sending of all queued spans."""
1238
1251
  try:
@@ -1240,89 +1253,101 @@ class BackgroundSpanService:
1240
1253
  self._span_queue.join()
1241
1254
  except Exception as e:
1242
1255
  warnings.warn(f"Error during flush: {e}")
1243
-
1256
+
1244
1257
  def shutdown(self):
1245
1258
  """Shutdown the background service and flush remaining spans."""
1246
1259
  if self._shutdown_event.is_set():
1247
1260
  return
1248
-
1261
+
1249
1262
  try:
1250
1263
  # Signal shutdown to stop new items from being queued
1251
1264
  self._shutdown_event.set()
1252
-
1265
+
1253
1266
  # Try to flush any remaining spans
1254
1267
  try:
1255
1268
  self.flush()
1256
1269
  except Exception as e:
1257
- warnings.warn(f"Error during final flush: {e}")
1270
+ warnings.warn(f"Error during final flush: {e}")
1258
1271
  except Exception as e:
1259
1272
  warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
1260
1273
  finally:
1261
1274
  # Clear the worker threads list (daemon threads will be killed automatically)
1262
1275
  self._worker_threads.clear()
1263
-
1276
+
1264
1277
  def get_queue_size(self) -> int:
1265
1278
  """Get the current size of the span queue."""
1266
1279
  return self._span_queue.qsize()
1267
1280
 
1281
+
1268
1282
  class _DeepTracer:
1269
1283
  _instance: Optional["_DeepTracer"] = None
1270
1284
  _lock: threading.Lock = threading.Lock()
1271
1285
  _refcount: int = 0
1272
- _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
1273
- _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
1286
+ _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
1287
+ "_deep_profiler_span_stack", default=[]
1288
+ )
1289
+ _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
1290
+ "_deep_profiler_skip_stack", default=[]
1291
+ )
1274
1292
  _original_sys_trace: Optional[Callable] = None
1275
1293
  _original_threading_trace: Optional[Callable] = None
1276
1294
 
1277
- def __init__(self, tracer: 'Tracer'):
1295
+ def __init__(self, tracer: "Tracer"):
1278
1296
  self._tracer = tracer
1279
1297
 
1280
1298
  def _get_qual_name(self, frame) -> str:
1281
1299
  func_name = frame.f_code.co_name
1282
1300
  module_name = frame.f_globals.get("__name__", "unknown_module")
1283
-
1301
+
1284
1302
  try:
1285
1303
  func = frame.f_globals.get(func_name)
1286
1304
  if func is None:
1287
1305
  return f"{module_name}.{func_name}"
1288
1306
  if hasattr(func, "__qualname__"):
1289
- return f"{module_name}.{func.__qualname__}"
1307
+ return f"{module_name}.{func.__qualname__}"
1308
+ return f"{module_name}.{func_name}"
1290
1309
  except Exception:
1291
1310
  return f"{module_name}.{func_name}"
1292
-
1293
- def __new__(cls, tracer: 'Tracer' = None):
1311
+
1312
+ def __new__(cls, tracer: "Tracer"):
1294
1313
  with cls._lock:
1295
1314
  if cls._instance is None:
1296
1315
  cls._instance = super().__new__(cls)
1297
1316
  return cls._instance
1298
-
1317
+
1299
1318
  def _should_trace(self, frame):
1300
1319
  # Skip stack is maintained by the tracer as an optimization to skip earlier
1301
1320
  # frames in the call stack that we've already determined should be skipped
1302
1321
  skip_stack = self._skip_stack.get()
1303
1322
  if len(skip_stack) > 0:
1304
1323
  return False
1305
-
1324
+
1306
1325
  func_name = frame.f_code.co_name
1307
1326
  module_name = frame.f_globals.get("__name__", None)
1308
-
1309
1327
  func = frame.f_globals.get(func_name)
1310
- if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
1328
+ if func and (
1329
+ hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
1330
+ ):
1311
1331
  return False
1312
1332
 
1313
1333
  if (
1314
1334
  not module_name
1315
- or func_name.startswith("<") # ex: <listcomp>
1316
- or func_name.startswith("__") and func_name != "__call__" # dunders
1335
+ or func_name.startswith("<") # ex: <listcomp>
1336
+ or func_name.startswith("__")
1337
+ and func_name != "__call__" # dunders
1317
1338
  or not self._is_user_code(frame.f_code.co_filename)
1318
1339
  ):
1319
1340
  return False
1320
-
1341
+
1321
1342
  return True
1322
-
1343
+
1323
1344
  @functools.cache
1324
1345
  def _is_user_code(self, filename: str):
1325
- return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
1346
+ return (
1347
+ bool(filename)
1348
+ and not filename.startswith("<")
1349
+ and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
1350
+ )
1326
1351
 
1327
1352
  def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
1328
1353
  """Cooperative trace function for sys.settrace that chains with existing tracers."""
@@ -1334,18 +1359,20 @@ class _DeepTracer:
1334
1359
  except Exception:
1335
1360
  # If the original tracer fails, continue with our tracing
1336
1361
  pass
1337
-
1362
+
1338
1363
  # Then do our own tracing
1339
1364
  our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
1340
-
1365
+
1341
1366
  # Return our tracer to continue tracing, but respect the original's decision
1342
1367
  # If the original tracer returned None (stop tracing), we should respect that
1343
1368
  if original_result is None and self._original_sys_trace:
1344
1369
  return None
1345
-
1370
+
1346
1371
  return our_result or original_result
1347
-
1348
- def _cooperative_threading_trace(self, frame: types.FrameType, event: str, arg: Any):
1372
+
1373
+ def _cooperative_threading_trace(
1374
+ self, frame: types.FrameType, event: str, arg: Any
1375
+ ):
1349
1376
  """Cooperative trace function for threading.settrace that chains with existing tracers."""
1350
1377
  # First, call the original threading trace function if it exists
1351
1378
  original_result = None
@@ -1355,44 +1382,48 @@ class _DeepTracer:
1355
1382
  except Exception:
1356
1383
  # If the original tracer fails, continue with our tracing
1357
1384
  pass
1358
-
1385
+
1359
1386
  # Then do our own tracing
1360
1387
  our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
1361
-
1388
+
1362
1389
  # Return our tracer to continue tracing, but respect the original's decision
1363
1390
  # If the original tracer returned None (stop tracing), we should respect that
1364
1391
  if original_result is None and self._original_threading_trace:
1365
1392
  return None
1366
-
1393
+
1367
1394
  return our_result or original_result
1368
-
1369
- def _trace(self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable):
1395
+
1396
+ def _trace(
1397
+ self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
1398
+ ):
1370
1399
  frame.f_trace_lines = False
1371
1400
  frame.f_trace_opcodes = False
1372
1401
 
1373
1402
  if not self._should_trace(frame):
1374
1403
  return
1375
-
1404
+
1376
1405
  if event not in ("call", "return", "exception"):
1377
1406
  return
1378
-
1407
+
1379
1408
  current_trace = self._tracer.get_current_trace()
1380
1409
  if not current_trace:
1381
1410
  return
1382
-
1411
+
1383
1412
  parent_span_id = self._tracer.get_current_span()
1384
1413
  if not parent_span_id:
1385
1414
  return
1386
1415
 
1387
1416
  qual_name = self._get_qual_name(frame)
1388
1417
  instance_name = None
1389
- if 'self' in frame.f_locals:
1390
- instance = frame.f_locals['self']
1418
+ if "self" in frame.f_locals:
1419
+ instance = frame.f_locals["self"]
1391
1420
  class_name = instance.__class__.__name__
1392
- class_identifiers = getattr(Tracer._instance, 'class_identifiers', {})
1393
- instance_name = get_instance_prefixed_name(instance, class_name, class_identifiers)
1421
+ class_identifiers = getattr(self._tracer, "class_identifiers", {})
1422
+ instance_name = get_instance_prefixed_name(
1423
+ instance, class_name, class_identifiers
1424
+ )
1394
1425
  skip_stack = self._skip_stack.get()
1395
-
1426
+
1396
1427
  if event == "call":
1397
1428
  # If we have entries in the skip stack and the current qual_name matches the top entry,
1398
1429
  # push it again to track nesting depth and skip
@@ -1402,9 +1433,9 @@ class _DeepTracer:
1402
1433
  skip_stack.append(qual_name)
1403
1434
  self._skip_stack.set(skip_stack)
1404
1435
  return
1405
-
1436
+
1406
1437
  should_trace = self._should_trace(frame)
1407
-
1438
+
1408
1439
  if not should_trace:
1409
1440
  if not skip_stack:
1410
1441
  self._skip_stack.set([qual_name])
@@ -1416,35 +1447,37 @@ class _DeepTracer:
1416
1447
  skip_stack.pop()
1417
1448
  self._skip_stack.set(skip_stack)
1418
1449
  return
1419
-
1450
+
1420
1451
  if skip_stack:
1421
1452
  return
1422
-
1453
+
1423
1454
  span_stack = self._span_stack.get()
1424
1455
  if event == "call":
1425
1456
  if not self._should_trace(frame):
1426
1457
  return
1427
-
1458
+
1428
1459
  span_id = str(uuid.uuid4())
1429
-
1460
+
1430
1461
  parent_depth = current_trace._span_depths.get(parent_span_id, 0)
1431
1462
  depth = parent_depth + 1
1432
-
1463
+
1433
1464
  current_trace._span_depths[span_id] = depth
1434
-
1465
+
1435
1466
  start_time = time.time()
1436
-
1437
- span_stack.append({
1438
- "span_id": span_id,
1439
- "parent_span_id": parent_span_id,
1440
- "function": qual_name,
1441
- "start_time": start_time
1442
- })
1467
+
1468
+ span_stack.append(
1469
+ {
1470
+ "span_id": span_id,
1471
+ "parent_span_id": parent_span_id,
1472
+ "function": qual_name,
1473
+ "start_time": start_time,
1474
+ }
1475
+ )
1443
1476
  self._span_stack.set(span_stack)
1444
-
1477
+
1445
1478
  token = self._tracer.set_current_span(span_id)
1446
1479
  frame.f_locals["_judgment_span_token"] = token
1447
-
1480
+
1448
1481
  span = TraceSpan(
1449
1482
  span_id=span_id,
1450
1483
  trace_id=current_trace.trace_id,
@@ -1454,57 +1487,55 @@ class _DeepTracer:
1454
1487
  span_type="span",
1455
1488
  parent_span_id=parent_span_id,
1456
1489
  function=qual_name,
1457
- agent_name=instance_name
1490
+ agent_name=instance_name,
1458
1491
  )
1459
1492
  current_trace.add_span(span)
1460
-
1493
+
1461
1494
  inputs = {}
1462
1495
  try:
1463
1496
  args_info = inspect.getargvalues(frame)
1464
1497
  for arg in args_info.args:
1465
1498
  try:
1466
1499
  inputs[arg] = args_info.locals.get(arg)
1467
- except:
1500
+ except Exception:
1468
1501
  inputs[arg] = "<<Unserializable>>"
1469
1502
  current_trace.record_input(inputs)
1470
1503
  except Exception as e:
1471
- current_trace.record_input({
1472
- "error": str(e)
1473
- })
1474
-
1504
+ current_trace.record_input({"error": str(e)})
1505
+
1475
1506
  elif event == "return":
1476
1507
  if not span_stack:
1477
1508
  return
1478
-
1509
+
1479
1510
  current_id = self._tracer.get_current_span()
1480
-
1511
+
1481
1512
  span_data = None
1482
1513
  for i, entry in enumerate(reversed(span_stack)):
1483
1514
  if entry["span_id"] == current_id:
1484
- span_data = span_stack.pop(-(i+1))
1515
+ span_data = span_stack.pop(-(i + 1))
1485
1516
  self._span_stack.set(span_stack)
1486
1517
  break
1487
-
1518
+
1488
1519
  if not span_data:
1489
1520
  return
1490
-
1521
+
1491
1522
  start_time = span_data["start_time"]
1492
1523
  duration = time.time() - start_time
1493
-
1524
+
1494
1525
  current_trace.span_id_to_span[span_data["span_id"]].duration = duration
1495
1526
 
1496
1527
  if arg is not None:
1497
- # exception handling will take priority.
1528
+ # exception handling will take priority.
1498
1529
  current_trace.record_output(arg)
1499
-
1530
+
1500
1531
  if span_data["span_id"] in current_trace._span_depths:
1501
1532
  del current_trace._span_depths[span_data["span_id"]]
1502
-
1533
+
1503
1534
  if span_stack:
1504
1535
  self._tracer.set_current_span(span_stack[-1]["span_id"])
1505
1536
  else:
1506
1537
  self._tracer.set_current_span(span_data["parent_span_id"])
1507
-
1538
+
1508
1539
  if "_judgment_span_token" in frame.f_locals:
1509
1540
  self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
1510
1541
 
@@ -1513,10 +1544,9 @@ class _DeepTracer:
1513
1544
  if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
1514
1545
  return
1515
1546
  _capture_exception_for_trace(current_trace, arg)
1516
-
1517
-
1547
+
1518
1548
  return continuation_func
1519
-
1549
+
1520
1550
  def __enter__(self):
1521
1551
  with self._lock:
1522
1552
  self._refcount += 1
@@ -1524,14 +1554,14 @@ class _DeepTracer:
1524
1554
  # Store the existing trace functions before setting ours
1525
1555
  self._original_sys_trace = sys.gettrace()
1526
1556
  self._original_threading_trace = threading.gettrace()
1527
-
1557
+
1528
1558
  self._skip_stack.set([])
1529
1559
  self._span_stack.set([])
1530
-
1560
+
1531
1561
  sys.settrace(self._cooperative_sys_trace)
1532
1562
  threading.settrace(self._cooperative_threading_trace)
1533
1563
  return self
1534
-
1564
+
1535
1565
  def __exit__(self, exc_type, exc_val, exc_tb):
1536
1566
  with self._lock:
1537
1567
  self._refcount -= 1
@@ -1539,7 +1569,7 @@ class _DeepTracer:
1539
1569
  # Restore the original trace functions instead of setting to None
1540
1570
  sys.settrace(self._original_sys_trace)
1541
1571
  threading.settrace(self._original_threading_trace)
1542
-
1572
+
1543
1573
  # Clean up the references
1544
1574
  self._original_sys_trace = None
1545
1575
  self._original_threading_trace = None
@@ -1547,7 +1577,7 @@ class _DeepTracer:
1547
1577
 
1548
1578
  # Below commented out function isn't used anymore?
1549
1579
 
1550
- # def log(self, message: str, level: str = "info"):
1580
+ # def log(self, message: str, level: str = "info"):
1551
1581
  # """ Log a message with the span context """
1552
1582
  # current_trace = self._tracer.get_current_trace()
1553
1583
  # if current_trace:
@@ -1555,25 +1585,29 @@ class _DeepTracer:
1555
1585
  # else:
1556
1586
  # print(f"[{level}] {message}")
1557
1587
  # current_trace.record_output({"log": message})
1558
-
1559
- class Tracer:
1560
1588
 
1589
+
1590
+ class Tracer:
1561
1591
  # Tracer.current_trace class variable is currently used in wrap()
1562
1592
  # TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
1563
1593
  # Should be fine to do so as long as we keep Tracer as a singleton
1564
1594
  current_trace: Optional[TraceClient] = None
1565
1595
  # current_span_id: Optional[str] = None
1566
1596
 
1567
- trace_across_async_contexts: bool = False # BY default, we don't trace across async contexts
1597
+ trace_across_async_contexts: bool = (
1598
+ False # BY default, we don't trace across async contexts
1599
+ )
1568
1600
 
1569
1601
  def __init__(
1570
- self,
1571
- api_key: str = os.getenv("JUDGMENT_API_KEY"),
1572
- project_name: str = None,
1602
+ self,
1603
+ api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
1604
+ project_name: str | None = None,
1573
1605
  rules: Optional[List[Rule]] = None, # Added rules parameter
1574
- organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
1575
- enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
1576
- enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
1606
+ organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
1607
+ enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
1608
+ == "true",
1609
+ enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
1610
+ == "true",
1577
1611
  # S3 configuration
1578
1612
  use_s3: bool = False,
1579
1613
  s3_bucket_name: Optional[str] = None,
@@ -1581,26 +1615,32 @@ class Tracer:
1581
1615
  s3_aws_secret_access_key: Optional[str] = None,
1582
1616
  s3_region_name: Optional[str] = None,
1583
1617
  offline_mode: bool = False,
1584
- deep_tracing: bool = True, # Deep tracing is enabled by default
1585
- trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1618
+ deep_tracing: bool = False, # Deep tracing is disabled by default
1619
+ trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1586
1620
  # Background span service configuration
1587
1621
  enable_background_spans: bool = True, # Enable background span service by default
1588
1622
  span_batch_size: int = 50, # Number of spans to batch before sending
1589
1623
  span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
1590
- span_num_workers: int = 10 # Number of worker threads for span processing
1591
- ):
1624
+ span_num_workers: int = 10, # Number of worker threads for span processing
1625
+ ):
1592
1626
  if not api_key:
1593
1627
  raise ValueError("Tracer must be configured with a Judgment API key")
1594
-
1595
- result, response = validate_api_key(api_key)
1628
+
1629
+ try:
1630
+ result, response = validate_api_key(api_key)
1631
+ except Exception as e:
1632
+ print(f"Issue with verifying API key, disabling monitoring: {e}")
1633
+ enable_monitoring = False
1634
+ result = True
1635
+
1596
1636
  if not result:
1597
1637
  raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
1598
-
1638
+
1599
1639
  if not organization_id:
1600
1640
  raise ValueError("Tracer must be configured with an Organization ID")
1601
1641
  if use_s3 and not s3_bucket_name:
1602
1642
  raise ValueError("S3 bucket name must be provided when use_s3 is True")
1603
-
1643
+
1604
1644
  self.api_key: str = api_key
1605
1645
  self.project_name: str = project_name or str(uuid.uuid4())
1606
1646
  self.organization_id: str = organization_id
@@ -1608,9 +1648,11 @@ class Tracer:
1608
1648
  self.traces: List[Trace] = []
1609
1649
  self.enable_monitoring: bool = enable_monitoring
1610
1650
  self.enable_evaluations: bool = enable_evaluations
1611
- self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
1612
- self.span_id_to_previous_span_id: Dict[str, str] = {}
1613
- self.trace_id_to_previous_trace: Dict[str, TraceClient] = {}
1651
+ self.class_identifiers: Dict[
1652
+ str, str
1653
+ ] = {} # Dictionary to store class identifiers
1654
+ self.span_id_to_previous_span_id: Dict[str, str | None] = {}
1655
+ self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
1614
1656
  self.current_span_id: Optional[str] = None
1615
1657
  self.current_trace: Optional[TraceClient] = None
1616
1658
  self.trace_across_async_contexts: bool = trace_across_async_contexts
@@ -1620,15 +1662,21 @@ class Tracer:
1620
1662
  self.use_s3 = use_s3
1621
1663
  if use_s3:
1622
1664
  from judgeval.common.s3_storage import S3Storage
1623
- self.s3_storage = S3Storage(
1624
- bucket_name=s3_bucket_name,
1625
- aws_access_key_id=s3_aws_access_key_id,
1626
- aws_secret_access_key=s3_aws_secret_access_key,
1627
- region_name=s3_region_name
1628
- )
1665
+
1666
+ try:
1667
+ self.s3_storage = S3Storage(
1668
+ bucket_name=s3_bucket_name,
1669
+ aws_access_key_id=s3_aws_access_key_id,
1670
+ aws_secret_access_key=s3_aws_secret_access_key,
1671
+ region_name=s3_region_name,
1672
+ )
1673
+ except Exception as e:
1674
+ print(f"Issue with initializing S3 storage, disabling S3: {e}")
1675
+ self.use_s3 = False
1676
+
1629
1677
  self.offline_mode: bool = offline_mode
1630
1678
  self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1631
-
1679
+
1632
1680
  # Initialize background span service
1633
1681
  self.enable_background_spans: bool = enable_background_spans
1634
1682
  self.background_span_service: Optional[BackgroundSpanService] = None
@@ -1638,49 +1686,61 @@ class Tracer:
1638
1686
  organization_id=organization_id,
1639
1687
  batch_size=span_batch_size,
1640
1688
  flush_interval=span_flush_interval,
1641
- num_workers=span_num_workers
1689
+ num_workers=span_num_workers,
1642
1690
  )
1643
1691
 
1644
- def set_current_span(self, span_id: str):
1645
- self.span_id_to_previous_span_id[span_id] = getattr(self, 'current_span_id', None)
1692
+ def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
1693
+ self.span_id_to_previous_span_id[span_id] = self.current_span_id
1646
1694
  self.current_span_id = span_id
1647
1695
  Tracer.current_span_id = span_id
1648
1696
  try:
1649
1697
  token = current_span_var.set(span_id)
1650
- except:
1698
+ except Exception:
1651
1699
  token = None
1652
1700
  return token
1653
-
1701
+
1654
1702
  def get_current_span(self) -> Optional[str]:
1655
1703
  try:
1656
1704
  current_span_var_val = current_span_var.get()
1657
- except:
1705
+ except Exception:
1658
1706
  current_span_var_val = None
1659
- return (self.current_span_id or current_span_var_val) if self.trace_across_async_contexts else current_span_var_val
1660
-
1661
- def reset_current_span(self, token: Optional[str] = None, span_id: Optional[str] = None):
1662
- if not span_id:
1663
- span_id = self.current_span_id
1707
+ return (
1708
+ (self.current_span_id or current_span_var_val)
1709
+ if self.trace_across_async_contexts
1710
+ else current_span_var_val
1711
+ )
1712
+
1713
+ def reset_current_span(
1714
+ self,
1715
+ token: Optional[contextvars.Token[str | None]] = None,
1716
+ span_id: Optional[str] = None,
1717
+ ):
1664
1718
  try:
1665
- current_span_var.reset(token)
1666
- except:
1719
+ if token:
1720
+ current_span_var.reset(token)
1721
+ except Exception:
1667
1722
  pass
1668
- self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
1669
- Tracer.current_span_id = self.current_span_id
1670
-
1671
- def set_current_trace(self, trace: TraceClient):
1723
+ if not span_id:
1724
+ span_id = self.current_span_id
1725
+ if span_id:
1726
+ self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
1727
+ Tracer.current_span_id = self.current_span_id
1728
+
1729
+ def set_current_trace(
1730
+ self, trace: TraceClient
1731
+ ) -> Optional[contextvars.Token[TraceClient | None]]:
1672
1732
  """
1673
1733
  Set the current trace context in contextvars
1674
1734
  """
1675
- self.trace_id_to_previous_trace[trace.trace_id] = getattr(self, 'current_trace', None)
1735
+ self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
1676
1736
  self.current_trace = trace
1677
1737
  Tracer.current_trace = trace
1678
1738
  try:
1679
1739
  token = current_trace_var.set(trace)
1680
- except:
1740
+ except Exception:
1681
1741
  token = None
1682
1742
  return token
1683
-
1743
+
1684
1744
  def get_current_trace(self) -> Optional[TraceClient]:
1685
1745
  """
1686
1746
  Get the current trace context.
@@ -1691,72 +1751,69 @@ class Tracer:
1691
1751
  """
1692
1752
  try:
1693
1753
  current_trace_var_val = current_trace_var.get()
1694
- except:
1754
+ except Exception:
1695
1755
  current_trace_var_val = None
1696
-
1697
- # Use context variable or class variable based on trace_across_async_contexts setting
1698
- context_trace = (self.current_trace or current_trace_var_val) if self.trace_across_async_contexts else current_trace_var_val
1699
-
1700
- # If we found a trace from context, return it
1701
- if context_trace:
1702
- return context_trace
1703
-
1704
- # Fallback: Check the active client potentially set by a callback handler (e.g., LangGraph)
1705
- if hasattr(self, '_active_trace_client') and self._active_trace_client:
1706
- return self._active_trace_client
1707
-
1708
- # If neither is available, return None
1709
- return None
1710
-
1711
- def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
1712
- if not trace_id and self.current_trace:
1713
- trace_id = self.current_trace.trace_id
1756
+ return (
1757
+ (self.current_trace or current_trace_var_val)
1758
+ if self.trace_across_async_contexts
1759
+ else current_trace_var_val
1760
+ )
1761
+
1762
+ def reset_current_trace(
1763
+ self,
1764
+ token: Optional[contextvars.Token[TraceClient | None]] = None,
1765
+ trace_id: Optional[str] = None,
1766
+ ):
1714
1767
  try:
1715
- current_trace_var.reset(token)
1716
- except:
1768
+ if token:
1769
+ current_trace_var.reset(token)
1770
+ except Exception:
1717
1771
  pass
1718
- self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1719
- Tracer.current_trace = self.current_trace
1772
+ if not trace_id and self.current_trace:
1773
+ trace_id = self.current_trace.trace_id
1774
+ if trace_id:
1775
+ self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1776
+ Tracer.current_trace = self.current_trace
1720
1777
 
1721
1778
  @contextmanager
1722
1779
  def trace(
1723
- self,
1724
- name: str,
1725
- project_name: str = None,
1780
+ self,
1781
+ name: str,
1782
+ project_name: str | None = None,
1726
1783
  overwrite: bool = False,
1727
- rules: Optional[List[Rule]] = None # Added rules parameter
1784
+ rules: Optional[List[Rule]] = None, # Added rules parameter
1728
1785
  ) -> Generator[TraceClient, None, None]:
1729
1786
  """Start a new trace context using a context manager"""
1730
1787
  trace_id = str(uuid.uuid4())
1731
1788
  project = project_name if project_name is not None else self.project_name
1732
-
1789
+
1733
1790
  # Get parent trace info from context
1734
1791
  parent_trace = self.get_current_trace()
1735
1792
  parent_trace_id = None
1736
1793
  parent_name = None
1737
-
1794
+
1738
1795
  if parent_trace:
1739
1796
  parent_trace_id = parent_trace.trace_id
1740
1797
  parent_name = parent_trace.name
1741
1798
 
1742
1799
  trace = TraceClient(
1743
- self,
1744
- trace_id,
1745
- name,
1746
- project_name=project,
1800
+ self,
1801
+ trace_id,
1802
+ name,
1803
+ project_name=project,
1747
1804
  overwrite=overwrite,
1748
1805
  rules=self.rules, # Pass combined rules to the trace client
1749
1806
  enable_monitoring=self.enable_monitoring,
1750
1807
  enable_evaluations=self.enable_evaluations,
1751
1808
  parent_trace_id=parent_trace_id,
1752
- parent_name=parent_name
1809
+ parent_name=parent_name,
1753
1810
  )
1754
-
1811
+
1755
1812
  # Set the current trace in context variables
1756
1813
  token = self.set_current_trace(trace)
1757
-
1814
+
1758
1815
  # Automatically create top-level span
1759
- with trace.span(name or "unnamed_trace") as span:
1816
+ with trace.span(name or "unnamed_trace"):
1760
1817
  try:
1761
1818
  # Save the trace to the database to handle Evaluations' trace_id referential integrity
1762
1819
  yield trace
@@ -1764,101 +1821,110 @@ class Tracer:
1764
1821
  # Reset the context variable
1765
1822
  self.reset_current_trace(token)
1766
1823
 
1767
-
1768
1824
  def log(self, msg: str, label: str = "log", score: int = 1):
1769
1825
  """Log a message with the current span context"""
1770
1826
  current_span_id = self.get_current_span()
1771
1827
  current_trace = self.get_current_trace()
1772
- if current_span_id:
1828
+ if current_span_id and current_trace:
1773
1829
  annotation = TraceAnnotation(
1774
- span_id=current_span_id,
1775
- text=msg,
1776
- label=label,
1777
- score=score
1830
+ span_id=current_span_id, text=msg, label=label, score=score
1778
1831
  )
1779
-
1780
1832
  current_trace.add_annotation(annotation)
1781
1833
 
1782
1834
  rprint(f"[bold]{label}:[/bold] {msg}")
1783
-
1784
- def identify(self, identifier: str, track_state: bool = False, track_attributes: Optional[List[str]] = None, field_mappings: Optional[Dict[str, str]] = None):
1835
+
1836
+ def identify(
1837
+ self,
1838
+ identifier: str,
1839
+ track_state: bool = False,
1840
+ track_attributes: Optional[List[str]] = None,
1841
+ field_mappings: Optional[Dict[str, str]] = None,
1842
+ ):
1785
1843
  """
1786
1844
  Class decorator that associates a class with a custom identifier and enables state tracking.
1787
-
1845
+
1788
1846
  This decorator creates a mapping between the class name and the provided
1789
1847
  identifier, which can be useful for tagging, grouping, or referencing
1790
1848
  classes in a standardized way. It also enables automatic state capture
1791
1849
  for instances of the decorated class when used with tracing.
1792
-
1850
+
1793
1851
  Args:
1794
1852
  identifier: The identifier to associate with the decorated class.
1795
1853
  This will be used as the instance name in traces.
1796
- track_state: Whether to automatically capture the state (attributes)
1854
+ track_state: Whether to automatically capture the state (attributes)
1797
1855
  of instances before and after function execution. Defaults to False.
1798
1856
  track_attributes: Optional list of specific attribute names to track.
1799
- If None, all non-private attributes (not starting with '_')
1857
+ If None, all non-private attributes (not starting with '_')
1800
1858
  will be tracked when track_state=True.
1801
- field_mappings: Optional dictionary mapping internal attribute names to
1859
+ field_mappings: Optional dictionary mapping internal attribute names to
1802
1860
  display names in the captured state. For example:
1803
- {"system_prompt": "instructions"} will capture the
1861
+ {"system_prompt": "instructions"} will capture the
1804
1862
  'instructions' attribute as 'system_prompt' in the state.
1805
-
1863
+
1806
1864
  Example:
1807
1865
  @tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
1808
1866
  class User:
1809
1867
  # Class implementation
1810
1868
  """
1869
+
1811
1870
  def decorator(cls):
1812
1871
  class_name = cls.__name__
1813
1872
  self.class_identifiers[class_name] = {
1814
1873
  "identifier": identifier,
1815
1874
  "track_state": track_state,
1816
1875
  "track_attributes": track_attributes,
1817
- "field_mappings": field_mappings or {}
1876
+ "field_mappings": field_mappings or {},
1818
1877
  }
1819
1878
  return cls
1820
-
1879
+
1821
1880
  return decorator
1822
-
1823
- def _capture_instance_state(self, instance: Any, class_config: Dict[str, Any]) -> Dict[str, Any]:
1881
+
1882
+ def _capture_instance_state(
1883
+ self, instance: Any, class_config: Dict[str, Any]
1884
+ ) -> Dict[str, Any]:
1824
1885
  """
1825
1886
  Capture the state of an instance based on class configuration.
1826
1887
  Args:
1827
1888
  instance: The instance to capture the state of.
1828
- class_config: Configuration dictionary for state capture,
1889
+ class_config: Configuration dictionary for state capture,
1829
1890
  expected to contain 'track_attributes' and 'field_mappings'.
1830
1891
  """
1831
- track_attributes = class_config.get('track_attributes')
1832
- field_mappings = class_config.get('field_mappings')
1833
-
1892
+ track_attributes = class_config.get("track_attributes")
1893
+ field_mappings = class_config.get("field_mappings")
1894
+
1834
1895
  if track_attributes:
1835
-
1836
1896
  state = {attr: getattr(instance, attr, None) for attr in track_attributes}
1837
1897
  else:
1838
-
1839
- state = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')}
1898
+ state = {
1899
+ k: v for k, v in instance.__dict__.items() if not k.startswith("_")
1900
+ }
1840
1901
 
1841
1902
  if field_mappings:
1842
- state['field_mappings'] = field_mappings
1903
+ state["field_mappings"] = field_mappings
1843
1904
 
1844
1905
  return state
1845
-
1846
-
1906
+
1847
1907
  def _get_instance_state_if_tracked(self, args):
1848
1908
  """
1849
1909
  Extract instance state if the instance should be tracked.
1850
-
1910
+
1851
1911
  Returns the captured state dict if tracking is enabled, None otherwise.
1852
1912
  """
1853
- if args and hasattr(args[0], '__class__'):
1913
+ if args and hasattr(args[0], "__class__"):
1854
1914
  instance = args[0]
1855
1915
  class_name = instance.__class__.__name__
1856
- if (class_name in self.class_identifiers and
1857
- isinstance(self.class_identifiers[class_name], dict) and
1858
- self.class_identifiers[class_name].get('track_state', False)):
1859
- return self._capture_instance_state(instance, self.class_identifiers[class_name])
1860
-
1861
- def _conditionally_capture_and_record_state(self, trace_client_instance: TraceClient, args: tuple, is_before: bool):
1916
+ if (
1917
+ class_name in self.class_identifiers
1918
+ and isinstance(self.class_identifiers[class_name], dict)
1919
+ and self.class_identifiers[class_name].get("track_state", False)
1920
+ ):
1921
+ return self._capture_instance_state(
1922
+ instance, self.class_identifiers[class_name]
1923
+ )
1924
+
1925
+ def _conditionally_capture_and_record_state(
1926
+ self, trace_client_instance: TraceClient, args: tuple, is_before: bool
1927
+ ):
1862
1928
  """Captures instance state if tracked and records it via the trace_client."""
1863
1929
  state = self._get_instance_state_if_tracked(args)
1864
1930
  if state:
@@ -1866,11 +1932,20 @@ class Tracer:
1866
1932
  trace_client_instance.record_state_before(state)
1867
1933
  else:
1868
1934
  trace_client_instance.record_state_after(state)
1869
-
1870
- def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1935
+
1936
+ def observe(
1937
+ self,
1938
+ func=None,
1939
+ *,
1940
+ name=None,
1941
+ span_type: SpanType = "span",
1942
+ project_name: str | None = None,
1943
+ overwrite: bool = False,
1944
+ deep_tracing: bool | None = None,
1945
+ ):
1871
1946
  """
1872
1947
  Decorator to trace function execution with detailed entry/exit information.
1873
-
1948
+
1874
1949
  Args:
1875
1950
  func: The function to decorate
1876
1951
  name: Optional custom name for the span (defaults to function name)
@@ -1881,56 +1956,71 @@ class Tracer:
1881
1956
  If None, uses the tracer's default setting.
1882
1957
  """
1883
1958
  # If monitoring is disabled, return the function as is
1884
- if not self.enable_monitoring:
1885
- return func if func else lambda f: f
1886
-
1887
- if func is None:
1888
- return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
1889
- overwrite=overwrite, deep_tracing=deep_tracing)
1890
-
1891
- # Use provided name or fall back to function name
1892
- original_span_name = name or func.__name__
1893
-
1894
- # Store custom attributes on the function object
1895
- func._judgment_span_name = original_span_name
1896
- func._judgment_span_type = span_type
1897
-
1898
- # Use the provided deep_tracing value or fall back to the tracer's default
1899
- use_deep_tracing = deep_tracing if deep_tracing is not None else self.deep_tracing
1900
-
1959
+ try:
1960
+ if not self.enable_monitoring:
1961
+ return func if func else lambda f: f
1962
+
1963
+ if func is None:
1964
+ return lambda f: self.observe(
1965
+ f,
1966
+ name=name,
1967
+ span_type=span_type,
1968
+ project_name=project_name,
1969
+ overwrite=overwrite,
1970
+ deep_tracing=deep_tracing,
1971
+ )
1972
+
1973
+ # Use provided name or fall back to function name
1974
+ original_span_name = name or func.__name__
1975
+
1976
+ # Store custom attributes on the function object
1977
+ func._judgment_span_name = original_span_name
1978
+ func._judgment_span_type = span_type
1979
+
1980
+ # Use the provided deep_tracing value or fall back to the tracer's default
1981
+ use_deep_tracing = (
1982
+ deep_tracing if deep_tracing is not None else self.deep_tracing
1983
+ )
1984
+ except Exception:
1985
+ return func
1986
+
1901
1987
  if asyncio.iscoroutinefunction(func):
1988
+
1902
1989
  @functools.wraps(func)
1903
1990
  async def async_wrapper(*args, **kwargs):
1904
1991
  nonlocal original_span_name
1905
1992
  class_name = None
1906
- instance_name = None
1907
1993
  span_name = original_span_name
1908
1994
  agent_name = None
1909
1995
 
1910
- if args and hasattr(args[0], '__class__'):
1996
+ if args and hasattr(args[0], "__class__"):
1911
1997
  class_name = args[0].__class__.__name__
1912
- agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1998
+ agent_name = get_instance_prefixed_name(
1999
+ args[0], class_name, self.class_identifiers
2000
+ )
1913
2001
 
1914
2002
  # Get current trace from context
1915
2003
  current_trace = self.get_current_trace()
1916
-
2004
+
1917
2005
  # If there's no current trace, create a root trace
1918
2006
  if not current_trace:
1919
2007
  trace_id = str(uuid.uuid4())
1920
- project = project_name if project_name is not None else self.project_name
1921
-
2008
+ project = (
2009
+ project_name if project_name is not None else self.project_name
2010
+ )
2011
+
1922
2012
  # Create a new trace client to serve as the root
1923
2013
  current_trace = TraceClient(
1924
2014
  self,
1925
2015
  trace_id,
1926
- span_name, # MODIFIED: Use span_name directly
2016
+ span_name, # MODIFIED: Use span_name directly
1927
2017
  project_name=project,
1928
2018
  overwrite=overwrite,
1929
2019
  rules=self.rules,
1930
2020
  enable_monitoring=self.enable_monitoring,
1931
- enable_evaluations=self.enable_evaluations
2021
+ enable_evaluations=self.enable_evaluations,
1932
2022
  )
1933
-
2023
+
1934
2024
  # Save empty trace and set trace context
1935
2025
  # current_trace.save(empty_save=True, overwrite=overwrite)
1936
2026
  trace_token = self.set_current_trace(current_trace)
@@ -1938,7 +2028,9 @@ class Tracer:
1938
2028
  try:
1939
2029
  # Use span for the function execution within the root trace
1940
2030
  # This sets the current_span_var
1941
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
2031
+ with current_trace.span(
2032
+ span_name, span_type=span_type
2033
+ ) as span: # MODIFIED: Use span_name directly
1942
2034
  # Record inputs
1943
2035
  inputs = combine_args_kwargs(func, args, kwargs)
1944
2036
  span.record_input(inputs)
@@ -1946,50 +2038,66 @@ class Tracer:
1946
2038
  span.record_agent_name(agent_name)
1947
2039
 
1948
2040
  # Capture state before execution
1949
- self._conditionally_capture_and_record_state(span, args, is_before=True)
1950
-
1951
- if use_deep_tracing:
1952
- with _DeepTracer(self):
1953
- result = await func(*args, **kwargs)
1954
- else:
1955
- try:
2041
+ self._conditionally_capture_and_record_state(
2042
+ span, args, is_before=True
2043
+ )
2044
+
2045
+ try:
2046
+ if use_deep_tracing:
2047
+ with _DeepTracer(self):
2048
+ result = await func(*args, **kwargs)
2049
+ else:
1956
2050
  result = await func(*args, **kwargs)
1957
- except Exception as e:
1958
- _capture_exception_for_trace(current_trace, sys.exc_info())
1959
- raise e
1960
-
1961
- # Capture state after execution
1962
- self._conditionally_capture_and_record_state(span, args, is_before=False)
1963
-
2051
+ except Exception as e:
2052
+ _capture_exception_for_trace(
2053
+ current_trace, sys.exc_info()
2054
+ )
2055
+ raise e
2056
+
2057
+ # Capture state after execution
2058
+ self._conditionally_capture_and_record_state(
2059
+ span, args, is_before=False
2060
+ )
2061
+
1964
2062
  # Record output
1965
2063
  span.record_output(result)
1966
2064
  return result
1967
2065
  finally:
1968
2066
  # Flush background spans before saving the trace
1969
-
1970
- complete_trace_data = {
1971
- "trace_id": current_trace.trace_id,
1972
- "name": current_trace.name,
1973
- "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
1974
- "duration": current_trace.get_duration(),
1975
- "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
1976
- "overwrite": overwrite,
1977
- "offline_mode": self.offline_mode,
1978
- "parent_trace_id": current_trace.parent_trace_id,
1979
- "parent_name": current_trace.parent_name
1980
- }
1981
- # Save the completed trace
1982
- trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
1983
-
1984
- # Store the complete trace data instead of just server response
1985
-
1986
- self.traces.append(complete_trace_data)
1987
-
1988
- # if self.background_span_service:
1989
- # self.background_span_service.flush()
1990
-
1991
- # Reset trace context (span context resets automatically)
1992
- self.reset_current_trace(trace_token)
2067
+ try:
2068
+ complete_trace_data = {
2069
+ "trace_id": current_trace.trace_id,
2070
+ "name": current_trace.name,
2071
+ "created_at": datetime.utcfromtimestamp(
2072
+ current_trace.start_time
2073
+ ).isoformat(),
2074
+ "duration": current_trace.get_duration(),
2075
+ "trace_spans": [
2076
+ span.model_dump()
2077
+ for span in current_trace.trace_spans
2078
+ ],
2079
+ "overwrite": overwrite,
2080
+ "offline_mode": self.offline_mode,
2081
+ "parent_trace_id": current_trace.parent_trace_id,
2082
+ "parent_name": current_trace.parent_name,
2083
+ }
2084
+ # Save the completed trace
2085
+ trace_id, server_response = current_trace.save(
2086
+ overwrite=overwrite, final_save=True
2087
+ )
2088
+
2089
+ # Store the complete trace data instead of just server response
2090
+
2091
+ self.traces.append(complete_trace_data)
2092
+
2093
+ # if self.background_span_service:
2094
+ # self.background_span_service.flush()
2095
+
2096
+ # Reset trace context (span context resets automatically)
2097
+ self.reset_current_trace(trace_token)
2098
+ except Exception as e:
2099
+ warnings.warn(f"Issue with async_wrapper: {e}")
2100
+ return
1993
2101
  else:
1994
2102
  with current_trace.span(span_name, span_type=span_type) as span:
1995
2103
  inputs = combine_args_kwargs(func, args, kwargs)
@@ -1998,24 +2106,28 @@ class Tracer:
1998
2106
  span.record_agent_name(agent_name)
1999
2107
 
2000
2108
  # Capture state before execution
2001
- self._conditionally_capture_and_record_state(span, args, is_before=True)
2002
-
2003
- if use_deep_tracing:
2004
- with _DeepTracer(self):
2005
- result = await func(*args, **kwargs)
2006
- else:
2007
- try:
2109
+ self._conditionally_capture_and_record_state(
2110
+ span, args, is_before=True
2111
+ )
2112
+
2113
+ try:
2114
+ if use_deep_tracing:
2115
+ with _DeepTracer(self):
2116
+ result = await func(*args, **kwargs)
2117
+ else:
2008
2118
  result = await func(*args, **kwargs)
2009
- except Exception as e:
2010
- _capture_exception_for_trace(current_trace, sys.exc_info())
2011
- raise e
2012
-
2013
- # Capture state after execution
2014
- self._conditionally_capture_and_record_state(span, args, is_before=False)
2015
-
2119
+ except Exception as e:
2120
+ _capture_exception_for_trace(current_trace, sys.exc_info())
2121
+ raise e
2122
+
2123
+ # Capture state after execution
2124
+ self._conditionally_capture_and_record_state(
2125
+ span, args, is_before=False
2126
+ )
2127
+
2016
2128
  span.record_output(result)
2017
2129
  return result
2018
-
2130
+
2019
2131
  return async_wrapper
2020
2132
  else:
2021
2133
  # Non-async function implementation with deep tracing
@@ -2023,118 +2135,146 @@ class Tracer:
2023
2135
  def wrapper(*args, **kwargs):
2024
2136
  nonlocal original_span_name
2025
2137
  class_name = None
2026
- instance_name = None
2027
2138
  span_name = original_span_name
2028
2139
  agent_name = None
2029
- if args and hasattr(args[0], '__class__'):
2140
+ if args and hasattr(args[0], "__class__"):
2030
2141
  class_name = args[0].__class__.__name__
2031
- agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
2142
+ agent_name = get_instance_prefixed_name(
2143
+ args[0], class_name, self.class_identifiers
2144
+ )
2032
2145
  # Get current trace from context
2033
2146
  current_trace = self.get_current_trace()
2034
2147
 
2035
2148
  # If there's no current trace, create a root trace
2036
2149
  if not current_trace:
2037
2150
  trace_id = str(uuid.uuid4())
2038
- project = project_name if project_name is not None else self.project_name
2039
-
2151
+ project = (
2152
+ project_name if project_name is not None else self.project_name
2153
+ )
2154
+
2040
2155
  # Create a new trace client to serve as the root
2041
2156
  current_trace = TraceClient(
2042
2157
  self,
2043
2158
  trace_id,
2044
- span_name, # MODIFIED: Use span_name directly
2159
+ span_name, # MODIFIED: Use span_name directly
2045
2160
  project_name=project,
2046
2161
  overwrite=overwrite,
2047
2162
  rules=self.rules,
2048
2163
  enable_monitoring=self.enable_monitoring,
2049
- enable_evaluations=self.enable_evaluations
2164
+ enable_evaluations=self.enable_evaluations,
2050
2165
  )
2051
-
2166
+
2052
2167
  # Save empty trace and set trace context
2053
2168
  # current_trace.save(empty_save=True, overwrite=overwrite)
2054
2169
  trace_token = self.set_current_trace(current_trace)
2055
-
2170
+
2056
2171
  try:
2057
2172
  # Use span for the function execution within the root trace
2058
2173
  # This sets the current_span_var
2059
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
2174
+ with current_trace.span(
2175
+ span_name, span_type=span_type
2176
+ ) as span: # MODIFIED: Use span_name directly
2060
2177
  # Record inputs
2061
2178
  inputs = combine_args_kwargs(func, args, kwargs)
2062
2179
  span.record_input(inputs)
2063
2180
  if agent_name:
2064
2181
  span.record_agent_name(agent_name)
2065
2182
  # Capture state before execution
2066
- self._conditionally_capture_and_record_state(span, args, is_before=True)
2067
-
2068
- if use_deep_tracing:
2069
- with _DeepTracer(self):
2070
- result = func(*args, **kwargs)
2071
- else:
2072
- try:
2183
+ self._conditionally_capture_and_record_state(
2184
+ span, args, is_before=True
2185
+ )
2186
+
2187
+ try:
2188
+ if use_deep_tracing:
2189
+ with _DeepTracer(self):
2190
+ result = func(*args, **kwargs)
2191
+ else:
2073
2192
  result = func(*args, **kwargs)
2074
- except Exception as e:
2075
- _capture_exception_for_trace(current_trace, sys.exc_info())
2076
- raise e
2077
-
2193
+ except Exception as e:
2194
+ _capture_exception_for_trace(
2195
+ current_trace, sys.exc_info()
2196
+ )
2197
+ raise e
2198
+
2078
2199
  # Capture state after execution
2079
- self._conditionally_capture_and_record_state(span, args, is_before=False)
2200
+ self._conditionally_capture_and_record_state(
2201
+ span, args, is_before=False
2202
+ )
2080
2203
 
2081
-
2082
2204
  # Record output
2083
2205
  span.record_output(result)
2084
2206
  return result
2085
2207
  finally:
2086
2208
  # Flush background spans before saving the trace
2087
-
2088
-
2089
- # Save the completed trace
2090
- trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
2091
-
2092
- # Store the complete trace data instead of just server response
2093
- complete_trace_data = {
2094
- "trace_id": current_trace.trace_id,
2095
- "name": current_trace.name,
2096
- "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
2097
- "duration": current_trace.get_duration(),
2098
- "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
2099
- "overwrite": overwrite,
2100
- "offline_mode": self.offline_mode,
2101
- "parent_trace_id": current_trace.parent_trace_id,
2102
- "parent_name": current_trace.parent_name
2103
- }
2104
- self.traces.append(complete_trace_data)
2105
- # Reset trace context (span context resets automatically)
2106
- self.reset_current_trace(trace_token)
2209
+ try:
2210
+ # Save the completed trace
2211
+ trace_id, server_response = current_trace.save(
2212
+ overwrite=overwrite, final_save=True
2213
+ )
2214
+
2215
+ # Store the complete trace data instead of just server response
2216
+ complete_trace_data = {
2217
+ "trace_id": current_trace.trace_id,
2218
+ "name": current_trace.name,
2219
+ "created_at": datetime.utcfromtimestamp(
2220
+ current_trace.start_time
2221
+ ).isoformat(),
2222
+ "duration": current_trace.get_duration(),
2223
+ "trace_spans": [
2224
+ span.model_dump()
2225
+ for span in current_trace.trace_spans
2226
+ ],
2227
+ "overwrite": overwrite,
2228
+ "offline_mode": self.offline_mode,
2229
+ "parent_trace_id": current_trace.parent_trace_id,
2230
+ "parent_name": current_trace.parent_name,
2231
+ }
2232
+ self.traces.append(complete_trace_data)
2233
+ # Reset trace context (span context resets automatically)
2234
+ self.reset_current_trace(trace_token)
2235
+ except Exception as e:
2236
+ warnings.warn(f"Issue with save: {e}")
2237
+ return
2107
2238
  else:
2108
2239
  with current_trace.span(span_name, span_type=span_type) as span:
2109
-
2110
2240
  inputs = combine_args_kwargs(func, args, kwargs)
2111
2241
  span.record_input(inputs)
2112
2242
  if agent_name:
2113
2243
  span.record_agent_name(agent_name)
2114
2244
 
2115
2245
  # Capture state before execution
2116
- self._conditionally_capture_and_record_state(span, args, is_before=True)
2246
+ self._conditionally_capture_and_record_state(
2247
+ span, args, is_before=True
2248
+ )
2117
2249
 
2118
- if use_deep_tracing:
2119
- with _DeepTracer(self):
2120
- result = func(*args, **kwargs)
2121
- else:
2122
- try:
2250
+ try:
2251
+ if use_deep_tracing:
2252
+ with _DeepTracer(self):
2253
+ result = func(*args, **kwargs)
2254
+ else:
2123
2255
  result = func(*args, **kwargs)
2124
- except Exception as e:
2125
- _capture_exception_for_trace(current_trace, sys.exc_info())
2126
- raise e
2127
-
2256
+ except Exception as e:
2257
+ _capture_exception_for_trace(current_trace, sys.exc_info())
2258
+ raise e
2259
+
2128
2260
  # Capture state after execution
2129
- self._conditionally_capture_and_record_state(span, args, is_before=False)
2130
-
2261
+ self._conditionally_capture_and_record_state(
2262
+ span, args, is_before=False
2263
+ )
2264
+
2131
2265
  span.record_output(result)
2132
2266
  return result
2133
-
2267
+
2134
2268
  return wrapper
2135
-
2136
- def observe_tools(self, cls=None, *, exclude_methods: Optional[List[str]] = None,
2137
- include_private: bool = False, warn_on_double_decoration: bool = True):
2269
+
2270
+ def observe_tools(
2271
+ self,
2272
+ cls=None,
2273
+ *,
2274
+ exclude_methods: Optional[List[str]] = None,
2275
+ include_private: bool = False,
2276
+ warn_on_double_decoration: bool = True,
2277
+ ):
2138
2278
  """
2139
2279
  Automatically adds @observe(span_type="tool") to all methods in a class.
2140
2280
 
@@ -2146,28 +2286,32 @@ class Tracer:
2146
2286
  """
2147
2287
 
2148
2288
  if exclude_methods is None:
2149
- exclude_methods = ['__init__', '__new__', '__del__', '__str__', '__repr__']
2150
-
2289
+ exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
2290
+
2151
2291
  def decorate_class(cls):
2152
2292
  if not self.enable_monitoring:
2153
2293
  return cls
2154
-
2294
+
2155
2295
  decorated = []
2156
2296
  skipped = []
2157
-
2297
+
2158
2298
  for name in dir(cls):
2159
2299
  method = getattr(cls, name)
2160
-
2161
- if (not callable(method) or
2162
- name in exclude_methods or
2163
- (name.startswith('_') and not include_private) or
2164
- not hasattr(cls, name)):
2300
+
2301
+ if (
2302
+ not callable(method)
2303
+ or name in exclude_methods
2304
+ or (name.startswith("_") and not include_private)
2305
+ or not hasattr(cls, name)
2306
+ ):
2165
2307
  continue
2166
-
2167
- if hasattr(method, '_judgment_span_name'):
2308
+
2309
+ if hasattr(method, "_judgment_span_name"):
2168
2310
  skipped.append(name)
2169
2311
  if warn_on_double_decoration:
2170
- print(f"Warning: {cls.__name__}.{name} already decorated, skipping")
2312
+ print(
2313
+ f"Warning: {cls.__name__}.{name} already decorated, skipping"
2314
+ )
2171
2315
  continue
2172
2316
 
2173
2317
  try:
@@ -2177,28 +2321,76 @@ class Tracer:
2177
2321
  except Exception as e:
2178
2322
  if warn_on_double_decoration:
2179
2323
  print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
2180
-
2324
+
2181
2325
  return cls
2182
-
2326
+
2183
2327
  return decorate_class if cls is None else decorate_class(cls)
2184
2328
 
2185
2329
  def async_evaluate(self, *args, **kwargs):
2186
- if not self.enable_evaluations:
2187
- return
2330
+ try:
2331
+ if not self.enable_monitoring or not self.enable_evaluations:
2332
+ return
2188
2333
 
2189
- # --- Get trace_id passed explicitly (if any) ---
2190
- passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
2334
+ # --- Get trace_id passed explicitly (if any) ---
2335
+ passed_trace_id = kwargs.pop(
2336
+ "trace_id", None
2337
+ ) # Get and remove trace_id from kwargs
2191
2338
 
2339
+ current_trace = self.get_current_trace()
2340
+
2341
+ if current_trace:
2342
+ # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
2343
+ # (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
2344
+ if passed_trace_id:
2345
+ kwargs["trace_id"] = (
2346
+ passed_trace_id # Re-add if needed by TraceClient.async_evaluate
2347
+ )
2348
+ current_trace.async_evaluate(*args, **kwargs)
2349
+ else:
2350
+ warnings.warn(
2351
+ "No trace found (context var or fallback), skipping evaluation"
2352
+ ) # Modified warning
2353
+ except Exception as e:
2354
+ warnings.warn(f"Issue with async_evaluate: {e}")
2355
+
2356
+ def update_metadata(self, metadata: dict):
2357
+ """
2358
+ Update metadata for the current trace.
2359
+
2360
+ Args:
2361
+ metadata: Metadata as a dictionary
2362
+ """
2192
2363
  current_trace = self.get_current_trace()
2364
+ if current_trace:
2365
+ current_trace.update_metadata(metadata)
2366
+ else:
2367
+ warnings.warn("No current trace found, cannot set metadata")
2368
+
2369
+ def set_customer_id(self, customer_id: str):
2370
+ """
2371
+ Set the customer ID for the current trace.
2193
2372
 
2373
+ Args:
2374
+ customer_id: The customer ID to set
2375
+ """
2376
+ current_trace = self.get_current_trace()
2194
2377
  if current_trace:
2195
- # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
2196
- # (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
2197
- if passed_trace_id:
2198
- kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
2199
- current_trace.async_evaluate(*args, **kwargs)
2378
+ current_trace.set_customer_id(customer_id)
2200
2379
  else:
2201
- warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
2380
+ warnings.warn("No current trace found, cannot set customer ID")
2381
+
2382
+ def set_tags(self, tags: List[Union[str, set, tuple]]):
2383
+ """
2384
+ Set the tags for the current trace.
2385
+
2386
+ Args:
2387
+ tags: List of tags to set
2388
+ """
2389
+ current_trace = self.get_current_trace()
2390
+ if current_trace:
2391
+ current_trace.set_tags(tags)
2392
+ else:
2393
+ warnings.warn("No current trace found, cannot set tags")
2202
2394
 
2203
2395
  def get_background_span_service(self) -> Optional[BackgroundSpanService]:
2204
2396
  """Get the background span service instance."""
@@ -2215,31 +2407,43 @@ class Tracer:
2215
2407
  self.background_span_service.shutdown()
2216
2408
  self.background_span_service = None
2217
2409
 
2218
- def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts) -> Any:
2410
+
2411
+ def _get_current_trace(
2412
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2413
+ ):
2414
+ if trace_across_async_contexts:
2415
+ return Tracer.current_trace
2416
+ else:
2417
+ return current_trace_var.get()
2418
+
2419
+
2420
+ def wrap(
2421
+ client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
2422
+ ) -> Any:
2219
2423
  """
2220
2424
  Wraps an API client to add tracing capabilities.
2221
2425
  Supports OpenAI, Together, Anthropic, and Google GenAI clients.
2222
2426
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
2223
2427
  """
2224
- span_name, original_create, original_responses_create, original_stream, original_beta_parse = _get_client_config(client)
2428
+ (
2429
+ span_name,
2430
+ original_create,
2431
+ original_responses_create,
2432
+ original_stream,
2433
+ original_beta_parse,
2434
+ ) = _get_client_config(client)
2225
2435
 
2226
- def _get_current_trace():
2227
- if trace_across_async_contexts:
2228
- return Tracer.current_trace
2229
- else:
2230
- return current_trace_var.get()
2231
-
2232
2436
  def _record_input_and_check_streaming(span, kwargs, is_responses=False):
2233
2437
  """Record input and check for streaming"""
2234
2438
  is_streaming = kwargs.get("stream", False)
2235
2439
 
2236
- # Record input based on whether this is a responses endpoint
2440
+ # Record input based on whether this is a responses endpoint
2237
2441
  if is_responses:
2238
2442
  span.record_input(kwargs)
2239
2443
  else:
2240
2444
  input_data = _format_input_data(client, **kwargs)
2241
2445
  span.record_input(input_data)
2242
-
2446
+
2243
2447
  # Warn about token counting limitations with streaming
2244
2448
  if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
2245
2449
  if not kwargs.get("stream_options", {}).get("include_usage"):
@@ -2247,88 +2451,101 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2247
2451
  "OpenAI streaming calls don't include token counts by default. "
2248
2452
  "To enable token counting with streams, set stream_options={'include_usage': True} "
2249
2453
  "in your API call arguments.",
2250
- UserWarning
2454
+ UserWarning,
2251
2455
  )
2252
-
2456
+
2253
2457
  return is_streaming
2254
-
2458
+
2255
2459
  def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
2256
2460
  """Format and record the output in the span"""
2257
2461
  if is_streaming:
2258
2462
  output_entry = span.record_output("<pending stream>")
2259
2463
  wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
2260
- return wrapper_func(response, client, output_entry)
2464
+ return wrapper_func(
2465
+ response, client, output_entry, trace_across_async_contexts
2466
+ )
2261
2467
  else:
2262
- format_func = _format_response_output_data if is_responses else _format_output_data
2468
+ format_func = (
2469
+ _format_response_output_data if is_responses else _format_output_data
2470
+ )
2263
2471
  output, usage = format_func(client, response)
2264
2472
  span.record_output(output)
2265
2473
  span.record_usage(usage)
2266
-
2474
+
2267
2475
  # Queue the completed LLM span now that it has all data (input, output, usage)
2268
- current_trace = _get_current_trace()
2476
+ current_trace = _get_current_trace(trace_across_async_contexts)
2269
2477
  if current_trace and current_trace.background_span_service:
2270
2478
  # Get the current span from the trace client
2271
2479
  current_span_id = current_trace.get_current_span()
2272
2480
  if current_span_id and current_span_id in current_trace.span_id_to_span:
2273
2481
  completed_span = current_trace.span_id_to_span[current_span_id]
2274
- current_trace.background_span_service.queue_span(completed_span, span_state="completed")
2275
-
2482
+ current_trace.background_span_service.queue_span(
2483
+ completed_span, span_state="completed"
2484
+ )
2485
+
2276
2486
  return response
2277
-
2487
+
2278
2488
  # --- Traced Async Functions ---
2279
2489
  async def traced_create_async(*args, **kwargs):
2280
- current_trace = _get_current_trace()
2490
+ current_trace = _get_current_trace(trace_across_async_contexts)
2281
2491
  if not current_trace:
2282
2492
  return await original_create(*args, **kwargs)
2283
-
2493
+
2284
2494
  with current_trace.span(span_name, span_type="llm") as span:
2285
2495
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2286
-
2496
+
2287
2497
  try:
2288
2498
  response_or_iterator = await original_create(*args, **kwargs)
2289
- return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
2499
+ return _format_and_record_output(
2500
+ span, response_or_iterator, is_streaming, True, False
2501
+ )
2290
2502
  except Exception as e:
2291
2503
  _capture_exception_for_trace(span, sys.exc_info())
2292
2504
  raise e
2293
-
2505
+
2294
2506
  async def traced_beta_parse_async(*args, **kwargs):
2295
- current_trace = _get_current_trace()
2507
+ current_trace = _get_current_trace(trace_across_async_contexts)
2296
2508
  if not current_trace:
2297
2509
  return await original_beta_parse(*args, **kwargs)
2298
-
2510
+
2299
2511
  with current_trace.span(span_name, span_type="llm") as span:
2300
2512
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2301
-
2513
+
2302
2514
  try:
2303
2515
  response_or_iterator = await original_beta_parse(*args, **kwargs)
2304
- return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
2516
+ return _format_and_record_output(
2517
+ span, response_or_iterator, is_streaming, True, False
2518
+ )
2305
2519
  except Exception as e:
2306
2520
  _capture_exception_for_trace(span, sys.exc_info())
2307
2521
  raise e
2308
-
2309
-
2522
+
2310
2523
  # Async responses for OpenAI clients
2311
2524
  async def traced_response_create_async(*args, **kwargs):
2312
- current_trace = _get_current_trace()
2525
+ current_trace = _get_current_trace(trace_across_async_contexts)
2313
2526
  if not current_trace:
2314
2527
  return await original_responses_create(*args, **kwargs)
2315
-
2528
+
2316
2529
  with current_trace.span(span_name, span_type="llm") as span:
2317
- is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
2318
-
2530
+ is_streaming = _record_input_and_check_streaming(
2531
+ span, kwargs, is_responses=True
2532
+ )
2533
+
2319
2534
  try:
2320
2535
  response_or_iterator = await original_responses_create(*args, **kwargs)
2321
- return _format_and_record_output(span, response_or_iterator, is_streaming, True, True)
2536
+ return _format_and_record_output(
2537
+ span, response_or_iterator, is_streaming, True, True
2538
+ )
2322
2539
  except Exception as e:
2323
2540
  _capture_exception_for_trace(span, sys.exc_info())
2324
2541
  raise e
2325
-
2542
+
2326
2543
  # Function replacing .stream() for async clients
2327
2544
  def traced_stream_async(*args, **kwargs):
2328
- current_trace = _get_current_trace()
2545
+ current_trace = _get_current_trace(trace_across_async_contexts)
2329
2546
  if not current_trace or not original_stream:
2330
2547
  return original_stream(*args, **kwargs)
2331
-
2548
+
2332
2549
  original_manager = original_stream(*args, **kwargs)
2333
2550
  return _TracedAsyncStreamManagerWrapper(
2334
2551
  original_manager=original_manager,
@@ -2336,61 +2553,70 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2336
2553
  span_name=span_name,
2337
2554
  trace_client=current_trace,
2338
2555
  stream_wrapper_func=_async_stream_wrapper,
2339
- input_kwargs=kwargs
2556
+ input_kwargs=kwargs,
2557
+ trace_across_async_contexts=trace_across_async_contexts,
2340
2558
  )
2341
-
2559
+
2342
2560
  # --- Traced Sync Functions ---
2343
2561
  def traced_create_sync(*args, **kwargs):
2344
- current_trace = _get_current_trace()
2562
+ current_trace = _get_current_trace(trace_across_async_contexts)
2345
2563
  if not current_trace:
2346
2564
  return original_create(*args, **kwargs)
2347
-
2565
+
2348
2566
  with current_trace.span(span_name, span_type="llm") as span:
2349
2567
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2350
-
2568
+
2351
2569
  try:
2352
2570
  response_or_iterator = original_create(*args, **kwargs)
2353
- return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
2571
+ return _format_and_record_output(
2572
+ span, response_or_iterator, is_streaming, False, False
2573
+ )
2354
2574
  except Exception as e:
2355
2575
  _capture_exception_for_trace(span, sys.exc_info())
2356
2576
  raise e
2357
-
2577
+
2358
2578
  def traced_beta_parse_sync(*args, **kwargs):
2359
- current_trace = _get_current_trace()
2579
+ current_trace = _get_current_trace(trace_across_async_contexts)
2360
2580
  if not current_trace:
2361
2581
  return original_beta_parse(*args, **kwargs)
2362
-
2582
+
2363
2583
  with current_trace.span(span_name, span_type="llm") as span:
2364
2584
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2365
-
2585
+
2366
2586
  try:
2367
2587
  response_or_iterator = original_beta_parse(*args, **kwargs)
2368
- return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
2588
+ return _format_and_record_output(
2589
+ span, response_or_iterator, is_streaming, False, False
2590
+ )
2369
2591
  except Exception as e:
2370
2592
  _capture_exception_for_trace(span, sys.exc_info())
2371
2593
  raise e
2372
-
2594
+
2373
2595
  def traced_response_create_sync(*args, **kwargs):
2374
- current_trace = _get_current_trace()
2596
+ current_trace = _get_current_trace(trace_across_async_contexts)
2375
2597
  if not current_trace:
2376
2598
  return original_responses_create(*args, **kwargs)
2377
-
2599
+
2378
2600
  with current_trace.span(span_name, span_type="llm") as span:
2379
- is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
2380
-
2601
+ is_streaming = _record_input_and_check_streaming(
2602
+ span, kwargs, is_responses=True
2603
+ )
2604
+
2381
2605
  try:
2382
2606
  response_or_iterator = original_responses_create(*args, **kwargs)
2383
- return _format_and_record_output(span, response_or_iterator, is_streaming, False, True)
2607
+ return _format_and_record_output(
2608
+ span, response_or_iterator, is_streaming, False, True
2609
+ )
2384
2610
  except Exception as e:
2385
2611
  _capture_exception_for_trace(span, sys.exc_info())
2386
2612
  raise e
2387
-
2613
+
2388
2614
  # Function replacing sync .stream()
2389
2615
  def traced_stream_sync(*args, **kwargs):
2390
- current_trace = _get_current_trace()
2616
+ current_trace = _get_current_trace(trace_across_async_contexts)
2391
2617
  if not current_trace or not original_stream:
2392
2618
  return original_stream(*args, **kwargs)
2393
-
2619
+
2394
2620
  original_manager = original_stream(*args, **kwargs)
2395
2621
  return _TracedSyncStreamManagerWrapper(
2396
2622
  original_manager=original_manager,
@@ -2398,15 +2624,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2398
2624
  span_name=span_name,
2399
2625
  trace_client=current_trace,
2400
2626
  stream_wrapper_func=_sync_stream_wrapper,
2401
- input_kwargs=kwargs
2627
+ input_kwargs=kwargs,
2628
+ trace_across_async_contexts=trace_across_async_contexts,
2402
2629
  )
2403
-
2630
+
2404
2631
  # --- Assign Traced Methods to Client Instance ---
2405
2632
  if isinstance(client, (AsyncOpenAI, AsyncTogether)):
2406
2633
  client.chat.completions.create = traced_create_async
2407
2634
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
2408
2635
  client.responses.create = traced_response_create_async
2409
- if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2636
+ if (
2637
+ hasattr(client, "beta")
2638
+ and hasattr(client.beta, "chat")
2639
+ and hasattr(client.beta.chat, "completions")
2640
+ and hasattr(client.beta.chat.completions, "parse")
2641
+ ):
2410
2642
  client.beta.chat.completions.parse = traced_beta_parse_async
2411
2643
  elif isinstance(client, AsyncAnthropic):
2412
2644
  client.messages.create = traced_create_async
@@ -2418,7 +2650,12 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2418
2650
  client.chat.completions.create = traced_create_sync
2419
2651
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
2420
2652
  client.responses.create = traced_response_create_sync
2421
- if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2653
+ if (
2654
+ hasattr(client, "beta")
2655
+ and hasattr(client.beta, "chat")
2656
+ and hasattr(client.beta.chat, "completions")
2657
+ and hasattr(client.beta.chat.completions, "parse")
2658
+ ):
2422
2659
  client.beta.chat.completions.parse = traced_beta_parse_sync
2423
2660
  elif isinstance(client, Anthropic):
2424
2661
  client.messages.create = traced_create_sync
@@ -2426,17 +2663,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2426
2663
  client.messages.stream = traced_stream_sync
2427
2664
  elif isinstance(client, genai.Client):
2428
2665
  client.models.generate_content = traced_create_sync
2429
-
2666
+
2430
2667
  return client
2431
2668
 
2669
+
2432
2670
  # Helper functions for client-specific operations
2433
2671
 
2434
- def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[callable]]:
2672
+
2673
+ def _get_client_config(
2674
+ client: ApiClient,
2675
+ ) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
2435
2676
  """Returns configuration tuple for the given API client.
2436
-
2677
+
2437
2678
  Args:
2438
2679
  client: An instance of OpenAI, Together, or Anthropic client
2439
-
2680
+
2440
2681
  Returns:
2441
2682
  tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
2442
2683
  - span_name: String identifier for tracing
@@ -2444,23 +2685,36 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
2444
2685
  - responses_method: Reference to the client's responses method (if applicable)
2445
2686
  - stream_method: Reference to the client's stream method (if applicable)
2446
2687
  - beta_parse_method: Reference to the client's beta parse method (if applicable)
2447
-
2688
+
2448
2689
  Raises:
2449
2690
  ValueError: If client type is not supported
2450
2691
  """
2451
2692
  if isinstance(client, (OpenAI, AsyncOpenAI)):
2452
- return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None, client.beta.chat.completions.parse
2693
+ return (
2694
+ "OPENAI_API_CALL",
2695
+ client.chat.completions.create,
2696
+ client.responses.create,
2697
+ None,
2698
+ client.beta.chat.completions.parse,
2699
+ )
2453
2700
  elif isinstance(client, (Together, AsyncTogether)):
2454
2701
  return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
2455
2702
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
2456
- return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream, None
2703
+ return (
2704
+ "ANTHROPIC_API_CALL",
2705
+ client.messages.create,
2706
+ None,
2707
+ client.messages.stream,
2708
+ None,
2709
+ )
2457
2710
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2458
2711
  return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
2459
2712
  raise ValueError(f"Unsupported client type: {type(client)}")
2460
2713
 
2714
+
2461
2715
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
2462
2716
  """Format input parameters based on client type.
2463
-
2717
+
2464
2718
  Extracts relevant parameters from kwargs based on the client type
2465
2719
  to ensure consistent tracing across different APIs.
2466
2720
  """
@@ -2473,25 +2727,23 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
2473
2727
  input_data["response_format"] = kwargs.get("response_format")
2474
2728
  return input_data
2475
2729
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2476
- return {
2477
- "model": kwargs.get("model"),
2478
- "contents": kwargs.get("contents")
2479
- }
2730
+ return {"model": kwargs.get("model"), "contents": kwargs.get("contents")}
2480
2731
  # Anthropic requires additional max_tokens parameter
2481
2732
  return {
2482
2733
  "model": kwargs.get("model"),
2483
2734
  "messages": kwargs.get("messages"),
2484
- "max_tokens": kwargs.get("max_tokens")
2735
+ "max_tokens": kwargs.get("max_tokens"),
2485
2736
  }
2486
2737
 
2487
- def _format_response_output_data(client: ApiClient, response: Any) -> dict:
2738
+
2739
+ def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
2488
2740
  """Format API response data based on client type.
2489
-
2741
+
2490
2742
  Normalizes different response formats into a consistent structure
2491
2743
  for tracing purposes.
2492
2744
  """
2493
2745
  message_content = None
2494
- prompt_tokens = 0
2746
+ prompt_tokens = 0
2495
2747
  completion_tokens = 0
2496
2748
  model_name = None
2497
2749
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
@@ -2501,14 +2753,16 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
2501
2753
  message_content = response.output
2502
2754
  else:
2503
2755
  warnings.warn(f"Unsupported client type: {type(client)}")
2504
- return {}
2505
-
2506
- prompt_cost, completion_cost = cost_per_token(
2756
+ return None, None
2757
+
2758
+ prompt_cost, completion_cost = cost_per_token(
2507
2759
  model=model_name,
2508
2760
  prompt_tokens=prompt_tokens,
2509
2761
  completion_tokens=completion_tokens,
2510
2762
  )
2511
- total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2763
+ total_cost_usd = (
2764
+ (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2765
+ )
2512
2766
  usage = TraceUsage(
2513
2767
  prompt_tokens=prompt_tokens,
2514
2768
  completion_tokens=completion_tokens,
@@ -2516,17 +2770,19 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
2516
2770
  prompt_tokens_cost_usd=prompt_cost,
2517
2771
  completion_tokens_cost_usd=completion_cost,
2518
2772
  total_cost_usd=total_cost_usd,
2519
- model_name=model_name
2773
+ model_name=model_name,
2520
2774
  )
2521
2775
  return message_content, usage
2522
2776
 
2523
2777
 
2524
- def _format_output_data(client: ApiClient, response: Any) -> dict:
2778
+ def _format_output_data(
2779
+ client: ApiClient, response: Any
2780
+ ) -> tuple[Optional[str], Optional[TraceUsage]]:
2525
2781
  """Format API response data based on client type.
2526
-
2782
+
2527
2783
  Normalizes different response formats into a consistent structure
2528
2784
  for tracing purposes.
2529
-
2785
+
2530
2786
  Returns:
2531
2787
  dict containing:
2532
2788
  - content: The generated text
@@ -2541,7 +2797,10 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
2541
2797
  model_name = response.model
2542
2798
  prompt_tokens = response.usage.prompt_tokens
2543
2799
  completion_tokens = response.usage.completion_tokens
2544
- if hasattr(response.choices[0].message, "parsed") and response.choices[0].message.parsed:
2800
+ if (
2801
+ hasattr(response.choices[0].message, "parsed")
2802
+ and response.choices[0].message.parsed
2803
+ ):
2545
2804
  message_content = response.choices[0].message.parsed
2546
2805
  else:
2547
2806
  message_content = response.choices[0].message.content
@@ -2558,13 +2817,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
2558
2817
  else:
2559
2818
  warnings.warn(f"Unsupported client type: {type(client)}")
2560
2819
  return None, None
2561
-
2820
+
2562
2821
  prompt_cost, completion_cost = cost_per_token(
2563
2822
  model=model_name,
2564
2823
  prompt_tokens=prompt_tokens,
2565
2824
  completion_tokens=completion_tokens,
2566
2825
  )
2567
- total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2826
+ total_cost_usd = (
2827
+ (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2828
+ )
2568
2829
  usage = TraceUsage(
2569
2830
  prompt_tokens=prompt_tokens,
2570
2831
  completion_tokens=completion_tokens,
@@ -2572,56 +2833,61 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
2572
2833
  prompt_tokens_cost_usd=prompt_cost,
2573
2834
  completion_tokens_cost_usd=completion_cost,
2574
2835
  total_cost_usd=total_cost_usd,
2575
- model_name=model_name
2836
+ model_name=model_name,
2576
2837
  )
2577
2838
  return message_content, usage
2578
2839
 
2840
+
2579
2841
  def combine_args_kwargs(func, args, kwargs):
2580
2842
  """
2581
2843
  Combine positional arguments and keyword arguments into a single dictionary.
2582
-
2844
+
2583
2845
  Args:
2584
2846
  func: The function being called
2585
2847
  args: Tuple of positional arguments
2586
2848
  kwargs: Dictionary of keyword arguments
2587
-
2849
+
2588
2850
  Returns:
2589
2851
  A dictionary combining both args and kwargs
2590
2852
  """
2591
2853
  try:
2592
2854
  import inspect
2855
+
2593
2856
  sig = inspect.signature(func)
2594
2857
  param_names = list(sig.parameters.keys())
2595
-
2858
+
2596
2859
  args_dict = {}
2597
2860
  for i, arg in enumerate(args):
2598
2861
  if i < len(param_names):
2599
2862
  args_dict[param_names[i]] = arg
2600
2863
  else:
2601
2864
  args_dict[f"arg{i}"] = arg
2602
-
2865
+
2603
2866
  return {**args_dict, **kwargs}
2604
- except Exception as e:
2867
+ except Exception:
2605
2868
  # Fallback if signature inspection fails
2606
2869
  return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
2607
2870
 
2871
+
2608
2872
  # NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
2609
2873
  # @link https://docs.python.org/3.13/library/sysconfig.html
2610
2874
  _TRACE_FILEPATH_BLOCKLIST = tuple(
2611
2875
  os.path.realpath(p) + os.sep
2612
2876
  for p in {
2613
- sysconfig.get_paths()['stdlib'],
2614
- sysconfig.get_paths().get('platstdlib', ''),
2877
+ sysconfig.get_paths()["stdlib"],
2878
+ sysconfig.get_paths().get("platstdlib", ""),
2615
2879
  *site.getsitepackages(),
2616
2880
  site.getusersitepackages(),
2617
2881
  *(
2618
- [os.path.join(os.path.dirname(__file__), '../../judgeval/')]
2619
- if os.environ.get('JUDGMENT_DEV')
2882
+ [os.path.join(os.path.dirname(__file__), "../../judgeval/")]
2883
+ if os.environ.get("JUDGMENT_DEV")
2620
2884
  else []
2621
2885
  ),
2622
- } if p
2886
+ }
2887
+ if p
2623
2888
  )
2624
2889
 
2890
+
2625
2891
  # Add the new TraceThreadPoolExecutor class
2626
2892
  class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
2627
2893
  """
@@ -2633,6 +2899,7 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
2633
2899
  allowing the Tracer to maintain correct parent-child relationships across
2634
2900
  thread boundaries.
2635
2901
  """
2902
+
2636
2903
  def submit(self, fn, /, *args, **kwargs):
2637
2904
  """
2638
2905
  Submit a callable to be executed with the captured context.
@@ -2653,9 +2920,11 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
2653
2920
  # Note: The `map` method would also need to be overridden for full context
2654
2921
  # propagation if users rely on it, but `submit` is the most common use case.
2655
2922
 
2923
+
2656
2924
  # Helper functions for stream processing
2657
2925
  # ---------------------------------------
2658
2926
 
2927
+
2659
2928
  def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
2660
2929
  """Extracts the text content from a stream chunk based on the client type."""
2661
2930
  try:
@@ -2667,34 +2936,49 @@ def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
2667
2936
  return chunk.delta.text
2668
2937
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2669
2938
  # Google streams Candidate objects
2670
- if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts:
2939
+ if (
2940
+ chunk.candidates
2941
+ and chunk.candidates[0].content
2942
+ and chunk.candidates[0].content.parts
2943
+ ):
2671
2944
  return chunk.candidates[0].content.parts[0].text
2672
2945
  except (AttributeError, IndexError, KeyError):
2673
2946
  # Handle cases where chunk structure is unexpected or doesn't contain content
2674
- pass # Return None
2947
+ pass # Return None
2675
2948
  return None
2676
2949
 
2677
- def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[Dict[str, int]]:
2950
+
2951
+ def _extract_usage_from_final_chunk(
2952
+ client: ApiClient, chunk: Any
2953
+ ) -> Optional[Dict[str, int]]:
2678
2954
  """Extracts usage data if present in the *final* chunk (client-specific)."""
2679
2955
  try:
2680
2956
  # OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
2681
2957
  # This typically requires specific API versions or settings. Often usage is *not* streamed.
2682
2958
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
2683
2959
  # Check if usage is directly on the chunk (some models might do this)
2684
- if hasattr(chunk, 'usage') and chunk.usage:
2960
+ if hasattr(chunk, "usage") and chunk.usage:
2685
2961
  prompt_tokens = chunk.usage.prompt_tokens
2686
2962
  completion_tokens = chunk.usage.completion_tokens
2687
2963
  # Check if usage is nested within choices (less common for final chunk, but check)
2688
- elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
2964
+ elif (
2965
+ chunk.choices
2966
+ and hasattr(chunk.choices[0], "usage")
2967
+ and chunk.choices[0].usage
2968
+ ):
2689
2969
  prompt_tokens = chunk.choices[0].usage.prompt_tokens
2690
2970
  completion_tokens = chunk.choices[0].usage.completion_tokens
2691
-
2971
+
2692
2972
  prompt_cost, completion_cost = cost_per_token(
2693
- model=chunk.model,
2694
- prompt_tokens=prompt_tokens,
2695
- completion_tokens=completion_tokens,
2696
- )
2697
- total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2973
+ model=chunk.model,
2974
+ prompt_tokens=prompt_tokens,
2975
+ completion_tokens=completion_tokens,
2976
+ )
2977
+ total_cost_usd = (
2978
+ (prompt_cost + completion_cost)
2979
+ if prompt_cost and completion_cost
2980
+ else None
2981
+ )
2698
2982
  return TraceUsage(
2699
2983
  prompt_tokens=chunk.usage.prompt_tokens,
2700
2984
  completion_tokens=chunk.usage.completion_tokens,
@@ -2702,9 +2986,9 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
2702
2986
  prompt_tokens_cost_usd=prompt_cost,
2703
2987
  completion_tokens_cost_usd=completion_cost,
2704
2988
  total_cost_usd=total_cost_usd,
2705
- model_name=chunk.model
2989
+ model_name=chunk.model,
2706
2990
  )
2707
- # Anthropic includes usage in the 'message_stop' event type
2991
+ # Anthropic includes usage in the 'message_stop' event type
2708
2992
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
2709
2993
  if chunk.type == "message_stop":
2710
2994
  # Anthropic final usage is often attached to the *message* object, not the chunk directly
@@ -2713,18 +2997,18 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
2713
2997
  # This is a placeholder - Anthropic usage typically needs a separate call or context.
2714
2998
  pass
2715
2999
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2716
- # Google provides usage metadata on the full response object, not typically streamed per chunk.
2717
- # It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
2718
- if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
2719
- return {
2720
- "prompt_tokens": chunk.usage_metadata.prompt_token_count,
2721
- "completion_tokens": chunk.usage_metadata.candidates_token_count,
2722
- "total_tokens": chunk.usage_metadata.total_token_count
2723
- }
3000
+ # Google provides usage metadata on the full response object, not typically streamed per chunk.
3001
+ # It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
3002
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
3003
+ return {
3004
+ "prompt_tokens": chunk.usage_metadata.prompt_token_count,
3005
+ "completion_tokens": chunk.usage_metadata.candidates_token_count,
3006
+ "total_tokens": chunk.usage_metadata.total_token_count,
3007
+ }
2724
3008
 
2725
3009
  except (AttributeError, IndexError, KeyError, TypeError):
2726
3010
  # Handle cases where usage data is missing or malformed
2727
- pass # Return None
3011
+ pass # Return None
2728
3012
  return None
2729
3013
 
2730
3014
 
@@ -2732,7 +3016,8 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
2732
3016
  def _sync_stream_wrapper(
2733
3017
  original_stream: Iterator,
2734
3018
  client: ApiClient,
2735
- span: TraceSpan
3019
+ span: TraceSpan,
3020
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2736
3021
  ) -> Generator[Any, None, None]:
2737
3022
  """Wraps a synchronous stream iterator to capture content and update the trace."""
2738
3023
  content_parts = [] # Use a list instead of string concatenation
@@ -2742,9 +3027,11 @@ def _sync_stream_wrapper(
2742
3027
  for chunk in original_stream:
2743
3028
  content_part = _extract_content_from_chunk(client, chunk)
2744
3029
  if content_part:
2745
- content_parts.append(content_part) # Append to list instead of concatenating
2746
- last_chunk = chunk # Keep track of the last chunk for potential usage data
2747
- yield chunk # Pass the chunk to the caller
3030
+ content_parts.append(
3031
+ content_part
3032
+ ) # Append to list instead of concatenating
3033
+ last_chunk = chunk # Keep track of the last chunk for potential usage data
3034
+ yield chunk # Pass the chunk to the caller
2748
3035
  finally:
2749
3036
  # Attempt to extract usage from the last chunk received
2750
3037
  if last_chunk:
@@ -2753,23 +3040,25 @@ def _sync_stream_wrapper(
2753
3040
  # Update the trace entry with the accumulated content and usage
2754
3041
  span.output = "".join(content_parts)
2755
3042
  span.usage = final_usage
2756
-
3043
+
2757
3044
  # Queue the completed LLM span now that streaming is done and all data is available
2758
- # Note: We need to get the TraceClient that owns this span to access the background service
2759
- # We can find this through the tracer singleton since spans are associated with traces
2760
- from judgeval.common.tracer import Tracer
2761
- tracer_instance = Tracer._instance
2762
- if tracer_instance and tracer_instance.background_span_service:
2763
- tracer_instance.background_span_service.queue_span(span, span_state="completed")
2764
-
3045
+
3046
+ current_trace = _get_current_trace(trace_across_async_contexts)
3047
+ if current_trace and current_trace.background_span_service:
3048
+ current_trace.background_span_service.queue_span(
3049
+ span, span_state="completed"
3050
+ )
3051
+
2765
3052
  # Note: We might need to adjust _serialize_output if this dict causes issues,
2766
3053
  # but Pydantic's model_dump should handle dicts.
2767
3054
 
3055
+
2768
3056
  # --- Async Stream Wrapper ---
2769
3057
  async def _async_stream_wrapper(
2770
3058
  original_stream: AsyncIterator,
2771
3059
  client: ApiClient,
2772
- span: TraceSpan
3060
+ span: TraceSpan,
3061
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2773
3062
  ) -> AsyncGenerator[Any, None]:
2774
3063
  # [Existing logic - unchanged]
2775
3064
  content_parts = [] # Use a list instead of string concatenation
@@ -2778,56 +3067,70 @@ async def _async_stream_wrapper(
2778
3067
  anthropic_input_tokens = 0
2779
3068
  anthropic_output_tokens = 0
2780
3069
 
2781
- target_span_id = span.span_id
2782
-
2783
3070
  try:
2784
3071
  model_name = ""
2785
3072
  async for chunk in original_stream:
2786
3073
  # Check for OpenAI's final usage chunk
2787
- if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
3074
+ if (
3075
+ isinstance(client, (AsyncOpenAI, OpenAI))
3076
+ and hasattr(chunk, "usage")
3077
+ and chunk.usage is not None
3078
+ ):
2788
3079
  final_usage_data = {
2789
3080
  "prompt_tokens": chunk.usage.prompt_tokens,
2790
3081
  "completion_tokens": chunk.usage.completion_tokens,
2791
- "total_tokens": chunk.usage.total_tokens
3082
+ "total_tokens": chunk.usage.total_tokens,
2792
3083
  }
2793
3084
  model_name = chunk.model
2794
3085
  yield chunk
2795
3086
  continue
2796
3087
 
2797
- if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
3088
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
3089
+ chunk, "type"
3090
+ ):
2798
3091
  if chunk.type == "message_start":
2799
- if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
2800
- anthropic_input_tokens = chunk.message.usage.input_tokens
2801
- model_name = chunk.message.model
3092
+ if (
3093
+ hasattr(chunk, "message")
3094
+ and hasattr(chunk.message, "usage")
3095
+ and hasattr(chunk.message.usage, "input_tokens")
3096
+ ):
3097
+ anthropic_input_tokens = chunk.message.usage.input_tokens
3098
+ model_name = chunk.message.model
2802
3099
  elif chunk.type == "message_delta":
2803
- if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
3100
+ if hasattr(chunk, "usage") and hasattr(
3101
+ chunk.usage, "output_tokens"
3102
+ ):
2804
3103
  anthropic_output_tokens = chunk.usage.output_tokens
2805
3104
 
2806
3105
  content_part = _extract_content_from_chunk(client, chunk)
2807
3106
  if content_part:
2808
- content_parts.append(content_part) # Append to list instead of concatenating
3107
+ content_parts.append(
3108
+ content_part
3109
+ ) # Append to list instead of concatenating
2809
3110
  last_content_chunk = chunk
2810
3111
 
2811
3112
  yield chunk
2812
3113
  finally:
2813
3114
  anthropic_final_usage = None
2814
- if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
2815
- anthropic_final_usage = {
2816
- "prompt_tokens": anthropic_input_tokens,
2817
- "completion_tokens": anthropic_output_tokens,
2818
- "total_tokens": anthropic_input_tokens + anthropic_output_tokens
2819
- }
3115
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and (
3116
+ anthropic_input_tokens > 0 or anthropic_output_tokens > 0
3117
+ ):
3118
+ anthropic_final_usage = {
3119
+ "prompt_tokens": anthropic_input_tokens,
3120
+ "completion_tokens": anthropic_output_tokens,
3121
+ "total_tokens": anthropic_input_tokens + anthropic_output_tokens,
3122
+ }
2820
3123
 
2821
3124
  usage_info = None
2822
3125
  if final_usage_data:
2823
- usage_info = final_usage_data
3126
+ usage_info = final_usage_data
2824
3127
  elif anthropic_final_usage:
2825
- usage_info = anthropic_final_usage
3128
+ usage_info = anthropic_final_usage
2826
3129
  elif last_content_chunk:
2827
3130
  usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
2828
3131
 
2829
3132
  if usage_info and not isinstance(usage_info, TraceUsage):
2830
- prompt_cost, completion_cost = cost_per_token(
3133
+ prompt_cost, completion_cost = cost_per_token(
2831
3134
  model=model_name,
2832
3135
  prompt_tokens=usage_info["prompt_tokens"],
2833
3136
  completion_tokens=usage_info["completion_tokens"],
@@ -2839,21 +3142,23 @@ async def _async_stream_wrapper(
2839
3142
  prompt_tokens_cost_usd=prompt_cost,
2840
3143
  completion_tokens_cost_usd=completion_cost,
2841
3144
  total_cost_usd=prompt_cost + completion_cost,
2842
- model_name=model_name
3145
+ model_name=model_name,
2843
3146
  )
2844
- if span and hasattr(span, 'output'):
2845
- span.output = ''.join(content_parts)
3147
+ if span and hasattr(span, "output"):
3148
+ span.output = "".join(content_parts)
2846
3149
  span.usage = usage_info
2847
- start_ts = getattr(span, 'created_at', time.time())
3150
+ start_ts = getattr(span, "created_at", time.time())
2848
3151
  span.duration = time.time() - start_ts
2849
-
3152
+
2850
3153
  # Queue the completed LLM span now that async streaming is done and all data is available
2851
- from judgeval.common.tracer import Tracer
2852
- tracer_instance = Tracer._instance
2853
- if tracer_instance and tracer_instance.background_span_service:
2854
- tracer_instance.background_span_service.queue_span(span, span_state="completed")
3154
+ current_trace = _get_current_trace(trace_across_async_contexts)
3155
+ if current_trace and current_trace.background_span_service:
3156
+ current_trace.background_span_service.queue_span(
3157
+ span, span_state="completed"
3158
+ )
2855
3159
  # else: # Handle error case if necessary, but remove debug print
2856
3160
 
3161
+
2857
3162
  def cost_per_token(*args, **kwargs):
2858
3163
  try:
2859
3164
  return _original_cost_per_token(*args, **kwargs)
@@ -2861,8 +3166,18 @@ def cost_per_token(*args, **kwargs):
2861
3166
  warnings.warn(f"Error calculating cost per token: {e}")
2862
3167
  return None, None
2863
3168
 
3169
+
2864
3170
  class _BaseStreamManagerWrapper:
2865
- def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
3171
+ def __init__(
3172
+ self,
3173
+ original_manager,
3174
+ client,
3175
+ span_name,
3176
+ trace_client,
3177
+ stream_wrapper_func,
3178
+ input_kwargs,
3179
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
3180
+ ):
2866
3181
  self._original_manager = original_manager
2867
3182
  self._client = client
2868
3183
  self._span_name = span_name
@@ -2870,13 +3185,19 @@ class _BaseStreamManagerWrapper:
2870
3185
  self._stream_wrapper_func = stream_wrapper_func
2871
3186
  self._input_kwargs = input_kwargs
2872
3187
  self._parent_span_id_at_entry = None
3188
+ self._trace_across_async_contexts = trace_across_async_contexts
2873
3189
 
2874
3190
  def _create_span(self):
2875
3191
  start_time = time.time()
2876
3192
  span_id = str(uuid.uuid4())
2877
3193
  current_depth = 0
2878
- if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2879
- current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
3194
+ if (
3195
+ self._parent_span_id_at_entry
3196
+ and self._parent_span_id_at_entry in self._trace_client._span_depths
3197
+ ):
3198
+ current_depth = (
3199
+ self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
3200
+ )
2880
3201
  self._trace_client._span_depths[span_id] = current_depth
2881
3202
  span = TraceSpan(
2882
3203
  function=self._span_name,
@@ -2886,7 +3207,7 @@ class _BaseStreamManagerWrapper:
2886
3207
  message=self._span_name,
2887
3208
  created_at=start_time,
2888
3209
  span_type="llm",
2889
- parent_span_id=self._parent_span_id_at_entry
3210
+ parent_span_id=self._parent_span_id_at_entry,
2890
3211
  )
2891
3212
  self._trace_client.add_span(span)
2892
3213
  return span_id, span
@@ -2898,7 +3219,10 @@ class _BaseStreamManagerWrapper:
2898
3219
  if span_id in self._trace_client._span_depths:
2899
3220
  del self._trace_client._span_depths[span_id]
2900
3221
 
2901
- class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
3222
+
3223
+ class _TracedAsyncStreamManagerWrapper(
3224
+ _BaseStreamManagerWrapper, AbstractAsyncContextManager
3225
+ ):
2902
3226
  async def __aenter__(self):
2903
3227
  self._parent_span_id_at_entry = self._trace_client.get_current_span()
2904
3228
  if not self._trace_client:
@@ -2911,17 +3235,22 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
2911
3235
  # Call the original __aenter__ and expect it to be an async generator
2912
3236
  raw_iterator = await self._original_manager.__aenter__()
2913
3237
  span.output = "<pending stream>"
2914
- return self._stream_wrapper_func(raw_iterator, self._client, span)
3238
+ return self._stream_wrapper_func(
3239
+ raw_iterator, self._client, span, self._trace_across_async_contexts
3240
+ )
2915
3241
 
2916
3242
  async def __aexit__(self, exc_type, exc_val, exc_tb):
2917
- if hasattr(self, '_span_context_token'):
3243
+ if hasattr(self, "_span_context_token"):
2918
3244
  span_id = self._trace_client.get_current_span()
2919
3245
  self._finalize_span(span_id)
2920
3246
  self._trace_client.reset_current_span(self._span_context_token)
2921
- delattr(self, '_span_context_token')
3247
+ delattr(self, "_span_context_token")
2922
3248
  return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2923
3249
 
2924
- class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
3250
+
3251
+ class _TracedSyncStreamManagerWrapper(
3252
+ _BaseStreamManagerWrapper, AbstractContextManager
3253
+ ):
2925
3254
  def __enter__(self):
2926
3255
  self._parent_span_id_at_entry = self._trace_client.get_current_span()
2927
3256
  if not self._trace_client:
@@ -2933,16 +3262,19 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
2933
3262
 
2934
3263
  raw_iterator = self._original_manager.__enter__()
2935
3264
  span.output = "<pending stream>"
2936
- return self._stream_wrapper_func(raw_iterator, self._client, span)
3265
+ return self._stream_wrapper_func(
3266
+ raw_iterator, self._client, span, self._trace_across_async_contexts
3267
+ )
2937
3268
 
2938
3269
  def __exit__(self, exc_type, exc_val, exc_tb):
2939
- if hasattr(self, '_span_context_token'):
3270
+ if hasattr(self, "_span_context_token"):
2940
3271
  span_id = self._trace_client.get_current_span()
2941
3272
  self._finalize_span(span_id)
2942
3273
  self._trace_client.reset_current_span(self._span_context_token)
2943
- delattr(self, '_span_context_token')
3274
+ delattr(self, "_span_context_token")
2944
3275
  return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2945
3276
 
3277
+
2946
3278
  # --- Helper function for instance-prefixed qual_name ---
2947
3279
  def get_instance_prefixed_name(instance, class_name, class_identifiers):
2948
3280
  """
@@ -2951,11 +3283,13 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
2951
3283
  """
2952
3284
  if class_name in class_identifiers:
2953
3285
  class_config = class_identifiers[class_name]
2954
- attr = class_config['identifier']
2955
-
3286
+ attr = class_config["identifier"]
3287
+
2956
3288
  if hasattr(instance, attr):
2957
3289
  instance_name = getattr(instance, attr)
2958
3290
  return instance_name
2959
3291
  else:
2960
- raise Exception(f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator.")
2961
- return None
3292
+ raise Exception(
3293
+ f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
3294
+ )
3295
+ return None