judgeval 0.0.43__py3-none-any.whl → 0.0.45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. judgeval/__init__.py +5 -4
  2. judgeval/clients.py +6 -6
  3. judgeval/common/__init__.py +7 -2
  4. judgeval/common/exceptions.py +2 -3
  5. judgeval/common/logger.py +74 -49
  6. judgeval/common/s3_storage.py +30 -23
  7. judgeval/common/tracer.py +1302 -984
  8. judgeval/common/utils.py +416 -244
  9. judgeval/constants.py +73 -61
  10. judgeval/data/__init__.py +1 -1
  11. judgeval/data/custom_example.py +3 -2
  12. judgeval/data/datasets/dataset.py +80 -54
  13. judgeval/data/datasets/eval_dataset_client.py +131 -181
  14. judgeval/data/example.py +67 -43
  15. judgeval/data/result.py +11 -9
  16. judgeval/data/scorer_data.py +4 -2
  17. judgeval/data/tool.py +25 -16
  18. judgeval/data/trace.py +57 -29
  19. judgeval/data/trace_run.py +5 -11
  20. judgeval/evaluation_run.py +22 -82
  21. judgeval/integrations/langgraph.py +546 -184
  22. judgeval/judges/base_judge.py +1 -2
  23. judgeval/judges/litellm_judge.py +33 -11
  24. judgeval/judges/mixture_of_judges.py +128 -78
  25. judgeval/judges/together_judge.py +22 -9
  26. judgeval/judges/utils.py +14 -5
  27. judgeval/judgment_client.py +259 -271
  28. judgeval/rules.py +169 -142
  29. judgeval/run_evaluation.py +462 -305
  30. judgeval/scorers/api_scorer.py +20 -11
  31. judgeval/scorers/exceptions.py +1 -0
  32. judgeval/scorers/judgeval_scorer.py +77 -58
  33. judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +46 -15
  34. judgeval/scorers/judgeval_scorers/api_scorers/answer_correctness.py +3 -2
  35. judgeval/scorers/judgeval_scorers/api_scorers/answer_relevancy.py +3 -2
  36. judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +12 -11
  37. judgeval/scorers/judgeval_scorers/api_scorers/comparison.py +7 -5
  38. judgeval/scorers/judgeval_scorers/api_scorers/contextual_precision.py +3 -2
  39. judgeval/scorers/judgeval_scorers/api_scorers/contextual_recall.py +3 -2
  40. judgeval/scorers/judgeval_scorers/api_scorers/contextual_relevancy.py +5 -2
  41. judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +2 -1
  42. judgeval/scorers/judgeval_scorers/api_scorers/execution_order.py +17 -8
  43. judgeval/scorers/judgeval_scorers/api_scorers/faithfulness.py +3 -2
  44. judgeval/scorers/judgeval_scorers/api_scorers/groundedness.py +3 -2
  45. judgeval/scorers/judgeval_scorers/api_scorers/hallucination.py +3 -2
  46. judgeval/scorers/judgeval_scorers/api_scorers/instruction_adherence.py +3 -2
  47. judgeval/scorers/judgeval_scorers/api_scorers/json_correctness.py +8 -9
  48. judgeval/scorers/judgeval_scorers/api_scorers/summarization.py +4 -4
  49. judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +5 -5
  50. judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +5 -2
  51. judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +9 -10
  52. judgeval/scorers/prompt_scorer.py +48 -37
  53. judgeval/scorers/score.py +86 -53
  54. judgeval/scorers/utils.py +11 -7
  55. judgeval/tracer/__init__.py +1 -1
  56. judgeval/utils/alerts.py +23 -12
  57. judgeval/utils/{data_utils.py → file_utils.py} +5 -9
  58. judgeval/utils/requests.py +29 -0
  59. judgeval/version_check.py +5 -2
  60. {judgeval-0.0.43.dist-info → judgeval-0.0.45.dist-info}/METADATA +79 -135
  61. judgeval-0.0.45.dist-info/RECORD +69 -0
  62. judgeval-0.0.43.dist-info/RECORD +0 -68
  63. {judgeval-0.0.43.dist-info → judgeval-0.0.45.dist-info}/WHEEL +0 -0
  64. {judgeval-0.0.43.dist-info → judgeval-0.0.45.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """
2
2
  Tracing system for judgeval that allows for function tracing using decorators.
3
3
  """
4
+
4
5
  # Standard library imports
5
6
  import asyncio
6
7
  import functools
@@ -16,9 +17,13 @@ import warnings
16
17
  import contextvars
17
18
  import sys
18
19
  import json
19
- from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
20
- from dataclasses import dataclass, field
21
- from datetime import datetime, timezone
20
+ from contextlib import (
21
+ contextmanager,
22
+ AbstractAsyncContextManager,
23
+ AbstractContextManager,
24
+ ) # Import context manager bases
25
+ from dataclasses import dataclass
26
+ from datetime import datetime
22
27
  from http import HTTPStatus
23
28
  from typing import (
24
29
  Any,
@@ -37,9 +42,9 @@ from rich import print as rprint
37
42
  import types
38
43
 
39
44
  # Third-party imports
40
- import requests
45
+ from requests import RequestException
46
+ from judgeval.utils.requests import requests
41
47
  from litellm import cost_per_token as _original_cost_per_token
42
- from rich import print as rprint
43
48
  from openai import OpenAI, AsyncOpenAI
44
49
  from together import Together, AsyncTogether
45
50
  from anthropic import Anthropic, AsyncAnthropic
@@ -50,12 +55,7 @@ from judgeval.constants import (
50
55
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
51
56
  JUDGMENT_TRACES_SAVE_API_URL,
52
57
  JUDGMENT_TRACES_UPSERT_API_URL,
53
- JUDGMENT_TRACES_USAGE_CHECK_API_URL,
54
- JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
55
58
  JUDGMENT_TRACES_FETCH_API_URL,
56
- RABBITMQ_HOST,
57
- RABBITMQ_PORT,
58
- RABBITMQ_QUEUE,
59
59
  JUDGMENT_TRACES_DELETE_API_URL,
60
60
  JUDGMENT_PROJECT_DELETE_API_URL,
61
61
  JUDGMENT_TRACES_SPANS_BATCH_API_URL,
@@ -70,32 +70,37 @@ from judgeval.common.exceptions import JudgmentAPIError
70
70
 
71
71
  # Standard library imports needed for the new class
72
72
  import concurrent.futures
73
- from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
73
+ from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
74
74
  import queue
75
75
  import atexit
76
76
 
77
77
  # Define context variables for tracking the current trace and the current span within a trace
78
- current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
79
- current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
78
+ current_trace_var = contextvars.ContextVar[Optional["TraceClient"]](
79
+ "current_trace", default=None
80
+ )
81
+ current_span_var = contextvars.ContextVar[Optional[str]](
82
+ "current_span", default=None
83
+ ) # ContextVar for the active span id
80
84
 
81
85
  # Define type aliases for better code readability and maintainability
82
- ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
83
- SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
86
+ ApiClient: TypeAlias = Union[
87
+ OpenAI,
88
+ Together,
89
+ Anthropic,
90
+ AsyncOpenAI,
91
+ AsyncAnthropic,
92
+ AsyncTogether,
93
+ genai.Client,
94
+ genai.client.AsyncClient,
95
+ ] # Supported API clients
96
+ SpanType = Literal["span", "tool", "llm", "evaluation", "chain"]
84
97
 
85
- # --- Evaluation Config Dataclass (Moved from langgraph.py) ---
86
- @dataclass
87
- class EvaluationConfig:
88
- """Configuration for triggering an evaluation from the handler."""
89
- scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
90
- example: Example
91
- model: Optional[str] = None
92
- log_results: Optional[bool] = True
93
- # --- End Evaluation Config Dataclass ---
94
98
 
95
99
  # Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
96
100
  @dataclass
97
101
  class TraceAnnotation:
98
102
  """Represents a single annotation for a trace span."""
103
+
99
104
  span_id: str
100
105
  text: str
101
106
  label: str
@@ -105,24 +110,27 @@ class TraceAnnotation:
105
110
  """Convert the annotation to a dictionary format for storage/transmission."""
106
111
  return {
107
112
  "span_id": self.span_id,
108
- "annotation": {
109
- "text": self.text,
110
- "label": self.label,
111
- "score": self.score
112
- }
113
+ "annotation": {"text": self.text, "label": self.label, "score": self.score},
113
114
  }
114
-
115
+
116
+
115
117
  class TraceManagerClient:
116
118
  """
117
119
  Client for handling trace endpoints with the Judgment API
118
-
120
+
119
121
 
120
122
  Operations include:
121
123
  - Fetching a trace by id
122
124
  - Saving a trace
123
125
  - Deleting a trace
124
126
  """
125
- def __init__(self, judgment_api_key: str, organization_id: str, tracer: Optional["Tracer"] = None):
127
+
128
+ def __init__(
129
+ self,
130
+ judgment_api_key: str,
131
+ organization_id: str,
132
+ tracer: Optional["Tracer"] = None,
133
+ ):
126
134
  self.judgment_api_key = judgment_api_key
127
135
  self.organization_id = organization_id
128
136
  self.tracer = tracer
@@ -139,17 +147,19 @@ class TraceManagerClient:
139
147
  headers={
140
148
  "Content-Type": "application/json",
141
149
  "Authorization": f"Bearer {self.judgment_api_key}",
142
- "X-Organization-Id": self.organization_id
150
+ "X-Organization-Id": self.organization_id,
143
151
  },
144
- verify=True
152
+ verify=True,
145
153
  )
146
154
 
147
155
  if response.status_code != HTTPStatus.OK:
148
156
  raise ValueError(f"Failed to fetch traces: {response.text}")
149
-
157
+
150
158
  return response.json()
151
159
 
152
- def save_trace(self, trace_data: dict, offline_mode: bool = False, final_save: bool = True):
160
+ def save_trace(
161
+ self, trace_data: dict, offline_mode: bool = False, final_save: bool = True
162
+ ):
153
163
  """
154
164
  Saves a trace to the Judgment Supabase and optionally to S3 if configured.
155
165
 
@@ -158,12 +168,12 @@ class TraceManagerClient:
158
168
  offline_mode: Whether running in offline mode
159
169
  final_save: Whether this is the final save (controls S3 saving)
160
170
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
161
-
171
+
162
172
  Returns:
163
173
  dict: Server response containing UI URL and other metadata
164
174
  """
165
175
  # Save to Judgment API
166
-
176
+
167
177
  def fallback_encoder(obj):
168
178
  """
169
179
  Custom JSON encoder fallback.
@@ -180,7 +190,7 @@ class TraceManagerClient:
180
190
  except Exception as e:
181
191
  # If both fail, you might return a placeholder or re-raise
182
192
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
183
-
193
+
184
194
  serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
185
195
  response = requests.post(
186
196
  JUDGMENT_TRACES_SAVE_API_URL,
@@ -188,71 +198,46 @@ class TraceManagerClient:
188
198
  headers={
189
199
  "Content-Type": "application/json",
190
200
  "Authorization": f"Bearer {self.judgment_api_key}",
191
- "X-Organization-Id": self.organization_id
201
+ "X-Organization-Id": self.organization_id,
192
202
  },
193
- verify=True
203
+ verify=True,
194
204
  )
195
-
205
+
196
206
  if response.status_code == HTTPStatus.BAD_REQUEST:
197
- raise ValueError(f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}")
207
+ raise ValueError(
208
+ f"Failed to save trace data: Check your Trace name for conflicts, set overwrite=True to overwrite existing traces: {response.text}"
209
+ )
198
210
  elif response.status_code != HTTPStatus.OK:
199
211
  raise ValueError(f"Failed to save trace data: {response.text}")
200
-
212
+
201
213
  # Parse server response
202
214
  server_response = response.json()
203
-
215
+
204
216
  # If S3 storage is enabled, save to S3 only on final save
205
217
  if self.tracer and self.tracer.use_s3 and final_save:
206
218
  try:
207
219
  s3_key = self.tracer.s3_storage.save_trace(
208
220
  trace_data=trace_data,
209
221
  trace_id=trace_data["trace_id"],
210
- project_name=trace_data["project_name"]
222
+ project_name=trace_data["project_name"],
211
223
  )
212
224
  print(f"Trace also saved to S3 at key: {s3_key}")
213
225
  except Exception as e:
214
226
  warnings.warn(f"Failed to save trace to S3: {str(e)}")
215
-
227
+
216
228
  if not offline_mode and "ui_results_url" in server_response:
217
229
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
218
230
  rprint(pretty_str)
219
-
220
- return server_response
221
231
 
222
- def check_usage_limits(self, count: int = 1):
223
- """
224
- Check if the organization can use the requested number of traces without exceeding limits.
225
-
226
- Args:
227
- count: Number of traces to check for (default: 1)
228
-
229
- Returns:
230
- dict: Server response with rate limit status and usage info
231
-
232
- Raises:
233
- ValueError: If rate limits would be exceeded or other errors occur
234
- """
235
- response = requests.post(
236
- JUDGMENT_TRACES_USAGE_CHECK_API_URL,
237
- json={"count": count},
238
- headers={
239
- "Content-Type": "application/json",
240
- "Authorization": f"Bearer {self.judgment_api_key}",
241
- "X-Organization-Id": self.organization_id
242
- },
243
- verify=True
244
- )
245
-
246
- if response.status_code == HTTPStatus.FORBIDDEN:
247
- # Rate limits exceeded
248
- error_data = response.json()
249
- raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
250
- elif response.status_code != HTTPStatus.OK:
251
- raise ValueError(f"Failed to check usage limits: {response.text}")
252
-
253
- return response.json()
232
+ return server_response
254
233
 
255
- def upsert_trace(self, trace_data: dict, offline_mode: bool = False, show_link: bool = True, final_save: bool = True):
234
+ def upsert_trace(
235
+ self,
236
+ trace_data: dict,
237
+ offline_mode: bool = False,
238
+ show_link: bool = True,
239
+ final_save: bool = True,
240
+ ):
256
241
  """
257
242
  Upserts a trace to the Judgment API (always overwrites if exists).
258
243
 
@@ -261,10 +246,11 @@ class TraceManagerClient:
261
246
  offline_mode: Whether running in offline mode
262
247
  show_link: Whether to show the UI link (for live tracing)
263
248
  final_save: Whether this is the final save (controls S3 saving)
264
-
249
+
265
250
  Returns:
266
251
  dict: Server response containing UI URL and other metadata
267
252
  """
253
+
268
254
  def fallback_encoder(obj):
269
255
  """
270
256
  Custom JSON encoder fallback.
@@ -277,7 +263,7 @@ class TraceManagerClient:
277
263
  return str(obj)
278
264
  except Exception as e:
279
265
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
280
-
266
+
281
267
  serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
282
268
 
283
269
  response = requests.post(
@@ -286,63 +272,34 @@ class TraceManagerClient:
286
272
  headers={
287
273
  "Content-Type": "application/json",
288
274
  "Authorization": f"Bearer {self.judgment_api_key}",
289
- "X-Organization-Id": self.organization_id
275
+ "X-Organization-Id": self.organization_id,
290
276
  },
291
- verify=True
277
+ verify=True,
292
278
  )
293
-
279
+
294
280
  if response.status_code != HTTPStatus.OK:
295
281
  raise ValueError(f"Failed to upsert trace data: {response.text}")
296
-
282
+
297
283
  # Parse server response
298
284
  server_response = response.json()
299
-
285
+
300
286
  # If S3 storage is enabled, save to S3 only on final save
301
287
  if self.tracer and self.tracer.use_s3 and final_save:
302
288
  try:
303
289
  s3_key = self.tracer.s3_storage.save_trace(
304
290
  trace_data=trace_data,
305
291
  trace_id=trace_data["trace_id"],
306
- project_name=trace_data["project_name"]
292
+ project_name=trace_data["project_name"],
307
293
  )
308
294
  print(f"Trace also saved to S3 at key: {s3_key}")
309
295
  except Exception as e:
310
296
  warnings.warn(f"Failed to save trace to S3: {str(e)}")
311
-
297
+
312
298
  if not offline_mode and show_link and "ui_results_url" in server_response:
313
299
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
314
300
  rprint(pretty_str)
315
-
316
- return server_response
317
301
 
318
- def update_usage_counters(self, count: int = 1):
319
- """
320
- Update trace usage counters after successfully saving traces.
321
-
322
- Args:
323
- count: Number of traces to count (default: 1)
324
-
325
- Returns:
326
- dict: Server response with updated usage information
327
-
328
- Raises:
329
- ValueError: If the update fails
330
- """
331
- response = requests.post(
332
- JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
333
- json={"count": count},
334
- headers={
335
- "Content-Type": "application/json",
336
- "Authorization": f"Bearer {self.judgment_api_key}",
337
- "X-Organization-Id": self.organization_id
338
- },
339
- verify=True
340
- )
341
-
342
- if response.status_code != HTTPStatus.OK:
343
- raise ValueError(f"Failed to update usage counters: {response.text}")
344
-
345
- return response.json()
302
+ return server_response
346
303
 
347
304
  ## TODO: Should have a log endpoint, endpoint should also support batched payloads
348
305
  def save_annotation(self, annotation: TraceAnnotation):
@@ -351,24 +308,24 @@ class TraceManagerClient:
351
308
  "annotation": {
352
309
  "text": annotation.text,
353
310
  "label": annotation.label,
354
- "score": annotation.score
355
- }
356
- }
311
+ "score": annotation.score,
312
+ },
313
+ }
357
314
 
358
315
  response = requests.post(
359
316
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
360
317
  json=json_data,
361
318
  headers={
362
- 'Content-Type': 'application/json',
363
- 'Authorization': f'Bearer {self.judgment_api_key}',
364
- 'X-Organization-Id': self.organization_id
319
+ "Content-Type": "application/json",
320
+ "Authorization": f"Bearer {self.judgment_api_key}",
321
+ "X-Organization-Id": self.organization_id,
365
322
  },
366
- verify=True
323
+ verify=True,
367
324
  )
368
-
325
+
369
326
  if response.status_code != HTTPStatus.OK:
370
327
  raise ValueError(f"Failed to save annotation: {response.text}")
371
-
328
+
372
329
  return response.json()
373
330
 
374
331
  def delete_trace(self, trace_id: str):
@@ -383,15 +340,15 @@ class TraceManagerClient:
383
340
  headers={
384
341
  "Content-Type": "application/json",
385
342
  "Authorization": f"Bearer {self.judgment_api_key}",
386
- "X-Organization-Id": self.organization_id
387
- }
343
+ "X-Organization-Id": self.organization_id,
344
+ },
388
345
  )
389
346
 
390
347
  if response.status_code != HTTPStatus.OK:
391
348
  raise ValueError(f"Failed to delete trace: {response.text}")
392
-
349
+
393
350
  return response.json()
394
-
351
+
395
352
  def delete_traces(self, trace_ids: List[str]):
396
353
  """
397
354
  Delete a batch of traces from the database.
@@ -404,15 +361,15 @@ class TraceManagerClient:
404
361
  headers={
405
362
  "Content-Type": "application/json",
406
363
  "Authorization": f"Bearer {self.judgment_api_key}",
407
- "X-Organization-Id": self.organization_id
408
- }
364
+ "X-Organization-Id": self.organization_id,
365
+ },
409
366
  )
410
367
 
411
368
  if response.status_code != HTTPStatus.OK:
412
369
  raise ValueError(f"Failed to delete trace: {response.text}")
413
-
370
+
414
371
  return response.json()
415
-
372
+
416
373
  def delete_project(self, project_name: str):
417
374
  """
418
375
  Deletes a project from the server. Which also deletes all evaluations and traces associated with the project.
