judgeval 0.0.35__py3-none-any.whl → 0.0.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -7,7 +7,11 @@ import functools
7
7
  import inspect
8
8
  import json
9
9
  import os
10
+ import site
11
+ import sysconfig
12
+ import threading
10
13
  import time
14
+ import traceback
11
15
  import uuid
12
16
  import warnings
13
17
  import contextvars
@@ -35,7 +39,6 @@ from rich import print as rprint
35
39
  import types # <--- Add this import
36
40
 
37
41
  # Third-party imports
38
- import pika
39
42
  import requests
40
43
  from litellm import cost_per_token
41
44
  from pydantic import BaseModel
@@ -47,6 +50,7 @@ from google import genai
47
50
 
48
51
  # Local application/library-specific imports
49
52
  from judgeval.constants import (
53
+ JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
50
54
  JUDGMENT_TRACES_SAVE_API_URL,
51
55
  JUDGMENT_TRACES_FETCH_API_URL,
52
56
  RABBITMQ_HOST,
@@ -55,172 +59,56 @@ from judgeval.constants import (
55
59
  JUDGMENT_TRACES_DELETE_API_URL,
56
60
  JUDGMENT_PROJECT_DELETE_API_URL,
57
61
  )
58
- from judgeval.judgment_client import JudgmentClient
59
- from judgeval.data import Example
62
+ from judgeval.data import Example, Trace, TraceSpan
60
63
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
61
64
  from judgeval.rules import Rule
62
65
  from judgeval.evaluation_run import EvaluationRun
63
66
  from judgeval.data.result import ScoringResult
67
+ from judgeval.common.utils import validate_api_key
68
+ from judgeval.common.exceptions import JudgmentAPIError
64
69
 
65
70
  # Standard library imports needed for the new class
66
71
  import concurrent.futures
67
72
  from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
68
73
 
69
74
  # Define context variables for tracking the current trace and the current span within a trace
70
- current_trace_var = contextvars.ContextVar('current_trace', default=None)
75
+ current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
71
76
  current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
72
- in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
73
77
 
74
78
  # Define type aliases for better code readability and maintainability
75
79
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
76
- TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
77
80
  SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
78
- @dataclass
79
- class TraceEntry:
80
- """Represents a single trace entry with its visual representation.
81
-
82
- Visual representations:
83
- - enter: → (function entry)
84
- - exit: ← (function exit)
85
- - output: Output: (function return value)
86
- - input: Input: (function parameters)
87
- - evaluation: Evaluation: (evaluation results)
88
- """
89
- type: TraceEntryType
90
- span_id: str # Unique ID for this specific span instance
91
- depth: int # Indentation level for nested calls
92
- created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
93
- function: Optional[str] = None # Name of the function being traced
94
- message: Optional[str] = None # Human-readable description
95
- duration: Optional[float] = None # Time taken (for exit/evaluation entries)
96
- trace_id: str = None # ID of the trace this entry belongs to
97
- output: Any = None # Function output value
98
- # Use field() for mutable defaults to avoid shared state issues
99
- inputs: dict = field(default_factory=dict)
100
- span_type: SpanType = "span"
101
- evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
102
- parent_span_id: Optional[str] = None # ID of the parent span instance
103
-
104
- def print_entry(self):
105
- """Print a trace entry with proper formatting and parent relationship information."""
106
- indent = " " * self.depth
107
-
108
- if self.type == "enter":
109
- # Format parent info if present
110
- parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
111
- print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
112
- elif self.type == "exit":
113
- print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
114
- elif self.type == "output":
115
- # Format output to align properly
116
- output_str = str(self.output)
117
- print(f"{indent}Output (for id: {self.span_id}): {output_str}")
118
- elif self.type == "input":
119
- # Format inputs to align properly
120
- print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
121
- elif self.type == "evaluation":
122
- for evaluation_run in self.evaluation_runs:
123
- print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
124
-
125
- def _serialize_inputs(self) -> dict:
126
- """Helper method to serialize input data safely.
127
-
128
- Returns a dict with serializable versions of inputs, converting non-serializable
129
- objects to None with a warning.
130
- """
131
- serialized_inputs = {}
132
- for key, value in self.inputs.items():
133
- if isinstance(value, BaseModel):
134
- serialized_inputs[key] = value.model_dump()
135
- elif isinstance(value, (list, tuple)):
136
- # Handle lists/tuples of arguments
137
- serialized_inputs[key] = [
138
- item.model_dump() if isinstance(item, BaseModel)
139
- else None if not self._is_json_serializable(item)
140
- else item
141
- for item in value
142
- ]
143
- else:
144
- if self._is_json_serializable(value):
145
- serialized_inputs[key] = value
146
- else:
147
- serialized_inputs[key] = self.safe_stringify(value, self.function)
148
- return serialized_inputs
149
-
150
- def _is_json_serializable(self, obj: Any) -> bool:
151
- """Helper method to check if an object is JSON serializable."""
152
- try:
153
- json.dumps(obj)
154
- return True
155
- except (TypeError, OverflowError, ValueError):
156
- return False
157
81
 
158
- def safe_stringify(self, output, function_name):
159
- """
160
- Safely converts an object to a string or repr, handling serialization issues gracefully.
161
- """
162
- try:
163
- return str(output)
164
- except (TypeError, OverflowError, ValueError):
165
- pass
166
-
167
- try:
168
- return repr(output)
169
- except (TypeError, OverflowError, ValueError):
170
- pass
171
-
172
- warnings.warn(
173
- f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
174
- )
175
- return None
82
+ # --- Evaluation Config Dataclass (Moved from langgraph.py) ---
83
+ @dataclass
84
+ class EvaluationConfig:
85
+ """Configuration for triggering an evaluation from the handler."""
86
+ scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
87
+ example: Example
88
+ model: Optional[str] = None
89
+ log_results: Optional[bool] = True
90
+ # --- End Evaluation Config Dataclass ---
91
+
92
+ # Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
93
+ @dataclass
94
+ class TraceAnnotation:
95
+ """Represents a single annotation for a trace span."""
96
+ span_id: str
97
+ text: str
98
+ label: str
99
+ score: int
176
100
 
177
101
  def to_dict(self) -> dict:
178
- """Convert the trace entry to a dictionary format for storage/transmission."""
102
+ """Convert the annotation to a dictionary format for storage/transmission."""
179
103
  return {
180
- "type": self.type,
181
- "function": self.function,
182
104
  "span_id": self.span_id,
183
- "trace_id": self.trace_id,
184
- "depth": self.depth,
185
- "message": self.message,
186
- "created_at": datetime.fromtimestamp(self.created_at).isoformat(),
187
- "duration": self.duration,
188
- "output": self._serialize_output(),
189
- "inputs": self._serialize_inputs(),
190
- "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
191
- "span_type": self.span_type,
192
- "parent_span_id": self.parent_span_id,
105
+ "annotation": {
106
+ "text": self.text,
107
+ "label": self.label,
108
+ "score": self.score
109
+ }
193
110
  }
194
-
195
- def _serialize_output(self) -> Any:
196
- """Helper method to serialize output data safely.
197
-
198
- Handles special cases:
199
- - Pydantic models are converted using model_dump()
200
- - We try to serialize into JSON, then string, then the base representation (__repr__)
201
- - Non-serializable objects return None with a warning
202
- """
203
-
204
- if isinstance(self.output, BaseModel):
205
- return self.output.model_dump()
206
-
207
- # NEW check: If output is the dict structure from our stream wrapper
208
- if isinstance(self.output, dict) and 'streamed' in self.output:
209
- # Assume it's already JSON-serializable (content is string, usage is dict or None)
210
- return self.output
211
- # NEW check: If output is the placeholder string before stream completes
212
- elif self.output == "<pending stream>":
213
- # Represent this state clearly in the serialized data
214
- return {"status": "pending stream"}
215
-
216
- try:
217
- # Try to serialize the output to verify it's JSON compatible
218
- json.dumps(self.output)
219
- return self.output
220
- except (TypeError, OverflowError, ValueError):
221
- return self.safe_stringify(self.output, self.function)
222
-
223
-
111
+
224
112
  class TraceManagerClient:
225
113
  """
226
114
  Client for handling trace endpoints with the Judgment API
@@ -257,8 +145,6 @@ class TraceManagerClient:
257
145
  raise ValueError(f"Failed to fetch traces: {response.text}")
258
146
 
259
147
  return response.json()
260
-
261
-
262
148
 
263
149
  def save_trace(self, trace_data: dict):
264
150
  """
@@ -301,6 +187,33 @@ class TraceManagerClient:
301
187
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
302
188
  rprint(pretty_str)
303
189
 
190
+ ## TODO: Should have a log endpoint, endpoint should also support batched payloads
191
+ def save_annotation(self, annotation: TraceAnnotation):
192
+ json_data = {
193
+ "span_id": annotation.span_id,
194
+ "annotation": {
195
+ "text": annotation.text,
196
+ "label": annotation.label,
197
+ "score": annotation.score
198
+ }
199
+ }
200
+
201
+ response = requests.post(
202
+ JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
203
+ json=json_data,
204
+ headers={
205
+ 'Content-Type': 'application/json',
206
+ 'Authorization': f'Bearer {self.judgment_api_key}',
207
+ 'X-Organization-Id': self.organization_id
208
+ },
209
+ verify=True
210
+ )
211
+
212
+ if response.status_code != HTTPStatus.OK:
213
+ raise ValueError(f"Failed to save annotation: {response.text}")
214
+
215
+ return response.json()
216
+
304
217
  def delete_trace(self, trace_id: str):
305
218
  """
306
219
  Delete a trace from the database.
@@ -391,15 +304,16 @@ class TraceClient:
391
304
  self.enable_evaluations = enable_evaluations
392
305
  self.parent_trace_id = parent_trace_id
393
306
  self.parent_name = parent_name
394
- self.client: JudgmentClient = tracer.client
395
- self.entries: List[TraceEntry] = []
307
+ self.trace_spans: List[TraceSpan] = []
308
+ self.span_id_to_span: Dict[str, TraceSpan] = {}
309
+ self.evaluation_runs: List[EvaluationRun] = []
310
+ self.annotations: List[TraceAnnotation] = []
396
311
  self.start_time = time.time()
397
312
  self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
398
313
  self.visited_nodes = []
399
314
  self.executed_tools = []
400
315
  self.executed_node_tools = []
401
316
  self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
402
-
403
317
  def get_current_span(self):
404
318
  """Get the current span from the context var"""
405
319
  return current_span_var.get()
@@ -429,9 +343,7 @@ class TraceClient:
429
343
 
430
344
  self._span_depths[span_id] = current_depth # Store depth by span_id
431
345
 
432
- entry = TraceEntry(
433
- type="enter",
434
- function=name,
346
+ span = TraceSpan(
435
347
  span_id=span_id,
436
348
  trace_id=self.trace_id,
437
349
  depth=current_depth,
@@ -439,25 +351,15 @@ class TraceClient:
439
351
  created_at=start_time,
440
352
  span_type=span_type,
441
353
  parent_span_id=parent_span_id,
354
+ function=name,
442
355
  )
443
- self.add_entry(entry)
356
+ self.add_span(span)
444
357
 
445
358
  try:
446
359
  yield self
447
360
  finally:
448
361
  duration = time.time() - start_time
449
- exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
450
- self.add_entry(TraceEntry(
451
- type="exit",
452
- function=name,
453
- span_id=span_id, # Use the same span_id for exit
454
- trace_id=self.trace_id, # Use the trace_id from the trace client
455
- depth=exit_depth,
456
- message=f"← {name}",
457
- created_at=time.time(),
458
- duration=duration,
459
- span_type=span_type,
460
- ))
362
+ span.duration = duration
461
363
  # Clean up depth tracking for this span_id
462
364
  if span_id in self._span_depths:
463
365
  del self._span_depths[span_id]
@@ -467,32 +369,24 @@ class TraceClient:
467
369
  def async_evaluate(
468
370
  self,
469
371
  scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
372
+ example: Optional[Example] = None,
470
373
  input: Optional[str] = None,
471
- actual_output: Optional[str] = None,
472
- expected_output: Optional[str] = None,
374
+ actual_output: Optional[Union[str, List[str]]] = None,
375
+ expected_output: Optional[Union[str, List[str]]] = None,
473
376
  context: Optional[List[str]] = None,
474
377
  retrieval_context: Optional[List[str]] = None,
475
378
  tools_called: Optional[List[str]] = None,
476
379
  expected_tools: Optional[List[str]] = None,
477
380
  additional_metadata: Optional[Dict[str, Any]] = None,
478
381
  model: Optional[str] = None,
382
+ span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
479
383
  log_results: Optional[bool] = True
480
384
  ):
481
385
  if not self.enable_evaluations:
482
386
  return
483
387
 
484
388
  start_time = time.time() # Record start time
485
- example = Example(
486
- input=input,
487
- actual_output=actual_output,
488
- expected_output=expected_output,
489
- context=context,
490
- retrieval_context=retrieval_context,
491
- tools_called=tools_called,
492
- expected_tools=expected_tools,
493
- additional_metadata=additional_metadata,
494
- trace_id=self.trace_id
495
- )
389
+
496
390
  try:
497
391
  # Load appropriate implementations for all scorers
498
392
  if not scorers:
@@ -507,13 +401,44 @@ class TraceClient:
507
401
  warnings.warn(f"Failed to load scorers: {str(e)}")
508
402
  return
509
403
 
404
+ # If example is not provided, create one from the individual parameters
405
+ if example is None:
406
+ # Check if any of the individual parameters are provided
407
+ if any(param is not None for param in [input, actual_output, expected_output, context,
408
+ retrieval_context, tools_called, expected_tools,
409
+ additional_metadata]):
410
+ example = Example(
411
+ input=input,
412
+ actual_output=actual_output,
413
+ expected_output=expected_output,
414
+ context=context,
415
+ retrieval_context=retrieval_context,
416
+ tools_called=tools_called,
417
+ expected_tools=expected_tools,
418
+ additional_metadata=additional_metadata,
419
+ )
420
+ else:
421
+ raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
422
+
423
+ # Check examples before creating evaluation run
424
+
425
+ # check_examples([example], scorers)
426
+
427
+ # --- Modification: Capture span_id immediately ---
428
+ # span_id_at_eval_call = current_span_var.get()
429
+ # print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
430
+ # Prioritize explicitly passed span_id, fallback to context var
431
+ span_id_to_use = span_id if span_id is not None else current_span_var.get()
432
+ # print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
433
+ # --- End Modification ---
434
+
510
435
  # Combine the trace-level rules with any evaluation-specific rules)
511
436
  eval_run = EvaluationRun(
512
437
  organization_id=self.tracer.organization_id,
513
438
  log_results=log_results,
514
439
  project_name=self.project_name,
515
440
  eval_name=f"{self.name.capitalize()}-"
516
- f"{current_span_var.get()}-"
441
+ f"{current_span_var.get()}-" # Keep original eval name format using context var if available
517
442
  f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
518
443
  examples=[example],
519
444
  scorers=scorers,
@@ -521,296 +446,73 @@ class TraceClient:
521
446
  metadata={},
522
447
  judgment_api_key=self.tracer.api_key,
523
448
  override=self.overwrite,
524
- trace_span_id=current_span_var.get(),
449
+ trace_span_id=span_id_to_use, # Pass the determined ID
525
450
  rules=self.rules # Use the combined rules
526
451
  )
527
452
 
528
453
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
529
454
 
530
455
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
531
- current_span_id = current_span_var.get()
532
- if current_span_id:
533
- duration = time.time() - start_time
534
- prev_entry = self.entries[-1] if self.entries else None
535
- # Determine function name based on previous entry or context var (less ideal)
536
- function_name = "unknown_function" # Default
537
- if prev_entry and prev_entry.span_type == "llm":
538
- function_name = prev_entry.function
539
- else:
540
- # Try to find the function name associated with the current span_id
541
- for entry in reversed(self.entries):
542
- if entry.span_id == current_span_id and entry.type == 'enter':
543
- function_name = entry.function
544
- break
545
-
546
- # Get depth for the current span
547
- current_depth = self._span_depths.get(current_span_id, 0)
548
-
549
- self.add_entry(TraceEntry(
550
- type="evaluation",
551
- function=function_name,
552
- span_id=current_span_id, # Associate with current span
553
- trace_id=self.trace_id, # Use the trace_id from the trace client
554
- depth=current_depth,
555
- message=f"Evaluation results for {function_name}",
556
- created_at=time.time(),
557
- evaluation_runs=[eval_run],
558
- duration=duration,
559
- span_type="evaluation"
560
- ))
456
+ # --- Modification: Use span_id from eval_run ---
457
+ current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
458
+ # print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
459
+ # --- End Modification ---
561
460
 
461
+ if current_span_id:
462
+ span = self.span_id_to_span[current_span_id]
463
+ span.evaluation_runs.append(eval_run)
464
+ self.evaluation_runs.append(eval_run)
465
+
466
+ def add_annotation(self, annotation: TraceAnnotation):
467
+ """Add an annotation to this trace context"""
468
+ self.annotations.append(annotation)
469
+ return self
470
+
562
471
  def record_input(self, inputs: dict):
563
472
  current_span_id = current_span_var.get()
564
473
  if current_span_id:
565
- entry_span_type = "span"
566
- current_depth = self._span_depths.get(current_span_id, 0)
567
- function_name = "unknown_function" # Default
568
- for entry in reversed(self.entries):
569
- if entry.span_id == current_span_id and entry.type == 'enter':
570
- entry_span_type = entry.span_type
571
- function_name = entry.function
572
- break
573
-
574
- self.add_entry(TraceEntry(
575
- type="input",
576
- function=function_name,
577
- span_id=current_span_id, # Use current span_id
578
- trace_id=self.trace_id, # Use the trace_id from the trace client
579
- depth=current_depth,
580
- message=f"Inputs to {function_name}",
581
- created_at=time.time(),
582
- inputs=inputs,
583
- span_type=entry_span_type,
584
- ))
585
-
586
- async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
474
+ span = self.span_id_to_span[current_span_id]
475
+ span.inputs = inputs
476
+
477
+ async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
587
478
  """Helper method to update the output of a trace entry once the coroutine completes"""
588
479
  try:
589
480
  result = await coroutine
590
- entry.output = result
481
+ span.output = result
591
482
  return result
592
483
  except Exception as e:
593
- entry.output = f"Error: {str(e)}"
484
+ span.output = f"Error: {str(e)}"
594
485
  raise
595
486
 
596
487
  def record_output(self, output: Any):
597
488
  current_span_id = current_span_var.get()
598
489
  if current_span_id:
599
- entry_span_type = "span"
600
- current_depth = self._span_depths.get(current_span_id, 0)
601
- function_name = "unknown_function" # Default
602
- for entry in reversed(self.entries):
603
- if entry.span_id == current_span_id and entry.type == 'enter':
604
- entry_span_type = entry.span_type
605
- function_name = entry.function
606
- break
607
-
608
- entry = TraceEntry(
609
- type="output",
610
- function=function_name,
611
- span_id=current_span_id, # Use current span_id
612
- depth=current_depth,
613
- message=f"Output from {function_name}",
614
- created_at=time.time(),
615
- output="<pending>" if inspect.iscoroutine(output) else output,
616
- span_type=entry_span_type,
617
- )
618
- self.add_entry(entry)
490
+ span = self.span_id_to_span[current_span_id]
491
+ span.output = "<pending>" if inspect.iscoroutine(output) else output
619
492
 
620
493
  if inspect.iscoroutine(output):
621
- asyncio.create_task(self._update_coroutine_output(entry, output))
494
+ asyncio.create_task(self._update_coroutine_output(span, output))
622
495
 
623
- # Return the created entry
624
- return entry
625
-
626
- def add_entry(self, entry: TraceEntry):
627
- """Add a trace entry to this trace context"""
628
- self.entries.append(entry)
496
+ return span # Return the created entry
497
+ # Removed else block - original didn't have one
498
+ return None # Return None if no span_id found
499
+
500
+ def add_span(self, span: TraceSpan):
501
+ """Add a trace span to this trace context"""
502
+ self.trace_spans.append(span)
503
+ self.span_id_to_span[span.span_id] = span
629
504
  return self
630
505
 
631
506
  def print(self):
632
507
  """Print the complete trace with proper visual structure"""
633
- for entry in self.entries:
634
- entry.print_entry()
635
-
636
- def print_hierarchical(self):
637
- """Print the trace in a hierarchical structure based on parent-child relationships"""
638
- # First, build a map of spans
639
- spans = {}
640
- root_spans = []
641
-
642
- # Collect all enter events first
643
- for entry in self.entries:
644
- if entry.type == "enter":
645
- spans[entry.function] = {
646
- "name": entry.function,
647
- "depth": entry.depth,
648
- "parent_id": entry.parent_span_id,
649
- "children": []
650
- }
651
-
652
- # If no parent, it's a root span
653
- if not entry.parent_span_id:
654
- root_spans.append(entry.function)
655
- elif entry.parent_span_id not in spans:
656
- # If parent doesn't exist yet, temporarily treat as root
657
- # (we'll fix this later)
658
- root_spans.append(entry.function)
659
-
660
- # Build parent-child relationships
661
- for span_name, span in spans.items():
662
- parent = span["parent_id"]
663
- if parent and parent in spans:
664
- spans[parent]["children"].append(span_name)
665
- # Remove from root spans if it was temporarily there
666
- if span_name in root_spans:
667
- root_spans.remove(span_name)
668
-
669
- # Now print the hierarchy
670
- def print_span(span_name, level=0):
671
- if span_name not in spans:
672
- return
673
-
674
- span = spans[span_name]
675
- indent = " " * level
676
- parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
677
- print(f"{indent}→ {span_name}{parent_info}")
678
-
679
- # Print children
680
- for child in span["children"]:
681
- print_span(child, level + 1)
682
-
683
- # Print starting with root spans
684
- print("\nHierarchical Trace Structure:")
685
- for root in root_spans:
686
- print_span(root)
508
+ for span in self.trace_spans:
509
+ span.print_span()
687
510
 
688
511
  def get_duration(self) -> float:
689
512
  """
690
513
  Get the total duration of this trace
691
514
  """
692
515
  return time.time() - self.start_time
693
-
694
- def condense_trace(self, entries: List[dict]) -> List[dict]:
695
- """
696
- Condenses trace entries into a single entry for each span instance,
697
- preserving parent-child span relationships using span_id and parent_span_id.
698
- """
699
- spans_by_id: Dict[str, dict] = {}
700
- evaluation_runs: List[EvaluationRun] = []
701
-
702
- # First pass: Group entries by span_id and gather data
703
- for entry in entries:
704
- span_id = entry.get("span_id")
705
- if not span_id:
706
- continue # Skip entries without a span_id (should not happen)
707
-
708
- if entry["type"] == "enter":
709
- if span_id not in spans_by_id:
710
- spans_by_id[span_id] = {
711
- "span_id": span_id,
712
- "function": entry["function"],
713
- "depth": entry["depth"], # Use the depth recorded at entry time
714
- "created_at": entry["created_at"],
715
- "trace_id": entry["trace_id"],
716
- "parent_span_id": entry.get("parent_span_id"),
717
- "span_type": entry.get("span_type", "span"),
718
- "inputs": None,
719
- "output": None,
720
- "evaluation_runs": [],
721
- "duration": None
722
- }
723
- # Handle potential duplicate enter events if necessary (e.g., log warning)
724
-
725
- elif span_id in spans_by_id:
726
- current_span_data = spans_by_id[span_id]
727
-
728
- if entry["type"] == "input" and entry["inputs"]:
729
- # Merge inputs if multiple are recorded, or just assign
730
- if current_span_data["inputs"] is None:
731
- current_span_data["inputs"] = entry["inputs"]
732
- elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
733
- current_span_data["inputs"].update(entry["inputs"])
734
- # Add more sophisticated merging if needed
735
-
736
- elif entry["type"] == "output" and "output" in entry:
737
- current_span_data["output"] = entry["output"]
738
-
739
- elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
740
- if current_span_data.get("evaluation_runs") is not None:
741
- evaluation_runs.extend(entry["evaluation_runs"])
742
-
743
- elif entry["type"] == "exit":
744
- if current_span_data["duration"] is None: # Calculate duration only once
745
- start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
746
- end_time = datetime.fromisoformat(entry["created_at"])
747
- current_span_data["duration"] = (end_time - start_time).total_seconds()
748
- # Update depth if exit depth is different (though current span() implementation keeps it same)
749
- # current_span_data["depth"] = entry["depth"]
750
-
751
- # Convert dictionary to a list initially for easier access
752
- spans_list = list(spans_by_id.values())
753
-
754
- # Build tree structure (adjacency list) and find roots
755
- children_map: Dict[Optional[str], List[dict]] = {}
756
- roots = []
757
- span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
758
-
759
- for span in spans_list:
760
- parent_id = span.get("parent_span_id")
761
- if parent_id is None:
762
- roots.append(span)
763
- else:
764
- if parent_id not in children_map:
765
- children_map[parent_id] = []
766
- children_map[parent_id].append(span)
767
-
768
- # Sort roots by timestamp
769
- roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
770
-
771
- # Perform depth-first traversal to get the final sorted list
772
- sorted_condensed_list = []
773
- visited = set() # To handle potential cycles, though unlikely with UUIDs
774
-
775
- def dfs(span_data):
776
- span_id = span_data['span_id']
777
- if span_id in visited:
778
- return # Avoid infinite loops in case of cycles
779
- visited.add(span_id)
780
-
781
- sorted_condensed_list.append(span_data) # Add parent before children
782
-
783
- # Get children, sort them by created_at, and visit them
784
- span_children = children_map.get(span_id, [])
785
- span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
786
- for child in span_children:
787
- # Ensure the child exists in our map before recursing
788
- if child['span_id'] in span_map:
789
- dfs(child)
790
- else:
791
- # This case might indicate an issue, but we'll add the child directly
792
- # if its parent was processed but the child itself wasn't in the initial list?
793
- # Or if the child's 'enter' event was missing. For robustness, add it.
794
- if child['span_id'] not in visited:
795
- visited.add(child['span_id'])
796
- sorted_condensed_list.append(child)
797
-
798
-
799
- # Start DFS from each root
800
- for root_span in roots:
801
- if root_span['span_id'] not in visited:
802
- dfs(root_span)
803
-
804
- # Handle spans that might not have been reachable from roots (orphans)
805
- # Though ideally, all spans should descend from a root.
806
- for span_data in spans_list:
807
- if span_data['span_id'] not in visited:
808
- # Decide how to handle orphans, maybe append them at the end sorted by time?
809
- # For now, let's just add them to ensure they aren't lost.
810
- sorted_condensed_list.append(span_data)
811
-
812
-
813
- return sorted_condensed_list, evaluation_runs
814
516
 
815
517
  def save(self, overwrite: bool = False) -> Tuple[str, dict]:
816
518
  """
@@ -819,103 +521,391 @@ class TraceClient:
819
521
  """
820
522
  # Calculate total elapsed time
821
523
  total_duration = self.get_duration()
822
-
823
- raw_entries = [entry.to_dict() for entry in self.entries]
824
-
825
- condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
826
524
 
827
- # Calculate total token counts from LLM API calls
828
- total_prompt_tokens = 0
829
- total_completion_tokens = 0
830
- total_tokens = 0
831
-
832
- total_prompt_tokens_cost = 0.0
833
- total_completion_tokens_cost = 0.0
834
- total_cost = 0.0
835
-
836
525
  # Only count tokens for actual LLM API call spans
837
526
  llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
838
- for entry in condensed_entries:
839
- if entry.get("span_type") == "llm" and entry.get("function") in llm_span_names and isinstance(entry.get("output"), dict):
840
- output = entry["output"]
841
- usage = output.get("usage", {})
842
- model_name = entry.get("inputs", {}).get("model", "")
527
+ for span in self.trace_spans:
528
+ span_function_name = span.function # Get function name safely
529
+ # Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
530
+ is_llm_span = span.span_type == "llm"
531
+ has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
532
+ output_is_dict = isinstance(span.output, dict)
533
+
534
+ # --- DEBUG PRINT 1: Check if condition passes ---
535
+ # if is_llm_entry and has_api_suffix and output_is_dict:
536
+ # elif is_llm_entry:
537
+ # # Print why it failed if it was an LLM entry
538
+ # # --- END DEBUG ---
539
+
540
+ if is_llm_span and has_api_suffix and output_is_dict:
541
+ output = span.output
542
+ usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
543
+
544
+ # --- DEBUG PRINT 2: Check extracted usage ---
545
+ # --- END DEBUG ---
546
+
547
+ # --- NEW: Extract model_name correctly from nested inputs ---
548
+ model_name = None
549
+ span_inputs = span.inputs
550
+ if span_inputs:
551
+ # Try common locations for model name within the inputs structure
552
+ invocation_params = span_inputs.get("invocation_params", {})
553
+ serialized_data = span_inputs.get("serialized", {})
554
+
555
+ # Look in invocation_params (often directly contains model)
556
+ if isinstance(invocation_params, dict):
557
+ model_name = invocation_params.get("model")
558
+
559
+ # Fallback: Check serialized 'repr' if it contains model info
560
+ if not model_name and isinstance(serialized_data, dict):
561
+ serialized_repr = serialized_data.get("repr", "")
562
+ if "model_name=" in serialized_repr:
563
+ try: # Simple parsing attempt
564
+ model_name = serialized_repr.split("model_name='")[1].split("'")[0]
565
+ except IndexError: pass # Ignore parsing errors
566
+
567
+ # Fallback: Check top-level of invocation_params (sometimes passed flat)
568
+ if not model_name and isinstance(invocation_params, dict):
569
+ model_name = invocation_params.get("model") # Redundant check, but safe
570
+
571
+ # Fallback: Check top-level of inputs itself (less likely for callbacks)
572
+ if not model_name:
573
+ model_name = span_inputs.get("model")
574
+
575
+
576
+ # --- END NEW ---
577
+
843
578
  prompt_tokens = 0
844
- completion_tokens = 0
845
-
846
- # Handle OpenAI/Together format
579
+ completion_tokens = 0
580
+
581
+ # Handle OpenAI/Together format (checks within the 'usage' dict)
847
582
  if "prompt_tokens" in usage:
848
583
  prompt_tokens = usage.get("prompt_tokens", 0)
849
584
  completion_tokens = usage.get("completion_tokens", 0)
850
- total_prompt_tokens += prompt_tokens
851
- total_completion_tokens += completion_tokens
852
- # Handle Anthropic format
585
+
586
+ # Handle Anthropic format - MAP values to standard keys
853
587
  elif "input_tokens" in usage:
854
- prompt_tokens = usage.get("input_tokens", 0)
855
- completion_tokens = usage.get("output_tokens", 0)
856
- total_prompt_tokens += prompt_tokens
857
- total_completion_tokens += completion_tokens
858
-
859
- total_tokens += usage.get("total_tokens", 0)
588
+ prompt_tokens = usage.get("input_tokens", 0) # Get value from input_tokens
589
+ completion_tokens = usage.get("output_tokens", 0) # Get value from output_tokens
590
+
591
+ # *** Overwrite the usage dict in the entry to use standard keys ***
592
+ original_total = usage.get("total_tokens", 0)
593
+ original_total_cost = usage.get("total_cost_usd", 0.0) # Preserve if already calculated
594
+ # Recalculate cost just in case it wasn't done correctly before
595
+ temp_prompt_cost, temp_completion_cost = 0.0, 0.0
596
+ if model_name:
597
+ try:
598
+ temp_prompt_cost, temp_completion_cost = cost_per_token(
599
+ model=model_name,
600
+ prompt_tokens=prompt_tokens,
601
+ completion_tokens=completion_tokens
602
+ )
603
+ except Exception:
604
+ pass # Ignore cost calculation errors here, focus on keys
605
+ # Replace the usage dict with one using standard keys but Anthropic values
606
+ output["usage"] = {
607
+ "prompt_tokens": prompt_tokens,
608
+ "completion_tokens": completion_tokens,
609
+ "total_tokens": original_total,
610
+ "prompt_tokens_cost_usd": temp_prompt_cost, # Use standard cost key
611
+ "completion_tokens_cost_usd": temp_completion_cost, # Use standard cost key
612
+ "total_cost_usd": original_total_cost if original_total_cost > 0 else (temp_prompt_cost + temp_completion_cost)
613
+ }
614
+ usage = output["usage"]
615
+
616
+ # Calculate costs if model name is available and ensure they are stored with standard keys
617
+ prompt_tokens = usage.get("prompt_tokens", 0)
618
+ completion_tokens = usage.get("completion_tokens", 0)
860
619
 
861
620
  # Calculate costs if model name is available
862
621
  if model_name:
863
622
  try:
623
+ # Recalculate costs based on potentially mapped tokens
864
624
  prompt_cost, completion_cost = cost_per_token(
865
625
  model=model_name,
866
626
  prompt_tokens=prompt_tokens,
867
627
  completion_tokens=completion_tokens
868
628
  )
869
- total_prompt_tokens_cost += prompt_cost
870
- total_completion_tokens_cost += completion_cost
871
- total_cost += prompt_cost + completion_cost
872
629
 
873
630
  # Add cost information directly to the usage dictionary in the condensed entry
631
+ # Ensure 'usage' exists in the output dict before modifying it
632
+ # Add/Update cost information using standard keys
633
+
874
634
  if "usage" not in output:
875
- output["usage"] = {}
635
+ output["usage"] = {} # Initialize if missing
636
+ elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
637
+ print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
638
+ output["usage"] = {} # Reset to dict
639
+
876
640
  output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
877
641
  output["usage"]["completion_tokens_cost_usd"] = completion_cost
878
642
  output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
879
643
  except Exception as e:
880
644
  # If cost calculation fails, continue without adding costs
881
- print(f"Error calculating cost for model '{model_name}': {str(e)}")
645
+ print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
882
646
  pass
647
+ else:
648
+ print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
883
649
 
884
- # Create trace document
650
+
651
+ # Create trace document - Always use standard keys for top-level counts
885
652
  trace_data = {
886
653
  "trace_id": self.trace_id,
887
654
  "name": self.name,
888
655
  "project_name": self.project_name,
889
656
  "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
890
657
  "duration": total_duration,
891
- "token_counts": {
892
- "prompt_tokens": total_prompt_tokens,
893
- "completion_tokens": total_completion_tokens,
894
- "total_tokens": total_tokens,
895
- "prompt_tokens_cost_usd": total_prompt_tokens_cost,
896
- "completion_tokens_cost_usd": total_completion_tokens_cost,
897
- "total_cost_usd": total_cost
898
- },
899
- "entries": condensed_entries,
900
- "evaluation_runs": evaluation_runs,
658
+ "entries": [span.model_dump() for span in self.trace_spans],
659
+ "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
901
660
  "overwrite": overwrite,
902
661
  "parent_trace_id": self.parent_trace_id,
903
662
  "parent_name": self.parent_name
904
663
  }
905
664
  # --- Log trace data before saving ---
906
- try:
907
- rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
908
- rprint(json.dumps(trace_data, indent=2))
909
- except Exception as log_e:
910
- rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
911
- # --- End logging ---
912
665
  self.trace_manager_client.save_trace(trace_data)
913
666
 
667
+ # upload annotations
668
+ # TODO: batch to the log endpoint
669
+ for annotation in self.annotations:
670
+ self.trace_manager_client.save_annotation(annotation)
671
+
914
672
  return self.trace_id, trace_data
915
673
 
916
674
  def delete(self):
917
675
  return self.trace_manager_client.delete_trace(self.trace_id)
918
676
 
677
+
678
+ class _DeepTracer:
679
+ _instance: Optional["_DeepTracer"] = None
680
+ _lock: threading.Lock = threading.Lock()
681
+ _refcount: int = 0
682
+ _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
683
+ _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
684
+
685
+ def _get_qual_name(self, frame) -> str:
686
+ func_name = frame.f_code.co_name
687
+ module_name = frame.f_globals.get("__name__", "unknown_module")
688
+
689
+ try:
690
+ func = frame.f_globals.get(func_name)
691
+ if func is None:
692
+ return f"{module_name}.{func_name}"
693
+ if hasattr(func, "__qualname__"):
694
+ return f"{module_name}.{func.__qualname__}"
695
+ except Exception:
696
+ return f"{module_name}.{func_name}"
697
+
698
+ def __new__(cls):
699
+ with cls._lock:
700
+ if cls._instance is None:
701
+ cls._instance = super().__new__(cls)
702
+ return cls._instance
703
+
704
+ def _should_trace(self, frame):
705
+ # Skip stack is maintained by the tracer as an optimization to skip earlier
706
+ # frames in the call stack that we've already determined should be skipped
707
+ skip_stack = self._skip_stack.get()
708
+ if len(skip_stack) > 0:
709
+ return False
710
+
711
+ func_name = frame.f_code.co_name
712
+ module_name = frame.f_globals.get("__name__", None)
713
+
714
+ func = frame.f_globals.get(func_name)
715
+ if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
716
+ return False
717
+
718
+ if (
719
+ not module_name
720
+ or func_name.startswith("<") # ex: <listcomp>
721
+ or func_name.startswith("__") and func_name != "__call__" # dunders
722
+ or not self._is_user_code(frame.f_code.co_filename)
723
+ ):
724
+ return False
725
+
726
+ return True
727
+
728
+ @functools.cache
729
+ def _is_user_code(self, filename: str):
730
+ return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
731
+
732
+ def _trace(self, frame: types.FrameType, event: str, arg: Any):
733
+ frame.f_trace_lines = False
734
+ frame.f_trace_opcodes = False
735
+
736
+
737
+ if not self._should_trace(frame):
738
+ return
739
+
740
+ if event not in ("call", "return", "exception"):
741
+ return
742
+
743
+ current_trace = current_trace_var.get()
744
+ if not current_trace:
745
+ return
746
+
747
+ parent_span_id = current_span_var.get()
748
+ if not parent_span_id:
749
+ return
750
+
751
+ qual_name = self._get_qual_name(frame)
752
+ skip_stack = self._skip_stack.get()
753
+
754
+ if event == "call":
755
+ # If we have entries in the skip stack and the current qual_name matches the top entry,
756
+ # push it again to track nesting depth and skip
757
+ # As an optimization, we only care about duplicate qual_names.
758
+ if skip_stack:
759
+ if qual_name == skip_stack[-1]:
760
+ skip_stack.append(qual_name)
761
+ self._skip_stack.set(skip_stack)
762
+ return
763
+
764
+ should_trace = self._should_trace(frame)
765
+
766
+ if not should_trace:
767
+ if not skip_stack:
768
+ self._skip_stack.set([qual_name])
769
+ return
770
+ elif event == "return":
771
+ # If we have entries in skip stack and current qual_name matches the top entry,
772
+ # pop it to track exiting from the skipped section
773
+ if skip_stack and qual_name == skip_stack[-1]:
774
+ skip_stack.pop()
775
+ self._skip_stack.set(skip_stack)
776
+ return
777
+
778
+ if skip_stack:
779
+ return
780
+
781
+ span_stack = self._span_stack.get()
782
+ if event == "call":
783
+ if not self._should_trace(frame):
784
+ return
785
+
786
+ span_id = str(uuid.uuid4())
787
+
788
+ parent_depth = current_trace._span_depths.get(parent_span_id, 0)
789
+ depth = parent_depth + 1
790
+
791
+ current_trace._span_depths[span_id] = depth
792
+
793
+ start_time = time.time()
794
+
795
+ span_stack.append({
796
+ "span_id": span_id,
797
+ "parent_span_id": parent_span_id,
798
+ "function": qual_name,
799
+ "start_time": start_time
800
+ })
801
+ self._span_stack.set(span_stack)
802
+
803
+ token = current_span_var.set(span_id)
804
+ frame.f_locals["_judgment_span_token"] = token
805
+
806
+ span = TraceSpan(
807
+ span_id=span_id,
808
+ trace_id=current_trace.trace_id,
809
+ depth=depth,
810
+ message=qual_name,
811
+ created_at=start_time,
812
+ span_type="span",
813
+ parent_span_id=parent_span_id,
814
+ function=qual_name
815
+ )
816
+ current_trace.add_span(span)
817
+
818
+ inputs = {}
819
+ try:
820
+ args_info = inspect.getargvalues(frame)
821
+ for arg in args_info.args:
822
+ try:
823
+ inputs[arg] = args_info.locals.get(arg)
824
+ except:
825
+ inputs[arg] = "<<Unserializable>>"
826
+ current_trace.record_input(inputs)
827
+ except Exception as e:
828
+ current_trace.record_input({
829
+ "error": str(e)
830
+ })
831
+
832
+ elif event == "return":
833
+ if not span_stack:
834
+ return
835
+
836
+ current_id = current_span_var.get()
837
+
838
+ span_data = None
839
+ for i, entry in enumerate(reversed(span_stack)):
840
+ if entry["span_id"] == current_id:
841
+ span_data = span_stack.pop(-(i+1))
842
+ self._span_stack.set(span_stack)
843
+ break
844
+
845
+ if not span_data:
846
+ return
847
+
848
+ start_time = span_data["start_time"]
849
+ duration = time.time() - start_time
850
+
851
+ current_trace.span_id_to_span[span_data["span_id"]].duration = duration
852
+
853
+ if arg is not None:
854
+ # exception handling will take priority.
855
+ current_trace.record_output(arg)
856
+
857
+ if span_data["span_id"] in current_trace._span_depths:
858
+ del current_trace._span_depths[span_data["span_id"]]
859
+
860
+ if span_stack:
861
+ current_span_var.set(span_stack[-1]["span_id"])
862
+ else:
863
+ current_span_var.set(span_data["parent_span_id"])
864
+
865
+ if "_judgment_span_token" in frame.f_locals:
866
+ current_span_var.reset(frame.f_locals["_judgment_span_token"])
867
+
868
+ elif event == "exception":
869
+ exc_type, exc_value, exc_traceback = arg
870
+ formatted_exception = {
871
+ "type": exc_type.__name__,
872
+ "message": str(exc_value),
873
+ "traceback": traceback.format_tb(exc_traceback)
874
+ }
875
+ current_trace = current_trace_var.get()
876
+ current_trace.record_output({
877
+ "error": formatted_exception
878
+ })
879
+
880
+ return self._trace
881
+
882
+ def __enter__(self):
883
+ with self._lock:
884
+ self._refcount += 1
885
+ if self._refcount == 1:
886
+ self._skip_stack.set([])
887
+ self._span_stack.set([])
888
+ sys.settrace(self._trace)
889
+ threading.settrace(self._trace)
890
+ return self
891
+
892
+ def __exit__(self, exc_type, exc_val, exc_tb):
893
+ with self._lock:
894
+ self._refcount -= 1
895
+ if self._refcount == 0:
896
+ sys.settrace(None)
897
+ threading.settrace(None)
898
+
899
+
900
+ def log(self, message: str, level: str = "info"):
901
+ """ Log a message with the span context """
902
+ current_trace = current_trace_var.get()
903
+ if current_trace:
904
+ current_trace.log(message, level)
905
+ else:
906
+ print(f"[{level}] {message}")
907
+ current_trace.record_output({"log": message})
908
+
919
909
  class Tracer:
920
910
  _instance = None
921
911
 
@@ -938,12 +928,16 @@ class Tracer:
938
928
  s3_aws_access_key_id: Optional[str] = None,
939
929
  s3_aws_secret_access_key: Optional[str] = None,
940
930
  s3_region_name: Optional[str] = None,
941
- deep_tracing: bool = True # NEW: Enable deep tracing by default
931
+ deep_tracing: bool = True # Deep tracing is enabled by default
942
932
  ):
943
933
  if not hasattr(self, 'initialized'):
944
934
  if not api_key:
945
935
  raise ValueError("Tracer must be configured with a Judgment API key")
946
936
 
937
+ result, response = validate_api_key(api_key)
938
+ if not result:
939
+ raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
940
+
947
941
  if not organization_id:
948
942
  raise ValueError("Tracer must be configured with an Organization ID")
949
943
  if use_s3 and not s3_bucket_name:
@@ -955,10 +949,11 @@ class Tracer:
955
949
 
956
950
  self.api_key: str = api_key
957
951
  self.project_name: str = project_name
958
- self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
959
952
  self.organization_id: str = organization_id
960
953
  self._current_trace: Optional[str] = None
954
+ self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
961
955
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
956
+ self.traces: List[Trace] = []
962
957
  self.initialized: bool = True
963
958
  self.enable_monitoring: bool = enable_monitoring
964
959
  self.enable_evaluations: bool = enable_evaluations
@@ -991,49 +986,29 @@ class Tracer:
991
986
 
992
987
  def get_current_trace(self) -> Optional[TraceClient]:
993
988
  """
994
- Get the current trace context from contextvars
995
- """
996
- return current_trace_var.get()
997
-
998
- def _apply_deep_tracing(self, func, span_type="span"):
989
+ Get the current trace context.
990
+
991
+ Tries to get the trace client from the context variable first.
992
+ If not found (e.g., context lost across threads/tasks),
993
+ it falls back to the active trace client managed by the callback handler.
999
994
  """
1000
- Apply deep tracing to all functions in the same module as the given function.
995
+ trace_from_context = current_trace_var.get()
996
+ if trace_from_context:
997
+ return trace_from_context
1001
998
 
1002
- Args:
1003
- func: The function being traced
1004
- span_type: Type of span to use for traced functions
999
+ # Fallback: Check the active client potentially set by a callback handler
1000
+ if hasattr(self, '_active_trace_client') and self._active_trace_client:
1001
+ # warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
1002
+ return self._active_trace_client
1005
1003
 
1006
- Returns:
1007
- A tuple of (module, original_functions_dict) where original_functions_dict
1008
- contains the original functions that were replaced with traced versions.
1009
- """
1010
- module = inspect.getmodule(func)
1011
- if not module:
1012
- return None, {}
1013
-
1014
- # Save original functions
1015
- original_functions = {}
1004
+ # If neither is available
1005
+ # warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
1006
+ return None
1016
1007
 
1017
- # Find all functions in the module
1018
- for name, obj in inspect.getmembers(module, inspect.isfunction):
1019
- # Skip already wrapped functions
1020
- if hasattr(obj, '_judgment_traced'):
1021
- continue
1022
-
1023
- # Create a traced version of the function
1024
- # Always use default span type "span" for child functions
1025
- traced_func = _create_deep_tracing_wrapper(obj, self, "span")
1026
-
1027
- # Mark the function as traced to avoid double wrapping
1028
- traced_func._judgment_traced = True
1029
-
1030
- # Save the original function
1031
- original_functions[name] = obj
1032
-
1033
- # Replace with traced version
1034
- setattr(module, name, traced_func)
1035
-
1036
- return module, original_functions
1008
+ def get_active_trace_client(self) -> Optional[TraceClient]:
1009
+ """Returns the TraceClient instance currently marked as active by the handler."""
1010
+ return self._active_trace_client
1011
+
1037
1012
 
1038
1013
  @contextmanager
1039
1014
  def trace(
@@ -1080,6 +1055,23 @@ class Tracer:
1080
1055
  finally:
1081
1056
  # Reset the context variable
1082
1057
  current_trace_var.reset(token)
1058
+
1059
+
1060
+ def log(self, msg: str, label: str = "log", score: int = 1):
1061
+ """Log a message with the current span context"""
1062
+ current_span_id = current_span_var.get()
1063
+ current_trace = current_trace_var.get()
1064
+ if current_span_id:
1065
+ annotation = TraceAnnotation(
1066
+ span_id=current_span_id,
1067
+ text=msg,
1068
+ label=label,
1069
+ score=score
1070
+ )
1071
+
1072
+ current_trace.add_annotation(annotation)
1073
+
1074
+ rprint(f"[bold]{label}:[/bold] {msg}")
1083
1075
 
1084
1076
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1085
1077
  """
@@ -1115,13 +1107,6 @@ class Tracer:
1115
1107
  if asyncio.iscoroutinefunction(func):
1116
1108
  @functools.wraps(func)
1117
1109
  async def async_wrapper(*args, **kwargs):
1118
- # Check if we're already in a traced function
1119
- if in_traced_function_var.get():
1120
- return await func(*args, **kwargs)
1121
-
1122
- # Set in_traced_function_var to True
1123
- token = in_traced_function_var.set(True)
1124
-
1125
1110
  # Get current trace from context
1126
1111
  current_trace = current_trace_var.get()
1127
1112
 
@@ -1151,81 +1136,47 @@ class Tracer:
1151
1136
  # This sets the current_span_var
1152
1137
  with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1153
1138
  # Record inputs
1154
- span.record_input({
1155
- 'args': str(args),
1156
- 'kwargs': kwargs
1157
- })
1139
+ inputs = combine_args_kwargs(func, args, kwargs)
1140
+ span.record_input(inputs)
1158
1141
 
1159
- # If deep tracing is enabled, apply monkey patching
1160
1142
  if use_deep_tracing:
1161
- module, original_functions = self._apply_deep_tracing(func, span_type)
1162
-
1163
- # Execute function
1164
- result = await func(*args, **kwargs)
1165
-
1166
- # Restore original functions if deep tracing was enabled
1167
- if use_deep_tracing and module and 'original_functions' in locals():
1168
- for name, obj in original_functions.items():
1169
- setattr(module, name, obj)
1170
-
1143
+ with _DeepTracer():
1144
+ result = await func(*args, **kwargs)
1145
+ else:
1146
+ result = await func(*args, **kwargs)
1147
+
1171
1148
  # Record output
1172
1149
  span.record_output(result)
1173
-
1174
- # Save the completed trace
1175
- current_trace.save(overwrite=overwrite)
1176
1150
  return result
1177
1151
  finally:
1152
+ # Save the completed trace
1153
+ trace_id, trace = current_trace.save(overwrite=overwrite)
1154
+ self.traces.append(trace)
1155
+
1178
1156
  # Reset trace context (span context resets automatically)
1179
1157
  current_trace_var.reset(trace_token)
1180
- # Reset in_traced_function_var
1181
- in_traced_function_var.reset(token)
1182
1158
  else:
1183
- # Already have a trace context, just create a span in it
1184
- # The span method handles current_span_var
1185
-
1186
- try:
1187
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1188
- # Record inputs
1189
- span.record_input({
1190
- 'args': str(args),
1191
- 'kwargs': kwargs
1192
- })
1193
-
1194
- # If deep tracing is enabled, apply monkey patching
1195
- if use_deep_tracing:
1196
- module, original_functions = self._apply_deep_tracing(func, span_type)
1197
-
1198
- # Execute function
1159
+ with current_trace.span(span_name, span_type=span_type) as span:
1160
+ inputs = combine_args_kwargs(func, args, kwargs)
1161
+ span.record_input(inputs)
1162
+
1163
+ if use_deep_tracing:
1164
+ with _DeepTracer():
1165
+ result = await func(*args, **kwargs)
1166
+ else:
1199
1167
  result = await func(*args, **kwargs)
1200
1168
 
1201
- # Restore original functions if deep tracing was enabled
1202
- if use_deep_tracing and module and 'original_functions' in locals():
1203
- for name, obj in original_functions.items():
1204
- setattr(module, name, obj)
1205
-
1206
- # Record output
1207
- span.record_output(result)
1208
-
1209
- return result
1210
- finally:
1211
- # Reset in_traced_function_var
1212
- in_traced_function_var.reset(token)
1213
-
1169
+ span.record_output(result)
1170
+ return result
1171
+
1214
1172
  return async_wrapper
1215
1173
  else:
1216
1174
  # Non-async function implementation with deep tracing
1217
1175
  @functools.wraps(func)
1218
- def wrapper(*args, **kwargs):
1219
- # Check if we're already in a traced function
1220
- if in_traced_function_var.get():
1221
- return func(*args, **kwargs)
1222
-
1223
- # Set in_traced_function_var to True
1224
- token = in_traced_function_var.set(True)
1225
-
1176
+ def wrapper(*args, **kwargs):
1226
1177
  # Get current trace from context
1227
1178
  current_trace = current_trace_var.get()
1228
-
1179
+
1229
1180
  # If there's no current trace, create a root trace
1230
1181
  if not current_trace:
1231
1182
  trace_id = str(uuid.uuid4())
@@ -1252,105 +1203,65 @@ class Tracer:
1252
1203
  # This sets the current_span_var
1253
1204
  with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1254
1205
  # Record inputs
1255
- span.record_input({
1256
- 'args': str(args),
1257
- 'kwargs': kwargs
1258
- })
1206
+ inputs = combine_args_kwargs(func, args, kwargs)
1207
+ span.record_input(inputs)
1259
1208
 
1260
- # If deep tracing is enabled, apply monkey patching
1261
1209
  if use_deep_tracing:
1262
- module, original_functions = self._apply_deep_tracing(func, span_type)
1263
-
1264
- # Execute function
1265
- result = func(*args, **kwargs)
1266
-
1267
- # Restore original functions if deep tracing was enabled
1268
- if use_deep_tracing and module and 'original_functions' in locals():
1269
- for name, obj in original_functions.items():
1270
- setattr(module, name, obj)
1210
+ with _DeepTracer():
1211
+ result = func(*args, **kwargs)
1212
+ else:
1213
+ result = func(*args, **kwargs)
1271
1214
 
1272
1215
  # Record output
1273
1216
  span.record_output(result)
1274
-
1275
- # Save the completed trace
1276
- current_trace.save(overwrite=overwrite)
1277
1217
  return result
1278
1218
  finally:
1219
+ # Save the completed trace
1220
+ trace_id, trace = current_trace.save(overwrite=overwrite)
1221
+ self.traces.append(trace)
1222
+
1279
1223
  # Reset trace context (span context resets automatically)
1280
1224
  current_trace_var.reset(trace_token)
1281
- # Reset in_traced_function_var
1282
- in_traced_function_var.reset(token)
1283
1225
  else:
1284
- # Already have a trace context, just create a span in it
1285
- # The span method handles current_span_var
1286
-
1287
- try:
1288
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1289
- # Record inputs
1290
- span.record_input({
1291
- 'args': str(args),
1292
- 'kwargs': kwargs
1293
- })
1294
-
1295
- # If deep tracing is enabled, apply monkey patching
1296
- if use_deep_tracing:
1297
- module, original_functions = self._apply_deep_tracing(func, span_type)
1298
-
1299
- # Execute function
1226
+ with current_trace.span(span_name, span_type=span_type) as span:
1227
+
1228
+ inputs = combine_args_kwargs(func, args, kwargs)
1229
+ span.record_input(inputs)
1230
+
1231
+ if use_deep_tracing:
1232
+ with _DeepTracer():
1233
+ result = func(*args, **kwargs)
1234
+ else:
1300
1235
  result = func(*args, **kwargs)
1301
1236
 
1302
- # Restore original functions if deep tracing was enabled
1303
- if use_deep_tracing and module and 'original_functions' in locals():
1304
- for name, obj in original_functions.items():
1305
- setattr(module, name, obj)
1306
-
1307
- # Record output
1308
- span.record_output(result)
1309
-
1310
- return result
1311
- finally:
1312
- # Reset in_traced_function_var
1313
- in_traced_function_var.reset(token)
1314
-
1315
- return wrapper
1316
-
1317
- def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
1318
- """
1319
- Decorator to trace function execution with detailed entry/exit information.
1320
- """
1321
- if func is None:
1322
- return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
1323
-
1324
- if asyncio.iscoroutinefunction(func):
1325
- @functools.wraps(func)
1326
- async def async_wrapper(*args, **kwargs):
1327
- # Get current trace from contextvars
1328
- current_trace = current_trace_var.get()
1329
- if current_trace and scorers:
1330
- current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1331
- return await func(*args, **kwargs)
1332
- return async_wrapper
1333
- else:
1334
- @functools.wraps(func)
1335
- def wrapper(*args, **kwargs):
1336
- # Get current trace from contextvars
1337
- current_trace = current_trace_var.get()
1338
- if current_trace and scorers:
1339
- current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1340
- return func(*args, **kwargs)
1237
+ span.record_output(result)
1238
+ return result
1239
+
1341
1240
  return wrapper
1342
1241
 
1343
1242
  def async_evaluate(self, *args, **kwargs):
1344
1243
  if not self.enable_evaluations:
1345
1244
  return
1346
1245
 
1347
- # Get current trace from context
1246
+ # --- Get trace_id passed explicitly (if any) ---
1247
+ passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
1248
+
1249
+ # --- Get current trace from context FIRST ---
1348
1250
  current_trace = current_trace_var.get()
1349
-
1251
+
1252
+ # --- Fallback Logic: Use active client only if context var is empty ---
1253
+ if not current_trace:
1254
+ current_trace = self._active_trace_client # Use the fallback
1255
+ # --- End Fallback Logic ---
1256
+
1350
1257
  if current_trace:
1258
+ # Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
1259
+ # (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
1260
+ if passed_trace_id:
1261
+ kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
1351
1262
  current_trace.async_evaluate(*args, **kwargs)
1352
1263
  else:
1353
- warnings.warn("No trace found, skipping evaluation")
1264
+ warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
1354
1265
 
1355
1266
 
1356
1267
  def wrap(client: Any) -> Any:
@@ -1359,7 +1270,7 @@ def wrap(client: Any) -> Any:
1359
1270
  Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1360
1271
  Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1361
1272
  """
1362
- span_name, original_create, original_stream = _get_client_config(client)
1273
+ span_name, original_create, responses_create, original_stream = _get_client_config(client)
1363
1274
 
1364
1275
  # --- Define Traced Async Functions ---
1365
1276
  async def traced_create_async(*args, **kwargs):
@@ -1457,7 +1368,41 @@ def wrap(client: Any) -> Any:
1457
1368
  span.record_output(output_data)
1458
1369
  return response_or_iterator
1459
1370
 
1371
+ # --- Define Traced Sync Functions ---
1372
+ def traced_response_create_sync(*args, **kwargs):
1373
+ # [Existing logic - unchanged]
1374
+ current_trace = current_trace_var.get()
1375
+ if not current_trace:
1376
+ return responses_create(*args, **kwargs)
1377
+
1378
+ is_streaming = kwargs.get("stream", False)
1379
+ with current_trace.span(span_name, span_type="llm") as span:
1380
+ span.record_input(kwargs)
1381
+
1382
+ # Warn about token counting limitations with streaming
1383
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1384
+ if not kwargs.get("stream_options", {}).get("include_usage"):
1385
+ warnings.warn(
1386
+ "OpenAI streaming calls don't include token counts by default. "
1387
+ "To enable token counting with streams, set stream_options={'include_usage': True} "
1388
+ "in your API call arguments.",
1389
+ UserWarning
1390
+ )
1460
1391
 
1392
+ try:
1393
+ response_or_iterator = responses_create(*args, **kwargs)
1394
+ except Exception as e:
1395
+ print(f"Error during wrapped sync API call ({span_name}): {e}")
1396
+ span.record_output({"error": str(e)})
1397
+ raise
1398
+ if is_streaming:
1399
+ output_entry = span.record_output("<pending stream>")
1400
+ return _sync_stream_wrapper(response_or_iterator, client, output_entry)
1401
+ else:
1402
+ output_data = _format_response_output_data(client, response_or_iterator)
1403
+ span.record_output(output_data)
1404
+ return response_or_iterator
1405
+
1461
1406
  # Function replacing sync .stream()
1462
1407
  def traced_stream_sync(*args, **kwargs):
1463
1408
  current_trace = current_trace_var.get()
@@ -1505,15 +1450,16 @@ def wrap(client: Any) -> Any:
1505
1450
  if original_stream:
1506
1451
  client.messages.stream = traced_stream_async
1507
1452
  elif isinstance(client, genai.client.AsyncClient):
1508
- client.generate_content = traced_create_async
1453
+ client.models.generate_content = traced_create_async
1509
1454
  elif isinstance(client, (OpenAI, Together)):
1510
1455
  client.chat.completions.create = traced_create_sync
1456
+ client.responses.create = traced_response_create_sync
1511
1457
  elif isinstance(client, Anthropic):
1512
1458
  client.messages.create = traced_create_sync
1513
1459
  if original_stream:
1514
1460
  client.messages.stream = traced_stream_sync
1515
1461
  elif isinstance(client, genai.Client):
1516
- client.generate_content = traced_create_sync
1462
+ client.models.generate_content = traced_create_sync
1517
1463
 
1518
1464
  return client
1519
1465
 
@@ -1529,19 +1475,20 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
1529
1475
  tuple: (span_name, create_method, stream_method)
1530
1476
  - span_name: String identifier for tracing
1531
1477
  - create_method: Reference to the client's creation method
1478
+ - responses_method: Reference to the client's responses method (if applicable)
1532
1479
  - stream_method: Reference to the client's stream method (if applicable)
1533
1480
 
1534
1481
  Raises:
1535
1482
  ValueError: If client type is not supported
1536
1483
  """
1537
1484
  if isinstance(client, (OpenAI, AsyncOpenAI)):
1538
- return "OPENAI_API_CALL", client.chat.completions.create, None
1485
+ return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None
1539
1486
  elif isinstance(client, (Together, AsyncTogether)):
1540
- return "TOGETHER_API_CALL", client.chat.completions.create, None
1487
+ return "TOGETHER_API_CALL", client.chat.completions.create, None, None
1541
1488
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1542
- return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
1489
+ return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream
1543
1490
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1544
- return "GOOGLE_API_CALL", client.models.generate_content, None
1491
+ return "GOOGLE_API_CALL", client.models.generate_content, None, None
1545
1492
  raise ValueError(f"Unsupported client type: {type(client)}")
1546
1493
 
1547
1494
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
@@ -1567,6 +1514,26 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
1567
1514
  "max_tokens": kwargs.get("max_tokens")
1568
1515
  }
