judgeval 0.0.41__py3-none-any.whl → 0.0.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -18,7 +18,7 @@ import sys
18
18
  import json
19
19
  from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
20
20
  from dataclasses import dataclass, field
21
- from datetime import datetime
21
+ from datetime import datetime, timezone
22
22
  from http import HTTPStatus
23
23
  from typing import (
24
24
  Any,
@@ -49,12 +49,17 @@ from google import genai
49
49
  from judgeval.constants import (
50
50
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
51
51
  JUDGMENT_TRACES_SAVE_API_URL,
52
+ JUDGMENT_TRACES_UPSERT_API_URL,
53
+ JUDGMENT_TRACES_USAGE_CHECK_API_URL,
54
+ JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
52
55
  JUDGMENT_TRACES_FETCH_API_URL,
53
56
  RABBITMQ_HOST,
54
57
  RABBITMQ_PORT,
55
58
  RABBITMQ_QUEUE,
56
59
  JUDGMENT_TRACES_DELETE_API_URL,
57
60
  JUDGMENT_PROJECT_DELETE_API_URL,
61
+ JUDGMENT_TRACES_SPANS_BATCH_API_URL,
62
+ JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
58
63
  )
59
64
  from judgeval.data import Example, Trace, TraceSpan, TraceUsage
60
65
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
@@ -66,6 +71,8 @@ from judgeval.common.exceptions import JudgmentAPIError
66
71
  # Standard library imports needed for the new class
67
72
  import concurrent.futures
68
73
  from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
74
+ import queue
75
+ import atexit
69
76
 
70
77
  # Define context variables for tracking the current trace and the current span within a trace
71
78
  current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
@@ -142,13 +149,18 @@ class TraceManagerClient:
142
149
 
143
150
  return response.json()
144
151
 
145
- def save_trace(self, trace_data: dict, offline_mode: bool = False):
152
+ def save_trace(self, trace_data: dict, offline_mode: bool = False, final_save: bool = True):
146
153
  """
147
154
  Saves a trace to the Judgment Supabase and optionally to S3 if configured.
148
155
 
149
156
  Args:
150
157
  trace_data: The trace data to save
158
+ offline_mode: Whether running in offline mode
159
+ final_save: Whether this is the final save (controls S3 saving)
151
160
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
161
+
162
+ Returns:
163
+ dict: Server response containing UI URL and other metadata
152
164
  """
153
165
  # Save to Judgment API
154
166
 
@@ -170,7 +182,6 @@ class TraceManagerClient:
170
182
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
171
183
 
172
184
  serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
173
-
174
185
  response = requests.post(
175
186
  JUDGMENT_TRACES_SAVE_API_URL,
176
187
  data=serialized_trace_data,
@@ -187,8 +198,11 @@ class TraceManagerClient:
187
198
  elif response.status_code != HTTPStatus.OK:
188
199
  raise ValueError(f"Failed to save trace data: {response.text}")
189
200
 
190
- # If S3 storage is enabled, save to S3 as well
191
- if self.tracer and self.tracer.use_s3:
201
+ # Parse server response
202
+ server_response = response.json()
203
+
204
+ # If S3 storage is enabled, save to S3 only on final save
205
+ if self.tracer and self.tracer.use_s3 and final_save:
192
206
  try:
193
207
  s3_key = self.tracer.s3_storage.save_trace(
194
208
  trace_data=trace_data,
@@ -199,9 +213,136 @@ class TraceManagerClient:
199
213
  except Exception as e:
200
214
  warnings.warn(f"Failed to save trace to S3: {str(e)}")
201
215
 
202
- if not offline_mode and "ui_results_url" in response.json():
203
- pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
216
+ if not offline_mode and "ui_results_url" in server_response:
217
+ pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
204
218
  rprint(pretty_str)
219
+
220
+ return server_response
221
+
222
+ def check_usage_limits(self, count: int = 1):
223
+ """
224
+ Check if the organization can use the requested number of traces without exceeding limits.
225
+
226
+ Args:
227
+ count: Number of traces to check for (default: 1)
228
+
229
+ Returns:
230
+ dict: Server response with rate limit status and usage info
231
+
232
+ Raises:
233
+ ValueError: If rate limits would be exceeded or other errors occur
234
+ """
235
+ response = requests.post(
236
+ JUDGMENT_TRACES_USAGE_CHECK_API_URL,
237
+ json={"count": count},
238
+ headers={
239
+ "Content-Type": "application/json",
240
+ "Authorization": f"Bearer {self.judgment_api_key}",
241
+ "X-Organization-Id": self.organization_id
242
+ },
243
+ verify=True
244
+ )
245
+
246
+ if response.status_code == HTTPStatus.FORBIDDEN:
247
+ # Rate limits exceeded
248
+ error_data = response.json()
249
+ raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
250
+ elif response.status_code != HTTPStatus.OK:
251
+ raise ValueError(f"Failed to check usage limits: {response.text}")
252
+
253
+ return response.json()
254
+
255
+ def upsert_trace(self, trace_data: dict, offline_mode: bool = False, show_link: bool = True, final_save: bool = True):
256
+ """
257
+ Upserts a trace to the Judgment API (always overwrites if exists).
258
+
259
+ Args:
260
+ trace_data: The trace data to upsert
261
+ offline_mode: Whether running in offline mode
262
+ show_link: Whether to show the UI link (for live tracing)
263
+ final_save: Whether this is the final save (controls S3 saving)
264
+
265
+ Returns:
266
+ dict: Server response containing UI URL and other metadata
267
+ """
268
+ def fallback_encoder(obj):
269
+ """
270
+ Custom JSON encoder fallback.
271
+ Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
272
+ """
273
+ try:
274
+ return repr(obj)
275
+ except Exception:
276
+ try:
277
+ return str(obj)
278
+ except Exception as e:
279
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
280
+
281
+ serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
282
+
283
+ response = requests.post(
284
+ JUDGMENT_TRACES_UPSERT_API_URL,
285
+ data=serialized_trace_data,
286
+ headers={
287
+ "Content-Type": "application/json",
288
+ "Authorization": f"Bearer {self.judgment_api_key}",
289
+ "X-Organization-Id": self.organization_id
290
+ },
291
+ verify=True
292
+ )
293
+
294
+ if response.status_code != HTTPStatus.OK:
295
+ raise ValueError(f"Failed to upsert trace data: {response.text}")
296
+
297
+ # Parse server response
298
+ server_response = response.json()
299
+
300
+ # If S3 storage is enabled, save to S3 only on final save
301
+ if self.tracer and self.tracer.use_s3 and final_save:
302
+ try:
303
+ s3_key = self.tracer.s3_storage.save_trace(
304
+ trace_data=trace_data,
305
+ trace_id=trace_data["trace_id"],
306
+ project_name=trace_data["project_name"]
307
+ )
308
+ print(f"Trace also saved to S3 at key: {s3_key}")
309
+ except Exception as e:
310
+ warnings.warn(f"Failed to save trace to S3: {str(e)}")
311
+
312
+ if not offline_mode and show_link and "ui_results_url" in server_response:
313
+ pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
314
+ rprint(pretty_str)
315
+
316
+ return server_response
317
+
318
+ def update_usage_counters(self, count: int = 1):
319
+ """
320
+ Update trace usage counters after successfully saving traces.
321
+
322
+ Args:
323
+ count: Number of traces to count (default: 1)
324
+
325
+ Returns:
326
+ dict: Server response with updated usage information
327
+
328
+ Raises:
329
+ ValueError: If the update fails
330
+ """
331
+ response = requests.post(
332
+ JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
333
+ json={"count": count},
334
+ headers={
335
+ "Content-Type": "application/json",
336
+ "Authorization": f"Bearer {self.judgment_api_key}",
337
+ "X-Organization-Id": self.organization_id
338
+ },
339
+ verify=True
340
+ )
341
+
342
+ if response.status_code != HTTPStatus.OK:
343
+ raise ValueError(f"Failed to update usage counters: {response.text}")
344
+
345
+ return response.json()
205
346
 
206
347
  ## TODO: Should have a log endpoint, endpoint should also support batched payloads
207
348
  def save_annotation(self, annotation: TraceAnnotation):
@@ -324,35 +465,48 @@ class TraceClient:
324
465
  self.span_id_to_span: Dict[str, TraceSpan] = {}
325
466
  self.evaluation_runs: List[EvaluationRun] = []
326
467
  self.annotations: List[TraceAnnotation] = []
327
- self.start_time = time.time()
468
+ self.start_time = None # Will be set after first successful save
328
469
  self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
329
470
  self.visited_nodes = []
330
471
  self.executed_tools = []
331
472
  self.executed_node_tools = []
332
473
  self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
474
+
475
+ # Get background span service from tracer
476
+ self.background_span_service = tracer.get_background_span_service() if tracer else None
333
477
 
334
478
  def get_current_span(self):
335
479
  """Get the current span from the context var"""
336
- return current_span_var.get()
480
+ return self.tracer.get_current_span()
337
481
 
338
482
  def set_current_span(self, span: Any):
339
483
  """Set the current span from the context var"""
340
- return current_span_var.set(span)
484
+ return self.tracer.set_current_span(span)
341
485
 
342
486
  def reset_current_span(self, token: Any):
343
487
  """Reset the current span from the context var"""
344
- return current_span_var.reset(token)
488
+ self.tracer.reset_current_span(token)
345
489
 
346
490
  @contextmanager
347
491
  def span(self, name: str, span_type: SpanType = "span"):
348
492
  """Context manager for creating a trace span, managing the current span via contextvars"""
493
+ is_first_span = len(self.trace_spans) == 0
494
+ if is_first_span:
495
+ try:
496
+ trace_id, server_response = self.save_with_rate_limiting(overwrite=self.overwrite, final_save=False)
497
+ # Set start_time after first successful save
498
+ if self.start_time is None:
499
+ self.start_time = time.time()
500
+ # Link will be shown by upsert_trace method
501
+ except Exception as e:
502
+ warnings.warn(f"Failed to save initial trace for live tracking: {e}")
349
503
  start_time = time.time()
350
504
 
351
505
  # Generate a unique ID for *this specific span invocation*
352
506
  span_id = str(uuid.uuid4())
353
507
 
354
- parent_span_id = current_span_var.get() # Get ID of the parent span from context var
355
- token = current_span_var.set(span_id) # Set *this* span's ID as the current one
508
+ parent_span_id = self.get_current_span() # Get ID of the parent span from context var
509
+ token = self.set_current_span(span_id) # Set *this* span's ID as the current one
356
510
 
357
511
  current_depth = 0
358
512
  if parent_span_id and parent_span_id in self._span_depths:
@@ -372,16 +526,27 @@ class TraceClient:
372
526
  )
373
527
  self.add_span(span)
374
528
 
529
+
530
+
531
+ # Queue span with initial state (input phase)
532
+ if self.background_span_service:
533
+ self.background_span_service.queue_span(span, span_state="input")
534
+
375
535
  try:
376
536
  yield self
377
537
  finally:
378
538
  duration = time.time() - start_time
379
539
  span.duration = duration
540
+
541
+ # Queue span with completed state (output phase)
542
+ if self.background_span_service:
543
+ self.background_span_service.queue_span(span, span_state="completed")
544
+
380
545
  # Clean up depth tracking for this span_id
381
546
  if span_id in self._span_depths:
382
547
  del self._span_depths[span_id]
383
548
  # Reset context var
384
- current_span_var.reset(token)
549
+ self.reset_current_span(token)
385
550
 
386
551
  def async_evaluate(
387
552
  self,
@@ -445,8 +610,7 @@ class TraceClient:
445
610
  # span_id_at_eval_call = current_span_var.get()
446
611
  # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
447
612
  # Prioritize explicitly passed span_id, fallback to context var
448
- current_span_ctx_var = current_span_var.get()
449
- span_id_to_use = span_id if span_id is not None else current_span_ctx_var if current_span_ctx_var is not None else self.tracer.get_current_span()
613
+ span_id_to_use = span_id if span_id is not None else self.get_current_span()
450
614
  # print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
451
615
  # --- End Modification ---
452
616
 
@@ -469,6 +633,17 @@ class TraceClient:
469
633
  )
470
634
 
471
635
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
636
+
637
+ # Queue evaluation run through background service
638
+ if self.background_span_service and span_id_to_use:
639
+ # Get the current span data to avoid race conditions
640
+ current_span = self.span_id_to_span.get(span_id_to_use)
641
+ if current_span:
642
+ self.background_span_service.queue_evaluation_run(
643
+ eval_run,
644
+ span_id=span_id_to_use,
645
+ span_data=current_span
646
+ )
472
647
 
473
648
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
474
649
  # --- Modification: Use span_id from eval_run ---
@@ -488,19 +663,27 @@ class TraceClient:
488
663
  return self
489
664
 
490
665
  def record_input(self, inputs: dict):
491
- current_span_id = current_span_var.get()
666
+ current_span_id = self.get_current_span()
492
667
  if current_span_id:
493
668
  span = self.span_id_to_span[current_span_id]
494
669
  # Ignore self parameter
495
670
  if "self" in inputs:
496
671
  del inputs["self"]
497
672
  span.inputs = inputs
673
+
674
+ # Queue span with input data
675
+ if self.background_span_service:
676
+ self.background_span_service.queue_span(span, span_state="input")
498
677
 
499
678
  def record_agent_name(self, agent_name: str):
500
- current_span_id = current_span_var.get()
679
+ current_span_id = self.get_current_span()
501
680
  if current_span_id:
502
681
  span = self.span_id_to_span[current_span_id]
503
682
  span.agent_name = agent_name
683
+
684
+ # Queue span with agent_name data
685
+ if self.background_span_service:
686
+ self.background_span_service.queue_span(span, span_state="agent_name")
504
687
 
505
688
  def record_state_before(self, state: dict):
506
689
  """Records the agent's state before a tool execution on the current span.
@@ -508,60 +691,91 @@ class TraceClient:
508
691
  Args:
509
692
  state: A dictionary representing the agent's state.
510
693
  """
511
- current_span_id = current_span_var.get()
694
+ current_span_id = self.get_current_span()
512
695
  if current_span_id:
513
696
  span = self.span_id_to_span[current_span_id]
514
697
  span.state_before = state
515
698
 
699
+ # Queue span with state_before data
700
+ if self.background_span_service:
701
+ self.background_span_service.queue_span(span, span_state="state_before")
702
+
516
703
  def record_state_after(self, state: dict):
517
704
  """Records the agent's state after a tool execution on the current span.
518
705
 
519
706
  Args:
520
707
  state: A dictionary representing the agent's state.
521
708
  """
522
- current_span_id = current_span_var.get()
709
+ current_span_id = self.get_current_span()
523
710
  if current_span_id:
524
711
  span = self.span_id_to_span[current_span_id]
525
712
  span.state_after = state
713
+
714
+ # Queue span with state_after data
715
+ if self.background_span_service:
716
+ self.background_span_service.queue_span(span, span_state="state_after")
526
717
 
527
718
  async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
528
719
  """Helper method to update the output of a trace entry once the coroutine completes"""
529
720
  try:
530
721
  result = await coroutine
531
722
  setattr(span, field, result)
723
+
724
+ # Queue span with output data now that coroutine is complete
725
+ if self.background_span_service and field == "output":
726
+ self.background_span_service.queue_span(span, span_state="output")
727
+
532
728
  return result
533
729
  except Exception as e:
534
730
  setattr(span, field, f"Error: {str(e)}")
731
+
732
+ # Queue span even if there was an error
733
+ if self.background_span_service and field == "output":
734
+ self.background_span_service.queue_span(span, span_state="output")
735
+
535
736
  raise
536
737
 
537
738
  def record_output(self, output: Any):
538
- current_span_id = current_span_var.get()
739
+ current_span_id = self.get_current_span()
539
740
  if current_span_id:
540
741
  span = self.span_id_to_span[current_span_id]
541
742
  span.output = "<pending>" if inspect.iscoroutine(output) else output
542
743
 
543
744
  if inspect.iscoroutine(output):
544
745
  asyncio.create_task(self._update_coroutine(span, output, "output"))
746
+
747
+ # # Queue span with output data (unless it's pending)
748
+ if self.background_span_service and not inspect.iscoroutine(output):
749
+ self.background_span_service.queue_span(span, span_state="output")
545
750
 
546
751
  return span # Return the created entry
547
752
  # Removed else block - original didn't have one
548
753
  return None # Return None if no span_id found
549
754
 
550
755
  def record_usage(self, usage: TraceUsage):
551
- current_span_id = current_span_var.get()
756
+ current_span_id = self.get_current_span()
552
757
  if current_span_id:
553
758
  span = self.span_id_to_span[current_span_id]
554
759
  span.usage = usage
555
760
 
761
+ # Queue span with usage data
762
+ if self.background_span_service:
763
+ self.background_span_service.queue_span(span, span_state="usage")
764
+
556
765
  return span # Return the created entry
557
766
  # Removed else block - original didn't have one
558
767
  return None # Return None if no span_id found
559
768
 
560
769
  def record_error(self, error: Dict[str, Any]):
561
- current_span_id = current_span_var.get()
770
+ current_span_id = self.get_current_span()
562
771
  if current_span_id:
563
772
  span = self.span_id_to_span[current_span_id]
564
773
  span.error = error
774
+
775
+ # Queue span with error data
776
+ if self.background_span_service:
777
+ self.background_span_service.queue_span(span, span_state="error")
778
+
565
779
  return span
566
780
  return None
567
781
 
@@ -580,13 +794,19 @@ class TraceClient:
580
794
  """
581
795
  Get the total duration of this trace
582
796
  """
797
+ if self.start_time is None:
798
+ return 0.0 # No duration if trace hasn't been saved yet
583
799
  return time.time() - self.start_time
584
800
 
585
801
  def save(self, overwrite: bool = False) -> Tuple[str, dict]:
586
802
  """
587
803
  Save the current trace to the database.
588
- Returns a tuple of (trace_id, trace_data) where trace_data is the trace data that was saved.
804
+ Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
589
805
  """
806
+ # Set start_time if this is the first save
807
+ if self.start_time is None:
808
+ self.start_time = time.time()
809
+
590
810
  # Calculate total elapsed time
591
811
  total_duration = self.get_duration()
592
812
  # Create trace document - Always use standard keys for top-level counts
@@ -594,7 +814,7 @@ class TraceClient:
594
814
  "trace_id": self.trace_id,
595
815
  "name": self.name,
596
816
  "project_name": self.project_name,
597
- "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
817
+ "created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
598
818
  "duration": total_duration,
599
819
  "trace_spans": [span.model_dump() for span in self.trace_spans],
600
820
  "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
@@ -604,14 +824,79 @@ class TraceClient:
604
824
  "parent_name": self.parent_name
605
825
  }
606
826
  # --- Log trace data before saving ---
607
- self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode)
827
+ server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
608
828
 
