judgeval 0.0.36__py3-none-any.whl → 0.0.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -7,7 +7,11 @@ import functools
7
7
  import inspect
8
8
  import json
9
9
  import os
10
+ import site
11
+ import sysconfig
12
+ import threading
10
13
  import time
14
+ import traceback
11
15
  import uuid
12
16
  import warnings
13
17
  import contextvars
@@ -35,7 +39,6 @@ from rich import print as rprint
35
39
  import types # <--- Add this import
36
40
 
37
41
  # Third-party imports
38
- import pika
39
42
  import requests
40
43
  from litellm import cost_per_token
41
44
  from pydantic import BaseModel
@@ -44,10 +47,10 @@ from openai import OpenAI, AsyncOpenAI
44
47
  from together import Together, AsyncTogether
45
48
  from anthropic import Anthropic, AsyncAnthropic
46
49
  from google import genai
47
- from judgeval.run_evaluation import check_examples
48
50
 
49
51
  # Local application/library-specific imports
50
52
  from judgeval.constants import (
53
+ JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
51
54
  JUDGMENT_TRACES_SAVE_API_URL,
52
55
  JUDGMENT_TRACES_FETCH_API_URL,
53
56
  RABBITMQ_HOST,
@@ -56,25 +59,24 @@ from judgeval.constants import (
56
59
  JUDGMENT_TRACES_DELETE_API_URL,
57
60
  JUDGMENT_PROJECT_DELETE_API_URL,
58
61
  )
59
- from judgeval.judgment_client import JudgmentClient
60
- from judgeval.data import Example
62
+ from judgeval.data import Example, Trace, TraceSpan
61
63
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
62
64
  from judgeval.rules import Rule
63
65
  from judgeval.evaluation_run import EvaluationRun
64
66
  from judgeval.data.result import ScoringResult
67
+ from judgeval.common.utils import validate_api_key
68
+ from judgeval.common.exceptions import JudgmentAPIError
65
69
 
66
70
  # Standard library imports needed for the new class
67
71
  import concurrent.futures
68
72
  from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
69
73
 
70
74
  # Define context variables for tracking the current trace and the current span within a trace
71
- current_trace_var = contextvars.ContextVar('current_trace', default=None)
75
+ current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
72
76
  current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
73
- in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
74
77
 
75
78
  # Define type aliases for better code readability and maintainability
76
79
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
77
- TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
78
80
  SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
79
81
 
80
82
  # --- Evaluation Config Dataclass (Moved from langgraph.py) ---
@@ -87,154 +89,26 @@ class EvaluationConfig:
87
89
  log_results: Optional[bool] = True
88
90
  # --- End Evaluation Config Dataclass ---
89
91
 
92
+ # Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
90
93
  @dataclass
91
- class TraceEntry:
92
- """Represents a single trace entry with its visual representation.
93
-
94
- Visual representations:
95
- - enter: → (function entry)
96
- - exit: ← (function exit)
97
- - output: Output: (function return value)
98
- - input: Input: (function parameters)
99
- - evaluation: Evaluation: (evaluation results)
100
- """
101
- type: TraceEntryType
102
- span_id: str # Unique ID for this specific span instance
103
- depth: int # Indentation level for nested calls
104
- created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
105
- function: Optional[str] = None # Name of the function being traced
106
- message: Optional[str] = None # Human-readable description
107
- duration: Optional[float] = None # Time taken (for exit/evaluation entries)
108
- trace_id: str = None # ID of the trace this entry belongs to
109
- output: Any = None # Function output value
110
- # Use field() for mutable defaults to avoid shared state issues
111
- inputs: dict = field(default_factory=dict)
112
- span_type: SpanType = "span"
113
- evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
114
- parent_span_id: Optional[str] = None # ID of the parent span instance
115
-
116
- def print_entry(self):
117
- """Print a trace entry with proper formatting and parent relationship information."""
118
- indent = " " * self.depth
119
-
120
- if self.type == "enter":
121
- # Format parent info if present
122
- parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
123
- print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
124
- elif self.type == "exit":
125
- print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
126
- elif self.type == "output":
127
- # Format output to align properly
128
- output_str = str(self.output)
129
- print(f"{indent}Output (for id: {self.span_id}): {output_str}")
130
- elif self.type == "input":
131
- # Format inputs to align properly
132
- print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
133
- elif self.type == "evaluation":
134
- for evaluation_run in self.evaluation_runs:
135
- print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
136
-
137
- def _serialize_inputs(self) -> dict:
138
- """Helper method to serialize input data safely.
139
-
140
- Returns a dict with serializable versions of inputs, converting non-serializable
141
- objects to None with a warning.
142
- """
143
- serialized_inputs = {}
144
- for key, value in self.inputs.items():
145
- if isinstance(value, BaseModel):
146
- serialized_inputs[key] = value.model_dump()
147
- elif isinstance(value, (list, tuple)):
148
- # Handle lists/tuples of arguments
149
- serialized_inputs[key] = [
150
- item.model_dump() if isinstance(item, BaseModel)
151
- else None if not self._is_json_serializable(item)
152
- else item
153
- for item in value
154
- ]
155
- else:
156
- if self._is_json_serializable(value):
157
- serialized_inputs[key] = value
158
- else:
159
- serialized_inputs[key] = self.safe_stringify(value, self.function)
160
- return serialized_inputs
161
-
162
- def _is_json_serializable(self, obj: Any) -> bool:
163
- """Helper method to check if an object is JSON serializable."""
164
- try:
165
- json.dumps(obj)
166
- return True
167
- except (TypeError, OverflowError, ValueError):
168
- return False
169
-
170
- def safe_stringify(self, output, function_name):
171
- """
172
- Safely converts an object to a string or repr, handling serialization issues gracefully.
173
- """
174
- try:
175
- return str(output)
176
- except (TypeError, OverflowError, ValueError):
177
- pass
178
-
179
- try:
180
- return repr(output)
181
- except (TypeError, OverflowError, ValueError):
182
- pass
183
-
184
- warnings.warn(
185
- f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
186
- )
187
- return None
94
+ class TraceAnnotation:
95
+ """Represents a single annotation for a trace span."""
96
+ span_id: str
97
+ text: str
98
+ label: str
99
+ score: int
188
100
 
189
101
  def to_dict(self) -> dict:
190
- """Convert the trace entry to a dictionary format for storage/transmission."""
102
+ """Convert the annotation to a dictionary format for storage/transmission."""
191
103
  return {
192
- "type": self.type,
193
- "function": self.function,
194
104
  "span_id": self.span_id,
195
- "trace_id": self.trace_id,
196
- "depth": self.depth,
197
- "message": self.message,
198
- "created_at": datetime.fromtimestamp(self.created_at).isoformat(),
199
- "duration": self.duration,
200
- "output": self._serialize_output(),
201
- "inputs": self._serialize_inputs(),
202
- "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
203
- "span_type": self.span_type,
204
- "parent_span_id": self.parent_span_id,
105
+ "annotation": {
106
+ "text": self.text,
107
+ "label": self.label,
108
+ "score": self.score
109
+ }
205
110
  }
206
-
207
- def _serialize_output(self) -> Any:
208
- """Helper method to serialize output data safely.
209
-
210
- Handles special cases:
211
- - Pydantic models are converted using model_dump()
212
- - Dictionaries are processed recursively to handle non-serializable values.
213
- - We try to serialize into JSON, then string, then the base representation (__repr__)
214
- - Non-serializable objects return None with a warning
215
- """
216
-
217
- def serialize_value(value):
218
- if isinstance(value, BaseModel):
219
- return value.model_dump()
220
- elif isinstance(value, dict):
221
- # Recursively serialize dictionary values
222
- return {k: serialize_value(v) for k, v in value.items()}
223
- elif isinstance(value, (list, tuple)):
224
- # Recursively serialize list/tuple items
225
- return [serialize_value(item) for item in value]
226
- else:
227
- # Try direct JSON serialization first
228
- try:
229
- json.dumps(value)
230
- return value
231
- except (TypeError, OverflowError, ValueError):
232
- # Fallback to safe stringification
233
- return self.safe_stringify(value, self.function)
234
-
235
- # Start serialization with the top-level output
236
- return serialize_value(self.output)
237
-
111
+
238
112
  class TraceManagerClient:
239
113
  """
240
114
  Client for handling trace endpoints with the Judgment API
@@ -271,10 +145,8 @@ class TraceManagerClient:
271
145
  raise ValueError(f"Failed to fetch traces: {response.text}")
272
146
 
273
147
  return response.json()
274
-
275
-
276
148
 
277
- def save_trace(self, trace_data: dict):
149
+ def save_trace(self, trace_data: dict, offline_mode: bool = False):
278
150
  """
279
151
  Saves a trace to the Judgment Supabase and optionally to S3 if configured.
280
152
 
@@ -311,10 +183,37 @@ class TraceManagerClient:
311
183
  except Exception as e:
312
184
  warnings.warn(f"Failed to save trace to S3: {str(e)}")
313
185
 
314
- if "ui_results_url" in response.json():
186
+ if not offline_mode and "ui_results_url" in response.json():
315
187
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
316
188
  rprint(pretty_str)
317
189
 
190
+ ## TODO: Should have a log endpoint, endpoint should also support batched payloads
191
+ def save_annotation(self, annotation: TraceAnnotation):
192
+ json_data = {
193
+ "span_id": annotation.span_id,
194
+ "annotation": {
195
+ "text": annotation.text,
196
+ "label": annotation.label,
197
+ "score": annotation.score
198
+ }
199
+ }
200
+
201
+ response = requests.post(
202
+ JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
203
+ json=json_data,
204
+ headers={
205
+ 'Content-Type': 'application/json',
206
+ 'Authorization': f'Bearer {self.judgment_api_key}',
207
+ 'X-Organization-Id': self.organization_id
208
+ },
209
+ verify=True
210
+ )
211
+
212
+ if response.status_code != HTTPStatus.OK:
213
+ raise ValueError(f"Failed to save annotation: {response.text}")
214
+
215
+ return response.json()
216
+
318
217
  def delete_trace(self, trace_id: str):
319
218
  """
320
219
  Delete a trace from the database.
@@ -405,15 +304,17 @@ class TraceClient:
405
304
  self.enable_evaluations = enable_evaluations
406
305
  self.parent_trace_id = parent_trace_id
407
306
  self.parent_name = parent_name
408
- self.client: JudgmentClient = tracer.client
409
- self.entries: List[TraceEntry] = []
307
+ self.trace_spans: List[TraceSpan] = []
308
+ self.span_id_to_span: Dict[str, TraceSpan] = {}
309
+ self.evaluation_runs: List[EvaluationRun] = []
310
+ self.annotations: List[TraceAnnotation] = []
410
311
  self.start_time = time.time()
411
312
  self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
412
313
  self.visited_nodes = []
413
314
  self.executed_tools = []
414
315
  self.executed_node_tools = []
415
316
  self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
416
-
317
+
417
318
  def get_current_span(self):
418
319
  """Get the current span from the context var"""
419
320
  return current_span_var.get()
@@ -443,9 +344,7 @@ class TraceClient:
443
344
 
444
345
  self._span_depths[span_id] = current_depth # Store depth by span_id
445
346
 
446
- entry = TraceEntry(
447
- type="enter",
448
- function=name,
347
+ span = TraceSpan(
449
348
  span_id=span_id,
450
349
  trace_id=self.trace_id,
451
350
  depth=current_depth,
@@ -453,25 +352,15 @@ class TraceClient:
453
352
  created_at=start_time,
454
353
  span_type=span_type,
455
354
  parent_span_id=parent_span_id,
355
+ function=name,
456
356
  )
457
- self.add_entry(entry)
357
+ self.add_span(span)
458
358
 
459
359
  try:
460
360
  yield self
461
361
  finally:
462
362
  duration = time.time() - start_time
463
- exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
464
- self.add_entry(TraceEntry(
465
- type="exit",
466
- function=name,
467
- span_id=span_id, # Use the same span_id for exit
468
- trace_id=self.trace_id, # Use the trace_id from the trace client
469
- depth=exit_depth,
470
- message=f"← {name}",
471
- created_at=time.time(),
472
- duration=duration,
473
- span_type=span_type,
474
- ))
363
+ span.duration = duration
475
364
  # Clean up depth tracking for this span_id
476
365
  if span_id in self._span_depths:
477
366
  del self._span_depths[span_id]
@@ -528,19 +417,20 @@ class TraceClient:
528
417
  tools_called=tools_called,
529
418
  expected_tools=expected_tools,
530
419
  additional_metadata=additional_metadata,
531
- trace_id=self.trace_id
532
420
  )
533
421
  else:
534
422
  raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
535
423
 
536
424
  # Check examples before creating evaluation run
537
- check_examples([example], scorers)
425
+
426
+ # check_examples([example], scorers)
538
427
 
539
428
  # --- Modification: Capture span_id immediately ---
540
429
  # span_id_at_eval_call = current_span_var.get()
541
430
  # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
542
431
  # Prioritize explicitly passed span_id, fallback to context var
543
- span_id_to_use = span_id if span_id is not None else current_span_var.get()
432
+ current_span_ctx_var = current_span_var.get()
433
+ span_id_to_use = span_id if span_id is not None else current_span_ctx_var if current_span_ctx_var is not None else self.tracer.get_current_span()
544
434
  # print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
545
435
  # --- End Modification ---
546
436
 
@@ -550,7 +440,7 @@ class TraceClient:
550
440
  log_results=log_results,
551
441
  project_name=self.project_name,
552
442
  eval_name=f"{self.name.capitalize()}-"
553
- f"{current_span_var.get()}-" # Keep original eval name format using context var if available
443
+ f"{span_id_to_use}-" # Keep original eval name format using context var if available
554
444
  f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
555
445
  examples=[example],
556
446
  scorers=scorers,
@@ -571,290 +461,60 @@ class TraceClient:
571
461
  # --- End Modification ---
572
462
 
573
463
  if current_span_id:
574
- duration = time.time() - start_time
575
- prev_entry = self.entries[-1] if self.entries else None
576
- # Determine function name based on previous entry or context var (less ideal)
577
- function_name = "unknown_function" # Default
578
- if prev_entry and prev_entry.span_type == "llm":
579
- function_name = prev_entry.function
580
- else:
581
- # Try to find the function name associated with the current span_id
582
- for entry in reversed(self.entries):
583
- if entry.span_id == current_span_id and entry.type == 'enter':
584
- function_name = entry.function
585
- break
586
-
587
- # Get depth for the current span
588
- current_depth = self._span_depths.get(current_span_id, 0)
589
-
590
- self.add_entry(TraceEntry(
591
- type="evaluation",
592
- function=function_name,
593
- span_id=current_span_id, # Associate with current span
594
- trace_id=self.trace_id, # Use the trace_id from the trace client
595
- depth=current_depth,
596
- message=f"Evaluation results for {function_name}",
597
- created_at=time.time(),
598
- evaluation_runs=[eval_run],
599
- duration=duration,
600
- span_type="evaluation"
601
- ))
602
-
464
+ span = self.span_id_to_span[current_span_id]
465
+ span.evaluation_runs.append(eval_run)
466
+ self.evaluation_runs.append(eval_run)
467
+
468
+ def add_annotation(self, annotation: TraceAnnotation):
469
+ """Add an annotation to this trace context"""
470
+ self.annotations.append(annotation)
471
+ return self
472
+
603
473
  def record_input(self, inputs: dict):
604
474
  current_span_id = current_span_var.get()
605
475
  if current_span_id:
606
- entry_span_type = "span"
607
- current_depth = self._span_depths.get(current_span_id, 0)
608
- function_name = "unknown_function" # Default
609
- for entry in reversed(self.entries):
610
- if entry.span_id == current_span_id and entry.type == 'enter':
611
- entry_span_type = entry.span_type
612
- function_name = entry.function
613
- break
614
-
615
- self.add_entry(TraceEntry(
616
- type="input",
617
- function=function_name,
618
- span_id=current_span_id, # Use current span_id from context
619
- trace_id=self.trace_id, # Use the trace_id from the trace client
620
- depth=current_depth,
621
- message=f"Inputs to {function_name}",
622
- created_at=time.time(),
623
- inputs=inputs,
624
- span_type=entry_span_type,
625
- ))
626
- # Removed else block - original didn't have one
476
+ span = self.span_id_to_span[current_span_id]
477
+ span.inputs = inputs
627
478
 
628
- async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
479
+ async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
629
480
  """Helper method to update the output of a trace entry once the coroutine completes"""
630
481
  try:
631
482
  result = await coroutine
632
- entry.output = result
483
+ span.output = result
633
484
  return result
634
485
  except Exception as e:
635
- entry.output = f"Error: {str(e)}"
486
+ span.output = f"Error: {str(e)}"
636
487
  raise
637
488
 
638
489
  def record_output(self, output: Any):
639
490
  current_span_id = current_span_var.get()
640
491
  if current_span_id:
641
- entry_span_type = "span"
642
- current_depth = self._span_depths.get(current_span_id, 0)
643
- function_name = "unknown_function" # Default
644
- for entry in reversed(self.entries):
645
- if entry.span_id == current_span_id and entry.type == 'enter':
646
- entry_span_type = entry.span_type
647
- function_name = entry.function
648
- break
649
-
650
- entry = TraceEntry(
651
- type="output",
652
- function=function_name,
653
- span_id=current_span_id, # Use current span_id from context
654
- depth=current_depth,
655
- message=f"Output from {function_name}",
656
- created_at=time.time(),
657
- output="<pending>" if inspect.iscoroutine(output) else output,
658
- span_type=entry_span_type,
659
- trace_id=self.trace_id # Added trace_id for consistency
660
- )
661
- self.add_entry(entry)
492
+ span = self.span_id_to_span[current_span_id]
493
+ span.output = "<pending>" if inspect.iscoroutine(output) else output
662
494
 
663
495
  if inspect.iscoroutine(output):
664
- asyncio.create_task(self._update_coroutine_output(entry, output))
496
+ asyncio.create_task(self._update_coroutine_output(span, output))
665
497
 
666
- return entry # Return the created entry
498
+ return span # Return the created entry
667
499
  # Removed else block - original didn't have one
668
500
  return None # Return None if no span_id found
669
501
 
670
- def add_entry(self, entry: TraceEntry):
671
- """Add a trace entry to this trace context"""
672
- self.entries.append(entry)
502
+ def add_span(self, span: TraceSpan):
503
+ """Add a trace span to this trace context"""
504
+ self.trace_spans.append(span)
505
+ self.span_id_to_span[span.span_id] = span
673
506
  return self
674
507
 
675
508
  def print(self):
676
509
  """Print the complete trace with proper visual structure"""
677
- for entry in self.entries:
678
- entry.print_entry()
679
-
680
- def print_hierarchical(self):
681
- """Print the trace in a hierarchical structure based on parent-child relationships"""
682
- # First, build a map of spans
683
- spans = {}
684
- root_spans = []
685
-
686
- # Collect all enter events first
687
- for entry in self.entries:
688
- if entry.type == "enter":
689
- spans[entry.function] = {
690
- "name": entry.function,
691
- "depth": entry.depth,
692
- "parent_id": entry.parent_span_id,
693
- "children": []
694
- }
695
-
696
- # If no parent, it's a root span
697
- if not entry.parent_span_id:
698
- root_spans.append(entry.function)
699
- elif entry.parent_span_id not in spans:
700
- # If parent doesn't exist yet, temporarily treat as root
701
- # (we'll fix this later)
702
- root_spans.append(entry.function)
703
-
704
- # Build parent-child relationships
705
- for span_name, span in spans.items():
706
- parent = span["parent_id"]
707
- if parent and parent in spans:
708
- spans[parent]["children"].append(span_name)
709
- # Remove from root spans if it was temporarily there
710
- if span_name in root_spans:
711
- root_spans.remove(span_name)
712
-
713
- # Now print the hierarchy
714
- def print_span(span_name, level=0):
715
- if span_name not in spans:
716
- return
717
-
718
- span = spans[span_name]
719
- indent = " " * level
720
- parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
721
- print(f"{indent}→ {span_name}{parent_info}")
722
-
723
- # Print children
724
- for child in span["children"]:
725
- print_span(child, level + 1)
726
-
727
- # Print starting with root spans
728
- print("\nHierarchical Trace Structure:")
729
- for root in root_spans:
730
- print_span(root)
510
+ for span in self.trace_spans:
511
+ span.print_span()
731
512
 
732
513
  def get_duration(self) -> float:
733
514
  """
734
515
  Get the total duration of this trace
735
516
  """
736
517
  return time.time() - self.start_time
737
-
738
- def condense_trace(self, entries: List[dict]) -> List[dict]:
739
- """
740
- Condenses trace entries into a single entry for each span instance,
741
- preserving parent-child span relationships using span_id and parent_span_id.
742
- """
743
- spans_by_id: Dict[str, dict] = {}
744
- evaluation_runs: List[EvaluationRun] = []
745
-
746
- # First pass: Group entries by span_id and gather data
747
- for entry in entries:
748
- span_id = entry.get("span_id")
749
- if not span_id:
750
- continue # Skip entries without a span_id (should not happen)
751
-
752
- if entry["type"] == "enter":
753
- if span_id not in spans_by_id:
754
- spans_by_id[span_id] = {
755
- "span_id": span_id,
756
- "function": entry["function"],
757
- "depth": entry["depth"], # Use the depth recorded at entry time
758
- "created_at": entry["created_at"],
759
- "trace_id": entry["trace_id"],
760
- "parent_span_id": entry.get("parent_span_id"),
761
- "span_type": entry.get("span_type", "span"),
762
- "inputs": None,
763
- "output": None,
764
- "evaluation_runs": [],
765
- "duration": None
766
- }
767
- # Handle potential duplicate enter events if necessary (e.g., log warning)
768
-
769
- elif span_id in spans_by_id:
770
- current_span_data = spans_by_id[span_id]
771
-
772
- if entry["type"] == "input" and entry["inputs"]:
773
- # Merge inputs if multiple are recorded, or just assign
774
- if current_span_data["inputs"] is None:
775
- current_span_data["inputs"] = entry["inputs"]
776
- elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
777
- current_span_data["inputs"].update(entry["inputs"])
778
- # Add more sophisticated merging if needed
779
-
780
- elif entry["type"] == "output" and "output" in entry:
781
- current_span_data["output"] = entry["output"]
782
-
783
- elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
784
- if current_span_data.get("evaluation_runs") is not None:
785
- evaluation_runs.extend(entry["evaluation_runs"])
786
-
787
- elif entry["type"] == "exit":
788
- if current_span_data["duration"] is None: # Calculate duration only once
789
- start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
790
- end_time = datetime.fromisoformat(entry["created_at"])
791
- current_span_data["duration"] = (end_time - start_time).total_seconds()
792
- # Update depth if exit depth is different (though current span() implementation keeps it same)
793
- # current_span_data["depth"] = entry["depth"]
794
-
795
- # Convert dictionary to a list initially for easier access
796
- spans_list = list(spans_by_id.values())
797
-
798
- # Build tree structure (adjacency list) and find roots
799
- children_map: Dict[Optional[str], List[dict]] = {}
800
- roots = []
801
- span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
802
-
803
- for span in spans_list:
804
- parent_id = span.get("parent_span_id")
805
- if parent_id is None:
806
- roots.append(span)
807
- else:
808
- if parent_id not in children_map:
809
- children_map[parent_id] = []
810
- children_map[parent_id].append(span)
811
-
812
- # Sort roots by timestamp
813
- roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
814
-
815
- # Perform depth-first traversal to get the final sorted list
816
- sorted_condensed_list = []
817
- visited = set() # To handle potential cycles, though unlikely with UUIDs
818
-
819
- def dfs(span_data):
820
- span_id = span_data['span_id']
821
- if span_id in visited:
822
- return # Avoid infinite loops in case of cycles
823
- visited.add(span_id)
824
-
825
- sorted_condensed_list.append(span_data) # Add parent before children
826
-
827
- # Get children, sort them by created_at, and visit them
828
- span_children = children_map.get(span_id, [])
829
- span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
830
- for child in span_children:
831
- # Ensure the child exists in our map before recursing
832
- if child['span_id'] in span_map:
833
- dfs(child)
834
- else:
835
- # This case might indicate an issue, but we'll add the child directly
836
- # if its parent was processed but the child itself wasn't in the initial list?
837
- # Or if the child's 'enter' event was missing. For robustness, add it.
838
- if child['span_id'] not in visited:
839
- visited.add(child['span_id'])
840
- sorted_condensed_list.append(child)
841
-
842
-
843
- # Start DFS from each root
844
- for root_span in roots:
845
- if root_span['span_id'] not in visited:
846
- dfs(root_span)
847
-
848
- # Handle spans that might not have been reachable from roots (orphans)
849
- # Though ideally, all spans should descend from a root.
850
- for span_data in spans_list:
851
- if span_data['span_id'] not in visited:
852
- # Decide how to handle orphans, maybe append them at the end sorted by time?
853
- # For now, let's just add them to ensure they aren't lost.
854
- sorted_condensed_list.append(span_data)
855
-
856
-
857
- return sorted_condensed_list, evaluation_runs
858
518
 
859
519
  def save(self, overwrite: bool = False) -> Tuple[str, dict]:
860
520
  """
@@ -863,44 +523,36 @@ class TraceClient:
863
523
  """
864
524
  # Calculate total elapsed time
865
525
  total_duration = self.get_duration()
866
-
867
- raw_entries = [entry.to_dict() for entry in self.entries]
868
-
869
- condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
870
526
 
871
527
  # Only count tokens for actual LLM API call spans
872
528
  llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
873
- for entry in condensed_entries:
874
- entry_function_name = entry.get("function", "") # Get function name safely
529
+ for span in self.trace_spans:
530
+ span_function_name = span.function # Get function name safely
875
531
  # Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
876
- is_llm_entry = entry.get("span_type") == "llm"
877
- has_api_suffix = any(suffix in entry_function_name for suffix in llm_span_names)
878
- output_is_dict = isinstance(entry.get("output"), dict)
532
+ is_llm_span = span.span_type == "llm"
533
+ has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
534
+ output_is_dict = isinstance(span.output, dict)
879
535
 
880
536
  # --- DEBUG PRINT 1: Check if condition passes ---
881
537
  # if is_llm_entry and has_api_suffix and output_is_dict:
882
- # # print(f"[DEBUG TraceClient.save] Processing entry: {entry.get('span_id')} ({entry_function_name}) - Condition PASSED")
883
538
  # elif is_llm_entry:
884
539
  # # Print why it failed if it was an LLM entry
885
- # print(f"[DEBUG TraceClient.save] Skipping LLM entry: {entry.get('span_id')} ({entry_function_name}) - Suffix Match: {has_api_suffix}, Output is Dict: {output_is_dict}")
886
540
  # # --- END DEBUG ---
887
541
 
888
- if is_llm_entry and has_api_suffix and output_is_dict:
889
- output = entry["output"]
542
+ if is_llm_span and has_api_suffix and output_is_dict:
543
+ output = span.output
890
544
  usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
891
545
 
892
546
  # --- DEBUG PRINT 2: Check extracted usage ---
893
- # print(f"[DEBUG TraceClient.save] Extracted usage dict: {usage}")
894
547
  # --- END DEBUG ---
895
548
 
896
549
  # --- NEW: Extract model_name correctly from nested inputs ---
897
550
  model_name = None
898
- entry_inputs = entry.get("inputs", {})
899
- # print(f"[DEBUG TraceClient.save] Inspecting inputs for span {entry.get('span_id')}: {entry_inputs}") # DEBUG Inputs
900
- if entry_inputs:
551
+ span_inputs = span.inputs
552
+ if span_inputs:
901
553
  # Try common locations for model name within the inputs structure
902
- invocation_params = entry_inputs.get("invocation_params", {})
903
- serialized_data = entry_inputs.get("serialized", {})
554
+ invocation_params = span_inputs.get("invocation_params", {})
555
+ serialized_data = span_inputs.get("serialized", {})
904
556
 
905
557
  # Look in invocation_params (often directly contains model)
906
558
  if isinstance(invocation_params, dict):
@@ -920,10 +572,9 @@ class TraceClient:
920
572
 
921
573
  # Fallback: Check top-level of inputs itself (less likely for callbacks)
922
574
  if not model_name:
923
- model_name = entry_inputs.get("model")
575
+ model_name = span_inputs.get("model")
924
576
 
925
577
 
926
- # print(f"[DEBUG TraceClient.save] Determined model_name: {model_name}") # DEBUG Model Name
927
578
  # --- END NEW ---
928
579
 
929
580
  prompt_tokens = 0
@@ -985,7 +636,7 @@ class TraceClient:
985
636
  if "usage" not in output:
986
637
  output["usage"] = {} # Initialize if missing
987
638
  elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
988
- print(f"[WARN TraceClient.save] Output 'usage' for span {entry.get('span_id')} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
639
+ print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
989
640
  output["usage"] = {} # Reset to dict
990
641
 
991
642
  output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
@@ -993,10 +644,10 @@ class TraceClient:
993
644
  output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
994
645
  except Exception as e:
995
646
  # If cost calculation fails, continue without adding costs
996
- print(f"Error calculating cost for model '{model_name}' (span: {entry.get('span_id')}): {str(e)}")
647
+ print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
997
648
  pass
998
649
  else:
999
- print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {entry.get('span_id')}). Inputs: {entry_inputs}")
650
+ print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
1000
651
 
1001
652
 
1002
653
  # Create trace document - Always use standard keys for top-level counts
@@ -1006,20 +657,258 @@ class TraceClient:
1006
657
  "project_name": self.project_name,
1007
658
  "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
1008
659
  "duration": total_duration,
1009
- "entries": condensed_entries,
1010
- "evaluation_runs": evaluation_runs,
660
+ "entries": [span.model_dump() for span in self.trace_spans],
661
+ "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
1011
662
  "overwrite": overwrite,
663
+ "offline_mode": self.tracer.offline_mode,
1012
664
  "parent_trace_id": self.parent_trace_id,
1013
665
  "parent_name": self.parent_name
1014
666
  }
1015
667
  # --- Log trace data before saving ---
1016
- self.trace_manager_client.save_trace(trace_data)
668
+ self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode)
669
+
670
+ # upload annotations
671
+ # TODO: batch to the log endpoint
672
+ for annotation in self.annotations:
673
+ self.trace_manager_client.save_annotation(annotation)
1017
674
 
1018
675
  return self.trace_id, trace_data
1019
676
 
1020
677
  def delete(self):
1021
678
  return self.trace_manager_client.delete_trace(self.trace_id)
1022
679
 
680
+
681
+ class _DeepTracer:
682
+ _instance: Optional["_DeepTracer"] = None
683
+ _lock: threading.Lock = threading.Lock()
684
+ _refcount: int = 0
685
+ _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
686
+ _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
687
+
688
+ def _get_qual_name(self, frame) -> str:
689
+ func_name = frame.f_code.co_name
690
+ module_name = frame.f_globals.get("__name__", "unknown_module")
691
+
692
+ try:
693
+ func = frame.f_globals.get(func_name)
694
+ if func is None:
695
+ return f"{module_name}.{func_name}"
696
+ if hasattr(func, "__qualname__"):
697
+ return f"{module_name}.{func.__qualname__}"
698
+ except Exception:
699
+ return f"{module_name}.{func_name}"
700
+
701
+ def __new__(cls):
702
+ with cls._lock:
703
+ if cls._instance is None:
704
+ cls._instance = super().__new__(cls)
705
+ return cls._instance
706
+
707
+ def _should_trace(self, frame):
708
+ # Skip stack is maintained by the tracer as an optimization to skip earlier
709
+ # frames in the call stack that we've already determined should be skipped
710
+ skip_stack = self._skip_stack.get()
711
+ if len(skip_stack) > 0:
712
+ return False
713
+
714
+ func_name = frame.f_code.co_name
715
+ module_name = frame.f_globals.get("__name__", None)
716
+
717
+ func = frame.f_globals.get(func_name)
718
+ if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
719
+ return False
720
+
721
+ if (
722
+ not module_name
723
+ or func_name.startswith("<") # ex: <listcomp>
724
+ or func_name.startswith("__") and func_name != "__call__" # dunders
725
+ or not self._is_user_code(frame.f_code.co_filename)
726
+ ):
727
+ return False
728
+
729
+ return True
730
+
731
+ @functools.cache
732
+ def _is_user_code(self, filename: str):
733
+ return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
734
+
735
+ def _trace(self, frame: types.FrameType, event: str, arg: Any):
736
+ frame.f_trace_lines = False
737
+ frame.f_trace_opcodes = False
738
+
739
+
740
+ if not self._should_trace(frame):
741
+ return
742
+
743
+ if event not in ("call", "return", "exception"):
744
+ return
745
+
746
+ current_trace = current_trace_var.get()
747
+ if not current_trace:
748
+ return
749
+
750
+ parent_span_id = current_span_var.get()
751
+ if not parent_span_id:
752
+ return
753
+
754
+ qual_name = self._get_qual_name(frame)
755
+ skip_stack = self._skip_stack.get()
756
+
757
+ if event == "call":
758
+ # If we have entries in the skip stack and the current qual_name matches the top entry,
759
+ # push it again to track nesting depth and skip
760
+ # As an optimization, we only care about duplicate qual_names.
761
+ if skip_stack:
762
+ if qual_name == skip_stack[-1]:
763
+ skip_stack.append(qual_name)
764
+ self._skip_stack.set(skip_stack)
765
+ return
766
+
767
+ should_trace = self._should_trace(frame)
768
+
769
+ if not should_trace:
770
+ if not skip_stack:
771
+ self._skip_stack.set([qual_name])
772
+ return
773
+ elif event == "return":
774
+ # If we have entries in skip stack and current qual_name matches the top entry,
775
+ # pop it to track exiting from the skipped section
776
+ if skip_stack and qual_name == skip_stack[-1]:
777
+ skip_stack.pop()
778
+ self._skip_stack.set(skip_stack)
779
+ return
780
+
781
+ if skip_stack:
782
+ return
783
+
784
+ span_stack = self._span_stack.get()
785
+ if event == "call":
786
+ if not self._should_trace(frame):
787
+ return
788
+
789
+ span_id = str(uuid.uuid4())
790
+
791
+ parent_depth = current_trace._span_depths.get(parent_span_id, 0)
792
+ depth = parent_depth + 1
793
+
794
+ current_trace._span_depths[span_id] = depth
795
+
796
+ start_time = time.time()
797
+
798
+ span_stack.append({
799
+ "span_id": span_id,
800
+ "parent_span_id": parent_span_id,
801
+ "function": qual_name,
802
+ "start_time": start_time
803
+ })
804
+ self._span_stack.set(span_stack)
805
+
806
+ token = current_span_var.set(span_id)
807
+ frame.f_locals["_judgment_span_token"] = token
808
+
809
+ span = TraceSpan(
810
+ span_id=span_id,
811
+ trace_id=current_trace.trace_id,
812
+ depth=depth,
813
+ message=qual_name,
814
+ created_at=start_time,
815
+ span_type="span",
816
+ parent_span_id=parent_span_id,
817
+ function=qual_name
818
+ )
819
+ current_trace.add_span(span)
820
+
821
+ inputs = {}
822
+ try:
823
+ args_info = inspect.getargvalues(frame)
824
+ for arg in args_info.args:
825
+ try:
826
+ inputs[arg] = args_info.locals.get(arg)
827
+ except:
828
+ inputs[arg] = "<<Unserializable>>"
829
+ current_trace.record_input(inputs)
830
+ except Exception as e:
831
+ current_trace.record_input({
832
+ "error": str(e)
833
+ })
834
+
835
+ elif event == "return":
836
+ if not span_stack:
837
+ return
838
+
839
+ current_id = current_span_var.get()
840
+
841
+ span_data = None
842
+ for i, entry in enumerate(reversed(span_stack)):
843
+ if entry["span_id"] == current_id:
844
+ span_data = span_stack.pop(-(i+1))
845
+ self._span_stack.set(span_stack)
846
+ break
847
+
848
+ if not span_data:
849
+ return
850
+
851
+ start_time = span_data["start_time"]
852
+ duration = time.time() - start_time
853
+
854
+ current_trace.span_id_to_span[span_data["span_id"]].duration = duration
855
+
856
+ if arg is not None:
857
+ # exception handling will take priority.
858
+ current_trace.record_output(arg)
859
+
860
+ if span_data["span_id"] in current_trace._span_depths:
861
+ del current_trace._span_depths[span_data["span_id"]]
862
+
863
+ if span_stack:
864
+ current_span_var.set(span_stack[-1]["span_id"])
865
+ else:
866
+ current_span_var.set(span_data["parent_span_id"])
867
+
868
+ if "_judgment_span_token" in frame.f_locals:
869
+ current_span_var.reset(frame.f_locals["_judgment_span_token"])
870
+
871
+ elif event == "exception":
872
+ exc_type, exc_value, exc_traceback = arg
873
+ formatted_exception = {
874
+ "type": exc_type.__name__,
875
+ "message": str(exc_value),
876
+ "traceback": traceback.format_tb(exc_traceback)
877
+ }
878
+ current_trace = current_trace_var.get()
879
+ current_trace.record_output({
880
+ "error": formatted_exception
881
+ })
882
+
883
+ return self._trace
884
+
885
+ def __enter__(self):
886
+ with self._lock:
887
+ self._refcount += 1
888
+ if self._refcount == 1:
889
+ self._skip_stack.set([])
890
+ self._span_stack.set([])
891
+ sys.settrace(self._trace)
892
+ threading.settrace(self._trace)
893
+ return self
894
+
895
+ def __exit__(self, exc_type, exc_val, exc_tb):
896
+ with self._lock:
897
+ self._refcount -= 1
898
+ if self._refcount == 0:
899
+ sys.settrace(None)
900
+ threading.settrace(None)
901
+
902
+
903
+ def log(self, message: str, level: str = "info"):
904
+ """ Log a message with the span context """
905
+ current_trace = current_trace_var.get()
906
+ if current_trace:
907
+ current_trace.log(message, level)
908
+ else:
909
+ print(f"[{level}] {message}")
910
+ current_trace.record_output({"log": message})
911
+
1023
912
  class Tracer:
1024
913
  _instance = None
1025
914
 
@@ -1042,12 +931,17 @@ class Tracer:
1042
931
  s3_aws_access_key_id: Optional[str] = None,
1043
932
  s3_aws_secret_access_key: Optional[str] = None,
1044
933
  s3_region_name: Optional[str] = None,
1045
- deep_tracing: bool = True # NEW: Enable deep tracing by default
934
+ offline_mode: bool = False,
935
+ deep_tracing: bool = True # Deep tracing is enabled by default
1046
936
  ):
1047
937
  if not hasattr(self, 'initialized'):
1048
938
  if not api_key:
1049
939
  raise ValueError("Tracer must be configured with a Judgment API key")
1050
940
 
941
+ result, response = validate_api_key(api_key)
942
+ if not result:
943
+ raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
944
+
1051
945
  if not organization_id:
1052
946
  raise ValueError("Tracer must be configured with an Organization ID")
1053
947
  if use_s3 and not s3_bucket_name:
@@ -1059,11 +953,11 @@ class Tracer:
1059
953
 
1060
954
  self.api_key: str = api_key
1061
955
  self.project_name: str = project_name
1062
- self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
1063
956
  self.organization_id: str = organization_id
1064
957
  self._current_trace: Optional[str] = None
1065
958
  self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
1066
959
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
960
+ self.traces: List[Trace] = []
1067
961
  self.initialized: bool = True
1068
962
  self.enable_monitoring: bool = enable_monitoring
1069
963
  self.enable_evaluations: bool = enable_evaluations
@@ -1078,6 +972,7 @@ class Tracer:
1078
972
  aws_secret_access_key=s3_aws_secret_access_key,
1079
973
  region_name=s3_region_name
1080
974
  )
975
+ self.offline_mode: bool = offline_mode
1081
976
  self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
1082
977
 
1083
978
  elif hasattr(self, 'project_name') and self.project_name != project_name:
@@ -1087,6 +982,12 @@ class Tracer:
1087
982
  "To use a different project name, ensure the first Tracer initialization uses the desired project name.",
1088
983
  RuntimeWarning
1089
984
  )
985
+
986
+ def set_current_span(self, span_id: str):
987
+ self.current_span_id = span_id
988
+
989
+ def get_current_span(self) -> Optional[str]:
990
+ return getattr(self, 'current_span_id', None)
1090
991
 
1091
992
  def set_current_trace(self, trace: TraceClient):
1092
993
  """
@@ -1119,45 +1020,6 @@ class Tracer:
1119
1020
  """Returns the TraceClient instance currently marked as active by the handler."""
1120
1021
  return self._active_trace_client
1121
1022
 
1122
- def _apply_deep_tracing(self, func, span_type="span"):
1123
- """
1124
- Apply deep tracing to all functions in the same module as the given function.
1125
-
1126
- Args:
1127
- func: The function being traced
1128
- span_type: Type of span to use for traced functions
1129
-
1130
- Returns:
1131
- A tuple of (module, original_functions_dict) where original_functions_dict
1132
- contains the original functions that were replaced with traced versions.
1133
- """
1134
- module = inspect.getmodule(func)
1135
- if not module:
1136
- return None, {}
1137
-
1138
- # Save original functions
1139
- original_functions = {}
1140
-
1141
- # Find all functions in the module
1142
- for name, obj in inspect.getmembers(module, inspect.isfunction):
1143
- # Skip already wrapped functions
1144
- if hasattr(obj, '_judgment_traced'):
1145
- continue
1146
-
1147
- # Create a traced version of the function
1148
- # Always use default span type "span" for child functions
1149
- traced_func = _create_deep_tracing_wrapper(obj, self, "span")
1150
-
1151
- # Mark the function as traced to avoid double wrapping
1152
- traced_func._judgment_traced = True
1153
-
1154
- # Save the original function
1155
- original_functions[name] = obj
1156
-
1157
- # Replace with traced version
1158
- setattr(module, name, traced_func)
1159
-
1160
- return module, original_functions
1161
1023
 
1162
1024
  @contextmanager
1163
1025
  def trace(
@@ -1204,6 +1066,23 @@ class Tracer:
1204
1066
  finally:
1205
1067
  # Reset the context variable
1206
1068
  current_trace_var.reset(token)
1069
+
1070
+
1071
+ def log(self, msg: str, label: str = "log", score: int = 1):
1072
+ """Log a message with the current span context"""
1073
+ current_span_id = current_span_var.get()
1074
+ current_trace = current_trace_var.get()
1075
+ if current_span_id:
1076
+ annotation = TraceAnnotation(
1077
+ span_id=current_span_id,
1078
+ text=msg,
1079
+ label=label,
1080
+ score=score
1081
+ )
1082
+
1083
+ current_trace.add_annotation(annotation)
1084
+
1085
+ rprint(f"[bold]{label}:[/bold] {msg}")
1207
1086
 
1208
1087
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1209
1088
  """
@@ -1239,13 +1118,6 @@ class Tracer:
1239
1118
  if asyncio.iscoroutinefunction(func):
1240
1119
  @functools.wraps(func)
1241
1120
  async def async_wrapper(*args, **kwargs):
1242
- # Check if we're already in a traced function
1243
- if in_traced_function_var.get():
1244
- return await func(*args, **kwargs)
1245
-
1246
- # Set in_traced_function_var to True
1247
- token = in_traced_function_var.set(True)
1248
-
1249
1121
  # Get current trace from context
1250
1122
  current_trace = current_trace_var.get()
1251
1123
 
@@ -1275,81 +1147,47 @@ class Tracer:
1275
1147
  # This sets the current_span_var
1276
1148
  with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1277
1149
  # Record inputs
1278
- span.record_input({
1279
- 'args': str(args),
1280
- 'kwargs': kwargs
1281
- })
1150
+ inputs = combine_args_kwargs(func, args, kwargs)
1151
+ span.record_input(inputs)
1282
1152
 
1283
- # If deep tracing is enabled, apply monkey patching
1284
1153
  if use_deep_tracing:
1285
- module, original_functions = self._apply_deep_tracing(func, span_type)
1286
-
1287
- # Execute function
1288
- result = await func(*args, **kwargs)
1289
-
1290
- # Restore original functions if deep tracing was enabled
1291
- if use_deep_tracing and module and 'original_functions' in locals():
1292
- for name, obj in original_functions.items():
1293
- setattr(module, name, obj)
1294
-
1154
+ with _DeepTracer():
1155
+ result = await func(*args, **kwargs)
1156
+ else:
1157
+ result = await func(*args, **kwargs)
1158
+
1295
1159
  # Record output
1296
1160
  span.record_output(result)
1297
-
1298
- # Save the completed trace
1299
- current_trace.save(overwrite=overwrite)
1300
1161
  return result
1301
1162
  finally:
1163
+ # Save the completed trace
1164
+ trace_id, trace = current_trace.save(overwrite=overwrite)
1165
+ self.traces.append(trace)
1166
+
1302
1167
  # Reset trace context (span context resets automatically)
1303
1168
  current_trace_var.reset(trace_token)
1304
- # Reset in_traced_function_var
1305
- in_traced_function_var.reset(token)
1306
1169
  else:
1307
- # Already have a trace context, just create a span in it
1308
- # The span method handles current_span_var
1309
-
1310
- try:
1311
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1312
- # Record inputs
1313
- span.record_input({
1314
- 'args': str(args),
1315
- 'kwargs': kwargs
1316
- })
1317
-
1318
- # If deep tracing is enabled, apply monkey patching
1319
- if use_deep_tracing:
1320
- module, original_functions = self._apply_deep_tracing(func, span_type)
1321
-
1322
- # Execute function
1170
+ with current_trace.span(span_name, span_type=span_type) as span:
1171
+ inputs = combine_args_kwargs(func, args, kwargs)
1172
+ span.record_input(inputs)
1173
+
1174
+ if use_deep_tracing:
1175
+ with _DeepTracer():
1176
+ result = await func(*args, **kwargs)
1177
+ else:
1323
1178
  result = await func(*args, **kwargs)
1324
1179
 
1325
- # Restore original functions if deep tracing was enabled
1326
- if use_deep_tracing and module and 'original_functions' in locals():
1327
- for name, obj in original_functions.items():
1328
- setattr(module, name, obj)
1329
-
1330
- # Record output
1331
- span.record_output(result)
1332
-
1333
- return result
1334
- finally:
1335
- # Reset in_traced_function_var
1336
- in_traced_function_var.reset(token)
1337
-
1180
+ span.record_output(result)
1181
+ return result
1182
+
1338
1183
  return async_wrapper
1339
1184
  else:
1340
1185
  # Non-async function implementation with deep tracing
1341
1186
  @functools.wraps(func)
1342
- def wrapper(*args, **kwargs):
1343
- # Check if we're already in a traced function
1344
- if in_traced_function_var.get():
1345
- return func(*args, **kwargs)
1346
-
1347
- # Set in_traced_function_var to True
1348
- token = in_traced_function_var.set(True)
1349
-
1187
+ def wrapper(*args, **kwargs):
1350
1188
  # Get current trace from context
1351
1189
  current_trace = current_trace_var.get()
1352
-
1190
+
1353
1191
  # If there's no current trace, create a root trace
1354
1192
  if not current_trace:
1355
1193
  trace_id = str(uuid.uuid4())
@@ -1376,66 +1214,40 @@ class Tracer:
1376
1214
  # This sets the current_span_var
1377
1215
  with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1378
1216
  # Record inputs
1379
- span.record_input({
1380
- 'args': str(args),
1381
- 'kwargs': kwargs
1382
- })
1217
+ inputs = combine_args_kwargs(func, args, kwargs)
1218
+ span.record_input(inputs)
1383
1219
 
1384
- # If deep tracing is enabled, apply monkey patching
1385
1220
  if use_deep_tracing:
1386
- module, original_functions = self._apply_deep_tracing(func, span_type)
1387
-
1388
- # Execute function
1389
- result = func(*args, **kwargs)
1390
-
1391
- # Restore original functions if deep tracing was enabled
1392
- if use_deep_tracing and module and 'original_functions' in locals():
1393
- for name, obj in original_functions.items():
1394
- setattr(module, name, obj)
1221
+ with _DeepTracer():
1222
+ result = func(*args, **kwargs)
1223
+ else:
1224
+ result = func(*args, **kwargs)
1395
1225
 
1396
1226
  # Record output
1397
1227
  span.record_output(result)
1398
-
1399
- # Save the completed trace
1400
- current_trace.save(overwrite=overwrite)
1401
1228
  return result
1402
1229
  finally:
1230
+ # Save the completed trace
1231
+ trace_id, trace = current_trace.save(overwrite=overwrite)
1232
+ self.traces.append(trace)
1233
+
1403
1234
  # Reset trace context (span context resets automatically)
1404
1235
  current_trace_var.reset(trace_token)
1405
- # Reset in_traced_function_var
1406
- in_traced_function_var.reset(token)
1407
1236
  else:
1408
- # Already have a trace context, just create a span in it
1409
- # The span method handles current_span_var
1410
-
1411
- try:
1412
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1413
- # Record inputs
1414
- span.record_input({
1415
- 'args': str(args),
1416
- 'kwargs': kwargs
1417
- })
1418
-
1419
- # If deep tracing is enabled, apply monkey patching
1420
- if use_deep_tracing:
1421
- module, original_functions = self._apply_deep_tracing(func, span_type)
1422
-
1423
- # Execute function
1237
+ with current_trace.span(span_name, span_type=span_type) as span:
1238
+
1239
+ inputs = combine_args_kwargs(func, args, kwargs)
1240
+ span.record_input(inputs)
1241
+
1242
+ if use_deep_tracing:
1243
+ with _DeepTracer():
1244
+ result = func(*args, **kwargs)
1245
+ else:
1424
1246
  result = func(*args, **kwargs)
1425
1247
 
1426
- # Restore original functions if deep tracing was enabled
1427
- if use_deep_tracing and module and 'original_functions' in locals():
1428
- for name, obj in original_functions.items():
1429
- setattr(module, name, obj)
1430
-
1431
- # Record output
1432
- span.record_output(result)
1433
-
1434
- return result
1435
- finally:
1436
- # Reset in_traced_function_var
1437
- in_traced_function_var.reset(token)
1438
-
1248
+ span.record_output(result)
1249
+ return result
1250
+
1439
1251
  return wrapper
1440
1252
 
1441
1253
  def async_evaluate(self, *args, **kwargs):
@@ -1462,64 +1274,94 @@ class Tracer:
1462
1274
  else:
1463
1275
  warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
1464
1276
 
1465
-
1466
1277
  def wrap(client: Any) -> Any:
1467
1278
  """
1468
1279
  Wraps an API client to add tracing capabilities.
1469
1280
  Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1470
1281
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1471
1282
  """
1472
- span_name, original_create, original_stream = _get_client_config(client)
1283
+ span_name, original_create, original_responses_create, original_stream = _get_client_config(client)
1284
+
1285
+ def _record_input_and_check_streaming(span, kwargs, is_responses=False):
1286
+ """Record input and check for streaming"""
1287
+ is_streaming = kwargs.get("stream", False)
1473
1288
 
1474
- # --- Define Traced Async Functions ---
1289
+ # Record input based on whether this is a responses endpoint
1290
+ if is_responses:
1291
+ span.record_input(kwargs)
1292
+ else:
1293
+ input_data = _format_input_data(client, **kwargs)
1294
+ span.record_input(input_data)
1295
+
1296
+ # Warn about token counting limitations with streaming
1297
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1298
+ if not kwargs.get("stream_options", {}).get("include_usage"):
1299
+ warnings.warn(
1300
+ "OpenAI streaming calls don't include token counts by default. "
1301
+ "To enable token counting with streams, set stream_options={'include_usage': True} "
1302
+ "in your API call arguments.",
1303
+ UserWarning
1304
+ )
1305
+
1306
+ return is_streaming
1307
+
1308
+ def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
1309
+ """Format and record the output in the span"""
1310
+ if is_streaming:
1311
+ output_entry = span.record_output("<pending stream>")
1312
+ wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
1313
+ return wrapper_func(response, client, output_entry)
1314
+ else:
1315
+ format_func = _format_response_output_data if is_responses else _format_output_data
1316
+ output_data = format_func(client, response)
1317
+ span.record_output(output_data)
1318
+ return response
1319
+
1320
+ def _handle_error(span, e, is_async):
1321
+ """Handle and record errors"""
1322
+ call_type = "async" if is_async else "sync"
1323
+ print(f"Error during wrapped {call_type} API call ({span_name}): {e}")
1324
+ span.record_output({"error": str(e)})
1325
+ raise
1326
+
1327
+ # --- Traced Async Functions ---
1475
1328
  async def traced_create_async(*args, **kwargs):
1476
- # [Existing logic - unchanged]
1477
1329
  current_trace = current_trace_var.get()
1478
1330
  if not current_trace:
1479
- if asyncio.iscoroutinefunction(original_create):
1480
- return await original_create(*args, **kwargs)
1481
- else:
1482
- return original_create(*args, **kwargs)
1483
-
1484
- is_streaming = kwargs.get("stream", False)
1485
-
1331
+ return await original_create(*args, **kwargs)
1332
+
1486
1333
  with current_trace.span(span_name, span_type="llm") as span:
1487
- input_data = _format_input_data(client, **kwargs)
1488
- span.record_input(input_data)
1489
-
1490
- # Warn about token counting limitations with streaming
1491
- if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1492
- if not kwargs.get("stream_options", {}).get("include_usage"):
1493
- warnings.warn(
1494
- "OpenAI streaming calls don't include token counts by default. "
1495
- "To enable token counting with streams, set stream_options={'include_usage': True} "
1496
- "in your API call arguments.",
1497
- UserWarning
1498
- )
1499
-
1334
+ is_streaming = _record_input_and_check_streaming(span, kwargs)
1335
+
1500
1336
  try:
1501
- if is_streaming:
1502
- stream_iterator = await original_create(*args, **kwargs)
1503
- output_entry = span.record_output("<pending stream>")
1504
- return _async_stream_wrapper(stream_iterator, client, output_entry)
1505
- else:
1506
- awaited_response = await original_create(*args, **kwargs)
1507
- output_data = _format_output_data(client, awaited_response)
1508
- span.record_output(output_data)
1509
- return awaited_response
1337
+ response_or_iterator = await original_create(*args, **kwargs)
1338
+ return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
1510
1339
  except Exception as e:
1511
- print(f"Error during wrapped async API call ({span_name}): {e}")
1512
- span.record_output({"error": str(e)})
1513
- raise
1514
-
1515
-
1516
- # Function replacing .stream() - NOW returns the wrapper class instance
1340
+ return _handle_error(span, e, True)
1341
+
1342
+ # Async responses for OpenAI clients
1343
+ async def traced_response_create_async(*args, **kwargs):
1344
+ current_trace = current_trace_var.get()
1345
+ if not current_trace:
1346
+ return await original_responses_create(*args, **kwargs)
1347
+
1348
+ with current_trace.span(span_name, span_type="llm") as span:
1349
+ is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
1350
+
1351
+ try:
1352
+ response_or_iterator = await original_responses_create(*args, **kwargs)
1353
+ return _format_and_record_output(span, response_or_iterator, is_streaming, True, True)
1354
+ except Exception as e:
1355
+ return _handle_error(span, e, True)
1356
+
1357
+ # Function replacing .stream() for async clients
1517
1358
  def traced_stream_async(*args, **kwargs):
1518
1359
  current_trace = current_trace_var.get()
1519
1360
  if not current_trace or not original_stream:
1520
1361
  return original_stream(*args, **kwargs)
1362
+
1521
1363
  original_manager = original_stream(*args, **kwargs)
1522
- wrapper_manager = _TracedAsyncStreamManagerWrapper(
1364
+ return _TracedAsyncStreamManagerWrapper(
1523
1365
  original_manager=original_manager,
1524
1366
  client=client,
1525
1367
  span_name=span_name,
@@ -1527,104 +1369,74 @@ def wrap(client: Any) -> Any:
1527
1369
  stream_wrapper_func=_async_stream_wrapper,
1528
1370
  input_kwargs=kwargs
1529
1371
  )
1530
- return wrapper_manager
1531
-
1532
- # --- Define Traced Sync Functions ---
1372
+
1373
+ # --- Traced Sync Functions ---
1533
1374
  def traced_create_sync(*args, **kwargs):
1534
- # [Existing logic - unchanged]
1535
1375
  current_trace = current_trace_var.get()
1536
1376
  if not current_trace:
1537
- return original_create(*args, **kwargs)
1538
-
1539
- is_streaming = kwargs.get("stream", False)
1540
-
1377
+ return original_create(*args, **kwargs)
1378
+
1541
1379
  with current_trace.span(span_name, span_type="llm") as span:
1542
- input_data = _format_input_data(client, **kwargs)
1543
- span.record_input(input_data)
1544
-
1545
- # Warn about token counting limitations with streaming
1546
- if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1547
- if not kwargs.get("stream_options", {}).get("include_usage"):
1548
- warnings.warn(
1549
- "OpenAI streaming calls don't include token counts by default. "
1550
- "To enable token counting with streams, set stream_options={'include_usage': True} "
1551
- "in your API call arguments.",
1552
- UserWarning
1553
- )
1554
-
1555
- try:
1556
- response_or_iterator = original_create(*args, **kwargs)
1557
- except Exception as e:
1558
- print(f"Error during wrapped sync API call ({span_name}): {e}")
1559
- span.record_output({"error": str(e)})
1560
- raise
1561
-
1562
- if is_streaming:
1563
- output_entry = span.record_output("<pending stream>")
1564
- return _sync_stream_wrapper(response_or_iterator, client, output_entry)
1565
- else:
1566
- output_data = _format_output_data(client, response_or_iterator)
1567
- span.record_output(output_data)
1568
- return response_or_iterator
1569
-
1570
-
1380
+ is_streaming = _record_input_and_check_streaming(span, kwargs)
1381
+
1382
+ try:
1383
+ response_or_iterator = original_create(*args, **kwargs)
1384
+ return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
1385
+ except Exception as e:
1386
+ return _handle_error(span, e, False)
1387
+
1388
+ def traced_response_create_sync(*args, **kwargs):
1389
+ current_trace = current_trace_var.get()
1390
+ if not current_trace:
1391
+ return original_responses_create(*args, **kwargs)
1392
+
1393
+ with current_trace.span(span_name, span_type="llm") as span:
1394
+ is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
1395
+
1396
+ try:
1397
+ response_or_iterator = original_responses_create(*args, **kwargs)
1398
+ return _format_and_record_output(span, response_or_iterator, is_streaming, False, True)
1399
+ except Exception as e:
1400
+ return _handle_error(span, e, False)
1401
+
1571
1402
  # Function replacing sync .stream()
1572
1403
  def traced_stream_sync(*args, **kwargs):
1573
- current_trace = current_trace_var.get()
1574
- if not current_trace or not original_stream:
1575
- return original_stream(*args, **kwargs)
1576
- original_manager = original_stream(*args, **kwargs)
1577
- wrapper_manager = _TracedSyncStreamManagerWrapper(
1578
- original_manager=original_manager,
1579
- client=client,
1580
- span_name=span_name,
1581
- trace_client=current_trace,
1582
- stream_wrapper_func=_sync_stream_wrapper,
1583
- input_kwargs=kwargs
1584
- )
1585
- return wrapper_manager
1586
-
1587
-
1404
+ current_trace = current_trace_var.get()
1405
+ if not current_trace or not original_stream:
1406
+ return original_stream(*args, **kwargs)
1407
+
1408
+ original_manager = original_stream(*args, **kwargs)
1409
+ return _TracedSyncStreamManagerWrapper(
1410
+ original_manager=original_manager,
1411
+ client=client,
1412
+ span_name=span_name,
1413
+ trace_client=current_trace,
1414
+ stream_wrapper_func=_sync_stream_wrapper,
1415
+ input_kwargs=kwargs
1416
+ )
1417
+
1588
1418
  # --- Assign Traced Methods to Client Instance ---
1589
- # [Assignment logic remains the same]
1590
1419
  if isinstance(client, (AsyncOpenAI, AsyncTogether)):
1591
1420
  client.chat.completions.create = traced_create_async
1592
- # Wrap the Responses API endpoint for OpenAI clients
1593
1421
  if hasattr(client, "responses") and hasattr(client.responses, "create"):
1594
- # Capture the original responses.create
1595
- original_responses_create = client.responses.create
1596
- def traced_responses(*args, **kwargs):
1597
- # Get the current trace from contextvars
1598
- current_trace = current_trace_var.get()
1599
- # If no active trace, call the original
1600
- if not current_trace:
1601
- return original_responses_create(*args, **kwargs)
1602
- # Trace this responses.create call
1603
- with current_trace.span(span_name, span_type="llm") as span:
1604
- # Record raw input kwargs
1605
- span.record_input(kwargs)
1606
- # Make the actual API call
1607
- response = original_responses_create(*args, **kwargs)
1608
- # Record the output object
1609
- span.record_output(response)
1610
- return response
1611
- # Assign the traced wrapper
1612
- client.responses.create = traced_responses
1422
+ client.responses.create = traced_response_create_async
1613
1423
  elif isinstance(client, AsyncAnthropic):
1614
1424
  client.messages.create = traced_create_async
1615
1425
  if original_stream:
1616
- client.messages.stream = traced_stream_async
1426
+ client.messages.stream = traced_stream_async
1617
1427
  elif isinstance(client, genai.client.AsyncClient):
1618
- client.generate_content = traced_create_async
1428
+ client.models.generate_content = traced_create_async
1619
1429
  elif isinstance(client, (OpenAI, Together)):
1620
- client.chat.completions.create = traced_create_sync
1430
+ client.chat.completions.create = traced_create_sync
1431
+ if hasattr(client, "responses") and hasattr(client.responses, "create"):
1432
+ client.responses.create = traced_response_create_sync
1621
1433
  elif isinstance(client, Anthropic):
1622
- client.messages.create = traced_create_sync
1623
- if original_stream:
1624
- client.messages.stream = traced_stream_sync
1434
+ client.messages.create = traced_create_sync
1435
+ if original_stream:
1436
+ client.messages.stream = traced_stream_sync
1625
1437
  elif isinstance(client, genai.Client):
1626
- client.generate_content = traced_create_sync
1627
-
1438
+ client.models.generate_content = traced_create_sync
1439
+
1628
1440
  return client
1629
1441
 
1630
1442
  # Helper functions for client-specific operations
@@ -1639,19 +1451,20 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
1639
1451
  tuple: (span_name, create_method, stream_method)
1640
1452
  - span_name: String identifier for tracing
1641
1453
  - create_method: Reference to the client's creation method
1454
+ - responses_method: Reference to the client's responses method (if applicable)
1642
1455
  - stream_method: Reference to the client's stream method (if applicable)
1643
1456
 
1644
1457
  Raises:
1645
1458
  ValueError: If client type is not supported
1646
1459
  """
1647
1460
  if isinstance(client, (OpenAI, AsyncOpenAI)):
1648
- return "OPENAI_API_CALL", client.chat.completions.create, None
1461
+ return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None
1649
1462
  elif isinstance(client, (Together, AsyncTogether)):
1650
- return "TOGETHER_API_CALL", client.chat.completions.create, None
1463
+ return "TOGETHER_API_CALL", client.chat.completions.create, None, None
1651
1464
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1652
- return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
1465
+ return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream
1653
1466
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1654
- return "GOOGLE_API_CALL", client.models.generate_content, None
1467
+ return "GOOGLE_API_CALL", client.models.generate_content, None, None
1655
1468
  raise ValueError(f"Unsupported client type: {type(client)}")
1656
1469
 
1657
1470
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
@@ -1677,6 +1490,26 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
1677
1490
  "max_tokens": kwargs.get("max_tokens")
1678
1491
  }
1679
1492
 
1493
+ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
1494
+ """Format API response data based on client type.
1495
+
1496
+ Normalizes different response formats into a consistent structure
1497
+ for tracing purposes.
1498
+ """
1499
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1500
+ return {
1501
+ "content": response.output,
1502
+ "usage": {
1503
+ "prompt_tokens": response.usage.input_tokens,
1504
+ "completion_tokens": response.usage.output_tokens,
1505
+ "total_tokens": response.usage.total_tokens
1506
+ }
1507
+ }
1508
+ else:
1509
+ warnings.warn(f"Unsupported client type: {type(client)}")
1510
+ return {}
1511
+
1512
+
1680
1513
  def _format_output_data(client: ApiClient, response: Any) -> dict:
1681
1514
  """Format API response data based on client type.
1682
1515
 
@@ -1716,117 +1549,51 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1716
1549
  }
1717
1550
  }
1718
1551
 
1719
- # Define a blocklist of functions that should not be traced
1720
- # These are typically utility functions, print statements, logging, etc.
1721
- _TRACE_BLOCKLIST = {
1722
- # Built-in functions
1723
- 'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
1724
- 'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
1725
- 'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
1726
- # Logging functions
1727
- 'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
1728
- # Common utility functions
1729
- 'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
1730
- # String operations
1731
- 'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
1732
- # Dict operations
1733
- 'get', 'items', 'keys', 'values', 'update',
1734
- # List operations
1735
- 'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
1736
- }
1737
-
1738
-
1739
- # Add a new function for deep tracing at the module level
1740
- def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
1552
+ def combine_args_kwargs(func, args, kwargs):
1741
1553
  """
1742
- Creates a wrapper for a function that automatically traces it when called within a traced function.
1743
- This enables deep tracing without requiring explicit @observe decorators on every function.
1554
+ Combine positional arguments and keyword arguments into a single dictionary.
1744
1555
 
1745
1556
  Args:
1746
- func: The function to wrap
1747
- tracer: The Tracer instance
1748
- span_type: Type of span (default "span")
1557
+ func: The function being called
1558
+ args: Tuple of positional arguments
1559
+ kwargs: Dictionary of keyword arguments
1749
1560
 
1750
1561
  Returns:
1751
- A wrapped function that will be traced when called
1562
+ A dictionary combining both args and kwargs
1752
1563
  """
1753
- # Skip wrapping if the function is not callable or is a built-in
1754
- if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
1755
- return func
1756
-
1757
- # Skip functions in the blocklist
1758
- if func.__name__ in _TRACE_BLOCKLIST:
1759
- return func
1760
-
1761
- # Skip functions from certain modules (logging, sys, etc.)
1762
- if func.__module__ and any(func.__module__.startswith(m) for m in ['logging', 'sys', 'os', 'json', 'time', 'datetime']):
1763
- return func
1764
-
1765
-
1766
- # Get function name for the span - check for custom name set by @observe
1767
- func_name = getattr(func, '_judgment_span_name', func.__name__)
1768
-
1769
- # Check for custom span_type set by @observe
1770
- func_span_type = getattr(func, '_judgment_span_type', "span")
1771
-
1772
- # Store original function to prevent losing reference
1773
- original_func = func
1774
-
1775
- # Create appropriate wrapper based on whether the function is async or not
1776
- if asyncio.iscoroutinefunction(func):
1777
- @functools.wraps(func)
1778
- async def async_deep_wrapper(*args, **kwargs):
1779
- # Get current trace from context
1780
- current_trace = current_trace_var.get()
1781
-
1782
- # If no trace context, just call the function
1783
- if not current_trace:
1784
- return await original_func(*args, **kwargs)
1785
-
1786
- # Create a span for this function call - use custom span_type if available
1787
- with current_trace.span(func_name, span_type=func_span_type) as span:
1788
- # Record inputs
1789
- span.record_input({
1790
- 'args': str(args),
1791
- 'kwargs': kwargs
1792
- })
1793
-
1794
- # Execute function
1795
- result = await original_func(*args, **kwargs)
1796
-
1797
- # Record output
1798
- span.record_output(result)
1799
-
1800
- return result
1801
-
1802
- return async_deep_wrapper
1803
- else:
1804
- @functools.wraps(func)
1805
- def deep_wrapper(*args, **kwargs):
1806
- # Get current trace from context
1807
- current_trace = current_trace_var.get()
1808
-
1809
- # If no trace context, just call the function
1810
- if not current_trace:
1811
- return original_func(*args, **kwargs)
1812
-
1813
- # Create a span for this function call - use custom span_type if available
1814
- with current_trace.span(func_name, span_type=func_span_type) as span:
1815
- # Record inputs
1816
- span.record_input({
1817
- 'args': str(args),
1818
- 'kwargs': kwargs
1819
- })
1820
-
1821
- # Execute function
1822
- result = original_func(*args, **kwargs)
1823
-
1824
- # Record output
1825
- span.record_output(result)
1826
-
1827
- return result
1828
-
1829
- return deep_wrapper
1564
+ try:
1565
+ import inspect
1566
+ sig = inspect.signature(func)
1567
+ param_names = list(sig.parameters.keys())
1568
+
1569
+ args_dict = {}
1570
+ for i, arg in enumerate(args):
1571
+ if i < len(param_names):
1572
+ args_dict[param_names[i]] = arg
1573
+ else:
1574
+ args_dict[f"arg{i}"] = arg
1575
+
1576
+ return {**args_dict, **kwargs}
1577
+ except Exception as e:
1578
+ # Fallback if signature inspection fails
1579
+ return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
1580
+
1581
+ # NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
1582
+ # @link https://docs.python.org/3.13/library/sysconfig.html
1583
+ _TRACE_FILEPATH_BLOCKLIST = tuple(
1584
+ os.path.realpath(p) + os.sep
1585
+ for p in {
1586
+ sysconfig.get_paths()['stdlib'],
1587
+ sysconfig.get_paths().get('platstdlib', ''),
1588
+ *site.getsitepackages(),
1589
+ site.getusersitepackages(),
1590
+ *(
1591
+ [os.path.join(os.path.dirname(__file__), '../../judgeval/')]
1592
+ if os.environ.get('JUDGMENT_DEV')
1593
+ else []
1594
+ ),
1595
+ } if p
1596
+ )
1830
1597
 
1831
1598
  # Add the new TraceThreadPoolExecutor class
1832
1599
  class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
@@ -1929,7 +1696,7 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
1929
1696
  def _sync_stream_wrapper(
1930
1697
  original_stream: Iterator,
1931
1698
  client: ApiClient,
1932
- output_entry: TraceEntry
1699
+ span: TraceSpan
1933
1700
  ) -> Generator[Any, None, None]:
1934
1701
  """Wraps a synchronous stream iterator to capture content and update the trace."""
1935
1702
  content_parts = [] # Use a list instead of string concatenation
@@ -1948,7 +1715,7 @@ def _sync_stream_wrapper(
1948
1715
  final_usage = _extract_usage_from_final_chunk(client, last_chunk)
1949
1716
 
1950
1717
  # Update the trace entry with the accumulated content and usage
1951
- output_entry.output = {
1718
+ span.output = {
1952
1719
  "content": "".join(content_parts), # Join list at the end
1953
1720
  "usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
1954
1721
  "streamed": True
@@ -1960,7 +1727,7 @@ def _sync_stream_wrapper(
1960
1727
  async def _async_stream_wrapper(
1961
1728
  original_stream: AsyncIterator,
1962
1729
  client: ApiClient,
1963
- output_entry: TraceEntry
1730
+ span: TraceSpan
1964
1731
  ) -> AsyncGenerator[Any, None]:
1965
1732
  # [Existing logic - unchanged]
1966
1733
  content_parts = [] # Use a list instead of string concatenation
@@ -1969,7 +1736,7 @@ async def _async_stream_wrapper(
1969
1736
  anthropic_input_tokens = 0
1970
1737
  anthropic_output_tokens = 0
1971
1738
 
1972
- target_span_id = getattr(output_entry, 'span_id', 'UNKNOWN')
1739
+ target_span_id = span.span_id
1973
1740
 
1974
1741
  try:
1975
1742
  async for chunk in original_stream:
@@ -2014,19 +1781,17 @@ async def _async_stream_wrapper(
2014
1781
  elif last_content_chunk:
2015
1782
  usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
2016
1783
 
2017
- if output_entry and hasattr(output_entry, 'output'):
2018
- output_entry.output = {
1784
+ if span and hasattr(span, 'output'):
1785
+ span.output = {
2019
1786
  "content": "".join(content_parts), # Join list at the end
2020
1787
  "usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
2021
1788
  "streamed": True
2022
1789
  }
2023
- start_ts = getattr(output_entry, 'created_at', time.time())
2024
- output_entry.duration = time.time() - start_ts
1790
+ start_ts = getattr(span, 'created_at', time.time())
1791
+ span.duration = time.time() - start_ts
2025
1792
  # else: # Handle error case if necessary, but remove debug print
2026
1793
 
2027
- # --- Define Context Manager Wrapper Classes ---
2028
- class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
2029
- """Wraps an original async stream manager to add tracing."""
1794
+ class _BaseStreamManagerWrapper:
2030
1795
  def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
2031
1796
  self._original_manager = original_manager
2032
1797
  self._client = client
@@ -2036,281 +1801,74 @@ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
2036
1801
  self._input_kwargs = input_kwargs
2037
1802
  self._parent_span_id_at_entry = None
2038
1803
 
2039
- async def __aenter__(self):
2040
- self._parent_span_id_at_entry = current_span_var.get()
2041
- if not self._trace_client:
2042
- # If no trace, just delegate to the original manager
2043
- return await self._original_manager.__aenter__()
2044
-
2045
- # --- Manually create the 'enter' entry ---
1804
+ def _create_span(self):
2046
1805
  start_time = time.time()
2047
1806
  span_id = str(uuid.uuid4())
2048
1807
  current_depth = 0
2049
1808
  if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2050
1809
  current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
2051
1810
  self._trace_client._span_depths[span_id] = current_depth
2052
- enter_entry = TraceEntry(
2053
- type="enter", function=self._span_name, span_id=span_id,
2054
- trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
2055
- created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
1811
+ span = TraceSpan(
1812
+ function=self._span_name,
1813
+ span_id=span_id,
1814
+ trace_id=self._trace_client.trace_id,
1815
+ depth=current_depth,
1816
+ message=self._span_name,
1817
+ created_at=start_time,
1818
+ span_type="llm",
1819
+ parent_span_id=self._parent_span_id_at_entry
2056
1820
  )
2057
- self._trace_client.add_entry(enter_entry)
2058
- # --- End manual 'enter' entry ---
1821
+ self._trace_client.add_span(span)
1822
+ return span_id, span
2059
1823
 
2060
- # Set the current span ID in contextvars
2061
- self._span_context_token = current_span_var.set(span_id)
1824
+ def _finalize_span(self, span_id):
1825
+ span = self._trace_client.span_id_to_span.get(span_id)
1826
+ if span:
1827
+ span.duration = time.time() - span.created_at
1828
+ if span_id in self._trace_client._span_depths:
1829
+ del self._trace_client._span_depths[span_id]
2062
1830
 
2063
- # Manually create 'input' entry
2064
- input_data = _format_input_data(self._client, **self._input_kwargs)
2065
- input_entry = TraceEntry(
2066
- type="input", function=self._span_name, span_id=span_id,
2067
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
2068
- created_at=time.time(), inputs=input_data, span_type="llm"
2069
- )
2070
- self._trace_client.add_entry(input_entry)
2071
-
2072
- # Call the original __aenter__
2073
- raw_iterator = await self._original_manager.__aenter__()
1831
+ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
1832
+ async def __aenter__(self):
1833
+ self._parent_span_id_at_entry = current_span_var.get()
1834
+ if not self._trace_client:
1835
+ return await self._original_manager.__aenter__()
2074
1836
 
2075
- # Manually create pending 'output' entry
2076
- output_entry = TraceEntry(
2077
- type="output", function=self._span_name, span_id=span_id,
2078
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
2079
- created_at=time.time(), output="<pending stream>", span_type="llm"
2080
- )
2081
- self._trace_client.add_entry(output_entry)
1837
+ span_id, span = self._create_span()
1838
+ self._span_context_token = current_span_var.set(span_id)
1839
+ span.inputs = _format_input_data(self._client, **self._input_kwargs)
2082
1840
 
2083
- # Wrap the raw iterator
2084
- wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
2085
- return wrapped_iterator
1841
+ # Call the original __aenter__ and expect it to be an async generator
1842
+ raw_iterator = await self._original_manager.__aenter__()
1843
+ span.output = "<pending stream>"
1844
+ return self._stream_wrapper_func(raw_iterator, self._client, span)
2086
1845
 
2087
1846
  async def __aexit__(self, exc_type, exc_val, exc_tb):
2088
- # Manually create the 'exit' entry
2089
1847
  if hasattr(self, '_span_context_token'):
2090
- span_id = current_span_var.get()
2091
- start_time_for_duration = 0
2092
- for entry in reversed(self._trace_client.entries):
2093
- if entry.span_id == span_id and entry.type == 'enter':
2094
- start_time_for_duration = entry.created_at
2095
- break
2096
- duration = time.time() - start_time_for_duration if start_time_for_duration else None
2097
- exit_depth = self._trace_client._span_depths.get(span_id, 0)
2098
- exit_entry = TraceEntry(
2099
- type="exit", function=self._span_name, span_id=span_id,
2100
- trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
2101
- created_at=time.time(), duration=duration, span_type="llm"
2102
- )
2103
- self._trace_client.add_entry(exit_entry)
2104
- if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
2105
- current_span_var.reset(self._span_context_token)
2106
- delattr(self, '_span_context_token')
2107
-
2108
- # Delegate __aexit__
2109
- if hasattr(self._original_manager, "__aexit__"):
2110
- return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2111
- return None
2112
-
2113
- class _TracedSyncStreamManagerWrapper(AbstractContextManager):
2114
- """Wraps an original sync stream manager to add tracing."""
2115
- def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
2116
- self._original_manager = original_manager
2117
- self._client = client
2118
- self._span_name = span_name
2119
- self._trace_client = trace_client
2120
- self._stream_wrapper_func = stream_wrapper_func
2121
- self._input_kwargs = input_kwargs
2122
- self._parent_span_id_at_entry = None
1848
+ span_id = current_span_var.get()
1849
+ self._finalize_span(span_id)
1850
+ current_span_var.reset(self._span_context_token)
1851
+ delattr(self, '_span_context_token')
1852
+ return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2123
1853
 
1854
+ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
2124
1855
  def __enter__(self):
2125
1856
  self._parent_span_id_at_entry = current_span_var.get()
2126
1857
  if not self._trace_client:
2127
- return self._original_manager.__enter__()
1858
+ return self._original_manager.__enter__()
2128
1859
 
2129
- # Manually create 'enter' entry
2130
- start_time = time.time()
2131
- span_id = str(uuid.uuid4())
2132
- current_depth = 0
2133
- if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2134
- current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
2135
- self._trace_client._span_depths[span_id] = current_depth
2136
- enter_entry = TraceEntry(
2137
- type="enter", function=self._span_name, span_id=span_id,
2138
- trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
2139
- created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
2140
- )
2141
- self._trace_client.add_entry(enter_entry)
1860
+ span_id, span = self._create_span()
2142
1861
  self._span_context_token = current_span_var.set(span_id)
1862
+ span.inputs = _format_input_data(self._client, **self._input_kwargs)
2143
1863
 
2144
- # Manually create 'input' entry
2145
- input_data = _format_input_data(self._client, **self._input_kwargs)
2146
- input_entry = TraceEntry(
2147
- type="input", function=self._span_name, span_id=span_id,
2148
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
2149
- created_at=time.time(), inputs=input_data, span_type="llm"
2150
- )
2151
- self._trace_client.add_entry(input_entry)
2152
-
2153
- # Call original __enter__
2154
1864
  raw_iterator = self._original_manager.__enter__()
2155
-
2156
- # Manually create 'output' entry (pending)
2157
- output_entry = TraceEntry(
2158
- type="output", function=self._span_name, span_id=span_id,
2159
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
2160
- created_at=time.time(), output="<pending stream>", span_type="llm"
2161
- )
2162
- self._trace_client.add_entry(output_entry)
2163
-
2164
- # Wrap the raw iterator
2165
- wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
2166
- return wrapped_iterator
1865
+ span.output = "<pending stream>"
1866
+ return self._stream_wrapper_func(raw_iterator, self._client, span)
2167
1867
 
2168
1868
  def __exit__(self, exc_type, exc_val, exc_tb):
2169
- # Manually create 'exit' entry
2170
1869
  if hasattr(self, '_span_context_token'):
2171
- span_id = current_span_var.get()
2172
- start_time_for_duration = 0
2173
- for entry in reversed(self._trace_client.entries):
2174
- if entry.span_id == span_id and entry.type == 'enter':
2175
- start_time_for_duration = entry.created_at
2176
- break
2177
- duration = time.time() - start_time_for_duration if start_time_for_duration else None
2178
- exit_depth = self._trace_client._span_depths.get(span_id, 0)
2179
- exit_entry = TraceEntry(
2180
- type="exit", function=self._span_name, span_id=span_id,
2181
- trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
2182
- created_at=time.time(), duration=duration, span_type="llm"
2183
- )
2184
- self._trace_client.add_entry(exit_entry)
2185
- if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
2186
- current_span_var.reset(self._span_context_token)
2187
- delattr(self, '_span_context_token')
2188
-
2189
- # Delegate __exit__
2190
- if hasattr(self._original_manager, "__exit__"):
2191
- return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2192
- return None
2193
-
2194
- # --- NEW Generalized Helper Function (Moved from demo) ---
2195
- def prepare_evaluation_for_state(
2196
- scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
2197
- example: Optional[Example] = None,
2198
- # --- Individual components (alternative to 'example') ---
2199
- input: Optional[str] = None,
2200
- actual_output: Optional[Union[str, List[str]]] = None,
2201
- expected_output: Optional[Union[str, List[str]]] = None,
2202
- context: Optional[List[str]] = None,
2203
- retrieval_context: Optional[List[str]] = None,
2204
- tools_called: Optional[List[str]] = None,
2205
- expected_tools: Optional[List[str]] = None,
2206
- additional_metadata: Optional[Dict[str, Any]] = None,
2207
- # --- Other eval parameters ---
2208
- model: Optional[str] = None,
2209
- log_results: Optional[bool] = True
2210
- ) -> Optional[EvaluationConfig]:
2211
- """
2212
- Prepares an EvaluationConfig object, similar to TraceClient.async_evaluate.
2213
-
2214
- Accepts either a pre-made Example object or individual components to construct one.
2215
- Returns the EvaluationConfig object ready to be placed in the state, or None.
2216
- """
2217
- final_example = example
2218
-
2219
- # If example is not provided, try to construct one from individual parts
2220
- if final_example is None:
2221
- # Basic validation: Ensure at least actual_output is present for most scorers
2222
- if actual_output is None:
2223
- # print("[prepare_evaluation_for_state] Warning: 'actual_output' is required when 'example' is not provided. Skipping evaluation setup.")
2224
- return None
2225
- try:
2226
- final_example = Example(
2227
- input=input,
2228
- actual_output=actual_output,
2229
- expected_output=expected_output,
2230
- context=context,
2231
- retrieval_context=retrieval_context,
2232
- tools_called=tools_called,
2233
- expected_tools=expected_tools,
2234
- additional_metadata=additional_metadata,
2235
- # trace_id will be set by the handler later if needed
2236
- )
2237
- # print("[prepare_evaluation_for_state] Constructed Example from individual components.")
2238
- except Exception as e:
2239
- # print(f"[prepare_evaluation_for_state] Error constructing Example: {e}. Skipping evaluation setup.")
2240
- return None
2241
-
2242
- # If we have a valid example (provided or constructed) and scorers
2243
- if final_example and scorers:
2244
- # TODO: Add validation like check_examples if needed here,
2245
- # although the handler might implicitly handle some checks via TraceClient.
2246
- return EvaluationConfig(
2247
- scorers=scorers,
2248
- example=final_example,
2249
- model=model,
2250
- log_results=log_results
2251
- )
2252
- elif not scorers:
2253
- # print("[prepare_evaluation_for_state] No scorers provided. Skipping evaluation setup.")
2254
- return None
2255
- else: # No valid example
2256
- # print("[prepare_evaluation_for_state] No valid Example available. Skipping evaluation setup.")
2257
- return None
2258
- # --- End NEW Helper Function ---
2259
-
2260
- # --- NEW: Helper function to simplify adding eval config to state ---
2261
- def add_evaluation_to_state(
2262
- state: Dict[str, Any], # The LangGraph state dictionary
2263
- scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
2264
- # --- Evaluation components (same as prepare_evaluation_for_state) ---
2265
- input: Optional[str] = None,
2266
- actual_output: Optional[Union[str, List[str]]] = None,
2267
- expected_output: Optional[Union[str, List[str]]] = None,
2268
- context: Optional[List[str]] = None,
2269
- retrieval_context: Optional[List[str]] = None,
2270
- tools_called: Optional[List[str]] = None,
2271
- expected_tools: Optional[List[str]] = None,
2272
- additional_metadata: Optional[Dict[str, Any]] = None,
2273
- # --- Other eval parameters ---
2274
- model: Optional[str] = None,
2275
- log_results: Optional[bool] = True
2276
- ) -> None:
2277
- """
2278
- Prepares an EvaluationConfig and adds it to the state dictionary
2279
- under the '_judgeval_eval' key if successful.
2280
-
2281
- This simplifies the process of setting up evaluations within LangGraph nodes.
2282
-
2283
- Args:
2284
- state: The LangGraph state dictionary to modify.
2285
- scorers: List of scorer instances.
2286
- input: Input for the evaluation example.
2287
- actual_output: Actual output for the evaluation example.
2288
- expected_output: Expected output for the evaluation example.
2289
- context: Context for the evaluation example.
2290
- retrieval_context: Retrieval context for the evaluation example.
2291
- tools_called: Tools called for the evaluation example.
2292
- expected_tools: Expected tools for the evaluation example.
2293
- additional_metadata: Additional metadata for the evaluation example.
2294
- model: Model name used for generation (optional).
2295
- log_results: Whether to log evaluation results (optional, defaults to True).
2296
- """
2297
- eval_config = prepare_evaluation_for_state(
2298
- scorers=scorers,
2299
- input=input,
2300
- actual_output=actual_output,
2301
- expected_output=expected_output,
2302
- context=context,
2303
- retrieval_context=retrieval_context,
2304
- tools_called=tools_called,
2305
- expected_tools=expected_tools,
2306
- additional_metadata=additional_metadata,
2307
- model=model,
2308
- log_results=log_results
2309
- )
2310
-
2311
- if eval_config:
2312
- state["_judgeval_eval"] = eval_config
2313
- # print(f"[_judgeval_eval added to state for node]") # Optional: Log confirmation
2314
-
2315
- # print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
2316
- # --- End NEW Helper ---
1870
+ span_id = current_span_var.get()
1871
+ self._finalize_span(span_id)
1872
+ current_span_var.reset(self._span_context_token)
1873
+ delattr(self, '_span_context_token')
1874
+ return self._original_manager.__exit__(exc_type, exc_val, exc_tb)