judgeval 0.0.36__py3-none-any.whl → 0.0.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +565 -858
- judgeval/common/utils.py +18 -0
- judgeval/constants.py +3 -1
- judgeval/data/__init__.py +4 -0
- judgeval/data/datasets/dataset.py +0 -2
- judgeval/data/example.py +29 -7
- judgeval/data/sequence.py +5 -4
- judgeval/data/sequence_run.py +4 -3
- judgeval/data/trace.py +129 -0
- judgeval/evaluation_run.py +1 -1
- judgeval/integrations/langgraph.py +18 -17
- judgeval/judgment_client.py +77 -64
- judgeval/run_evaluation.py +126 -29
- judgeval/scorers/__init__.py +2 -0
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -0
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +18 -0
- judgeval/scorers/score.py +1 -1
- judgeval/utils/data_utils.py +57 -0
- judgeval-0.0.37.dist-info/METADATA +214 -0
- {judgeval-0.0.36.dist-info → judgeval-0.0.37.dist-info}/RECORD +22 -19
- judgeval-0.0.36.dist-info/METADATA +0 -169
- {judgeval-0.0.36.dist-info → judgeval-0.0.37.dist-info}/WHEEL +0 -0
- {judgeval-0.0.36.dist-info → judgeval-0.0.37.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -7,7 +7,11 @@ import functools
|
|
7
7
|
import inspect
|
8
8
|
import json
|
9
9
|
import os
|
10
|
+
import site
|
11
|
+
import sysconfig
|
12
|
+
import threading
|
10
13
|
import time
|
14
|
+
import traceback
|
11
15
|
import uuid
|
12
16
|
import warnings
|
13
17
|
import contextvars
|
@@ -35,7 +39,6 @@ from rich import print as rprint
|
|
35
39
|
import types # <--- Add this import
|
36
40
|
|
37
41
|
# Third-party imports
|
38
|
-
import pika
|
39
42
|
import requests
|
40
43
|
from litellm import cost_per_token
|
41
44
|
from pydantic import BaseModel
|
@@ -44,10 +47,10 @@ from openai import OpenAI, AsyncOpenAI
|
|
44
47
|
from together import Together, AsyncTogether
|
45
48
|
from anthropic import Anthropic, AsyncAnthropic
|
46
49
|
from google import genai
|
47
|
-
from judgeval.run_evaluation import check_examples
|
48
50
|
|
49
51
|
# Local application/library-specific imports
|
50
52
|
from judgeval.constants import (
|
53
|
+
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
51
54
|
JUDGMENT_TRACES_SAVE_API_URL,
|
52
55
|
JUDGMENT_TRACES_FETCH_API_URL,
|
53
56
|
RABBITMQ_HOST,
|
@@ -56,25 +59,24 @@ from judgeval.constants import (
|
|
56
59
|
JUDGMENT_TRACES_DELETE_API_URL,
|
57
60
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
58
61
|
)
|
59
|
-
from judgeval.
|
60
|
-
from judgeval.data import Example
|
62
|
+
from judgeval.data import Example, Trace, TraceSpan
|
61
63
|
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
62
64
|
from judgeval.rules import Rule
|
63
65
|
from judgeval.evaluation_run import EvaluationRun
|
64
66
|
from judgeval.data.result import ScoringResult
|
67
|
+
from judgeval.common.utils import validate_api_key
|
68
|
+
from judgeval.common.exceptions import JudgmentAPIError
|
65
69
|
|
66
70
|
# Standard library imports needed for the new class
|
67
71
|
import concurrent.futures
|
68
72
|
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
69
73
|
|
70
74
|
# Define context variables for tracking the current trace and the current span within a trace
|
71
|
-
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
75
|
+
current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
|
72
76
|
current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
|
73
|
-
in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
|
74
77
|
|
75
78
|
# Define type aliases for better code readability and maintainability
|
76
79
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
|
77
|
-
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
78
80
|
SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
|
79
81
|
|
80
82
|
# --- Evaluation Config Dataclass (Moved from langgraph.py) ---
|
@@ -87,154 +89,26 @@ class EvaluationConfig:
|
|
87
89
|
log_results: Optional[bool] = True
|
88
90
|
# --- End Evaluation Config Dataclass ---
|
89
91
|
|
92
|
+
# Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
|
90
93
|
@dataclass
|
91
|
-
class
|
92
|
-
"""Represents a single
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
- output: Output: (function return value)
|
98
|
-
- input: Input: (function parameters)
|
99
|
-
- evaluation: Evaluation: (evaluation results)
|
100
|
-
"""
|
101
|
-
type: TraceEntryType
|
102
|
-
span_id: str # Unique ID for this specific span instance
|
103
|
-
depth: int # Indentation level for nested calls
|
104
|
-
created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
|
105
|
-
function: Optional[str] = None # Name of the function being traced
|
106
|
-
message: Optional[str] = None # Human-readable description
|
107
|
-
duration: Optional[float] = None # Time taken (for exit/evaluation entries)
|
108
|
-
trace_id: str = None # ID of the trace this entry belongs to
|
109
|
-
output: Any = None # Function output value
|
110
|
-
# Use field() for mutable defaults to avoid shared state issues
|
111
|
-
inputs: dict = field(default_factory=dict)
|
112
|
-
span_type: SpanType = "span"
|
113
|
-
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
114
|
-
parent_span_id: Optional[str] = None # ID of the parent span instance
|
115
|
-
|
116
|
-
def print_entry(self):
|
117
|
-
"""Print a trace entry with proper formatting and parent relationship information."""
|
118
|
-
indent = " " * self.depth
|
119
|
-
|
120
|
-
if self.type == "enter":
|
121
|
-
# Format parent info if present
|
122
|
-
parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
|
123
|
-
print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
|
124
|
-
elif self.type == "exit":
|
125
|
-
print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
|
126
|
-
elif self.type == "output":
|
127
|
-
# Format output to align properly
|
128
|
-
output_str = str(self.output)
|
129
|
-
print(f"{indent}Output (for id: {self.span_id}): {output_str}")
|
130
|
-
elif self.type == "input":
|
131
|
-
# Format inputs to align properly
|
132
|
-
print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
|
133
|
-
elif self.type == "evaluation":
|
134
|
-
for evaluation_run in self.evaluation_runs:
|
135
|
-
print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
|
136
|
-
|
137
|
-
def _serialize_inputs(self) -> dict:
|
138
|
-
"""Helper method to serialize input data safely.
|
139
|
-
|
140
|
-
Returns a dict with serializable versions of inputs, converting non-serializable
|
141
|
-
objects to None with a warning.
|
142
|
-
"""
|
143
|
-
serialized_inputs = {}
|
144
|
-
for key, value in self.inputs.items():
|
145
|
-
if isinstance(value, BaseModel):
|
146
|
-
serialized_inputs[key] = value.model_dump()
|
147
|
-
elif isinstance(value, (list, tuple)):
|
148
|
-
# Handle lists/tuples of arguments
|
149
|
-
serialized_inputs[key] = [
|
150
|
-
item.model_dump() if isinstance(item, BaseModel)
|
151
|
-
else None if not self._is_json_serializable(item)
|
152
|
-
else item
|
153
|
-
for item in value
|
154
|
-
]
|
155
|
-
else:
|
156
|
-
if self._is_json_serializable(value):
|
157
|
-
serialized_inputs[key] = value
|
158
|
-
else:
|
159
|
-
serialized_inputs[key] = self.safe_stringify(value, self.function)
|
160
|
-
return serialized_inputs
|
161
|
-
|
162
|
-
def _is_json_serializable(self, obj: Any) -> bool:
|
163
|
-
"""Helper method to check if an object is JSON serializable."""
|
164
|
-
try:
|
165
|
-
json.dumps(obj)
|
166
|
-
return True
|
167
|
-
except (TypeError, OverflowError, ValueError):
|
168
|
-
return False
|
169
|
-
|
170
|
-
def safe_stringify(self, output, function_name):
|
171
|
-
"""
|
172
|
-
Safely converts an object to a string or repr, handling serialization issues gracefully.
|
173
|
-
"""
|
174
|
-
try:
|
175
|
-
return str(output)
|
176
|
-
except (TypeError, OverflowError, ValueError):
|
177
|
-
pass
|
178
|
-
|
179
|
-
try:
|
180
|
-
return repr(output)
|
181
|
-
except (TypeError, OverflowError, ValueError):
|
182
|
-
pass
|
183
|
-
|
184
|
-
warnings.warn(
|
185
|
-
f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
|
186
|
-
)
|
187
|
-
return None
|
94
|
+
class TraceAnnotation:
|
95
|
+
"""Represents a single annotation for a trace span."""
|
96
|
+
span_id: str
|
97
|
+
text: str
|
98
|
+
label: str
|
99
|
+
score: int
|
188
100
|
|
189
101
|
def to_dict(self) -> dict:
|
190
|
-
"""Convert the
|
102
|
+
"""Convert the annotation to a dictionary format for storage/transmission."""
|
191
103
|
return {
|
192
|
-
"type": self.type,
|
193
|
-
"function": self.function,
|
194
104
|
"span_id": self.span_id,
|
195
|
-
"
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
"output": self._serialize_output(),
|
201
|
-
"inputs": self._serialize_inputs(),
|
202
|
-
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
203
|
-
"span_type": self.span_type,
|
204
|
-
"parent_span_id": self.parent_span_id,
|
105
|
+
"annotation": {
|
106
|
+
"text": self.text,
|
107
|
+
"label": self.label,
|
108
|
+
"score": self.score
|
109
|
+
}
|
205
110
|
}
|
206
|
-
|
207
|
-
def _serialize_output(self) -> Any:
|
208
|
-
"""Helper method to serialize output data safely.
|
209
|
-
|
210
|
-
Handles special cases:
|
211
|
-
- Pydantic models are converted using model_dump()
|
212
|
-
- Dictionaries are processed recursively to handle non-serializable values.
|
213
|
-
- We try to serialize into JSON, then string, then the base representation (__repr__)
|
214
|
-
- Non-serializable objects return None with a warning
|
215
|
-
"""
|
216
|
-
|
217
|
-
def serialize_value(value):
|
218
|
-
if isinstance(value, BaseModel):
|
219
|
-
return value.model_dump()
|
220
|
-
elif isinstance(value, dict):
|
221
|
-
# Recursively serialize dictionary values
|
222
|
-
return {k: serialize_value(v) for k, v in value.items()}
|
223
|
-
elif isinstance(value, (list, tuple)):
|
224
|
-
# Recursively serialize list/tuple items
|
225
|
-
return [serialize_value(item) for item in value]
|
226
|
-
else:
|
227
|
-
# Try direct JSON serialization first
|
228
|
-
try:
|
229
|
-
json.dumps(value)
|
230
|
-
return value
|
231
|
-
except (TypeError, OverflowError, ValueError):
|
232
|
-
# Fallback to safe stringification
|
233
|
-
return self.safe_stringify(value, self.function)
|
234
|
-
|
235
|
-
# Start serialization with the top-level output
|
236
|
-
return serialize_value(self.output)
|
237
|
-
|
111
|
+
|
238
112
|
class TraceManagerClient:
|
239
113
|
"""
|
240
114
|
Client for handling trace endpoints with the Judgment API
|
@@ -271,8 +145,6 @@ class TraceManagerClient:
|
|
271
145
|
raise ValueError(f"Failed to fetch traces: {response.text}")
|
272
146
|
|
273
147
|
return response.json()
|
274
|
-
|
275
|
-
|
276
148
|
|
277
149
|
def save_trace(self, trace_data: dict):
|
278
150
|
"""
|
@@ -315,6 +187,33 @@ class TraceManagerClient:
|
|
315
187
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
|
316
188
|
rprint(pretty_str)
|
317
189
|
|
190
|
+
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
191
|
+
def save_annotation(self, annotation: TraceAnnotation):
|
192
|
+
json_data = {
|
193
|
+
"span_id": annotation.span_id,
|
194
|
+
"annotation": {
|
195
|
+
"text": annotation.text,
|
196
|
+
"label": annotation.label,
|
197
|
+
"score": annotation.score
|
198
|
+
}
|
199
|
+
}
|
200
|
+
|
201
|
+
response = requests.post(
|
202
|
+
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
203
|
+
json=json_data,
|
204
|
+
headers={
|
205
|
+
'Content-Type': 'application/json',
|
206
|
+
'Authorization': f'Bearer {self.judgment_api_key}',
|
207
|
+
'X-Organization-Id': self.organization_id
|
208
|
+
},
|
209
|
+
verify=True
|
210
|
+
)
|
211
|
+
|
212
|
+
if response.status_code != HTTPStatus.OK:
|
213
|
+
raise ValueError(f"Failed to save annotation: {response.text}")
|
214
|
+
|
215
|
+
return response.json()
|
216
|
+
|
318
217
|
def delete_trace(self, trace_id: str):
|
319
218
|
"""
|
320
219
|
Delete a trace from the database.
|
@@ -405,15 +304,16 @@ class TraceClient:
|
|
405
304
|
self.enable_evaluations = enable_evaluations
|
406
305
|
self.parent_trace_id = parent_trace_id
|
407
306
|
self.parent_name = parent_name
|
408
|
-
self.
|
409
|
-
self.
|
307
|
+
self.trace_spans: List[TraceSpan] = []
|
308
|
+
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
309
|
+
self.evaluation_runs: List[EvaluationRun] = []
|
310
|
+
self.annotations: List[TraceAnnotation] = []
|
410
311
|
self.start_time = time.time()
|
411
312
|
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
|
412
313
|
self.visited_nodes = []
|
413
314
|
self.executed_tools = []
|
414
315
|
self.executed_node_tools = []
|
415
316
|
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
416
|
-
|
417
317
|
def get_current_span(self):
|
418
318
|
"""Get the current span from the context var"""
|
419
319
|
return current_span_var.get()
|
@@ -443,9 +343,7 @@ class TraceClient:
|
|
443
343
|
|
444
344
|
self._span_depths[span_id] = current_depth # Store depth by span_id
|
445
345
|
|
446
|
-
|
447
|
-
type="enter",
|
448
|
-
function=name,
|
346
|
+
span = TraceSpan(
|
449
347
|
span_id=span_id,
|
450
348
|
trace_id=self.trace_id,
|
451
349
|
depth=current_depth,
|
@@ -453,25 +351,15 @@ class TraceClient:
|
|
453
351
|
created_at=start_time,
|
454
352
|
span_type=span_type,
|
455
353
|
parent_span_id=parent_span_id,
|
354
|
+
function=name,
|
456
355
|
)
|
457
|
-
self.
|
356
|
+
self.add_span(span)
|
458
357
|
|
459
358
|
try:
|
460
359
|
yield self
|
461
360
|
finally:
|
462
361
|
duration = time.time() - start_time
|
463
|
-
|
464
|
-
self.add_entry(TraceEntry(
|
465
|
-
type="exit",
|
466
|
-
function=name,
|
467
|
-
span_id=span_id, # Use the same span_id for exit
|
468
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
469
|
-
depth=exit_depth,
|
470
|
-
message=f"← {name}",
|
471
|
-
created_at=time.time(),
|
472
|
-
duration=duration,
|
473
|
-
span_type=span_type,
|
474
|
-
))
|
362
|
+
span.duration = duration
|
475
363
|
# Clean up depth tracking for this span_id
|
476
364
|
if span_id in self._span_depths:
|
477
365
|
del self._span_depths[span_id]
|
@@ -528,13 +416,13 @@ class TraceClient:
|
|
528
416
|
tools_called=tools_called,
|
529
417
|
expected_tools=expected_tools,
|
530
418
|
additional_metadata=additional_metadata,
|
531
|
-
trace_id=self.trace_id
|
532
419
|
)
|
533
420
|
else:
|
534
421
|
raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
|
535
422
|
|
536
423
|
# Check examples before creating evaluation run
|
537
|
-
|
424
|
+
|
425
|
+
# check_examples([example], scorers)
|
538
426
|
|
539
427
|
# --- Modification: Capture span_id immediately ---
|
540
428
|
# span_id_at_eval_call = current_span_var.get()
|
@@ -571,290 +459,60 @@ class TraceClient:
|
|
571
459
|
# --- End Modification ---
|
572
460
|
|
573
461
|
if current_span_id:
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
584
|
-
function_name = entry.function
|
585
|
-
break
|
586
|
-
|
587
|
-
# Get depth for the current span
|
588
|
-
current_depth = self._span_depths.get(current_span_id, 0)
|
589
|
-
|
590
|
-
self.add_entry(TraceEntry(
|
591
|
-
type="evaluation",
|
592
|
-
function=function_name,
|
593
|
-
span_id=current_span_id, # Associate with current span
|
594
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
595
|
-
depth=current_depth,
|
596
|
-
message=f"Evaluation results for {function_name}",
|
597
|
-
created_at=time.time(),
|
598
|
-
evaluation_runs=[eval_run],
|
599
|
-
duration=duration,
|
600
|
-
span_type="evaluation"
|
601
|
-
))
|
602
|
-
|
462
|
+
span = self.span_id_to_span[current_span_id]
|
463
|
+
span.evaluation_runs.append(eval_run)
|
464
|
+
self.evaluation_runs.append(eval_run)
|
465
|
+
|
466
|
+
def add_annotation(self, annotation: TraceAnnotation):
|
467
|
+
"""Add an annotation to this trace context"""
|
468
|
+
self.annotations.append(annotation)
|
469
|
+
return self
|
470
|
+
|
603
471
|
def record_input(self, inputs: dict):
|
604
472
|
current_span_id = current_span_var.get()
|
605
473
|
if current_span_id:
|
606
|
-
|
607
|
-
|
608
|
-
function_name = "unknown_function" # Default
|
609
|
-
for entry in reversed(self.entries):
|
610
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
611
|
-
entry_span_type = entry.span_type
|
612
|
-
function_name = entry.function
|
613
|
-
break
|
614
|
-
|
615
|
-
self.add_entry(TraceEntry(
|
616
|
-
type="input",
|
617
|
-
function=function_name,
|
618
|
-
span_id=current_span_id, # Use current span_id from context
|
619
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
620
|
-
depth=current_depth,
|
621
|
-
message=f"Inputs to {function_name}",
|
622
|
-
created_at=time.time(),
|
623
|
-
inputs=inputs,
|
624
|
-
span_type=entry_span_type,
|
625
|
-
))
|
626
|
-
# Removed else block - original didn't have one
|
474
|
+
span = self.span_id_to_span[current_span_id]
|
475
|
+
span.inputs = inputs
|
627
476
|
|
628
|
-
async def _update_coroutine_output(self,
|
477
|
+
async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
|
629
478
|
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
630
479
|
try:
|
631
480
|
result = await coroutine
|
632
|
-
|
481
|
+
span.output = result
|
633
482
|
return result
|
634
483
|
except Exception as e:
|
635
|
-
|
484
|
+
span.output = f"Error: {str(e)}"
|
636
485
|
raise
|
637
486
|
|
638
487
|
def record_output(self, output: Any):
|
639
488
|
current_span_id = current_span_var.get()
|
640
489
|
if current_span_id:
|
641
|
-
|
642
|
-
|
643
|
-
function_name = "unknown_function" # Default
|
644
|
-
for entry in reversed(self.entries):
|
645
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
646
|
-
entry_span_type = entry.span_type
|
647
|
-
function_name = entry.function
|
648
|
-
break
|
649
|
-
|
650
|
-
entry = TraceEntry(
|
651
|
-
type="output",
|
652
|
-
function=function_name,
|
653
|
-
span_id=current_span_id, # Use current span_id from context
|
654
|
-
depth=current_depth,
|
655
|
-
message=f"Output from {function_name}",
|
656
|
-
created_at=time.time(),
|
657
|
-
output="<pending>" if inspect.iscoroutine(output) else output,
|
658
|
-
span_type=entry_span_type,
|
659
|
-
trace_id=self.trace_id # Added trace_id for consistency
|
660
|
-
)
|
661
|
-
self.add_entry(entry)
|
490
|
+
span = self.span_id_to_span[current_span_id]
|
491
|
+
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
662
492
|
|
663
493
|
if inspect.iscoroutine(output):
|
664
|
-
asyncio.create_task(self._update_coroutine_output(
|
494
|
+
asyncio.create_task(self._update_coroutine_output(span, output))
|
665
495
|
|
666
|
-
return
|
496
|
+
return span # Return the created entry
|
667
497
|
# Removed else block - original didn't have one
|
668
498
|
return None # Return None if no span_id found
|
669
499
|
|
670
|
-
def
|
671
|
-
"""Add a trace
|
672
|
-
self.
|
500
|
+
def add_span(self, span: TraceSpan):
|
501
|
+
"""Add a trace span to this trace context"""
|
502
|
+
self.trace_spans.append(span)
|
503
|
+
self.span_id_to_span[span.span_id] = span
|
673
504
|
return self
|
674
505
|
|
675
506
|
def print(self):
|
676
507
|
"""Print the complete trace with proper visual structure"""
|
677
|
-
for
|
678
|
-
|
679
|
-
|
680
|
-
def print_hierarchical(self):
|
681
|
-
"""Print the trace in a hierarchical structure based on parent-child relationships"""
|
682
|
-
# First, build a map of spans
|
683
|
-
spans = {}
|
684
|
-
root_spans = []
|
685
|
-
|
686
|
-
# Collect all enter events first
|
687
|
-
for entry in self.entries:
|
688
|
-
if entry.type == "enter":
|
689
|
-
spans[entry.function] = {
|
690
|
-
"name": entry.function,
|
691
|
-
"depth": entry.depth,
|
692
|
-
"parent_id": entry.parent_span_id,
|
693
|
-
"children": []
|
694
|
-
}
|
695
|
-
|
696
|
-
# If no parent, it's a root span
|
697
|
-
if not entry.parent_span_id:
|
698
|
-
root_spans.append(entry.function)
|
699
|
-
elif entry.parent_span_id not in spans:
|
700
|
-
# If parent doesn't exist yet, temporarily treat as root
|
701
|
-
# (we'll fix this later)
|
702
|
-
root_spans.append(entry.function)
|
703
|
-
|
704
|
-
# Build parent-child relationships
|
705
|
-
for span_name, span in spans.items():
|
706
|
-
parent = span["parent_id"]
|
707
|
-
if parent and parent in spans:
|
708
|
-
spans[parent]["children"].append(span_name)
|
709
|
-
# Remove from root spans if it was temporarily there
|
710
|
-
if span_name in root_spans:
|
711
|
-
root_spans.remove(span_name)
|
712
|
-
|
713
|
-
# Now print the hierarchy
|
714
|
-
def print_span(span_name, level=0):
|
715
|
-
if span_name not in spans:
|
716
|
-
return
|
717
|
-
|
718
|
-
span = spans[span_name]
|
719
|
-
indent = " " * level
|
720
|
-
parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
|
721
|
-
print(f"{indent}→ {span_name}{parent_info}")
|
722
|
-
|
723
|
-
# Print children
|
724
|
-
for child in span["children"]:
|
725
|
-
print_span(child, level + 1)
|
726
|
-
|
727
|
-
# Print starting with root spans
|
728
|
-
print("\nHierarchical Trace Structure:")
|
729
|
-
for root in root_spans:
|
730
|
-
print_span(root)
|
508
|
+
for span in self.trace_spans:
|
509
|
+
span.print_span()
|
731
510
|
|
732
511
|
def get_duration(self) -> float:
|
733
512
|
"""
|
734
513
|
Get the total duration of this trace
|
735
514
|
"""
|
736
515
|
return time.time() - self.start_time
|
737
|
-
|
738
|
-
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
739
|
-
"""
|
740
|
-
Condenses trace entries into a single entry for each span instance,
|
741
|
-
preserving parent-child span relationships using span_id and parent_span_id.
|
742
|
-
"""
|
743
|
-
spans_by_id: Dict[str, dict] = {}
|
744
|
-
evaluation_runs: List[EvaluationRun] = []
|
745
|
-
|
746
|
-
# First pass: Group entries by span_id and gather data
|
747
|
-
for entry in entries:
|
748
|
-
span_id = entry.get("span_id")
|
749
|
-
if not span_id:
|
750
|
-
continue # Skip entries without a span_id (should not happen)
|
751
|
-
|
752
|
-
if entry["type"] == "enter":
|
753
|
-
if span_id not in spans_by_id:
|
754
|
-
spans_by_id[span_id] = {
|
755
|
-
"span_id": span_id,
|
756
|
-
"function": entry["function"],
|
757
|
-
"depth": entry["depth"], # Use the depth recorded at entry time
|
758
|
-
"created_at": entry["created_at"],
|
759
|
-
"trace_id": entry["trace_id"],
|
760
|
-
"parent_span_id": entry.get("parent_span_id"),
|
761
|
-
"span_type": entry.get("span_type", "span"),
|
762
|
-
"inputs": None,
|
763
|
-
"output": None,
|
764
|
-
"evaluation_runs": [],
|
765
|
-
"duration": None
|
766
|
-
}
|
767
|
-
# Handle potential duplicate enter events if necessary (e.g., log warning)
|
768
|
-
|
769
|
-
elif span_id in spans_by_id:
|
770
|
-
current_span_data = spans_by_id[span_id]
|
771
|
-
|
772
|
-
if entry["type"] == "input" and entry["inputs"]:
|
773
|
-
# Merge inputs if multiple are recorded, or just assign
|
774
|
-
if current_span_data["inputs"] is None:
|
775
|
-
current_span_data["inputs"] = entry["inputs"]
|
776
|
-
elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
|
777
|
-
current_span_data["inputs"].update(entry["inputs"])
|
778
|
-
# Add more sophisticated merging if needed
|
779
|
-
|
780
|
-
elif entry["type"] == "output" and "output" in entry:
|
781
|
-
current_span_data["output"] = entry["output"]
|
782
|
-
|
783
|
-
elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
|
784
|
-
if current_span_data.get("evaluation_runs") is not None:
|
785
|
-
evaluation_runs.extend(entry["evaluation_runs"])
|
786
|
-
|
787
|
-
elif entry["type"] == "exit":
|
788
|
-
if current_span_data["duration"] is None: # Calculate duration only once
|
789
|
-
start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
|
790
|
-
end_time = datetime.fromisoformat(entry["created_at"])
|
791
|
-
current_span_data["duration"] = (end_time - start_time).total_seconds()
|
792
|
-
# Update depth if exit depth is different (though current span() implementation keeps it same)
|
793
|
-
# current_span_data["depth"] = entry["depth"]
|
794
|
-
|
795
|
-
# Convert dictionary to a list initially for easier access
|
796
|
-
spans_list = list(spans_by_id.values())
|
797
|
-
|
798
|
-
# Build tree structure (adjacency list) and find roots
|
799
|
-
children_map: Dict[Optional[str], List[dict]] = {}
|
800
|
-
roots = []
|
801
|
-
span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
|
802
|
-
|
803
|
-
for span in spans_list:
|
804
|
-
parent_id = span.get("parent_span_id")
|
805
|
-
if parent_id is None:
|
806
|
-
roots.append(span)
|
807
|
-
else:
|
808
|
-
if parent_id not in children_map:
|
809
|
-
children_map[parent_id] = []
|
810
|
-
children_map[parent_id].append(span)
|
811
|
-
|
812
|
-
# Sort roots by timestamp
|
813
|
-
roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
814
|
-
|
815
|
-
# Perform depth-first traversal to get the final sorted list
|
816
|
-
sorted_condensed_list = []
|
817
|
-
visited = set() # To handle potential cycles, though unlikely with UUIDs
|
818
|
-
|
819
|
-
def dfs(span_data):
|
820
|
-
span_id = span_data['span_id']
|
821
|
-
if span_id in visited:
|
822
|
-
return # Avoid infinite loops in case of cycles
|
823
|
-
visited.add(span_id)
|
824
|
-
|
825
|
-
sorted_condensed_list.append(span_data) # Add parent before children
|
826
|
-
|
827
|
-
# Get children, sort them by created_at, and visit them
|
828
|
-
span_children = children_map.get(span_id, [])
|
829
|
-
span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
830
|
-
for child in span_children:
|
831
|
-
# Ensure the child exists in our map before recursing
|
832
|
-
if child['span_id'] in span_map:
|
833
|
-
dfs(child)
|
834
|
-
else:
|
835
|
-
# This case might indicate an issue, but we'll add the child directly
|
836
|
-
# if its parent was processed but the child itself wasn't in the initial list?
|
837
|
-
# Or if the child's 'enter' event was missing. For robustness, add it.
|
838
|
-
if child['span_id'] not in visited:
|
839
|
-
visited.add(child['span_id'])
|
840
|
-
sorted_condensed_list.append(child)
|
841
|
-
|
842
|
-
|
843
|
-
# Start DFS from each root
|
844
|
-
for root_span in roots:
|
845
|
-
if root_span['span_id'] not in visited:
|
846
|
-
dfs(root_span)
|
847
|
-
|
848
|
-
# Handle spans that might not have been reachable from roots (orphans)
|
849
|
-
# Though ideally, all spans should descend from a root.
|
850
|
-
for span_data in spans_list:
|
851
|
-
if span_data['span_id'] not in visited:
|
852
|
-
# Decide how to handle orphans, maybe append them at the end sorted by time?
|
853
|
-
# For now, let's just add them to ensure they aren't lost.
|
854
|
-
sorted_condensed_list.append(span_data)
|
855
|
-
|
856
|
-
|
857
|
-
return sorted_condensed_list, evaluation_runs
|
858
516
|
|
859
517
|
def save(self, overwrite: bool = False) -> Tuple[str, dict]:
|
860
518
|
"""
|
@@ -863,44 +521,36 @@ class TraceClient:
|
|
863
521
|
"""
|
864
522
|
# Calculate total elapsed time
|
865
523
|
total_duration = self.get_duration()
|
866
|
-
|
867
|
-
raw_entries = [entry.to_dict() for entry in self.entries]
|
868
|
-
|
869
|
-
condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
|
870
524
|
|
871
525
|
# Only count tokens for actual LLM API call spans
|
872
526
|
llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
|
873
|
-
for
|
874
|
-
|
527
|
+
for span in self.trace_spans:
|
528
|
+
span_function_name = span.function # Get function name safely
|
875
529
|
# Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
|
876
|
-
|
877
|
-
has_api_suffix = any(suffix in
|
878
|
-
output_is_dict = isinstance(
|
530
|
+
is_llm_span = span.span_type == "llm"
|
531
|
+
has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
|
532
|
+
output_is_dict = isinstance(span.output, dict)
|
879
533
|
|
880
534
|
# --- DEBUG PRINT 1: Check if condition passes ---
|
881
535
|
# if is_llm_entry and has_api_suffix and output_is_dict:
|
882
|
-
# # print(f"[DEBUG TraceClient.save] Processing entry: {entry.get('span_id')} ({entry_function_name}) - Condition PASSED")
|
883
536
|
# elif is_llm_entry:
|
884
537
|
# # Print why it failed if it was an LLM entry
|
885
|
-
# print(f"[DEBUG TraceClient.save] Skipping LLM entry: {entry.get('span_id')} ({entry_function_name}) - Suffix Match: {has_api_suffix}, Output is Dict: {output_is_dict}")
|
886
538
|
# # --- END DEBUG ---
|
887
539
|
|
888
|
-
if
|
889
|
-
output =
|
540
|
+
if is_llm_span and has_api_suffix and output_is_dict:
|
541
|
+
output = span.output
|
890
542
|
usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
|
891
543
|
|
892
544
|
# --- DEBUG PRINT 2: Check extracted usage ---
|
893
|
-
# print(f"[DEBUG TraceClient.save] Extracted usage dict: {usage}")
|
894
545
|
# --- END DEBUG ---
|
895
546
|
|
896
547
|
# --- NEW: Extract model_name correctly from nested inputs ---
|
897
548
|
model_name = None
|
898
|
-
|
899
|
-
|
900
|
-
if entry_inputs:
|
549
|
+
span_inputs = span.inputs
|
550
|
+
if span_inputs:
|
901
551
|
# Try common locations for model name within the inputs structure
|
902
|
-
invocation_params =
|
903
|
-
serialized_data =
|
552
|
+
invocation_params = span_inputs.get("invocation_params", {})
|
553
|
+
serialized_data = span_inputs.get("serialized", {})
|
904
554
|
|
905
555
|
# Look in invocation_params (often directly contains model)
|
906
556
|
if isinstance(invocation_params, dict):
|
@@ -920,10 +570,9 @@ class TraceClient:
|
|
920
570
|
|
921
571
|
# Fallback: Check top-level of inputs itself (less likely for callbacks)
|
922
572
|
if not model_name:
|
923
|
-
model_name =
|
573
|
+
model_name = span_inputs.get("model")
|
924
574
|
|
925
575
|
|
926
|
-
# print(f"[DEBUG TraceClient.save] Determined model_name: {model_name}") # DEBUG Model Name
|
927
576
|
# --- END NEW ---
|
928
577
|
|
929
578
|
prompt_tokens = 0
|
@@ -985,7 +634,7 @@ class TraceClient:
|
|
985
634
|
if "usage" not in output:
|
986
635
|
output["usage"] = {} # Initialize if missing
|
987
636
|
elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
|
988
|
-
print(f"[WARN TraceClient.save] Output 'usage' for span {
|
637
|
+
print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
|
989
638
|
output["usage"] = {} # Reset to dict
|
990
639
|
|
991
640
|
output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
|
@@ -993,10 +642,10 @@ class TraceClient:
|
|
993
642
|
output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
|
994
643
|
except Exception as e:
|
995
644
|
# If cost calculation fails, continue without adding costs
|
996
|
-
print(f"Error calculating cost for model '{model_name}' (span: {
|
645
|
+
print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
|
997
646
|
pass
|
998
647
|
else:
|
999
|
-
print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {
|
648
|
+
print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
|
1000
649
|
|
1001
650
|
|
1002
651
|
# Create trace document - Always use standard keys for top-level counts
|
@@ -1006,8 +655,8 @@ class TraceClient:
|
|
1006
655
|
"project_name": self.project_name,
|
1007
656
|
"created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
|
1008
657
|
"duration": total_duration,
|
1009
|
-
"entries":
|
1010
|
-
"evaluation_runs": evaluation_runs,
|
658
|
+
"entries": [span.model_dump() for span in self.trace_spans],
|
659
|
+
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
1011
660
|
"overwrite": overwrite,
|
1012
661
|
"parent_trace_id": self.parent_trace_id,
|
1013
662
|
"parent_name": self.parent_name
|
@@ -1015,11 +664,248 @@ class TraceClient:
|
|
1015
664
|
# --- Log trace data before saving ---
|
1016
665
|
self.trace_manager_client.save_trace(trace_data)
|
1017
666
|
|
667
|
+
# upload annotations
|
668
|
+
# TODO: batch to the log endpoint
|
669
|
+
for annotation in self.annotations:
|
670
|
+
self.trace_manager_client.save_annotation(annotation)
|
671
|
+
|
1018
672
|
return self.trace_id, trace_data
|
1019
673
|
|
1020
674
|
def delete(self):
|
1021
675
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
1022
676
|
|
677
|
+
|
678
|
+
class _DeepTracer:
|
679
|
+
_instance: Optional["_DeepTracer"] = None
|
680
|
+
_lock: threading.Lock = threading.Lock()
|
681
|
+
_refcount: int = 0
|
682
|
+
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
|
683
|
+
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
|
684
|
+
|
685
|
+
def _get_qual_name(self, frame) -> str:
|
686
|
+
func_name = frame.f_code.co_name
|
687
|
+
module_name = frame.f_globals.get("__name__", "unknown_module")
|
688
|
+
|
689
|
+
try:
|
690
|
+
func = frame.f_globals.get(func_name)
|
691
|
+
if func is None:
|
692
|
+
return f"{module_name}.{func_name}"
|
693
|
+
if hasattr(func, "__qualname__"):
|
694
|
+
return f"{module_name}.{func.__qualname__}"
|
695
|
+
except Exception:
|
696
|
+
return f"{module_name}.{func_name}"
|
697
|
+
|
698
|
+
def __new__(cls):
|
699
|
+
with cls._lock:
|
700
|
+
if cls._instance is None:
|
701
|
+
cls._instance = super().__new__(cls)
|
702
|
+
return cls._instance
|
703
|
+
|
704
|
+
def _should_trace(self, frame):
|
705
|
+
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
706
|
+
# frames in the call stack that we've already determined should be skipped
|
707
|
+
skip_stack = self._skip_stack.get()
|
708
|
+
if len(skip_stack) > 0:
|
709
|
+
return False
|
710
|
+
|
711
|
+
func_name = frame.f_code.co_name
|
712
|
+
module_name = frame.f_globals.get("__name__", None)
|
713
|
+
|
714
|
+
func = frame.f_globals.get(func_name)
|
715
|
+
if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
|
716
|
+
return False
|
717
|
+
|
718
|
+
if (
|
719
|
+
not module_name
|
720
|
+
or func_name.startswith("<") # ex: <listcomp>
|
721
|
+
or func_name.startswith("__") and func_name != "__call__" # dunders
|
722
|
+
or not self._is_user_code(frame.f_code.co_filename)
|
723
|
+
):
|
724
|
+
return False
|
725
|
+
|
726
|
+
return True
|
727
|
+
|
728
|
+
@functools.cache
|
729
|
+
def _is_user_code(self, filename: str):
|
730
|
+
return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
731
|
+
|
732
|
+
def _trace(self, frame: types.FrameType, event: str, arg: Any):
|
733
|
+
frame.f_trace_lines = False
|
734
|
+
frame.f_trace_opcodes = False
|
735
|
+
|
736
|
+
|
737
|
+
if not self._should_trace(frame):
|
738
|
+
return
|
739
|
+
|
740
|
+
if event not in ("call", "return", "exception"):
|
741
|
+
return
|
742
|
+
|
743
|
+
current_trace = current_trace_var.get()
|
744
|
+
if not current_trace:
|
745
|
+
return
|
746
|
+
|
747
|
+
parent_span_id = current_span_var.get()
|
748
|
+
if not parent_span_id:
|
749
|
+
return
|
750
|
+
|
751
|
+
qual_name = self._get_qual_name(frame)
|
752
|
+
skip_stack = self._skip_stack.get()
|
753
|
+
|
754
|
+
if event == "call":
|
755
|
+
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
756
|
+
# push it again to track nesting depth and skip
|
757
|
+
# As an optimization, we only care about duplicate qual_names.
|
758
|
+
if skip_stack:
|
759
|
+
if qual_name == skip_stack[-1]:
|
760
|
+
skip_stack.append(qual_name)
|
761
|
+
self._skip_stack.set(skip_stack)
|
762
|
+
return
|
763
|
+
|
764
|
+
should_trace = self._should_trace(frame)
|
765
|
+
|
766
|
+
if not should_trace:
|
767
|
+
if not skip_stack:
|
768
|
+
self._skip_stack.set([qual_name])
|
769
|
+
return
|
770
|
+
elif event == "return":
|
771
|
+
# If we have entries in skip stack and current qual_name matches the top entry,
|
772
|
+
# pop it to track exiting from the skipped section
|
773
|
+
if skip_stack and qual_name == skip_stack[-1]:
|
774
|
+
skip_stack.pop()
|
775
|
+
self._skip_stack.set(skip_stack)
|
776
|
+
return
|
777
|
+
|
778
|
+
if skip_stack:
|
779
|
+
return
|
780
|
+
|
781
|
+
span_stack = self._span_stack.get()
|
782
|
+
if event == "call":
|
783
|
+
if not self._should_trace(frame):
|
784
|
+
return
|
785
|
+
|
786
|
+
span_id = str(uuid.uuid4())
|
787
|
+
|
788
|
+
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
789
|
+
depth = parent_depth + 1
|
790
|
+
|
791
|
+
current_trace._span_depths[span_id] = depth
|
792
|
+
|
793
|
+
start_time = time.time()
|
794
|
+
|
795
|
+
span_stack.append({
|
796
|
+
"span_id": span_id,
|
797
|
+
"parent_span_id": parent_span_id,
|
798
|
+
"function": qual_name,
|
799
|
+
"start_time": start_time
|
800
|
+
})
|
801
|
+
self._span_stack.set(span_stack)
|
802
|
+
|
803
|
+
token = current_span_var.set(span_id)
|
804
|
+
frame.f_locals["_judgment_span_token"] = token
|
805
|
+
|
806
|
+
span = TraceSpan(
|
807
|
+
span_id=span_id,
|
808
|
+
trace_id=current_trace.trace_id,
|
809
|
+
depth=depth,
|
810
|
+
message=qual_name,
|
811
|
+
created_at=start_time,
|
812
|
+
span_type="span",
|
813
|
+
parent_span_id=parent_span_id,
|
814
|
+
function=qual_name
|
815
|
+
)
|
816
|
+
current_trace.add_span(span)
|
817
|
+
|
818
|
+
inputs = {}
|
819
|
+
try:
|
820
|
+
args_info = inspect.getargvalues(frame)
|
821
|
+
for arg in args_info.args:
|
822
|
+
try:
|
823
|
+
inputs[arg] = args_info.locals.get(arg)
|
824
|
+
except:
|
825
|
+
inputs[arg] = "<<Unserializable>>"
|
826
|
+
current_trace.record_input(inputs)
|
827
|
+
except Exception as e:
|
828
|
+
current_trace.record_input({
|
829
|
+
"error": str(e)
|
830
|
+
})
|
831
|
+
|
832
|
+
elif event == "return":
|
833
|
+
if not span_stack:
|
834
|
+
return
|
835
|
+
|
836
|
+
current_id = current_span_var.get()
|
837
|
+
|
838
|
+
span_data = None
|
839
|
+
for i, entry in enumerate(reversed(span_stack)):
|
840
|
+
if entry["span_id"] == current_id:
|
841
|
+
span_data = span_stack.pop(-(i+1))
|
842
|
+
self._span_stack.set(span_stack)
|
843
|
+
break
|
844
|
+
|
845
|
+
if not span_data:
|
846
|
+
return
|
847
|
+
|
848
|
+
start_time = span_data["start_time"]
|
849
|
+
duration = time.time() - start_time
|
850
|
+
|
851
|
+
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
852
|
+
|
853
|
+
if arg is not None:
|
854
|
+
# exception handling will take priority.
|
855
|
+
current_trace.record_output(arg)
|
856
|
+
|
857
|
+
if span_data["span_id"] in current_trace._span_depths:
|
858
|
+
del current_trace._span_depths[span_data["span_id"]]
|
859
|
+
|
860
|
+
if span_stack:
|
861
|
+
current_span_var.set(span_stack[-1]["span_id"])
|
862
|
+
else:
|
863
|
+
current_span_var.set(span_data["parent_span_id"])
|
864
|
+
|
865
|
+
if "_judgment_span_token" in frame.f_locals:
|
866
|
+
current_span_var.reset(frame.f_locals["_judgment_span_token"])
|
867
|
+
|
868
|
+
elif event == "exception":
|
869
|
+
exc_type, exc_value, exc_traceback = arg
|
870
|
+
formatted_exception = {
|
871
|
+
"type": exc_type.__name__,
|
872
|
+
"message": str(exc_value),
|
873
|
+
"traceback": traceback.format_tb(exc_traceback)
|
874
|
+
}
|
875
|
+
current_trace = current_trace_var.get()
|
876
|
+
current_trace.record_output({
|
877
|
+
"error": formatted_exception
|
878
|
+
})
|
879
|
+
|
880
|
+
return self._trace
|
881
|
+
|
882
|
+
def __enter__(self):
|
883
|
+
with self._lock:
|
884
|
+
self._refcount += 1
|
885
|
+
if self._refcount == 1:
|
886
|
+
self._skip_stack.set([])
|
887
|
+
self._span_stack.set([])
|
888
|
+
sys.settrace(self._trace)
|
889
|
+
threading.settrace(self._trace)
|
890
|
+
return self
|
891
|
+
|
892
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
893
|
+
with self._lock:
|
894
|
+
self._refcount -= 1
|
895
|
+
if self._refcount == 0:
|
896
|
+
sys.settrace(None)
|
897
|
+
threading.settrace(None)
|
898
|
+
|
899
|
+
|
900
|
+
def log(self, message: str, level: str = "info"):
|
901
|
+
""" Log a message with the span context """
|
902
|
+
current_trace = current_trace_var.get()
|
903
|
+
if current_trace:
|
904
|
+
current_trace.log(message, level)
|
905
|
+
else:
|
906
|
+
print(f"[{level}] {message}")
|
907
|
+
current_trace.record_output({"log": message})
|
908
|
+
|
1023
909
|
class Tracer:
|
1024
910
|
_instance = None
|
1025
911
|
|
@@ -1042,12 +928,16 @@ class Tracer:
|
|
1042
928
|
s3_aws_access_key_id: Optional[str] = None,
|
1043
929
|
s3_aws_secret_access_key: Optional[str] = None,
|
1044
930
|
s3_region_name: Optional[str] = None,
|
1045
|
-
deep_tracing: bool = True #
|
931
|
+
deep_tracing: bool = True # Deep tracing is enabled by default
|
1046
932
|
):
|
1047
933
|
if not hasattr(self, 'initialized'):
|
1048
934
|
if not api_key:
|
1049
935
|
raise ValueError("Tracer must be configured with a Judgment API key")
|
1050
936
|
|
937
|
+
result, response = validate_api_key(api_key)
|
938
|
+
if not result:
|
939
|
+
raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
|
940
|
+
|
1051
941
|
if not organization_id:
|
1052
942
|
raise ValueError("Tracer must be configured with an Organization ID")
|
1053
943
|
if use_s3 and not s3_bucket_name:
|
@@ -1059,11 +949,11 @@ class Tracer:
|
|
1059
949
|
|
1060
950
|
self.api_key: str = api_key
|
1061
951
|
self.project_name: str = project_name
|
1062
|
-
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
1063
952
|
self.organization_id: str = organization_id
|
1064
953
|
self._current_trace: Optional[str] = None
|
1065
954
|
self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
|
1066
955
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
956
|
+
self.traces: List[Trace] = []
|
1067
957
|
self.initialized: bool = True
|
1068
958
|
self.enable_monitoring: bool = enable_monitoring
|
1069
959
|
self.enable_evaluations: bool = enable_evaluations
|
@@ -1119,45 +1009,6 @@ class Tracer:
|
|
1119
1009
|
"""Returns the TraceClient instance currently marked as active by the handler."""
|
1120
1010
|
return self._active_trace_client
|
1121
1011
|
|
1122
|
-
def _apply_deep_tracing(self, func, span_type="span"):
|
1123
|
-
"""
|
1124
|
-
Apply deep tracing to all functions in the same module as the given function.
|
1125
|
-
|
1126
|
-
Args:
|
1127
|
-
func: The function being traced
|
1128
|
-
span_type: Type of span to use for traced functions
|
1129
|
-
|
1130
|
-
Returns:
|
1131
|
-
A tuple of (module, original_functions_dict) where original_functions_dict
|
1132
|
-
contains the original functions that were replaced with traced versions.
|
1133
|
-
"""
|
1134
|
-
module = inspect.getmodule(func)
|
1135
|
-
if not module:
|
1136
|
-
return None, {}
|
1137
|
-
|
1138
|
-
# Save original functions
|
1139
|
-
original_functions = {}
|
1140
|
-
|
1141
|
-
# Find all functions in the module
|
1142
|
-
for name, obj in inspect.getmembers(module, inspect.isfunction):
|
1143
|
-
# Skip already wrapped functions
|
1144
|
-
if hasattr(obj, '_judgment_traced'):
|
1145
|
-
continue
|
1146
|
-
|
1147
|
-
# Create a traced version of the function
|
1148
|
-
# Always use default span type "span" for child functions
|
1149
|
-
traced_func = _create_deep_tracing_wrapper(obj, self, "span")
|
1150
|
-
|
1151
|
-
# Mark the function as traced to avoid double wrapping
|
1152
|
-
traced_func._judgment_traced = True
|
1153
|
-
|
1154
|
-
# Save the original function
|
1155
|
-
original_functions[name] = obj
|
1156
|
-
|
1157
|
-
# Replace with traced version
|
1158
|
-
setattr(module, name, traced_func)
|
1159
|
-
|
1160
|
-
return module, original_functions
|
1161
1012
|
|
1162
1013
|
@contextmanager
|
1163
1014
|
def trace(
|
@@ -1204,6 +1055,23 @@ class Tracer:
|
|
1204
1055
|
finally:
|
1205
1056
|
# Reset the context variable
|
1206
1057
|
current_trace_var.reset(token)
|
1058
|
+
|
1059
|
+
|
1060
|
+
def log(self, msg: str, label: str = "log", score: int = 1):
|
1061
|
+
"""Log a message with the current span context"""
|
1062
|
+
current_span_id = current_span_var.get()
|
1063
|
+
current_trace = current_trace_var.get()
|
1064
|
+
if current_span_id:
|
1065
|
+
annotation = TraceAnnotation(
|
1066
|
+
span_id=current_span_id,
|
1067
|
+
text=msg,
|
1068
|
+
label=label,
|
1069
|
+
score=score
|
1070
|
+
)
|
1071
|
+
|
1072
|
+
current_trace.add_annotation(annotation)
|
1073
|
+
|
1074
|
+
rprint(f"[bold]{label}:[/bold] {msg}")
|
1207
1075
|
|
1208
1076
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
|
1209
1077
|
"""
|
@@ -1239,13 +1107,6 @@ class Tracer:
|
|
1239
1107
|
if asyncio.iscoroutinefunction(func):
|
1240
1108
|
@functools.wraps(func)
|
1241
1109
|
async def async_wrapper(*args, **kwargs):
|
1242
|
-
# Check if we're already in a traced function
|
1243
|
-
if in_traced_function_var.get():
|
1244
|
-
return await func(*args, **kwargs)
|
1245
|
-
|
1246
|
-
# Set in_traced_function_var to True
|
1247
|
-
token = in_traced_function_var.set(True)
|
1248
|
-
|
1249
1110
|
# Get current trace from context
|
1250
1111
|
current_trace = current_trace_var.get()
|
1251
1112
|
|
@@ -1275,81 +1136,47 @@ class Tracer:
|
|
1275
1136
|
# This sets the current_span_var
|
1276
1137
|
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1277
1138
|
# Record inputs
|
1278
|
-
|
1279
|
-
|
1280
|
-
'kwargs': kwargs
|
1281
|
-
})
|
1139
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1140
|
+
span.record_input(inputs)
|
1282
1141
|
|
1283
|
-
# If deep tracing is enabled, apply monkey patching
|
1284
1142
|
if use_deep_tracing:
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
# Restore original functions if deep tracing was enabled
|
1291
|
-
if use_deep_tracing and module and 'original_functions' in locals():
|
1292
|
-
for name, obj in original_functions.items():
|
1293
|
-
setattr(module, name, obj)
|
1294
|
-
|
1143
|
+
with _DeepTracer():
|
1144
|
+
result = await func(*args, **kwargs)
|
1145
|
+
else:
|
1146
|
+
result = await func(*args, **kwargs)
|
1147
|
+
|
1295
1148
|
# Record output
|
1296
1149
|
span.record_output(result)
|
1297
|
-
|
1298
|
-
# Save the completed trace
|
1299
|
-
current_trace.save(overwrite=overwrite)
|
1300
1150
|
return result
|
1301
1151
|
finally:
|
1152
|
+
# Save the completed trace
|
1153
|
+
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1154
|
+
self.traces.append(trace)
|
1155
|
+
|
1302
1156
|
# Reset trace context (span context resets automatically)
|
1303
1157
|
current_trace_var.reset(trace_token)
|
1304
|
-
# Reset in_traced_function_var
|
1305
|
-
in_traced_function_var.reset(token)
|
1306
1158
|
else:
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
'kwargs': kwargs
|
1316
|
-
})
|
1317
|
-
|
1318
|
-
# If deep tracing is enabled, apply monkey patching
|
1319
|
-
if use_deep_tracing:
|
1320
|
-
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1321
|
-
|
1322
|
-
# Execute function
|
1159
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1160
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1161
|
+
span.record_input(inputs)
|
1162
|
+
|
1163
|
+
if use_deep_tracing:
|
1164
|
+
with _DeepTracer():
|
1165
|
+
result = await func(*args, **kwargs)
|
1166
|
+
else:
|
1323
1167
|
result = await func(*args, **kwargs)
|
1324
1168
|
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
setattr(module, name, obj)
|
1329
|
-
|
1330
|
-
# Record output
|
1331
|
-
span.record_output(result)
|
1332
|
-
|
1333
|
-
return result
|
1334
|
-
finally:
|
1335
|
-
# Reset in_traced_function_var
|
1336
|
-
in_traced_function_var.reset(token)
|
1337
|
-
|
1169
|
+
span.record_output(result)
|
1170
|
+
return result
|
1171
|
+
|
1338
1172
|
return async_wrapper
|
1339
1173
|
else:
|
1340
1174
|
# Non-async function implementation with deep tracing
|
1341
1175
|
@functools.wraps(func)
|
1342
|
-
def wrapper(*args, **kwargs):
|
1343
|
-
# Check if we're already in a traced function
|
1344
|
-
if in_traced_function_var.get():
|
1345
|
-
return func(*args, **kwargs)
|
1346
|
-
|
1347
|
-
# Set in_traced_function_var to True
|
1348
|
-
token = in_traced_function_var.set(True)
|
1349
|
-
|
1176
|
+
def wrapper(*args, **kwargs):
|
1350
1177
|
# Get current trace from context
|
1351
1178
|
current_trace = current_trace_var.get()
|
1352
|
-
|
1179
|
+
|
1353
1180
|
# If there's no current trace, create a root trace
|
1354
1181
|
if not current_trace:
|
1355
1182
|
trace_id = str(uuid.uuid4())
|
@@ -1376,66 +1203,40 @@ class Tracer:
|
|
1376
1203
|
# This sets the current_span_var
|
1377
1204
|
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1378
1205
|
# Record inputs
|
1379
|
-
|
1380
|
-
|
1381
|
-
'kwargs': kwargs
|
1382
|
-
})
|
1206
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1207
|
+
span.record_input(inputs)
|
1383
1208
|
|
1384
|
-
# If deep tracing is enabled, apply monkey patching
|
1385
1209
|
if use_deep_tracing:
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
# Restore original functions if deep tracing was enabled
|
1392
|
-
if use_deep_tracing and module and 'original_functions' in locals():
|
1393
|
-
for name, obj in original_functions.items():
|
1394
|
-
setattr(module, name, obj)
|
1210
|
+
with _DeepTracer():
|
1211
|
+
result = func(*args, **kwargs)
|
1212
|
+
else:
|
1213
|
+
result = func(*args, **kwargs)
|
1395
1214
|
|
1396
1215
|
# Record output
|
1397
1216
|
span.record_output(result)
|
1398
|
-
|
1399
|
-
# Save the completed trace
|
1400
|
-
current_trace.save(overwrite=overwrite)
|
1401
1217
|
return result
|
1402
1218
|
finally:
|
1219
|
+
# Save the completed trace
|
1220
|
+
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1221
|
+
self.traces.append(trace)
|
1222
|
+
|
1403
1223
|
# Reset trace context (span context resets automatically)
|
1404
1224
|
current_trace_var.reset(trace_token)
|
1405
|
-
# Reset in_traced_function_var
|
1406
|
-
in_traced_function_var.reset(token)
|
1407
1225
|
else:
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
})
|
1418
|
-
|
1419
|
-
# If deep tracing is enabled, apply monkey patching
|
1420
|
-
if use_deep_tracing:
|
1421
|
-
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1422
|
-
|
1423
|
-
# Execute function
|
1226
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1227
|
+
|
1228
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1229
|
+
span.record_input(inputs)
|
1230
|
+
|
1231
|
+
if use_deep_tracing:
|
1232
|
+
with _DeepTracer():
|
1233
|
+
result = func(*args, **kwargs)
|
1234
|
+
else:
|
1424
1235
|
result = func(*args, **kwargs)
|
1425
1236
|
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
setattr(module, name, obj)
|
1430
|
-
|
1431
|
-
# Record output
|
1432
|
-
span.record_output(result)
|
1433
|
-
|
1434
|
-
return result
|
1435
|
-
finally:
|
1436
|
-
# Reset in_traced_function_var
|
1437
|
-
in_traced_function_var.reset(token)
|
1438
|
-
|
1237
|
+
span.record_output(result)
|
1238
|
+
return result
|
1239
|
+
|
1439
1240
|
return wrapper
|
1440
1241
|
|
1441
1242
|
def async_evaluate(self, *args, **kwargs):
|
@@ -1469,7 +1270,7 @@ def wrap(client: Any) -> Any:
|
|
1469
1270
|
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1470
1271
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1471
1272
|
"""
|
1472
|
-
span_name, original_create, original_stream = _get_client_config(client)
|
1273
|
+
span_name, original_create, responses_create, original_stream = _get_client_config(client)
|
1473
1274
|
|
1474
1275
|
# --- Define Traced Async Functions ---
|
1475
1276
|
async def traced_create_async(*args, **kwargs):
|
@@ -1567,7 +1368,41 @@ def wrap(client: Any) -> Any:
|
|
1567
1368
|
span.record_output(output_data)
|
1568
1369
|
return response_or_iterator
|
1569
1370
|
|
1371
|
+
# --- Define Traced Sync Functions ---
|
1372
|
+
def traced_response_create_sync(*args, **kwargs):
|
1373
|
+
# [Existing logic - unchanged]
|
1374
|
+
current_trace = current_trace_var.get()
|
1375
|
+
if not current_trace:
|
1376
|
+
return responses_create(*args, **kwargs)
|
1570
1377
|
|
1378
|
+
is_streaming = kwargs.get("stream", False)
|
1379
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1380
|
+
span.record_input(kwargs)
|
1381
|
+
|
1382
|
+
# Warn about token counting limitations with streaming
|
1383
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1384
|
+
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1385
|
+
warnings.warn(
|
1386
|
+
"OpenAI streaming calls don't include token counts by default. "
|
1387
|
+
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1388
|
+
"in your API call arguments.",
|
1389
|
+
UserWarning
|
1390
|
+
)
|
1391
|
+
|
1392
|
+
try:
|
1393
|
+
response_or_iterator = responses_create(*args, **kwargs)
|
1394
|
+
except Exception as e:
|
1395
|
+
print(f"Error during wrapped sync API call ({span_name}): {e}")
|
1396
|
+
span.record_output({"error": str(e)})
|
1397
|
+
raise
|
1398
|
+
if is_streaming:
|
1399
|
+
output_entry = span.record_output("<pending stream>")
|
1400
|
+
return _sync_stream_wrapper(response_or_iterator, client, output_entry)
|
1401
|
+
else:
|
1402
|
+
output_data = _format_response_output_data(client, response_or_iterator)
|
1403
|
+
span.record_output(output_data)
|
1404
|
+
return response_or_iterator
|
1405
|
+
|
1571
1406
|
# Function replacing sync .stream()
|
1572
1407
|
def traced_stream_sync(*args, **kwargs):
|
1573
1408
|
current_trace = current_trace_var.get()
|
@@ -1615,15 +1450,16 @@ def wrap(client: Any) -> Any:
|
|
1615
1450
|
if original_stream:
|
1616
1451
|
client.messages.stream = traced_stream_async
|
1617
1452
|
elif isinstance(client, genai.client.AsyncClient):
|
1618
|
-
client.generate_content = traced_create_async
|
1453
|
+
client.models.generate_content = traced_create_async
|
1619
1454
|
elif isinstance(client, (OpenAI, Together)):
|
1620
1455
|
client.chat.completions.create = traced_create_sync
|
1456
|
+
client.responses.create = traced_response_create_sync
|
1621
1457
|
elif isinstance(client, Anthropic):
|
1622
1458
|
client.messages.create = traced_create_sync
|
1623
1459
|
if original_stream:
|
1624
1460
|
client.messages.stream = traced_stream_sync
|
1625
1461
|
elif isinstance(client, genai.Client):
|
1626
|
-
client.generate_content = traced_create_sync
|
1462
|
+
client.models.generate_content = traced_create_sync
|
1627
1463
|
|
1628
1464
|
return client
|
1629
1465
|
|
@@ -1639,19 +1475,20 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
|
|
1639
1475
|
tuple: (span_name, create_method, stream_method)
|
1640
1476
|
- span_name: String identifier for tracing
|
1641
1477
|
- create_method: Reference to the client's creation method
|
1478
|
+
- responses_method: Reference to the client's responses method (if applicable)
|
1642
1479
|
- stream_method: Reference to the client's stream method (if applicable)
|
1643
1480
|
|
1644
1481
|
Raises:
|
1645
1482
|
ValueError: If client type is not supported
|
1646
1483
|
"""
|
1647
1484
|
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1648
|
-
return "OPENAI_API_CALL", client.chat.completions.create, None
|
1485
|
+
return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None
|
1649
1486
|
elif isinstance(client, (Together, AsyncTogether)):
|
1650
|
-
return "TOGETHER_API_CALL", client.chat.completions.create, None
|
1487
|
+
return "TOGETHER_API_CALL", client.chat.completions.create, None, None
|
1651
1488
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1652
|
-
return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
|
1489
|
+
return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream
|
1653
1490
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1654
|
-
return "GOOGLE_API_CALL", client.models.generate_content, None
|
1491
|
+
return "GOOGLE_API_CALL", client.models.generate_content, None, None
|
1655
1492
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1656
1493
|
|
1657
1494
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
@@ -1677,6 +1514,26 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
1677
1514
|
"max_tokens": kwargs.get("max_tokens")
|
1678
1515
|
}
|
1679
1516
|
|
1517
|
+
def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
1518
|
+
"""Format API response data based on client type.
|
1519
|
+
|
1520
|
+
Normalizes different response formats into a consistent structure
|
1521
|
+
for tracing purposes.
|
1522
|
+
"""
|
1523
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1524
|
+
return {
|
1525
|
+
"content": response.output,
|
1526
|
+
"usage": {
|
1527
|
+
"prompt_tokens": response.usage.input_tokens,
|
1528
|
+
"completion_tokens": response.usage.output_tokens,
|
1529
|
+
"total_tokens": response.usage.total_tokens
|
1530
|
+
}
|
1531
|
+
}
|
1532
|
+
else:
|
1533
|
+
warnings.warn(f"Unsupported client type: {type(client)}")
|
1534
|
+
return {}
|
1535
|
+
|
1536
|
+
|
1680
1537
|
def _format_output_data(client: ApiClient, response: Any) -> dict:
|
1681
1538
|
"""Format API response data based on client type.
|
1682
1539
|
|
@@ -1716,117 +1573,51 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1716
1573
|
}
|
1717
1574
|
}
|
1718
1575
|
|
1719
|
-
|
1720
|
-
# These are typically utility functions, print statements, logging, etc.
|
1721
|
-
_TRACE_BLOCKLIST = {
|
1722
|
-
# Built-in functions
|
1723
|
-
'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
|
1724
|
-
'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
|
1725
|
-
'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
|
1726
|
-
# Logging functions
|
1727
|
-
'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
|
1728
|
-
# Common utility functions
|
1729
|
-
'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
|
1730
|
-
# String operations
|
1731
|
-
'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
|
1732
|
-
# Dict operations
|
1733
|
-
'get', 'items', 'keys', 'values', 'update',
|
1734
|
-
# List operations
|
1735
|
-
'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
|
1736
|
-
}
|
1737
|
-
|
1738
|
-
|
1739
|
-
# Add a new function for deep tracing at the module level
|
1740
|
-
def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
|
1576
|
+
def combine_args_kwargs(func, args, kwargs):
|
1741
1577
|
"""
|
1742
|
-
|
1743
|
-
This enables deep tracing without requiring explicit @observe decorators on every function.
|
1578
|
+
Combine positional arguments and keyword arguments into a single dictionary.
|
1744
1579
|
|
1745
1580
|
Args:
|
1746
|
-
func: The function
|
1747
|
-
|
1748
|
-
|
1581
|
+
func: The function being called
|
1582
|
+
args: Tuple of positional arguments
|
1583
|
+
kwargs: Dictionary of keyword arguments
|
1749
1584
|
|
1750
1585
|
Returns:
|
1751
|
-
A
|
1586
|
+
A dictionary combining both args and kwargs
|
1752
1587
|
"""
|
1753
|
-
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
# Create a span for this function call - use custom span_type if available
|
1787
|
-
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1788
|
-
# Record inputs
|
1789
|
-
span.record_input({
|
1790
|
-
'args': str(args),
|
1791
|
-
'kwargs': kwargs
|
1792
|
-
})
|
1793
|
-
|
1794
|
-
# Execute function
|
1795
|
-
result = await original_func(*args, **kwargs)
|
1796
|
-
|
1797
|
-
# Record output
|
1798
|
-
span.record_output(result)
|
1799
|
-
|
1800
|
-
return result
|
1801
|
-
|
1802
|
-
return async_deep_wrapper
|
1803
|
-
else:
|
1804
|
-
@functools.wraps(func)
|
1805
|
-
def deep_wrapper(*args, **kwargs):
|
1806
|
-
# Get current trace from context
|
1807
|
-
current_trace = current_trace_var.get()
|
1808
|
-
|
1809
|
-
# If no trace context, just call the function
|
1810
|
-
if not current_trace:
|
1811
|
-
return original_func(*args, **kwargs)
|
1812
|
-
|
1813
|
-
# Create a span for this function call - use custom span_type if available
|
1814
|
-
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1815
|
-
# Record inputs
|
1816
|
-
span.record_input({
|
1817
|
-
'args': str(args),
|
1818
|
-
'kwargs': kwargs
|
1819
|
-
})
|
1820
|
-
|
1821
|
-
# Execute function
|
1822
|
-
result = original_func(*args, **kwargs)
|
1823
|
-
|
1824
|
-
# Record output
|
1825
|
-
span.record_output(result)
|
1826
|
-
|
1827
|
-
return result
|
1828
|
-
|
1829
|
-
return deep_wrapper
|
1588
|
+
try:
|
1589
|
+
import inspect
|
1590
|
+
sig = inspect.signature(func)
|
1591
|
+
param_names = list(sig.parameters.keys())
|
1592
|
+
|
1593
|
+
args_dict = {}
|
1594
|
+
for i, arg in enumerate(args):
|
1595
|
+
if i < len(param_names):
|
1596
|
+
args_dict[param_names[i]] = arg
|
1597
|
+
else:
|
1598
|
+
args_dict[f"arg{i}"] = arg
|
1599
|
+
|
1600
|
+
return {**args_dict, **kwargs}
|
1601
|
+
except Exception as e:
|
1602
|
+
# Fallback if signature inspection fails
|
1603
|
+
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
1604
|
+
|
1605
|
+
# NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
|
1606
|
+
# @link https://docs.python.org/3.13/library/sysconfig.html
|
1607
|
+
_TRACE_FILEPATH_BLOCKLIST = tuple(
|
1608
|
+
os.path.realpath(p) + os.sep
|
1609
|
+
for p in {
|
1610
|
+
sysconfig.get_paths()['stdlib'],
|
1611
|
+
sysconfig.get_paths().get('platstdlib', ''),
|
1612
|
+
*site.getsitepackages(),
|
1613
|
+
site.getusersitepackages(),
|
1614
|
+
*(
|
1615
|
+
[os.path.join(os.path.dirname(__file__), '../../judgeval/')]
|
1616
|
+
if os.environ.get('JUDGMENT_DEV')
|
1617
|
+
else []
|
1618
|
+
),
|
1619
|
+
} if p
|
1620
|
+
)
|
1830
1621
|
|
1831
1622
|
# Add the new TraceThreadPoolExecutor class
|
1832
1623
|
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
@@ -1929,7 +1720,7 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
1929
1720
|
def _sync_stream_wrapper(
|
1930
1721
|
original_stream: Iterator,
|
1931
1722
|
client: ApiClient,
|
1932
|
-
|
1723
|
+
span: TraceSpan
|
1933
1724
|
) -> Generator[Any, None, None]:
|
1934
1725
|
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
1935
1726
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -1948,7 +1739,7 @@ def _sync_stream_wrapper(
|
|
1948
1739
|
final_usage = _extract_usage_from_final_chunk(client, last_chunk)
|
1949
1740
|
|
1950
1741
|
# Update the trace entry with the accumulated content and usage
|
1951
|
-
|
1742
|
+
span.output = {
|
1952
1743
|
"content": "".join(content_parts), # Join list at the end
|
1953
1744
|
"usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
|
1954
1745
|
"streamed": True
|
@@ -1960,7 +1751,7 @@ def _sync_stream_wrapper(
|
|
1960
1751
|
async def _async_stream_wrapper(
|
1961
1752
|
original_stream: AsyncIterator,
|
1962
1753
|
client: ApiClient,
|
1963
|
-
|
1754
|
+
span: TraceSpan
|
1964
1755
|
) -> AsyncGenerator[Any, None]:
|
1965
1756
|
# [Existing logic - unchanged]
|
1966
1757
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -1969,7 +1760,7 @@ async def _async_stream_wrapper(
|
|
1969
1760
|
anthropic_input_tokens = 0
|
1970
1761
|
anthropic_output_tokens = 0
|
1971
1762
|
|
1972
|
-
target_span_id =
|
1763
|
+
target_span_id = span.span_id
|
1973
1764
|
|
1974
1765
|
try:
|
1975
1766
|
async for chunk in original_stream:
|
@@ -2014,19 +1805,17 @@ async def _async_stream_wrapper(
|
|
2014
1805
|
elif last_content_chunk:
|
2015
1806
|
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
2016
1807
|
|
2017
|
-
if
|
2018
|
-
|
1808
|
+
if span and hasattr(span, 'output'):
|
1809
|
+
span.output = {
|
2019
1810
|
"content": "".join(content_parts), # Join list at the end
|
2020
1811
|
"usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
|
2021
1812
|
"streamed": True
|
2022
1813
|
}
|
2023
|
-
start_ts = getattr(
|
2024
|
-
|
1814
|
+
start_ts = getattr(span, 'created_at', time.time())
|
1815
|
+
span.duration = time.time() - start_ts
|
2025
1816
|
# else: # Handle error case if necessary, but remove debug print
|
2026
1817
|
|
2027
|
-
|
2028
|
-
class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
2029
|
-
"""Wraps an original async stream manager to add tracing."""
|
1818
|
+
class _BaseStreamManagerWrapper:
|
2030
1819
|
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
2031
1820
|
self._original_manager = original_manager
|
2032
1821
|
self._client = client
|
@@ -2036,160 +1825,77 @@ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
|
2036
1825
|
self._input_kwargs = input_kwargs
|
2037
1826
|
self._parent_span_id_at_entry = None
|
2038
1827
|
|
2039
|
-
|
2040
|
-
self._parent_span_id_at_entry = current_span_var.get()
|
2041
|
-
if not self._trace_client:
|
2042
|
-
# If no trace, just delegate to the original manager
|
2043
|
-
return await self._original_manager.__aenter__()
|
2044
|
-
|
2045
|
-
# --- Manually create the 'enter' entry ---
|
1828
|
+
def _create_span(self):
|
2046
1829
|
start_time = time.time()
|
2047
1830
|
span_id = str(uuid.uuid4())
|
2048
1831
|
current_depth = 0
|
2049
1832
|
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
2050
1833
|
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
2051
1834
|
self._trace_client._span_depths[span_id] = current_depth
|
2052
|
-
|
2053
|
-
|
2054
|
-
|
2055
|
-
|
1835
|
+
span = TraceSpan(
|
1836
|
+
function=self._span_name,
|
1837
|
+
span_id=span_id,
|
1838
|
+
trace_id=self._trace_client.trace_id,
|
1839
|
+
depth=current_depth,
|
1840
|
+
message=self._span_name,
|
1841
|
+
created_at=start_time,
|
1842
|
+
span_type="llm",
|
1843
|
+
parent_span_id=self._parent_span_id_at_entry
|
2056
1844
|
)
|
2057
|
-
self._trace_client.
|
2058
|
-
|
2059
|
-
|
2060
|
-
# Set the current span ID in contextvars
|
2061
|
-
self._span_context_token = current_span_var.set(span_id)
|
1845
|
+
self._trace_client.add_span(span)
|
1846
|
+
return span_id, span
|
2062
1847
|
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
2066
|
-
|
2067
|
-
|
2068
|
-
|
2069
|
-
)
|
2070
|
-
self._trace_client.add_entry(input_entry)
|
1848
|
+
def _finalize_span(self, span_id):
|
1849
|
+
span = self._trace_client.span_id_to_span.get(span_id)
|
1850
|
+
if span:
|
1851
|
+
span.duration = time.time() - span.created_at
|
1852
|
+
if span_id in self._trace_client._span_depths:
|
1853
|
+
del self._trace_client._span_depths[span_id]
|
2071
1854
|
|
2072
|
-
|
2073
|
-
|
1855
|
+
class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
|
1856
|
+
async def __aenter__(self):
|
1857
|
+
self._parent_span_id_at_entry = current_span_var.get()
|
1858
|
+
if not self._trace_client:
|
1859
|
+
return await self._original_manager.__aenter__()
|
2074
1860
|
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
2078
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
2079
|
-
created_at=time.time(), output="<pending stream>", span_type="llm"
|
2080
|
-
)
|
2081
|
-
self._trace_client.add_entry(output_entry)
|
1861
|
+
span_id, span = self._create_span()
|
1862
|
+
self._span_context_token = current_span_var.set(span_id)
|
1863
|
+
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
2082
1864
|
|
2083
|
-
#
|
2084
|
-
|
2085
|
-
|
1865
|
+
# Call the original __aenter__ and expect it to be an async generator
|
1866
|
+
raw_iterator = await self._original_manager.__aenter__()
|
1867
|
+
span.output = "<pending stream>"
|
1868
|
+
return self._stream_wrapper_func(raw_iterator, self._client, span)
|
2086
1869
|
|
2087
1870
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
2088
|
-
# Manually create the 'exit' entry
|
2089
1871
|
if hasattr(self, '_span_context_token'):
|
2090
|
-
|
2091
|
-
|
2092
|
-
|
2093
|
-
|
2094
|
-
|
2095
|
-
break
|
2096
|
-
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
2097
|
-
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
2098
|
-
exit_entry = TraceEntry(
|
2099
|
-
type="exit", function=self._span_name, span_id=span_id,
|
2100
|
-
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
2101
|
-
created_at=time.time(), duration=duration, span_type="llm"
|
2102
|
-
)
|
2103
|
-
self._trace_client.add_entry(exit_entry)
|
2104
|
-
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
2105
|
-
current_span_var.reset(self._span_context_token)
|
2106
|
-
delattr(self, '_span_context_token')
|
2107
|
-
|
2108
|
-
# Delegate __aexit__
|
2109
|
-
if hasattr(self._original_manager, "__aexit__"):
|
2110
|
-
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2111
|
-
return None
|
2112
|
-
|
2113
|
-
class _TracedSyncStreamManagerWrapper(AbstractContextManager):
|
2114
|
-
"""Wraps an original sync stream manager to add tracing."""
|
2115
|
-
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
2116
|
-
self._original_manager = original_manager
|
2117
|
-
self._client = client
|
2118
|
-
self._span_name = span_name
|
2119
|
-
self._trace_client = trace_client
|
2120
|
-
self._stream_wrapper_func = stream_wrapper_func
|
2121
|
-
self._input_kwargs = input_kwargs
|
2122
|
-
self._parent_span_id_at_entry = None
|
1872
|
+
span_id = current_span_var.get()
|
1873
|
+
self._finalize_span(span_id)
|
1874
|
+
current_span_var.reset(self._span_context_token)
|
1875
|
+
delattr(self, '_span_context_token')
|
1876
|
+
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2123
1877
|
|
1878
|
+
class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
|
2124
1879
|
def __enter__(self):
|
2125
1880
|
self._parent_span_id_at_entry = current_span_var.get()
|
2126
1881
|
if not self._trace_client:
|
2127
|
-
|
1882
|
+
return self._original_manager.__enter__()
|
2128
1883
|
|
2129
|
-
|
2130
|
-
start_time = time.time()
|
2131
|
-
span_id = str(uuid.uuid4())
|
2132
|
-
current_depth = 0
|
2133
|
-
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
2134
|
-
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
2135
|
-
self._trace_client._span_depths[span_id] = current_depth
|
2136
|
-
enter_entry = TraceEntry(
|
2137
|
-
type="enter", function=self._span_name, span_id=span_id,
|
2138
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
|
2139
|
-
created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
|
2140
|
-
)
|
2141
|
-
self._trace_client.add_entry(enter_entry)
|
1884
|
+
span_id, span = self._create_span()
|
2142
1885
|
self._span_context_token = current_span_var.set(span_id)
|
1886
|
+
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
2143
1887
|
|
2144
|
-
# Manually create 'input' entry
|
2145
|
-
input_data = _format_input_data(self._client, **self._input_kwargs)
|
2146
|
-
input_entry = TraceEntry(
|
2147
|
-
type="input", function=self._span_name, span_id=span_id,
|
2148
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
|
2149
|
-
created_at=time.time(), inputs=input_data, span_type="llm"
|
2150
|
-
)
|
2151
|
-
self._trace_client.add_entry(input_entry)
|
2152
|
-
|
2153
|
-
# Call original __enter__
|
2154
1888
|
raw_iterator = self._original_manager.__enter__()
|
2155
|
-
|
2156
|
-
|
2157
|
-
output_entry = TraceEntry(
|
2158
|
-
type="output", function=self._span_name, span_id=span_id,
|
2159
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
2160
|
-
created_at=time.time(), output="<pending stream>", span_type="llm"
|
2161
|
-
)
|
2162
|
-
self._trace_client.add_entry(output_entry)
|
2163
|
-
|
2164
|
-
# Wrap the raw iterator
|
2165
|
-
wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
|
2166
|
-
return wrapped_iterator
|
1889
|
+
span.output = "<pending stream>"
|
1890
|
+
return self._stream_wrapper_func(raw_iterator, self._client, span)
|
2167
1891
|
|
2168
1892
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
2169
|
-
# Manually create 'exit' entry
|
2170
1893
|
if hasattr(self, '_span_context_token'):
|
2171
|
-
|
2172
|
-
|
2173
|
-
|
2174
|
-
|
2175
|
-
|
2176
|
-
break
|
2177
|
-
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
2178
|
-
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
2179
|
-
exit_entry = TraceEntry(
|
2180
|
-
type="exit", function=self._span_name, span_id=span_id,
|
2181
|
-
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
2182
|
-
created_at=time.time(), duration=duration, span_type="llm"
|
2183
|
-
)
|
2184
|
-
self._trace_client.add_entry(exit_entry)
|
2185
|
-
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
2186
|
-
current_span_var.reset(self._span_context_token)
|
2187
|
-
delattr(self, '_span_context_token')
|
2188
|
-
|
2189
|
-
# Delegate __exit__
|
2190
|
-
if hasattr(self._original_manager, "__exit__"):
|
2191
|
-
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2192
|
-
return None
|
1894
|
+
span_id = current_span_var.get()
|
1895
|
+
self._finalize_span(span_id)
|
1896
|
+
current_span_var.reset(self._span_context_token)
|
1897
|
+
delattr(self, '_span_context_token')
|
1898
|
+
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2193
1899
|
|
2194
1900
|
# --- NEW Generalized Helper Function (Moved from demo) ---
|
2195
1901
|
def prepare_evaluation_for_state(
|
@@ -2314,3 +2020,4 @@ def add_evaluation_to_state(
|
|
2314
2020
|
|
2315
2021
|
# print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
|
2316
2022
|
# --- End NEW Helper ---
|
2023
|
+
|