609
829
  # upload annotations
610
830
  # TODO: batch to the log endpoint
611
831
  for annotation in self.annotations:
612
832
  self.trace_manager_client.save_annotation(annotation)
613
833
 
614
- return self.trace_id, trace_data
834
+ return self.trace_id, server_response
835
+
836
+ def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
837
+ """
838
+ Save the current trace to the database with rate limiting checks.
839
+ First checks usage limits, then upserts the trace if allowed.
840
+
841
+ Args:
842
+ overwrite: Whether to overwrite existing traces
843
+ final_save: Whether this is the final save (updates usage counters)
844
+
845
+ Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
846
+ """
847
+
848
+
849
+ # Calculate total elapsed time
850
+ total_duration = self.get_duration()
851
+
852
+ # Create trace document
853
+ trace_data = {
854
+ "trace_id": self.trace_id,
855
+ "name": self.name,
856
+ "project_name": self.project_name,
857
+ "created_at": datetime.utcfromtimestamp(time.time()).isoformat(),
858
+ "duration": total_duration,
859
+ "trace_spans": [span.model_dump() for span in self.trace_spans],
860
+ "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
861
+ "overwrite": overwrite,
862
+ "offline_mode": self.tracer.offline_mode,
863
+ "parent_trace_id": self.parent_trace_id,
864
+ "parent_name": self.parent_name
865
+ }
866
+
867
+ # Check usage limits first
868
+ try:
869
+ usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
870
+ # Usage check passed silently - no need to show detailed info
871
+ except ValueError as e:
872
+ # Rate limit exceeded
873
+ warnings.warn(f"Rate limit check failed for live tracing: {e}")
874
+ raise e
875
+
876
+ # If usage check passes, upsert the trace
877
+ server_response = self.trace_manager_client.upsert_trace(
878
+ trace_data,
879
+ offline_mode=self.tracer.offline_mode,
880
+ show_link=not final_save, # Show link only on initial save, not final save
881
+ final_save=final_save # Pass final_save to control S3 saving
882
+ )
883
+
884
+ # Update usage counters only on final save
885
+ if final_save:
886
+ try:
887
+ usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
888
+ # Usage updated silently - no need to show detailed usage info
889
+ except ValueError as e:
890
+ # Log warning but don't fail the trace save since the trace was already saved
891
+ warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
892
+
893
+ # Upload annotations
894
+ # TODO: batch to the log endpoint
895
+ for annotation in self.annotations:
896
+ self.trace_manager_client.save_annotation(annotation)
897
+ if self.start_time is None:
898
+ self.start_time = time.time()
899
+ return self.trace_id, server_response
615
900
 
