judgeval 0.0.40__py3-none-any.whl → 0.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -5,7 +5,6 @@ Tracing system for judgeval that allows for function tracing using decorators.
5
5
  import asyncio
6
6
  import functools
7
7
  import inspect
8
- import json
9
8
  import os
10
9
  import site
11
10
  import sysconfig
@@ -16,9 +15,10 @@ import uuid
16
15
  import warnings
17
16
  import contextvars
18
17
  import sys
18
+ import json
19
19
  from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
20
20
  from dataclasses import dataclass, field
21
- from datetime import datetime
21
+ from datetime import datetime, timezone
22
22
  from http import HTTPStatus
23
23
  from typing import (
24
24
  Any,
@@ -29,20 +29,16 @@ from typing import (
29
29
  Literal,
30
30
  Optional,
31
31
  Tuple,
32
- Type,
33
- TypeVar,
34
32
  Union,
35
33
  AsyncGenerator,
36
34
  TypeAlias,
37
- Set
38
35
  )
39
36
  from rich import print as rprint
40
- import types # <--- Add this import
37
+ import types
41
38
 
42
39
  # Third-party imports
43
40
  import requests
44
41
  from litellm import cost_per_token as _original_cost_per_token
45
- from pydantic import BaseModel
46
42
  from rich import print as rprint
47
43
  from openai import OpenAI, AsyncOpenAI
48
44
  from together import Together, AsyncTogether
@@ -53,24 +49,30 @@ from google import genai
53
49
  from judgeval.constants import (
54
50
  JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
55
51
  JUDGMENT_TRACES_SAVE_API_URL,
52
+ JUDGMENT_TRACES_UPSERT_API_URL,
53
+ JUDGMENT_TRACES_USAGE_CHECK_API_URL,
54
+ JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
56
55
  JUDGMENT_TRACES_FETCH_API_URL,
57
56
  RABBITMQ_HOST,
58
57
  RABBITMQ_PORT,
59
58
  RABBITMQ_QUEUE,
60
59
  JUDGMENT_TRACES_DELETE_API_URL,
61
60
  JUDGMENT_PROJECT_DELETE_API_URL,
61
+ JUDGMENT_TRACES_SPANS_BATCH_API_URL,
62
+ JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
62
63
  )
63
64
  from judgeval.data import Example, Trace, TraceSpan, TraceUsage
64
65
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
65
66
  from judgeval.rules import Rule
66
67
  from judgeval.evaluation_run import EvaluationRun
67
- from judgeval.data.result import ScoringResult
68
- from judgeval.common.utils import validate_api_key
68
+ from judgeval.common.utils import ExcInfo, validate_api_key
69
69
  from judgeval.common.exceptions import JudgmentAPIError
70
70
 
71
71
  # Standard library imports needed for the new class
72
72
  import concurrent.futures
73
73
  from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
74
+ import queue
75
+ import atexit
74
76
 
75
77
  # Define context variables for tracking the current trace and the current span within a trace
76
78
  current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
@@ -147,13 +149,18 @@ class TraceManagerClient:
147
149
 
148
150
  return response.json()
149
151
 
150
- def save_trace(self, trace_data: dict, offline_mode: bool = False):
152
+ def save_trace(self, trace_data: dict, offline_mode: bool = False, final_save: bool = True):
151
153
  """
152
154
  Saves a trace to the Judgment Supabase and optionally to S3 if configured.
153
155
 
154
156
  Args:
155
157
  trace_data: The trace data to save
158
+ offline_mode: Whether running in offline mode
159
+ final_save: Whether this is the final save (controls S3 saving)
156
160
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
161
+
162
+ Returns:
163
+ dict: Server response containing UI URL and other metadata
157
164
  """
158
165
  # Save to Judgment API
159
166
 
@@ -175,7 +182,6 @@ class TraceManagerClient:
175
182
  return f"<Unserializable object of type {type(obj).__name__}: {e}>"
176
183
 
177
184
  serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
178
-
179
185
  response = requests.post(
180
186
  JUDGMENT_TRACES_SAVE_API_URL,
181
187
  data=serialized_trace_data,
@@ -192,8 +198,107 @@ class TraceManagerClient:
192
198
  elif response.status_code != HTTPStatus.OK:
193
199
  raise ValueError(f"Failed to save trace data: {response.text}")
194
200
 
195
- # If S3 storage is enabled, save to S3 as well
196
- if self.tracer and self.tracer.use_s3:
201
+ # Parse server response
202
+ server_response = response.json()
203
+
204
+ # If S3 storage is enabled, save to S3 only on final save
205
+ if self.tracer and self.tracer.use_s3 and final_save:
206
+ try:
207
+ s3_key = self.tracer.s3_storage.save_trace(
208
+ trace_data=trace_data,
209
+ trace_id=trace_data["trace_id"],
210
+ project_name=trace_data["project_name"]
211
+ )
212
+ print(f"Trace also saved to S3 at key: {s3_key}")
213
+ except Exception as e:
214
+ warnings.warn(f"Failed to save trace to S3: {str(e)}")
215
+
216
+ if not offline_mode and "ui_results_url" in server_response:
217
+ pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
218
+ rprint(pretty_str)
219
+
220
+ return server_response
221
+
222
+ def check_usage_limits(self, count: int = 1):
223
+ """
224
+ Check if the organization can use the requested number of traces without exceeding limits.
225
+
226
+ Args:
227
+ count: Number of traces to check for (default: 1)
228
+
229
+ Returns:
230
+ dict: Server response with rate limit status and usage info
231
+
232
+ Raises:
233
+ ValueError: If rate limits would be exceeded or other errors occur
234
+ """
235
+ response = requests.post(
236
+ JUDGMENT_TRACES_USAGE_CHECK_API_URL,
237
+ json={"count": count},
238
+ headers={
239
+ "Content-Type": "application/json",
240
+ "Authorization": f"Bearer {self.judgment_api_key}",
241
+ "X-Organization-Id": self.organization_id
242
+ },
243
+ verify=True
244
+ )
245
+
246
+ if response.status_code == HTTPStatus.FORBIDDEN:
247
+ # Rate limits exceeded
248
+ error_data = response.json()
249
+ raise ValueError(f"Rate limit exceeded: {error_data.get('detail', 'Monthly trace limit reached')}")
250
+ elif response.status_code != HTTPStatus.OK:
251
+ raise ValueError(f"Failed to check usage limits: {response.text}")
252
+
253
+ return response.json()
254
+
255
+ def upsert_trace(self, trace_data: dict, offline_mode: bool = False, show_link: bool = True, final_save: bool = True):
256
+ """
257
+ Upserts a trace to the Judgment API (always overwrites if exists).
258
+
259
+ Args:
260
+ trace_data: The trace data to upsert
261
+ offline_mode: Whether running in offline mode
262
+ show_link: Whether to show the UI link (for live tracing)
263
+ final_save: Whether this is the final save (controls S3 saving)
264
+
265
+ Returns:
266
+ dict: Server response containing UI URL and other metadata
267
+ """
268
+ def fallback_encoder(obj):
269
+ """
270
+ Custom JSON encoder fallback.
271
+ Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
272
+ """
273
+ try:
274
+ return repr(obj)
275
+ except Exception:
276
+ try:
277
+ return str(obj)
278
+ except Exception as e:
279
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
280
+
281
+ serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
282
+
283
+ response = requests.post(
284
+ JUDGMENT_TRACES_UPSERT_API_URL,
285
+ data=serialized_trace_data,
286
+ headers={
287
+ "Content-Type": "application/json",
288
+ "Authorization": f"Bearer {self.judgment_api_key}",
289
+ "X-Organization-Id": self.organization_id
290
+ },
291
+ verify=True
292
+ )
293
+
294
+ if response.status_code != HTTPStatus.OK:
295
+ raise ValueError(f"Failed to upsert trace data: {response.text}")
296
+
297
+ # Parse server response
298
+ server_response = response.json()
299
+
300
+ # If S3 storage is enabled, save to S3 only on final save
301
+ if self.tracer and self.tracer.use_s3 and final_save:
197
302
  try:
198
303
  s3_key = self.tracer.s3_storage.save_trace(
199
304
  trace_data=trace_data,
@@ -204,9 +309,40 @@ class TraceManagerClient:
204
309
  except Exception as e:
205
310
  warnings.warn(f"Failed to save trace to S3: {str(e)}")
206
311
 
207
- if not offline_mode and "ui_results_url" in response.json():
208
- pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
312
+ if not offline_mode and show_link and "ui_results_url" in server_response:
313
+ pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={server_response['ui_results_url']}]View Trace[/link]\n"
209
314
  rprint(pretty_str)
315
+
316
+ return server_response
317
+
318
+ def update_usage_counters(self, count: int = 1):
319
+ """
320
+ Update trace usage counters after successfully saving traces.
321
+
322
+ Args:
323
+ count: Number of traces to count (default: 1)
324
+
325
+ Returns:
326
+ dict: Server response with updated usage information
327
+
328
+ Raises:
329
+ ValueError: If the update fails
330
+ """
331
+ response = requests.post(
332
+ JUDGMENT_TRACES_USAGE_UPDATE_API_URL,
333
+ json={"count": count},
334
+ headers={
335
+ "Content-Type": "application/json",
336
+ "Authorization": f"Bearer {self.judgment_api_key}",
337
+ "X-Organization-Id": self.organization_id
338
+ },
339
+ verify=True
340
+ )
341
+
342
+ if response.status_code != HTTPStatus.OK:
343
+ raise ValueError(f"Failed to update usage counters: {response.text}")
344
+
345
+ return response.json()
210
346
 
211
347
  ## TODO: Should have a log endpoint, endpoint should also support batched payloads
212
348
  def save_annotation(self, annotation: TraceAnnotation):
@@ -307,7 +443,7 @@ class TraceClient:
307
443
  tracer: Optional["Tracer"],
308
444
  trace_id: Optional[str] = None,
309
445
  name: str = "default",
310
- project_name: str = "default_project",
446
+ project_name: str = None,
311
447
  overwrite: bool = False,
312
448
  rules: Optional[List[Rule]] = None,
313
449
  enable_monitoring: bool = True,
@@ -317,7 +453,7 @@ class TraceClient:
317
453
  ):
318
454
  self.name = name
319
455
  self.trace_id = trace_id or str(uuid.uuid4())
320
- self.project_name = project_name
456
+ self.project_name = project_name or str(uuid.uuid4())
321
457
  self.overwrite = overwrite
322
458
  self.tracer = tracer
323
459
  self.rules = rules or []
@@ -329,35 +465,48 @@ class TraceClient:
329
465
  self.span_id_to_span: Dict[str, TraceSpan] = {}
330
466
  self.evaluation_runs: List[EvaluationRun] = []
331
467
  self.annotations: List[TraceAnnotation] = []
332
- self.start_time = time.time()
468
+ self.start_time = None # Will be set after first successful save
333
469
  self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
334
470
  self.visited_nodes = []
335
471
  self.executed_tools = []
336
472
  self.executed_node_tools = []
337
473
  self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
474
+
475
+ # Get background span service from tracer
476
+ self.background_span_service = tracer.get_background_span_service() if tracer else None
338
477
 
339
478
  def get_current_span(self):
340
479
  """Get the current span from the context var"""
341
- return current_span_var.get()
480
+ return self.tracer.get_current_span()
342
481
 
343
482
  def set_current_span(self, span: Any):
344
483
  """Set the current span from the context var"""
345
- return current_span_var.set(span)
484
+ return self.tracer.set_current_span(span)
346
485
 
347
486
  def reset_current_span(self, token: Any):
348
487
  """Reset the current span from the context var"""
349
- return current_span_var.reset(token)
488
+ self.tracer.reset_current_span(token)
350
489
 
351
490
  @contextmanager
352
491
  def span(self, name: str, span_type: SpanType = "span"):
353
492
  """Context manager for creating a trace span, managing the current span via contextvars"""
493
+ is_first_span = len(self.trace_spans) == 0
494
+ if is_first_span:
495
+ try:
496
+ trace_id, server_response = self.save_with_rate_limiting(overwrite=self.overwrite, final_save=False)
497
+ # Set start_time after first successful save
498
+ if self.start_time is None:
499
+ self.start_time = time.time()
500
+ # Link will be shown by upsert_trace method
501
+ except Exception as e:
502
+ warnings.warn(f"Failed to save initial trace for live tracking: {e}")
354
503
  start_time = time.time()
355
504
 
356
505
  # Generate a unique ID for *this specific span invocation*
357
506
  span_id = str(uuid.uuid4())
358
507
 
359
- parent_span_id = current_span_var.get() # Get ID of the parent span from context var
360
- token = current_span_var.set(span_id) # Set *this* span's ID as the current one
508
+ parent_span_id = self.get_current_span() # Get ID of the parent span from context var
509
+ token = self.set_current_span(span_id) # Set *this* span's ID as the current one
361
510
 
362
511
  current_depth = 0
363
512
  if parent_span_id and parent_span_id in self._span_depths:
@@ -377,16 +526,27 @@ class TraceClient:
377
526
  )