@@ -425,31 +382,31 @@ class TraceManagerClient:
425
382
  headers={
426
383
  "Content-Type": "application/json",
427
384
  "Authorization": f"Bearer {self.judgment_api_key}",
428
- "X-Organization-Id": self.organization_id
429
- }
385
+ "X-Organization-Id": self.organization_id,
386
+ },
430
387
  )
431
388
 
432
389
  if response.status_code != HTTPStatus.OK:
433
390
  raise ValueError(f"Failed to delete traces: {response.text}")
434
-
391
+
435
392
  return response.json()
436
393
 
437
394
 
438
395
  class TraceClient:
439
396
  """Client for managing a single trace context"""
440
-
397
+
441
398
  def __init__(
442
399
  self,
443
- tracer: Optional["Tracer"],
400
+ tracer: "Tracer",
444
401
  trace_id: Optional[str] = None,
445
402
  name: str = "default",
446
- project_name: str = None,
403
+ project_name: str | None = None,
447
404
  overwrite: bool = False,
448
405
  rules: Optional[List[Rule]] = None,
449
406
  enable_monitoring: bool = True,
450
407
  enable_evaluations: bool = True,
451
408
  parent_trace_id: Optional[str] = None,
452
- parent_name: Optional[str] = None
409
+ parent_name: Optional[str] = None,
453
410
  ):
454
411
  self.name = name
455
412
  self.trace_id = trace_id or str(uuid.uuid4())
@@ -461,39 +418,48 @@ class TraceClient:
461
418
  self.enable_evaluations = enable_evaluations
462
419
  self.parent_trace_id = parent_trace_id
463
420
  self.parent_name = parent_name
421
+ self.customer_id: Optional[str] = None # Added customer_id attribute
422
+ self.tags: List[Union[str, set, tuple]] = [] # Added tags attribute
423
+ self.metadata: Dict[str, Any] = {}
424
+ self.has_notification: Optional[bool] = False # Initialize has_notification
464
425
  self.trace_spans: List[TraceSpan] = []
465
426
  self.span_id_to_span: Dict[str, TraceSpan] = {}
466
427
  self.evaluation_runs: List[EvaluationRun] = []
467
428
  self.annotations: List[TraceAnnotation] = []
468
- self.start_time = None # Will be set after first successful save
469
- self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
470
- self.visited_nodes = []
471
- self.executed_tools = []
472
- self.executed_node_tools = []
473
- self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
474
-
429
+ self.start_time: Optional[float] = (
430
+ None # Will be set after first successful save
431
+ )
432
+ self.trace_manager_client = TraceManagerClient(
433
+ tracer.api_key, tracer.organization_id, tracer
434
+ )
435
+ self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
436
+
475
437
  # Get background span service from tracer
476
- self.background_span_service = tracer.get_background_span_service() if tracer else None
438
+ self.background_span_service = (
439
+ tracer.get_background_span_service() if tracer else None
440
+ )
477
441
 
478
442
  def get_current_span(self):
479
443
  """Get the current span from the context var"""
480
444
  return self.tracer.get_current_span()
481
-
445
+
482
446
  def set_current_span(self, span: Any):
483
447
  """Set the current span from the context var"""
484
448
  return self.tracer.set_current_span(span)
485
-
449
+
486
450
  def reset_current_span(self, token: Any):
487
451
  """Reset the current span from the context var"""
488
452
  self.tracer.reset_current_span(token)
489
-
453
+
490
454
  @contextmanager
491
455
  def span(self, name: str, span_type: SpanType = "span"):
492
456
  """Context manager for creating a trace span, managing the current span via contextvars"""
493
457
  is_first_span = len(self.trace_spans) == 0
494
458
  if is_first_span:
495
459
  try:
496
- trace_id, server_response = self.save_with_rate_limiting(overwrite=self.overwrite, final_save=False)
460
+ trace_id, server_response = self.save(
461
+ overwrite=self.overwrite, final_save=False
462
+ )
497
463
  # Set start_time after first successful save
498
464
  if self.start_time is None:
499
465
  self.start_time = time.time()
@@ -501,19 +467,23 @@ class TraceClient:
501
467
  except Exception as e:
502
468
  warnings.warn(f"Failed to save initial trace for live tracking: {e}")
503
469
  start_time = time.time()
504
-
470
+
505
471
  # Generate a unique ID for *this specific span invocation*
506
472
  span_id = str(uuid.uuid4())
507
-
508
- parent_span_id = self.get_current_span() # Get ID of the parent span from context var
509
- token = self.set_current_span(span_id) # Set *this* span's ID as the current one
510
-
473
+
474
+ parent_span_id = (
475
+ self.get_current_span()
476
+ ) # Get ID of the parent span from context var
477
+ token = self.set_current_span(
478
+ span_id
479
+ ) # Set *this* span's ID as the current one
480
+
511
481
  current_depth = 0
512
482
  if parent_span_id and parent_span_id in self._span_depths:
513
483
  current_depth = self._span_depths[parent_span_id] + 1
514
-
515
- self._span_depths[span_id] = current_depth # Store depth by span_id
516
-
484
+
485
+ self._span_depths[span_id] = current_depth # Store depth by span_id
486
+
517
487
  span = TraceSpan(
518
488
  span_id=span_id,
519
489
  trace_id=self.trace_id,
@@ -525,23 +495,21 @@ class TraceClient:
525
495
  function=name,
526
496
  )
527
497
  self.add_span(span)
528
-
529
-
530
-
498
+
531
499
  # Queue span with initial state (input phase)
532
500
  if self.background_span_service:
533
501
  self.background_span_service.queue_span(span, span_state="input")
534
-
502
+
535
503
  try:
536
504
  yield self
537
505
  finally:
538
506
  duration = time.time() - start_time
539
507
  span.duration = duration
540
-
508
+
541
509
  # Queue span with completed state (output phase)
542
510
  if self.background_span_service:
543
511
  self.background_span_service.queue_span(span, span_state="completed")
544
-
512
+
545
513
  # Clean up depth tracking for this span_id
546
514
  if span_id in self._span_depths:
547
515
  del self._span_depths[span_id]
@@ -561,12 +529,11 @@ class TraceClient:
561
529
  expected_tools: Optional[List[str]] = None,
562
530
  additional_metadata: Optional[Dict[str, Any]] = None,
563
531
  model: Optional[str] = None,
564
- span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
565
- log_results: Optional[bool] = True
532
+ span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
566
533
  ):
567
534
  if not self.enable_evaluations:
568
535
  return
569
-
536
+
570
537
  start_time = time.time() # Record start time
571
538
 
572
539
  try:
@@ -574,21 +541,35 @@ class TraceClient:
574
541
  if not scorers:
575
542
  warnings.warn("No valid scorers available for evaluation")
576
543
  return
577
-
544
+
578
545
  # Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
579
- if self.rules and any(isinstance(scorer, JudgevalScorer) for scorer in scorers):
580
- raise ValueError("Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types.")
581
-
546
+ if self.rules and any(
547
+ isinstance(scorer, JudgevalScorer) for scorer in scorers
548
+ ):
549
+ raise ValueError(
550
+ "Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types."
551
+ )
552
+
582
553
  except Exception as e:
583
554
  warnings.warn(f"Failed to load scorers: {str(e)}")
584
555
  return
585
-
556
+
586
557
  # If example is not provided, create one from the individual parameters
587
558
  if example is None:
588
559
  # Check if any of the individual parameters are provided
589
- if any(param is not None for param in [input, actual_output, expected_output, context,
590
- retrieval_context, tools_called, expected_tools,
591
- additional_metadata]):
560
+ if any(
561
+ param is not None
562
+ for param in [
563
+ input,
564
+ actual_output,
565
+ expected_output,
566
+ context,
567
+ retrieval_context,
568
+ tools_called,
569
+ expected_tools,
570
+ additional_metadata,
571
+ ]
572
+ ):
592
573
  example = Example(
593
574
  input=input,
594
575
  actual_output=actual_output,
@@ -600,12 +581,14 @@ class TraceClient:
600
581
  additional_metadata=additional_metadata,
601
582
  )
602
583
  else:
603
- raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
604
-
584
+ raise ValueError(
585
+ "Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided"
586
+ )
587
+
605
588
  # Check examples before creating evaluation run
606
-
589
+
607
590
  # check_examples([example], scorers)
608
-
591
+
609
592
  # --- Modification: Capture span_id immediately ---
610
593
  # span_id_at_eval_call = current_span_var.get()
611
594
  # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
@@ -617,37 +600,32 @@ class TraceClient:
617
600
  # Combine the trace-level rules with any evaluation-specific rules)
618
601
  eval_run = EvaluationRun(
619
602
  organization_id=self.tracer.organization_id,
620
- log_results=log_results,
621
603
  project_name=self.project_name,
622
604
  eval_name=f"{self.name.capitalize()}-"
623
- f"{span_id_to_use}-" # Keep original eval name format using context var if available
624
- f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
605
+ f"{span_id_to_use}-" # Keep original eval name format using context var if available
606
+ f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
625
607
  examples=[example],
626
608
  scorers=scorers,
627
609
  model=model,
628
- metadata={},
629
610
  judgment_api_key=self.tracer.api_key,
630
611
  override=self.overwrite,
631
- trace_span_id=span_id_to_use, # Pass the determined ID
632
- rules=self.rules # Use the combined rules
612
+ trace_span_id=span_id_to_use,
633
613
  )
634
-
614
+
635
615
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
636
-
616
+
637
617
  # Queue evaluation run through background service
638
618
  if self.background_span_service and span_id_to_use:
639
619
  # Get the current span data to avoid race conditions
640
620
  current_span = self.span_id_to_span.get(span_id_to_use)
641
621
  if current_span:
642
622
  self.background_span_service.queue_evaluation_run(
643
- eval_run,
644
- span_id=span_id_to_use,
645
- span_data=current_span
623
+ eval_run, span_id=span_id_to_use, span_data=current_span
646
624
  )
647
-
625
+
648
626
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
649
- # --- Modification: Use span_id from eval_run ---
650
- current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
627
+ # --- Modification: Use span_id from eval_run ---
628
+ current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
651
629
  # print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
652
630
  # --- End Modification ---
653
631
 
@@ -658,10 +636,10 @@ class TraceClient:
658
636
  self.evaluation_runs.append(eval_run)
659
637
 
660
638
  def add_annotation(self, annotation: TraceAnnotation):
661
- """Add an annotation to this trace context"""
662
- self.annotations.append(annotation)
663
- return self
664
-
639
+ """Add an annotation to this trace context"""
640
+ self.annotations.append(annotation)
641
+ return self
642
+
665
643
  def record_input(self, inputs: dict):
666
644
  current_span_id = self.get_current_span()
667
645
  if current_span_id:
@@ -670,17 +648,20 @@ class TraceClient:
670
648
  if "self" in inputs:
671
649
  del inputs["self"]
672
650
  span.inputs = inputs
673
-
651
+
674
652
  # Queue span with input data
675
- if self.background_span_service:
676
- self.background_span_service.queue_span(span, span_state="input")
677
-
653
+ try:
654
+ if self.background_span_service:
655
+ self.background_span_service.queue_span(span, span_state="input")
656
+ except Exception as e:
657
+ warnings.warn(f"Failed to queue span with input data: {e}")
658
+
678
659
  def record_agent_name(self, agent_name: str):
679
660
  current_span_id = self.get_current_span()
680
661
  if current_span_id:
681
662
  span = self.span_id_to_span[current_span_id]
682
663
  span.agent_name = agent_name
683
-
664
+
684
665
  # Queue span with agent_name data
685
666
  if self.background_span_service:
686
667
  self.background_span_service.queue_span(span, span_state="agent_name")
@@ -695,11 +676,11 @@ class TraceClient:
695
676
  if current_span_id:
696
677
  span = self.span_id_to_span[current_span_id]
697
678
  span.state_before = state
698
-
679
+
699
680
  # Queue span with state_before data
700
681
  if self.background_span_service:
701
682
  self.background_span_service.queue_span(span, span_state="state_before")
702
-
683
+
703
684
  def record_state_after(self, state: dict):
704
685
  """Records the agent's state after a tool execution on the current span.
705
686
 
@@ -710,7 +691,7 @@ class TraceClient:
710
691
  if current_span_id:
711
692
  span = self.span_id_to_span[current_span_id]
712
693
  span.state_after = state
713
-
694
+
714
695
  # Queue span with state_after data
715
696
  if self.background_span_service:
716
697
  self.background_span_service.queue_span(span, span_state="state_after")
@@ -720,19 +701,19 @@ class TraceClient:
720
701
  try:
721
702
  result = await coroutine
722
703
  setattr(span, field, result)
723
-
704
+
724
705
  # Queue span with output data now that coroutine is complete
725
706
  if self.background_span_service and field == "output":
726
707
  self.background_span_service.queue_span(span, span_state="output")
727
-
708
+
728
709
  return result
729
710
  except Exception as e:
730
711
  setattr(span, field, f"Error: {str(e)}")
731
-
712
+
732
713
  # Queue span even if there was an error
733
714
  if self.background_span_service and field == "output":
734
715
  self.background_span_service.queue_span(span, span_state="output")
735
-
716
+
736
717
  raise
737
718
 
738
719
  def record_output(self, output: Any):
@@ -740,56 +721,56 @@ class TraceClient:
740
721
  if current_span_id:
741
722
  span = self.span_id_to_span[current_span_id]
742
723
  span.output = "<pending>" if inspect.iscoroutine(output) else output
743
-
724
+
744
725
  if inspect.iscoroutine(output):
745
726
  asyncio.create_task(self._update_coroutine(span, output, "output"))
746
-
727
+
747
728
  # # Queue span with output data (unless it's pending)
748
729
  if self.background_span_service and not inspect.iscoroutine(output):
749
730
  self.background_span_service.queue_span(span, span_state="output")
750
731
 
751
- return span # Return the created entry
732
+ return span # Return the created entry
752
733
  # Removed else block - original didn't have one
753
- return None # Return None if no span_id found
754
-
734
+ return None # Return None if no span_id found
735
+
755
736
  def record_usage(self, usage: TraceUsage):
756
737
  current_span_id = self.get_current_span()
757
738
  if current_span_id:
758
739
  span = self.span_id_to_span[current_span_id]
759
740
  span.usage = usage
760
-
741
+
761
742
  # Queue span with usage data
762
743
  if self.background_span_service:
763
744
  self.background_span_service.queue_span(span, span_state="usage")
764
-
765
- return span # Return the created entry
745
+
746
+ return span # Return the created entry
766
747
  # Removed else block - original didn't have one
767
- return None # Return None if no span_id found
768
-
748
+ return None # Return None if no span_id found
749
+
769
750
  def record_error(self, error: Dict[str, Any]):
770
751
  current_span_id = self.get_current_span()
771
752
  if current_span_id:
772
753
  span = self.span_id_to_span[current_span_id]
773
754
  span.error = error
774
-
755
+
775
756
  # Queue span with error data
776
757
  if self.background_span_service:
777
758
  self.background_span_service.queue_span(span, span_state="error")
778
-
759
+
779
760
  return span
780
761
  return None
781
-
762
+
782
763
  def add_span(self, span: TraceSpan):
783
764
  """Add a trace span to this trace context"""
784
765
  self.trace_spans.append(span)
785
766
  self.span_id_to_span[span.span_id] = span
786
767
  return self
787
-
768
+
788
769
  def print(self):
789
770
  """Print the complete trace with proper visual structure"""
790
771
  for span in self.trace_spans:
791
772
  span.print_span()
792
-
773
+
793
774
  def get_duration(self) -> float:
794
775
  """
795
776
  Get the total duration of this trace
@@ -798,57 +779,23 @@ class TraceClient:
798
779
  return 0.0 # No duration if trace hasn't been saved yet
799
780
  return time.time() - self.start_time
800
781
 
801
- def save(self, overwrite: bool = False) -> Tuple[str, dict]:
802
- """
803
- Save the current trace to the database.
804
- Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
805
- """
806
- # Set start_time if this is the first save
807
- if self.start_time is None:
808
- self.start_time = time.time()
809
-
810
- # Calculate total elapsed time
811
- total_duration = self.get_duration()
812
- # Create trace document - Always use standard keys for top-level counts
813
- trace_data = {
814
- "trace_id": self.trace_id,
815
- "name": self.name,
816
- "project_name": self.project_name,
817
- "created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
818
- "duration": total_duration,
819
- "trace_spans": [span.model_dump() for span in self.trace_spans],
820
- "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
821
- "overwrite": overwrite,
822
- "offline_mode": self.tracer.offline_mode,
823
- "parent_trace_id": self.parent_trace_id,
824
- "parent_name": self.parent_name
825
- }
826
- # --- Log trace data before saving ---
827
- server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
828
-
829
- # upload annotations
830
- # TODO: batch to the log endpoint
831
- for annotation in self.annotations:
832
- self.trace_manager_client.save_annotation(annotation)
833
-
834
- return self.trace_id, server_response
835
-
836
- def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
782
+ def save(
783
+ self, overwrite: bool = False, final_save: bool = False
784
+ ) -> Tuple[str, dict]:
837
785
  """
838
786
  Save the current trace to the database with rate limiting checks.
839
787
  First checks usage limits, then upserts the trace if allowed.
840
-
788
+
841
789
  Args:
842
790
  overwrite: Whether to overwrite existing traces
843
791
  final_save: Whether this is the final save (updates usage counters)
844
-
792
+
845
793
  Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
846
794
  """
847
795
 
848
-
849
796
  # Calculate total elapsed time
850
797
  total_duration = self.get_duration()
851
-
798
+
852
799
  # Create trace document
853
800
  trace_data = {
854
801
  "trace_id": self.trace_id,
@@ -861,35 +808,20 @@ class TraceClient:
861
808
  "overwrite": overwrite,
862
809
  "offline_mode": self.tracer.offline_mode,
863
810
  "parent_trace_id": self.parent_trace_id,
864
- "parent_name": self.parent_name
811
+ "parent_name": self.parent_name,
812
+ "customer_id": self.customer_id,
813
+ "tags": self.tags,
814
+ "metadata": self.metadata,
865
815
  }
866
-
867
- # Check usage limits first
868
- try:
869
- usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
870
- # Usage check passed silently - no need to show detailed info
871
- except ValueError as e:
872
- # Rate limit exceeded
873
- warnings.warn(f"Rate limit check failed for live tracing: {e}")
874
- raise e
875
-
816
+
876
817
  # If usage check passes, upsert the trace
877
818
  server_response = self.trace_manager_client.upsert_trace(
878
- trace_data,
819
+ trace_data,
879
820
  offline_mode=self.tracer.offline_mode,
880
821
  show_link=not final_save, # Show link only on initial save, not final save
881
- final_save=final_save # Pass final_save to control S3 saving
822
+ final_save=final_save, # Pass final_save to control S3 saving
882
823
  )
883
824
 
884
- # Update usage counters only on final save
885
- if final_save:
886
- try:
887
- usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
888
- # Usage updated silently - no need to show detailed usage info
889
- except ValueError as e:
890
- # Log warning but don't fail the trace save since the trace was already saved
891
- warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
892
-
893
825
  # Upload annotations
894
826
  # TODO: batch to the log endpoint
895
827
  for annotation in self.annotations:
@@ -900,8 +832,77 @@ class TraceClient:
900
832
 
901
833
  def delete(self):
902
834
  return self.trace_manager_client.delete_trace(self.trace_id)