616
901
  def delete(self):
617
902
  return self.trace_manager_client.delete_trace(self.trace_id)
@@ -648,6 +933,338 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
648
933
  pass
649
934
 
650
935
  current_trace.record_error(formatted_exception)
936
+
937
+ # Queue the span with error state through background service
938
+ if current_trace.background_span_service:
939
+ current_span_id = current_trace.get_current_span()
940
+ if current_span_id and current_span_id in current_trace.span_id_to_span:
941
+ error_span = current_trace.span_id_to_span[current_span_id]
942
+ current_trace.background_span_service.queue_span(error_span, span_state="error")
943
+
944
+ class BackgroundSpanService:
945
+ """
946
+ Background service for queueing and batching trace spans for efficient saving.
947
+
948
+ This service:
949
+ - Queues spans as they complete
950
+ - Batches them for efficient network usage
951
+ - Sends spans periodically or when batches reach a certain size
952
+ - Handles automatic flushing when the main event terminates
953
+ """
954
+
955
+ def __init__(self, judgment_api_key: str, organization_id: str,
956
+ batch_size: int = 10, flush_interval: float = 5.0, num_workers: int = 1):
957
+ """
958
+ Initialize the background span service.
959
+
960
+ Args:
961
+ judgment_api_key: API key for Judgment service
962
+ organization_id: Organization ID
963
+ batch_size: Number of spans to batch before sending (default: 10)
964
+ flush_interval: Time in seconds between automatic flushes (default: 5.0)
965
+ num_workers: Number of worker threads to process the queue (default: 1)
966
+ """
967
+ self.judgment_api_key = judgment_api_key
968
+ self.organization_id = organization_id
969
+ self.batch_size = batch_size
970
+ self.flush_interval = flush_interval
971
+ self.num_workers = max(1, num_workers) # Ensure at least 1 worker
972
+
973
+ # Queue for pending spans
974
+ self._span_queue = queue.Queue()
975
+
976
+ # Background threads for processing spans
977
+ self._worker_threads = []
978
+ self._shutdown_event = threading.Event()
979
+
980
+ # Track spans that have been sent
981
+ self._sent_spans = set()
982
+
983
+ # Register cleanup on exit
984
+ atexit.register(self.shutdown)
985
+
986
+ # Start the background workers
987
+ self._start_workers()
988
+
989
+ def _start_workers(self):
990
+ """Start the background worker threads."""
991
+ for i in range(self.num_workers):
992
+ if len(self._worker_threads) < self.num_workers:
993
+ worker_thread = threading.Thread(
994
+ target=self._worker_loop,
995
+ daemon=True,
996
+ name=f"SpanWorker-{i+1}"
997
+ )
998
+ worker_thread.start()
999
+ self._worker_threads.append(worker_thread)
1000
+
1001
+ def _worker_loop(self):
1002
+ """Main worker loop that processes spans in batches."""
1003
+ batch = []
1004
+ last_flush_time = time.time()
1005
+ pending_task_count = 0 # Track how many tasks we've taken from queue but not marked done
1006
+
1007
+ while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
1008
+ try:
1009
+ # First, do a blocking get to wait for at least one item
1010
+ if not batch: # Only block if we don't have items already
1011
+ try:
1012
+ span_data = self._span_queue.get(timeout=1.0)
1013
+ batch.append(span_data)
1014
+ pending_task_count += 1
1015
+ except queue.Empty:
1016
+ # No new spans, continue to check for flush conditions
1017
+ pass
1018
+
1019
+ # Then, do non-blocking gets to drain any additional available items
1020
+ # up to our batch size limit
1021
+ while len(batch) < self.batch_size:
1022
+ try:
1023
+ span_data = self._span_queue.get_nowait() # Non-blocking
1024
+ batch.append(span_data)
1025
+ pending_task_count += 1
1026
+ except queue.Empty:
1027
+ # No more items immediately available
1028
+ break
1029
+
1030
+ current_time = time.time()
1031
+ should_flush = (
1032
+ len(batch) >= self.batch_size or
1033
+ (batch and (current_time - last_flush_time) >= self.flush_interval)
1034
+ )
1035
+
1036
+ if should_flush and batch:
1037
+ self._send_batch(batch)
1038
+
1039
+ # Only mark tasks as done after successful sending
1040
+ for _ in range(pending_task_count):
1041
+ self._span_queue.task_done()
1042
+ pending_task_count = 0 # Reset counter
1043
+
1044
+ batch.clear()
1045
+ last_flush_time = current_time
1046
+
1047
+ except Exception as e:
1048
+ warnings.warn(f"Error in span service worker loop: {e}")
1049
+ # On error, still need to mark tasks as done to prevent hanging
1050
+ for _ in range(pending_task_count):
1051
+ self._span_queue.task_done()
1052
+ pending_task_count = 0
1053
+ batch.clear()
1054
+
1055
+ # Final flush on shutdown
1056
+ if batch:
1057
+ self._send_batch(batch)
1058
+ # Mark remaining tasks as done
1059
+ for _ in range(pending_task_count):
1060
+ self._span_queue.task_done()
1061
+
1062
+ def _send_batch(self, batch: List[Dict[str, Any]]):
1063
+ """
1064
+ Send a batch of spans to the server.
1065
+
1066
+ Args:
1067
+ batch: List of span dictionaries to send
1068
+ """
1069
+ if not batch:
1070
+ return
1071
+
1072
+ try:
1073
+ # Group spans by type for different endpoints
1074
+ spans_to_send = []
1075
+ evaluation_runs_to_send = []
1076
+
1077
+ for item in batch:
1078
+ if item['type'] == 'span':
1079
+ spans_to_send.append(item['data'])
1080
+ elif item['type'] == 'evaluation_run':
1081
+ evaluation_runs_to_send.append(item['data'])
1082
+
1083
+ # Send spans if any
1084
+ if spans_to_send:
1085
+ self._send_spans_batch(spans_to_send)
1086
+
1087
+ # Send evaluation runs if any
1088
+ if evaluation_runs_to_send:
1089
+ self._send_evaluation_runs_batch(evaluation_runs_to_send)
1090
+
1091
+ except Exception as e:
1092
+ warnings.warn(f"Failed to send span batch: {e}")
1093
+
1094
+ def _send_spans_batch(self, spans: List[Dict[str, Any]]):
1095
+ """Send a batch of spans to the spans endpoint."""
1096
+ payload = {
1097
+ "spans": spans,
1098
+ "organization_id": self.organization_id
1099
+ }
1100
+
1101
+ # Serialize with fallback encoder
1102
+ def fallback_encoder(obj):
1103
+ try:
1104
+ return repr(obj)
1105
+ except Exception:
1106
+ try:
1107
+ return str(obj)
1108
+ except Exception as e:
1109
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1110
+
1111
+ try:
1112
+ serialized_data = json.dumps(payload, default=fallback_encoder)
1113
+
1114
+ # Send the actual HTTP request to the batch endpoint
1115
+ response = requests.post(
1116
+ JUDGMENT_TRACES_SPANS_BATCH_API_URL,
1117
+ data=serialized_data,
1118
+ headers={
1119
+ "Content-Type": "application/json",
1120
+ "Authorization": f"Bearer {self.judgment_api_key}",
1121
+ "X-Organization-Id": self.organization_id
1122
+ },
1123
+ verify=True,
1124
+ timeout=30 # Add timeout to prevent hanging
1125
+ )
1126
+
1127
+ if response.status_code != HTTPStatus.OK:
1128
+ warnings.warn(f"Failed to send spans batch: HTTP {response.status_code} - {response.text}")
1129
+
1130
+
1131
+ except requests.RequestException as e:
1132
+ warnings.warn(f"Network error sending spans batch: {e}")
1133
+ except Exception as e:
1134
+ warnings.warn(f"Failed to serialize or send spans batch: {e}")
1135
+
1136
+ def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
1137
+ """Send a batch of evaluation runs with their associated span data to the endpoint."""
1138
+ # Structure payload to include both evaluation run data and span data
1139
+ evaluation_entries = []
1140
+ for eval_data in evaluation_runs:
1141
+ # eval_data already contains the evaluation run data (no need to access ['data'])
1142
+ entry = {
1143
+ "evaluation_run": {
1144
+ # Extract evaluation run fields (excluding span-specific fields)
1145
+ key: value for key, value in eval_data.items()
1146
+ if key not in ['associated_span_id', 'span_data', 'queued_at']
1147
+ },
1148
+ "associated_span": {
1149
+ "span_id": eval_data.get('associated_span_id'),
1150
+ "span_data": eval_data.get('span_data')
1151
+ },
1152
+ "queued_at": eval_data.get('queued_at')
1153
+ }
1154
+ evaluation_entries.append(entry)
1155
+
1156
+ payload = {
1157
+ "organization_id": self.organization_id,
1158
+ "evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
1159
+ }
1160
+
1161
+ # Serialize with fallback encoder
1162
+ def fallback_encoder(obj):
1163
+ try:
1164
+ return repr(obj)
1165
+ except Exception:
1166
+ try:
1167
+ return str(obj)
1168
+ except Exception as e:
1169
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1170
+
1171
+ try:
1172
+ serialized_data = json.dumps(payload, default=fallback_encoder)
1173
+
1174
+ # Send the actual HTTP request to the batch endpoint
1175
+ response = requests.post(
1176
+ JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
1177
+ data=serialized_data,
1178
+ headers={
1179
+ "Content-Type": "application/json",
1180
+ "Authorization": f"Bearer {self.judgment_api_key}",
1181
+ "X-Organization-Id": self.organization_id
1182
+ },
1183
+ verify=True,
1184
+ timeout=30 # Add timeout to prevent hanging
1185
+ )
1186
+
1187
+ if response.status_code != HTTPStatus.OK:
1188
+ warnings.warn(f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}")
1189
+
1190
+
1191
+ except requests.RequestException as e:
1192
+ warnings.warn(f"Network error sending evaluation runs batch: {e}")
1193
+ except Exception as e:
1194
+ warnings.warn(f"Failed to send evaluation runs batch: {e}")
1195
+
1196
+ def queue_span(self, span: TraceSpan, span_state: str = "input"):
1197
+ """
1198
+ Queue a span for background sending.
1199
+
1200
+ Args:
1201
+ span: The TraceSpan object to queue
1202
+ span_state: State of the span ("input", "output", "completed")
1203
+ """
1204
+ if not self._shutdown_event.is_set():
1205
+ span_data = {
1206
+ "type": "span",
1207
+ "data": {
1208
+ **span.model_dump(),
1209
+ "span_state": span_state,
1210
+ "queued_at": time.time()
1211
+ }
1212
+ }
1213
+ self._span_queue.put(span_data)
1214
+
1215
+ def queue_evaluation_run(self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan):
1216
+ """
1217
+ Queue an evaluation run for background sending.
1218
+
1219
+ Args:
1220
+ evaluation_run: The EvaluationRun object to queue
1221
+ span_id: The span ID associated with this evaluation run
1222
+ span_data: The span data at the time of evaluation (to avoid race conditions)
1223
+ """
1224
+ if not self._shutdown_event.is_set():
1225
+ eval_data = {
1226
+ "type": "evaluation_run",
1227
+ "data": {
1228
+ **evaluation_run.model_dump(),
1229
+ "associated_span_id": span_id,
1230
+ "span_data": span_data.model_dump(), # Include span data to avoid race conditions
1231
+ "queued_at": time.time()
1232
+ }
1233
+ }
1234
+ self._span_queue.put(eval_data)
1235
+
1236
+ def flush(self):
1237
+ """Force immediate sending of all queued spans."""
1238
+ try:
1239
+ # Wait for the queue to be processed
1240
+ self._span_queue.join()
1241
+ except Exception as e:
1242
+ warnings.warn(f"Error during flush: {e}")
1243
+
1244
+ def shutdown(self):
1245
+ """Shutdown the background service and flush remaining spans."""
1246
+ if self._shutdown_event.is_set():
1247
+ return
1248
+
1249
+ try:
1250
+ # Signal shutdown to stop new items from being queued
1251
+ self._shutdown_event.set()
1252
+
1253
+ # Try to flush any remaining spans
1254
+ try:
1255
+ self.flush()
1256
+ except Exception as e:
1257
+ warnings.warn(f"Error during final flush: {e}")
1258
+ except Exception as e:
1259
+ warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
1260
+ finally:
1261
+ # Clear the worker threads list (daemon threads will be killed automatically)
1262
+ self._worker_threads.clear()
1263
+
1264
+ def get_queue_size(self) -> int:
1265
+ """Get the current size of the span queue."""
1266
+ return self._span_queue.qsize()
1267
+
651
1268
  class _DeepTracer:
