judgeval 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -7,7 +7,11 @@ import functools
7
7
  import inspect
8
8
  import json
9
9
  import os
10
+ import site
11
+ import sysconfig
12
+ import threading
10
13
  import time
14
+ import traceback
11
15
  import uuid
12
16
  import warnings
13
17
  import contextvars
@@ -35,7 +39,6 @@ from rich import print as rprint
35
39
  import types # <--- Add this import
36
40
 
37
41
  # Third-party imports
38
- import pika
39
42
  import requests
40
43
  from litellm import cost_per_token
41
44
  from pydantic import BaseModel
@@ -44,10 +47,10 @@ from openai import OpenAI, AsyncOpenAI
44
47
  from together import Together, AsyncTogether
45
48
  from anthropic import Anthropic, AsyncAnthropic
46
49
  from google import genai
47
- from judgeval.run_evaluation import check_examples
48
50
 
49
51
  # Local application/library-specific imports
50
52
  from judgeval.constants import (
53
+ JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
51
54
  JUDGMENT_TRACES_SAVE_API_URL,
52
55
  JUDGMENT_TRACES_FETCH_API_URL,
53
56
  RABBITMQ_HOST,
@@ -56,25 +59,24 @@ from judgeval.constants import (
56
59
  JUDGMENT_TRACES_DELETE_API_URL,
57
60
  JUDGMENT_PROJECT_DELETE_API_URL,
58
61
  )
59
- from judgeval.judgment_client import JudgmentClient
60
- from judgeval.data import Example
62
+ from judgeval.data import Example, Trace, TraceSpan
61
63
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
62
64
  from judgeval.rules import Rule
63
65
  from judgeval.evaluation_run import EvaluationRun
64
66
  from judgeval.data.result import ScoringResult
67
+ from judgeval.common.utils import validate_api_key
68
+ from judgeval.common.exceptions import JudgmentAPIError
65
69
 
66
70
  # Standard library imports needed for the new class
67
71
  import concurrent.futures
68
72
  from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
69
73
 
70
74
  # Define context variables for tracking the current trace and the current span within a trace
71
- current_trace_var = contextvars.ContextVar('current_trace', default=None)
75
+ current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
72
76
  current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
73
- in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
74
77
 
75
78
  # Define type aliases for better code readability and maintainability
76
79
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
77
- TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
78
80
  SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
79
81
 
80
82
  # --- Evaluation Config Dataclass (Moved from langgraph.py) ---
@@ -87,154 +89,26 @@ class EvaluationConfig:
87
89
  log_results: Optional[bool] = True
88
90
  # --- End Evaluation Config Dataclass ---
89
91
 
92
+ # Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
90
93
  @dataclass
91
- class TraceEntry:
92
- """Represents a single trace entry with its visual representation.
93
-
94
- Visual representations:
95
- - enter: → (function entry)
96
- - exit: ← (function exit)
97
- - output: Output: (function return value)
98
- - input: Input: (function parameters)
99
- - evaluation: Evaluation: (evaluation results)
100
- """
101
- type: TraceEntryType
102
- span_id: str # Unique ID for this specific span instance
103
- depth: int # Indentation level for nested calls
104
- created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
105
- function: Optional[str] = None # Name of the function being traced
106
- message: Optional[str] = None # Human-readable description
107
- duration: Optional[float] = None # Time taken (for exit/evaluation entries)
108
- trace_id: str = None # ID of the trace this entry belongs to
109
- output: Any = None # Function output value
110
- # Use field() for mutable defaults to avoid shared state issues
111
- inputs: dict = field(default_factory=dict)
112
- span_type: SpanType = "span"
113
- evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
114
- parent_span_id: Optional[str] = None # ID of the parent span instance
115
-
116
- def print_entry(self):
117
- """Print a trace entry with proper formatting and parent relationship information."""
118
- indent = " " * self.depth
119
-
120
- if self.type == "enter":
121
- # Format parent info if present
122
- parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
123
- print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
124
- elif self.type == "exit":
125
- print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
126
- elif self.type == "output":
127
- # Format output to align properly
128
- output_str = str(self.output)
129
- print(f"{indent}Output (for id: {self.span_id}): {output_str}")
130
- elif self.type == "input":
131
- # Format inputs to align properly
132
- print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
133
- elif self.type == "evaluation":
134
- for evaluation_run in self.evaluation_runs:
135
- print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
136
-
137
- def _serialize_inputs(self) -> dict:
138
- """Helper method to serialize input data safely.
139
-
140
- Returns a dict with serializable versions of inputs, converting non-serializable
141
- objects to None with a warning.
142
- """
143
- serialized_inputs = {}
144
- for key, value in self.inputs.items():
145
- if isinstance(value, BaseModel):
146
- serialized_inputs[key] = value.model_dump()
147
- elif isinstance(value, (list, tuple)):
148
- # Handle lists/tuples of arguments
149
- serialized_inputs[key] = [
150
- item.model_dump() if isinstance(item, BaseModel)
151
- else None if not self._is_json_serializable(item)
152
- else item
153
- for item in value
154
- ]
155
- else:
156
- if self._is_json_serializable(value):
157
- serialized_inputs[key] = value
158
- else:
159
- serialized_inputs[key] = self.safe_stringify(value, self.function)
160
- return serialized_inputs
161
-
162
- def _is_json_serializable(self, obj: Any) -> bool:
163
- """Helper method to check if an object is JSON serializable."""
164
- try:
165
- json.dumps(obj)
166
- return True
167
- except (TypeError, OverflowError, ValueError):
168
- return False
169
-
170
- def safe_stringify(self, output, function_name):
171
- """
172
- Safely converts an object to a string or repr, handling serialization issues gracefully.
173
- """
174
- try:
175
- return str(output)
176
- except (TypeError, OverflowError, ValueError):
177
- pass
178
-
179
- try:
180
- return repr(output)
181
- except (TypeError, OverflowError, ValueError):
182
- pass
183
-
184
- warnings.warn(
185
- f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
186
- )
187
- return None
94
+ class TraceAnnotation:
95
+ """Represents a single annotation for a trace span."""
96
+ span_id: str
97
+ text: str
98
+ label: str
99
+ score: int
188
100
 
189
101
  def to_dict(self) -> dict:
190
- """Convert the trace entry to a dictionary format for storage/transmission."""
102
+ """Convert the annotation to a dictionary format for storage/transmission."""
191
103
  return {
192
- "type": self.type,
193
- "function": self.function,
194
104
  "span_id": self.span_id,
195
- "trace_id": self.trace_id,
196
- "depth": self.depth,
197
- "message": self.message,
198
- "created_at": datetime.fromtimestamp(self.created_at).isoformat(),
199
- "duration": self.duration,
200
- "output": self._serialize_output(),
201
- "inputs": self._serialize_inputs(),
202
- "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
203
- "span_type": self.span_type,
204
- "parent_span_id": self.parent_span_id,
105
+ "annotation": {
106
+ "text": self.text,
107
+ "label": self.label,
108
+ "score": self.score
109
+ }
205
110
  }
206
-
207
- def _serialize_output(self) -> Any:
208
- """Helper method to serialize output data safely.
209
-
210
- Handles special cases:
211
- - Pydantic models are converted using model_dump()
212
- - Dictionaries are processed recursively to handle non-serializable values.
213
- - We try to serialize into JSON, then string, then the base representation (__repr__)
214
- - Non-serializable objects return None with a warning
215
- """
216
-
217
- def serialize_value(value):
218
- if isinstance(value, BaseModel):
219
- return value.model_dump()
220
- elif isinstance(value, dict):
221
- # Recursively serialize dictionary values
222
- return {k: serialize_value(v) for k, v in value.items()}
223
- elif isinstance(value, (list, tuple)):
224
- # Recursively serialize list/tuple items
225
- return [serialize_value(item) for item in value]
226
- else:
227
- # Try direct JSON serialization first
228
- try:
229
- json.dumps(value)
230
- return value
231
- except (TypeError, OverflowError, ValueError):
232
- # Fallback to safe stringification
233
- return self.safe_stringify(value, self.function)
234
-
235
- # Start serialization with the top-level output
236
- return serialize_value(self.output)
237
-
111
+
238
112
  class TraceManagerClient:
239
113
  """
240
114
  Client for handling trace endpoints with the Judgment API
@@ -271,8 +145,6 @@ class TraceManagerClient:
271
145
  raise ValueError(f"Failed to fetch traces: {response.text}")
272
146
 
273
147
  return response.json()
274
-
275
-
276
148
 
277
149
  def save_trace(self, trace_data: dict):
278
150
  """
@@ -315,6 +187,33 @@ class TraceManagerClient:
315
187
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
316
188
  rprint(pretty_str)
317
189
 
190
+ ## TODO: Should have a log endpoint, endpoint should also support batched payloads
191
+ def save_annotation(self, annotation: TraceAnnotation):
192
+ json_data = {
193
+ "span_id": annotation.span_id,
194
+ "annotation": {
195
+ "text": annotation.text,
196
+ "label": annotation.label,
197
+ "score": annotation.score
198
+ }
199
+ }
200
+
201
+ response = requests.post(
202
+ JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
203
+ json=json_data,
204
+ headers={
205
+ 'Content-Type': 'application/json',
206
+ 'Authorization': f'Bearer {self.judgment_api_key}',
207
+ 'X-Organization-Id': self.organization_id
208
+ },
209
+ verify=True
210
+ )
211
+
212
+ if response.status_code != HTTPStatus.OK:
213
+ raise ValueError(f"Failed to save annotation: {response.text}")
214
+
215
+ return response.json()
216
+
318
217
  def delete_trace(self, trace_id: str):
319
218
  """
320
219
  Delete a trace from the database.
@@ -405,15 +304,16 @@ class TraceClient:
405
304
  self.enable_evaluations = enable_evaluations
406
305
  self.parent_trace_id = parent_trace_id
407
306
  self.parent_name = parent_name
408
- self.client: JudgmentClient = tracer.client
409
- self.entries: List[TraceEntry] = []
307
+ self.trace_spans: List[TraceSpan] = []
308
+ self.span_id_to_span: Dict[str, TraceSpan] = {}
309
+ self.evaluation_runs: List[EvaluationRun] = []
310
+ self.annotations: List[TraceAnnotation] = []
410
311
  self.start_time = time.time()
411
312
  self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
412
313
  self.visited_nodes = []
413
314
  self.executed_tools = []
414
315
  self.executed_node_tools = []
415
316
  self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
416
-
417
317
  def get_current_span(self):
418
318
  """Get the current span from the context var"""
419
319
  return current_span_var.get()
@@ -443,9 +343,7 @@ class TraceClient:
443
343
 
444
344
  self._span_depths[span_id] = current_depth # Store depth by span_id
445
345
 
446
- entry = TraceEntry(
447
- type="enter",
448
- function=name,
346
+ span = TraceSpan(
449
347
  span_id=span_id,
450
348
  trace_id=self.trace_id,
451
349
  depth=current_depth,
@@ -453,25 +351,15 @@ class TraceClient:
453
351
  created_at=start_time,
454
352
  span_type=span_type,
455
353
  parent_span_id=parent_span_id,
354
+ function=name,
456
355
  )
457
- self.add_entry(entry)
356
+ self.add_span(span)
458
357
 
459
358
  try:
460
359
  yield self
461
360
  finally:
462
361
  duration = time.time() - start_time
463
- exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
464
- self.add_entry(TraceEntry(
465
- type="exit",
466
- function=name,
467
- span_id=span_id, # Use the same span_id for exit
468
- trace_id=self.trace_id, # Use the trace_id from the trace client
469
- depth=exit_depth,
470
- message=f"← {name}",
471
- created_at=time.time(),
472
- duration=duration,
473
- span_type=span_type,
474
- ))
362
+ span.duration = duration
475
363
  # Clean up depth tracking for this span_id
476
364
  if span_id in self._span_depths:
477
365
  del self._span_depths[span_id]
@@ -528,13 +416,13 @@ class TraceClient:
528
416
  tools_called=tools_called,
529
417
  expected_tools=expected_tools,
530
418
  additional_metadata=additional_metadata,
531
- trace_id=self.trace_id
532
419
  )
533
420
  else:
534
421
  raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
535
422
 
536
423
  # Check examples before creating evaluation run
537
- check_examples([example], scorers)
424
+
425
+ # check_examples([example], scorers)
538
426
 
539
427
  # --- Modification: Capture span_id immediately ---
540
428
  # span_id_at_eval_call = current_span_var.get()
@@ -571,290 +459,60 @@ class TraceClient:
571
459
  # --- End Modification ---
572
460
 
573
461
  if current_span_id:
574
- duration = time.time() - start_time
575
- prev_entry = self.entries[-1] if self.entries else None
576
- # Determine function name based on previous entry or context var (less ideal)
577
- function_name = "unknown_function" # Default
578
- if prev_entry and prev_entry.span_type == "llm":
579
- function_name = prev_entry.function
580
- else:
581
- # Try to find the function name associated with the current span_id
582
- for entry in reversed(self.entries):
583
- if entry.span_id == current_span_id and entry.type == 'enter':
584
- function_name = entry.function
585
- break
586
-
587
- # Get depth for the current span
588
- current_depth = self._span_depths.get(current_span_id, 0)
589
-
590
- self.add_entry(TraceEntry(
591
- type="evaluation",
592
- function=function_name,
593
- span_id=current_span_id, # Associate with current span
594
- trace_id=self.trace_id, # Use the trace_id from the trace client
595
- depth=current_depth,
596
- message=f"Evaluation results for {function_name}",
597
- created_at=time.time(),
598
- evaluation_runs=[eval_run],
599
- duration=duration,
600
- span_type="evaluation"
601
- ))
602
-
462
+ span = self.span_id_to_span[current_span_id]
463
+ span.evaluation_runs.append(eval_run)
464
+ self.evaluation_runs.append(eval_run)
465
+
466
+ def add_annotation(self, annotation: TraceAnnotation):
467
+ """Add an annotation to this trace context"""
468
+ self.annotations.append(annotation)
469
+ return self
470
+
603
471
  def record_input(self, inputs: dict):
604
472
  current_span_id = current_span_var.get()
605
473
  if current_span_id:
606
- entry_span_type = "span"
607
- current_depth = self._span_depths.get(current_span_id, 0)
608
- function_name = "unknown_function" # Default
609
- for entry in reversed(self.entries):
610
- if entry.span_id == current_span_id and entry.type == 'enter':
611
- entry_span_type = entry.span_type
612
- function_name = entry.function
613
- break
614
-
615
- self.add_entry(TraceEntry(
616
- type="input",
617
- function=function_name,
618
- span_id=current_span_id, # Use current span_id from context
619
- trace_id=self.trace_id, # Use the trace_id from the trace client
620
- depth=current_depth,
621
- message=f"Inputs to {function_name}",
622
- created_at=time.time(),
623
- inputs=inputs,
624
- span_type=entry_span_type,
625
- ))
626
- # Removed else block - original didn't have one
474
+ span = self.span_id_to_span[current_span_id]
475
+ span.inputs = inputs
627
476
 
628
- async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
477
+ async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
629
478
  """Helper method to update the output of a trace entry once the coroutine completes"""
630
479
  try:
631
480
  result = await coroutine
632
- entry.output = result
481
+ span.output = result
633
482
  return result
634
483
  except Exception as e:
635
- entry.output = f"Error: {str(e)}"
484
+ span.output = f"Error: {str(e)}"
636
485
  raise
637
486
 
638
487
  def record_output(self, output: Any):
639
488
  current_span_id = current_span_var.get()
640
489
  if current_span_id:
641
- entry_span_type = "span"
642
- current_depth = self._span_depths.get(current_span_id, 0)
643
- function_name = "unknown_function" # Default
644
- for entry in reversed(self.entries):
645
- if entry.span_id == current_span_id and entry.type == 'enter':
646
- entry_span_type = entry.span_type
647
- function_name = entry.function
648
- break
649
-
650
- entry = TraceEntry(
651
- type="output",
652
- function=function_name,
653
- span_id=current_span_id, # Use current span_id from context
654
- depth=current_depth,
655
- message=f"Output from {function_name}",
656
- created_at=time.time(),
657
- output="<pending>" if inspect.iscoroutine(output) else output,
658
- span_type=entry_span_type,
659
- trace_id=self.trace_id # Added trace_id for consistency
660
- )
661
- self.add_entry(entry)
490
+ span = self.span_id_to_span[current_span_id]
491
+ span.output = "<pending>" if inspect.iscoroutine(output) else output
662
492
 
663
493
  if inspect.iscoroutine(output):
664
- asyncio.create_task(self._update_coroutine_output(entry, output))
494
+ asyncio.create_task(self._update_coroutine_output(span, output))
665
495
 
666
- return entry # Return the created entry
496
+ return span # Return the created entry
667
497
  # Removed else block - original didn't have one
668
498
  return None # Return None if no span_id found
669
499
 
670
- def add_entry(self, entry: TraceEntry):
671
- """Add a trace entry to this trace context"""
672
- self.entries.append(entry)
500
+ def add_span(self, span: TraceSpan):
501
+ """Add a trace span to this trace context"""
502
+ self.trace_spans.append(span)
503
+ self.span_id_to_span[span.span_id] = span
673
504
  return self
674
505
 
675
506
  def print(self):
676
507
  """Print the complete trace with proper visual structure"""
677
- for entry in self.entries:
678
- entry.print_entry()
679
-
680
- def print_hierarchical(self):
681
- """Print the trace in a hierarchical structure based on parent-child relationships"""
682
- # First, build a map of spans
683
- spans = {}
684
- root_spans = []
685
-
686
- # Collect all enter events first
687
- for entry in self.entries:
688
- if entry.type == "enter":
689
- spans[entry.function] = {
690
- "name": entry.function,
691
- "depth": entry.depth,
692
- "parent_id": entry.parent_span_id,
693
- "children": []
694
- }
695
-
696
- # If no parent, it's a root span
697
- if not entry.parent_span_id:
698
- root_spans.append(entry.function)
699
- elif entry.parent_span_id not in spans:
700
- # If parent doesn't exist yet, temporarily treat as root
701
- # (we'll fix this later)
702
- root_spans.append(entry.function)
703
-
704
- # Build parent-child relationships
705
- for span_name, span in spans.items():
706
- parent = span["parent_id"]
707
- if parent and parent in spans:
708
- spans[parent]["children"].append(span_name)
709
- # Remove from root spans if it was temporarily there
710
- if span_name in root_spans:
711
- root_spans.remove(span_name)
712
-
713
- # Now print the hierarchy
714
- def print_span(span_name, level=0):
715
- if span_name not in spans:
716
- return
717
-
718
- span = spans[span_name]
719
- indent = " " * level
720
- parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
721
- print(f"{indent}→ {span_name}{parent_info}")
722
-
723
- # Print children
724
- for child in span["children"]:
725
- print_span(child, level + 1)
726
-
727
- # Print starting with root spans
728
- print("\nHierarchical Trace Structure:")
729
- for root in root_spans:
730
- print_span(root)
508
+ for span in self.trace_spans:
509
+ span.print_span()
731
510
 
732
511
  def get_duration(self) -> float:
733
512
  """
734
513
  Get the total duration of this trace
735
514
  """
736
515
  return time.time() - self.start_time
737
-
738
- def condense_trace(self, entries: List[dict]) -> List[dict]:
739
- """
740
- Condenses trace entries into a single entry for each span instance,
741
- preserving parent-child span relationships using span_id and parent_span_id.
742
- """
743
- spans_by_id: Dict[str, dict] = {}
744
- evaluation_runs: List[EvaluationRun] = []
745
-
746
- # First pass: Group entries by span_id and gather data
747
- for entry in entries:
748
- span_id = entry.get("span_id")
749
- if not span_id:
750
- continue # Skip entries without a span_id (should not happen)
751
-
752
- if entry["type"] == "enter":
753
- if span_id not in spans_by_id:
754
- spans_by_id[span_id] = {
755
- "span_id": span_id,
756
- "function": entry["function"],
757
- "depth": entry["depth"], # Use the depth recorded at entry time
758
- "created_at": entry["created_at"],
759
- "trace_id": entry["trace_id"],
760
- "parent_span_id": entry.get("parent_span_id"),
761
- "span_type": entry.get("span_type", "span"),
762
- "inputs": None,
763
- "output": None,
764
- "evaluation_runs": [],
765
- "duration": None
766
- }
767
- # Handle potential duplicate enter events if necessary (e.g., log warning)
768
-
769
- elif span_id in spans_by_id:
770
- current_span_data = spans_by_id[span_id]
771
-
772
- if entry["type"] == "input" and entry["inputs"]:
773
- # Merge inputs if multiple are recorded, or just assign
774
- if current_span_data["inputs"] is None:
775
- current_span_data["inputs"] = entry["inputs"]
776
- elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
777
- current_span_data["inputs"].update(entry["inputs"])
778
- # Add more sophisticated merging if needed
779
-
780
- elif entry["type"] == "output" and "output" in entry:
781
- current_span_data["output"] = entry["output"]
782
-
783
- elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
784
- if current_span_data.get("evaluation_runs") is not None:
785
- evaluation_runs.extend(entry["evaluation_runs"])
786
-
787
- elif entry["type"] == "exit":
788
- if current_span_data["duration"] is None: # Calculate duration only once
789
- start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
790
- end_time = datetime.fromisoformat(entry["created_at"])
791
- current_span_data["duration"] = (end_time - start_time).total_seconds()
792
- # Update depth if exit depth is different (though current span() implementation keeps it same)
793
- # current_span_data["depth"] = entry["depth"]
794
-
795
- # Convert dictionary to a list initially for easier access
796
- spans_list = list(spans_by_id.values())
797
-
798
- # Build tree structure (adjacency list) and find roots
799
- children_map: Dict[Optional[str], List[dict]] = {}
800
- roots = []
801
- span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
802
-
803
- for span in spans_list:
804
- parent_id = span.get("parent_span_id")
805
- if parent_id is None:
806
- roots.append(span)
807
- else:
808
- if parent_id not in children_map:
809
- children_map[parent_id] = []
810
- children_map[parent_id].append(span)
811
-
812
- # Sort roots by timestamp
813
- roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
814
-
815
- # Perform depth-first traversal to get the final sorted list
816
- sorted_condensed_list = []
817
- visited = set() # To handle potential cycles, though unlikely with UUIDs
818
-
819
- def dfs(span_data):
820
- span_id = span_data['span_id']
821
- if span_id in visited:
822
- return # Avoid infinite loops in case of cycles
823
- visited.add(span_id)
824
-
825
- sorted_condensed_list.append(span_data) # Add parent before children
826
-
827
- # Get children, sort them by created_at, and visit them
828
- span_children = children_map.get(span_id, [])
829
- span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
830
- for child in span_children:
831
- # Ensure the child exists in our map before recursing
832
- if child['span_id'] in span_map:
833
- dfs(child)
834
- else:
835
- # This case might indicate an issue, but we'll add the child directly
836
- # if its parent was processed but the child itself wasn't in the initial list?
837
- # Or if the child's 'enter' event was missing. For robustness, add it.
838
- if child['span_id'] not in visited:
839
- visited.add(child['span_id'])
840
- sorted_condensed_list.append(child)
841
-
842
-
843
- # Start DFS from each root
844
- for root_span in roots:
845
- if root_span['span_id'] not in visited:
846
- dfs(root_span)
847
-
848
- # Handle spans that might not have been reachable from roots (orphans)
849
- # Though ideally, all spans should descend from a root.
850
- for span_data in spans_list:
851
- if span_data['span_id'] not in visited:
852
- # Decide how to handle orphans, maybe append them at the end sorted by time?
853
- # For now, let's just add them to ensure they aren't lost.
854
- sorted_condensed_list.append(span_data)
855
-
856
-
857
- return sorted_condensed_list, evaluation_runs
858
516
 
859
517
  def save(self, overwrite: bool = False) -> Tuple[str, dict]:
860
518
  """
@@ -863,44 +521,36 @@ class TraceClient:
863
521
  """
864
522
  # Calculate total elapsed time
865
523
  total_duration = self.get_duration()
866
-
867
- raw_entries = [entry.to_dict() for entry in self.entries]
868
-
869
- condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
870
524
 
871
525
  # Only count tokens for actual LLM API call spans
872
526
  llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
873
- for entry in condensed_entries:
874
- entry_function_name = entry.get("function", "") # Get function name safely
527
+ for span in self.trace_spans:
528
+ span_function_name = span.function # Get function name safely
875
529
  # Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
876
- is_llm_entry = entry.get("span_type") == "llm"
877
- has_api_suffix = any(suffix in entry_function_name for suffix in llm_span_names)
878
- output_is_dict = isinstance(entry.get("output"), dict)
530
+ is_llm_span = span.span_type == "llm"
531
+ has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
532
+ output_is_dict = isinstance(span.output, dict)
879
533
 
880
534
  # --- DEBUG PRINT 1: Check if condition passes ---
881
535
  # if is_llm_entry and has_api_suffix and output_is_dict:
882
- # # print(f"[DEBUG TraceClient.save] Processing entry: {entry.get('span_id')} ({entry_function_name}) - Condition PASSED")
883
536
  # elif is_llm_entry:
884
537
  # # Print why it failed if it was an LLM entry
885
- # print(f"[DEBUG TraceClient.save] Skipping LLM entry: {entry.get('span_id')} ({entry_function_name}) - Suffix Match: {has_api_suffix}, Output is Dict: {output_is_dict}")
886
538
  # # --- END DEBUG ---
887
539
 
888
- if is_llm_entry and has_api_suffix and output_is_dict:
889
- output = entry["output"]
540
+ if is_llm_span and has_api_suffix and output_is_dict:
541
+ output = span.output
890
542
  usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
891
543
 
892
544
  # --- DEBUG PRINT 2: Check extracted usage ---
893
- # print(f"[DEBUG TraceClient.save] Extracted usage dict: {usage}")
894
545
  # --- END DEBUG ---
895
546
 
896
547
  # --- NEW: Extract model_name correctly from nested inputs ---
897
548
  model_name = None
898
- entry_inputs = entry.get("inputs", {})
899
- # print(f"[DEBUG TraceClient.save] Inspecting inputs for span {entry.get('span_id')}: {entry_inputs}") # DEBUG Inputs
900
- if entry_inputs:
549
+ span_inputs = span.inputs
550
+ if span_inputs:
901
551
  # Try common locations for model name within the inputs structure
902
- invocation_params = entry_inputs.get("invocation_params", {})
903
- serialized_data = entry_inputs.get("serialized", {})
552
+ invocation_params = span_inputs.get("invocation_params", {})
553
+ serialized_data = span_inputs.get("serialized", {})
904
554
 
905
555
  # Look in invocation_params (often directly contains model)
906
556
  if isinstance(invocation_params, dict):
@@ -920,10 +570,9 @@ class TraceClient:
920
570
 
921
571
  # Fallback: Check top-level of inputs itself (less likely for callbacks)
922
572
  if not model_name:
923
- model_name = entry_inputs.get("model")
573
+ model_name = span_inputs.get("model")
924
574
 
925
575
 
926
- # print(f"[DEBUG TraceClient.save] Determined model_name: {model_name}") # DEBUG Model Name
927
576
  # --- END NEW ---
928
577
 
929
578
  prompt_tokens = 0
@@ -985,7 +634,7 @@ class TraceClient:
985
634
  if "usage" not in output:
986
635
  output["usage"] = {} # Initialize if missing
987
636
  elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
988
- print(f"[WARN TraceClient.save] Output 'usage' for span {entry.get('span_id')} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
637
+ print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
989
638
  output["usage"] = {} # Reset to dict
990
639
 
991
640
  output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
@@ -993,10 +642,10 @@ class TraceClient:
993
642
  output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
994
643
  except Exception as e:
995
644
  # If cost calculation fails, continue without adding costs
996
- print(f"Error calculating cost for model '{model_name}' (span: {entry.get('span_id')}): {str(e)}")
645
+ print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
997
646
  pass
998
647
  else:
999
- print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {entry.get('span_id')}). Inputs: {entry_inputs}")
648
+ print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
1000
649
 
1001
650
 
1002
651
  # Create trace document - Always use standard keys for top-level counts
@@ -1006,8 +655,8 @@ class TraceClient:
1006
655
  "project_name": self.project_name,
1007
656
  "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
1008
657
  "duration": total_duration,
1009
- "entries": condensed_entries,
1010
- "evaluation_runs": evaluation_runs,
658
+ "entries": [span.model_dump() for span in self.trace_spans],
659
+ "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
1011
660
  "overwrite": overwrite,
1012
661
  "parent_trace_id": self.parent_trace_id,
1013
662
  "parent_name": self.parent_name
@@ -1015,11 +664,248 @@ class TraceClient:
1015
664
  # --- Log trace data before saving ---
1016
665
  self.trace_manager_client.save_trace(trace_data)
1017
666
 
667
+ # upload annotations
668
+ # TODO: batch to the log endpoint
669
+ for annotation in self.annotations:
670
+ self.trace_manager_client.save_annotation(annotation)
671
+
1018
672
  return self.trace_id, trace_data
1019
673
 
1020
674
  def delete(self):
1021
675
  return self.trace_manager_client.delete_trace(self.trace_id)
1022
676
 
677
+
678
+ class _DeepTracer:
679
+ _instance: Optional["_DeepTracer"] = None
680
+ _lock: threading.Lock = threading.Lock()
681
+ _refcount: int = 0
682
+ _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
683
+ _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
684
+
685
+ def _get_qual_name(self, frame) -> str:
686
+ func_name = frame.f_code.co_name
687
+ module_name = frame.f_globals.get("__name__", "unknown_module")
688
+
689
+ try:
690
+ func = frame.f_globals.get(func_name)
691
+ if func is None:
692
+ return f"{module_name}.{func_name}"
693
+ if hasattr(func, "__qualname__"):
694
+ return f"{module_name}.{func.__qualname__}"
695
+ except Exception:
696
+ return f"{module_name}.{func_name}"
697
+
698
+ def __new__(cls):
699
+ with cls._lock:
700
+ if cls._instance is None:
701
+ cls._instance = super().__new__(cls)
702
+ return cls._instance
703
+
704
+ def _should_trace(self, frame):
705
+ # Skip stack is maintained by the tracer as an optimization to skip earlier
706
+ # frames in the call stack that we've already determined should be skipped
707
+ skip_stack = self._skip_stack.get()
708
+ if len(skip_stack) > 0:
709
+ return False
710
+
711
+ func_name = frame.f_code.co_name
712
+ module_name = frame.f_globals.get("__name__", None)
713
+
714
+ func = frame.f_globals.get(func_name)
715
+ if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
716
+ return False
717
+
718
+ if (
719
+ not module_name
720
+ or func_name.startswith("<") # ex: <listcomp>
721
+ or func_name.startswith("__") and func_name != "__call__" # dunders
722
+ or not self._is_user_code(frame.f_code.co_filename)
723
+ ):
724
+ return False
725
+
726
+ return True
727
+
728
+ @functools.cache
729
+ def _is_user_code(self, filename: str):
730
+ return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
731
+
732
+ def _trace(self, frame: types.FrameType, event: str, arg: Any):
733
+ frame.f_trace_lines = False
734
+ frame.f_trace_opcodes = False
735
+
736
+
737
+ if not self._should_trace(frame):
738
+ return
739
+
740
+ if event not in ("call", "return", "exception"):
741
+ return
742
+
743
+ current_trace = current_trace_var.get()
744
+ if not current_trace:
745
+ return
746
+
747
+ parent_span_id = current_span_var.get()
748
+ if not parent_span_id:
749
+ return
750
+
751
+ qual_name = self._get_qual_name(frame)
752
+ skip_stack = self._skip_stack.get()
753
+
754
+ if event == "call":
755
+ # If we have entries in the skip stack and the current qual_name matches the top entry,
756
+ # push it again to track nesting depth and skip
757
+ # As an optimization, we only care about duplicate qual_names.
758
+ if skip_stack:
759
+ if qual_name == skip_stack[-1]:
760
+ skip_stack.append(qual_name)
761
+ self._skip_stack.set(skip_stack)
762
+ return
763
+
764
+ should_trace = self._should_trace(frame)
765
+
766
+ if not should_trace:
767
+ if not skip_stack:
768
+ self._skip_stack.set([qual_name])
769
+ return
770
+ elif event == "return":
771
+ # If we have entries in skip stack and current qual_name matches the top entry,
772
+ # pop it to track exiting from the skipped section
773
+ if skip_stack and qual_name == skip_stack[-1]:
774
+ skip_stack.pop()
775
+ self._skip_stack.set(skip_stack)
776
+ return
777
+
778
+ if skip_stack:
779
+ return
780
+
781
+ span_stack = self._span_stack.get()
782
+ if event == "call":
783
+ if not self._should_trace(frame):
784
+ return
785
+
786
+ span_id = str(uuid.uuid4())
787
+
788
+ parent_depth = current_trace._span_depths.get(parent_span_id, 0)
789
+ depth = parent_depth + 1
790
+
791
+ current_trace._span_depths[span_id] = depth
792
+
793
+ start_time = time.time()
794
+
795
+ span_stack.append({
796
+ "span_id": span_id,
797
+ "parent_span_id": parent_span_id,
798
+ "function": qual_name,
799
+ "start_time": start_time
800
+ })
801
+ self._span_stack.set(span_stack)
802
+
803
+ token = current_span_var.set(span_id)
804
+ frame.f_locals["_judgment_span_token"] = token
805
+
806
+ span = TraceSpan(
807
+ span_id=span_id,
808
+ trace_id=current_trace.trace_id,
809
+ depth=depth,
810
+ message=qual_name,
811
+ created_at=start_time,
812
+ span_type="span",
813
+ parent_span_id=parent_span_id,
814
+ function=qual_name
815
+ )
816
+ current_trace.add_span(span)
817
+
818
+ inputs = {}
819
+ try:
820
+ args_info = inspect.getargvalues(frame)
821
+ for arg in args_info.args:
822
+ try:
823
+ inputs[arg] = args_info.locals.get(arg)
824
+ except:
825
+ inputs[arg] = "<<Unserializable>>"
826
+ current_trace.record_input(inputs)
827
+ except Exception as e:
828
+ current_trace.record_input({
829
+ "error": str(e)
830
+ })
831
+
832
+ elif event == "return":
833
+ if not span_stack:
834
+ return
835
+
836
+ current_id = current_span_var.get()
837
+
838
+ span_data = None
839
+ for i, entry in enumerate(reversed(span_stack)):
840
+ if entry["span_id"] == current_id:
841
+ span_data = span_stack.pop(-(i+1))
842
+ self._span_stack.set(span_stack)
843
+ break
844
+
845
+ if not span_data:
846
+ return
847
+
848
+ start_time = span_data["start_time"]
849
+ duration = time.time() - start_time
850
+
851
+ current_trace.span_id_to_span[span_data["span_id"]].duration = duration
852
+
853
+ if arg is not None:
854
+ # exception handling will take priority.
855
+ current_trace.record_output(arg)
856
+
857
+ if span_data["span_id"] in current_trace._span_depths:
858
+ del current_trace._span_depths[span_data["span_id"]]
859
+
860
+ if span_stack:
861
+ current_span_var.set(span_stack[-1]["span_id"])
862
+ else:
863
+ current_span_var.set(span_data["parent_span_id"])
864
+
865
+ if "_judgment_span_token" in frame.f_locals:
866
+ current_span_var.reset(frame.f_locals["_judgment_span_token"])
867
+
868
+ elif event == "exception":
869
+ exc_type, exc_value, exc_traceback = arg
870
+ formatted_exception = {
871
+ "type": exc_type.__name__,
872
+ "message": str(exc_value),
873
+ "traceback": traceback.format_tb(exc_traceback)
874
+ }
875
+ current_trace = current_trace_var.get()
876
+ current_trace.record_output({
877
+ "error": formatted_exception
878
+ })
879
+
880
+ return self._trace
881
+
882
+ def __enter__(self):
883
+ with self._lock:
884
+ self._refcount += 1
885
+ if self._refcount == 1:
886
+ self._skip_stack.set([])
887
+ self._span_stack.set([])
888
+ sys.settrace(self._trace)
889
+ threading.settrace(self._trace)
890
+ return self
891
+
892
+ def __exit__(self, exc_type, exc_val, exc_tb):
893
+ with self._lock:
894
+ self._refcount -= 1
895
+ if self._refcount == 0:
896
+ sys.settrace(None)
897
+ threading.settrace(None)
898
+
899
+
900
+ def log(self, message: str, level: str = "info"):
901
+ """ Log a message with the span context """
902
+ current_trace = current_trace_var.get()
903
+ if current_trace:
904
+ current_trace.log(message, level)
905
+ else:
906
+ print(f"[{level}] {message}")
907
+ current_trace.record_output({"log": message})
908
+
1023
909
  class Tracer:
1024
910
  _instance = None
1025
911
 
@@ -1042,12 +928,16 @@ class Tracer:
1042
928
  s3_aws_access_key_id: Optional[str] = None,
1043
929
  s3_aws_secret_access_key: Optional[str] = None,
1044
930
  s3_region_name: Optional[str] = None,
1045
- deep_tracing: bool = True # NEW: Enable deep tracing by default
931
+ deep_tracing: bool = True # Deep tracing is enabled by default
1046
932
  ):
1047
933
  if not hasattr(self, 'initialized'):
1048
934
  if not api_key:
1049
935
  raise ValueError("Tracer must be configured with a Judgment API key")
1050
936
 
937
+ result, response = validate_api_key(api_key)
938
+ if not result:
939
+ raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
940
+
1051
941
  if not organization_id:
1052
942
  raise ValueError("Tracer must be configured with an Organization ID")
1053
943
  if use_s3 and not s3_bucket_name:
@@ -1059,11 +949,11 @@ class Tracer:
1059
949
 
1060
950
  self.api_key: str = api_key
1061
951
  self.project_name: str = project_name
1062
- self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
1063
952
  self.organization_id: str = organization_id
1064
953
  self._current_trace: Optional[str] = None
1065
954
  self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
1066
955
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
956
+ self.traces: List[Trace] = []
1067
957
  self.initialized: bool = True
1068
958
  self.enable_monitoring: bool = enable_monitoring
1069
959
  self.enable_evaluations: bool = enable_evaluations
@@ -1119,45 +1009,6 @@ class Tracer:
1119
1009
  """Returns the TraceClient instance currently marked as active by the handler."""
1120
1010
  return self._active_trace_client
1121
1011
 
1122
- def _apply_deep_tracing(self, func, span_type="span"):
1123
- """
1124
- Apply deep tracing to all functions in the same module as the given function.
1125
-
1126
- Args:
1127
- func: The function being traced
1128
- span_type: Type of span to use for traced functions
1129
-
1130
- Returns:
1131
- A tuple of (module, original_functions_dict) where original_functions_dict
1132
- contains the original functions that were replaced with traced versions.
1133
- """
1134
- module = inspect.getmodule(func)
1135
- if not module:
1136
- return None, {}
1137
-
1138
- # Save original functions
1139
- original_functions = {}
1140
-
1141
- # Find all functions in the module
1142
- for name, obj in inspect.getmembers(module, inspect.isfunction):
1143
- # Skip already wrapped functions
1144
- if hasattr(obj, '_judgment_traced'):
1145
- continue
1146
-
1147
- # Create a traced version of the function
1148
- # Always use default span type "span" for child functions
1149
- traced_func = _create_deep_tracing_wrapper(obj, self, "span")
1150
-
1151
- # Mark the function as traced to avoid double wrapping
1152
- traced_func._judgment_traced = True
1153
-
1154
- # Save the original function
1155
- original_functions[name] = obj
1156
-
1157
- # Replace with traced version
1158
- setattr(module, name, traced_func)
1159
-
1160
- return module, original_functions
1161
1012
 
1162
1013
  @contextmanager
1163
1014
  def trace(
@@ -1204,6 +1055,23 @@ class Tracer:
1204
1055
  finally:
1205
1056
  # Reset the context variable
1206
1057
  current_trace_var.reset(token)
1058
+
1059
+
1060
+ def log(self, msg: str, label: str = "log", score: int = 1):
1061
+ """Log a message with the current span context"""
1062
+ current_span_id = current_span_var.get()
1063
+ current_trace = current_trace_var.get()
1064
+ if current_span_id:
1065
+ annotation = TraceAnnotation(
1066
+ span_id=current_span_id,
1067
+ text=msg,
1068
+ label=label,
1069
+ score=score
1070
+ )
1071
+
1072
+ current_trace.add_annotation(annotation)
1073
+
1074
+ rprint(f"[bold]{label}:[/bold] {msg}")
1207
1075
 
1208
1076
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1209
1077
  """
@@ -1239,13 +1107,6 @@ class Tracer:
1239
1107
  if asyncio.iscoroutinefunction(func):
1240
1108
  @functools.wraps(func)
1241
1109
  async def async_wrapper(*args, **kwargs):
1242
- # Check if we're already in a traced function
1243
- if in_traced_function_var.get():
1244
- return await func(*args, **kwargs)
1245
-
1246
- # Set in_traced_function_var to True
1247
- token = in_traced_function_var.set(True)
1248
-
1249
1110
  # Get current trace from context
1250
1111
  current_trace = current_trace_var.get()
1251
1112
 
@@ -1275,81 +1136,47 @@ class Tracer:
1275
1136
  # This sets the current_span_var
1276
1137
  with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1277
1138
  # Record inputs
1278
- span.record_input({
1279
- 'args': str(args),
1280
- 'kwargs': kwargs
1281
- })
1139
+ inputs = combine_args_kwargs(func, args, kwargs)
1140
+ span.record_input(inputs)
1282
1141
 
1283
- # If deep tracing is enabled, apply monkey patching
1284
1142
  if use_deep_tracing:
1285
- module, original_functions = self._apply_deep_tracing(func, span_type)
1286
-
1287
- # Execute function
1288
- result = await func(*args, **kwargs)
1289
-
1290
- # Restore original functions if deep tracing was enabled
1291
- if use_deep_tracing and module and 'original_functions' in locals():
1292
- for name, obj in original_functions.items():
1293
- setattr(module, name, obj)
1294
-
1143
+ with _DeepTracer():
1144
+ result = await func(*args, **kwargs)
1145
+ else:
1146
+ result = await func(*args, **kwargs)
1147
+
1295
1148
  # Record output
1296
1149
  span.record_output(result)
1297
-
1298
- # Save the completed trace
1299
- current_trace.save(overwrite=overwrite)
1300
1150
  return result
1301
1151
  finally:
1152
+ # Save the completed trace
1153
+ trace_id, trace = current_trace.save(overwrite=overwrite)
1154
+ self.traces.append(trace)
1155
+
1302
1156
  # Reset trace context (span context resets automatically)
1303
1157
  current_trace_var.reset(trace_token)
1304
- # Reset in_traced_function_var
1305
- in_traced_function_var.reset(token)
1306
1158
  else:
1307
- # Already have a trace context, just create a span in it
1308
- # The span method handles current_span_var
1309
-
1310
- try:
1311
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1312
- # Record inputs
1313
- span.record_input({
1314
- 'args': str(args),
1315
- 'kwargs': kwargs
1316
- })
1317
-
1318
- # If deep tracing is enabled, apply monkey patching
1319
- if use_deep_tracing:
1320
- module, original_functions = self._apply_deep_tracing(func, span_type)
1321
-
1322
- # Execute function
1159
+ with current_trace.span(span_name, span_type=span_type) as span:
1160
+ inputs = combine_args_kwargs(func, args, kwargs)
1161
+ span.record_input(inputs)
1162
+
1163
+ if use_deep_tracing:
1164
+ with _DeepTracer():
1165
+ result = await func(*args, **kwargs)
1166
+ else:
1323
1167
  result = await func(*args, **kwargs)
1324
1168
 
1325
- # Restore original functions if deep tracing was enabled
1326
- if use_deep_tracing and module and 'original_functions' in locals():
1327
- for name, obj in original_functions.items():
1328
- setattr(module, name, obj)
1329
-
1330
- # Record output
1331
- span.record_output(result)
1332
-
1333
- return result
1334
- finally:
1335
- # Reset in_traced_function_var
1336
- in_traced_function_var.reset(token)
1337
-
1169
+ span.record_output(result)
1170
+ return result
1171
+
1338
1172
  return async_wrapper
1339
1173
  else:
1340
1174
  # Non-async function implementation with deep tracing
1341
1175
  @functools.wraps(func)
1342
- def wrapper(*args, **kwargs):
1343
- # Check if we're already in a traced function
1344
- if in_traced_function_var.get():
1345
- return func(*args, **kwargs)
1346
-
1347
- # Set in_traced_function_var to True
1348
- token = in_traced_function_var.set(True)
1349
-
1176
+ def wrapper(*args, **kwargs):
1350
1177
  # Get current trace from context
1351
1178
  current_trace = current_trace_var.get()
1352
-
1179
+
1353
1180
  # If there's no current trace, create a root trace
1354
1181
  if not current_trace:
1355
1182
  trace_id = str(uuid.uuid4())
@@ -1376,66 +1203,40 @@ class Tracer:
1376
1203
  # This sets the current_span_var
1377
1204
  with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1378
1205
  # Record inputs
1379
- span.record_input({
1380
- 'args': str(args),
1381
- 'kwargs': kwargs
1382
- })
1206
+ inputs = combine_args_kwargs(func, args, kwargs)
1207
+ span.record_input(inputs)
1383
1208
 
1384
- # If deep tracing is enabled, apply monkey patching
1385
1209
  if use_deep_tracing:
1386
- module, original_functions = self._apply_deep_tracing(func, span_type)
1387
-
1388
- # Execute function
1389
- result = func(*args, **kwargs)
1390
-
1391
- # Restore original functions if deep tracing was enabled
1392
- if use_deep_tracing and module and 'original_functions' in locals():
1393
- for name, obj in original_functions.items():
1394
- setattr(module, name, obj)
1210
+ with _DeepTracer():
1211
+ result = func(*args, **kwargs)
1212
+ else:
1213
+ result = func(*args, **kwargs)
1395
1214
 
1396
1215
  # Record output
1397
1216
  span.record_output(result)
1398
-
1399
- # Save the completed trace
1400
- current_trace.save(overwrite=overwrite)
1401
1217
  return result
1402
1218
  finally:
1219
+ # Save the completed trace
1220
+ trace_id, trace = current_trace.save(overwrite=overwrite)
1221
+ self.traces.append(trace)
1222
+
1403
1223
  # Reset trace context (span context resets automatically)
1404
1224
  current_trace_var.reset(trace_token)
1405
- # Reset in_traced_function_var
1406
- in_traced_function_var.reset(token)
1407
1225
  else:
1408
- # Already have a trace context, just create a span in it
1409
- # The span method handles current_span_var
1410
-
1411
- try:
1412
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1413
- # Record inputs
1414
- span.record_input({
1415
- 'args': str(args),
1416
- 'kwargs': kwargs
1417
- })
1418
-
1419
- # If deep tracing is enabled, apply monkey patching
1420
- if use_deep_tracing:
1421
- module, original_functions = self._apply_deep_tracing(func, span_type)
1422
-
1423
- # Execute function
1226
+ with current_trace.span(span_name, span_type=span_type) as span:
1227
+
1228
+ inputs = combine_args_kwargs(func, args, kwargs)
1229
+ span.record_input(inputs)
1230
+
1231
+ if use_deep_tracing:
1232
+ with _DeepTracer():
1233
+ result = func(*args, **kwargs)
1234
+ else:
1424
1235
  result = func(*args, **kwargs)
1425
1236
 
1426
- # Restore original functions if deep tracing was enabled
1427
- if use_deep_tracing and module and 'original_functions' in locals():
1428
- for name, obj in original_functions.items():
1429
- setattr(module, name, obj)
1430
-
1431
- # Record output
1432
- span.record_output(result)
1433
-
1434
- return result
1435
- finally:
1436
- # Reset in_traced_function_var
1437
- in_traced_function_var.reset(token)
1438
-
1237
+ span.record_output(result)
1238
+ return result
1239
+
1439
1240
  return wrapper
1440
1241
 
1441
1242
  def async_evaluate(self, *args, **kwargs):
@@ -1469,7 +1270,7 @@ def wrap(client: Any) -> Any:
1469
1270
  Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1470
1271
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1471
1272
  """
1472
- span_name, original_create, original_stream = _get_client_config(client)
1273
+ span_name, original_create, responses_create, original_stream = _get_client_config(client)
1473
1274
 
1474
1275
  # --- Define Traced Async Functions ---
1475
1276
  async def traced_create_async(*args, **kwargs):
@@ -1567,7 +1368,41 @@ def wrap(client: Any) -> Any:
1567
1368
  span.record_output(output_data)
1568
1369
  return response_or_iterator
1569
1370
 
1371
+ # --- Define Traced Sync Functions ---
1372
+ def traced_response_create_sync(*args, **kwargs):
1373
+ # [Existing logic - unchanged]
1374
+ current_trace = current_trace_var.get()
1375
+ if not current_trace:
1376
+ return responses_create(*args, **kwargs)
1570
1377
 
1378
+ is_streaming = kwargs.get("stream", False)
1379
+ with current_trace.span(span_name, span_type="llm") as span:
1380
+ span.record_input(kwargs)
1381
+
1382
+ # Warn about token counting limitations with streaming
1383
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1384
+ if not kwargs.get("stream_options", {}).get("include_usage"):
1385
+ warnings.warn(
1386
+ "OpenAI streaming calls don't include token counts by default. "
1387
+ "To enable token counting with streams, set stream_options={'include_usage': True} "
1388
+ "in your API call arguments.",
1389
+ UserWarning
1390
+ )
1391
+
1392
+ try:
1393
+ response_or_iterator = responses_create(*args, **kwargs)
1394
+ except Exception as e:
1395
+ print(f"Error during wrapped sync API call ({span_name}): {e}")
1396
+ span.record_output({"error": str(e)})
1397
+ raise
1398
+ if is_streaming:
1399
+ output_entry = span.record_output("<pending stream>")
1400
+ return _sync_stream_wrapper(response_or_iterator, client, output_entry)
1401
+ else:
1402
+ output_data = _format_response_output_data(client, response_or_iterator)
1403
+ span.record_output(output_data)
1404
+ return response_or_iterator
1405
+
1571
1406
  # Function replacing sync .stream()
1572
1407
  def traced_stream_sync(*args, **kwargs):
1573
1408
  current_trace = current_trace_var.get()
@@ -1615,15 +1450,16 @@ def wrap(client: Any) -> Any:
1615
1450
  if original_stream:
1616
1451
  client.messages.stream = traced_stream_async
1617
1452
  elif isinstance(client, genai.client.AsyncClient):
1618
- client.generate_content = traced_create_async
1453
+ client.models.generate_content = traced_create_async
1619
1454
  elif isinstance(client, (OpenAI, Together)):
1620
1455
  client.chat.completions.create = traced_create_sync
1456
+ client.responses.create = traced_response_create_sync
1621
1457
  elif isinstance(client, Anthropic):
1622
1458
  client.messages.create = traced_create_sync
1623
1459
  if original_stream:
1624
1460
  client.messages.stream = traced_stream_sync
1625
1461
  elif isinstance(client, genai.Client):
1626
- client.generate_content = traced_create_sync
1462
+ client.models.generate_content = traced_create_sync
1627
1463
 
1628
1464
  return client
1629
1465
 
@@ -1639,19 +1475,20 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
1639
1475
  tuple: (span_name, create_method, stream_method)
1640
1476
  - span_name: String identifier for tracing
1641
1477
  - create_method: Reference to the client's creation method
1478
+ - responses_method: Reference to the client's responses method (if applicable)
1642
1479
  - stream_method: Reference to the client's stream method (if applicable)
1643
1480
 
1644
1481
  Raises:
1645
1482
  ValueError: If client type is not supported
1646
1483
  """
1647
1484
  if isinstance(client, (OpenAI, AsyncOpenAI)):
1648
- return "OPENAI_API_CALL", client.chat.completions.create, None
1485
+ return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None
1649
1486
  elif isinstance(client, (Together, AsyncTogether)):
1650
- return "TOGETHER_API_CALL", client.chat.completions.create, None
1487
+ return "TOGETHER_API_CALL", client.chat.completions.create, None, None
1651
1488
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1652
- return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
1489
+ return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream
1653
1490
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1654
- return "GOOGLE_API_CALL", client.models.generate_content, None
1491
+ return "GOOGLE_API_CALL", client.models.generate_content, None, None
1655
1492
  raise ValueError(f"Unsupported client type: {type(client)}")
1656
1493
 
1657
1494
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
@@ -1677,6 +1514,26 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
1677
1514
  "max_tokens": kwargs.get("max_tokens")
1678
1515
  }