903
-
904
- def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: ExcInfo):
835
+
836
+ def update_metadata(self, metadata: dict):
837
+ """
838
+ Set metadata for this trace.
839
+
840
+ Args:
841
+ metadata: Metadata as a dictionary
842
+
843
+ Supported keys:
844
+ - customer_id: ID of the customer using this trace
845
+ - tags: List of tags for this trace
846
+ - has_notification: Whether this trace has a notification
847
+ - overwrite: Whether to overwrite existing traces
848
+ - name: Name of the trace
849
+ """
850
+ for k, v in metadata.items():
851
+ if k == "customer_id":
852
+ if v is not None:
853
+ self.customer_id = str(v)
854
+ else:
855
+ self.customer_id = None
856
+ elif k == "tags":
857
+ if isinstance(v, list):
858
+ # Validate that all items in the list are of the expected types
859
+ for item in v:
860
+ if not isinstance(item, (str, set, tuple)):
861
+ raise ValueError(
862
+ f"Tags must be a list of strings, sets, or tuples, got item of type {type(item)}"
863
+ )
864
+ self.tags = v
865
+ else:
866
+ raise ValueError(
867
+ f"Tags must be a list of strings, sets, or tuples, got {type(v)}"
868
+ )
869
+ elif k == "has_notification":
870
+ if not isinstance(v, bool):
871
+ raise ValueError(
872
+ f"has_notification must be a boolean, got {type(v)}"
873
+ )
874
+ self.has_notification = v
875
+ elif k == "overwrite":
876
+ if not isinstance(v, bool):
877
+ raise ValueError(f"overwrite must be a boolean, got {type(v)}")
878
+ self.overwrite = v
879
+ elif k == "name":
880
+ self.name = v
881
+ else:
882
+ self.metadata[k] = v
883
+
884
+ def set_customer_id(self, customer_id: str):
885
+ """
886
+ Set the customer ID for this trace.
887
+
888
+ Args:
889
+ customer_id: The customer ID to set
890
+ """
891
+ self.update_metadata({"customer_id": customer_id})
892
+
893
+ def set_tags(self, tags: List[Union[str, set, tuple]]):
894
+ """
895
+ Set the tags for this trace.
896
+
897
+ Args:
898
+ tags: List of tags to set
899
+ """
900
+ self.update_metadata({"tags": tags})
901
+
902
+
903
+ def _capture_exception_for_trace(
904
+ current_trace: Optional["TraceClient"], exc_info: ExcInfo
905
+ ):
905
906
  if not current_trace:
906
907
  return
907
908
 
@@ -909,18 +910,20 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
909
910
  formatted_exception = {
910
911
  "type": exc_type.__name__ if exc_type else "UnknownExceptionType",
911
912
  "message": str(exc_value) if exc_value else "No exception message",
912
- "traceback": traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
913
+ "traceback": traceback.format_tb(exc_traceback_obj)
914
+ if exc_traceback_obj
915
+ else [],
913
916
  }
914
-
917
+
915
918
  # This is where we specially handle exceptions that we might want to collect additional data for.
916
919
  # When we do this, always try checking the module from sys.modules instead of importing. This will
917
920
  # Let us support a wider range of exceptions without needing to import them for all clients.
918
-
919
- # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
921
+
922
+ # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
920
923
  # The alternative is to hand select libraries we want from sys.modules and check for them:
921
924
  # As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
922
925
 
923
- # General HTTP Like errors
926
+ # General HTTP Like errors
924
927
  try:
925
928
  url = getattr(getattr(exc_value, "request", None), "url", None)
926
929
  status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
@@ -929,34 +932,43 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
929
932
  "url": url if url else "Unknown URL",
930
933
  "status_code": status_code if status_code else None,
931
934
  }
932
- except Exception as e:
935
+ except Exception:
933
936
  pass
934
937
 
935
938
  current_trace.record_error(formatted_exception)
936
-
939
+
937
940
  # Queue the span with error state through background service
938
941
  if current_trace.background_span_service:
939
942
  current_span_id = current_trace.get_current_span()
940
943
  if current_span_id and current_span_id in current_trace.span_id_to_span:
941
944
  error_span = current_trace.span_id_to_span[current_span_id]
942
- current_trace.background_span_service.queue_span(error_span, span_state="error")
945
+ current_trace.background_span_service.queue_span(
946
+ error_span, span_state="error"
947
+ )
948
+
943
949
 
944
950
  class BackgroundSpanService:
945
951
  """
946
952
  Background service for queueing and batching trace spans for efficient saving.
947
-
953
+
948
954
  This service:
949
955
  - Queues spans as they complete
950
956
  - Batches them for efficient network usage
951
957
  - Sends spans periodically or when batches reach a certain size
952
958
  - Handles automatic flushing when the main event terminates
953
959
  """
954
-
955
- def __init__(self, judgment_api_key: str, organization_id: str,
956
- batch_size: int = 10, flush_interval: float = 5.0, num_workers: int = 1):
960
+
961
+ def __init__(
962
+ self,
963
+ judgment_api_key: str,
964
+ organization_id: str,
965
+ batch_size: int = 10,
966
+ flush_interval: float = 5.0,
967
+ num_workers: int = 1,
968
+ ):
957
969
  """
958
970
  Initialize the background span service.
959
-
971
+
960
972
  Args:
961
973
  judgment_api_key: API key for Judgment service
962
974
  organization_id: Organization ID
@@ -969,41 +981,41 @@ class BackgroundSpanService:
969
981
  self.batch_size = batch_size
970
982
  self.flush_interval = flush_interval
971
983
  self.num_workers = max(1, num_workers) # Ensure at least 1 worker
972
-
984
+
973
985
  # Queue for pending spans
974
- self._span_queue = queue.Queue()
975
-
986
+ self._span_queue: queue.Queue[Dict[str, Any]] = queue.Queue()
987
+
976
988
  # Background threads for processing spans
977
- self._worker_threads = []
989
+ self._worker_threads: List[threading.Thread] = []
978
990
  self._shutdown_event = threading.Event()
979
-
991
+
980
992
  # Track spans that have been sent
981
- self._sent_spans = set()
982
-
993
+ # self._sent_spans = set()
994
+
983
995
  # Register cleanup on exit
984
996
  atexit.register(self.shutdown)
985
-
997
+
986
998
  # Start the background workers
987
999
  self._start_workers()
988
-
1000
+
989
1001
  def _start_workers(self):
990
1002
  """Start the background worker threads."""
991
1003
  for i in range(self.num_workers):
992
1004
  if len(self._worker_threads) < self.num_workers:
993
1005
  worker_thread = threading.Thread(
994
- target=self._worker_loop,
995
- daemon=True,
996
- name=f"SpanWorker-{i+1}"
1006
+ target=self._worker_loop, daemon=True, name=f"SpanWorker-{i + 1}"
997
1007
  )
998
1008
  worker_thread.start()
999
1009
  self._worker_threads.append(worker_thread)
1000
-
1010
+
1001
1011
  def _worker_loop(self):
1002
1012
  """Main worker loop that processes spans in batches."""
1003
1013
  batch = []
1004
1014
  last_flush_time = time.time()
1005
- pending_task_count = 0 # Track how many tasks we've taken from queue but not marked done
1006
-
1015
+ pending_task_count = (
1016
+ 0 # Track how many tasks we've taken from queue but not marked done
1017
+ )
1018
+
1007
1019
  while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
1008
1020
  try:
1009
1021
  # First, do a blocking get to wait for at least one item
@@ -1015,7 +1027,7 @@ class BackgroundSpanService:
1015
1027
  except queue.Empty:
1016
1028
  # No new spans, continue to check for flush conditions
1017
1029
  pass
1018
-
1030
+
1019
1031
  # Then, do non-blocking gets to drain any additional available items
1020
1032
  # up to our batch size limit
1021
1033
  while len(batch) < self.batch_size:
@@ -1026,24 +1038,23 @@ class BackgroundSpanService:
1026
1038
  except queue.Empty:
1027
1039
  # No more items immediately available
1028
1040
  break
1029
-
1041
+
1030
1042
  current_time = time.time()
1031
- should_flush = (
1032
- len(batch) >= self.batch_size or
1033
- (batch and (current_time - last_flush_time) >= self.flush_interval)
1043
+ should_flush = len(batch) >= self.batch_size or (
1044
+ batch and (current_time - last_flush_time) >= self.flush_interval
1034
1045
  )
1035
-
1046
+
1036
1047
  if should_flush and batch:
1037
1048
  self._send_batch(batch)
1038
-
1049
+
1039
1050
  # Only mark tasks as done after successful sending
1040
1051
  for _ in range(pending_task_count):
1041
1052
  self._span_queue.task_done()
1042
1053
  pending_task_count = 0 # Reset counter
1043
-
1054
+
1044
1055
  batch.clear()
1045
1056
  last_flush_time = current_time
1046
-
1057
+
1047
1058
  except Exception as e:
1048
1059
  warnings.warn(f"Error in span service worker loop: {e}")
1049
1060
  # On error, still need to mark tasks as done to prevent hanging
@@ -1051,53 +1062,50 @@ class BackgroundSpanService:
1051
1062
  self._span_queue.task_done()
1052
1063
  pending_task_count = 0
1053
1064
  batch.clear()
1054
-
1065
+
1055
1066
  # Final flush on shutdown
1056
1067
  if batch:
1057
1068
  self._send_batch(batch)
1058
1069
  # Mark remaining tasks as done
1059
1070
  for _ in range(pending_task_count):
1060
1071
  self._span_queue.task_done()
1061
-
1072
+
1062
1073
  def _send_batch(self, batch: List[Dict[str, Any]]):
1063
1074
  """
1064
1075
  Send a batch of spans to the server.
1065
-
1076
+
1066
1077
  Args:
1067
1078
  batch: List of span dictionaries to send
1068
1079
  """
1069
1080
  if not batch:
1070
1081
  return
1071
-
1082
+
1072
1083
  try:
1073
- # Group spans by type for different endpoints
1084
+ # Group items by type for different endpoints
1074
1085
  spans_to_send = []
1075
1086
  evaluation_runs_to_send = []
1076
-
1087
+
1077
1088
  for item in batch:
1078
- if item['type'] == 'span':
1079
- spans_to_send.append(item['data'])
1080
- elif item['type'] == 'evaluation_run':
1081
- evaluation_runs_to_send.append(item['data'])
1082
-
1089
+ if item["type"] == "span":
1090
+ spans_to_send.append(item["data"])
1091
+ elif item["type"] == "evaluation_run":
1092
+ evaluation_runs_to_send.append(item["data"])
1093
+
1083
1094
  # Send spans if any
1084
1095
  if spans_to_send:
1085
1096
  self._send_spans_batch(spans_to_send)
1086
-
1097
+
1087
1098
  # Send evaluation runs if any
1088
1099
  if evaluation_runs_to_send:
1089
1100
  self._send_evaluation_runs_batch(evaluation_runs_to_send)
1090
-
1101
+
1091
1102
  except Exception as e:
1092
- warnings.warn(f"Failed to send span batch: {e}")
1093
-
1103
+ warnings.warn(f"Failed to send batch: {e}")
1104
+
1094
1105
  def _send_spans_batch(self, spans: List[Dict[str, Any]]):
1095
1106
  """Send a batch of spans to the spans endpoint."""
1096
- payload = {
1097
- "spans": spans,
1098
- "organization_id": self.organization_id
1099
- }
1100
-
1107
+ payload = {"spans": spans, "organization_id": self.organization_id}
1108
+
1101
1109
  # Serialize with fallback encoder
1102
1110
  def fallback_encoder(obj):
1103
1111
  try:
@@ -1107,10 +1115,10 @@ class BackgroundSpanService:
1107
1115
  return str(obj)
1108
1116
  except Exception as e:
1109
1117
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1110
-
1118
+
1111
1119
  try:
1112
1120
  serialized_data = json.dumps(payload, default=fallback_encoder)
1113
-
1121
+
1114
1122
  # Send the actual HTTP request to the batch endpoint
1115
1123
  response = requests.post(
1116
1124
  JUDGMENT_TRACES_SPANS_BATCH_API_URL,
@@ -1118,21 +1126,22 @@ class BackgroundSpanService:
1118
1126
  headers={
1119
1127
  "Content-Type": "application/json",
1120
1128
  "Authorization": f"Bearer {self.judgment_api_key}",
1121
- "X-Organization-Id": self.organization_id
1129
+ "X-Organization-Id": self.organization_id,
1122
1130
  },
1123
1131
  verify=True,
1124
- timeout=30 # Add timeout to prevent hanging
1132
+ timeout=30, # Add timeout to prevent hanging
1125
1133
  )
1126
-
1134
+
1127
1135
  if response.status_code != HTTPStatus.OK:
1128
- warnings.warn(f"Failed to send spans batch: HTTP {response.status_code} - {response.text}")
1129
-
1130
-
1131
- except requests.RequestException as e:
1136
+ warnings.warn(
1137
+ f"Failed to send spans batch: HTTP {response.status_code} - {response.text}"
1138
+ )
1139
+
1140
+ except RequestException as e:
1132
1141
  warnings.warn(f"Network error sending spans batch: {e}")
1133
1142
  except Exception as e:
1134
1143
  warnings.warn(f"Failed to serialize or send spans batch: {e}")
1135
-
1144
+
1136
1145
  def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
1137
1146
  """Send a batch of evaluation runs with their associated span data to the endpoint."""
1138
1147
  # Structure payload to include both evaluation run data and span data
@@ -1142,22 +1151,23 @@ class BackgroundSpanService:
1142
1151
  entry = {
1143
1152
  "evaluation_run": {
1144
1153
  # Extract evaluation run fields (excluding span-specific fields)
1145
- key: value for key, value in eval_data.items()
1146
- if key not in ['associated_span_id', 'span_data', 'queued_at']
1154
+ key: value
1155
+ for key, value in eval_data.items()
1156
+ if key not in ["associated_span_id", "span_data", "queued_at"]
1147
1157
  },
1148
1158
  "associated_span": {
1149
- "span_id": eval_data.get('associated_span_id'),
1150
- "span_data": eval_data.get('span_data')
1159
+ "span_id": eval_data.get("associated_span_id"),
1160
+ "span_data": eval_data.get("span_data"),
1151
1161
  },
1152
- "queued_at": eval_data.get('queued_at')
1162
+ "queued_at": eval_data.get("queued_at"),
1153
1163
  }
1154
1164
  evaluation_entries.append(entry)
1155
-
1165
+
1156
1166
  payload = {
1157
1167
  "organization_id": self.organization_id,
1158
- "evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
1168
+ "evaluation_entries": evaluation_entries, # Each entry contains both eval run + span data
1159
1169
  }
1160
-
1170
+
1161
1171
  # Serialize with fallback encoder
1162
1172
  def fallback_encoder(obj):
1163
1173
  try:
@@ -1167,10 +1177,10 @@ class BackgroundSpanService:
1167
1177
  return str(obj)
1168
1178
  except Exception as e:
1169
1179
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1170
-
1180
+
1171
1181
  try:
1172
1182
  serialized_data = json.dumps(payload, default=fallback_encoder)
1173
-
1183
+
1174
1184
  # Send the actual HTTP request to the batch endpoint
1175
1185
  response = requests.post(
1176
1186
  JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
@@ -1178,25 +1188,26 @@ class BackgroundSpanService:
1178
1188
  headers={
1179
1189
  "Content-Type": "application/json",
1180
1190
  "Authorization": f"Bearer {self.judgment_api_key}",
1181
- "X-Organization-Id": self.organization_id
1191
+ "X-Organization-Id": self.organization_id,
1182
1192
  },
1183
1193
  verify=True,
1184
- timeout=30 # Add timeout to prevent hanging
1194
+ timeout=30, # Add timeout to prevent hanging
1185
1195
  )
1186
-
1196
+
1187
1197
  if response.status_code != HTTPStatus.OK:
1188
- warnings.warn(f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}")
1189
-
1190
-
1191
- except requests.RequestException as e:
1198
+ warnings.warn(
1199
+ f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}"
1200
+ )
1201
+
1202
+ except RequestException as e:
1192
1203
  warnings.warn(f"Network error sending evaluation runs batch: {e}")
1193
1204
  except Exception as e:
1194
1205
  warnings.warn(f"Failed to send evaluation runs batch: {e}")
1195
-
1206
+
1196
1207
  def queue_span(self, span: TraceSpan, span_state: str = "input"):
1197
1208
  """
1198
1209
  Queue a span for background sending.
1199
-
1210
+
1200
1211
  Args:
1201
1212
  span: The TraceSpan object to queue
1202
1213
  span_state: State of the span ("input", "output", "completed")
@@ -1207,15 +1218,17 @@ class BackgroundSpanService:
1207
1218
  "data": {
1208
1219
  **span.model_dump(),
1209
1220
  "span_state": span_state,
1210
- "queued_at": time.time()
1211
- }
1221
+ "queued_at": time.time(),
1222
+ },
1212
1223
  }
1213
1224
  self._span_queue.put(span_data)