652
1269
  _instance: Optional["_DeepTracer"] = None
653
1270
  _lock: threading.Lock = threading.Lock()
@@ -657,6 +1274,9 @@ class _DeepTracer:
657
1274
  _original_sys_trace: Optional[Callable] = None
658
1275
  _original_threading_trace: Optional[Callable] = None
659
1276
 
1277
+ def __init__(self, tracer: 'Tracer'):
1278
+ self._tracer = tracer
1279
+
660
1280
  def _get_qual_name(self, frame) -> str:
661
1281
  func_name = frame.f_code.co_name
662
1282
  module_name = frame.f_globals.get("__name__", "unknown_module")
@@ -670,7 +1290,7 @@ class _DeepTracer:
670
1290
  except Exception:
671
1291
  return f"{module_name}.{func_name}"
672
1292
 
673
- def __new__(cls):
1293
+ def __new__(cls, tracer: 'Tracer' = None):
674
1294
  with cls._lock:
675
1295
  if cls._instance is None:
676
1296
  cls._instance = super().__new__(cls)
@@ -756,11 +1376,11 @@ class _DeepTracer:
756
1376
  if event not in ("call", "return", "exception"):
757
1377
  return
758
1378
 
759
- current_trace = current_trace_var.get()
1379
+ current_trace = self._tracer.get_current_trace()
760
1380
  if not current_trace:
761
1381
  return
762
1382
 
763
- parent_span_id = current_span_var.get()
1383
+ parent_span_id = self._tracer.get_current_span()
764
1384
  if not parent_span_id:
765
1385
  return
766
1386
 
@@ -822,7 +1442,7 @@ class _DeepTracer:
822
1442
  })
823
1443
  self._span_stack.set(span_stack)
824
1444
 
825
- token = current_span_var.set(span_id)
1445
+ token = self._tracer.set_current_span(span_id)
826
1446
  frame.f_locals["_judgment_span_token"] = token
827
1447
 
828
1448
  span = TraceSpan(
@@ -856,7 +1476,7 @@ class _DeepTracer:
856
1476
  if not span_stack:
857
1477
  return
858
1478
 
859
- current_id = current_span_var.get()
1479
+ current_id = self._tracer.get_current_span()
860
1480
 
861
1481
  span_data = None
862
1482
  for i, entry in enumerate(reversed(span_stack)):
@@ -881,12 +1501,12 @@ class _DeepTracer:
881
1501
  del current_trace._span_depths[span_data["span_id"]]
882
1502
 
883
1503
  if span_stack:
884
- current_span_var.set(span_stack[-1]["span_id"])
1504
+ self._tracer.set_current_span(span_stack[-1]["span_id"])
885
1505
  else:
886
- current_span_var.set(span_data["parent_span_id"])
1506
+ self._tracer.set_current_span(span_data["parent_span_id"])
887
1507
 
888
1508
  if "_judgment_span_token" in frame.f_locals:
889
- current_span_var.reset(frame.f_locals["_judgment_span_token"])
1509
+ self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
890
1510
 
891
1511
  elif event == "exception":
892
1512
  exc_type = arg[0]
@@ -925,18 +1545,28 @@ class _DeepTracer:
925
1545
  self._original_threading_trace = None
926
1546
 
927
1547
 
928
- def log(self, message: str, level: str = "info"):
929
- """ Log a message with the span context """
930
- current_trace = current_trace_var.get()
931
- if current_trace:
932
- current_trace.log(message, level)
933
- else:
934
- print(f"[{level}] {message}")
935
- current_trace.record_output({"log": message})
1548
+ # Below commented out function isn't used anymore?
1549
+
1550
+ # def log(self, message: str, level: str = "info"):
1551
+ # """ Log a message with the span context """
1552
+ # current_trace = self._tracer.get_current_trace()
1553
+ # if current_trace:
1554
+ # current_trace.log(message, level)
1555
+ # else:
1556
+ # print(f"[{level}] {message}")
1557
+ # current_trace.record_output({"log": message})
936
1558
 
937
1559
  class Tracer:
938
1560
  _instance = None
939
1561
 
1562
+ # Tracer.current_trace class variable is currently used in wrap()
1563
+ # TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
1564
+ # Should be fine to do so as long as we keep Tracer as a singleton
1565
+ current_trace: Optional[TraceClient] = None
1566
+ # current_span_id: Optional[str] = None
1567
+
1568
+ trace_across_async_contexts: bool = False # BY default, we don't trace across async contexts
1569
+
940
1570
  def __new__(cls, *args, **kwargs):
941
1571
  if cls._instance is None:
942
1572
  cls._instance = super(Tracer, cls).__new__(cls)
@@ -957,7 +1587,13 @@ class Tracer:
957
1587
  s3_aws_secret_access_key: Optional[str] = None,
958
1588
  s3_region_name: Optional[str] = None,
959
1589
  offline_mode: bool = False,
960
- deep_tracing: bool = True # Deep tracing is enabled by default
1590
+ deep_tracing: bool = True, # Deep tracing is enabled by default
1591
+ trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1592
+ # Background span service configuration
1593
+ enable_background_spans: bool = True, # Enable background span service by default
1594
+ span_batch_size: int = 50, # Number of spans to batch before sending
1595
+ span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
1596
+ span_num_workers: int = 10 # Number of worker threads for span processing
961
1597
  ):
962
1598
  if not hasattr(self, 'initialized'):
963
1599
  if not api_key:
@@ -975,14 +1611,18 @@ class Tracer:
975
1611
  self.api_key: str = api_key
976
1612
  self.project_name: str = project_name or str(uuid.uuid4())
977
1613
  self.organization_id: str = organization_id
978
- self._current_trace: Optional[str] = None
979
- self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
980
1614
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
981
1615
  self.traces: List[Trace] = []
982
1616
  self.initialized: bool = True
983
1617
  self.enable_monitoring: bool = enable_monitoring
984
1618
  self.enable_evaluations: bool = enable_evaluations
985
1619
  self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
1620
+ self.span_id_to_previous_span_id: Dict[str, str] = {}
1621
+ self.trace_id_to_previous_trace: Dict[str, TraceClient] = {}
1622
+ self.current_span_id: Optional[str] = None
1623
+ self.current_trace: Optional[TraceClient] = None
1624
+ self.trace_across_async_contexts: bool = trace_across_async_contexts
1625
+ Tracer.trace_across_async_contexts = trace_across_async_contexts
986
1626
 
987
1627
  # Initialize S3 storage if enabled
988
1628
  self.use_s3 = use_s3
@@ -996,6 +1636,18 @@ class Tracer:
996
1636
  )
997
1637
  self.offline_mode: bool = offline_mode
998
1638
  self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1639
+
1640
+ # Initialize background span service
1641
+ self.enable_background_spans: bool = enable_background_spans
1642
+ self.background_span_service: Optional[BackgroundSpanService] = None
1643
+ if enable_background_spans and not offline_mode:
1644
+ self.background_span_service = BackgroundSpanService(
1645
+ judgment_api_key=api_key,
1646
+ organization_id=organization_id,
1647
+ batch_size=span_batch_size,
1648
+ flush_interval=span_flush_interval,
1649
+ num_workers=span_num_workers
1650
+ )
999
1651
 
1000
1652
  elif hasattr(self, 'project_name') and self.project_name != project_name:
1001
1653
  warnings.warn(
@@ -1006,16 +1658,44 @@ class Tracer:
1006
1658
  )
1007
1659
 
1008
1660
  def set_current_span(self, span_id: str):
1661
+ self.span_id_to_previous_span_id[span_id] = getattr(self, 'current_span_id', None)
1009
1662
  self.current_span_id = span_id
1663
+ Tracer.current_span_id = span_id
1664
+ try:
1665
+ token = current_span_var.set(span_id)
1666
+ except:
1667
+ token = None
1668
+ return token
1010
1669
 
1011
1670
  def get_current_span(self) -> Optional[str]:
1012
- return getattr(self, 'current_span_id', None)
1671
+ try:
1672
+ current_span_var_val = current_span_var.get()
1673
+ except:
1674
+ current_span_var_val = None
1675
+ return (self.current_span_id or current_span_var_val) if self.trace_across_async_contexts else current_span_var_val
1676
+
1677
+ def reset_current_span(self, token: Optional[str] = None, span_id: Optional[str] = None):
1678
+ if not span_id:
1679
+ span_id = self.current_span_id
1680
+ try:
1681
+ current_span_var.reset(token)
1682
+ except:
1683
+ pass
1684
+ self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
1685
+ Tracer.current_span_id = self.current_span_id
1013
1686
 
1014
1687
  def set_current_trace(self, trace: TraceClient):
1015
1688
  """
1016
1689
  Set the current trace context in contextvars
1017
1690
  """
1018
- current_trace_var.set(trace)
1691
+ self.trace_id_to_previous_trace[trace.trace_id] = getattr(self, 'current_trace', None)
1692
+ self.current_trace = trace
1693
+ Tracer.current_trace = trace
1694
+ try:
1695
+ token = current_trace_var.set(trace)
1696
+ except:
1697
+ token = None
1698
+ return token
1019
1699
 
1020
1700
  def get_current_trace(self) -> Optional[TraceClient]:
1021
1701
  """
@@ -1025,23 +1705,34 @@ class Tracer:
1025
1705
  If not found (e.g., context lost across threads/tasks),
1026
1706
  it falls back to the active trace client managed by the callback handler.
1027
1707
  """