1679
1516
 
1517
+ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
1518
+ """Format API response data based on client type.
1519
+
1520
+ Normalizes different response formats into a consistent structure
1521
+ for tracing purposes.
1522
+ """
1523
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1524
+ return {
1525
+ "content": response.output,
1526
+ "usage": {
1527
+ "prompt_tokens": response.usage.input_tokens,
1528
+ "completion_tokens": response.usage.output_tokens,
1529
+ "total_tokens": response.usage.total_tokens
1530
+ }
1531
+ }
1532
+ else:
1533
+ warnings.warn(f"Unsupported client type: {type(client)}")
1534
+ return {}
1535
+
1536
+
1680
1537
  def _format_output_data(client: ApiClient, response: Any) -> dict:
1681
1538
  """Format API response data based on client type.
1682
1539
 
@@ -1716,117 +1573,51 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1716
1573
  }
1717
1574
  }
1718
1575
 
1719
- # Define a blocklist of functions that should not be traced
1720
- # These are typically utility functions, print statements, logging, etc.
1721
- _TRACE_BLOCKLIST = {
1722
- # Built-in functions
1723
- 'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
1724
- 'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
1725
- 'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
1726
- # Logging functions
1727
- 'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
1728
- # Common utility functions
1729
- 'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
1730
- # String operations
1731
- 'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
1732
- # Dict operations
1733
- 'get', 'items', 'keys', 'values', 'update',
1734
- # List operations
1735
- 'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
1736
- }
1737
-
1738
-
1739
- # Add a new function for deep tracing at the module level
1740
- def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
1576
+ def combine_args_kwargs(func, args, kwargs):
1741
1577
  """