1214
-
1215
- def queue_evaluation_run(self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan):
1225
+
1226
+ def queue_evaluation_run(
1227
+ self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan
1228
+ ):
1216
1229
  """
1217
1230
  Queue an evaluation run for background sending.
1218
-
1231
+
1219
1232
  Args:
1220
1233
  evaluation_run: The EvaluationRun object to queue
1221
1234
  span_id: The span ID associated with this evaluation run
@@ -1228,11 +1241,11 @@ class BackgroundSpanService:
1228
1241
  **evaluation_run.model_dump(),
1229
1242
  "associated_span_id": span_id,
1230
1243
  "span_data": span_data.model_dump(), # Include span data to avoid race conditions
1231
- "queued_at": time.time()
1232
- }
1244
+ "queued_at": time.time(),
1245
+ },
1233
1246
  }
1234
1247
  self._span_queue.put(eval_data)
1235
-
1248
+
1236
1249
  def flush(self):
1237
1250
  """Force immediate sending of all queued spans."""
1238
1251
  try:
@@ -1240,89 +1253,101 @@ class BackgroundSpanService:
1240
1253
  self._span_queue.join()
1241
1254
  except Exception as e:
1242
1255
  warnings.warn(f"Error during flush: {e}")
1243
-
1256
+
1244
1257
  def shutdown(self):
1245
1258
  """Shutdown the background service and flush remaining spans."""
1246
1259
  if self._shutdown_event.is_set():
1247
1260
  return
1248
-
1261
+
1249
1262
  try:
1250
1263
  # Signal shutdown to stop new items from being queued
1251
1264
  self._shutdown_event.set()
1252
-
1265
+
1253
1266
  # Try to flush any remaining spans
1254
1267
  try:
1255
1268
  self.flush()
1256
1269
  except Exception as e:
1257
- warnings.warn(f"Error during final flush: {e}")
1270
+ warnings.warn(f"Error during final flush: {e}")
1258
1271
  except Exception as e:
1259
1272
  warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
1260
1273
  finally:
1261
1274
  # Clear the worker threads list (daemon threads will be killed automatically)
1262
1275
  self._worker_threads.clear()
1263
-
1276
+
1264
1277
  def get_queue_size(self) -> int:
1265
1278
  """Get the current size of the span queue."""
1266
1279
  return self._span_queue.qsize()
1267
1280
 
1281
+
1268
1282
  class _DeepTracer:
1269
1283
  _instance: Optional["_DeepTracer"] = None
1270
1284
  _lock: threading.Lock = threading.Lock()
1271
1285
  _refcount: int = 0
1272
- _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
1273
- _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
1286
+ _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar(
1287
+ "_deep_profiler_span_stack", default=[]
1288
+ )
1289
+ _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar(
1290
+ "_deep_profiler_skip_stack", default=[]
1291
+ )
1274
1292
  _original_sys_trace: Optional[Callable] = None
1275
1293
  _original_threading_trace: Optional[Callable] = None
1276
1294
 
1277
- def __init__(self, tracer: 'Tracer'):
1295
+ def __init__(self, tracer: "Tracer"):
1278
1296
  self._tracer = tracer
1279
1297
 
1280
1298
  def _get_qual_name(self, frame) -> str:
1281
1299
  func_name = frame.f_code.co_name
1282
1300
  module_name = frame.f_globals.get("__name__", "unknown_module")
1283
-
1301
+
1284
1302
  try:
1285
1303
  func = frame.f_globals.get(func_name)
1286
1304
  if func is None:
1287
1305
  return f"{module_name}.{func_name}"
1288
1306
  if hasattr(func, "__qualname__"):
1289
- return f"{module_name}.{func.__qualname__}"
1307
+ return f"{module_name}.{func.__qualname__}"
1308
+ return f"{module_name}.{func_name}"
1290
1309
  except Exception:
1291
1310
  return f"{module_name}.{func_name}"
1292
-
1293
- def __new__(cls, tracer: 'Tracer' = None):
1311
+
1312
+ def __new__(cls, tracer: "Tracer"):
1294
1313
  with cls._lock:
1295
1314
  if cls._instance is None:
1296
1315
  cls._instance = super().__new__(cls)
1297
1316
  return cls._instance
1298
-
1317
+
1299
1318
  def _should_trace(self, frame):
1300
1319
  # Skip stack is maintained by the tracer as an optimization to skip earlier
1301
1320
  # frames in the call stack that we've already determined should be skipped
1302
1321
  skip_stack = self._skip_stack.get()
1303
1322
  if len(skip_stack) > 0:
1304
1323
  return False
1305
-
1324
+
1306
1325
  func_name = frame.f_code.co_name
1307
1326
  module_name = frame.f_globals.get("__name__", None)
1308
-
1309
1327
  func = frame.f_globals.get(func_name)
1310
- if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
1328
+ if func and (
1329
+ hasattr(func, "_judgment_span_name") or hasattr(func, "_judgment_span_type")
1330
+ ):
1311
1331
  return False
1312
1332
 
1313
1333
  if (
1314
1334
  not module_name
1315
- or func_name.startswith("<") # ex: <listcomp>
1316
- or func_name.startswith("__") and func_name != "__call__" # dunders
1335
+ or func_name.startswith("<") # ex: <listcomp>
1336
+ or func_name.startswith("__")
1337
+ and func_name != "__call__" # dunders
1317
1338
  or not self._is_user_code(frame.f_code.co_filename)
1318
1339
  ):
1319
1340
  return False
1320
-
1341
+
1321
1342
  return True
1322
-
1343
+
1323
1344
  @functools.cache
1324
1345
  def _is_user_code(self, filename: str):
1325
- return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
1346
+ return (
1347
+ bool(filename)
1348
+ and not filename.startswith("<")
1349
+ and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
1350
+ )
1326
1351
 
1327
1352
  def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
1328
1353
  """Cooperative trace function for sys.settrace that chains with existing tracers."""
@@ -1334,18 +1359,20 @@ class _DeepTracer:
1334
1359
  except Exception:
1335
1360
  # If the original tracer fails, continue with our tracing
1336
1361
  pass
1337
-
1362
+
1338
1363
  # Then do our own tracing
1339
1364
  our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
1340
-
1365
+
1341
1366
  # Return our tracer to continue tracing, but respect the original's decision
1342
1367
  # If the original tracer returned None (stop tracing), we should respect that
1343
1368
  if original_result is None and self._original_sys_trace:
1344
1369
  return None
1345
-
1370
+
1346
1371
  return our_result or original_result
1347
-
1348
- def _cooperative_threading_trace(self, frame: types.FrameType, event: str, arg: Any):
1372
+
1373
+ def _cooperative_threading_trace(
1374
+ self, frame: types.FrameType, event: str, arg: Any
1375
+ ):
1349
1376
  """Cooperative trace function for threading.settrace that chains with existing tracers."""
1350
1377
  # First, call the original threading trace function if it exists
1351
1378
  original_result = None
@@ -1355,44 +1382,48 @@ class _DeepTracer:
1355
1382
  except Exception:
1356
1383
  # If the original tracer fails, continue with our tracing
1357
1384
  pass
1358
-
1385
+
1359
1386
  # Then do our own tracing
1360
1387
  our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
1361
-
1388
+
1362
1389
  # Return our tracer to continue tracing, but respect the original's decision
1363
1390
  # If the original tracer returned None (stop tracing), we should respect that
1364
1391
  if original_result is None and self._original_threading_trace:
1365
1392
  return None
1366
-
1393
+
1367
1394
  return our_result or original_result
1368
-
1369
- def _trace(self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable):
1395
+
1396
+ def _trace(
1397
+ self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable
1398
+ ):
1370
1399
  frame.f_trace_lines = False
1371
1400
  frame.f_trace_opcodes = False
1372
1401
 
1373
1402
  if not self._should_trace(frame):
1374
1403
  return
1375
-
1404
+
1376
1405
  if event not in ("call", "return", "exception"):
1377
1406
  return
1378
-
1407
+
1379
1408
  current_trace = self._tracer.get_current_trace()
1380
1409
  if not current_trace:
1381
1410
  return
1382
-
1411
+
1383
1412
  parent_span_id = self._tracer.get_current_span()
1384
1413
  if not parent_span_id:
1385
1414
  return
1386
1415
 
1387
1416
  qual_name = self._get_qual_name(frame)
1388
1417
  instance_name = None
1389
- if 'self' in frame.f_locals:
1390
- instance = frame.f_locals['self']
1418
+ if "self" in frame.f_locals:
1419
+ instance = frame.f_locals["self"]
1391
1420
  class_name = instance.__class__.__name__
1392
- class_identifiers = getattr(Tracer._instance, 'class_identifiers', {})
1393
- instance_name = get_instance_prefixed_name(instance, class_name, class_identifiers)
1421
+ class_identifiers = getattr(self._tracer, "class_identifiers", {})
1422
+ instance_name = get_instance_prefixed_name(
1423
+ instance, class_name, class_identifiers
1424
+ )
1394
1425
  skip_stack = self._skip_stack.get()
1395
-
1426
+
1396
1427
  if event == "call":
1397
1428
  # If we have entries in the skip stack and the current qual_name matches the top entry,
1398
1429
  # push it again to track nesting depth and skip
@@ -1402,9 +1433,9 @@ class _DeepTracer:
1402
1433
  skip_stack.append(qual_name)
1403
1434
  self._skip_stack.set(skip_stack)
1404
1435
  return
1405
-
1436
+
1406
1437
  should_trace = self._should_trace(frame)
1407
-
1438
+
1408
1439
  if not should_trace:
1409
1440
  if not skip_stack:
1410
1441
  self._skip_stack.set([qual_name])
@@ -1416,35 +1447,37 @@ class _DeepTracer:
1416
1447
  skip_stack.pop()
1417
1448
  self._skip_stack.set(skip_stack)
1418
1449
  return
1419
-
1450
+
1420
1451
  if skip_stack:
1421
1452
  return
1422
-
1453
+
1423
1454
  span_stack = self._span_stack.get()
1424
1455
  if event == "call":
1425
1456
  if not self._should_trace(frame):
1426
1457
  return
1427
-
1458
+
1428
1459
  span_id = str(uuid.uuid4())
1429
-
1460
+
1430
1461
  parent_depth = current_trace._span_depths.get(parent_span_id, 0)
1431
1462
  depth = parent_depth + 1
1432
-
1463
+
1433
1464
  current_trace._span_depths[span_id] = depth
1434
-
1465
+
1435
1466
  start_time = time.time()
1436
-
1437
- span_stack.append({
1438
- "span_id": span_id,
1439
- "parent_span_id": parent_span_id,
1440
- "function": qual_name,
1441
- "start_time": start_time
1442
- })
1467
+
1468
+ span_stack.append(
1469
+ {
1470
+ "span_id": span_id,
1471
+ "parent_span_id": parent_span_id,
1472
+ "function": qual_name,
1473
+ "start_time": start_time,
1474
+ }
1475
+ )
1443
1476
  self._span_stack.set(span_stack)
1444
-
1477
+
1445
1478
  token = self._tracer.set_current_span(span_id)
1446
1479
  frame.f_locals["_judgment_span_token"] = token
1447
-
1480
+
1448
1481
  span = TraceSpan(
1449
1482
  span_id=span_id,
1450
1483
  trace_id=current_trace.trace_id,
@@ -1454,57 +1487,55 @@ class _DeepTracer:
1454
1487
  span_type="span",
1455
1488
  parent_span_id=parent_span_id,
1456
1489
  function=qual_name,
1457
- agent_name=instance_name
1490
+ agent_name=instance_name,
1458
1491
  )
1459
1492
  current_trace.add_span(span)
1460
-
1493
+
1461
1494
  inputs = {}
1462
1495
  try:
1463
1496
  args_info = inspect.getargvalues(frame)
1464
1497
  for arg in args_info.args:
1465
1498
  try:
1466
1499
  inputs[arg] = args_info.locals.get(arg)
1467
- except:
1500
+ except Exception:
1468
1501
  inputs[arg] = "<<Unserializable>>"
1469
1502
  current_trace.record_input(inputs)
1470
1503
  except Exception as e:
1471
- current_trace.record_input({
1472
- "error": str(e)
1473
- })
1474
-
1504
+ current_trace.record_input({"error": str(e)})
1505
+
1475
1506
  elif event == "return":
1476
1507
  if not span_stack:
1477
1508
  return
1478
-
1509
+
1479
1510
  current_id = self._tracer.get_current_span()
1480
-
1511
+
1481
1512
  span_data = None
1482
1513
  for i, entry in enumerate(reversed(span_stack)):
1483
1514
  if entry["span_id"] == current_id:
1484
- span_data = span_stack.pop(-(i+1))
1515
+ span_data = span_stack.pop(-(i + 1))
1485
1516
  self._span_stack.set(span_stack)
1486
1517
  break
1487
-
1518
+
1488
1519
  if not span_data:
1489
1520
  return
1490
-
1521
+
1491
1522
  start_time = span_data["start_time"]
1492
1523
  duration = time.time() - start_time
1493
-
1524
+
1494
1525
  current_trace.span_id_to_span[span_data["span_id"]].duration = duration
1495
1526
 
1496
1527
  if arg is not None:
1497
- # exception handling will take priority.
1528
+ # exception handling will take priority.
1498
1529
  current_trace.record_output(arg)
1499
-
1530
+
1500
1531
  if span_data["span_id"] in current_trace._span_depths:
1501
1532
  del current_trace._span_depths[span_data["span_id"]]
1502
-
1533
+
1503
1534
  if span_stack:
1504
1535
  self._tracer.set_current_span(span_stack[-1]["span_id"])
1505
1536
  else:
1506
1537
  self._tracer.set_current_span(span_data["parent_span_id"])
1507
-
1538
+
1508
1539
  if "_judgment_span_token" in frame.f_locals:
1509
1540
  self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
1510
1541
 
@@ -1513,10 +1544,9 @@ class _DeepTracer:
1513
1544
  if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
1514
1545
  return
1515
1546
  _capture_exception_for_trace(current_trace, arg)
1516
-
1517
-
1547
+
1518
1548
  return continuation_func
1519
-
1549
+
1520
1550
  def __enter__(self):
1521
1551
  with self._lock:
1522
1552
  self._refcount += 1
@@ -1524,14 +1554,14 @@ class _DeepTracer:
1524
1554
  # Store the existing trace functions before setting ours
1525
1555
  self._original_sys_trace = sys.gettrace()
1526
1556
  self._original_threading_trace = threading.gettrace()
1527
-
1557
+
1528
1558
  self._skip_stack.set([])
1529
1559
  self._span_stack.set([])
1530
-
1560
+
1531
1561
  sys.settrace(self._cooperative_sys_trace)
1532
1562
  threading.settrace(self._cooperative_threading_trace)
1533
1563
  return self
1534
-
1564
+
1535
1565
  def __exit__(self, exc_type, exc_val, exc_tb):
1536
1566
  with self._lock:
1537
1567
  self._refcount -= 1
@@ -1539,7 +1569,7 @@ class _DeepTracer:
1539
1569
  # Restore the original trace functions instead of setting to None
1540
1570
  sys.settrace(self._original_sys_trace)
1541
1571
  threading.settrace(self._original_threading_trace)
1542
-
1572
+
1543
1573
  # Clean up the references
1544
1574
  self._original_sys_trace = None
1545
1575
  self._original_threading_trace = None
@@ -1547,7 +1577,7 @@ class _DeepTracer:
1547
1577
 
1548
1578
  # Below commented out function isn't used anymore?
1549
1579
 
1550
- # def log(self, message: str, level: str = "info"):
1580
+ # def log(self, message: str, level: str = "info"):
1551
1581
  # """ Log a message with the span context """
1552
1582
  # current_trace = self._tracer.get_current_trace()
1553
1583
  # if current_trace:
@@ -1555,31 +1585,29 @@ class _DeepTracer:
1555
1585
  # else:
1556
1586
  # print(f"[{level}] {message}")
1557
1587
  # current_trace.record_output({"log": message})
1558
-
1559
- class Tracer:
1560
- _instance = None
1561
1588
 
1589
+
1590
+ class Tracer:
1562
1591
  # Tracer.current_trace class variable is currently used in wrap()
1563
1592
  # TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
1564
1593
  # Should be fine to do so as long as we keep Tracer as a singleton
1565
1594
  current_trace: Optional[TraceClient] = None
1566
1595
  # current_span_id: Optional[str] = None
1567
1596
 
1568
- trace_across_async_contexts: bool = False # BY default, we don't trace across async contexts
1569
-
1570
- def __new__(cls, *args, **kwargs):
1571
- if cls._instance is None:
1572
- cls._instance = super(Tracer, cls).__new__(cls)
1573
- return cls._instance
1597
+ trace_across_async_contexts: bool = (
1598
+ False # BY default, we don't trace across async contexts
1599
+ )
1574
1600
 
1575
1601
  def __init__(
1576
- self,
1577
- api_key: str = os.getenv("JUDGMENT_API_KEY"),
1578
- project_name: str = None,
1602
+ self,
1603
+ api_key: str | None = os.getenv("JUDGMENT_API_KEY"),
1604
+ project_name: str | None = None,
1579
1605
  rules: Optional[List[Rule]] = None, # Added rules parameter
1580
- organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
1581
- enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
1582
- enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
1606
+ organization_id: str | None = os.getenv("JUDGMENT_ORG_ID"),
1607
+ enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower()
1608
+ == "true",
1609
+ enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower()
1610
+ == "true",
1583
1611
  # S3 configuration
1584
1612
  use_s3: bool = False,
1585
1613
  s3_bucket_name: Optional[str] = None,
@@ -1587,116 +1615,132 @@ class Tracer:
1587
1615
  s3_aws_secret_access_key: Optional[str] = None,
1588
1616
  s3_region_name: Optional[str] = None,
1589
1617
  offline_mode: bool = False,
1590
- deep_tracing: bool = True, # Deep tracing is enabled by default
1591
- trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1618
+ deep_tracing: bool = False, # Deep tracing is disabled by default
1619
+ trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1592
1620
  # Background span service configuration
1593
1621
  enable_background_spans: bool = True, # Enable background span service by default
1594
1622
  span_batch_size: int = 50, # Number of spans to batch before sending
1595
1623
  span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
1596
- span_num_workers: int = 10 # Number of worker threads for span processing
1597
- ):
1598
- if not hasattr(self, 'initialized'):
1599
- if not api_key:
1600
- raise ValueError("Tracer must be configured with a Judgment API key")
1601
-
1624
+ span_num_workers: int = 10, # Number of worker threads for span processing
1625
+ ):
1626
+ if not api_key:
1627
+ raise ValueError("Tracer must be configured with a Judgment API key")
1628
+
1629
+ try:
1602
1630
  result, response = validate_api_key(api_key)
1603
- if not result:
1604
- raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
1605
-
1606
- if not organization_id:
1607
- raise ValueError("Tracer must be configured with an Organization ID")
1608
- if use_s3 and not s3_bucket_name:
1609
- raise ValueError("S3 bucket name must be provided when use_s3 is True")
1610
-
1611
- self.api_key: str = api_key
1612
- self.project_name: str = project_name or str(uuid.uuid4())
1613
- self.organization_id: str = organization_id
1614
- self.rules: List[Rule] = rules or [] # Store rules at tracer level
1615
- self.traces: List[Trace] = []
1616
- self.initialized: bool = True
1617
- self.enable_monitoring: bool = enable_monitoring
1618
- self.enable_evaluations: bool = enable_evaluations
1619
- self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
1620
- self.span_id_to_previous_span_id: Dict[str, str] = {}
1621
- self.trace_id_to_previous_trace: Dict[str, TraceClient] = {}
1622
- self.current_span_id: Optional[str] = None
1623
- self.current_trace: Optional[TraceClient] = None
1624
- self.trace_across_async_contexts: bool = trace_across_async_contexts
1625
- Tracer.trace_across_async_contexts = trace_across_async_contexts
1626
-
1627
- # Initialize S3 storage if enabled
1628
- self.use_s3 = use_s3
1629
- if use_s3:
1630
- from judgeval.common.s3_storage import S3Storage
1631
+ except Exception as e:
1632
+ print(f"Issue with verifying API key, disabling monitoring: {e}")
1633
+ enable_monitoring = False
1634
+ result = True
1635
+
1636
+ if not result:
1637
+ raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
1638
+
1639
+ if not organization_id:
1640
+ raise ValueError("Tracer must be configured with an Organization ID")
1641
+ if use_s3 and not s3_bucket_name:
1642
+ raise ValueError("S3 bucket name must be provided when use_s3 is True")
1643
+
1644
+ self.api_key: str = api_key
1645
+ self.project_name: str = project_name or str(uuid.uuid4())
1646
+ self.organization_id: str = organization_id
1647
+ self.rules: List[Rule] = rules or [] # Store rules at tracer level
1648
+ self.traces: List[Trace] = []
1649
+ self.enable_monitoring: bool = enable_monitoring
1650
+ self.enable_evaluations: bool = enable_evaluations
1651
+ self.class_identifiers: Dict[
1652
+ str, str
1653
+ ] = {} # Dictionary to store class identifiers
1654
+ self.span_id_to_previous_span_id: Dict[str, str | None] = {}
1655
+ self.trace_id_to_previous_trace: Dict[str, TraceClient | None] = {}
1656
+ self.current_span_id: Optional[str] = None
1657
+ self.current_trace: Optional[TraceClient] = None
1658
+ self.trace_across_async_contexts: bool = trace_across_async_contexts
1659
+ Tracer.trace_across_async_contexts = trace_across_async_contexts
1660
+
1661
+ # Initialize S3 storage if enabled
1662
+ self.use_s3 = use_s3
1663
+ if use_s3:
1664
+ from judgeval.common.s3_storage import S3Storage
1665
+
1666
+ try:
1631
1667
  self.s3_storage = S3Storage(
1632
1668
  bucket_name=s3_bucket_name,
1633
1669
  aws_access_key_id=s3_aws_access_key_id,
1634
1670
  aws_secret_access_key=s3_aws_secret_access_key,
1635
- region_name=s3_region_name
1671
+ region_name=s3_region_name,
1636
1672
  )
1637
- self.offline_mode: bool = offline_mode
1638
- self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1639
-
1640
- # Initialize background span service
1641
- self.enable_background_spans: bool = enable_background_spans
1642
- self.background_span_service: Optional[BackgroundSpanService] = None
1643
- if enable_background_spans and not offline_mode:
1644
- self.background_span_service = BackgroundSpanService(
1645
- judgment_api_key=api_key,
1646
- organization_id=organization_id,
1647
- batch_size=span_batch_size,
1648
- flush_interval=span_flush_interval,
1649
- num_workers=span_num_workers
1650
- )
1651
-
1652
- elif hasattr(self, 'project_name') and self.project_name != project_name:
1653
- warnings.warn(
1654
- f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
1655
- f"project_name='{self.project_name}'. Due to the singleton pattern, the original project_name will be used. "
1656
- "To use a different project name, ensure the first Tracer initialization uses the desired project name.",
1657
- RuntimeWarning
1673
+ except Exception as e:
1674
+ print(f"Issue with initializing S3 storage, disabling S3: {e}")
1675
+ self.use_s3 = False
1676
+
1677
+ self.offline_mode: bool = offline_mode
1678
+ self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1679
+
1680
+ # Initialize background span service
1681
+ self.enable_background_spans: bool = enable_background_spans
1682
+ self.background_span_service: Optional[BackgroundSpanService] = None
1683
+ if enable_background_spans and not offline_mode:
1684
+ self.background_span_service = BackgroundSpanService(
1685
+ judgment_api_key=api_key,
1686
+ organization_id=organization_id,
1687
+ batch_size=span_batch_size,
1688
+ flush_interval=span_flush_interval,
1689
+ num_workers=span_num_workers,
1658
1690
  )
1659
1691
 
1660
- def set_current_span(self, span_id: str):
1661
- self.span_id_to_previous_span_id[span_id] = getattr(self, 'current_span_id', None)
1692
+ def set_current_span(self, span_id: str) -> Optional[contextvars.Token[str | None]]:
1693
+ self.span_id_to_previous_span_id[span_id] = self.current_span_id
1662
1694
  self.current_span_id = span_id
1663
1695
  Tracer.current_span_id = span_id
1664
1696
  try:
1665
1697
  token = current_span_var.set(span_id)
1666
- except:
1698
+ except Exception:
1667
1699
  token = None
1668
1700
  return token
1669
-
1701
+
1670
1702
  def get_current_span(self) -> Optional[str]:
1671
1703
  try:
1672
1704
  current_span_var_val = current_span_var.get()
1673
- except:
1705
+ except Exception:
1674
1706
  current_span_var_val = None
1675
- return (self.current_span_id or current_span_var_val) if self.trace_across_async_contexts else current_span_var_val
1676
-
1677
- def reset_current_span(self, token: Optional[str] = None, span_id: Optional[str] = None):
1678
- if not span_id:
1679
- span_id = self.current_span_id
1707
+ return (
1708
+ (self.current_span_id or current_span_var_val)
1709
+ if self.trace_across_async_contexts
1710
+ else current_span_var_val
1711
+ )
1712
+
1713
+ def reset_current_span(
1714
+ self,
1715
+ token: Optional[contextvars.Token[str | None]] = None,
1716
+ span_id: Optional[str] = None,
1717
+ ):
1680
1718
  try:
1681
- current_span_var.reset(token)
1682
- except:
1719
+ if token:
1720
+ current_span_var.reset(token)
1721
+ except Exception:
1683
1722
  pass
1684
- self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
1685
- Tracer.current_span_id = self.current_span_id
1686
-
1687
- def set_current_trace(self, trace: TraceClient):
1723
+ if not span_id:
1724
+ span_id = self.current_span_id
1725
+ if span_id:
1726
+ self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
1727
+ Tracer.current_span_id = self.current_span_id
1728
+
1729
+ def set_current_trace(
1730
+ self, trace: TraceClient
1731
+ ) -> Optional[contextvars.Token[TraceClient | None]]:
1688
1732
  """