1028
- trace_from_context = current_trace_var.get()
1029
- if trace_from_context:
1030
- return trace_from_context
1708
+ try:
1709
+ current_trace_var_val = current_trace_var.get()
1710
+ except:
1711
+ current_trace_var_val = None
1712
+
1713
+ # Use context variable or class variable based on trace_across_async_contexts setting
1714
+ context_trace = (self.current_trace or current_trace_var_val) if self.trace_across_async_contexts else current_trace_var_val
1715
+
1716
+ # If we found a trace from context, return it
1717
+ if context_trace:
1718
+ return context_trace
1031
1719
 
1032
- # Fallback: Check the active client potentially set by a callback handler
1720
+ # Fallback: Check the active client potentially set by a callback handler (e.g., LangGraph)
1033
1721
  if hasattr(self, '_active_trace_client') and self._active_trace_client:
1034
- # warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
1035
1722
  return self._active_trace_client
1036
1723
 
1037
- # If neither is available
1038
- # warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
1724
+ # If neither is available, return None
1039
1725
  return None
1040
-
1041
- def get_active_trace_client(self) -> Optional[TraceClient]:
1042
- """Returns the TraceClient instance currently marked as active by the handler."""
1043
- return self._active_trace_client
1044
-
1726
+
1727
+ def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
1728
+ if not trace_id and self.current_trace:
1729
+ trace_id = self.current_trace.trace_id
1730
+ try:
1731
+ current_trace_var.reset(token)
1732
+ except:
1733
+ pass
1734
+ self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1735
+ Tracer.current_trace = self.current_trace
1045
1736
 
1046
1737
  @contextmanager
1047
1738
  def trace(
@@ -1056,7 +1747,7 @@ class Tracer:
1056
1747
  project = project_name if project_name is not None else self.project_name
1057
1748
 
1058
1749
  # Get parent trace info from context
1059
- parent_trace = current_trace_var.get()
1750
+ parent_trace = self.get_current_trace()
1060
1751
  parent_trace_id = None
1061
1752
  parent_name = None
1062
1753
 
@@ -1078,7 +1769,7 @@ class Tracer:
1078
1769
  )
1079
1770
 
1080
1771
  # Set the current trace in context variables
1081
- token = current_trace_var.set(trace)
1772
+ token = self.set_current_trace(trace)
1082
1773
 
1083
1774
  # Automatically create top-level span
1084
1775
  with trace.span(name or "unnamed_trace") as span:
@@ -1087,13 +1778,13 @@ class Tracer:
1087
1778
  yield trace
1088
1779
  finally:
1089
1780
  # Reset the context variable
1090
- current_trace_var.reset(token)
1781
+ self.reset_current_trace(token)
1091
1782
 
1092
1783
 
1093
1784
  def log(self, msg: str, label: str = "log", score: int = 1):
1094
1785
  """Log a message with the current span context"""
1095
- current_span_id = current_span_var.get()
1096
- current_trace = current_trace_var.get()
1786
+ current_span_id = self.get_current_span()
1787
+ current_trace = self.get_current_trace()
1097
1788
  if current_span_id:
1098
1789
  annotation = TraceAnnotation(
1099
1790
  span_id=current_span_id,
@@ -1237,7 +1928,7 @@ class Tracer:
1237
1928
  agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1238
1929
 
1239
1930
  # Get current trace from context
1240
- current_trace = current_trace_var.get()
1931
+ current_trace = self.get_current_trace()
1241
1932
 
1242
1933
  # If there's no current trace, create a root trace
1243
1934
  if not current_trace:
@@ -1258,7 +1949,7 @@ class Tracer:
1258
1949
 
1259
1950
  # Save empty trace and set trace context
1260
1951
  # current_trace.save(empty_save=True, overwrite=overwrite)
1261
- trace_token = current_trace_var.set(current_trace)
1952
+ trace_token = self.set_current_trace(current_trace)
1262
1953
 
1263
1954
  try:
1264
1955
  # Use span for the function execution within the root trace
@@ -1274,7 +1965,7 @@ class Tracer:
1274
1965
  self._conditionally_capture_and_record_state(span, args, is_before=True)
1275
1966
 
1276
1967
  if use_deep_tracing:
1277
- with _DeepTracer():
1968
+ with _DeepTracer(self):
1278
1969
  result = await func(*args, **kwargs)
1279
1970
  else:
1280
1971
  try:
@@ -1290,12 +1981,31 @@ class Tracer:
1290
1981
  span.record_output(result)
1291
1982
  return result
1292
1983
  finally:
1984
+ # Flush background spans before saving the trace
1985
+
1986
+ complete_trace_data = {
1987
+ "trace_id": current_trace.trace_id,
1988
+ "name": current_trace.name,
1989
+ "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
1990
+ "duration": current_trace.get_duration(),
1991
+ "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
1992
+ "overwrite": overwrite,
1993
+ "offline_mode": self.offline_mode,
1994
+ "parent_trace_id": current_trace.parent_trace_id,
1995
+ "parent_name": current_trace.parent_name
1996
+ }
1293
1997
  # Save the completed trace
1294
- trace_id, trace = current_trace.save(overwrite=overwrite)
1295
- self.traces.append(trace)
1998
+ trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
1999
+
2000
+ # Store the complete trace data instead of just server response
2001
+
2002
+ self.traces.append(complete_trace_data)
2003
+
2004
+ # if self.background_span_service:
2005
+ # self.background_span_service.flush()
1296
2006
 
1297
2007
  # Reset trace context (span context resets automatically)
1298
- current_trace_var.reset(trace_token)
2008
+ self.reset_current_trace(trace_token)
1299
2009
  else:
1300
2010
  with current_trace.span(span_name, span_type=span_type) as span:
1301
2011
  inputs = combine_args_kwargs(func, args, kwargs)
@@ -1307,7 +2017,7 @@ class Tracer:
1307
2017
  self._conditionally_capture_and_record_state(span, args, is_before=True)
1308
2018
 
1309
2019
  if use_deep_tracing:
1310
- with _DeepTracer():
2020
+ with _DeepTracer(self):
1311
2021
  result = await func(*args, **kwargs)
1312
2022
  else:
1313
2023
  try:
@@ -1336,7 +2046,7 @@ class Tracer:
1336
2046
  class_name = args[0].__class__.__name__
1337
2047
  agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1338
2048
  # Get current trace from context
1339
- current_trace = current_trace_var.get()
2049
+ current_trace = self.get_current_trace()
1340
2050
 
1341
2051
  # If there's no current trace, create a root trace
1342
2052
  if not current_trace:
@@ -1357,7 +2067,7 @@ class Tracer:
1357
2067
 
1358
2068
  # Save empty trace and set trace context
1359
2069
  # current_trace.save(empty_save=True, overwrite=overwrite)
1360
- trace_token = current_trace_var.set(current_trace)
2070
+ trace_token = self.set_current_trace(current_trace)
1361
2071
 
1362
2072
  try:
1363
2073
  # Use span for the function execution within the root trace
@@ -1372,7 +2082,7 @@ class Tracer:
1372
2082
  self._conditionally_capture_and_record_state(span, args, is_before=True)
1373
2083
 
1374
2084
  if use_deep_tracing:
1375
- with _DeepTracer():
2085
+ with _DeepTracer(self):
1376
2086
  result = func(*args, **kwargs)
1377
2087
  else:
1378
2088
  try:
@@ -1389,12 +2099,27 @@ class Tracer:
1389
2099
  span.record_output(result)
1390
2100
  return result
1391
2101
  finally:
1392
- # Save the completed trace
1393
- trace_id, trace = current_trace.save(overwrite=overwrite)
1394
- self.traces.append(trace)
2102
+ # Flush background spans before saving the trace
1395
2103
 
2104
+
2105
+ # Save the completed trace
2106
+ trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
2107
+
2108
+ # Store the complete trace data instead of just server response
2109
+ complete_trace_data = {
2110
+ "trace_id": current_trace.trace_id,
2111
+ "name": current_trace.name,
2112
+ "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
2113
+ "duration": current_trace.get_duration(),
2114
+ "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
2115
+ "overwrite": overwrite,
2116
+ "offline_mode": self.offline_mode,
2117
+ "parent_trace_id": current_trace.parent_trace_id,
2118
+ "parent_name": current_trace.parent_name
2119
+ }
2120
+ self.traces.append(complete_trace_data)
1396
2121
  # Reset trace context (span context resets automatically)
1397
- current_trace_var.reset(trace_token)
2122
+ self.reset_current_trace(trace_token)
1398
2123
  else:
1399
2124
  with current_trace.span(span_name, span_type=span_type) as span:
1400
2125
 
@@ -1407,7 +2132,7 @@ class Tracer:
1407
2132
  self._conditionally_capture_and_record_state(span, args, is_before=True)
1408
2133
 
1409
2134
  if use_deep_tracing:
1410
- with _DeepTracer():
2135
+ with _DeepTracer(self):
1411
2136
  result = func(*args, **kwargs)
1412
2137
  else:
1413
2138
  try:
@@ -1424,6 +2149,55 @@ class Tracer:
1424
2149
 
1425
2150
  return wrapper
1426
2151
 
2152
+ def observe_tools(self, cls=None, *, exclude_methods: Optional[List[str]] = None,
2153
+ include_private: bool = False, warn_on_double_decoration: bool = True):
2154
+ """
2155
+ Automatically adds @observe(span_type="tool") to all methods in a class.
2156
+
2157
+ Args:
2158
+ cls: The class to decorate (automatically provided when used as decorator)
2159
+ exclude_methods: List of method names to skip decorating. Defaults to common magic methods
2160
+ include_private: Whether to decorate methods starting with underscore. Defaults to False
2161
+ warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
2162
+ """
2163
+
2164
+ if exclude_methods is None:
2165
+ exclude_methods = ['__init__', '__new__', '__del__', '__str__', '__repr__']
2166
+
2167
+ def decorate_class(cls):
2168
+ if not self.enable_monitoring:
2169
+ return cls
2170
+
2171
+ decorated = []
2172
+ skipped = []
2173
+
2174
+ for name in dir(cls):
2175
+ method = getattr(cls, name)
2176
+
2177
+ if (not callable(method) or
2178
+ name in exclude_methods or
2179
+ (name.startswith('_') and not include_private) or
2180
+ not hasattr(cls, name)):
2181
+ continue
2182
+
2183
+ if hasattr(method, '_judgment_span_name'):
2184
+ skipped.append(name)
2185
+ if warn_on_double_decoration:
2186
+ print(f"Warning: {cls.__name__}.{name} already decorated, skipping")
2187
+ continue
2188
+
2189
+ try:
2190
+ decorated_method = self.observe(method, span_type="tool")
2191
+ setattr(cls, name, decorated_method)
2192
+ decorated.append(name)
2193
+ except Exception as e:
2194
+ if warn_on_double_decoration:
2195
+ print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
2196
+
2197
+ return cls
2198
+
2199
+ return decorate_class if cls is None else decorate_class(cls)
2200
+
1427
2201
  def async_evaluate(self, *args, **kwargs):
1428
2202
  if not self.enable_evaluations:
1429
2203
  return
@@ -1431,13 +2205,7 @@ class Tracer:
1431
2205
  # --- Get trace_id passed explicitly (if any) ---
1432
2206
  passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
1433
2207
 
1434
- # --- Get current trace from context FIRST ---
1435
- current_trace = current_trace_var.get()
1436
-
1437
- # --- Fallback Logic: Use active client only if context var is empty ---
1438
- if not current_trace:
1439
- current_trace = self._active_trace_client # Use the fallback
1440
- # --- End Fallback Logic ---
2208
+ current_trace = self.get_current_trace()
1441
2209
 
1442
2210
  if current_trace:
1443
2211
  # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
@@ -1448,13 +2216,34 @@ class Tracer:
1448
2216
  else:
1449
2217
  warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
1450
2218
 
1451
- def wrap(client: Any) -> Any:
2219
+ def get_background_span_service(self) -> Optional[BackgroundSpanService]:
2220
+ """Get the background span service instance."""
2221
+ return self.background_span_service
2222
+
2223
+ def flush_background_spans(self):
2224
+ """Flush all pending spans in the background service."""
2225
+ if self.background_span_service:
2226
+ self.background_span_service.flush()
2227
+
2228
+ def shutdown_background_service(self):
2229
+ """Shutdown the background span service."""
2230
+ if self.background_span_service:
2231
+ self.background_span_service.shutdown()
2232
+ self.background_span_service = None
2233
+
2234
+ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts) -> Any:
1452
2235
  """
1453
2236
  Wraps an API client to add tracing capabilities.
1454
2237
  Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1455
2238
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1456
2239
  """