378
527
  self.add_span(span)
379
528
 
529
+
530
+
531
+ # Queue span with initial state (input phase)
532
+ if self.background_span_service:
533
+ self.background_span_service.queue_span(span, span_state="input")
534
+
380
535
  try:
381
536
  yield self
382
537
  finally:
383
538
  duration = time.time() - start_time
384
539
  span.duration = duration
540
+
541
+ # Queue span with completed state (output phase)
542
+ if self.background_span_service:
543
+ self.background_span_service.queue_span(span, span_state="completed")
544
+
385
545
  # Clean up depth tracking for this span_id
386
546
  if span_id in self._span_depths:
387
547
  del self._span_depths[span_id]
388
548
  # Reset context var
389
- current_span_var.reset(token)
549
+ self.reset_current_span(token)
390
550
 
391
551
  def async_evaluate(
392
552
  self,
@@ -450,8 +610,7 @@ class TraceClient:
450
610
  # span_id_at_eval_call = current_span_var.get()
451
611
  # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
452
612
  # Prioritize explicitly passed span_id, fallback to context var
453
- current_span_ctx_var = current_span_var.get()
454
- span_id_to_use = span_id if span_id is not None else current_span_ctx_var if current_span_ctx_var is not None else self.tracer.get_current_span()
613
+ span_id_to_use = span_id if span_id is not None else self.get_current_span()
455
614
  # print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
456
615
  # --- End Modification ---
457
616
 
@@ -474,6 +633,17 @@ class TraceClient:
474
633
  )
475
634
 
476
635
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
636
+
637
+ # Queue evaluation run through background service
638
+ if self.background_span_service and span_id_to_use:
639
+ # Get the current span data to avoid race conditions
640
+ current_span = self.span_id_to_span.get(span_id_to_use)
641
+ if current_span:
642
+ self.background_span_service.queue_evaluation_run(
643
+ eval_run,
644
+ span_id=span_id_to_use,
645
+ span_data=current_span
646
+ )
477
647
 
478
648
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
479
649
  # --- Modification: Use span_id from eval_run ---
@@ -493,58 +663,119 @@ class TraceClient:
493
663
  return self
494
664
 
495
665
  def record_input(self, inputs: dict):
496
- current_span_id = current_span_var.get()
666
+ current_span_id = self.get_current_span()
497
667
  if current_span_id:
498
668
  span = self.span_id_to_span[current_span_id]
499
669
  # Ignore self parameter
500
670
  if "self" in inputs:
501
671
  del inputs["self"]
502
672
  span.inputs = inputs
673
+
674
+ # Queue span with input data
675
+ if self.background_span_service:
676
+ self.background_span_service.queue_span(span, span_state="input")
503
677
 
504
678
  def record_agent_name(self, agent_name: str):
505
- current_span_id = current_span_var.get()
679
+ current_span_id = self.get_current_span()
506
680
  if current_span_id:
507
681
  span = self.span_id_to_span[current_span_id]
508
682
  span.agent_name = agent_name
683
+
684
+ # Queue span with agent_name data
685
+ if self.background_span_service:
686
+ self.background_span_service.queue_span(span, span_state="agent_name")
687
+
688
+ def record_state_before(self, state: dict):
689
+ """Records the agent's state before a tool execution on the current span.
690
+
691
+ Args:
692
+ state: A dictionary representing the agent's state.
693
+ """
694
+ current_span_id = self.get_current_span()
695
+ if current_span_id:
696
+ span = self.span_id_to_span[current_span_id]
697
+ span.state_before = state
698
+
699
+ # Queue span with state_before data
700
+ if self.background_span_service:
701
+ self.background_span_service.queue_span(span, span_state="state_before")
702
+
703
+ def record_state_after(self, state: dict):
704
+ """Records the agent's state after a tool execution on the current span.
705
+
706
+ Args:
707
+ state: A dictionary representing the agent's state.
708
+ """
709
+ current_span_id = self.get_current_span()
710
+ if current_span_id:
711
+ span = self.span_id_to_span[current_span_id]
712
+ span.state_after = state
713
+
714
+ # Queue span with state_after data
715
+ if self.background_span_service:
716
+ self.background_span_service.queue_span(span, span_state="state_after")
509
717
 
510
718
  async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
511
719
  """Helper method to update the output of a trace entry once the coroutine completes"""
512
720
  try:
513
721
  result = await coroutine
514
722
  setattr(span, field, result)
723
+
724
+ # Queue span with output data now that coroutine is complete
725
+ if self.background_span_service and field == "output":
726
+ self.background_span_service.queue_span(span, span_state="output")
727
+
515
728
  return result
516
729
  except Exception as e:
517
730
  setattr(span, field, f"Error: {str(e)}")
731
+
732
+ # Queue span even if there was an error
733
+ if self.background_span_service and field == "output":
734
+ self.background_span_service.queue_span(span, span_state="output")
735
+
518
736
  raise
519
737
 
520
738
  def record_output(self, output: Any):
521
- current_span_id = current_span_var.get()
739
+ current_span_id = self.get_current_span()
522
740
  if current_span_id:
523
741
  span = self.span_id_to_span[current_span_id]
524
742
  span.output = "<pending>" if inspect.iscoroutine(output) else output
525
743
 
526
744
  if inspect.iscoroutine(output):
527
745
  asyncio.create_task(self._update_coroutine(span, output, "output"))
746
+
747
+ # # Queue span with output data (unless it's pending)
748
+ if self.background_span_service and not inspect.iscoroutine(output):
749
+ self.background_span_service.queue_span(span, span_state="output")
528
750
 
529
751
  return span # Return the created entry
530
752
  # Removed else block - original didn't have one
531
753
  return None # Return None if no span_id found
532
754
 
533
755
  def record_usage(self, usage: TraceUsage):
534
- current_span_id = current_span_var.get()
756
+ current_span_id = self.get_current_span()
535
757
  if current_span_id:
536
758
  span = self.span_id_to_span[current_span_id]
537
759
  span.usage = usage
538
760
 
761
+ # Queue span with usage data
762
+ if self.background_span_service:
763
+ self.background_span_service.queue_span(span, span_state="usage")
764
+
539
765
  return span # Return the created entry
540
766
  # Removed else block - original didn't have one
541
767
  return None # Return None if no span_id found
542
768
 
543
- def record_error(self, error: Any):
544
- current_span_id = current_span_var.get()
769
+ def record_error(self, error: Dict[str, Any]):
770
+ current_span_id = self.get_current_span()
545
771
  if current_span_id:
546
772
  span = self.span_id_to_span[current_span_id]
547
773
  span.error = error
774
+
775
+ # Queue span with error data
776
+ if self.background_span_service:
777
+ self.background_span_service.queue_span(span, span_state="error")
778
+
548
779
  return span
549
780
  return None
550
781
 