1689
1733
  Set the current trace context in contextvars
1690
1734
  """
1691
- self.trace_id_to_previous_trace[trace.trace_id] = getattr(self, 'current_trace', None)
1735
+ self.trace_id_to_previous_trace[trace.trace_id] = self.current_trace
1692
1736
  self.current_trace = trace
1693
1737
  Tracer.current_trace = trace
1694
1738
  try:
1695
1739
  token = current_trace_var.set(trace)
1696
- except:
1740
+ except Exception:
1697
1741
  token = None
1698
1742
  return token
1699
-
1743
+
1700
1744
  def get_current_trace(self) -> Optional[TraceClient]:
1701
1745
  """
1702
1746
  Get the current trace context.
@@ -1707,72 +1751,69 @@ class Tracer:
1707
1751
  """
1708
1752
  try:
1709
1753
  current_trace_var_val = current_trace_var.get()
1710
- except:
1754
+ except Exception:
1711
1755
  current_trace_var_val = None
1712
-
1713
- # Use context variable or class variable based on trace_across_async_contexts setting
1714
- context_trace = (self.current_trace or current_trace_var_val) if self.trace_across_async_contexts else current_trace_var_val
1715
-
1716
- # If we found a trace from context, return it
1717
- if context_trace:
1718
- return context_trace
1719
-
1720
- # Fallback: Check the active client potentially set by a callback handler (e.g., LangGraph)
1721
- if hasattr(self, '_active_trace_client') and self._active_trace_client:
1722
- return self._active_trace_client
1723
-
1724
- # If neither is available, return None
1725
- return None
1726
-
1727
- def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
1728
- if not trace_id and self.current_trace:
1729
- trace_id = self.current_trace.trace_id
1756
+ return (
1757
+ (self.current_trace or current_trace_var_val)
1758
+ if self.trace_across_async_contexts
1759
+ else current_trace_var_val
1760
+ )
1761
+
1762
+ def reset_current_trace(
1763
+ self,
1764
+ token: Optional[contextvars.Token[TraceClient | None]] = None,
1765
+ trace_id: Optional[str] = None,
1766
+ ):
1730
1767
  try:
1731
- current_trace_var.reset(token)
1732
- except:
1768
+ if token:
1769
+ current_trace_var.reset(token)
1770
+ except Exception:
1733
1771
  pass
1734
- self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1735
- Tracer.current_trace = self.current_trace
1772
+ if not trace_id and self.current_trace:
1773
+ trace_id = self.current_trace.trace_id
1774
+ if trace_id:
1775
+ self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1776
+ Tracer.current_trace = self.current_trace
1736
1777
 
1737
1778
  @contextmanager
1738
1779
  def trace(
1739
- self,
1740
- name: str,
1741
- project_name: str = None,
1780
+ self,
1781
+ name: str,
1782
+ project_name: str | None = None,
1742
1783
  overwrite: bool = False,
1743
- rules: Optional[List[Rule]] = None # Added rules parameter
1784
+ rules: Optional[List[Rule]] = None, # Added rules parameter
1744
1785
  ) -> Generator[TraceClient, None, None]:
1745
1786
  """Start a new trace context using a context manager"""
1746
1787
  trace_id = str(uuid.uuid4())
1747
1788
  project = project_name if project_name is not None else self.project_name
1748
-
1789
+
1749
1790
  # Get parent trace info from context
1750
1791
  parent_trace = self.get_current_trace()
1751
1792
  parent_trace_id = None
1752
1793
  parent_name = None
1753
-
1794
+
1754
1795
  if parent_trace:
1755
1796
  parent_trace_id = parent_trace.trace_id
1756
1797
  parent_name = parent_trace.name
1757
1798
 
1758
1799
  trace = TraceClient(
1759
- self,
1760
- trace_id,
1761
- name,
1762
- project_name=project,
1800
+ self,
1801
+ trace_id,
1802
+ name,
1803
+ project_name=project,
1763
1804
  overwrite=overwrite,
1764
1805
  rules=self.rules, # Pass combined rules to the trace client
1765
1806
  enable_monitoring=self.enable_monitoring,
1766
1807
  enable_evaluations=self.enable_evaluations,
1767
1808
  parent_trace_id=parent_trace_id,
1768
- parent_name=parent_name
1809
+ parent_name=parent_name,
1769
1810
  )
1770
-
1811
+
1771
1812
  # Set the current trace in context variables
1772
1813
  token = self.set_current_trace(trace)
1773
-
1814
+
1774
1815
  # Automatically create top-level span
1775
- with trace.span(name or "unnamed_trace") as span:
1816
+ with trace.span(name or "unnamed_trace"):
1776
1817
  try:
1777
1818
  # Save the trace to the database to handle Evaluations' trace_id referential integrity
1778
1819
  yield trace
@@ -1780,101 +1821,110 @@ class Tracer:
1780
1821
  # Reset the context variable
1781
1822
  self.reset_current_trace(token)
1782
1823
 
1783
-
1784
1824
  def log(self, msg: str, label: str = "log", score: int = 1):
1785
1825
  """Log a message with the current span context"""
1786
1826
  current_span_id = self.get_current_span()
1787
1827
  current_trace = self.get_current_trace()
1788
- if current_span_id:
1828
+ if current_span_id and current_trace:
1789
1829
  annotation = TraceAnnotation(
1790
- span_id=current_span_id,
1791
- text=msg,
1792
- label=label,
1793
- score=score
1830
+ span_id=current_span_id, text=msg, label=label, score=score
1794
1831
  )
1795
-
1796
1832
  current_trace.add_annotation(annotation)
1797
1833
 
1798
1834
  rprint(f"[bold]{label}:[/bold] {msg}")
1799
-
1800
- def identify(self, identifier: str, track_state: bool = False, track_attributes: Optional[List[str]] = None, field_mappings: Optional[Dict[str, str]] = None):
1835
+
1836
+ def identify(
1837
+ self,
1838
+ identifier: str,
1839
+ track_state: bool = False,
1840
+ track_attributes: Optional[List[str]] = None,
1841
+ field_mappings: Optional[Dict[str, str]] = None,
1842
+ ):
1801
1843
  """
1802
1844
  Class decorator that associates a class with a custom identifier and enables state tracking.
1803
-
1845
+
1804
1846
  This decorator creates a mapping between the class name and the provided
1805
1847
  identifier, which can be useful for tagging, grouping, or referencing
1806
1848
  classes in a standardized way. It also enables automatic state capture
1807
1849
  for instances of the decorated class when used with tracing.
1808
-
1850
+
1809
1851
  Args:
1810
1852
  identifier: The identifier to associate with the decorated class.
1811
1853
  This will be used as the instance name in traces.
1812
- track_state: Whether to automatically capture the state (attributes)
1854
+ track_state: Whether to automatically capture the state (attributes)
1813
1855
  of instances before and after function execution. Defaults to False.
1814
1856
  track_attributes: Optional list of specific attribute names to track.
1815
- If None, all non-private attributes (not starting with '_')
1857
+ If None, all non-private attributes (not starting with '_')
1816
1858
  will be tracked when track_state=True.
1817
- field_mappings: Optional dictionary mapping internal attribute names to
1859
+ field_mappings: Optional dictionary mapping internal attribute names to
1818
1860
  display names in the captured state. For example:
1819
- {"system_prompt": "instructions"} will capture the
1861
+ {"system_prompt": "instructions"} will capture the
1820
1862
  'instructions' attribute as 'system_prompt' in the state.
1821
-
1863
+
1822
1864
  Example:
1823
1865
  @tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
1824
1866
  class User:
1825
1867
  # Class implementation
1826
1868
  """
1869
+
1827
1870
  def decorator(cls):
1828
1871
  class_name = cls.__name__
1829
1872
  self.class_identifiers[class_name] = {
1830
1873
  "identifier": identifier,
1831
1874
  "track_state": track_state,
1832
1875
  "track_attributes": track_attributes,
1833
- "field_mappings": field_mappings or {}
1876
+ "field_mappings": field_mappings or {},
1834
1877
  }
1835
1878
  return cls
1836
-
1879
+
1837
1880
  return decorator
1838
-
1839
- def _capture_instance_state(self, instance: Any, class_config: Dict[str, Any]) -> Dict[str, Any]:
1881
+
1882
+ def _capture_instance_state(
1883
+ self, instance: Any, class_config: Dict[str, Any]
1884
+ ) -> Dict[str, Any]:
1840
1885
  """
1841
1886
  Capture the state of an instance based on class configuration.
1842
1887
  Args:
1843
1888
  instance: The instance to capture the state of.
1844
- class_config: Configuration dictionary for state capture,
1889
+ class_config: Configuration dictionary for state capture,
1845
1890
  expected to contain 'track_attributes' and 'field_mappings'.
1846
1891
  """
1847
- track_attributes = class_config.get('track_attributes')
1848
- field_mappings = class_config.get('field_mappings')
1849
-
1892
+ track_attributes = class_config.get("track_attributes")
1893
+ field_mappings = class_config.get("field_mappings")
1894
+
1850
1895
  if track_attributes:
1851
-
1852
1896
  state = {attr: getattr(instance, attr, None) for attr in track_attributes}
1853
1897
  else:
1854
-
1855
- state = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')}
1898
+ state = {
1899
+ k: v for k, v in instance.__dict__.items() if not k.startswith("_")
1900
+ }
1856
1901
 
1857
1902
  if field_mappings:
1858
- state['field_mappings'] = field_mappings
1903
+ state["field_mappings"] = field_mappings
1859
1904
 
1860
1905
  return state
1861
-
1862
-
1906
+
1863
1907
  def _get_instance_state_if_tracked(self, args):
1864
1908
  """
1865
1909
  Extract instance state if the instance should be tracked.
1866
-
1910
+
1867
1911
  Returns the captured state dict if tracking is enabled, None otherwise.
1868
1912
  """
1869
- if args and hasattr(args[0], '__class__'):
1913
+ if args and hasattr(args[0], "__class__"):
1870
1914
  instance = args[0]
1871
1915
  class_name = instance.__class__.__name__
1872
- if (class_name in self.class_identifiers and
1873
- isinstance(self.class_identifiers[class_name], dict) and
1874
- self.class_identifiers[class_name].get('track_state', False)):
1875
- return self._capture_instance_state(instance, self.class_identifiers[class_name])
1876
-
1877
- def _conditionally_capture_and_record_state(self, trace_client_instance: TraceClient, args: tuple, is_before: bool):
1916
+ if (
1917
+ class_name in self.class_identifiers
1918
+ and isinstance(self.class_identifiers[class_name], dict)
1919
+ and self.class_identifiers[class_name].get("track_state", False)
1920
+ ):
1921
+ return self._capture_instance_state(
1922
+ instance, self.class_identifiers[class_name]
1923
+ )
1924
+
1925
+ def _conditionally_capture_and_record_state(
1926
+ self, trace_client_instance: TraceClient, args: tuple, is_before: bool
1927
+ ):
1878
1928
  """Captures instance state if tracked and records it via the trace_client."""
1879
1929
  state = self._get_instance_state_if_tracked(args)
1880
1930
  if state:
@@ -1882,11 +1932,20 @@ class Tracer:
1882
1932
  trace_client_instance.record_state_before(state)
1883
1933
  else:
1884
1934
  trace_client_instance.record_state_after(state)
1885
-
1886
- def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1935
+
1936
+ def observe(
1937
+ self,
1938
+ func=None,
1939
+ *,
1940
+ name=None,
1941
+ span_type: SpanType = "span",
1942
+ project_name: str | None = None,
1943
+ overwrite: bool = False,
1944
+ deep_tracing: bool | None = None,
1945
+ ):
1887
1946
  """
1888
1947
  Decorator to trace function execution with detailed entry/exit information.
1889
-
1948
+
1890
1949
  Args:
1891
1950
  func: The function to decorate
1892
1951
  name: Optional custom name for the span (defaults to function name)
@@ -1897,56 +1956,71 @@ class Tracer:
1897
1956
  If None, uses the tracer's default setting.