1457
- span_name, original_create, original_responses_create, original_stream = _get_client_config(client)
2240
+ span_name, original_create, original_responses_create, original_stream, original_beta_parse = _get_client_config(client)
2241
+
2242
+ def _get_current_trace():
2243
+ if trace_across_async_contexts:
2244
+ return Tracer.current_trace
2245
+ else:
2246
+ return current_trace_var.get()
1458
2247
 
1459
2248
  def _record_input_and_check_streaming(span, kwargs, is_responses=False):
1460
2249
  """Record input and check for streaming"""
@@ -1490,11 +2279,21 @@ def wrap(client: Any) -> Any:
1490
2279
  output, usage = format_func(client, response)
1491
2280
  span.record_output(output)
1492
2281
  span.record_usage(usage)
2282
+
2283
+ # Queue the completed LLM span now that it has all data (input, output, usage)
2284
+ current_trace = _get_current_trace()
2285
+ if current_trace and current_trace.background_span_service:
2286
+ # Get the current span from the trace client
2287
+ current_span_id = current_trace.get_current_span()
2288
+ if current_span_id and current_span_id in current_trace.span_id_to_span:
2289
+ completed_span = current_trace.span_id_to_span[current_span_id]
2290
+ current_trace.background_span_service.queue_span(completed_span, span_state="completed")
2291
+
1493
2292
  return response
1494
2293
 
1495
2294
  # --- Traced Async Functions ---
1496
2295
  async def traced_create_async(*args, **kwargs):
1497
- current_trace = current_trace_var.get()
2296
+ current_trace = _get_current_trace()
1498
2297
  if not current_trace:
1499
2298
  return await original_create(*args, **kwargs)
1500
2299
 
@@ -1508,9 +2307,25 @@ def wrap(client: Any) -> Any:
1508
2307
  _capture_exception_for_trace(span, sys.exc_info())
1509
2308
  raise e
1510
2309
 
2310
+ async def traced_beta_parse_async(*args, **kwargs):
2311
+ current_trace = _get_current_trace()
2312
+ if not current_trace:
2313
+ return await original_beta_parse(*args, **kwargs)
2314
+
2315
+ with current_trace.span(span_name, span_type="llm") as span:
2316
+ is_streaming = _record_input_and_check_streaming(span, kwargs)
2317
+
2318
+ try:
2319
+ response_or_iterator = await original_beta_parse(*args, **kwargs)
2320
+ return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
2321
+ except Exception as e:
2322
+ _capture_exception_for_trace(span, sys.exc_info())
2323
+ raise e
2324
+
2325
+
1511
2326
  # Async responses for OpenAI clients
1512
2327
  async def traced_response_create_async(*args, **kwargs):
1513
- current_trace = current_trace_var.get()
2328
+ current_trace = _get_current_trace()
1514
2329
  if not current_trace:
1515
2330
  return await original_responses_create(*args, **kwargs)
1516
2331
 
@@ -1526,7 +2341,7 @@ def wrap(client: Any) -> Any:
1526
2341
 
1527
2342
  # Function replacing .stream() for async clients
1528
2343
  def traced_stream_async(*args, **kwargs):
1529
- current_trace = current_trace_var.get()
2344
+ current_trace = _get_current_trace()
1530
2345
  if not current_trace or not original_stream:
1531
2346
  return original_stream(*args, **kwargs)
1532
2347
 
@@ -1542,7 +2357,7 @@ def wrap(client: Any) -> Any:
1542
2357
 
1543
2358
  # --- Traced Sync Functions ---
1544
2359
  def traced_create_sync(*args, **kwargs):
1545
- current_trace = current_trace_var.get()
2360
+ current_trace = _get_current_trace()
1546
2361
  if not current_trace:
1547
2362
  return original_create(*args, **kwargs)
1548
2363
 
@@ -1555,9 +2370,24 @@ def wrap(client: Any) -> Any:
1555
2370
  except Exception as e:
1556
2371
  _capture_exception_for_trace(span, sys.exc_info())
1557
2372
  raise e
2373
+
2374
+ def traced_beta_parse_sync(*args, **kwargs):
2375
+ current_trace = _get_current_trace()
2376
+ if not current_trace:
2377
+ return original_beta_parse(*args, **kwargs)
2378
+
2379
+ with current_trace.span(span_name, span_type="llm") as span:
2380
+ is_streaming = _record_input_and_check_streaming(span, kwargs)
2381
+
2382
+ try:
2383
+ response_or_iterator = original_beta_parse(*args, **kwargs)
2384
+ return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
2385
+ except Exception as e:
2386
+ _capture_exception_for_trace(span, sys.exc_info())
2387
+ raise e
1558
2388
 
1559
2389
  def traced_response_create_sync(*args, **kwargs):