@@ -563,13 +794,19 @@ class TraceClient:
563
794
  """
564
795
  Get the total duration of this trace
565
796
  """
797
+ if self.start_time is None:
798
+ return 0.0 # No duration if trace hasn't been saved yet
566
799
  return time.time() - self.start_time
567
800
 
568
801
  def save(self, overwrite: bool = False) -> Tuple[str, dict]:
569
802
  """
570
803
  Save the current trace to the database.
571
- Returns a tuple of (trace_id, trace_data) where trace_data is the trace data that was saved.
804
+ Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
572
805
  """
806
+ # Set start_time if this is the first save
807
+ if self.start_time is None:
808
+ self.start_time = time.time()
809
+
573
810
  # Calculate total elapsed time
574
811
  total_duration = self.get_duration()
575
812
  # Create trace document - Always use standard keys for top-level counts
@@ -577,9 +814,9 @@ class TraceClient:
577
814
  "trace_id": self.trace_id,
578
815
  "name": self.name,
579
816
  "project_name": self.project_name,
580
- "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
817
+ "created_at": datetime.fromtimestamp(self.start_time, timezone.utc).isoformat(),
581
818
  "duration": total_duration,
582
- "entries": [span.model_dump() for span in self.trace_spans],
819
+ "trace_spans": [span.model_dump() for span in self.trace_spans],
583
820
  "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
584
821
  "overwrite": overwrite,
585
822
  "offline_mode": self.tracer.offline_mode,
@@ -587,19 +824,84 @@ class TraceClient:
587
824
  "parent_name": self.parent_name
588
825
  }
589
826
  # --- Log trace data before saving ---
590
- self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode)
827
+ server_response = self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode, final_save=True)
591
828
 
592
829
  # upload annotations
593
830
  # TODO: batch to the log endpoint
594
831
  for annotation in self.annotations:
595
832
  self.trace_manager_client.save_annotation(annotation)
596
833
 
597
- return self.trace_id, trace_data
834
+ return self.trace_id, server_response
835
+
836
+ def save_with_rate_limiting(self, overwrite: bool = False, final_save: bool = False) -> Tuple[str, dict]:
837
+ """
838
+ Save the current trace to the database with rate limiting checks.
839
+ First checks usage limits, then upserts the trace if allowed.
840
+
841
+ Args:
842
+ overwrite: Whether to overwrite existing traces
843
+ final_save: Whether this is the final save (updates usage counters)
844
+
845
+ Returns a tuple of (trace_id, server_response) where server_response contains the UI URL and other metadata.
846
+ """
847
+
848
+
849
+ # Calculate total elapsed time
850
+ total_duration = self.get_duration()
851
+
852
+ # Create trace document
853
+ trace_data = {
854
+ "trace_id": self.trace_id,
855
+ "name": self.name,
856
+ "project_name": self.project_name,
857
+ "created_at": datetime.utcfromtimestamp(time.time()).isoformat(),
858
+ "duration": total_duration,
859
+ "trace_spans": [span.model_dump() for span in self.trace_spans],
860
+ "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
861
+ "overwrite": overwrite,
862
+ "offline_mode": self.tracer.offline_mode,
863
+ "parent_trace_id": self.parent_trace_id,
864
+ "parent_name": self.parent_name
865
+ }
866
+
867
+ # Check usage limits first
868
+ try:
869
+ usage_check_result = self.trace_manager_client.check_usage_limits(count=1)
870
+ # Usage check passed silently - no need to show detailed info
871
+ except ValueError as e:
872
+ # Rate limit exceeded
873
+ warnings.warn(f"Rate limit check failed for live tracing: {e}")
874
+ raise e
875
+
876
+ # If usage check passes, upsert the trace
877
+ server_response = self.trace_manager_client.upsert_trace(
878
+ trace_data,
879
+ offline_mode=self.tracer.offline_mode,
880
+ show_link=not final_save, # Show link only on initial save, not final save
881
+ final_save=final_save # Pass final_save to control S3 saving
882
+ )
883
+
884
+ # Update usage counters only on final save
885
+ if final_save:
886
+ try:
887
+ usage_update_result = self.trace_manager_client.update_usage_counters(count=1)
888
+ # Usage updated silently - no need to show detailed usage info
889
+ except ValueError as e:
890
+ # Log warning but don't fail the trace save since the trace was already saved
891
+ warnings.warn(f"Usage counter update failed (trace was still saved): {e}")
892
+
893
+ # Upload annotations
894
+ # TODO: batch to the log endpoint
895
+ for annotation in self.annotations:
896
+ self.trace_manager_client.save_annotation(annotation)
897
+ if self.start_time is None:
898
+ self.start_time = time.time()
899
+ return self.trace_id, server_response
598
900
 
599
901
  def delete(self):
600
902
  return self.trace_manager_client.delete_trace(self.trace_id)
601
903
 
602
- def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: Tuple[Optional[type], Optional[BaseException], Optional[types.TracebackType]]):
904
+ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: ExcInfo):
603
905
  if not current_trace:
604
906
  return
605
907
 
@@ -609,7 +911,360 @@ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_inf
609
911
  "message": str(exc_value) if exc_value else "No exception message",
610
912
  "traceback": traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
611
913
  }
914
+
915
+ # This is where we specially handle exceptions that we might want to collect additional data for.
916
+ # When we do this, always try checking the module from sys.modules instead of importing. This will
917
+ # Let us support a wider range of exceptions without needing to import them for all clients.
918
+
919
+ # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
920
+ # The alternative is to hand select libraries we want from sys.modules and check for them:
921
+ # As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
922
+
923
+ # General HTTP Like errors
924
+ try:
925
+ url = getattr(getattr(exc_value, "request", None), "url", None)
926
+ status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
927
+ if status_code:
928
+ formatted_exception["http"] = {
929
+ "url": url if url else "Unknown URL",
930
+ "status_code": status_code if status_code else None,
931
+ }
932
+ except Exception as e:
933
+ pass
934
+
612
935
  current_trace.record_error(formatted_exception)