1898
1957
  """
1899
1958
  # If monitoring is disabled, return the function as is
1900
- if not self.enable_monitoring:
1901
- return func if func else lambda f: f
1902
-
1903
- if func is None:
1904
- return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
1905
- overwrite=overwrite, deep_tracing=deep_tracing)
1906
-
1907
- # Use provided name or fall back to function name
1908
- original_span_name = name or func.__name__
1909
-
1910
- # Store custom attributes on the function object
1911
- func._judgment_span_name = original_span_name
1912
- func._judgment_span_type = span_type
1913
-
1914
- # Use the provided deep_tracing value or fall back to the tracer's default
1915
- use_deep_tracing = deep_tracing if deep_tracing is not None else self.deep_tracing
1916
-
1959
+ try:
1960
+ if not self.enable_monitoring:
1961
+ return func if func else lambda f: f
1962
+
1963
+ if func is None:
1964
+ return lambda f: self.observe(
1965
+ f,
1966
+ name=name,
1967
+ span_type=span_type,
1968
+ project_name=project_name,
1969
+ overwrite=overwrite,
1970
+ deep_tracing=deep_tracing,
1971
+ )
1972
+
1973
+ # Use provided name or fall back to function name
1974
+ original_span_name = name or func.__name__
1975
+
1976
+ # Store custom attributes on the function object
1977
+ func._judgment_span_name = original_span_name
1978
+ func._judgment_span_type = span_type
1979
+
1980
+ # Use the provided deep_tracing value or fall back to the tracer's default
1981
+ use_deep_tracing = (
1982
+ deep_tracing if deep_tracing is not None else self.deep_tracing
1983
+ )
1984
+ except Exception:
1985
+ return func
1986
+
1917
1987
  if asyncio.iscoroutinefunction(func):
1988
+
1918
1989
  @functools.wraps(func)
1919
1990
  async def async_wrapper(*args, **kwargs):
1920
1991
  nonlocal original_span_name
1921
1992
  class_name = None
1922
- instance_name = None
1923
1993
  span_name = original_span_name
1924
1994
  agent_name = None
1925
1995
 
1926
- if args and hasattr(args[0], '__class__'):
1996
+ if args and hasattr(args[0], "__class__"):
1927
1997
  class_name = args[0].__class__.__name__
1928
- agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1998
+ agent_name = get_instance_prefixed_name(
1999
+ args[0], class_name, self.class_identifiers
2000
+ )
1929
2001
 
1930
2002
  # Get current trace from context
1931
2003
  current_trace = self.get_current_trace()
1932
-
2004
+
1933
2005
  # If there's no current trace, create a root trace
1934
2006
  if not current_trace:
1935
2007
  trace_id = str(uuid.uuid4())
1936
- project = project_name if project_name is not None else self.project_name
1937
-
2008
+ project = (
2009
+ project_name if project_name is not None else self.project_name
2010
+ )
2011
+
1938
2012
  # Create a new trace client to serve as the root
1939
2013
  current_trace = TraceClient(
1940
2014
  self,
1941
2015
  trace_id,
1942
- span_name, # MODIFIED: Use span_name directly
2016
+ span_name, # MODIFIED: Use span_name directly
1943
2017
  project_name=project,
1944
2018
  overwrite=overwrite,
1945
2019
  rules=self.rules,
1946
2020
  enable_monitoring=self.enable_monitoring,
1947
- enable_evaluations=self.enable_evaluations
2021
+ enable_evaluations=self.enable_evaluations,
1948
2022
  )
1949
-
2023
+
1950
2024
  # Save empty trace and set trace context
1951
2025
  # current_trace.save(empty_save=True, overwrite=overwrite)
1952
2026
  trace_token = self.set_current_trace(current_trace)
@@ -1954,7 +2028,9 @@ class Tracer:
1954
2028
  try:
1955
2029
  # Use span for the function execution within the root trace
1956
2030
  # This sets the current_span_var
1957
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
2031
+ with current_trace.span(
2032
+ span_name, span_type=span_type
2033
+ ) as span: # MODIFIED: Use span_name directly
1958
2034
  # Record inputs
1959
2035
  inputs = combine_args_kwargs(func, args, kwargs)
1960
2036
  span.record_input(inputs)
@@ -1962,50 +2038,66 @@ class Tracer:
1962
2038
  span.record_agent_name(agent_name)
1963
2039
 
1964
2040
  # Capture state before execution
1965
- self._conditionally_capture_and_record_state(span, args, is_before=True)
1966
-
1967
- if use_deep_tracing:
1968
- with _DeepTracer(self):
1969
- result = await func(*args, **kwargs)
1970
- else:
1971
- try:
2041
+ self._conditionally_capture_and_record_state(
2042
+ span, args, is_before=True
2043
+ )
2044
+
2045
+ try:
2046
+ if use_deep_tracing:
2047
+ with _DeepTracer(self):
2048
+ result = await func(*args, **kwargs)
2049
+ else:
1972
2050
  result = await func(*args, **kwargs)
1973
- except Exception as e:
1974
- _capture_exception_for_trace(current_trace, sys.exc_info())
1975
- raise e
1976
-
1977
- # Capture state after execution
1978
- self._conditionally_capture_and_record_state(span, args, is_before=False)
1979
-
2051
+ except Exception as e:
2052
+ _capture_exception_for_trace(
2053
+ current_trace, sys.exc_info()
2054
+ )
2055
+ raise e
2056
+
2057
+ # Capture state after execution
2058
+ self._conditionally_capture_and_record_state(
2059
+ span, args, is_before=False
2060
+ )
2061
+
1980
2062
  # Record output
1981
2063
  span.record_output(result)
1982
2064
  return result
1983
2065
  finally:
1984
2066
  # Flush background spans before saving the trace
1985
-
1986
- complete_trace_data = {
1987
- "trace_id": current_trace.trace_id,
1988
- "name": current_trace.name,
1989
- "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
1990
- "duration": current_trace.get_duration(),
1991
- "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
1992
- "overwrite": overwrite,
1993
- "offline_mode": self.offline_mode,
1994
- "parent_trace_id": current_trace.parent_trace_id,
1995
- "parent_name": current_trace.parent_name
1996
- }
1997
- # Save the completed trace
1998
- trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
1999
-
2000
- # Store the complete trace data instead of just server response
2001
-
2002
- self.traces.append(complete_trace_data)
2003
-
2004
- # if self.background_span_service:
2005
- # self.background_span_service.flush()
2006
-
2007
- # Reset trace context (span context resets automatically)
2008
- self.reset_current_trace(trace_token)
2067
+ try:
2068
+ complete_trace_data = {
2069
+ "trace_id": current_trace.trace_id,
2070
+ "name": current_trace.name,
2071
+ "created_at": datetime.utcfromtimestamp(
2072
+ current_trace.start_time
2073
+ ).isoformat(),
2074
+ "duration": current_trace.get_duration(),
2075
+ "trace_spans": [
2076
+ span.model_dump()
2077
+ for span in current_trace.trace_spans
2078
+ ],
2079
+ "overwrite": overwrite,
2080
+ "offline_mode": self.offline_mode,
2081
+ "parent_trace_id": current_trace.parent_trace_id,
2082
+ "parent_name": current_trace.parent_name,
2083
+ }
2084
+ # Save the completed trace
2085
+ trace_id, server_response = current_trace.save(
2086
+ overwrite=overwrite, final_save=True
2087
+ )
2088
+
2089
+ # Store the complete trace data instead of just server response
2090
+
2091
+ self.traces.append(complete_trace_data)
2092
+
2093
+ # if self.background_span_service:
2094
+ # self.background_span_service.flush()
2095
+
2096
+ # Reset trace context (span context resets automatically)
2097
+ self.reset_current_trace(trace_token)
2098
+ except Exception as e:
2099
+ warnings.warn(f"Issue with async_wrapper: {e}")
2100
+ return
2009
2101
  else:
2010
2102
  with current_trace.span(span_name, span_type=span_type) as span:
2011
2103
  inputs = combine_args_kwargs(func, args, kwargs)
@@ -2014,24 +2106,28 @@ class Tracer:
2014
2106
  span.record_agent_name(agent_name)
2015
2107
 
2016
2108
  # Capture state before execution
2017
- self._conditionally_capture_and_record_state(span, args, is_before=True)
2018
-
2019
- if use_deep_tracing:
2020
- with _DeepTracer(self):
2021
- result = await func(*args, **kwargs)
2022
- else:
2023
- try:
2109
+ self._conditionally_capture_and_record_state(
2110
+ span, args, is_before=True
2111
+ )
2112
+
2113
+ try:
2114
+ if use_deep_tracing:
2115
+ with _DeepTracer(self):
2116
+ result = await func(*args, **kwargs)
2117
+ else:
2024
2118
  result = await func(*args, **kwargs)
2025
- except Exception as e:
2026
- _capture_exception_for_trace(current_trace, sys.exc_info())
2027
- raise e
2028
-
2029
- # Capture state after execution
2030
- self._conditionally_capture_and_record_state(span, args, is_before=False)
2031
-
2119
+ except Exception as e:
2120
+ _capture_exception_for_trace(current_trace, sys.exc_info())
2121
+ raise e
2122
+
2123
+ # Capture state after execution
2124
+ self._conditionally_capture_and_record_state(
2125
+ span, args, is_before=False
2126
+ )
2127
+
2032
2128
  span.record_output(result)
2033
2129
  return result
2034
-
2130
+
2035
2131
  return async_wrapper
2036
2132
  else:
2037
2133
  # Non-async function implementation with deep tracing
@@ -2039,118 +2135,146 @@ class Tracer:
2039
2135
  def wrapper(*args, **kwargs):
2040
2136
  nonlocal original_span_name
2041
2137
  class_name = None
2042
- instance_name = None
2043
2138
  span_name = original_span_name
2044
2139
  agent_name = None
2045
- if args and hasattr(args[0], '__class__'):
2140
+ if args and hasattr(args[0], "__class__"):
2046
2141
  class_name = args[0].__class__.__name__
2047
- agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
2142
+ agent_name = get_instance_prefixed_name(
2143
+ args[0], class_name, self.class_identifiers
2144
+ )
2048
2145
  # Get current trace from context
2049
2146
  current_trace = self.get_current_trace()
2050
2147
 
2051
2148
  # If there's no current trace, create a root trace
2052
2149
  if not current_trace:
2053
2150
  trace_id = str(uuid.uuid4())
2054
- project = project_name if project_name is not None else self.project_name
2055
-
2151
+ project = (
2152
+ project_name if project_name is not None else self.project_name
2153
+ )
2154
+
2056
2155
  # Create a new trace client to serve as the root
2057
2156
  current_trace = TraceClient(
2058
2157
  self,
2059
2158
  trace_id,
2060
- span_name, # MODIFIED: Use span_name directly
2159
+ span_name, # MODIFIED: Use span_name directly
2061
2160
  project_name=project,
2062
2161
  overwrite=overwrite,
2063
2162
  rules=self.rules,
2064
2163
  enable_monitoring=self.enable_monitoring,
2065
- enable_evaluations=self.enable_evaluations
2164
+ enable_evaluations=self.enable_evaluations,
2066
2165
  )
2067
-
2166
+
2068
2167
  # Save empty trace and set trace context
2069
2168
  # current_trace.save(empty_save=True, overwrite=overwrite)
2070
2169
  trace_token = self.set_current_trace(current_trace)
2071
-
2170
+
2072
2171
  try:
2073
2172
  # Use span for the function execution within the root trace
2074
2173
  # This sets the current_span_var
2075
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
2174
+ with current_trace.span(
2175
+ span_name, span_type=span_type
2176
+ ) as span: # MODIFIED: Use span_name directly
2076
2177
  # Record inputs
2077
2178
  inputs = combine_args_kwargs(func, args, kwargs)
2078
2179
  span.record_input(inputs)
2079
2180
  if agent_name:
2080
2181
  span.record_agent_name(agent_name)
2081
2182
  # Capture state before execution
2082
- self._conditionally_capture_and_record_state(span, args, is_before=True)
2083
-
2084
- if use_deep_tracing:
2085
- with _DeepTracer(self):
2086
- result = func(*args, **kwargs)
2087
- else:
2088
- try:
2183
+ self._conditionally_capture_and_record_state(
2184
+ span, args, is_before=True
2185
+ )
2186
+
2187
+ try:
2188
+ if use_deep_tracing:
2189
+ with _DeepTracer(self):
2190
+ result = func(*args, **kwargs)
2191
+ else:
2089
2192
  result = func(*args, **kwargs)
2090
- except Exception as e:
2091
- _capture_exception_for_trace(current_trace, sys.exc_info())
2092
- raise e
2093
-
2193
+ except Exception as e:
2194
+ _capture_exception_for_trace(
2195
+ current_trace, sys.exc_info()
2196
+ )
2197
+ raise e
2198
+
2094
2199
  # Capture state after execution
2095
- self._conditionally_capture_and_record_state(span, args, is_before=False)
2200
+ self._conditionally_capture_and_record_state(
2201
+ span, args, is_before=False
2202
+ )
2096
2203
 
2097
-
2098
2204
  # Record output
2099
2205
  span.record_output(result)
2100
2206
  return result
2101
2207
  finally:
2102
2208
  # Flush background spans before saving the trace
2103
-
2104
-
2105
- # Save the completed trace
2106
- trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
2107
-
2108
- # Store the complete trace data instead of just server response
2109
- complete_trace_data = {
2110
- "trace_id": current_trace.trace_id,
2111
- "name": current_trace.name,
2112
- "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
2113
- "duration": current_trace.get_duration(),
2114
- "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
2115
- "overwrite": overwrite,
2116
- "offline_mode": self.offline_mode,
2117
- "parent_trace_id": current_trace.parent_trace_id,
2118
- "parent_name": current_trace.parent_name
2119
- }
2120
- self.traces.append(complete_trace_data)
2121
- # Reset trace context (span context resets automatically)
2122
- self.reset_current_trace(trace_token)
2209
+ try:
2210
+ # Save the completed trace
2211
+ trace_id, server_response = current_trace.save(
2212
+ overwrite=overwrite, final_save=True
2213
+ )
2214
+
2215
+ # Store the complete trace data instead of just server response
2216
+ complete_trace_data = {
2217
+ "trace_id": current_trace.trace_id,
2218
+ "name": current_trace.name,
2219
+ "created_at": datetime.utcfromtimestamp(
2220
+ current_trace.start_time
2221
+ ).isoformat(),
2222
+ "duration": current_trace.get_duration(),
2223
+ "trace_spans": [
2224
+ span.model_dump()
2225
+ for span in current_trace.trace_spans
2226
+ ],
2227
+ "overwrite": overwrite,
2228
+ "offline_mode": self.offline_mode,
2229
+ "parent_trace_id": current_trace.parent_trace_id,
2230
+ "parent_name": current_trace.parent_name,
2231
+ }
2232
+ self.traces.append(complete_trace_data)
2233
+ # Reset trace context (span context resets automatically)
2234
+ self.reset_current_trace(trace_token)
2235
+ except Exception as e:
2236
+ warnings.warn(f"Issue with save: {e}")
2237
+ return
2123
2238
  else:
2124
2239
  with current_trace.span(span_name, span_type=span_type) as span:
2125
-
2126
2240
  inputs = combine_args_kwargs(func, args, kwargs)
2127
2241
  span.record_input(inputs)
2128
2242
  if agent_name:
2129
2243
  span.record_agent_name(agent_name)
2130
2244
 
2131
2245
  # Capture state before execution
2132
- self._conditionally_capture_and_record_state(span, args, is_before=True)
2246
+ self._conditionally_capture_and_record_state(
2247
+ span, args, is_before=True
2248
+ )
2133
2249
 
2134
- if use_deep_tracing:
2135
- with _DeepTracer(self):
2136
- result = func(*args, **kwargs)
2137
- else:
2138
- try:
2250
+ try:
2251
+ if use_deep_tracing:
2252
+ with _DeepTracer(self):
2253
+ result = func(*args, **kwargs)
2254
+ else:
2139
2255
  result = func(*args, **kwargs)
2140
- except Exception as e:
2141
- _capture_exception_for_trace(current_trace, sys.exc_info())
2142
- raise e
2143
-
2256
+ except Exception as e:
2257
+ _capture_exception_for_trace(current_trace, sys.exc_info())
2258
+ raise e
2259
+
2144
2260
  # Capture state after execution
2145
- self._conditionally_capture_and_record_state(span, args, is_before=False)
2146
-
2261
+ self._conditionally_capture_and_record_state(
2262
+ span, args, is_before=False
2263
+ )
2264
+
2147
2265
  span.record_output(result)
2148
2266
  return result
2149
-
2267
+
2150
2268
  return wrapper
2151
-
2152
- def observe_tools(self, cls=None, *, exclude_methods: Optional[List[str]] = None,
2153
- include_private: bool = False, warn_on_double_decoration: bool = True):
2269
+
2270
+ def observe_tools(
2271
+ self,
2272
+ cls=None,
2273
+ *,
2274
+ exclude_methods: Optional[List[str]] = None,
2275
+ include_private: bool = False,
2276
+ warn_on_double_decoration: bool = True,
2277
+ ):
2154
2278
  """
2155
2279
  Automatically adds @observe(span_type="tool") to all methods in a class.
2156
2280
 
@@ -2162,28 +2286,32 @@ class Tracer:
2162
2286
  """
2163
2287
 
2164
2288
  if exclude_methods is None:
2165
- exclude_methods = ['__init__', '__new__', '__del__', '__str__', '__repr__']
2166
-
2289
+ exclude_methods = ["__init__", "__new__", "__del__", "__str__", "__repr__"]
2290
+
2167
2291
  def decorate_class(cls):
2168
2292
  if not self.enable_monitoring:
2169
2293
  return cls
2170
-
2294
+
2171
2295
  decorated = []
2172
2296
  skipped = []
2173
-
2297
+
2174
2298
  for name in dir(cls):
2175
2299
  method = getattr(cls, name)
2176
-
2177
- if (not callable(method) or
2178
- name in exclude_methods or
2179
- (name.startswith('_') and not include_private) or
2180
- not hasattr(cls, name)):
2300
+
2301
+ if (
2302
+ not callable(method)
2303
+ or name in exclude_methods
2304
+ or (name.startswith("_") and not include_private)
2305
+ or not hasattr(cls, name)
2306
+ ):
2181
2307
  continue
2182
-
2183
- if hasattr(method, '_judgment_span_name'):
2308
+
2309
+ if hasattr(method, "_judgment_span_name"):
2184
2310
  skipped.append(name)
2185
2311
  if warn_on_double_decoration:
2186
- print(f"Warning: {cls.__name__}.{name} already decorated, skipping")
2312
+ print(
2313
+ f"Warning: {cls.__name__}.{name} already decorated, skipping"
2314
+ )
2187
2315
  continue
2188
2316
 
2189
2317
  try:
@@ -2193,28 +2321,76 @@ class Tracer:
2193
2321
  except Exception as e:
2194
2322
  if warn_on_double_decoration:
2195
2323
  print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
2196
-
2324
+
2197
2325
  return cls
2198
-
2326
+
2199
2327
  return decorate_class if cls is None else decorate_class(cls)
2200
2328
 
2201
2329
  def async_evaluate(self, *args, **kwargs):
2202
- if not self.enable_evaluations:
2203
- return
2330
+ try:
2331
+ if not self.enable_monitoring or not self.enable_evaluations:
2332
+ return
2333
+
2334
+ # --- Get trace_id passed explicitly (if any) ---
2335
+ passed_trace_id = kwargs.pop(
2336
+ "trace_id", None
2337
+ ) # Get and remove trace_id from kwargs
2338
+
2339
+ current_trace = self.get_current_trace()
2340
+
2341
+ if current_trace:
2342
+ # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
2343
+ # (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
2344
+ if passed_trace_id:
2345
+ kwargs["trace_id"] = (
2346
+ passed_trace_id # Re-add if needed by TraceClient.async_evaluate
2347
+ )
2348
+ current_trace.async_evaluate(*args, **kwargs)
2349
+ else:
2350
+ warnings.warn(
2351
+ "No trace found (context var or fallback), skipping evaluation"
2352
+ ) # Modified warning
2353
+ except Exception as e:
2354
+ warnings.warn(f"Issue with async_evaluate: {e}")
2355
+
2356
+ def update_metadata(self, metadata: dict):
2357
+ """
2358
+ Update metadata for the current trace.
2359
+
2360
+ Args:
2361
+ metadata: Metadata as a dictionary
2362
+ """
2363
+ current_trace = self.get_current_trace()
2364
+ if current_trace:
2365
+ current_trace.update_metadata(metadata)
2366
+ else:
2367
+ warnings.warn("No current trace found, cannot set metadata")
2204
2368
 
2205
- # --- Get trace_id passed explicitly (if any) ---
2206
- passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
2369
+ def set_customer_id(self, customer_id: str):
2370
+ """
2371
+ Set the customer ID for the current trace.
2207
2372
 
2373
+ Args:
2374
+ customer_id: The customer ID to set
2375
+ """
2208
2376
  current_trace = self.get_current_trace()
2377
+ if current_trace:
2378
+ current_trace.set_customer_id(customer_id)
2379
+ else:
2380
+ warnings.warn("No current trace found, cannot set customer ID")
2381
+
2382
+ def set_tags(self, tags: List[Union[str, set, tuple]]):
2383
+ """
2384
+ Set the tags for the current trace.
2209
2385
 
2386
+ Args:
2387
+ tags: List of tags to set
2388
+ """
2389
+ current_trace = self.get_current_trace()
2210
2390
  if current_trace:
2211
- # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
2212
- # (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
2213
- if passed_trace_id:
2214
- kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
2215
- current_trace.async_evaluate(*args, **kwargs)
2391
+ current_trace.set_tags(tags)
2216
2392
  else:
2217
- warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
2393
+ warnings.warn("No current trace found, cannot set tags")
2218
2394
 
2219
2395
  def get_background_span_service(self) -> Optional[BackgroundSpanService]:
2220
2396
  """Get the background span service instance."""
@@ -2231,31 +2407,43 @@ class Tracer:
2231
2407
  self.background_span_service.shutdown()
2232
2408
  self.background_span_service = None
2233
2409
 
2234
- def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts) -> Any:
2410
+
2411
+ def _get_current_trace(
2412
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2413
+ ):
2414
+ if trace_across_async_contexts:
2415
+ return Tracer.current_trace
2416
+ else:
2417
+ return current_trace_var.get()
2418
+
2419
+
2420
+ def wrap(
2421
+ client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts
2422
+ ) -> Any:
2235
2423
  """
2236
2424
  Wraps an API client to add tracing capabilities.
2237
2425
  Supports OpenAI, Together, Anthropic, and Google GenAI clients.