1560
- current_trace = current_trace_var.get()
2390
+ current_trace = _get_current_trace()
1561
2391
  if not current_trace:
1562
2392
  return original_responses_create(*args, **kwargs)
1563
2393
 
@@ -1573,7 +2403,7 @@ def wrap(client: Any) -> Any:
1573
2403
 
1574
2404
  # Function replacing sync .stream()
1575
2405
  def traced_stream_sync(*args, **kwargs):
1576
- current_trace = current_trace_var.get()
2406
+ current_trace = _get_current_trace()
1577
2407
  if not current_trace or not original_stream:
1578
2408
  return original_stream(*args, **kwargs)
1579
2409
 
@@ -1592,6 +2422,8 @@ def wrap(client: Any) -> Any:
1592
2422
  client.chat.completions.create = traced_create_async
1593
2423
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
1594
2424
  client.responses.create = traced_response_create_async
2425
+ if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2426
+ client.beta.chat.completions.parse = traced_beta_parse_async
1595
2427
  elif isinstance(client, AsyncAnthropic):
1596
2428
  client.messages.create = traced_create_async
1597
2429
  if original_stream:
@@ -1602,6 +2434,8 @@ def wrap(client: Any) -> Any:
1602
2434
  client.chat.completions.create = traced_create_sync
1603
2435
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
1604
2436
  client.responses.create = traced_response_create_sync
2437
+ if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2438
+ client.beta.chat.completions.parse = traced_beta_parse_sync
1605
2439
  elif isinstance(client, Anthropic):
1606
2440
  client.messages.create = traced_create_sync
1607
2441
  if original_stream:
@@ -1620,23 +2454,24 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
1620
2454
  client: An instance of OpenAI, Together, or Anthropic client
1621
2455
 
1622
2456
  Returns:
1623
- tuple: (span_name, create_method, stream_method)
2457
+ tuple: (span_name, create_method, responses_method, stream_method, beta_parse_method)
1624
2458
  - span_name: String identifier for tracing
1625
2459
  - create_method: Reference to the client's creation method
1626
2460
  - responses_method: Reference to the client's responses method (if applicable)
1627
2461
  - stream_method: Reference to the client's stream method (if applicable)
2462
+ - beta_parse_method: Reference to the client's beta parse method (if applicable)
1628
2463
 
1629
2464
  Raises:
1630
2465
  ValueError: If client type is not supported
1631
2466
  """
1632
2467
  if isinstance(client, (OpenAI, AsyncOpenAI)):
1633
- return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None
2468
+ return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None, client.beta.chat.completions.parse
1634
2469
  elif isinstance(client, (Together, AsyncTogether)):
1635
- return "TOGETHER_API_CALL", client.chat.completions.create, None, None
2470
+ return "TOGETHER_API_CALL", client.chat.completions.create, None, None, None
1636
2471
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1637
- return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream
2472
+ return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream, None
1638
2473
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1639
- return "GOOGLE_API_CALL", client.models.generate_content, None, None
2474
+ return "GOOGLE_API_CALL", client.models.generate_content, None, None, None
1640
2475
  raise ValueError(f"Unsupported client type: {type(client)}")
1641
2476
 
1642
2477
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
@@ -1646,10 +2481,13 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
1646
2481
  to ensure consistent tracing across different APIs.
1647
2482
  """
1648
2483
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1649
- return {
2484
+ input_data = {
1650
2485
  "model": kwargs.get("model"),
1651
2486
  "messages": kwargs.get("messages"),
1652
2487
  }
2488
+ if kwargs.get("response_format"):
2489
+ input_data["response_format"] = kwargs.get("response_format")
2490
+ return input_data
1653
2491
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1654
2492
  return {
1655
2493
  "model": kwargs.get("model"),
@@ -1719,7 +2557,10 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1719
2557
  model_name = response.model
1720
2558
  prompt_tokens = response.usage.prompt_tokens
1721
2559
  completion_tokens = response.usage.completion_tokens
1722
- message_content = response.choices[0].message.content
2560
+ if hasattr(response.choices[0].message, "parsed") and response.choices[0].message.parsed:
2561
+ message_content = response.choices[0].message.parsed
2562
+ else:
2563
+ message_content = response.choices[0].message.content
1723
2564
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1724
2565
  model_name = response.model_version
1725
2566
  prompt_tokens = response.usage_metadata.prompt_token_count
@@ -1928,6 +2769,15 @@ def _sync_stream_wrapper(
1928
2769
  # Update the trace entry with the accumulated content and usage
1929
2770
  span.output = "".join(content_parts)
1930
2771
  span.usage = final_usage
2772
+
2773
+ # Queue the completed LLM span now that streaming is done and all data is available
2774
+ # Note: We need to get the TraceClient that owns this span to access the background service
2775
+ # We can find this through the tracer singleton since spans are associated with traces
2776
+ from judgeval.common.tracer import Tracer
2777
+ tracer_instance = Tracer._instance
2778
+ if tracer_instance and tracer_instance.background_span_service:
2779
+ tracer_instance.background_span_service.queue_span(span, span_state="completed")
2780
+
1931
2781
  # Note: We might need to adjust _serialize_output if this dict causes issues,
1932
2782
  # but Pydantic's model_dump should handle dicts.
1933
2783
 
@@ -2012,6 +2862,12 @@ async def _async_stream_wrapper(
2012
2862
  span.usage = usage_info
2013
2863
  start_ts = getattr(span, 'created_at', time.time())
2014
2864
  span.duration = time.time() - start_ts
2865
+
2866
+ # Queue the completed LLM span now that async streaming is done and all data is available
2867
+ from judgeval.common.tracer import Tracer
2868
+ tracer_instance = Tracer._instance
2869
+ if tracer_instance and tracer_instance.background_span_service:
2870
+ tracer_instance.background_span_service.queue_span(span, span_state="completed")
2015
2871
  # else: # Handle error case if necessary, but remove debug print
2016
2872
 
2017
2873
  def cost_per_token(*args, **kwargs):
@@ -2060,12 +2916,12 @@ class _BaseStreamManagerWrapper:
2060
2916
 
2061
2917
  class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
2062
2918
  async def __aenter__(self):
2063
- self._parent_span_id_at_entry = current_span_var.get()
2919
+ self._parent_span_id_at_entry = self._trace_client.get_current_span()
2064
2920
  if not self._trace_client:
2065
2921
  return await self._original_manager.__aenter__()
2066
2922
 
2067
2923
  span_id, span = self._create_span()
2068
- self._span_context_token = current_span_var.set(span_id)
2924
+ self._span_context_token = self._trace_client.set_current_span(span_id)
2069
2925
  span.inputs = _format_input_data(self._client, **self._input_kwargs)
2070
2926
 
2071
2927
  # Call the original __aenter__ and expect it to be an async generator
@@ -2075,20 +2931,20 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
2075
2931
 
2076
2932
  async def __aexit__(self, exc_type, exc_val, exc_tb):
2077
2933
  if hasattr(self, '_span_context_token'):
2078
- span_id = current_span_var.get()
2934
+ span_id = self._trace_client.get_current_span()
2079
2935
  self._finalize_span(span_id)
2080
- current_span_var.reset(self._span_context_token)
2936
+ self._trace_client.reset_current_span(self._span_context_token)
2081
2937
  delattr(self, '_span_context_token')
2082
2938
  return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2083
2939
 
2084
2940
  class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
2085
2941
  def __enter__(self):
2086
- self._parent_span_id_at_entry = current_span_var.get()
2942
+ self._parent_span_id_at_entry = self._trace_client.get_current_span()
2087
2943
  if not self._trace_client:
2088
2944
  return self._original_manager.__enter__()
2089
2945
 
2090
2946
  span_id, span = self._create_span()
2091
- self._span_context_token = current_span_var.set(span_id)
2947
+ self._span_context_token = self._trace_client.set_current_span(span_id)
2092
2948
  span.inputs = _format_input_data(self._client, **self._input_kwargs)
2093
2949
 
2094
2950
  raw_iterator = self._original_manager.__enter__()
@@ -2097,9 +2953,9 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
2097
2953
 
2098
2954
  def __exit__(self, exc_type, exc_val, exc_tb):
2099
2955
  if hasattr(self, '_span_context_token'):
2100
- span_id = current_span_var.get()
2956
+ span_id = self._trace_client.get_current_span()
2101
2957
  self._finalize_span(span_id)
2102
- current_span_var.reset(self._span_context_token)
2958
+ self._trace_client.reset_current_span(self._span_context_token)
2103
2959
  delattr(self, '_span_context_token')
2104
2960
  return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2105
2961
 
@@ -2118,4 +2974,4 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
2118
2974
  return instance_name
2119
2975
  else:
2120
2976
  raise Exception(f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator.")
2121
- return None
2977
+ return None