936
+
937
+ # Queue the span with error state through background service
938
+ if current_trace.background_span_service:
939
+ current_span_id = current_trace.get_current_span()
940
+ if current_span_id and current_span_id in current_trace.span_id_to_span:
941
+ error_span = current_trace.span_id_to_span[current_span_id]
942
+ current_trace.background_span_service.queue_span(error_span, span_state="error")
943
+
944
+ class BackgroundSpanService:
945
+ """
946
+ Background service for queueing and batching trace spans for efficient saving.
947
+
948
+ This service:
949
+ - Queues spans as they complete
950
+ - Batches them for efficient network usage
951
+ - Sends spans periodically or when batches reach a certain size
952
+ - Handles automatic flushing when the main event terminates
953
+ """
954
+
955
+ def __init__(self, judgment_api_key: str, organization_id: str,
956
+ batch_size: int = 10, flush_interval: float = 5.0, num_workers: int = 1):
957
+ """
958
+ Initialize the background span service.
959
+
960
+ Args:
961
+ judgment_api_key: API key for Judgment service
962
+ organization_id: Organization ID
963
+ batch_size: Number of spans to batch before sending (default: 10)
964
+ flush_interval: Time in seconds between automatic flushes (default: 5.0)
965
+ num_workers: Number of worker threads to process the queue (default: 1)
966
+ """
967
+ self.judgment_api_key = judgment_api_key
968
+ self.organization_id = organization_id
969
+ self.batch_size = batch_size
970
+ self.flush_interval = flush_interval
971
+ self.num_workers = max(1, num_workers) # Ensure at least 1 worker
972
+
973
+ # Queue for pending spans
974
+ self._span_queue = queue.Queue()
975
+
976
+ # Background threads for processing spans
977
+ self._worker_threads = []
978
+ self._shutdown_event = threading.Event()
979
+
980
+ # Track spans that have been sent
981
+ self._sent_spans = set()
982
+
983
+ # Register cleanup on exit
984
+ atexit.register(self.shutdown)
985
+
986
+ # Start the background workers
987
+ self._start_workers()
988
+
989
+ def _start_workers(self):
990
+ """Start the background worker threads."""
991
+ for i in range(self.num_workers):
992
+ if len(self._worker_threads) < self.num_workers:
993
+ worker_thread = threading.Thread(
994
+ target=self._worker_loop,
995
+ daemon=True,
996
+ name=f"SpanWorker-{i+1}"
997
+ )
998
+ worker_thread.start()
999
+ self._worker_threads.append(worker_thread)
1000
+
1001
+ def _worker_loop(self):
1002
+ """Main worker loop that processes spans in batches."""
1003
+ batch = []
1004
+ last_flush_time = time.time()
1005
+ pending_task_count = 0 # Track how many tasks we've taken from queue but not marked done
1006
+
1007
+ while not self._shutdown_event.is_set() or self._span_queue.qsize() > 0:
1008
+ try:
1009
+ # First, do a blocking get to wait for at least one item
1010
+ if not batch: # Only block if we don't have items already
1011
+ try:
1012
+ span_data = self._span_queue.get(timeout=1.0)
1013
+ batch.append(span_data)
1014
+ pending_task_count += 1
1015
+ except queue.Empty:
1016
+ # No new spans, continue to check for flush conditions
1017
+ pass
1018
+
1019
+ # Then, do non-blocking gets to drain any additional available items
1020
+ # up to our batch size limit
1021
+ while len(batch) < self.batch_size:
1022
+ try:
1023
+ span_data = self._span_queue.get_nowait() # Non-blocking
1024
+ batch.append(span_data)
1025
+ pending_task_count += 1
1026
+ except queue.Empty:
1027
+ # No more items immediately available
1028
+ break
1029
+
1030
+ current_time = time.time()
1031
+ should_flush = (
1032
+ len(batch) >= self.batch_size or
1033
+ (batch and (current_time - last_flush_time) >= self.flush_interval)
1034
+ )
1035
+
1036
+ if should_flush and batch:
1037
+ self._send_batch(batch)
1038
+
1039
+ # Only mark tasks as done after successful sending
1040
+ for _ in range(pending_task_count):
1041
+ self._span_queue.task_done()
1042
+ pending_task_count = 0 # Reset counter
1043
+
1044
+ batch.clear()
1045
+ last_flush_time = current_time
1046
+
1047
+ except Exception as e:
1048
+ warnings.warn(f"Error in span service worker loop: {e}")
1049
+ # On error, still need to mark tasks as done to prevent hanging
1050
+ for _ in range(pending_task_count):
1051
+ self._span_queue.task_done()
1052
+ pending_task_count = 0
1053
+ batch.clear()
1054
+
1055
+ # Final flush on shutdown
1056
+ if batch:
1057
+ self._send_batch(batch)
1058
+ # Mark remaining tasks as done
1059
+ for _ in range(pending_task_count):
1060
+ self._span_queue.task_done()
1061
+
1062
+ def _send_batch(self, batch: List[Dict[str, Any]]):
1063
+ """
1064
+ Send a batch of spans to the server.
1065
+
1066
+ Args:
1067
+ batch: List of span dictionaries to send
1068
+ """
1069
+ if not batch:
1070
+ return
1071
+
1072
+ try:
1073
+ # Group spans by type for different endpoints
1074
+ spans_to_send = []
1075
+ evaluation_runs_to_send = []
1076
+
1077
+ for item in batch:
1078
+ if item['type'] == 'span':
1079
+ spans_to_send.append(item['data'])
1080
+ elif item['type'] == 'evaluation_run':
1081
+ evaluation_runs_to_send.append(item['data'])
1082
+
1083
+ # Send spans if any
1084
+ if spans_to_send:
1085
+ self._send_spans_batch(spans_to_send)
1086
+
1087
+ # Send evaluation runs if any
1088
+ if evaluation_runs_to_send:
1089
+ self._send_evaluation_runs_batch(evaluation_runs_to_send)
1090
+
1091
+ except Exception as e:
1092
+ warnings.warn(f"Failed to send span batch: {e}")
1093
+
1094
+ def _send_spans_batch(self, spans: List[Dict[str, Any]]):
1095
+ """Send a batch of spans to the spans endpoint."""
1096
+ payload = {
1097
+ "spans": spans,
1098
+ "organization_id": self.organization_id
1099
+ }
1100
+
1101
+ # Serialize with fallback encoder
1102
+ def fallback_encoder(obj):
1103
+ try:
1104
+ return repr(obj)
1105
+ except Exception:
1106
+ try:
1107
+ return str(obj)
1108
+ except Exception as e:
1109
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1110
+
1111
+ try:
1112
+ serialized_data = json.dumps(payload, default=fallback_encoder)
1113
+
1114
+ # Send the actual HTTP request to the batch endpoint
1115
+ response = requests.post(
1116
+ JUDGMENT_TRACES_SPANS_BATCH_API_URL,
1117
+ data=serialized_data,
1118
+ headers={
1119
+ "Content-Type": "application/json",
1120
+ "Authorization": f"Bearer {self.judgment_api_key}",
1121
+ "X-Organization-Id": self.organization_id
1122
+ },
1123
+ verify=True,
1124
+ timeout=30 # Add timeout to prevent hanging
1125
+ )
1126
+
1127
+ if response.status_code != HTTPStatus.OK:
1128
+ warnings.warn(f"Failed to send spans batch: HTTP {response.status_code} - {response.text}")
1129
+
1130
+
1131
+ except requests.RequestException as e:
1132
+ warnings.warn(f"Network error sending spans batch: {e}")
1133
+ except Exception as e:
1134
+ warnings.warn(f"Failed to serialize or send spans batch: {e}")
1135
+
1136
+ def _send_evaluation_runs_batch(self, evaluation_runs: List[Dict[str, Any]]):
1137
+ """Send a batch of evaluation runs with their associated span data to the endpoint."""
1138
+ # Structure payload to include both evaluation run data and span data
1139
+ evaluation_entries = []
1140
+ for eval_data in evaluation_runs:
1141
+ # eval_data already contains the evaluation run data (no need to access ['data'])
1142
+ entry = {
1143
+ "evaluation_run": {
1144
+ # Extract evaluation run fields (excluding span-specific fields)
1145
+ key: value for key, value in eval_data.items()
1146
+ if key not in ['associated_span_id', 'span_data', 'queued_at']
1147
+ },
1148
+ "associated_span": {
1149
+ "span_id": eval_data.get('associated_span_id'),
1150
+ "span_data": eval_data.get('span_data')
1151
+ },
1152
+ "queued_at": eval_data.get('queued_at')
1153
+ }
1154
+ evaluation_entries.append(entry)
1155
+
1156
+ payload = {
1157
+ "organization_id": self.organization_id,
1158
+ "evaluation_entries": evaluation_entries # Each entry contains both eval run + span data
1159
+ }
1160
+
1161
+ # Serialize with fallback encoder
1162
+ def fallback_encoder(obj):
1163
+ try:
1164
+ return repr(obj)
1165
+ except Exception:
1166
+ try:
1167
+ return str(obj)
1168
+ except Exception as e:
1169
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
1170
+
1171
+ try:
1172
+ serialized_data = json.dumps(payload, default=fallback_encoder)
1173
+
1174
+ # Send the actual HTTP request to the batch endpoint
1175
+ response = requests.post(
1176
+ JUDGMENT_TRACES_EVALUATION_RUNS_BATCH_API_URL,
1177
+ data=serialized_data,
1178
+ headers={
1179
+ "Content-Type": "application/json",
1180
+ "Authorization": f"Bearer {self.judgment_api_key}",
1181
+ "X-Organization-Id": self.organization_id
1182
+ },
1183
+ verify=True,
1184
+ timeout=30 # Add timeout to prevent hanging
1185
+ )
1186
+
1187
+ if response.status_code != HTTPStatus.OK:
1188
+ warnings.warn(f"Failed to send evaluation runs batch: HTTP {response.status_code} - {response.text}")
1189
+
1190
+
1191
+ except requests.RequestException as e:
1192
+ warnings.warn(f"Network error sending evaluation runs batch: {e}")
1193
+ except Exception as e:
1194
+ warnings.warn(f"Failed to send evaluation runs batch: {e}")
1195
+
1196
+ def queue_span(self, span: TraceSpan, span_state: str = "input"):
1197
+ """
1198
+ Queue a span for background sending.
1199
+
1200
+ Args:
1201
+ span: The TraceSpan object to queue
1202
+ span_state: State of the span ("input", "output", "completed")
1203
+ """
1204
+ if not self._shutdown_event.is_set():
1205
+ span_data = {
1206
+ "type": "span",
1207
+ "data": {
1208
+ **span.model_dump(),
1209
+ "span_state": span_state,
1210
+ "queued_at": time.time()
1211
+ }
1212
+ }
1213
+ self._span_queue.put(span_data)
1214
+
1215
+ def queue_evaluation_run(self, evaluation_run: EvaluationRun, span_id: str, span_data: TraceSpan):
1216
+ """
1217
+ Queue an evaluation run for background sending.
1218
+
1219
+ Args:
1220
+ evaluation_run: The EvaluationRun object to queue
1221
+ span_id: The span ID associated with this evaluation run
1222
+ span_data: The span data at the time of evaluation (to avoid race conditions)
1223
+ """
1224
+ if not self._shutdown_event.is_set():
1225
+ eval_data = {
1226
+ "type": "evaluation_run",
1227
+ "data": {
1228
+ **evaluation_run.model_dump(),
1229
+ "associated_span_id": span_id,
1230
+ "span_data": span_data.model_dump(), # Include span data to avoid race conditions
1231
+ "queued_at": time.time()
1232
+ }
1233
+ }
1234
+ self._span_queue.put(eval_data)
1235
+
1236
+ def flush(self):
1237
+ """Force immediate sending of all queued spans."""
1238
+ try:
1239
+ # Wait for the queue to be processed
1240
+ self._span_queue.join()
1241
+ except Exception as e:
1242
+ warnings.warn(f"Error during flush: {e}")
1243
+
1244
+ def shutdown(self):
1245
+ """Shutdown the background service and flush remaining spans."""
1246
+ if self._shutdown_event.is_set():
1247
+ return
1248
+
1249
+ try:
1250
+ # Signal shutdown to stop new items from being queued
1251
+ self._shutdown_event.set()
1252
+
1253
+ # Try to flush any remaining spans
1254
+ try:
1255
+ self.flush()
1256
+ except Exception as e:
1257
+ warnings.warn(f"Error during final flush: {e}")
1258
+ except Exception as e:
1259
+ warnings.warn(f"Error during BackgroundSpanService shutdown: {e}")
1260
+ finally:
1261
+ # Clear the worker threads list (daemon threads will be killed automatically)
1262
+ self._worker_threads.clear()
1263
+
1264
+ def get_queue_size(self) -> int:
1265
+ """Get the current size of the span queue."""
1266
+ return self._span_queue.qsize()
1267
+
613
1268
  class _DeepTracer:
614
1269
  _instance: Optional["_DeepTracer"] = None
615
1270
  _lock: threading.Lock = threading.Lock()
@@ -619,6 +1274,9 @@ class _DeepTracer:
619
1274
  _original_sys_trace: Optional[Callable] = None
620
1275
  _original_threading_trace: Optional[Callable] = None
621
1276
 
1277
+ def __init__(self, tracer: 'Tracer'):
1278
+ self._tracer = tracer
1279
+
622
1280
  def _get_qual_name(self, frame) -> str:
623
1281
  func_name = frame.f_code.co_name
624
1282
  module_name = frame.f_globals.get("__name__", "unknown_module")
@@ -632,7 +1290,7 @@ class _DeepTracer:
632
1290
  except Exception:
633
1291
  return f"{module_name}.{func_name}"
634
1292
 
635
- def __new__(cls):
1293
+ def __new__(cls, tracer: 'Tracer' = None):
636
1294
  with cls._lock:
637
1295
  if cls._instance is None:
638
1296
  cls._instance = super().__new__(cls)