2238
2426
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
2239
2427
  """
2240
- span_name, original_create, original_responses_create, original_stream, original_beta_parse = _get_client_config(client)
2428
+ (
2429
+ span_name,
2430
+ original_create,
2431
+ original_responses_create,
2432
+ original_stream,
2433
+ original_beta_parse,
2434
+ ) = _get_client_config(client)
2241
2435
 
2242
- def _get_current_trace():
2243
- if trace_across_async_contexts:
2244
- return Tracer.current_trace
2245
- else:
2246
- return current_trace_var.get()
2247
-
2248
2436
  def _record_input_and_check_streaming(span, kwargs, is_responses=False):
2249
2437
  """Record input and check for streaming"""
2250
2438
  is_streaming = kwargs.get("stream", False)
2251
2439
 
2252
- # Record input based on whether this is a responses endpoint
2440
+ # Record input based on whether this is a responses endpoint
2253
2441
  if is_responses:
2254
2442
  span.record_input(kwargs)
2255
2443
  else:
2256
2444
  input_data = _format_input_data(client, **kwargs)
2257
2445
  span.record_input(input_data)
2258
-
2446
+
2259
2447
  # Warn about token counting limitations with streaming
2260
2448
  if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
2261
2449
  if not kwargs.get("stream_options", {}).get("include_usage"):
@@ -2263,88 +2451,101 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2263
2451
  "OpenAI streaming calls don't include token counts by default. "
2264
2452
  "To enable token counting with streams, set stream_options={'include_usage': True} "
2265
2453
  "in your API call arguments.",
2266
- UserWarning
2454
+ UserWarning,
2267
2455
  )
2268
-
2456
+
2269
2457
  return is_streaming
2270
-
2458
+
2271
2459
  def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
2272
2460
  """Format and record the output in the span"""
2273
2461
  if is_streaming:
2274
2462
  output_entry = span.record_output("<pending stream>")
2275
2463
  wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
2276
- return wrapper_func(response, client, output_entry)
2464
+ return wrapper_func(
2465
+ response, client, output_entry, trace_across_async_contexts
2466
+ )
2277
2467
  else:
2278
- format_func = _format_response_output_data if is_responses else _format_output_data
2468
+ format_func = (
2469
+ _format_response_output_data if is_responses else _format_output_data
2470
+ )
2279
2471
  output, usage = format_func(client, response)
2280
2472
  span.record_output(output)
2281
2473
  span.record_usage(usage)
2282
-
2474
+
2283
2475
  # Queue the completed LLM span now that it has all data (input, output, usage)
2284
- current_trace = _get_current_trace()
2476
+ current_trace = _get_current_trace(trace_across_async_contexts)
2285
2477
  if current_trace and current_trace.background_span_service:
2286
2478
  # Get the current span from the trace client
2287
2479
  current_span_id = current_trace.get_current_span()
2288
2480
  if current_span_id and current_span_id in current_trace.span_id_to_span:
2289
2481
  completed_span = current_trace.span_id_to_span[current_span_id]
2290
- current_trace.background_span_service.queue_span(completed_span, span_state="completed")
2291
-
2482
+ current_trace.background_span_service.queue_span(
2483
+ completed_span, span_state="completed"
2484
+ )
2485
+
2292
2486
  return response
2293
-
2487
+
2294
2488
  # --- Traced Async Functions ---
2295
2489
  async def traced_create_async(*args, **kwargs):
2296
- current_trace = _get_current_trace()
2490
+ current_trace = _get_current_trace(trace_across_async_contexts)
2297
2491
  if not current_trace:
2298
2492
  return await original_create(*args, **kwargs)
2299
-
2493
+
2300
2494
  with current_trace.span(span_name, span_type="llm") as span:
2301
2495
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2302
-
2496
+
2303
2497
  try:
2304
2498
  response_or_iterator = await original_create(*args, **kwargs)
2305
- return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
2499
+ return _format_and_record_output(
2500
+ span, response_or_iterator, is_streaming, True, False
2501
+ )
2306
2502
  except Exception as e:
2307
2503
  _capture_exception_for_trace(span, sys.exc_info())
2308
2504
  raise e
2309
-
2505
+
2310
2506
  async def traced_beta_parse_async(*args, **kwargs):
2311
- current_trace = _get_current_trace()
2507
+ current_trace = _get_current_trace(trace_across_async_contexts)
2312
2508
  if not current_trace:
2313
2509
  return await original_beta_parse(*args, **kwargs)
2314
-
2510
+
2315
2511
  with current_trace.span(span_name, span_type="llm") as span:
2316
2512
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2317
-
2513
+
2318
2514
  try:
2319
2515
  response_or_iterator = await original_beta_parse(*args, **kwargs)
2320
- return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
2516
+ return _format_and_record_output(
2517
+ span, response_or_iterator, is_streaming, True, False
2518
+ )
2321
2519
  except Exception as e:
2322
2520
  _capture_exception_for_trace(span, sys.exc_info())
2323
2521
  raise e
2324
-
2325
-
2522
+
2326
2523
  # Async responses for OpenAI clients
2327
2524
  async def traced_response_create_async(*args, **kwargs):
2328
- current_trace = _get_current_trace()
2525
+ current_trace = _get_current_trace(trace_across_async_contexts)
2329
2526
  if not current_trace:
2330
2527
  return await original_responses_create(*args, **kwargs)
2331
-
2528
+
2332
2529
  with current_trace.span(span_name, span_type="llm") as span:
2333
- is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
2334
-
2530
+ is_streaming = _record_input_and_check_streaming(
2531
+ span, kwargs, is_responses=True
2532
+ )
2533
+
2335
2534
  try:
2336
2535
  response_or_iterator = await original_responses_create(*args, **kwargs)
2337
- return _format_and_record_output(span, response_or_iterator, is_streaming, True, True)
2536
+ return _format_and_record_output(
2537
+ span, response_or_iterator, is_streaming, True, True
2538
+ )
2338
2539
  except Exception as e:
2339
2540
  _capture_exception_for_trace(span, sys.exc_info())
2340
2541
  raise e
2341
-
2542
+
2342
2543
  # Function replacing .stream() for async clients
2343
2544
  def traced_stream_async(*args, **kwargs):
2344
- current_trace = _get_current_trace()
2545
+ current_trace = _get_current_trace(trace_across_async_contexts)
2345
2546
  if not current_trace or not original_stream:
2346
2547
  return original_stream(*args, **kwargs)
2347
-
2548
+
2348
2549
  original_manager = original_stream(*args, **kwargs)
2349
2550
  return _TracedAsyncStreamManagerWrapper(
2350
2551
  original_manager=original_manager,
@@ -2352,61 +2553,70 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2352
2553
  span_name=span_name,
2353
2554
  trace_client=current_trace,
2354
2555
  stream_wrapper_func=_async_stream_wrapper,
2355
- input_kwargs=kwargs
2556
+ input_kwargs=kwargs,
2557
+ trace_across_async_contexts=trace_across_async_contexts,
2356
2558
  )
2357
-
2559
+
2358
2560
  # --- Traced Sync Functions ---
2359
2561
  def traced_create_sync(*args, **kwargs):
2360
- current_trace = _get_current_trace()
2562
+ current_trace = _get_current_trace(trace_across_async_contexts)
2361
2563
  if not current_trace:
2362
2564
  return original_create(*args, **kwargs)
2363
-
2565
+
2364
2566
  with current_trace.span(span_name, span_type="llm") as span:
2365
2567
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2366
-
2568
+
2367
2569
  try:
2368
2570
  response_or_iterator = original_create(*args, **kwargs)
2369
- return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
2571
+ return _format_and_record_output(
2572
+ span, response_or_iterator, is_streaming, False, False
2573
+ )
2370
2574
  except Exception as e:
2371
2575
  _capture_exception_for_trace(span, sys.exc_info())
2372
2576
  raise e
2373
-
2577
+
2374
2578
  def traced_beta_parse_sync(*args, **kwargs):
2375
- current_trace = _get_current_trace()
2579
+ current_trace = _get_current_trace(trace_across_async_contexts)
2376
2580
  if not current_trace:
2377
2581
  return original_beta_parse(*args, **kwargs)
2378
-
2582
+
2379
2583
  with current_trace.span(span_name, span_type="llm") as span:
2380
2584
  is_streaming = _record_input_and_check_streaming(span, kwargs)
2381
-
2585
+
2382
2586
  try:
2383
2587
  response_or_iterator = original_beta_parse(*args, **kwargs)
2384
- return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
2588
+ return _format_and_record_output(
2589
+ span, response_or_iterator, is_streaming, False, False
2590
+ )
2385
2591
  except Exception as e:
2386
2592
  _capture_exception_for_trace(span, sys.exc_info())
2387
2593
  raise e
2388
-
2594
+
2389
2595
  def traced_response_create_sync(*args, **kwargs):
2390
- current_trace = _get_current_trace()
2596
+ current_trace = _get_current_trace(trace_across_async_contexts)
2391
2597
  if not current_trace:
2392
2598
  return original_responses_create(*args, **kwargs)
2393
-
2599
+
2394
2600
  with current_trace.span(span_name, span_type="llm") as span:
2395
- is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
2396
-
2601
+ is_streaming = _record_input_and_check_streaming(
2602
+ span, kwargs, is_responses=True
2603
+ )
2604
+
2397
2605
  try:
2398
2606
  response_or_iterator = original_responses_create(*args, **kwargs)
2399
- return _format_and_record_output(span, response_or_iterator, is_streaming, False, True)
2607
+ return _format_and_record_output(
2608
+ span, response_or_iterator, is_streaming, False, True
2609
+ )
2400
2610
  except Exception as e:
2401
2611
  _capture_exception_for_trace(span, sys.exc_info())
2402
2612
  raise e
2403
-
2613
+
2404
2614
  # Function replacing sync .stream()
2405
2615
  def traced_stream_sync(*args, **kwargs):
2406
- current_trace = _get_current_trace()
2616
+ current_trace = _get_current_trace(trace_across_async_contexts)
2407
2617
  if not current_trace or not original_stream:
2408
2618
  return original_stream(*args, **kwargs)
2409
-
2619
+
2410
2620
  original_manager = original_stream(*args, **kwargs)
2411
2621
  return _TracedSyncStreamManagerWrapper(
2412
2622
  original_manager=original_manager,
@@ -2414,15 +2624,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2414
2624
  span_name=span_name,
2415
2625
  trace_client=current_trace,
2416
2626
  stream_wrapper_func=_sync_stream_wrapper,
2417
- input_kwargs=kwargs
2627
+ input_kwargs=kwargs,
2628
+ trace_across_async_contexts=trace_across_async_contexts,
2418
2629
  )
2419
-
2630
+
2420
2631
  # --- Assign Traced Methods to Client Instance ---
2421
2632
  if isinstance(client, (AsyncOpenAI, AsyncTogether)):
2422
2633
  client.chat.completions.create = traced_create_async
2423
2634
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
2424
2635
  client.responses.create = traced_response_create_async
2425
- if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2636
+ if (
2637
+ hasattr(client, "beta")
2638
+ and hasattr(client.beta, "chat")
2639
+ and hasattr(client.beta.chat, "completions")
2640
+ and hasattr(client.beta.chat.completions, "parse")
2641
+ ):
2426
2642
  client.beta.chat.completions.parse = traced_beta_parse_async
2427
2643
  elif isinstance(client, AsyncAnthropic):
2428
2644
  client.messages.create = traced_create_async
@@ -2434,7 +2650,12 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2434
2650
  client.chat.completions.create = traced_create_sync
2435
2651
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
2436
2652
  client.responses.create = traced_response_create_sync
2437
- if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2653
+ if (
2654
+ hasattr(client, "beta")
2655
+ and hasattr(client.beta, "chat")
2656
+ and hasattr(client.beta.chat, "completions")
2657
+ and hasattr(client.beta.chat.completions, "parse")
2658
+ ):
2438
2659
  client.beta.chat.completions.parse = traced_beta_parse_sync
2439
2660
  elif isinstance(client, Anthropic):
2440
2661
  client.messages.create = traced_create_sync
@@ -2442,17 +2663,21 @@ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_as
2442
2663
  client.messages.stream = traced_stream_sync
2443
2664
  elif isinstance(client, genai.Client):
2444
2665
  client.models.generate_content = traced_create_sync
2445
-
2666
+
2446
2667
  return client
2447
2668
 
2669
+
2448
2670
  # Helper functions for client-specific operations
2449
2671
 
2450
- def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[callable]]:
2672
+
2673
+ def _get_client_config(
2674
+ client: ApiClient,
2675
+ ) -> tuple[str, Callable, Optional[Callable], Optional[Callable], Optional[Callable]]:
2451
2676
  """Returns configuration tuple for the given API client.
2452
-
2677
+
2453
2678
  Args:
2454
2679
  client: An instance of OpenAI, Together, or Anthropic client
2455
-
2680
+
2456
2681
  Returns:
2457
2682
  tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
2458
2683
  - span_name: String identifier for tracing
@@ -2460,23 +2685,36 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
2460
2685
  - responses_method: Reference to the client's responses method (if applicable)
2461
2686
  - stream_method: Reference to the client's stream method (if applicable)
2462
2687
  - beta_parse_method: Reference to the client's beta parse method (if applicable)
2463
-
2688
+
2464
2689
  Raises:
2465
2690
  ValueError: If client type is not supported
2466
2691
  """
2467
2692
  if isinstance(client, (OpenAI, AsyncOpenAI)):
2468
- return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None, client.beta.chat.completions.parse
2693
+ return (
2694
+ "OPENAI_API_CALL",
2695
+ client.chat.completions.create,
2696
+ client.responses.create,
2697
+ None,
2698
+ client.beta.chat.completions.parse,
2699
+ )
2469
2700
  elif isinstance(client, (Together, AsyncTogether)):
2470
2701
  return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
2471
2702
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
2472
- return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream, None
2703
+ return (
2704
+ "ANTHROPIC_API_CALL",
2705
+ client.messages.create,
2706
+ None,
2707
+ client.messages.stream,
2708
+ None,
2709
+ )
2473
2710
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2474
2711
  return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
2475
2712
  raise ValueError(f"Unsupported client type: {type(client)}")
2476
2713
 
2714
+
2477
2715
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
2478
2716
  """Format input parameters based on client type.
2479
-
2717
+
2480
2718
  Extracts relevant parameters from kwargs based on the client type
2481
2719
  to ensure consistent tracing across different APIs.
2482
2720
  """
@@ -2489,25 +2727,23 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
2489
2727
  input_data["response_format"] = kwargs.get("response_format")
2490
2728
  return input_data
2491
2729
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2492
- return {
2493
- "model": kwargs.get("model"),
2494
- "contents": kwargs.get("contents")
2495
- }
2730
+ return {"model": kwargs.get("model"), "contents": kwargs.get("contents")}
2496
2731
  # Anthropic requires additional max_tokens parameter
2497
2732
  return {
2498
2733
  "model": kwargs.get("model"),
2499
2734
  "messages": kwargs.get("messages"),
2500
- "max_tokens": kwargs.get("max_tokens")
2735
+ "max_tokens": kwargs.get("max_tokens"),
2501
2736
  }
2502
2737
 
2503
- def _format_response_output_data(client: ApiClient, response: Any) -> dict:
2738
+
2739
+ def _format_response_output_data(client: ApiClient, response: Any) -> tuple:
2504
2740
  """Format API response data based on client type.
2505
-
2741
+
2506
2742
  Normalizes different response formats into a consistent structure
2507
2743
  for tracing purposes.
2508
2744
  """
2509
2745
  message_content = None
2510
- prompt_tokens = 0
2746
+ prompt_tokens = 0
2511
2747
  completion_tokens = 0
2512
2748
  model_name = None
2513
2749
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
@@ -2517,14 +2753,16 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
2517
2753
  message_content = response.output
2518
2754
  else:
2519
2755
  warnings.warn(f"Unsupported client type: {type(client)}")
2520
- return {}
2521
-
2522
- prompt_cost, completion_cost = cost_per_token(
2756
+ return None, None
2757
+
2758
+ prompt_cost, completion_cost = cost_per_token(
2523
2759
  model=model_name,
2524
2760
  prompt_tokens=prompt_tokens,
2525
2761
  completion_tokens=completion_tokens,
2526
2762
  )
2527
- total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2763
+ total_cost_usd = (
2764
+ (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2765
+ )
2528
2766
  usage = TraceUsage(
2529
2767
  prompt_tokens=prompt_tokens,
2530
2768
  completion_tokens=completion_tokens,
@@ -2532,17 +2770,19 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
2532
2770
  prompt_tokens_cost_usd=prompt_cost,
2533
2771
  completion_tokens_cost_usd=completion_cost,
2534
2772
  total_cost_usd=total_cost_usd,
2535
- model_name=model_name
2773
+ model_name=model_name,
2536
2774
  )
2537
2775
  return message_content, usage
2538
2776
 
2539
2777
 
2540
- def _format_output_data(client: ApiClient, response: Any) -> dict:
2778
+ def _format_output_data(
2779
+ client: ApiClient, response: Any
2780
+ ) -> tuple[Optional[str], Optional[TraceUsage]]:
2541
2781
  """Format API response data based on client type.
2542
-
2782
+
2543
2783
  Normalizes different response formats into a consistent structure
2544
2784
  for tracing purposes.
2545
-
2785
+
2546
2786
  Returns:
2547
2787
  dict containing:
2548
2788
  - content: The generated text
@@ -2557,7 +2797,10 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
2557
2797
  model_name = response.model
2558
2798
  prompt_tokens = response.usage.prompt_tokens
2559
2799
  completion_tokens = response.usage.completion_tokens
2560
- if hasattr(response.choices[0].message, "parsed") and response.choices[0].message.parsed:
2800
+ if (
2801
+ hasattr(response.choices[0].message, "parsed")
2802
+ and response.choices[0].message.parsed
2803
+ ):
2561
2804
  message_content = response.choices[0].message.parsed
2562
2805
  else:
2563
2806
  message_content = response.choices[0].message.content
@@ -2574,13 +2817,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
2574
2817
  else:
2575
2818
  warnings.warn(f"Unsupported client type: {type(client)}")
2576
2819
  return None, None
2577
-
2820
+
2578
2821
  prompt_cost, completion_cost = cost_per_token(
2579
2822
  model=model_name,
2580
2823
  prompt_tokens=prompt_tokens,
2581
2824
  completion_tokens=completion_tokens,
2582
2825
  )
2583
- total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2826
+ total_cost_usd = (
2827
+ (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2828
+ )
2584
2829
  usage = TraceUsage(
2585
2830
  prompt_tokens=prompt_tokens,
2586
2831
  completion_tokens=completion_tokens,
@@ -2588,56 +2833,61 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
2588
2833
  prompt_tokens_cost_usd=prompt_cost,
2589
2834
  completion_tokens_cost_usd=completion_cost,
2590
2835
  total_cost_usd=total_cost_usd,
2591
- model_name=model_name
2836
+ model_name=model_name,
2592
2837
  )
2593
2838
  return message_content, usage
2594
2839
 
2840
+
2595
2841
  def combine_args_kwargs(func, args, kwargs):
2596
2842
  """
2597
2843
  Combine positional arguments and keyword arguments into a single dictionary.
2598
-
2844
+
2599
2845
  Args:
2600
2846
  func: The function being called
2601
2847
  args: Tuple of positional arguments
2602
2848
  kwargs: Dictionary of keyword arguments
2603
-
2849
+
2604
2850
  Returns:
2605
2851
  A dictionary combining both args and kwargs
2606
2852
  """
2607
2853
  try:
2608
2854
  import inspect
2855
+
2609
2856
  sig = inspect.signature(func)
2610
2857
  param_names = list(sig.parameters.keys())
2611
-
2858
+
2612
2859
  args_dict = {}
2613
2860
  for i, arg in enumerate(args):
2614
2861
  if i < len(param_names):
2615
2862
  args_dict[param_names[i]] = arg
2616
2863
  else:
2617
2864
  args_dict[f"arg{i}"] = arg
2618
-
2865
+
2619
2866
  return {**args_dict, **kwargs}
2620
- except Exception as e:
2867
+ except Exception:
2621
2868
  # Fallback if signature inspection fails
2622
2869
  return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
2623
2870
 
2871
+
2624
2872
  # NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
2625
2873
  # @link https://docs.python.org/3.13/library/sysconfig.html
2626
2874
  _TRACE_FILEPATH_BLOCKLIST = tuple(
2627
2875
  os.path.realpath(p) + os.sep
2628
2876
  for p in {
2629
- sysconfig.get_paths()['stdlib'],
2630
- sysconfig.get_paths().get('platstdlib', ''),
2877
+ sysconfig.get_paths()["stdlib"],
2878
+ sysconfig.get_paths().get("platstdlib", ""),
2631
2879
  *site.getsitepackages(),
2632
2880
  site.getusersitepackages(),
2633
2881
  *(
2634
- [os.path.join(os.path.dirname(__file__), '../../judgeval/')]
2635
- if os.environ.get('JUDGMENT_DEV')
2882
+ [os.path.join(os.path.dirname(__file__), "../../judgeval/")]
2883
+ if os.environ.get("JUDGMENT_DEV")
2636
2884
  else []
2637
2885
  ),
2638
- } if p
2886
+ }
2887
+ if p
2639
2888
  )