1742
- Creates a wrapper for a function that automatically traces it when called within a traced function.
1743
- This enables deep tracing without requiring explicit @observe decorators on every function.
1578
+ Combine positional arguments and keyword arguments into a single dictionary.
1744
1579
 
1745
1580
  Args:
1746
- func: The function to wrap
1747
- tracer: The Tracer instance
1748
- span_type: Type of span (default "span")
1581
+ func: The function being called
1582
+ args: Tuple of positional arguments
1583
+ kwargs: Dictionary of keyword arguments
1749
1584
 
1750
1585
  Returns:
1751
- A wrapped function that will be traced when called
1586
+ A dictionary combining both args and kwargs
1752
1587
  """
1753
- # Skip wrapping if the function is not callable or is a built-in
1754
- if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
1755
- return func
1756
-
1757
- # Skip functions in the blocklist
1758
- if func.__name__ in _TRACE_BLOCKLIST:
1759
- return func
1760
-
1761
- # Skip functions from certain modules (logging, sys, etc.)
1762
- if func.__module__ and any(func.__module__.startswith(m) for m in ['logging', 'sys', 'os', 'json', 'time', 'datetime']):
1763
- return func
1764
-
1765
-
1766
- # Get function name for the span - check for custom name set by @observe
1767
- func_name = getattr(func, '_judgment_span_name', func.__name__)
1768
-
1769
- # Check for custom span_type set by @observe
1770
- func_span_type = getattr(func, '_judgment_span_type', "span")
1771
-
1772
- # Store original function to prevent losing reference
1773
- original_func = func
1774
-
1775
- # Create appropriate wrapper based on whether the function is async or not
1776
- if asyncio.iscoroutinefunction(func):
1777
- @functools.wraps(func)
1778
- async def async_deep_wrapper(*args, **kwargs):
1779
- # Get current trace from context
1780
- current_trace = current_trace_var.get()
1781
-
1782
- # If no trace context, just call the function
1783
- if not current_trace:
1784
- return await original_func(*args, **kwargs)
1785
-
1786
- # Create a span for this function call - use custom span_type if available
1787
- with current_trace.span(func_name, span_type=func_span_type) as span:
1788
- # Record inputs
1789
- span.record_input({
1790
- 'args': str(args),
1791
- 'kwargs': kwargs
1792
- })
1793
-
1794
- # Execute function
1795
- result = await original_func(*args, **kwargs)
1796
-
1797
- # Record output
1798
- span.record_output(result)
1799
-
1800
- return result
1801
-
1802
- return async_deep_wrapper
1803
- else:
1804
- @functools.wraps(func)
1805
- def deep_wrapper(*args, **kwargs):
1806
- # Get current trace from context
1807
- current_trace = current_trace_var.get()
1808
-
1809
- # If no trace context, just call the function
1810
- if not current_trace:
1811
- return original_func(*args, **kwargs)
1812
-
1813
- # Create a span for this function call - use custom span_type if available
1814
- with current_trace.span(func_name, span_type=func_span_type) as span:
1815
- # Record inputs
1816
- span.record_input({
1817
- 'args': str(args),
1818
- 'kwargs': kwargs
1819
- })
1820
-
1821
- # Execute function
1822
- result = original_func(*args, **kwargs)
1823
-
1824
- # Record output
1825
- span.record_output(result)
1826
-
1827
- return result
1828
-
1829
- return deep_wrapper
1588
+ try:
1589
+ import inspect
1590
+ sig = inspect.signature(func)
1591
+ param_names = list(sig.parameters.keys())
1592
+
1593
+ args_dict = {}
1594
+ for i, arg in enumerate(args):
1595
+ if i < len(param_names):
1596
+ args_dict[param_names[i]] = arg
1597
+ else:
1598
+ args_dict[f"arg{i}"] = arg
1599
+
1600
+ return {**args_dict, **kwargs}
1601
+ except Exception as e:
1602
+ # Fallback if signature inspection fails
1603
+ return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
1604
+
1605
+ # NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
1606
+ # @link https://docs.python.org/3.13/library/sysconfig.html
1607
+ _TRACE_FILEPATH_BLOCKLIST = tuple(
1608
+ os.path.realpath(p) + os.sep
1609
+ for p in {
1610
+ sysconfig.get_paths()['stdlib'],
1611
+ sysconfig.get_paths().get('platstdlib', ''),
1612
+ *site.getsitepackages(),
1613
+ site.getusersitepackages(),
1614
+ *(
1615
+ [os.path.join(os.path.dirname(__file__), '../../judgeval/')]
1616
+ if os.environ.get('JUDGMENT_DEV')
1617
+ else []
1618
+ ),
1619
+ } if p
1620
+ )
1830
1621
 
1831
1622
  # Add the new TraceThreadPoolExecutor class
1832
1623
  class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
@@ -1929,7 +1720,7 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
1929
1720
  def _sync_stream_wrapper(
1930
1721
  original_stream: Iterator,
1931
1722
  client: ApiClient,
1932
- output_entry: TraceEntry
1723
+ span: TraceSpan
1933
1724
  ) -> Generator[Any, None, None]:
1934
1725
  """Wraps a synchronous stream iterator to capture content and update the trace."""
1935
1726
  content_parts = [] # Use a list instead of string concatenation
@@ -1948,7 +1739,7 @@ def _sync_stream_wrapper(
1948
1739
  final_usage = _extract_usage_from_final_chunk(client, last_chunk)
1949
1740
 
1950
1741
  # Update the trace entry with the accumulated content and usage
1951
- output_entry.output = {
1742
+ span.output = {
1952
1743
  "content": "".join(content_parts), # Join list at the end
1953
1744
  "usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
1954
1745
  "streamed": True
@@ -1960,7 +1751,7 @@ def _sync_stream_wrapper(
1960
1751
  async def _async_stream_wrapper(
1961
1752
  original_stream: AsyncIterator,
1962
1753
  client: ApiClient,
1963
- output_entry: TraceEntry
1754
+ span: TraceSpan
1964
1755
  ) -> AsyncGenerator[Any, None]:
1965
1756
  # [Existing logic - unchanged]
1966
1757
  content_parts = [] # Use a list instead of string concatenation
@@ -1969,7 +1760,7 @@ async def _async_stream_wrapper(
1969
1760
  anthropic_input_tokens = 0
1970
1761
  anthropic_output_tokens = 0
1971
1762
 
1972
- target_span_id = getattr(output_entry, 'span_id', 'UNKNOWN')
1763
+ target_span_id = span.span_id
1973
1764
 
1974
1765
  try:
1975
1766
  async for chunk in original_stream:
@@ -2014,19 +1805,17 @@ async def _async_stream_wrapper(
2014
1805
  elif last_content_chunk:
2015
1806
  usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
2016
1807
 
2017
- if output_entry and hasattr(output_entry, 'output'):
2018
- output_entry.output = {
1808
+ if span and hasattr(span, 'output'):
1809
+ span.output = {
2019
1810
  "content": "".join(content_parts), # Join list at the end
2020
1811
  "usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
2021
1812
  "streamed": True
2022
1813
  }
2023
- start_ts = getattr(output_entry, 'created_at', time.time())
2024
- output_entry.duration = time.time() - start_ts
1814
+ start_ts = getattr(span, 'created_at', time.time())
1815
+ span.duration = time.time() - start_ts
2025
1816
  # else: # Handle error case if necessary, but remove debug print
2026
1817
 
2027
- # --- Define Context Manager Wrapper Classes ---
2028
- class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
2029
- """Wraps an original async stream manager to add tracing."""
1818
+ class _BaseStreamManagerWrapper:
2030
1819
  def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
2031
1820
  self._original_manager = original_manager
2032
1821
  self._client = client
@@ -2036,160 +1825,77 @@ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
2036
1825
  self._input_kwargs = input_kwargs
2037
1826
  self._parent_span_id_at_entry = None
2038
1827
 
2039
- async def __aenter__(self):
2040
- self._parent_span_id_at_entry = current_span_var.get()
2041
- if not self._trace_client:
2042
- # If no trace, just delegate to the original manager
2043
- return await self._original_manager.__aenter__()
2044
-
2045
- # --- Manually create the 'enter' entry ---
1828
+ def _create_span(self):
2046
1829
  start_time = time.time()
2047
1830
  span_id = str(uuid.uuid4())
2048
1831
  current_depth = 0
2049
1832
  if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2050
1833
  current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
2051
1834
  self._trace_client._span_depths[span_id] = current_depth
2052
- enter_entry = TraceEntry(
2053
- type="enter", function=self._span_name, span_id=span_id,
2054
- trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
2055
- created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
1835
+ span = TraceSpan(
1836
+ function=self._span_name,
1837
+ span_id=span_id,
1838
+ trace_id=self._trace_client.trace_id,
1839
+ depth=current_depth,
1840
+ message=self._span_name,
1841
+ created_at=start_time,
1842
+ span_type="llm",
1843
+ parent_span_id=self._parent_span_id_at_entry
2056
1844
  )
2057
- self._trace_client.add_entry(enter_entry)
2058
- # --- End manual 'enter' entry ---
2059
-
2060
- # Set the current span ID in contextvars
2061
- self._span_context_token = current_span_var.set(span_id)
1845
+ self._trace_client.add_span(span)
1846
+ return span_id, span
2062
1847
 
2063
- # Manually create 'input' entry
2064
- input_data = _format_input_data(self._client, **self._input_kwargs)
2065
- input_entry = TraceEntry(
2066
- type="input", function=self._span_name, span_id=span_id,
2067
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
2068
- created_at=time.time(), inputs=input_data, span_type="llm"
2069
- )
2070
- self._trace_client.add_entry(input_entry)
1848
+ def _finalize_span(self, span_id):
1849
+ span = self._trace_client.span_id_to_span.get(span_id)
1850
+ if span:
1851
+ span.duration = time.time() - span.created_at
1852
+ if span_id in self._trace_client._span_depths:
1853
+ del self._trace_client._span_depths[span_id]
2071
1854
 
2072
- # Call the original __aenter__
2073
- raw_iterator = await self._original_manager.__aenter__()
1855
+ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
1856
+ async def __aenter__(self):
1857
+ self._parent_span_id_at_entry = current_span_var.get()
1858
+ if not self._trace_client:
1859
+ return await self._original_manager.__aenter__()
2074
1860
 
2075
- # Manually create pending 'output' entry
2076
- output_entry = TraceEntry(
2077
- type="output", function=self._span_name, span_id=span_id,
2078
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
2079
- created_at=time.time(), output="<pending stream>", span_type="llm"
2080
- )
2081
- self._trace_client.add_entry(output_entry)
1861
+ span_id, span = self._create_span()
1862
+ self._span_context_token = current_span_var.set(span_id)
1863
+ span.inputs = _format_input_data(self._client, **self._input_kwargs)
2082
1864
 
2083
- # Wrap the raw iterator
2084
- wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
2085
- return wrapped_iterator
1865
+ # Call the original __aenter__ and expect it to be an async generator
1866
+ raw_iterator = await self._original_manager.__aenter__()
1867
+ span.output = "<pending stream>"
1868
+ return self._stream_wrapper_func(raw_iterator, self._client, span)
2086
1869
 
2087
1870
  async def __aexit__(self, exc_type, exc_val, exc_tb):
2088
- # Manually create the 'exit' entry
2089
1871
  if hasattr(self, '_span_context_token'):
2090
- span_id = current_span_var.get()
2091
- start_time_for_duration = 0
2092
- for entry in reversed(self._trace_client.entries):
2093
- if entry.span_id == span_id and entry.type == 'enter':
2094
- start_time_for_duration = entry.created_at
2095
- break
2096
- duration = time.time() - start_time_for_duration if start_time_for_duration else None
2097
- exit_depth = self._trace_client._span_depths.get(span_id, 0)
2098
- exit_entry = TraceEntry(
2099
- type="exit", function=self._span_name, span_id=span_id,
2100
- trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
2101
- created_at=time.time(), duration=duration, span_type="llm"
2102
- )
2103
- self._trace_client.add_entry(exit_entry)
2104
- if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
2105
- current_span_var.reset(self._span_context_token)
2106
- delattr(self, '_span_context_token')
2107
-
2108
- # Delegate __aexit__
2109
- if hasattr(self._original_manager, "__aexit__"):
2110
- return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2111
- return None
2112
-
2113
- class _TracedSyncStreamManagerWrapper(AbstractContextManager):
2114
- """Wraps an original sync stream manager to add tracing."""
2115
- def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
2116
- self._original_manager = original_manager
2117
- self._client = client
2118
- self._span_name = span_name
2119
- self._trace_client = trace_client
2120
- self._stream_wrapper_func = stream_wrapper_func
2121
- self._input_kwargs = input_kwargs
2122
- self._parent_span_id_at_entry = None
1872
+ span_id = current_span_var.get()
1873
+ self._finalize_span(span_id)
1874
+ current_span_var.reset(self._span_context_token)
1875
+ delattr(self, '_span_context_token')
1876
+ return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2123
1877
 
1878
+ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
2124
1879
  def __enter__(self):
2125
1880
  self._parent_span_id_at_entry = current_span_var.get()
2126
1881
  if not self._trace_client:
2127
- return self._original_manager.__enter__()
1882
+ return self._original_manager.__enter__()
2128
1883
 
2129
- # Manually create 'enter' entry
2130
- start_time = time.time()
2131
- span_id = str(uuid.uuid4())
2132
- current_depth = 0
2133
- if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2134
- current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
2135
- self._trace_client._span_depths[span_id] = current_depth
2136
- enter_entry = TraceEntry(
2137
- type="enter", function=self._span_name, span_id=span_id,
2138
- trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
2139
- created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
2140
- )
2141
- self._trace_client.add_entry(enter_entry)
1884
+ span_id, span = self._create_span()
2142
1885
  self._span_context_token = current_span_var.set(span_id)
1886
+ span.inputs = _format_input_data(self._client, **self._input_kwargs)
2143
1887
 
2144
- # Manually create 'input' entry
2145
- input_data = _format_input_data(self._client, **self._input_kwargs)
2146
- input_entry = TraceEntry(
2147
- type="input", function=self._span_name, span_id=span_id,
2148
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
2149
- created_at=time.time(), inputs=input_data, span_type="llm"
2150
- )
2151
- self._trace_client.add_entry(input_entry)
2152
-
2153
- # Call original __enter__
2154
1888
  raw_iterator = self._original_manager.__enter__()
2155
-
2156
- # Manually create 'output' entry (pending)
2157
- output_entry = TraceEntry(
2158
- type="output", function=self._span_name, span_id=span_id,
2159
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
2160
- created_at=time.time(), output="<pending stream>", span_type="llm"
2161
- )
2162
- self._trace_client.add_entry(output_entry)
2163
-
2164
- # Wrap the raw iterator
2165
- wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
2166
- return wrapped_iterator
1889
+ span.output = "<pending stream>"
1890
+ return self._stream_wrapper_func(raw_iterator, self._client, span)
2167
1891
 
2168
1892
  def __exit__(self, exc_type, exc_val, exc_tb):
2169
- # Manually create 'exit' entry
2170
1893
  if hasattr(self, '_span_context_token'):
2171
- span_id = current_span_var.get()
2172
- start_time_for_duration = 0
2173
- for entry in reversed(self._trace_client.entries):
2174
- if entry.span_id == span_id and entry.type == 'enter':
2175
- start_time_for_duration = entry.created_at
2176
- break
2177
- duration = time.time() - start_time_for_duration if start_time_for_duration else None
2178
- exit_depth = self._trace_client._span_depths.get(span_id, 0)
2179
- exit_entry = TraceEntry(
2180
- type="exit", function=self._span_name, span_id=span_id,
2181
- trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
2182
- created_at=time.time(), duration=duration, span_type="llm"
2183
- )
2184
- self._trace_client.add_entry(exit_entry)
2185
- if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
2186
- current_span_var.reset(self._span_context_token)
2187
- delattr(self, '_span_context_token')
2188
-
2189
- # Delegate __exit__
2190
- if hasattr(self._original_manager, "__exit__"):
2191
- return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2192
- return None
1894
+ span_id = current_span_var.get()
1895
+ self._finalize_span(span_id)
1896
+ current_span_var.reset(self._span_context_token)
1897
+ delattr(self, '_span_context_token')
1898
+ return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2193
1899
 
2194
1900
  # --- NEW Generalized Helper Function (Moved from demo) ---
2195
1901
  def prepare_evaluation_for_state(
@@ -2314,3 +2020,4 @@ def add_evaluation_to_state(
2314
2020
 
2315
2021
  # print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
2316
2022
  # --- End NEW Helper ---
2023
+