1569
1516
 
1517
+ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
1518
+ """Format API response data based on client type.
1519
+
1520
+ Normalizes different response formats into a consistent structure
1521
+ for tracing purposes.
1522
+ """
1523
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1524
+ return {
1525
+ "content": response.output,
1526
+ "usage": {
1527
+ "prompt_tokens": response.usage.input_tokens,
1528
+ "completion_tokens": response.usage.output_tokens,
1529
+ "total_tokens": response.usage.total_tokens
1530
+ }
1531
+ }
1532
+ else:
1533
+ warnings.warn(f"Unsupported client type: {type(client)}")
1534
+ return {}
1535
+
1536
+
1570
1537
  def _format_output_data(client: ApiClient, response: Any) -> dict:
1571
1538
  """Format API response data based on client type.
1572
1539
 
@@ -1600,123 +1567,57 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1600
1567
  return {
1601
1568
  "content": response.content[0].text,
1602
1569
  "usage": {
1603
- "input_tokens": response.usage.input_tokens,
1604
- "output_tokens": response.usage.output_tokens,
1570
+ "prompt_tokens": response.usage.input_tokens,
1571
+ "completion_tokens": response.usage.output_tokens,
1605
1572
  "total_tokens": response.usage.input_tokens + response.usage.output_tokens
1606
1573
  }
1607
1574
  }
1608
1575
 
1609
- # Define a blocklist of functions that should not be traced
1610
- # These are typically utility functions, print statements, logging, etc.
1611
- _TRACE_BLOCKLIST = {
1612
- # Built-in functions
1613
- 'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
1614
- 'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
1615
- 'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
1616
- # Logging functions
1617
- 'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
1618
- # Common utility functions
1619
- 'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
1620
- # String operations
1621
- 'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
1622
- # Dict operations
1623
- 'get', 'items', 'keys', 'values', 'update',
1624
- # List operations
1625
- 'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
1626
- }
1627
-
1628
-
1629
- # Add a new function for deep tracing at the module level
1630
- def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
1576
+ def combine_args_kwargs(func, args, kwargs):
1631
1577
  """