@@ -718,11 +1376,11 @@ class _DeepTracer:
718
1376
  if event not in ("call", "return", "exception"):
719
1377
  return
720
1378
 
721
- current_trace = current_trace_var.get()
1379
+ current_trace = self._tracer.get_current_trace()
722
1380
  if not current_trace:
723
1381
  return
724
1382
 
725
- parent_span_id = current_span_var.get()
1383
+ parent_span_id = self._tracer.get_current_span()
726
1384
  if not parent_span_id:
727
1385
  return
728
1386
 
@@ -784,7 +1442,7 @@ class _DeepTracer:
784
1442
  })
785
1443
  self._span_stack.set(span_stack)
786
1444
 
787
- token = current_span_var.set(span_id)
1445
+ token = self._tracer.set_current_span(span_id)
788
1446
  frame.f_locals["_judgment_span_token"] = token
789
1447
 
790
1448
  span = TraceSpan(
@@ -818,7 +1476,7 @@ class _DeepTracer:
818
1476
  if not span_stack:
819
1477
  return
820
1478
 
821
- current_id = current_span_var.get()
1479
+ current_id = self._tracer.get_current_span()
822
1480
 
823
1481
  span_data = None
824
1482
  for i, entry in enumerate(reversed(span_stack)):
@@ -843,12 +1501,12 @@ class _DeepTracer:
843
1501
  del current_trace._span_depths[span_data["span_id"]]
844
1502
 
845
1503
  if span_stack:
846
- current_span_var.set(span_stack[-1]["span_id"])
1504
+ self._tracer.set_current_span(span_stack[-1]["span_id"])
847
1505
  else:
848
- current_span_var.set(span_data["parent_span_id"])
1506
+ self._tracer.set_current_span(span_data["parent_span_id"])
849
1507
 
850
1508
  if "_judgment_span_token" in frame.f_locals:
851
- current_span_var.reset(frame.f_locals["_judgment_span_token"])
1509
+ self._tracer.reset_current_span(frame.f_locals["_judgment_span_token"])
852
1510
 
853
1511
  elif event == "exception":
854
1512
  exc_type = arg[0]
@@ -887,18 +1545,28 @@ class _DeepTracer:
887
1545
  self._original_threading_trace = None
888
1546
 
889
1547
 
890
- def log(self, message: str, level: str = "info"):
891
- """ Log a message with the span context """
892
- current_trace = current_trace_var.get()
893
- if current_trace:
894
- current_trace.log(message, level)
895
- else:
896
- print(f"[{level}] {message}")
897
- current_trace.record_output({"log": message})
1548
+ # Below commented out function isn't used anymore?
1549
+
1550
+ # def log(self, message: str, level: str = "info"):
1551
+ # """ Log a message with the span context """
1552
+ # current_trace = self._tracer.get_current_trace()
1553
+ # if current_trace:
1554
+ # current_trace.log(message, level)
1555
+ # else:
1556
+ # print(f"[{level}] {message}")
1557
+ # current_trace.record_output({"log": message})
898
1558
 
899
1559
  class Tracer:
900
1560
  _instance = None
901
1561
 
1562
+ # Tracer.current_trace class variable is currently used in wrap()
1563
+ # TODO: Keep track of cross-context state for current trace and current span ID solely through class variables instead of instance variables?
1564
+ # Should be fine to do so as long as we keep Tracer as a singleton
1565
+ current_trace: Optional[TraceClient] = None
1566
+ # current_span_id: Optional[str] = None
1567
+
1568
+ trace_across_async_contexts: bool = False # BY default, we don't trace across async contexts
1569
+
902
1570
  def __new__(cls, *args, **kwargs):
903
1571
  if cls._instance is None:
904
1572
  cls._instance = super(Tracer, cls).__new__(cls)
@@ -907,7 +1575,7 @@ class Tracer:
907
1575
  def __init__(
908
1576
  self,
909
1577
  api_key: str = os.getenv("JUDGMENT_API_KEY"),
910
- project_name: str = "default_project",
1578
+ project_name: str = None,
911
1579
  rules: Optional[List[Rule]] = None, # Added rules parameter
912
1580
  organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
913
1581
  enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
@@ -919,7 +1587,13 @@ class Tracer:
919
1587
  s3_aws_secret_access_key: Optional[str] = None,
920
1588
  s3_region_name: Optional[str] = None,
921
1589
  offline_mode: bool = False,
922
- deep_tracing: bool = True # Deep tracing is enabled by default
1590
+ deep_tracing: bool = True, # Deep tracing is enabled by default
1591
+ trace_across_async_contexts: bool = False, # BY default, we don't trace across async contexts
1592
+ # Background span service configuration
1593
+ enable_background_spans: bool = True, # Enable background span service by default
1594
+ span_batch_size: int = 50, # Number of spans to batch before sending
1595
+ span_flush_interval: float = 1.0, # Time in seconds between automatic flushes
1596
+ span_num_workers: int = 10 # Number of worker threads for span processing
923
1597
  ):
924
1598
  if not hasattr(self, 'initialized'):
925
1599
  if not api_key:
@@ -935,16 +1609,20 @@ class Tracer:
935
1609
  raise ValueError("S3 bucket name must be provided when use_s3 is True")
936
1610
 
937
1611
  self.api_key: str = api_key
938
- self.project_name: str = project_name
1612
+ self.project_name: str = project_name or str(uuid.uuid4())
939
1613
  self.organization_id: str = organization_id
940
- self._current_trace: Optional[str] = None
941
- self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
942
1614
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
943
1615
  self.traces: List[Trace] = []
944
1616
  self.initialized: bool = True
945
1617
  self.enable_monitoring: bool = enable_monitoring
946
1618
  self.enable_evaluations: bool = enable_evaluations
947
1619
  self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
1620
+ self.span_id_to_previous_span_id: Dict[str, str] = {}
1621
+ self.trace_id_to_previous_trace: Dict[str, TraceClient] = {}
1622
+ self.current_span_id: Optional[str] = None
1623
+ self.current_trace: Optional[TraceClient] = None
1624
+ self.trace_across_async_contexts: bool = trace_across_async_contexts
1625
+ Tracer.trace_across_async_contexts = trace_across_async_contexts
948
1626
 
949
1627
  # Initialize S3 storage if enabled
950
1628
  self.use_s3 = use_s3
@@ -958,6 +1636,18 @@ class Tracer:
958
1636
  )
959
1637
  self.offline_mode: bool = offline_mode
960
1638
  self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1639
+
1640
+ # Initialize background span service
1641
+ self.enable_background_spans: bool = enable_background_spans
1642
+ self.background_span_service: Optional[BackgroundSpanService] = None
1643
+ if enable_background_spans and not offline_mode:
1644
+ self.background_span_service = BackgroundSpanService(
1645
+ judgment_api_key=api_key,
1646
+ organization_id=organization_id,
1647
+ batch_size=span_batch_size,
1648
+ flush_interval=span_flush_interval,
1649
+ num_workers=span_num_workers
1650
+ )
961
1651
 
962
1652
  elif hasattr(self, 'project_name') and self.project_name != project_name:
963
1653
  warnings.warn(
@@ -968,16 +1658,44 @@ class Tracer:
968
1658
  )
969
1659
 
970
1660
  def set_current_span(self, span_id: str):
1661
+ self.span_id_to_previous_span_id[span_id] = getattr(self, 'current_span_id', None)
971
1662
  self.current_span_id = span_id
1663
+ Tracer.current_span_id = span_id
1664
+ try:
1665
+ token = current_span_var.set(span_id)
1666
+ except:
1667
+ token = None
1668
+ return token
972
1669
 
973
1670
  def get_current_span(self) -> Optional[str]:
974
- return getattr(self, 'current_span_id', None)
1671
+ try:
1672
+ current_span_var_val = current_span_var.get()
1673
+ except:
1674
+ current_span_var_val = None
1675
+ return (self.current_span_id or current_span_var_val) if self.trace_across_async_contexts else current_span_var_val
1676
+
1677
+ def reset_current_span(self, token: Optional[str] = None, span_id: Optional[str] = None):
1678
+ if not span_id:
1679
+ span_id = self.current_span_id
1680
+ try:
1681
+ current_span_var.reset(token)
1682
+ except:
1683
+ pass
1684
+ self.current_span_id = self.span_id_to_previous_span_id.get(span_id)
1685
+ Tracer.current_span_id = self.current_span_id
975
1686
 
976
1687
  def set_current_trace(self, trace: TraceClient):
977
1688
  """
978
1689
  Set the current trace context in contextvars
979
1690
  """
980
- current_trace_var.set(trace)
1691
+ self.trace_id_to_previous_trace[trace.trace_id] = getattr(self, 'current_trace', None)
1692
+ self.current_trace = trace
1693
+ Tracer.current_trace = trace
1694
+ try:
1695
+ token = current_trace_var.set(trace)
1696
+ except:
1697
+ token = None
1698
+ return token
981
1699
 
982
1700
  def get_current_trace(self) -> Optional[TraceClient]:
983
1701
  """
@@ -987,23 +1705,34 @@ class Tracer:
987
1705
  If not found (e.g., context lost across threads/tasks),
988
1706
  it falls back to the active trace client managed by the callback handler.
989
1707
  """
990
- trace_from_context = current_trace_var.get()
991
- if trace_from_context:
992
- return trace_from_context
1708
+ try:
1709
+ current_trace_var_val = current_trace_var.get()
1710
+ except:
1711
+ current_trace_var_val = None
1712
+
1713
+ # Use context variable or class variable based on trace_across_async_contexts setting
1714
+ context_trace = (self.current_trace or current_trace_var_val) if self.trace_across_async_contexts else current_trace_var_val
1715
+
1716
+ # If we found a trace from context, return it
1717
+ if context_trace:
1718
+ return context_trace
993
1719
 
994
- # Fallback: Check the active client potentially set by a callback handler
1720
+ # Fallback: Check the active client potentially set by a callback handler (e.g., LangGraph)
995
1721
  if hasattr(self, '_active_trace_client') and self._active_trace_client:
996
- # warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
997
1722
  return self._active_trace_client
998
1723
 
999
- # If neither is available
1000
- # warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
1724
+ # If neither is available, return None
1001
1725
  return None
1002
-
1003
- def get_active_trace_client(self) -> Optional[TraceClient]:
1004
- """Returns the TraceClient instance currently marked as active by the handler."""
1005
- return self._active_trace_client
1006
-
1726
+
1727
+ def reset_current_trace(self, token: Optional[str] = None, trace_id: Optional[str] = None):
1728
+ if not trace_id and self.current_trace:
1729
+ trace_id = self.current_trace.trace_id
1730
+ try:
1731
+ current_trace_var.reset(token)
1732
+ except:
1733
+ pass
1734
+ self.current_trace = self.trace_id_to_previous_trace.get(trace_id)
1735
+ Tracer.current_trace = self.current_trace
1007
1736
 
1008
1737
  @contextmanager
1009
1738
  def trace(
@@ -1018,7 +1747,7 @@ class Tracer:
1018
1747
  project = project_name if project_name is not None else self.project_name
1019
1748
 
1020
1749
  # Get parent trace info from context
1021
- parent_trace = current_trace_var.get()
1750
+ parent_trace = self.get_current_trace()
1022
1751
  parent_trace_id = None
1023
1752
  parent_name = None
1024
1753
 
@@ -1040,7 +1769,7 @@ class Tracer:
1040
1769
  )
1041
1770
 
1042
1771
  # Set the current trace in context variables
1043
- token = current_trace_var.set(trace)
1772
+ token = self.set_current_trace(trace)
1044
1773
 
1045
1774
  # Automatically create top-level span
1046
1775
  with trace.span(name or "unnamed_trace") as span:
@@ -1049,13 +1778,13 @@ class Tracer:
1049
1778
  yield trace
1050
1779
  finally:
1051
1780
  # Reset the context variable
1052
- current_trace_var.reset(token)
1781
+ self.reset_current_trace(token)
1053
1782
 
1054
1783
 
1055
1784
  def log(self, msg: str, label: str = "log", score: int = 1):
1056
1785
  """Log a message with the current span context"""
1057
- current_span_id = current_span_var.get()
1058
- current_trace = current_trace_var.get()
1786
+ current_span_id = self.get_current_span()
1787
+ current_trace = self.get_current_trace()
1059
1788
  if current_span_id:
1060
1789
  annotation = TraceAnnotation(
1061
1790
  span_id=current_span_id,
@@ -1068,32 +1797,92 @@ class Tracer:
1068
1797
 
1069
1798
  rprint(f"[bold]{label}:[/bold] {msg}")
1070
1799
 
1071
- def identify(self, identifier: str):
1800
+ def identify(self, identifier: str, track_state: bool = False, track_attributes: Optional[List[str]] = None, field_mappings: Optional[Dict[str, str]] = None):
1072
1801
  """
1073
- Class decorator that associates a class with a custom identifier.
1802
+ Class decorator that associates a class with a custom identifier and enables state tracking.
1074
1803
 
1075
1804
  This decorator creates a mapping between the class name and the provided
1076
1805
  identifier, which can be useful for tagging, grouping, or referencing
1077
- classes in a standardized way.
1806
+ classes in a standardized way. It also enables automatic state capture
1807
+ for instances of the decorated class when used with tracing.
1078
1808
 
1079
1809
  Args:
1080
- identifier: The identifier to associate with the decorated class
1081
-
1082
- Returns:
1083
- A decorator function that registers the class with the given identifier
1810
+ identifier: The identifier to associate with the decorated class.
1811
+ This will be used as the instance name in traces.
1812
+ track_state: Whether to automatically capture the state (attributes)
1813
+ of instances before and after function execution. Defaults to False.
1814
+ track_attributes: Optional list of specific attribute names to track.
1815
+ If None, all non-private attributes (not starting with '_')
1816
+ will be tracked when track_state=True.
1817
+ field_mappings: Optional dictionary mapping internal attribute names to
1818
+ display names in the captured state. For example:
1819
+ {"system_prompt": "instructions"} will capture the
1820
+ 'instructions' attribute as 'system_prompt' in the state.
1084
1821
 
1085
1822
  Example:
1086
- @tracer.identify(identifier="user_model")
1823
+ @tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
1087
1824
  class User:
1088
1825
  # Class implementation
1089
1826
  """
1090
1827
  def decorator(cls):
1091
1828
  class_name = cls.__name__
1092
- self.class_identifiers[class_name] = identifier
1829
+ self.class_identifiers[class_name] = {
1830
+ "identifier": identifier,
1831
+ "track_state": track_state,
1832
+ "track_attributes": track_attributes,
1833
+ "field_mappings": field_mappings or {}
1834
+ }
1093
1835
  return cls
1094
1836
 
1095
1837
  return decorator
1096
1838
 
1839
+ def _capture_instance_state(self, instance: Any, class_config: Dict[str, Any]) -> Dict[str, Any]:
1840
+ """
1841
+ Capture the state of an instance based on class configuration.
1842
+ Args:
1843
+ instance: The instance to capture the state of.
1844
+ class_config: Configuration dictionary for state capture,
1845
+ expected to contain 'track_attributes' and 'field_mappings'.
1846
+ """
1847
+ track_attributes = class_config.get('track_attributes')
1848
+ field_mappings = class_config.get('field_mappings')
1849
+
1850
+ if track_attributes:
1851
+
1852
+ state = {attr: getattr(instance, attr, None) for attr in track_attributes}
1853
+ else:
1854
+
1855
+ state = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')}
1856
+
1857
+ if field_mappings:
1858
+ state['field_mappings'] = field_mappings
1859
+
1860
+ return state
1861
+
1862
+
1863
+ def _get_instance_state_if_tracked(self, args):
1864
+ """
1865
+ Extract instance state if the instance should be tracked.
1866
+
1867
+ Returns the captured state dict if tracking is enabled, None otherwise.
1868
+ """
1869
+ if args and hasattr(args[0], '__class__'):
1870
+ instance = args[0]
1871
+ class_name = instance.__class__.__name__
1872
+ if (class_name in self.class_identifiers and
1873
+ isinstance(self.class_identifiers[class_name], dict) and
1874
+ self.class_identifiers[class_name].get('track_state', False)):
1875
+ return self._capture_instance_state(instance, self.class_identifiers[class_name])
1876
+
1877
+ def _conditionally_capture_and_record_state(self, trace_client_instance: TraceClient, args: tuple, is_before: bool):
1878
+ """Captures instance state if tracked and records it via the trace_client."""
1879
+ state = self._get_instance_state_if_tracked(args)
1880
+ if state:
1881
+ if is_before:
1882
+ trace_client_instance.record_state_before(state)
1883
+ else:
1884
+ trace_client_instance.record_state_after(state)
1885
+
1097
1886
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1098
1887
  """
1099
1888
  Decorator to trace function execution with detailed entry/exit information.
@@ -1139,7 +1928,7 @@ class Tracer:
1139
1928
  agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1140
1929
 
1141
1930
  # Get current trace from context
1142
- current_trace = current_trace_var.get()
1931
+ current_trace = self.get_current_trace()
1143
1932
 
1144
1933
  # If there's no current trace, create a root trace
1145
1934
  if not current_trace:
@@ -1160,7 +1949,7 @@ class Tracer:
1160
1949
 
1161
1950
  # Save empty trace and set trace context
1162
1951
  # current_trace.save(empty_save=True, overwrite=overwrite)
1163
- trace_token = current_trace_var.set(current_trace)
1952
+ trace_token = self.set_current_trace(current_trace)
1164
1953
 
1165
1954
  try:
1166
1955
  # Use span for the function execution within the root trace
@@ -1171,9 +1960,12 @@ class Tracer:
1171
1960
  span.record_input(inputs)
1172
1961
  if agent_name:
1173
1962
  span.record_agent_name(agent_name)
1963
+
1964
+ # Capture state before execution
1965
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
1174
1966
 
1175
1967
  if use_deep_tracing:
1176
- with _DeepTracer():
1968
+ with _DeepTracer(self):
1177
1969
  result = await func(*args, **kwargs)
1178
1970
  else:
1179
1971
  try:
@@ -1181,17 +1973,39 @@ class Tracer:
1181
1973
  except Exception as e:
1182
1974
  _capture_exception_for_trace(current_trace, sys.exc_info())
1183
1975
  raise e
1184
-
1976
+
1977
+ # Capture state after execution
1978
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
1979
+
1185
1980
  # Record output
1186
1981
  span.record_output(result)
1187
1982
  return result
1188
1983
  finally:
1984
+ # Flush background spans before saving the trace
1985
+
1986
+ complete_trace_data = {
1987
+ "trace_id": current_trace.trace_id,
1988
+ "name": current_trace.name,
1989
+ "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
1990
+ "duration": current_trace.get_duration(),
1991
+ "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
1992
+ "overwrite": overwrite,
1993
+ "offline_mode": self.offline_mode,
1994
+ "parent_trace_id": current_trace.parent_trace_id,
1995
+ "parent_name": current_trace.parent_name
1996
+ }
1189
1997
  # Save the completed trace
1190
- trace_id, trace = current_trace.save(overwrite=overwrite)
1191
- self.traces.append(trace)
1998
+ trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
1999
+
2000
+ # Store the complete trace data instead of just server response
2001
+
2002
+ self.traces.append(complete_trace_data)
2003
+
2004
+ # if self.background_span_service:
2005
+ # self.background_span_service.flush()
1192
2006
 
1193
2007
  # Reset trace context (span context resets automatically)
1194
- current_trace_var.reset(trace_token)
2008
+ self.reset_current_trace(trace_token)
1195
2009
  else:
1196
2010
  with current_trace.span(span_name, span_type=span_type) as span:
1197
2011
  inputs = combine_args_kwargs(func, args, kwargs)
@@ -1199,8 +2013,11 @@ class Tracer:
1199
2013
  if agent_name:
1200
2014
  span.record_agent_name(agent_name)
1201
2015
 
2016
+ # Capture state before execution
2017
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
2018
+
1202
2019
  if use_deep_tracing:
1203
- with _DeepTracer():
2020
+ with _DeepTracer(self):
1204
2021
  result = await func(*args, **kwargs)
1205
2022
  else:
1206
2023
  try:
@@ -1208,6 +2025,9 @@ class Tracer:
1208
2025
  except Exception as e:
1209
2026
  _capture_exception_for_trace(current_trace, sys.exc_info())
1210
2027
  raise e
2028
+
2029
+ # Capture state after execution
2030
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
1211
2031
 
1212
2032
  span.record_output(result)
1213
2033
  return result
@@ -1226,7 +2046,7 @@ class Tracer:
1226
2046
  class_name = args[0].__class__.__name__
1227
2047
  agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1228
2048
  # Get current trace from context
1229
- current_trace = current_trace_var.get()
2049
+ current_trace = self.get_current_trace()
1230
2050
 
1231
2051
  # If there's no current trace, create a root trace
1232
2052
  if not current_trace:
@@ -1247,7 +2067,7 @@ class Tracer:
1247
2067
 
1248
2068
  # Save empty trace and set trace context
1249
2069
  # current_trace.save(empty_save=True, overwrite=overwrite)
1250
- trace_token = current_trace_var.set(current_trace)
2070
+ trace_token = self.set_current_trace(current_trace)
1251
2071
 
1252
2072
  try:
1253
2073
  # Use span for the function execution within the root trace
@@ -1258,8 +2078,11 @@ class Tracer:
1258
2078
  span.record_input(inputs)
1259
2079
  if agent_name:
1260
2080
  span.record_agent_name(agent_name)
2081
+ # Capture state before execution
2082
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
2083
+
1261
2084
  if use_deep_tracing:
1262
- with _DeepTracer():
2085
+ with _DeepTracer(self):
1263
2086
  result = func(*args, **kwargs)
1264
2087
  else:
1265
2088
  try:
@@ -1267,17 +2090,36 @@ class Tracer:
1267
2090
  except Exception as e:
1268
2091
  _capture_exception_for_trace(current_trace, sys.exc_info())
1269
2092
  raise e
2093
+
2094
+ # Capture state after execution
2095
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
2096
+
1270
2097
 
1271
2098
  # Record output
1272
2099
  span.record_output(result)
1273
2100
  return result
1274
2101
  finally:
1275
- # Save the completed trace
1276
- trace_id, trace = current_trace.save(overwrite=overwrite)
1277
- self.traces.append(trace)
2102
+ # Flush background spans before saving the trace
1278
2103
 
2104
+
2105
+ # Save the completed trace
2106
+ trace_id, server_response = current_trace.save_with_rate_limiting(overwrite=overwrite, final_save=True)
2107
+
2108
+ # Store the complete trace data instead of just server response
2109
+ complete_trace_data = {
2110
+ "trace_id": current_trace.trace_id,
2111
+ "name": current_trace.name,
2112
+ "created_at": datetime.utcfromtimestamp(current_trace.start_time).isoformat(),
2113
+ "duration": current_trace.get_duration(),
2114
+ "trace_spans": [span.model_dump() for span in current_trace.trace_spans],
2115
+ "overwrite": overwrite,
2116
+ "offline_mode": self.offline_mode,
2117
+ "parent_trace_id": current_trace.parent_trace_id,
2118
+ "parent_name": current_trace.parent_name
2119
+ }
2120
+ self.traces.append(complete_trace_data)
1279
2121
  # Reset trace context (span context resets automatically)
1280
- current_trace_var.reset(trace_token)
2122
+ self.reset_current_trace(trace_token)
1281
2123
  else:
1282
2124
  with current_trace.span(span_name, span_type=span_type) as span:
1283
2125
 
@@ -1286,8 +2128,11 @@ class Tracer:
1286
2128
  if agent_name:
1287
2129
  span.record_agent_name(agent_name)
1288
2130
 
2131
+ # Capture state before execution
2132
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
2133
+
1289
2134
  if use_deep_tracing:
1290
- with _DeepTracer():
2135
+ with _DeepTracer(self):
1291
2136
  result = func(*args, **kwargs)
1292
2137
  else:
1293
2138
  try:
@@ -1296,11 +2141,63 @@ class Tracer:
1296
2141
  _capture_exception_for_trace(current_trace, sys.exc_info())
1297
2142
  raise e
1298
2143
 
2144
+ # Capture state after execution
2145
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
2146
+
1299
2147
  span.record_output(result)
1300
2148
  return result
1301
2149
 
1302
2150
  return wrapper
1303
2151
 
2152
+ def observe_tools(self, cls=None, *, exclude_methods: Optional[List[str]] = None,
2153
+ include_private: bool = False, warn_on_double_decoration: bool = True):
2154
+ """
2155
+ Automatically adds @observe(span_type="tool") to all methods in a class.
2156
+
2157
+ Args:
2158
+ cls: The class to decorate (automatically provided when used as decorator)
2159
+ exclude_methods: List of method names to skip decorating. Defaults to common magic methods
2160
+ include_private: Whether to decorate methods starting with underscore. Defaults to False
2161
+ warn_on_double_decoration: Whether to print warnings when skipping already-decorated methods. Defaults to True
2162
+ """
2163
+
2164
+ if exclude_methods is None:
2165
+ exclude_methods = ['__init__', '__new__', '__del__', '__str__', '__repr__']
2166
+
2167
+ def decorate_class(cls):
2168
+ if not self.enable_monitoring:
2169
+ return cls
2170
+
2171
+ decorated = []
2172
+ skipped = []
2173
+
2174
+ for name in dir(cls):
2175
+ method = getattr(cls, name)
2176
+
2177
+ if (not callable(method) or
2178
+ name in exclude_methods or
2179
+ (name.startswith('_') and not include_private) or
2180
+ not hasattr(cls, name)):
2181
+ continue
2182
+
2183
+ if hasattr(method, '_judgment_span_name'):
2184
+ skipped.append(name)
2185
+ if warn_on_double_decoration:
2186
+ print(f"Warning: {cls.__name__}.{name} already decorated, skipping")
2187
+ continue
2188
+
2189
+ try:
2190
+ decorated_method = self.observe(method, span_type="tool")
2191
+ setattr(cls, name, decorated_method)
2192
+ decorated.append(name)
2193
+ except Exception as e:
2194
+ if warn_on_double_decoration:
2195
+ print(f"Warning: Failed to decorate {cls.__name__}.{name}: {e}")
2196
+
2197
+ return cls
2198
+
2199
+ return decorate_class if cls is None else decorate_class(cls)
2200
+
1304
2201
  def async_evaluate(self, *args, **kwargs):
1305
2202
  if not self.enable_evaluations:
1306
2203
  return
@@ -1308,13 +2205,7 @@ class Tracer:
1308
2205
  # --- Get trace_id passed explicitly (if any) ---
1309
2206
  passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
1310
2207
 
1311
- # --- Get current trace from context FIRST ---
1312
- current_trace = current_trace_var.get()
1313
-
1314
- # --- Fallback Logic: Use active client only if context var is empty ---
1315
- if not current_trace:
1316
- current_trace = self._active_trace_client # Use the fallback
1317
- # --- End Fallback Logic ---
2208
+ current_trace = self.get_current_trace()
1318
2209
 
1319
2210
  if current_trace:
1320
2211
  # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
@@ -1325,13 +2216,34 @@ class Tracer:
1325
2216
  else:
1326
2217
  warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
1327
2218
 
1328
- def wrap(client: Any) -> Any:
2219
+ def get_background_span_service(self) -> Optional[BackgroundSpanService]:
2220
+ """Get the background span service instance."""
2221
+ return self.background_span_service
2222
+
2223
+ def flush_background_spans(self):
2224
+ """Flush all pending spans in the background service."""
2225
+ if self.background_span_service:
2226
+ self.background_span_service.flush()
2227
+
2228
+ def shutdown_background_service(self):
2229
+ """Shutdown the background span service."""
2230
+ if self.background_span_service:
2231
+ self.background_span_service.shutdown()
2232
+ self.background_span_service = None
2233
+
2234
+ def wrap(client: Any, trace_across_async_contexts: bool = Tracer.trace_across_async_contexts) -> Any:
1329
2235
  """
1330
2236
  Wraps an API client to add tracing capabilities.
1331
2237
  Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1332
2238
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1333
2239
  """
1334
2240
  span_name, original_create, original_responses_create, original_stream = _get_client_config(client)
2241
+
2242
+ def _get_current_trace():
2243
+ if trace_across_async_contexts:
2244
+ return Tracer.current_trace
2245
+ else:
2246
+ return current_trace_var.get()
1335
2247
 
1336
2248
  def _record_input_and_check_streaming(span, kwargs, is_responses=False):
1337
2249
  """Record input and check for streaming"""
@@ -1367,18 +2279,21 @@ def wrap(client: Any) -> Any:
1367
2279
  output, usage = format_func(client, response)
1368
2280
  span.record_output(output)
1369
2281
  span.record_usage(usage)
2282
+
2283
+ # Queue the completed LLM span now that it has all data (input, output, usage)
2284
+ current_trace = _get_current_trace()
2285
+ if current_trace and current_trace.background_span_service:
2286
+ # Get the current span from the trace client
2287
+ current_span_id = current_trace.get_current_span()
2288
+ if current_span_id and current_span_id in current_trace.span_id_to_span:
2289
+ completed_span = current_trace.span_id_to_span[current_span_id]
2290
+ current_trace.background_span_service.queue_span(completed_span, span_state="completed")
2291
+
1370
2292
  return response
1371
2293
 
1372
- def _handle_error(span, e, is_async):
1373
- """Handle and record errors"""
1374
- call_type = "async" if is_async else "sync"
1375
- print(f"Error during wrapped {call_type} API call ({span_name}): {e}")
1376
- span.record_output({"error": str(e)})
1377
- raise
1378
-
1379
2294
  # --- Traced Async Functions ---
1380
2295
  async def traced_create_async(*args, **kwargs):
1381
- current_trace = current_trace_var.get()
2296
+ current_trace = _get_current_trace()
1382
2297
  if not current_trace:
1383
2298
  return await original_create(*args, **kwargs)
1384
2299
 
@@ -1389,11 +2304,12 @@ def wrap(client: Any) -> Any:
1389
2304
  response_or_iterator = await original_create(*args, **kwargs)
1390
2305
  return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
1391
2306
  except Exception as e:
1392
- return _handle_error(span, e, True)
2307
+ _capture_exception_for_trace(span, sys.exc_info())
2308
+ raise e
1393
2309
 
1394
2310
  # Async responses for OpenAI clients
1395
2311
  async def traced_response_create_async(*args, **kwargs):
1396
- current_trace = current_trace_var.get()
2312
+ current_trace = _get_current_trace()
1397
2313
  if not current_trace:
1398
2314
  return await original_responses_create(*args, **kwargs)
1399
2315
 
@@ -1404,11 +2320,12 @@ def wrap(client: Any) -> Any:
1404
2320
  response_or_iterator = await original_responses_create(*args, **kwargs)
1405
2321
  return _format_and_record_output(span, response_or_iterator, is_streaming, True, True)
1406
2322
  except Exception as e:
1407
- return _handle_error(span, e, True)
2323
+ _capture_exception_for_trace(span, sys.exc_info())
2324
+ raise e
1408
2325
 
1409
2326
  # Function replacing .stream() for async clients
1410
2327
  def traced_stream_async(*args, **kwargs):
1411
- current_trace = current_trace_var.get()
2328
+ current_trace = _get_current_trace()
1412
2329
  if not current_trace or not original_stream:
1413
2330
  return original_stream(*args, **kwargs)
1414
2331
 
@@ -1424,7 +2341,7 @@ def wrap(client: Any) -> Any:
1424
2341
 
1425
2342
  # --- Traced Sync Functions ---
1426
2343
  def traced_create_sync(*args, **kwargs):
1427
- current_trace = current_trace_var.get()
2344
+ current_trace = _get_current_trace()
1428
2345
  if not current_trace:
1429
2346
  return original_create(*args, **kwargs)
1430
2347
 
@@ -1435,10 +2352,11 @@ def wrap(client: Any) -> Any:
1435
2352
  response_or_iterator = original_create(*args, **kwargs)
1436
2353
  return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
1437
2354
  except Exception as e:
1438
- return _handle_error(span, e, False)
2355
+ _capture_exception_for_trace(span, sys.exc_info())
2356
+ raise e
1439
2357
 
1440
2358
  def traced_response_create_sync(*args, **kwargs):
1441
- current_trace = current_trace_var.get()
2359
+ current_trace = _get_current_trace()
1442
2360
  if not current_trace:
1443
2361
  return original_responses_create(*args, **kwargs)
1444
2362
 
@@ -1449,11 +2367,12 @@ def wrap(client: Any) -> Any:
1449
2367
  response_or_iterator = original_responses_create(*args, **kwargs)
1450
2368
  return _format_and_record_output(span, response_or_iterator, is_streaming, False, True)
1451
2369
  except Exception as e:
1452
- return _handle_error(span, e, False)
2370
+ _capture_exception_for_trace(span, sys.exc_info())
2371
+ raise e
1453
2372
 
1454
2373
  # Function replacing sync .stream()
1455
2374
  def traced_stream_sync(*args, **kwargs):
1456
- current_trace = current_trace_var.get()
2375
+ current_trace = _get_current_trace()
1457
2376
  if not current_trace or not original_stream:
1458
2377
  return original_stream(*args, **kwargs)
1459
2378
 
@@ -1472,6 +2391,8 @@ def wrap(client: Any) -> Any:
1472
2391
  client.chat.completions.create = traced_create_async
1473
2392
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
1474
2393
  client.responses.create = traced_response_create_async
2394
+ if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2395
+ client.beta.chat.completions.parse = traced_create_async
1475
2396
  elif isinstance(client, AsyncAnthropic):
1476
2397
  client.messages.create = traced_create_async
1477
2398
  if original_stream:
@@ -1482,6 +2403,8 @@ def wrap(client: Any) -> Any:
1482
2403
  client.chat.completions.create = traced_create_sync
1483
2404
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
1484
2405
  client.responses.create = traced_response_create_sync
2406
+ if hasattr(client, "beta") and hasattr(client.beta, "chat") and hasattr(client.beta.chat, "completions") and hasattr(client.beta.chat.completions, "parse"):
2407
+ client.beta.chat.completions.parse = traced_create_sync
1485
2408
  elif isinstance(client, Anthropic):
1486
2409
  client.messages.create = traced_create_sync
1487
2410
  if original_stream:
@@ -1808,6 +2731,15 @@ def _sync_stream_wrapper(
1808
2731
  # Update the trace entry with the accumulated content and usage
1809
2732
  span.output = "".join(content_parts)
1810
2733
  span.usage = final_usage
2734
+
2735
+ # Queue the completed LLM span now that streaming is done and all data is available
2736
+ # Note: We need to get the TraceClient that owns this span to access the background service
2737
+ # We can find this through the tracer singleton since spans are associated with traces
2738
+ from judgeval.common.tracer import Tracer
2739
+ tracer_instance = Tracer._instance
2740
+ if tracer_instance and tracer_instance.background_span_service:
2741
+ tracer_instance.background_span_service.queue_span(span, span_state="completed")
2742
+
1811
2743
  # Note: We might need to adjust _serialize_output if this dict causes issues,
1812
2744
  # but Pydantic's model_dump should handle dicts.
1813
2745
 
@@ -1892,6 +2824,12 @@ async def _async_stream_wrapper(
1892
2824
  span.usage = usage_info
1893
2825
  start_ts = getattr(span, 'created_at', time.time())
1894
2826
  span.duration = time.time() - start_ts
2827
+
2828
+ # Queue the completed LLM span now that async streaming is done and all data is available
2829
+ from judgeval.common.tracer import Tracer
2830
+ tracer_instance = Tracer._instance
2831
+ if tracer_instance and tracer_instance.background_span_service:
2832
+ tracer_instance.background_span_service.queue_span(span, span_state="completed")
1895
2833
  # else: # Handle error case if necessary, but remove debug print
1896
2834
 
1897
2835
  def cost_per_token(*args, **kwargs):
@@ -1940,12 +2878,12 @@ class _BaseStreamManagerWrapper:
1940
2878
 
1941
2879
  class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
1942
2880
  async def __aenter__(self):
1943
- self._parent_span_id_at_entry = current_span_var.get()
2881
+ self._parent_span_id_at_entry = self._trace_client.get_current_span()
1944
2882
  if not self._trace_client:
1945
2883
  return await self._original_manager.__aenter__()
1946
2884
 
1947
2885
  span_id, span = self._create_span()
1948
- self._span_context_token = current_span_var.set(span_id)
2886
+ self._span_context_token = self._trace_client.set_current_span(span_id)
1949
2887
  span.inputs = _format_input_data(self._client, **self._input_kwargs)
1950
2888
 
1951
2889
  # Call the original __aenter__ and expect it to be an async generator
@@ -1955,20 +2893,20 @@ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncC
1955
2893
 
1956
2894
  async def __aexit__(self, exc_type, exc_val, exc_tb):
1957
2895
  if hasattr(self, '_span_context_token'):
1958
- span_id = current_span_var.get()
2896
+ span_id = self._trace_client.get_current_span()
1959
2897
  self._finalize_span(span_id)
1960
- current_span_var.reset(self._span_context_token)
2898
+ self._trace_client.reset_current_span(self._span_context_token)
1961
2899
  delattr(self, '_span_context_token')
1962
2900
  return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
1963
2901
 
1964
2902
  class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
1965
2903
  def __enter__(self):
1966
- self._parent_span_id_at_entry = current_span_var.get()
2904
+ self._parent_span_id_at_entry = self._trace_client.get_current_span()
1967
2905
  if not self._trace_client:
1968
2906
  return self._original_manager.__enter__()
1969
2907
 
1970
2908
  span_id, span = self._create_span()
1971
- self._span_context_token = current_span_var.set(span_id)
2909
+ self._span_context_token = self._trace_client.set_current_span(span_id)
1972
2910
  span.inputs = _format_input_data(self._client, **self._input_kwargs)
1973
2911
 
1974
2912
  raw_iterator = self._original_manager.__enter__()
@@ -1977,9 +2915,9 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
1977
2915
 
1978
2916
  def __exit__(self, exc_type, exc_val, exc_tb):
1979
2917
  if hasattr(self, '_span_context_token'):
1980
- span_id = current_span_var.get()
2918
+ span_id = self._trace_client.get_current_span()
1981
2919
  self._finalize_span(span_id)
1982
- current_span_var.reset(self._span_context_token)
2920
+ self._trace_client.reset_current_span(self._span_context_token)
1983
2921
  delattr(self, '_span_context_token')
1984
2922
  return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
1985
2923
 
@@ -1990,10 +2928,12 @@ def get_instance_prefixed_name(instance, class_name, class_identifiers):
1990
2928
  Otherwise, returns None.
1991
2929
  """
1992
2930
  if class_name in class_identifiers:
1993
- attr = class_identifiers[class_name]
2931
+ class_config = class_identifiers[class_name]
2932
+ attr = class_config['identifier']
2933
+
1994
2934
  if hasattr(instance, attr):
1995
2935
  instance_name = getattr(instance, attr)
1996
2936
  return instance_name
1997
2937
  else:
1998
- raise Exception(f"Attribute {class_identifiers[class_name]} does not exist for {class_name}. Check your identify() decorator.")
1999
- return None
2938
+ raise Exception(f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator.")
2939
+ return None