2640
2889
 
2890
+
2641
2891
  # Add the new TraceThreadPoolExecutor class
2642
2892
  class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
2643
2893
  """
@@ -2649,6 +2899,7 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
2649
2899
  allowing the Tracer to maintain correct parent-child relationships across
2650
2900
  thread boundaries.
2651
2901
  """
2902
+
2652
2903
  def submit(self, fn, /, *args, **kwargs):
2653
2904
  """
2654
2905
  Submit a callable to be executed with the captured context.
@@ -2669,9 +2920,11 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
2669
2920
  # Note: The `map` method would also need to be overridden for full context
2670
2921
  # propagation if users rely on it, but `submit` is the most common use case.
2671
2922
 
2923
+
2672
2924
  # Helper functions for stream processing
2673
2925
  # ---------------------------------------
2674
2926
 
2927
+
2675
2928
  def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
2676
2929
  """Extracts the text content from a stream chunk based on the client type."""
2677
2930
  try:
@@ -2683,34 +2936,49 @@ def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
2683
2936
  return chunk.delta.text
2684
2937
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2685
2938
  # Google streams Candidate objects
2686
- if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts:
2939
+ if (
2940
+ chunk.candidates
2941
+ and chunk.candidates[0].content
2942
+ and chunk.candidates[0].content.parts
2943
+ ):
2687
2944
  return chunk.candidates[0].content.parts[0].text
2688
2945
  except (AttributeError, IndexError, KeyError):
2689
2946
  # Handle cases where chunk structure is unexpected or doesn't contain content
2690
- pass # Return None
2947
+ pass # Return None
2691
2948
  return None
2692
2949
 
2693
- def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[Dict[str, int]]:
2950
+
2951
+ def _extract_usage_from_final_chunk(
2952
+ client: ApiClient, chunk: Any
2953
+ ) -> Optional[Dict[str, int]]:
2694
2954
  """Extracts usage data if present in the *final* chunk (client-specific)."""
2695
2955
  try:
2696
2956
  # OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
2697
2957
  # This typically requires specific API versions or settings. Often usage is *not* streamed.
2698
2958
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
2699
2959
  # Check if usage is directly on the chunk (some models might do this)
2700
- if hasattr(chunk, 'usage') and chunk.usage:
2960
+ if hasattr(chunk, "usage") and chunk.usage:
2701
2961
  prompt_tokens = chunk.usage.prompt_tokens
2702
2962
  completion_tokens = chunk.usage.completion_tokens
2703
2963
  # Check if usage is nested within choices (less common for final chunk, but check)
2704
- elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
2964
+ elif (
2965
+ chunk.choices
2966
+ and hasattr(chunk.choices[0], "usage")
2967
+ and chunk.choices[0].usage
2968
+ ):
2705
2969
  prompt_tokens = chunk.choices[0].usage.prompt_tokens
2706
2970
  completion_tokens = chunk.choices[0].usage.completion_tokens
2707
-
2971
+
2708
2972
  prompt_cost, completion_cost = cost_per_token(
2709
- model=chunk.model,
2710
- prompt_tokens=prompt_tokens,
2711
- completion_tokens=completion_tokens,
2712
- )
2713
- total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
2973
+ model=chunk.model,
2974
+ prompt_tokens=prompt_tokens,
2975
+ completion_tokens=completion_tokens,
2976
+ )
2977
+ total_cost_usd = (
2978
+ (prompt_cost + completion_cost)
2979
+ if prompt_cost and completion_cost
2980
+ else None
2981
+ )
2714
2982
  return TraceUsage(
2715
2983
  prompt_tokens=chunk.usage.prompt_tokens,
2716
2984
  completion_tokens=chunk.usage.completion_tokens,
@@ -2718,9 +2986,9 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
2718
2986
  prompt_tokens_cost_usd=prompt_cost,
2719
2987
  completion_tokens_cost_usd=completion_cost,
2720
2988
  total_cost_usd=total_cost_usd,
2721
- model_name=chunk.model
2989
+ model_name=chunk.model,
2722
2990
  )
2723
- # Anthropic includes usage in the 'message_stop' event type
2991
+ # Anthropic includes usage in the 'message_stop' event type
2724
2992
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
2725
2993
  if chunk.type == "message_stop":
2726
2994
  # Anthropic final usage is often attached to the *message* object, not the chunk directly
@@ -2729,18 +2997,18 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
2729
2997
  # This is a placeholder - Anthropic usage typically needs a separate call or context.
2730
2998
  pass
2731
2999
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
2732
- # Google provides usage metadata on the full response object, not typically streamed per chunk.
2733
- # It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
2734
- if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
2735
- return {
2736
- "prompt_tokens": chunk.usage_metadata.prompt_token_count,
2737
- "completion_tokens": chunk.usage_metadata.candidates_token_count,
2738
- "total_tokens": chunk.usage_metadata.total_token_count
2739
- }
3000
+ # Google provides usage metadata on the full response object, not typically streamed per chunk.
3001
+ # It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
3002
+ if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
3003
+ return {
3004
+ "prompt_tokens": chunk.usage_metadata.prompt_token_count,
3005
+ "completion_tokens": chunk.usage_metadata.candidates_token_count,
3006
+ "total_tokens": chunk.usage_metadata.total_token_count,
3007
+ }
2740
3008
 
2741
3009
  except (AttributeError, IndexError, KeyError, TypeError):
2742
3010
  # Handle cases where usage data is missing or malformed
2743
- pass # Return None
3011
+ pass # Return None
2744
3012
  return None
2745
3013
 
2746
3014
 
@@ -2748,7 +3016,8 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
2748
3016
  def _sync_stream_wrapper(
2749
3017
  original_stream: Iterator,
2750
3018
  client: ApiClient,
2751
- span: TraceSpan
3019
+ span: TraceSpan,
3020
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2752
3021
  ) -> Generator[Any, None, None]:
2753
3022
  """Wraps a synchronous stream iterator to capture content and update the trace."""
2754
3023
  content_parts = [] # Use a list instead of string concatenation
@@ -2758,9 +3027,11 @@ def _sync_stream_wrapper(
2758
3027
  for chunk in original_stream:
2759
3028
  content_part = _extract_content_from_chunk(client, chunk)
2760
3029
  if content_part:
2761
- content_parts.append(content_part) # Append to list instead of concatenating
2762
- last_chunk = chunk # Keep track of the last chunk for potential usage data
2763
- yield chunk # Pass the chunk to the caller
3030
+ content_parts.append(
3031
+ content_part
3032
+ ) # Append to list instead of concatenating
3033
+ last_chunk = chunk # Keep track of the last chunk for potential usage data
3034
+ yield chunk # Pass the chunk to the caller
2764
3035
  finally:
2765
3036
  # Attempt to extract usage from the last chunk received
2766
3037
  if last_chunk:
@@ -2769,23 +3040,25 @@ def _sync_stream_wrapper(
2769
3040
  # Update the trace entry with the accumulated content and usage
2770
3041
  span.output = "".join(content_parts)
2771
3042
  span.usage = final_usage
2772
-
3043
+
2773
3044
  # Queue the completed LLM span now that streaming is done and all data is available
2774
- # Note: We need to get the TraceClient that owns this span to access the background service
2775
- # We can find this through the tracer singleton since spans are associated with traces
2776
- from judgeval.common.tracer import Tracer
2777
- tracer_instance = Tracer._instance
2778
- if tracer_instance and tracer_instance.background_span_service:
2779
- tracer_instance.background_span_service.queue_span(span, span_state="completed")
2780
-
3045
+
3046
+ current_trace = _get_current_trace(trace_across_async_contexts)
3047
+ if current_trace and current_trace.background_span_service:
3048
+ current_trace.background_span_service.queue_span(
3049
+ span, span_state="completed"
3050
+ )
3051
+
2781
3052
  # Note: We might need to adjust _serialize_output if this dict causes issues,
2782
3053
  # but Pydantic's model_dump should handle dicts.
2783
3054
 
3055
+
2784
3056
  # --- Async Stream Wrapper ---
2785
3057
  async def _async_stream_wrapper(
2786
3058
  original_stream: AsyncIterator,
2787
3059
  client: ApiClient,
2788
- span: TraceSpan
3060
+ span: TraceSpan,
3061
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
2789
3062
  ) -> AsyncGenerator[Any, None]:
2790
3063
  # [Existing logic - unchanged]
2791
3064
  content_parts = [] # Use a list instead of string concatenation
@@ -2794,56 +3067,70 @@ async def _async_stream_wrapper(
2794
3067
  anthropic_input_tokens = 0
2795
3068
  anthropic_output_tokens = 0
2796
3069
 
2797
- target_span_id = span.span_id
2798
-
2799
3070
  try:
2800
3071
  model_name = ""
2801
3072
  async for chunk in original_stream:
2802
3073
  # Check for OpenAI's final usage chunk
2803
- if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
3074
+ if (
3075
+ isinstance(client, (AsyncOpenAI, OpenAI))
3076
+ and hasattr(chunk, "usage")
3077
+ and chunk.usage is not None
3078
+ ):
2804
3079
  final_usage_data = {
2805
3080
  "prompt_tokens": chunk.usage.prompt_tokens,
2806
3081
  "completion_tokens": chunk.usage.completion_tokens,
2807
- "total_tokens": chunk.usage.total_tokens
3082
+ "total_tokens": chunk.usage.total_tokens,
2808
3083
  }
2809
3084
  model_name = chunk.model
2810
3085
  yield chunk
2811
3086
  continue
2812
3087
 
2813
- if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
3088
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(
3089
+ chunk, "type"
3090
+ ):
2814
3091
  if chunk.type == "message_start":
2815
- if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
2816
- anthropic_input_tokens = chunk.message.usage.input_tokens
2817
- model_name = chunk.message.model
3092
+ if (
3093
+ hasattr(chunk, "message")
3094
+ and hasattr(chunk.message, "usage")
3095
+ and hasattr(chunk.message.usage, "input_tokens")
3096
+ ):
3097
+ anthropic_input_tokens = chunk.message.usage.input_tokens
3098
+ model_name = chunk.message.model
2818
3099
  elif chunk.type == "message_delta":
2819
- if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
3100
+ if hasattr(chunk, "usage") and hasattr(
3101
+ chunk.usage, "output_tokens"
3102
+ ):
2820
3103
  anthropic_output_tokens = chunk.usage.output_tokens
2821
3104
 
2822
3105
  content_part = _extract_content_from_chunk(client, chunk)
2823
3106
  if content_part:
2824
- content_parts.append(content_part) # Append to list instead of concatenating
3107
+ content_parts.append(
3108
+ content_part
3109
+ ) # Append to list instead of concatenating
2825
3110
  last_content_chunk = chunk
2826
3111
 
2827
3112
  yield chunk
2828
3113
  finally:
2829
3114
  anthropic_final_usage = None
2830
- if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
2831
- anthropic_final_usage = {
2832
- "prompt_tokens": anthropic_input_tokens,
2833
- "completion_tokens": anthropic_output_tokens,
2834
- "total_tokens": anthropic_input_tokens + anthropic_output_tokens
2835
- }
3115
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and (
3116
+ anthropic_input_tokens > 0 or anthropic_output_tokens > 0
3117
+ ):
3118
+ anthropic_final_usage = {
3119
+ "prompt_tokens": anthropic_input_tokens,
3120
+ "completion_tokens": anthropic_output_tokens,
3121
+ "total_tokens": anthropic_input_tokens + anthropic_output_tokens,
3122
+ }
2836
3123
 
2837
3124
  usage_info = None
2838
3125
  if final_usage_data:
2839
- usage_info = final_usage_data
3126
+ usage_info = final_usage_data
2840
3127
  elif anthropic_final_usage:
2841
- usage_info = anthropic_final_usage
3128
+ usage_info = anthropic_final_usage
2842
3129
  elif last_content_chunk:
2843
3130
  usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
2844
3131
 
2845
3132
  if usage_info and not isinstance(usage_info, TraceUsage):
2846
- prompt_cost, completion_cost = cost_per_token(
3133
+ prompt_cost, completion_cost = cost_per_token(
2847
3134
  model=model_name,
2848
3135
  prompt_tokens=usage_info["prompt_tokens"],
2849
3136
  completion_tokens=usage_info["completion_tokens"],
@@ -2855,21 +3142,23 @@ async def _async_stream_wrapper(
2855
3142
  prompt_tokens_cost_usd=prompt_cost,
2856
3143
  completion_tokens_cost_usd=completion_cost,
2857
3144
  total_cost_usd=prompt_cost + completion_cost,
2858
- model_name=model_name
3145
+ model_name=model_name,
2859
3146
  )
2860
- if span and hasattr(span, 'output'):
2861
- span.output = ''.join(content_parts)
3147
+ if span and hasattr(span, "output"):
3148
+ span.output = "".join(content_parts)
2862
3149
  span.usage = usage_info
2863
- start_ts = getattr(span, 'created_at', time.time())
3150
+ start_ts = getattr(span, "created_at", time.time())
2864
3151
  span.duration = time.time() - start_ts
2865
-
3152
+
2866
3153
  # Queue the completed LLM span now that async streaming is done and all data is available
2867
- from judgeval.common.tracer import Tracer
2868
- tracer_instance = Tracer._instance
2869
- if tracer_instance and tracer_instance.background_span_service:
2870
- tracer_instance.background_span_service.queue_span(span, span_state="completed")
3154
+ current_trace = _get_current_trace(trace_across_async_contexts)
3155
+ if current_trace and current_trace.background_span_service:
3156
+ current_trace.background_span_service.queue_span(
3157
+ span, span_state="completed"
3158
+ )
2871
3159
  # else: # Handle error case if necessary, but remove debug print
2872
3160
 
3161
+
2873
3162
  def cost_per_token(*args, **kwargs):
2874
3163
  try:
2875
3164
  return _original_cost_per_token(*args, **kwargs)
@@ -2877,8 +3166,18 @@ def cost_per_token(*args, **kwargs):
2877
3166
  warnings.warn(f"Error calculating cost per token: {e}")
2878
3167
  return None, None
2879
3168
 
3169
+
2880
3170
  class _BaseStreamManagerWrapper:
2881
- def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
3171
+ def __init__(
3172
+ self,
3173
+ original_manager,
3174
+ client,
3175
+ span_name,
3176
+ trace_client,
3177
+ stream_wrapper_func,
3178
+ input_kwargs,
3179
+ trace_across_async_contexts: bool = Tracer.trace_across_async_contexts,
3180
+ ):
2882
3181
  self._original_manager = original_manager
2883
3182
  self._client = client
2884
3183
  self._span_name = span_name
@@ -2886,13 +3185,19 @@ class _BaseStreamManagerWrapper:
2886
3185
  self._stream_wrapper_func = stream_wrapper_func
2887
3186
  self._input_kwargs = input_kwargs
2888
3187
  self._parent_span_id_at_entry = None
3188
+ self._trace_across_async_contexts = trace_across_async_contexts
2889
3189
 
2890
3190
  def _create_span(self):
2891
3191
  start_time = time.time()
2892
3192
  span_id = str(uuid.uuid4())
2893
3193
  current_depth = 0
2894
- if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2895
- current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
3194
+ if (
3195
+ self._parent_span_id_at_entry
3196
+ and self._parent_span_id_at_entry in self._trace_client._span_depths
3197
+ ):
3198
+ current_depth = (
3199
+ self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
3200
+ )
2896
3201
  self._trace_client._span_depths[span_id] = current_depth
2897
3202
  span = TraceSpan(
2898
3203
  function=self._span_name,
@@ -2902,7 +3207,7 @@ class _BaseStreamManagerWrapper:
2902
3207
  message=self._span_name,
2903
3208
  created_at=start_time,
2904
3209
  span_type="llm",
2905
- parent_span_id=self._parent_span_id_at_entry
3210
+ parent_span_id=self._parent_span_id_at_entry,
2906
3211
  )
2907
3212
  self._trace_client.add_span(span)
2908
3213
  return span_id, span
@@ -2914,7 +3219,10 @@ class _BaseStreamManagerWrapper:
2914
3219
  if span_id in self._trace_client._span_depths:
2915
3220
  del self._trace_client._span_depths[span_id]
2916
3221
 
2917
- class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
3222
+
3223
+ class _TracedAsyncStreamManagerWrapper(
3224
+ _BaseStreamManagerWrapper, AbstractAsyncContextManager
3225
+ ):
2918
3226
  async def __aenter__(self):
2919
3227
  self._parent_span_id_at_entry = self._trace_client.get_current_span()
2920
3228
  if not self._trace_client:
@@ -2927,17 +3235,22 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
2927
3235
  # Call the original __aenter__ and expect it to be an async generator
2928
3236
  raw_iterator = await self._original_manager.__aenter__()
2929
3237
  span.output = "<pending stream>"
2930
- return self._stream_wrapper_func(raw_iterator, self._client, span)
3238
+ return self._stream_wrapper_func(
3239
+ raw_iterator, self._client, span, self._trace_across_async_contexts
3240
+ )
2931
3241
 
2932
3242
  async def __aexit__(self, exc_type, exc_val, exc_tb):
2933
- if hasattr(self, '_span_context_token'):
3243
+ if hasattr(self, "_span_context_token"):
2934
3244
  span_id = self._trace_client.get_current_span()
2935
3245
  self._finalize_span(span_id)
2936
3246
  self._trace_client.reset_current_span(self._span_context_token)
2937
- delattr(self, '_span_context_token')
3247
+ delattr(self, "_span_context_token")
2938
3248
  return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2939
3249
 
2940
- class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
3250
+
3251
+ class _TracedSyncStreamManagerWrapper(
3252
+ _BaseStreamManagerWrapper, AbstractContextManager
3253
+ ):
2941
3254
  def __enter__(self):
2942
3255
  self._parent_span_id_at_entry = self._trace_client.get_current_span()
2943
3256
  if not self._trace_client:
@@ -2949,16 +3262,19 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
2949
3262
 
2950
3263
  raw_iterator = self._original_manager.__enter__()
2951
3264
  span.output = "<pending stream>"
2952
- return self._stream_wrapper_func(raw_iterator, self._client, span)
3265
+ return self._stream_wrapper_func(
3266
+ raw_iterator, self._client, span, self._trace_across_async_contexts
3267
+ )
2953
3268
 
2954
3269
  def __exit__(self, exc_type, exc_val, exc_tb):
2955
- if hasattr(self, '_span_context_token'):
3270
+ if hasattr(self, "_span_context_token"):
2956
3271
  span_id = self._trace_client.get_current_span()
2957
3272
  self._finalize_span(span_id)
2958
3273
  self._trace_client.reset_current_span(self._span_context_token)
2959
- delattr(self, '_span_context_token')
3274
+ delattr(self, "_span_context_token")
2960
3275
  return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2961
3276
 
3277
+
2962
3278
  # --- Helper function for instance-prefixed qual_name ---
2963
3279
  def get_instance_prefixed_name(instance, class_name, class_identifiers):
2964
3280
  """
@@ -2967,11 +3283,13 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
2967
3283
  """
2968
3284
  if class_name in class_identifiers:
2969
3285
  class_config = class_identifiers[class_name]
2970
- attr = class_config['identifier']
2971
-
3286
+ attr = class_config["identifier"]
3287
+
2972
3288
  if hasattr(instance, attr):
2973
3289
  instance_name = getattr(instance, attr)
2974
3290
  return instance_name
2975
3291
  else:
2976
- raise Exception(f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator.")
2977
- return None
3292
+ raise Exception(
3293
+ f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator."
3294
+ )
3295
+ return None