1632
- Creates a wrapper for a function that automatically traces it when called within a traced function.
1633
- This enables deep tracing without requiring explicit @observe decorators on every function.
1578
+ Combine positional arguments and keyword arguments into a single dictionary.
1634
1579
 
1635
1580
  Args:
1636
- func: The function to wrap
1637
- tracer: The Tracer instance
1638
- span_type: Type of span (default "span")
1581
+ func: The function being called
1582
+ args: Tuple of positional arguments
1583
+ kwargs: Dictionary of keyword arguments
1639
1584
 
1640
1585
  Returns:
1641
- A wrapped function that will be traced when called
1586
+ A dictionary combining both args and kwargs
1642
1587
  """
1643
- # Skip wrapping if the function is not callable or is a built-in
1644
- if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
1645
- return func
1646
-
1647
- # Skip functions in the blocklist
1648
- if func.__name__ in _TRACE_BLOCKLIST:
1649
- return func
1650
-
1651
- # Skip functions from certain modules (logging, sys, etc.)
1652
- if func.__module__ and any(func.__module__.startswith(m) for m in ['logging', 'sys', 'os', 'json', 'time', 'datetime']):
1653
- return func
1654
-
1655
-
1656
- # Get function name for the span - check for custom name set by @observe
1657
- func_name = getattr(func, '_judgment_span_name', func.__name__)
1658
-
1659
- # Check for custom span_type set by @observe
1660
- func_span_type = getattr(func, '_judgment_span_type', "span")
1661
-
1662
- # Store original function to prevent losing reference
1663
- original_func = func
1664
-
1665
- # Create appropriate wrapper based on whether the function is async or not
1666
- if asyncio.iscoroutinefunction(func):
1667
- @functools.wraps(func)
1668
- async def async_deep_wrapper(*args, **kwargs):
1669
- # Get current trace from context
1670
- current_trace = current_trace_var.get()
1671
-
1672
- # If no trace context, just call the function
1673
- if not current_trace:
1674
- return await original_func(*args, **kwargs)
1675
-
1676
- # Create a span for this function call - use custom span_type if available
1677
- with current_trace.span(func_name, span_type=func_span_type) as span:
1678
- # Record inputs
1679
- span.record_input({
1680
- 'args': str(args),
1681
- 'kwargs': kwargs
1682
- })
1683
-
1684
- # Execute function
1685
- result = await original_func(*args, **kwargs)
1686
-
1687
- # Record output
1688
- span.record_output(result)
1689
-
1690
- return result
1691
-
1692
- return async_deep_wrapper
1693
- else:
1694
- @functools.wraps(func)
1695
- def deep_wrapper(*args, **kwargs):
1696
- # Get current trace from context
1697
- current_trace = current_trace_var.get()
1698
-
1699
- # If no trace context, just call the function
1700
- if not current_trace:
1701
- return original_func(*args, **kwargs)
1702
-
1703
- # Create a span for this function call - use custom span_type if available
1704
- with current_trace.span(func_name, span_type=func_span_type) as span:
1705
- # Record inputs
1706
- span.record_input({
1707
- 'args': str(args),
1708
- 'kwargs': kwargs
1709
- })
1710
-
1711
- # Execute function
1712
- result = original_func(*args, **kwargs)
1713
-
1714
- # Record output
1715
- span.record_output(result)
1716
-
1717
- return result
1718
-
1719
- return deep_wrapper
1588
+ try:
1589
+ import inspect
1590
+ sig = inspect.signature(func)
1591
+ param_names = list(sig.parameters.keys())
1592
+
1593
+ args_dict = {}
1594
+ for i, arg in enumerate(args):
1595
+ if i < len(param_names):
1596
+ args_dict[param_names[i]] = arg
1597
+ else:
1598
+ args_dict[f"arg{i}"] = arg
1599
+
1600
+ return {**args_dict, **kwargs}
1601
+ except Exception as e:
1602
+ # Fallback if signature inspection fails
1603
+ return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
1604
+
1605
+ # NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
1606
+ # @link https://docs.python.org/3.13/library/sysconfig.html
1607
+ _TRACE_FILEPATH_BLOCKLIST = tuple(
1608
+ os.path.realpath(p) + os.sep
1609
+ for p in {
1610
+ sysconfig.get_paths()['stdlib'],
1611
+ sysconfig.get_paths().get('platstdlib', ''),
1612
+ *site.getsitepackages(),
1613
+ site.getusersitepackages(),
1614
+ *(
1615
+ [os.path.join(os.path.dirname(__file__), '../../judgeval/')]
1616
+ if os.environ.get('JUDGMENT_DEV')
1617
+ else []
1618
+ ),
1619
+ } if p
1620
+ )
1720
1621
 
1721
1622
  # Add the new TraceThreadPoolExecutor class
1722
1623
  class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
@@ -1819,7 +1720,7 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
1819
1720
  def _sync_stream_wrapper(
1820
1721
  original_stream: Iterator,
1821
1722
  client: ApiClient,
1822
- output_entry: TraceEntry
1723
+ span: TraceSpan
1823
1724
  ) -> Generator[Any, None, None]:
1824
1725
  """Wraps a synchronous stream iterator to capture content and update the trace."""
1825
1726
  content_parts = [] # Use a list instead of string concatenation
@@ -1838,7 +1739,7 @@ def _sync_stream_wrapper(
1838
1739
  final_usage = _extract_usage_from_final_chunk(client, last_chunk)
1839
1740
 
1840
1741
  # Update the trace entry with the accumulated content and usage
1841
- output_entry.output = {
1742
+ span.output = {
1842
1743
  "content": "".join(content_parts), # Join list at the end
1843
1744
  "usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
1844
1745
  "streamed": True
@@ -1850,7 +1751,7 @@ def _sync_stream_wrapper(
1850
1751
  async def _async_stream_wrapper(
1851
1752
  original_stream: AsyncIterator,
1852
1753
  client: ApiClient,
1853
- output_entry: TraceEntry
1754
+ span: TraceSpan
1854
1755
  ) -> AsyncGenerator[Any, None]:
1855
1756
  # [Existing logic - unchanged]
1856
1757
  content_parts = [] # Use a list instead of string concatenation
@@ -1859,7 +1760,7 @@ async def _async_stream_wrapper(
1859
1760
  anthropic_input_tokens = 0
1860
1761
  anthropic_output_tokens = 0
1861
1762
 
1862
- target_span_id = getattr(output_entry, 'span_id', 'UNKNOWN')
1763
+ target_span_id = span.span_id
1863
1764
 
1864
1765
  try:
1865
1766
  async for chunk in original_stream:
@@ -1891,8 +1792,8 @@ async def _async_stream_wrapper(
1891
1792
  anthropic_final_usage = None
1892
1793
  if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
1893
1794
  anthropic_final_usage = {
1894
- "input_tokens": anthropic_input_tokens,
1895
- "output_tokens": anthropic_output_tokens,
1795
+ "prompt_tokens": anthropic_input_tokens,
1796
+ "completion_tokens": anthropic_output_tokens,
1896
1797
  "total_tokens": anthropic_input_tokens + anthropic_output_tokens
1897
1798
  }
1898
1799
 
@@ -1904,19 +1805,17 @@ async def _async_stream_wrapper(
1904
1805
  elif last_content_chunk:
1905
1806
  usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
1906
1807
 
1907
- if output_entry and hasattr(output_entry, 'output'):
1908
- output_entry.output = {
1808
+ if span and hasattr(span, 'output'):
1809
+ span.output = {
1909
1810
  "content": "".join(content_parts), # Join list at the end
1910
1811
  "usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
1911
1812
  "streamed": True
1912
1813
  }
1913
- start_ts = getattr(output_entry, 'created_at', time.time())
1914
- output_entry.duration = time.time() - start_ts
1814
+ start_ts = getattr(span, 'created_at', time.time())
1815
+ span.duration = time.time() - start_ts
1915
1816
  # else: # Handle error case if necessary, but remove debug print
1916
1817
 
1917
- # --- Define Context Manager Wrapper Classes ---
1918
- class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
1919
- """Wraps an original async stream manager to add tracing."""
1818
+ class _BaseStreamManagerWrapper:
1920
1819
  def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
1921
1820
  self._original_manager = original_manager
1922
1821
  self._client = client
@@ -1926,157 +1825,199 @@ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
1926
1825
  self._input_kwargs = input_kwargs
1927
1826
  self._parent_span_id_at_entry = None
1928
1827
 
1929
- async def __aenter__(self):
1930
- self._parent_span_id_at_entry = current_span_var.get()
1931
- if not self._trace_client:
1932
- # If no trace, just delegate to the original manager
1933
- return await self._original_manager.__aenter__()
1934
-
1935
- # --- Manually create the 'enter' entry ---
1828
+ def _create_span(self):
1936
1829
  start_time = time.time()
1937
1830
  span_id = str(uuid.uuid4())
1938
1831
  current_depth = 0
1939
1832
  if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
1940
1833
  current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
1941
1834
  self._trace_client._span_depths[span_id] = current_depth
1942
- enter_entry = TraceEntry(
1943
- type="enter", function=self._span_name, span_id=span_id,
1944
- trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
1945
- created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
1835
+ span = TraceSpan(
1836
+ function=self._span_name,
1837
+ span_id=span_id,
1838
+ trace_id=self._trace_client.trace_id,
1839
+ depth=current_depth,
1840
+ message=self._span_name,
1841
+ created_at=start_time,
1842
+ span_type="llm",
1843
+ parent_span_id=self._parent_span_id_at_entry
1946
1844
  )
1947
- self._trace_client.add_entry(enter_entry)
1948
- # --- End manual 'enter' entry ---
1949
-
1950
- # Set the current span ID in contextvars
1951
- self._span_context_token = current_span_var.set(span_id)
1845
+ self._trace_client.add_span(span)
1846
+ return span_id, span
1952
1847
 
1953
- # Manually create 'input' entry
1954
- input_data = _format_input_data(self._client, **self._input_kwargs)
1955
- input_entry = TraceEntry(
1956
- type="input", function=self._span_name, span_id=span_id,
1957
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
1958
- created_at=time.time(), inputs=input_data, span_type="llm"
1959
- )
1960
- self._trace_client.add_entry(input_entry)
1848
+ def _finalize_span(self, span_id):
1849
+ span = self._trace_client.span_id_to_span.get(span_id)
1850
+ if span:
1851
+ span.duration = time.time() - span.created_at
1852
+ if span_id in self._trace_client._span_depths:
1853
+ del self._trace_client._span_depths[span_id]
1961
1854
 
1962
- # Call the original __aenter__
1963
- raw_iterator = await self._original_manager.__aenter__()
1855
+ class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
1856
+ async def __aenter__(self):
1857
+ self._parent_span_id_at_entry = current_span_var.get()
1858
+ if not self._trace_client:
1859
+ return await self._original_manager.__aenter__()
1964
1860
 
1965
- # Manually create pending 'output' entry
1966
- output_entry = TraceEntry(
1967
- type="output", function=self._span_name, span_id=span_id,
1968
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
1969
- created_at=time.time(), output="<pending stream>", span_type="llm"
1970
- )
1971
- self._trace_client.add_entry(output_entry)
1861
+ span_id, span = self._create_span()
1862
+ self._span_context_token = current_span_var.set(span_id)
1863
+ span.inputs = _format_input_data(self._client, **self._input_kwargs)
1972
1864
 
1973
- # Wrap the raw iterator
1974
- wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
1975
- return wrapped_iterator
1865
+ # Call the original __aenter__ and expect it to be an async generator
1866
+ raw_iterator = await self._original_manager.__aenter__()
1867
+ span.output = "<pending stream>"
1868
+ return self._stream_wrapper_func(raw_iterator, self._client, span)
1976
1869
 
1977
1870
  async def __aexit__(self, exc_type, exc_val, exc_tb):
1978
- # Manually create the 'exit' entry
1979
1871
  if hasattr(self, '_span_context_token'):
1980
- span_id = current_span_var.get()
1981
- start_time_for_duration = 0
1982
- for entry in reversed(self._trace_client.entries):
1983
- if entry.span_id == span_id and entry.type == 'enter':
1984
- start_time_for_duration = entry.created_at
1985
- break
1986
- duration = time.time() - start_time_for_duration if start_time_for_duration else None
1987
- exit_depth = self._trace_client._span_depths.get(span_id, 0)
1988
- exit_entry = TraceEntry(
1989
- type="exit", function=self._span_name, span_id=span_id,
1990
- trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
1991
- created_at=time.time(), duration=duration, span_type="llm"
1992
- )
1993
- self._trace_client.add_entry(exit_entry)
1994
- if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
1995
- current_span_var.reset(self._span_context_token)
1996
- delattr(self, '_span_context_token')
1997
-
1998
- # Delegate __aexit__
1999
- if hasattr(self._original_manager, "__aexit__"):
2000
- return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2001
- return None
2002
-
2003
- class _TracedSyncStreamManagerWrapper(AbstractContextManager):
2004
- """Wraps an original sync stream manager to add tracing."""
2005
- def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
2006
- self._original_manager = original_manager
2007
- self._client = client
2008
- self._span_name = span_name
2009
- self._trace_client = trace_client
2010
- self._stream_wrapper_func = stream_wrapper_func
2011
- self._input_kwargs = input_kwargs
2012
- self._parent_span_id_at_entry = None
1872
+ span_id = current_span_var.get()
1873
+ self._finalize_span(span_id)
1874
+ current_span_var.reset(self._span_context_token)
1875
+ delattr(self, '_span_context_token')
1876
+ return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2013
1877
 
1878
+ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
2014
1879
  def __enter__(self):
2015
1880
  self._parent_span_id_at_entry = current_span_var.get()
2016
1881
  if not self._trace_client:
2017
- return self._original_manager.__enter__()
1882
+ return self._original_manager.__enter__()
2018
1883
 
2019
- # Manually create 'enter' entry
2020
- start_time = time.time()
2021
- span_id = str(uuid.uuid4())
2022
- current_depth = 0
2023
- if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2024
- current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
2025
- self._trace_client._span_depths[span_id] = current_depth
2026
- enter_entry = TraceEntry(
2027
- type="enter", function=self._span_name, span_id=span_id,
2028
- trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
2029
- created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
2030
- )
2031
- self._trace_client.add_entry(enter_entry)
1884
+ span_id, span = self._create_span()
2032
1885
  self._span_context_token = current_span_var.set(span_id)
1886
+ span.inputs = _format_input_data(self._client, **self._input_kwargs)
2033
1887
 
2034
- # Manually create 'input' entry
2035
- input_data = _format_input_data(self._client, **self._input_kwargs)
2036
- input_entry = TraceEntry(
2037
- type="input", function=self._span_name, span_id=span_id,
2038
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
2039
- created_at=time.time(), inputs=input_data, span_type="llm"
2040
- )
2041
- self._trace_client.add_entry(input_entry)
2042
-
2043
- # Call original __enter__
2044
1888
  raw_iterator = self._original_manager.__enter__()
2045
-
2046
- # Manually create 'output' entry (pending)
2047
- output_entry = TraceEntry(
2048
- type="output", function=self._span_name, span_id=span_id,
2049
- trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
2050
- created_at=time.time(), output="<pending stream>", span_type="llm"
2051
- )
2052
- self._trace_client.add_entry(output_entry)
2053
-
2054
- # Wrap the raw iterator
2055
- wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
2056
- return wrapped_iterator
1889
+ span.output = "<pending stream>"
1890
+ return self._stream_wrapper_func(raw_iterator, self._client, span)
2057
1891
 
2058
1892
  def __exit__(self, exc_type, exc_val, exc_tb):
2059
- # Manually create 'exit' entry
2060
1893
  if hasattr(self, '_span_context_token'):
2061
- span_id = current_span_var.get()
2062
- start_time_for_duration = 0
2063
- for entry in reversed(self._trace_client.entries):
2064
- if entry.span_id == span_id and entry.type == 'enter':
2065
- start_time_for_duration = entry.created_at
2066
- break
2067
- duration = time.time() - start_time_for_duration if start_time_for_duration else None
2068
- exit_depth = self._trace_client._span_depths.get(span_id, 0)
2069
- exit_entry = TraceEntry(
2070
- type="exit", function=self._span_name, span_id=span_id,
2071
- trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
2072
- created_at=time.time(), duration=duration, span_type="llm"
2073
- )
2074
- self._trace_client.add_entry(exit_entry)
2075
- if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
2076
- current_span_var.reset(self._span_context_token)
2077
- delattr(self, '_span_context_token')
2078
-
2079
- # Delegate __exit__
2080
- if hasattr(self._original_manager, "__exit__"):
2081
- return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
1894
+ span_id = current_span_var.get()
1895
+ self._finalize_span(span_id)
1896
+ current_span_var.reset(self._span_context_token)
1897
+ delattr(self, '_span_context_token')
1898
+ return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
1899
+
1900
+ # --- NEW Generalized Helper Function (Moved from demo) ---
1901
+ def prepare_evaluation_for_state(
1902
+ scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
1903
+ example: Optional[Example] = None,
1904
+ # --- Individual components (alternative to 'example') ---
1905
+ input: Optional[str] = None,
1906
+ actual_output: Optional[Union[str, List[str]]] = None,
1907
+ expected_output: Optional[Union[str, List[str]]] = None,
1908
+ context: Optional[List[str]] = None,
1909
+ retrieval_context: Optional[List[str]] = None,
1910
+ tools_called: Optional[List[str]] = None,
1911
+ expected_tools: Optional[List[str]] = None,
1912
+ additional_metadata: Optional[Dict[str, Any]] = None,
1913
+ # --- Other eval parameters ---
1914
+ model: Optional[str] = None,
1915
+ log_results: Optional[bool] = True
1916
+ ) -> Optional[EvaluationConfig]:
1917
+ """
1918
+ Prepares an EvaluationConfig object, similar to TraceClient.async_evaluate.
1919
+
1920
+ Accepts either a pre-made Example object or individual components to construct one.
1921
+ Returns the EvaluationConfig object ready to be placed in the state, or None.
1922
+ """
1923
+ final_example = example
1924
+
1925
+ # If example is not provided, try to construct one from individual parts
1926
+ if final_example is None:
1927
+ # Basic validation: Ensure at least actual_output is present for most scorers
1928
+ if actual_output is None:
1929
+ # print("[prepare_evaluation_for_state] Warning: 'actual_output' is required when 'example' is not provided. Skipping evaluation setup.")
1930
+ return None
1931
+ try:
1932
+ final_example = Example(
1933
+ input=input,
1934
+ actual_output=actual_output,
1935
+ expected_output=expected_output,
1936
+ context=context,
1937
+ retrieval_context=retrieval_context,
1938
+ tools_called=tools_called,
1939
+ expected_tools=expected_tools,
1940
+ additional_metadata=additional_metadata,
1941
+ # trace_id will be set by the handler later if needed
1942
+ )
1943
+ # print("[prepare_evaluation_for_state] Constructed Example from individual components.")
1944
+ except Exception as e:
1945
+ # print(f"[prepare_evaluation_for_state] Error constructing Example: {e}. Skipping evaluation setup.")
1946
+ return None
1947
+
1948
+ # If we have a valid example (provided or constructed) and scorers
1949
+ if final_example and scorers:
1950
+ # TODO: Add validation like check_examples if needed here,
1951
+ # although the handler might implicitly handle some checks via TraceClient.
1952
+ return EvaluationConfig(
1953
+ scorers=scorers,
1954
+ example=final_example,
1955
+ model=model,
1956
+ log_results=log_results
1957
+ )
1958
+ elif not scorers:
1959
+ # print("[prepare_evaluation_for_state] No scorers provided. Skipping evaluation setup.")
2082
1960
  return None
1961
+ else: # No valid example
1962
+ # print("[prepare_evaluation_for_state] No valid Example available. Skipping evaluation setup.")
1963
+ return None
1964
+ # --- End NEW Helper Function ---
1965
+
1966
+ # --- NEW: Helper function to simplify adding eval config to state ---
1967
+ def add_evaluation_to_state(
1968
+ state: Dict[str, Any], # The LangGraph state dictionary
1969
+ scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
1970
+ # --- Evaluation components (same as prepare_evaluation_for_state) ---
1971
+ input: Optional[str] = None,
1972
+ actual_output: Optional[Union[str, List[str]]] = None,
1973
+ expected_output: Optional[Union[str, List[str]]] = None,
1974
+ context: Optional[List[str]] = None,
1975
+ retrieval_context: Optional[List[str]] = None,
1976
+ tools_called: Optional[List[str]] = None,
1977
+ expected_tools: Optional[List[str]] = None,
1978
+ additional_metadata: Optional[Dict[str, Any]] = None,
1979
+ # --- Other eval parameters ---
1980
+ model: Optional[str] = None,
1981
+ log_results: Optional[bool] = True
1982
+ ) -> None:
1983
+ """
1984
+ Prepares an EvaluationConfig and adds it to the state dictionary
1985
+ under the '_judgeval_eval' key if successful.
1986
+
1987
+ This simplifies the process of setting up evaluations within LangGraph nodes.
1988
+
1989
+ Args:
1990
+ state: The LangGraph state dictionary to modify.
1991
+ scorers: List of scorer instances.
1992
+ input: Input for the evaluation example.
1993
+ actual_output: Actual output for the evaluation example.
1994
+ expected_output: Expected output for the evaluation example.
1995
+ context: Context for the evaluation example.
1996
+ retrieval_context: Retrieval context for the evaluation example.
1997
+ tools_called: Tools called for the evaluation example.
1998
+ expected_tools: Expected tools for the evaluation example.
1999
+ additional_metadata: Additional metadata for the evaluation example.
2000
+ model: Model name used for generation (optional).
2001
+ log_results: Whether to log evaluation results (optional, defaults to True).
2002
+ """
2003
+ eval_config = prepare_evaluation_for_state(
2004
+ scorers=scorers,
2005
+ input=input,
2006
+ actual_output=actual_output,
2007
+ expected_output=expected_output,
2008
+ context=context,
2009
+ retrieval_context=retrieval_context,
2010
+ tools_called=tools_called,
2011
+ expected_tools=expected_tools,
2012
+ additional_metadata=additional_metadata,
2013
+ model=model,
2014
+ log_results=log_results
2015
+ )
2016
+
2017
+ if eval_config:
2018
+ state["_judgeval_eval"] = eval_config
2019
+ # print(f"[_judgeval_eval added to state for node]") # Optional: Log confirmation
2020
+
2021
+ # print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
2022
+ # --- End NEW Helper ---
2023
+