judgeval 0.0.36__py3-none-any.whl → 0.0.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +663 -1105
- judgeval/common/utils.py +19 -1
- judgeval/constants.py +3 -3
- judgeval/data/__init__.py +4 -2
- judgeval/data/datasets/dataset.py +2 -11
- judgeval/data/datasets/eval_dataset_client.py +1 -62
- judgeval/data/example.py +29 -8
- judgeval/data/result.py +3 -3
- judgeval/data/trace.py +132 -0
- judgeval/data/{sequence_run.py → trace_run.py} +7 -6
- judgeval/evaluation_run.py +2 -2
- judgeval/integrations/langgraph.py +189 -1769
- judgeval/judges/litellm_judge.py +1 -1
- judgeval/judges/mixture_of_judges.py +1 -1
- judgeval/judges/utils.py +1 -1
- judgeval/judgment_client.py +85 -78
- judgeval/run_evaluation.py +98 -51
- judgeval/scorers/__init__.py +2 -0
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -0
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +20 -0
- judgeval/scorers/score.py +1 -1
- judgeval/utils/data_utils.py +57 -0
- judgeval-0.0.38.dist-info/METADATA +247 -0
- {judgeval-0.0.36.dist-info → judgeval-0.0.38.dist-info}/RECORD +26 -24
- judgeval/data/sequence.py +0 -49
- judgeval-0.0.36.dist-info/METADATA +0 -169
- {judgeval-0.0.36.dist-info → judgeval-0.0.38.dist-info}/WHEEL +0 -0
- {judgeval-0.0.36.dist-info → judgeval-0.0.38.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -7,7 +7,11 @@ import functools
|
|
7
7
|
import inspect
|
8
8
|
import json
|
9
9
|
import os
|
10
|
+
import site
|
11
|
+
import sysconfig
|
12
|
+
import threading
|
10
13
|
import time
|
14
|
+
import traceback
|
11
15
|
import uuid
|
12
16
|
import warnings
|
13
17
|
import contextvars
|
@@ -35,7 +39,6 @@ from rich import print as rprint
|
|
35
39
|
import types # <--- Add this import
|
36
40
|
|
37
41
|
# Third-party imports
|
38
|
-
import pika
|
39
42
|
import requests
|
40
43
|
from litellm import cost_per_token
|
41
44
|
from pydantic import BaseModel
|
@@ -44,10 +47,10 @@ from openai import OpenAI, AsyncOpenAI
|
|
44
47
|
from together import Together, AsyncTogether
|
45
48
|
from anthropic import Anthropic, AsyncAnthropic
|
46
49
|
from google import genai
|
47
|
-
from judgeval.run_evaluation import check_examples
|
48
50
|
|
49
51
|
# Local application/library-specific imports
|
50
52
|
from judgeval.constants import (
|
53
|
+
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
51
54
|
JUDGMENT_TRACES_SAVE_API_URL,
|
52
55
|
JUDGMENT_TRACES_FETCH_API_URL,
|
53
56
|
RABBITMQ_HOST,
|
@@ -56,25 +59,24 @@ from judgeval.constants import (
|
|
56
59
|
JUDGMENT_TRACES_DELETE_API_URL,
|
57
60
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
58
61
|
)
|
59
|
-
from judgeval.
|
60
|
-
from judgeval.data import Example
|
62
|
+
from judgeval.data import Example, Trace, TraceSpan
|
61
63
|
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
62
64
|
from judgeval.rules import Rule
|
63
65
|
from judgeval.evaluation_run import EvaluationRun
|
64
66
|
from judgeval.data.result import ScoringResult
|
67
|
+
from judgeval.common.utils import validate_api_key
|
68
|
+
from judgeval.common.exceptions import JudgmentAPIError
|
65
69
|
|
66
70
|
# Standard library imports needed for the new class
|
67
71
|
import concurrent.futures
|
68
72
|
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
69
73
|
|
70
74
|
# Define context variables for tracking the current trace and the current span within a trace
|
71
|
-
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
75
|
+
current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
|
72
76
|
current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
|
73
|
-
in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
|
74
77
|
|
75
78
|
# Define type aliases for better code readability and maintainability
|
76
79
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
|
77
|
-
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
78
80
|
SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
|
79
81
|
|
80
82
|
# --- Evaluation Config Dataclass (Moved from langgraph.py) ---
|
@@ -87,154 +89,26 @@ class EvaluationConfig:
|
|
87
89
|
log_results: Optional[bool] = True
|
88
90
|
# --- End Evaluation Config Dataclass ---
|
89
91
|
|
92
|
+
# Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
|
90
93
|
@dataclass
|
91
|
-
class
|
92
|
-
"""Represents a single
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
- output: Output: (function return value)
|
98
|
-
- input: Input: (function parameters)
|
99
|
-
- evaluation: Evaluation: (evaluation results)
|
100
|
-
"""
|
101
|
-
type: TraceEntryType
|
102
|
-
span_id: str # Unique ID for this specific span instance
|
103
|
-
depth: int # Indentation level for nested calls
|
104
|
-
created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
|
105
|
-
function: Optional[str] = None # Name of the function being traced
|
106
|
-
message: Optional[str] = None # Human-readable description
|
107
|
-
duration: Optional[float] = None # Time taken (for exit/evaluation entries)
|
108
|
-
trace_id: str = None # ID of the trace this entry belongs to
|
109
|
-
output: Any = None # Function output value
|
110
|
-
# Use field() for mutable defaults to avoid shared state issues
|
111
|
-
inputs: dict = field(default_factory=dict)
|
112
|
-
span_type: SpanType = "span"
|
113
|
-
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
114
|
-
parent_span_id: Optional[str] = None # ID of the parent span instance
|
115
|
-
|
116
|
-
def print_entry(self):
|
117
|
-
"""Print a trace entry with proper formatting and parent relationship information."""
|
118
|
-
indent = " " * self.depth
|
119
|
-
|
120
|
-
if self.type == "enter":
|
121
|
-
# Format parent info if present
|
122
|
-
parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
|
123
|
-
print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
|
124
|
-
elif self.type == "exit":
|
125
|
-
print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
|
126
|
-
elif self.type == "output":
|
127
|
-
# Format output to align properly
|
128
|
-
output_str = str(self.output)
|
129
|
-
print(f"{indent}Output (for id: {self.span_id}): {output_str}")
|
130
|
-
elif self.type == "input":
|
131
|
-
# Format inputs to align properly
|
132
|
-
print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
|
133
|
-
elif self.type == "evaluation":
|
134
|
-
for evaluation_run in self.evaluation_runs:
|
135
|
-
print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
|
136
|
-
|
137
|
-
def _serialize_inputs(self) -> dict:
|
138
|
-
"""Helper method to serialize input data safely.
|
139
|
-
|
140
|
-
Returns a dict with serializable versions of inputs, converting non-serializable
|
141
|
-
objects to None with a warning.
|
142
|
-
"""
|
143
|
-
serialized_inputs = {}
|
144
|
-
for key, value in self.inputs.items():
|
145
|
-
if isinstance(value, BaseModel):
|
146
|
-
serialized_inputs[key] = value.model_dump()
|
147
|
-
elif isinstance(value, (list, tuple)):
|
148
|
-
# Handle lists/tuples of arguments
|
149
|
-
serialized_inputs[key] = [
|
150
|
-
item.model_dump() if isinstance(item, BaseModel)
|
151
|
-
else None if not self._is_json_serializable(item)
|
152
|
-
else item
|
153
|
-
for item in value
|
154
|
-
]
|
155
|
-
else:
|
156
|
-
if self._is_json_serializable(value):
|
157
|
-
serialized_inputs[key] = value
|
158
|
-
else:
|
159
|
-
serialized_inputs[key] = self.safe_stringify(value, self.function)
|
160
|
-
return serialized_inputs
|
161
|
-
|
162
|
-
def _is_json_serializable(self, obj: Any) -> bool:
|
163
|
-
"""Helper method to check if an object is JSON serializable."""
|
164
|
-
try:
|
165
|
-
json.dumps(obj)
|
166
|
-
return True
|
167
|
-
except (TypeError, OverflowError, ValueError):
|
168
|
-
return False
|
169
|
-
|
170
|
-
def safe_stringify(self, output, function_name):
|
171
|
-
"""
|
172
|
-
Safely converts an object to a string or repr, handling serialization issues gracefully.
|
173
|
-
"""
|
174
|
-
try:
|
175
|
-
return str(output)
|
176
|
-
except (TypeError, OverflowError, ValueError):
|
177
|
-
pass
|
178
|
-
|
179
|
-
try:
|
180
|
-
return repr(output)
|
181
|
-
except (TypeError, OverflowError, ValueError):
|
182
|
-
pass
|
183
|
-
|
184
|
-
warnings.warn(
|
185
|
-
f"Output for function {function_name} is not JSON serializable and could not be converted to string. Setting to None."
|
186
|
-
)
|
187
|
-
return None
|
94
|
+
class TraceAnnotation:
|
95
|
+
"""Represents a single annotation for a trace span."""
|
96
|
+
span_id: str
|
97
|
+
text: str
|
98
|
+
label: str
|
99
|
+
score: int
|
188
100
|
|
189
101
|
def to_dict(self) -> dict:
|
190
|
-
"""Convert the
|
102
|
+
"""Convert the annotation to a dictionary format for storage/transmission."""
|
191
103
|
return {
|
192
|
-
"type": self.type,
|
193
|
-
"function": self.function,
|
194
104
|
"span_id": self.span_id,
|
195
|
-
"
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
"output": self._serialize_output(),
|
201
|
-
"inputs": self._serialize_inputs(),
|
202
|
-
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
203
|
-
"span_type": self.span_type,
|
204
|
-
"parent_span_id": self.parent_span_id,
|
105
|
+
"annotation": {
|
106
|
+
"text": self.text,
|
107
|
+
"label": self.label,
|
108
|
+
"score": self.score
|
109
|
+
}
|
205
110
|
}
|
206
|
-
|
207
|
-
def _serialize_output(self) -> Any:
|
208
|
-
"""Helper method to serialize output data safely.
|
209
|
-
|
210
|
-
Handles special cases:
|
211
|
-
- Pydantic models are converted using model_dump()
|
212
|
-
- Dictionaries are processed recursively to handle non-serializable values.
|
213
|
-
- We try to serialize into JSON, then string, then the base representation (__repr__)
|
214
|
-
- Non-serializable objects return None with a warning
|
215
|
-
"""
|
216
|
-
|
217
|
-
def serialize_value(value):
|
218
|
-
if isinstance(value, BaseModel):
|
219
|
-
return value.model_dump()
|
220
|
-
elif isinstance(value, dict):
|
221
|
-
# Recursively serialize dictionary values
|
222
|
-
return {k: serialize_value(v) for k, v in value.items()}
|
223
|
-
elif isinstance(value, (list, tuple)):
|
224
|
-
# Recursively serialize list/tuple items
|
225
|
-
return [serialize_value(item) for item in value]
|
226
|
-
else:
|
227
|
-
# Try direct JSON serialization first
|
228
|
-
try:
|
229
|
-
json.dumps(value)
|
230
|
-
return value
|
231
|
-
except (TypeError, OverflowError, ValueError):
|
232
|
-
# Fallback to safe stringification
|
233
|
-
return self.safe_stringify(value, self.function)
|
234
|
-
|
235
|
-
# Start serialization with the top-level output
|
236
|
-
return serialize_value(self.output)
|
237
|
-
|
111
|
+
|
238
112
|
class TraceManagerClient:
|
239
113
|
"""
|
240
114
|
Client for handling trace endpoints with the Judgment API
|
@@ -271,10 +145,8 @@ class TraceManagerClient:
|
|
271
145
|
raise ValueError(f"Failed to fetch traces: {response.text}")
|
272
146
|
|
273
147
|
return response.json()
|
274
|
-
|
275
|
-
|
276
148
|
|
277
|
-
def save_trace(self, trace_data: dict):
|
149
|
+
def save_trace(self, trace_data: dict, offline_mode: bool = False):
|
278
150
|
"""
|
279
151
|
Saves a trace to the Judgment Supabase and optionally to S3 if configured.
|
280
152
|
|
@@ -311,10 +183,37 @@ class TraceManagerClient:
|
|
311
183
|
except Exception as e:
|
312
184
|
warnings.warn(f"Failed to save trace to S3: {str(e)}")
|
313
185
|
|
314
|
-
if "ui_results_url" in response.json():
|
186
|
+
if not offline_mode and "ui_results_url" in response.json():
|
315
187
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
|
316
188
|
rprint(pretty_str)
|
317
189
|
|
190
|
+
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
191
|
+
def save_annotation(self, annotation: TraceAnnotation):
|
192
|
+
json_data = {
|
193
|
+
"span_id": annotation.span_id,
|
194
|
+
"annotation": {
|
195
|
+
"text": annotation.text,
|
196
|
+
"label": annotation.label,
|
197
|
+
"score": annotation.score
|
198
|
+
}
|
199
|
+
}
|
200
|
+
|
201
|
+
response = requests.post(
|
202
|
+
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
203
|
+
json=json_data,
|
204
|
+
headers={
|
205
|
+
'Content-Type': 'application/json',
|
206
|
+
'Authorization': f'Bearer {self.judgment_api_key}',
|
207
|
+
'X-Organization-Id': self.organization_id
|
208
|
+
},
|
209
|
+
verify=True
|
210
|
+
)
|
211
|
+
|
212
|
+
if response.status_code != HTTPStatus.OK:
|
213
|
+
raise ValueError(f"Failed to save annotation: {response.text}")
|
214
|
+
|
215
|
+
return response.json()
|
216
|
+
|
318
217
|
def delete_trace(self, trace_id: str):
|
319
218
|
"""
|
320
219
|
Delete a trace from the database.
|
@@ -405,15 +304,17 @@ class TraceClient:
|
|
405
304
|
self.enable_evaluations = enable_evaluations
|
406
305
|
self.parent_trace_id = parent_trace_id
|
407
306
|
self.parent_name = parent_name
|
408
|
-
self.
|
409
|
-
self.
|
307
|
+
self.trace_spans: List[TraceSpan] = []
|
308
|
+
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
309
|
+
self.evaluation_runs: List[EvaluationRun] = []
|
310
|
+
self.annotations: List[TraceAnnotation] = []
|
410
311
|
self.start_time = time.time()
|
411
312
|
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
|
412
313
|
self.visited_nodes = []
|
413
314
|
self.executed_tools = []
|
414
315
|
self.executed_node_tools = []
|
415
316
|
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
416
|
-
|
317
|
+
|
417
318
|
def get_current_span(self):
|
418
319
|
"""Get the current span from the context var"""
|
419
320
|
return current_span_var.get()
|
@@ -443,9 +344,7 @@ class TraceClient:
|
|
443
344
|
|
444
345
|
self._span_depths[span_id] = current_depth # Store depth by span_id
|
445
346
|
|
446
|
-
|
447
|
-
type="enter",
|
448
|
-
function=name,
|
347
|
+
span = TraceSpan(
|
449
348
|
span_id=span_id,
|
450
349
|
trace_id=self.trace_id,
|
451
350
|
depth=current_depth,
|
@@ -453,25 +352,15 @@ class TraceClient:
|
|
453
352
|
created_at=start_time,
|
454
353
|
span_type=span_type,
|
455
354
|
parent_span_id=parent_span_id,
|
355
|
+
function=name,
|
456
356
|
)
|
457
|
-
self.
|
357
|
+
self.add_span(span)
|
458
358
|
|
459
359
|
try:
|
460
360
|
yield self
|
461
361
|
finally:
|
462
362
|
duration = time.time() - start_time
|
463
|
-
|
464
|
-
self.add_entry(TraceEntry(
|
465
|
-
type="exit",
|
466
|
-
function=name,
|
467
|
-
span_id=span_id, # Use the same span_id for exit
|
468
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
469
|
-
depth=exit_depth,
|
470
|
-
message=f"← {name}",
|
471
|
-
created_at=time.time(),
|
472
|
-
duration=duration,
|
473
|
-
span_type=span_type,
|
474
|
-
))
|
363
|
+
span.duration = duration
|
475
364
|
# Clean up depth tracking for this span_id
|
476
365
|
if span_id in self._span_depths:
|
477
366
|
del self._span_depths[span_id]
|
@@ -528,19 +417,20 @@ class TraceClient:
|
|
528
417
|
tools_called=tools_called,
|
529
418
|
expected_tools=expected_tools,
|
530
419
|
additional_metadata=additional_metadata,
|
531
|
-
trace_id=self.trace_id
|
532
420
|
)
|
533
421
|
else:
|
534
422
|
raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
|
535
423
|
|
536
424
|
# Check examples before creating evaluation run
|
537
|
-
|
425
|
+
|
426
|
+
# check_examples([example], scorers)
|
538
427
|
|
539
428
|
# --- Modification: Capture span_id immediately ---
|
540
429
|
# span_id_at_eval_call = current_span_var.get()
|
541
430
|
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
542
431
|
# Prioritize explicitly passed span_id, fallback to context var
|
543
|
-
|
432
|
+
current_span_ctx_var = current_span_var.get()
|
433
|
+
span_id_to_use = span_id if span_id is not None else current_span_ctx_var if current_span_ctx_var is not None else self.tracer.get_current_span()
|
544
434
|
# print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
|
545
435
|
# --- End Modification ---
|
546
436
|
|
@@ -550,7 +440,7 @@ class TraceClient:
|
|
550
440
|
log_results=log_results,
|
551
441
|
project_name=self.project_name,
|
552
442
|
eval_name=f"{self.name.capitalize()}-"
|
553
|
-
f"{
|
443
|
+
f"{span_id_to_use}-" # Keep original eval name format using context var if available
|
554
444
|
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
555
445
|
examples=[example],
|
556
446
|
scorers=scorers,
|
@@ -571,290 +461,60 @@ class TraceClient:
|
|
571
461
|
# --- End Modification ---
|
572
462
|
|
573
463
|
if current_span_id:
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
584
|
-
function_name = entry.function
|
585
|
-
break
|
586
|
-
|
587
|
-
# Get depth for the current span
|
588
|
-
current_depth = self._span_depths.get(current_span_id, 0)
|
589
|
-
|
590
|
-
self.add_entry(TraceEntry(
|
591
|
-
type="evaluation",
|
592
|
-
function=function_name,
|
593
|
-
span_id=current_span_id, # Associate with current span
|
594
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
595
|
-
depth=current_depth,
|
596
|
-
message=f"Evaluation results for {function_name}",
|
597
|
-
created_at=time.time(),
|
598
|
-
evaluation_runs=[eval_run],
|
599
|
-
duration=duration,
|
600
|
-
span_type="evaluation"
|
601
|
-
))
|
602
|
-
|
464
|
+
span = self.span_id_to_span[current_span_id]
|
465
|
+
span.evaluation_runs.append(eval_run)
|
466
|
+
self.evaluation_runs.append(eval_run)
|
467
|
+
|
468
|
+
def add_annotation(self, annotation: TraceAnnotation):
|
469
|
+
"""Add an annotation to this trace context"""
|
470
|
+
self.annotations.append(annotation)
|
471
|
+
return self
|
472
|
+
|
603
473
|
def record_input(self, inputs: dict):
|
604
474
|
current_span_id = current_span_var.get()
|
605
475
|
if current_span_id:
|
606
|
-
|
607
|
-
|
608
|
-
function_name = "unknown_function" # Default
|
609
|
-
for entry in reversed(self.entries):
|
610
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
611
|
-
entry_span_type = entry.span_type
|
612
|
-
function_name = entry.function
|
613
|
-
break
|
614
|
-
|
615
|
-
self.add_entry(TraceEntry(
|
616
|
-
type="input",
|
617
|
-
function=function_name,
|
618
|
-
span_id=current_span_id, # Use current span_id from context
|
619
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
620
|
-
depth=current_depth,
|
621
|
-
message=f"Inputs to {function_name}",
|
622
|
-
created_at=time.time(),
|
623
|
-
inputs=inputs,
|
624
|
-
span_type=entry_span_type,
|
625
|
-
))
|
626
|
-
# Removed else block - original didn't have one
|
476
|
+
span = self.span_id_to_span[current_span_id]
|
477
|
+
span.inputs = inputs
|
627
478
|
|
628
|
-
async def _update_coroutine_output(self,
|
479
|
+
async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
|
629
480
|
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
630
481
|
try:
|
631
482
|
result = await coroutine
|
632
|
-
|
483
|
+
span.output = result
|
633
484
|
return result
|
634
485
|
except Exception as e:
|
635
|
-
|
486
|
+
span.output = f"Error: {str(e)}"
|
636
487
|
raise
|
637
488
|
|
638
489
|
def record_output(self, output: Any):
|
639
490
|
current_span_id = current_span_var.get()
|
640
491
|
if current_span_id:
|
641
|
-
|
642
|
-
|
643
|
-
function_name = "unknown_function" # Default
|
644
|
-
for entry in reversed(self.entries):
|
645
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
646
|
-
entry_span_type = entry.span_type
|
647
|
-
function_name = entry.function
|
648
|
-
break
|
649
|
-
|
650
|
-
entry = TraceEntry(
|
651
|
-
type="output",
|
652
|
-
function=function_name,
|
653
|
-
span_id=current_span_id, # Use current span_id from context
|
654
|
-
depth=current_depth,
|
655
|
-
message=f"Output from {function_name}",
|
656
|
-
created_at=time.time(),
|
657
|
-
output="<pending>" if inspect.iscoroutine(output) else output,
|
658
|
-
span_type=entry_span_type,
|
659
|
-
trace_id=self.trace_id # Added trace_id for consistency
|
660
|
-
)
|
661
|
-
self.add_entry(entry)
|
492
|
+
span = self.span_id_to_span[current_span_id]
|
493
|
+
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
662
494
|
|
663
495
|
if inspect.iscoroutine(output):
|
664
|
-
asyncio.create_task(self._update_coroutine_output(
|
496
|
+
asyncio.create_task(self._update_coroutine_output(span, output))
|
665
497
|
|
666
|
-
return
|
498
|
+
return span # Return the created entry
|
667
499
|
# Removed else block - original didn't have one
|
668
500
|
return None # Return None if no span_id found
|
669
501
|
|
670
|
-
def
|
671
|
-
"""Add a trace
|
672
|
-
self.
|
502
|
+
def add_span(self, span: TraceSpan):
|
503
|
+
"""Add a trace span to this trace context"""
|
504
|
+
self.trace_spans.append(span)
|
505
|
+
self.span_id_to_span[span.span_id] = span
|
673
506
|
return self
|
674
507
|
|
675
508
|
def print(self):
|
676
509
|
"""Print the complete trace with proper visual structure"""
|
677
|
-
for
|
678
|
-
|
679
|
-
|
680
|
-
def print_hierarchical(self):
|
681
|
-
"""Print the trace in a hierarchical structure based on parent-child relationships"""
|
682
|
-
# First, build a map of spans
|
683
|
-
spans = {}
|
684
|
-
root_spans = []
|
685
|
-
|
686
|
-
# Collect all enter events first
|
687
|
-
for entry in self.entries:
|
688
|
-
if entry.type == "enter":
|
689
|
-
spans[entry.function] = {
|
690
|
-
"name": entry.function,
|
691
|
-
"depth": entry.depth,
|
692
|
-
"parent_id": entry.parent_span_id,
|
693
|
-
"children": []
|
694
|
-
}
|
695
|
-
|
696
|
-
# If no parent, it's a root span
|
697
|
-
if not entry.parent_span_id:
|
698
|
-
root_spans.append(entry.function)
|
699
|
-
elif entry.parent_span_id not in spans:
|
700
|
-
# If parent doesn't exist yet, temporarily treat as root
|
701
|
-
# (we'll fix this later)
|
702
|
-
root_spans.append(entry.function)
|
703
|
-
|
704
|
-
# Build parent-child relationships
|
705
|
-
for span_name, span in spans.items():
|
706
|
-
parent = span["parent_id"]
|
707
|
-
if parent and parent in spans:
|
708
|
-
spans[parent]["children"].append(span_name)
|
709
|
-
# Remove from root spans if it was temporarily there
|
710
|
-
if span_name in root_spans:
|
711
|
-
root_spans.remove(span_name)
|
712
|
-
|
713
|
-
# Now print the hierarchy
|
714
|
-
def print_span(span_name, level=0):
|
715
|
-
if span_name not in spans:
|
716
|
-
return
|
717
|
-
|
718
|
-
span = spans[span_name]
|
719
|
-
indent = " " * level
|
720
|
-
parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
|
721
|
-
print(f"{indent}→ {span_name}{parent_info}")
|
722
|
-
|
723
|
-
# Print children
|
724
|
-
for child in span["children"]:
|
725
|
-
print_span(child, level + 1)
|
726
|
-
|
727
|
-
# Print starting with root spans
|
728
|
-
print("\nHierarchical Trace Structure:")
|
729
|
-
for root in root_spans:
|
730
|
-
print_span(root)
|
510
|
+
for span in self.trace_spans:
|
511
|
+
span.print_span()
|
731
512
|
|
732
513
|
def get_duration(self) -> float:
|
733
514
|
"""
|
734
515
|
Get the total duration of this trace
|
735
516
|
"""
|
736
517
|
return time.time() - self.start_time
|
737
|
-
|
738
|
-
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
739
|
-
"""
|
740
|
-
Condenses trace entries into a single entry for each span instance,
|
741
|
-
preserving parent-child span relationships using span_id and parent_span_id.
|
742
|
-
"""
|
743
|
-
spans_by_id: Dict[str, dict] = {}
|
744
|
-
evaluation_runs: List[EvaluationRun] = []
|
745
|
-
|
746
|
-
# First pass: Group entries by span_id and gather data
|
747
|
-
for entry in entries:
|
748
|
-
span_id = entry.get("span_id")
|
749
|
-
if not span_id:
|
750
|
-
continue # Skip entries without a span_id (should not happen)
|
751
|
-
|
752
|
-
if entry["type"] == "enter":
|
753
|
-
if span_id not in spans_by_id:
|
754
|
-
spans_by_id[span_id] = {
|
755
|
-
"span_id": span_id,
|
756
|
-
"function": entry["function"],
|
757
|
-
"depth": entry["depth"], # Use the depth recorded at entry time
|
758
|
-
"created_at": entry["created_at"],
|
759
|
-
"trace_id": entry["trace_id"],
|
760
|
-
"parent_span_id": entry.get("parent_span_id"),
|
761
|
-
"span_type": entry.get("span_type", "span"),
|
762
|
-
"inputs": None,
|
763
|
-
"output": None,
|
764
|
-
"evaluation_runs": [],
|
765
|
-
"duration": None
|
766
|
-
}
|
767
|
-
# Handle potential duplicate enter events if necessary (e.g., log warning)
|
768
|
-
|
769
|
-
elif span_id in spans_by_id:
|
770
|
-
current_span_data = spans_by_id[span_id]
|
771
|
-
|
772
|
-
if entry["type"] == "input" and entry["inputs"]:
|
773
|
-
# Merge inputs if multiple are recorded, or just assign
|
774
|
-
if current_span_data["inputs"] is None:
|
775
|
-
current_span_data["inputs"] = entry["inputs"]
|
776
|
-
elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
|
777
|
-
current_span_data["inputs"].update(entry["inputs"])
|
778
|
-
# Add more sophisticated merging if needed
|
779
|
-
|
780
|
-
elif entry["type"] == "output" and "output" in entry:
|
781
|
-
current_span_data["output"] = entry["output"]
|
782
|
-
|
783
|
-
elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
|
784
|
-
if current_span_data.get("evaluation_runs") is not None:
|
785
|
-
evaluation_runs.extend(entry["evaluation_runs"])
|
786
|
-
|
787
|
-
elif entry["type"] == "exit":
|
788
|
-
if current_span_data["duration"] is None: # Calculate duration only once
|
789
|
-
start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
|
790
|
-
end_time = datetime.fromisoformat(entry["created_at"])
|
791
|
-
current_span_data["duration"] = (end_time - start_time).total_seconds()
|
792
|
-
# Update depth if exit depth is different (though current span() implementation keeps it same)
|
793
|
-
# current_span_data["depth"] = entry["depth"]
|
794
|
-
|
795
|
-
# Convert dictionary to a list initially for easier access
|
796
|
-
spans_list = list(spans_by_id.values())
|
797
|
-
|
798
|
-
# Build tree structure (adjacency list) and find roots
|
799
|
-
children_map: Dict[Optional[str], List[dict]] = {}
|
800
|
-
roots = []
|
801
|
-
span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
|
802
|
-
|
803
|
-
for span in spans_list:
|
804
|
-
parent_id = span.get("parent_span_id")
|
805
|
-
if parent_id is None:
|
806
|
-
roots.append(span)
|
807
|
-
else:
|
808
|
-
if parent_id not in children_map:
|
809
|
-
children_map[parent_id] = []
|
810
|
-
children_map[parent_id].append(span)
|
811
|
-
|
812
|
-
# Sort roots by timestamp
|
813
|
-
roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
814
|
-
|
815
|
-
# Perform depth-first traversal to get the final sorted list
|
816
|
-
sorted_condensed_list = []
|
817
|
-
visited = set() # To handle potential cycles, though unlikely with UUIDs
|
818
|
-
|
819
|
-
def dfs(span_data):
|
820
|
-
span_id = span_data['span_id']
|
821
|
-
if span_id in visited:
|
822
|
-
return # Avoid infinite loops in case of cycles
|
823
|
-
visited.add(span_id)
|
824
|
-
|
825
|
-
sorted_condensed_list.append(span_data) # Add parent before children
|
826
|
-
|
827
|
-
# Get children, sort them by created_at, and visit them
|
828
|
-
span_children = children_map.get(span_id, [])
|
829
|
-
span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
830
|
-
for child in span_children:
|
831
|
-
# Ensure the child exists in our map before recursing
|
832
|
-
if child['span_id'] in span_map:
|
833
|
-
dfs(child)
|
834
|
-
else:
|
835
|
-
# This case might indicate an issue, but we'll add the child directly
|
836
|
-
# if its parent was processed but the child itself wasn't in the initial list?
|
837
|
-
# Or if the child's 'enter' event was missing. For robustness, add it.
|
838
|
-
if child['span_id'] not in visited:
|
839
|
-
visited.add(child['span_id'])
|
840
|
-
sorted_condensed_list.append(child)
|
841
|
-
|
842
|
-
|
843
|
-
# Start DFS from each root
|
844
|
-
for root_span in roots:
|
845
|
-
if root_span['span_id'] not in visited:
|
846
|
-
dfs(root_span)
|
847
|
-
|
848
|
-
# Handle spans that might not have been reachable from roots (orphans)
|
849
|
-
# Though ideally, all spans should descend from a root.
|
850
|
-
for span_data in spans_list:
|
851
|
-
if span_data['span_id'] not in visited:
|
852
|
-
# Decide how to handle orphans, maybe append them at the end sorted by time?
|
853
|
-
# For now, let's just add them to ensure they aren't lost.
|
854
|
-
sorted_condensed_list.append(span_data)
|
855
|
-
|
856
|
-
|
857
|
-
return sorted_condensed_list, evaluation_runs
|
858
518
|
|
859
519
|
def save(self, overwrite: bool = False) -> Tuple[str, dict]:
|
860
520
|
"""
|
@@ -863,44 +523,36 @@ class TraceClient:
|
|
863
523
|
"""
|
864
524
|
# Calculate total elapsed time
|
865
525
|
total_duration = self.get_duration()
|
866
|
-
|
867
|
-
raw_entries = [entry.to_dict() for entry in self.entries]
|
868
|
-
|
869
|
-
condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
|
870
526
|
|
871
527
|
# Only count tokens for actual LLM API call spans
|
872
528
|
llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
|
873
|
-
for
|
874
|
-
|
529
|
+
for span in self.trace_spans:
|
530
|
+
span_function_name = span.function # Get function name safely
|
875
531
|
# Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
|
876
|
-
|
877
|
-
has_api_suffix = any(suffix in
|
878
|
-
output_is_dict = isinstance(
|
532
|
+
is_llm_span = span.span_type == "llm"
|
533
|
+
has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
|
534
|
+
output_is_dict = isinstance(span.output, dict)
|
879
535
|
|
880
536
|
# --- DEBUG PRINT 1: Check if condition passes ---
|
881
537
|
# if is_llm_entry and has_api_suffix and output_is_dict:
|
882
|
-
# # print(f"[DEBUG TraceClient.save] Processing entry: {entry.get('span_id')} ({entry_function_name}) - Condition PASSED")
|
883
538
|
# elif is_llm_entry:
|
884
539
|
# # Print why it failed if it was an LLM entry
|
885
|
-
# print(f"[DEBUG TraceClient.save] Skipping LLM entry: {entry.get('span_id')} ({entry_function_name}) - Suffix Match: {has_api_suffix}, Output is Dict: {output_is_dict}")
|
886
540
|
# # --- END DEBUG ---
|
887
541
|
|
888
|
-
if
|
889
|
-
output =
|
542
|
+
if is_llm_span and has_api_suffix and output_is_dict:
|
543
|
+
output = span.output
|
890
544
|
usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
|
891
545
|
|
892
546
|
# --- DEBUG PRINT 2: Check extracted usage ---
|
893
|
-
# print(f"[DEBUG TraceClient.save] Extracted usage dict: {usage}")
|
894
547
|
# --- END DEBUG ---
|
895
548
|
|
896
549
|
# --- NEW: Extract model_name correctly from nested inputs ---
|
897
550
|
model_name = None
|
898
|
-
|
899
|
-
|
900
|
-
if entry_inputs:
|
551
|
+
span_inputs = span.inputs
|
552
|
+
if span_inputs:
|
901
553
|
# Try common locations for model name within the inputs structure
|
902
|
-
invocation_params =
|
903
|
-
serialized_data =
|
554
|
+
invocation_params = span_inputs.get("invocation_params", {})
|
555
|
+
serialized_data = span_inputs.get("serialized", {})
|
904
556
|
|
905
557
|
# Look in invocation_params (often directly contains model)
|
906
558
|
if isinstance(invocation_params, dict):
|
@@ -920,10 +572,9 @@ class TraceClient:
|
|
920
572
|
|
921
573
|
# Fallback: Check top-level of inputs itself (less likely for callbacks)
|
922
574
|
if not model_name:
|
923
|
-
model_name =
|
575
|
+
model_name = span_inputs.get("model")
|
924
576
|
|
925
577
|
|
926
|
-
# print(f"[DEBUG TraceClient.save] Determined model_name: {model_name}") # DEBUG Model Name
|
927
578
|
# --- END NEW ---
|
928
579
|
|
929
580
|
prompt_tokens = 0
|
@@ -985,7 +636,7 @@ class TraceClient:
|
|
985
636
|
if "usage" not in output:
|
986
637
|
output["usage"] = {} # Initialize if missing
|
987
638
|
elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
|
988
|
-
print(f"[WARN TraceClient.save] Output 'usage' for span {
|
639
|
+
print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
|
989
640
|
output["usage"] = {} # Reset to dict
|
990
641
|
|
991
642
|
output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
|
@@ -993,10 +644,10 @@ class TraceClient:
|
|
993
644
|
output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
|
994
645
|
except Exception as e:
|
995
646
|
# If cost calculation fails, continue without adding costs
|
996
|
-
print(f"Error calculating cost for model '{model_name}' (span: {
|
647
|
+
print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
|
997
648
|
pass
|
998
649
|
else:
|
999
|
-
print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {
|
650
|
+
print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
|
1000
651
|
|
1001
652
|
|
1002
653
|
# Create trace document - Always use standard keys for top-level counts
|
@@ -1006,20 +657,258 @@ class TraceClient:
|
|
1006
657
|
"project_name": self.project_name,
|
1007
658
|
"created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
|
1008
659
|
"duration": total_duration,
|
1009
|
-
"entries":
|
1010
|
-
"evaluation_runs": evaluation_runs,
|
660
|
+
"entries": [span.model_dump() for span in self.trace_spans],
|
661
|
+
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
1011
662
|
"overwrite": overwrite,
|
663
|
+
"offline_mode": self.tracer.offline_mode,
|
1012
664
|
"parent_trace_id": self.parent_trace_id,
|
1013
665
|
"parent_name": self.parent_name
|
1014
666
|
}
|
1015
667
|
# --- Log trace data before saving ---
|
1016
|
-
self.trace_manager_client.save_trace(trace_data)
|
668
|
+
self.trace_manager_client.save_trace(trace_data, offline_mode=self.tracer.offline_mode)
|
669
|
+
|
670
|
+
# upload annotations
|
671
|
+
# TODO: batch to the log endpoint
|
672
|
+
for annotation in self.annotations:
|
673
|
+
self.trace_manager_client.save_annotation(annotation)
|
1017
674
|
|
1018
675
|
return self.trace_id, trace_data
|
1019
676
|
|
1020
677
|
def delete(self):
|
1021
678
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
1022
679
|
|
680
|
+
|
681
|
+
class _DeepTracer:
|
682
|
+
_instance: Optional["_DeepTracer"] = None
|
683
|
+
_lock: threading.Lock = threading.Lock()
|
684
|
+
_refcount: int = 0
|
685
|
+
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
|
686
|
+
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
|
687
|
+
|
688
|
+
def _get_qual_name(self, frame) -> str:
|
689
|
+
func_name = frame.f_code.co_name
|
690
|
+
module_name = frame.f_globals.get("__name__", "unknown_module")
|
691
|
+
|
692
|
+
try:
|
693
|
+
func = frame.f_globals.get(func_name)
|
694
|
+
if func is None:
|
695
|
+
return f"{module_name}.{func_name}"
|
696
|
+
if hasattr(func, "__qualname__"):
|
697
|
+
return f"{module_name}.{func.__qualname__}"
|
698
|
+
except Exception:
|
699
|
+
return f"{module_name}.{func_name}"
|
700
|
+
|
701
|
+
def __new__(cls):
|
702
|
+
with cls._lock:
|
703
|
+
if cls._instance is None:
|
704
|
+
cls._instance = super().__new__(cls)
|
705
|
+
return cls._instance
|
706
|
+
|
707
|
+
def _should_trace(self, frame):
|
708
|
+
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
709
|
+
# frames in the call stack that we've already determined should be skipped
|
710
|
+
skip_stack = self._skip_stack.get()
|
711
|
+
if len(skip_stack) > 0:
|
712
|
+
return False
|
713
|
+
|
714
|
+
func_name = frame.f_code.co_name
|
715
|
+
module_name = frame.f_globals.get("__name__", None)
|
716
|
+
|
717
|
+
func = frame.f_globals.get(func_name)
|
718
|
+
if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
|
719
|
+
return False
|
720
|
+
|
721
|
+
if (
|
722
|
+
not module_name
|
723
|
+
or func_name.startswith("<") # ex: <listcomp>
|
724
|
+
or func_name.startswith("__") and func_name != "__call__" # dunders
|
725
|
+
or not self._is_user_code(frame.f_code.co_filename)
|
726
|
+
):
|
727
|
+
return False
|
728
|
+
|
729
|
+
return True
|
730
|
+
|
731
|
+
@functools.cache
|
732
|
+
def _is_user_code(self, filename: str):
|
733
|
+
return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
734
|
+
|
735
|
+
def _trace(self, frame: types.FrameType, event: str, arg: Any):
|
736
|
+
frame.f_trace_lines = False
|
737
|
+
frame.f_trace_opcodes = False
|
738
|
+
|
739
|
+
|
740
|
+
if not self._should_trace(frame):
|
741
|
+
return
|
742
|
+
|
743
|
+
if event not in ("call", "return", "exception"):
|
744
|
+
return
|
745
|
+
|
746
|
+
current_trace = current_trace_var.get()
|
747
|
+
if not current_trace:
|
748
|
+
return
|
749
|
+
|
750
|
+
parent_span_id = current_span_var.get()
|
751
|
+
if not parent_span_id:
|
752
|
+
return
|
753
|
+
|
754
|
+
qual_name = self._get_qual_name(frame)
|
755
|
+
skip_stack = self._skip_stack.get()
|
756
|
+
|
757
|
+
if event == "call":
|
758
|
+
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
759
|
+
# push it again to track nesting depth and skip
|
760
|
+
# As an optimization, we only care about duplicate qual_names.
|
761
|
+
if skip_stack:
|
762
|
+
if qual_name == skip_stack[-1]:
|
763
|
+
skip_stack.append(qual_name)
|
764
|
+
self._skip_stack.set(skip_stack)
|
765
|
+
return
|
766
|
+
|
767
|
+
should_trace = self._should_trace(frame)
|
768
|
+
|
769
|
+
if not should_trace:
|
770
|
+
if not skip_stack:
|
771
|
+
self._skip_stack.set([qual_name])
|
772
|
+
return
|
773
|
+
elif event == "return":
|
774
|
+
# If we have entries in skip stack and current qual_name matches the top entry,
|
775
|
+
# pop it to track exiting from the skipped section
|
776
|
+
if skip_stack and qual_name == skip_stack[-1]:
|
777
|
+
skip_stack.pop()
|
778
|
+
self._skip_stack.set(skip_stack)
|
779
|
+
return
|
780
|
+
|
781
|
+
if skip_stack:
|
782
|
+
return
|
783
|
+
|
784
|
+
span_stack = self._span_stack.get()
|
785
|
+
if event == "call":
|
786
|
+
if not self._should_trace(frame):
|
787
|
+
return
|
788
|
+
|
789
|
+
span_id = str(uuid.uuid4())
|
790
|
+
|
791
|
+
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
792
|
+
depth = parent_depth + 1
|
793
|
+
|
794
|
+
current_trace._span_depths[span_id] = depth
|
795
|
+
|
796
|
+
start_time = time.time()
|
797
|
+
|
798
|
+
span_stack.append({
|
799
|
+
"span_id": span_id,
|
800
|
+
"parent_span_id": parent_span_id,
|
801
|
+
"function": qual_name,
|
802
|
+
"start_time": start_time
|
803
|
+
})
|
804
|
+
self._span_stack.set(span_stack)
|
805
|
+
|
806
|
+
token = current_span_var.set(span_id)
|
807
|
+
frame.f_locals["_judgment_span_token"] = token
|
808
|
+
|
809
|
+
span = TraceSpan(
|
810
|
+
span_id=span_id,
|
811
|
+
trace_id=current_trace.trace_id,
|
812
|
+
depth=depth,
|
813
|
+
message=qual_name,
|
814
|
+
created_at=start_time,
|
815
|
+
span_type="span",
|
816
|
+
parent_span_id=parent_span_id,
|
817
|
+
function=qual_name
|
818
|
+
)
|
819
|
+
current_trace.add_span(span)
|
820
|
+
|
821
|
+
inputs = {}
|
822
|
+
try:
|
823
|
+
args_info = inspect.getargvalues(frame)
|
824
|
+
for arg in args_info.args:
|
825
|
+
try:
|
826
|
+
inputs[arg] = args_info.locals.get(arg)
|
827
|
+
except:
|
828
|
+
inputs[arg] = "<<Unserializable>>"
|
829
|
+
current_trace.record_input(inputs)
|
830
|
+
except Exception as e:
|
831
|
+
current_trace.record_input({
|
832
|
+
"error": str(e)
|
833
|
+
})
|
834
|
+
|
835
|
+
elif event == "return":
|
836
|
+
if not span_stack:
|
837
|
+
return
|
838
|
+
|
839
|
+
current_id = current_span_var.get()
|
840
|
+
|
841
|
+
span_data = None
|
842
|
+
for i, entry in enumerate(reversed(span_stack)):
|
843
|
+
if entry["span_id"] == current_id:
|
844
|
+
span_data = span_stack.pop(-(i+1))
|
845
|
+
self._span_stack.set(span_stack)
|
846
|
+
break
|
847
|
+
|
848
|
+
if not span_data:
|
849
|
+
return
|
850
|
+
|
851
|
+
start_time = span_data["start_time"]
|
852
|
+
duration = time.time() - start_time
|
853
|
+
|
854
|
+
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
855
|
+
|
856
|
+
if arg is not None:
|
857
|
+
# exception handling will take priority.
|
858
|
+
current_trace.record_output(arg)
|
859
|
+
|
860
|
+
if span_data["span_id"] in current_trace._span_depths:
|
861
|
+
del current_trace._span_depths[span_data["span_id"]]
|
862
|
+
|
863
|
+
if span_stack:
|
864
|
+
current_span_var.set(span_stack[-1]["span_id"])
|
865
|
+
else:
|
866
|
+
current_span_var.set(span_data["parent_span_id"])
|
867
|
+
|
868
|
+
if "_judgment_span_token" in frame.f_locals:
|
869
|
+
current_span_var.reset(frame.f_locals["_judgment_span_token"])
|
870
|
+
|
871
|
+
elif event == "exception":
|
872
|
+
exc_type, exc_value, exc_traceback = arg
|
873
|
+
formatted_exception = {
|
874
|
+
"type": exc_type.__name__,
|
875
|
+
"message": str(exc_value),
|
876
|
+
"traceback": traceback.format_tb(exc_traceback)
|
877
|
+
}
|
878
|
+
current_trace = current_trace_var.get()
|
879
|
+
current_trace.record_output({
|
880
|
+
"error": formatted_exception
|
881
|
+
})
|
882
|
+
|
883
|
+
return self._trace
|
884
|
+
|
885
|
+
def __enter__(self):
|
886
|
+
with self._lock:
|
887
|
+
self._refcount += 1
|
888
|
+
if self._refcount == 1:
|
889
|
+
self._skip_stack.set([])
|
890
|
+
self._span_stack.set([])
|
891
|
+
sys.settrace(self._trace)
|
892
|
+
threading.settrace(self._trace)
|
893
|
+
return self
|
894
|
+
|
895
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
896
|
+
with self._lock:
|
897
|
+
self._refcount -= 1
|
898
|
+
if self._refcount == 0:
|
899
|
+
sys.settrace(None)
|
900
|
+
threading.settrace(None)
|
901
|
+
|
902
|
+
|
903
|
+
def log(self, message: str, level: str = "info"):
|
904
|
+
""" Log a message with the span context """
|
905
|
+
current_trace = current_trace_var.get()
|
906
|
+
if current_trace:
|
907
|
+
current_trace.log(message, level)
|
908
|
+
else:
|
909
|
+
print(f"[{level}] {message}")
|
910
|
+
current_trace.record_output({"log": message})
|
911
|
+
|
1023
912
|
class Tracer:
|
1024
913
|
_instance = None
|
1025
914
|
|
@@ -1042,12 +931,17 @@ class Tracer:
|
|
1042
931
|
s3_aws_access_key_id: Optional[str] = None,
|
1043
932
|
s3_aws_secret_access_key: Optional[str] = None,
|
1044
933
|
s3_region_name: Optional[str] = None,
|
1045
|
-
|
934
|
+
offline_mode: bool = False,
|
935
|
+
deep_tracing: bool = True # Deep tracing is enabled by default
|
1046
936
|
):
|
1047
937
|
if not hasattr(self, 'initialized'):
|
1048
938
|
if not api_key:
|
1049
939
|
raise ValueError("Tracer must be configured with a Judgment API key")
|
1050
940
|
|
941
|
+
result, response = validate_api_key(api_key)
|
942
|
+
if not result:
|
943
|
+
raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
|
944
|
+
|
1051
945
|
if not organization_id:
|
1052
946
|
raise ValueError("Tracer must be configured with an Organization ID")
|
1053
947
|
if use_s3 and not s3_bucket_name:
|
@@ -1059,11 +953,11 @@ class Tracer:
|
|
1059
953
|
|
1060
954
|
self.api_key: str = api_key
|
1061
955
|
self.project_name: str = project_name
|
1062
|
-
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
1063
956
|
self.organization_id: str = organization_id
|
1064
957
|
self._current_trace: Optional[str] = None
|
1065
958
|
self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
|
1066
959
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
960
|
+
self.traces: List[Trace] = []
|
1067
961
|
self.initialized: bool = True
|
1068
962
|
self.enable_monitoring: bool = enable_monitoring
|
1069
963
|
self.enable_evaluations: bool = enable_evaluations
|
@@ -1078,6 +972,7 @@ class Tracer:
|
|
1078
972
|
aws_secret_access_key=s3_aws_secret_access_key,
|
1079
973
|
region_name=s3_region_name
|
1080
974
|
)
|
975
|
+
self.offline_mode: bool = offline_mode
|
1081
976
|
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
1082
977
|
|
1083
978
|
elif hasattr(self, 'project_name') and self.project_name != project_name:
|
@@ -1087,6 +982,12 @@ class Tracer:
|
|
1087
982
|
"To use a different project name, ensure the first Tracer initialization uses the desired project name.",
|
1088
983
|
RuntimeWarning
|
1089
984
|
)
|
985
|
+
|
986
|
+
def set_current_span(self, span_id: str):
|
987
|
+
self.current_span_id = span_id
|
988
|
+
|
989
|
+
def get_current_span(self) -> Optional[str]:
|
990
|
+
return getattr(self, 'current_span_id', None)
|
1090
991
|
|
1091
992
|
def set_current_trace(self, trace: TraceClient):
|
1092
993
|
"""
|
@@ -1119,45 +1020,6 @@ class Tracer:
|
|
1119
1020
|
"""Returns the TraceClient instance currently marked as active by the handler."""
|
1120
1021
|
return self._active_trace_client
|
1121
1022
|
|
1122
|
-
def _apply_deep_tracing(self, func, span_type="span"):
|
1123
|
-
"""
|
1124
|
-
Apply deep tracing to all functions in the same module as the given function.
|
1125
|
-
|
1126
|
-
Args:
|
1127
|
-
func: The function being traced
|
1128
|
-
span_type: Type of span to use for traced functions
|
1129
|
-
|
1130
|
-
Returns:
|
1131
|
-
A tuple of (module, original_functions_dict) where original_functions_dict
|
1132
|
-
contains the original functions that were replaced with traced versions.
|
1133
|
-
"""
|
1134
|
-
module = inspect.getmodule(func)
|
1135
|
-
if not module:
|
1136
|
-
return None, {}
|
1137
|
-
|
1138
|
-
# Save original functions
|
1139
|
-
original_functions = {}
|
1140
|
-
|
1141
|
-
# Find all functions in the module
|
1142
|
-
for name, obj in inspect.getmembers(module, inspect.isfunction):
|
1143
|
-
# Skip already wrapped functions
|
1144
|
-
if hasattr(obj, '_judgment_traced'):
|
1145
|
-
continue
|
1146
|
-
|
1147
|
-
# Create a traced version of the function
|
1148
|
-
# Always use default span type "span" for child functions
|
1149
|
-
traced_func = _create_deep_tracing_wrapper(obj, self, "span")
|
1150
|
-
|
1151
|
-
# Mark the function as traced to avoid double wrapping
|
1152
|
-
traced_func._judgment_traced = True
|
1153
|
-
|
1154
|
-
# Save the original function
|
1155
|
-
original_functions[name] = obj
|
1156
|
-
|
1157
|
-
# Replace with traced version
|
1158
|
-
setattr(module, name, traced_func)
|
1159
|
-
|
1160
|
-
return module, original_functions
|
1161
1023
|
|
1162
1024
|
@contextmanager
|
1163
1025
|
def trace(
|
@@ -1204,6 +1066,23 @@ class Tracer:
|
|
1204
1066
|
finally:
|
1205
1067
|
# Reset the context variable
|
1206
1068
|
current_trace_var.reset(token)
|
1069
|
+
|
1070
|
+
|
1071
|
+
def log(self, msg: str, label: str = "log", score: int = 1):
|
1072
|
+
"""Log a message with the current span context"""
|
1073
|
+
current_span_id = current_span_var.get()
|
1074
|
+
current_trace = current_trace_var.get()
|
1075
|
+
if current_span_id:
|
1076
|
+
annotation = TraceAnnotation(
|
1077
|
+
span_id=current_span_id,
|
1078
|
+
text=msg,
|
1079
|
+
label=label,
|
1080
|
+
score=score
|
1081
|
+
)
|
1082
|
+
|
1083
|
+
current_trace.add_annotation(annotation)
|
1084
|
+
|
1085
|
+
rprint(f"[bold]{label}:[/bold] {msg}")
|
1207
1086
|
|
1208
1087
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
|
1209
1088
|
"""
|
@@ -1239,13 +1118,6 @@ class Tracer:
|
|
1239
1118
|
if asyncio.iscoroutinefunction(func):
|
1240
1119
|
@functools.wraps(func)
|
1241
1120
|
async def async_wrapper(*args, **kwargs):
|
1242
|
-
# Check if we're already in a traced function
|
1243
|
-
if in_traced_function_var.get():
|
1244
|
-
return await func(*args, **kwargs)
|
1245
|
-
|
1246
|
-
# Set in_traced_function_var to True
|
1247
|
-
token = in_traced_function_var.set(True)
|
1248
|
-
|
1249
1121
|
# Get current trace from context
|
1250
1122
|
current_trace = current_trace_var.get()
|
1251
1123
|
|
@@ -1275,81 +1147,47 @@ class Tracer:
|
|
1275
1147
|
# This sets the current_span_var
|
1276
1148
|
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1277
1149
|
# Record inputs
|
1278
|
-
|
1279
|
-
|
1280
|
-
'kwargs': kwargs
|
1281
|
-
})
|
1150
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1151
|
+
span.record_input(inputs)
|
1282
1152
|
|
1283
|
-
# If deep tracing is enabled, apply monkey patching
|
1284
1153
|
if use_deep_tracing:
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
# Restore original functions if deep tracing was enabled
|
1291
|
-
if use_deep_tracing and module and 'original_functions' in locals():
|
1292
|
-
for name, obj in original_functions.items():
|
1293
|
-
setattr(module, name, obj)
|
1294
|
-
|
1154
|
+
with _DeepTracer():
|
1155
|
+
result = await func(*args, **kwargs)
|
1156
|
+
else:
|
1157
|
+
result = await func(*args, **kwargs)
|
1158
|
+
|
1295
1159
|
# Record output
|
1296
1160
|
span.record_output(result)
|
1297
|
-
|
1298
|
-
# Save the completed trace
|
1299
|
-
current_trace.save(overwrite=overwrite)
|
1300
1161
|
return result
|
1301
1162
|
finally:
|
1163
|
+
# Save the completed trace
|
1164
|
+
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1165
|
+
self.traces.append(trace)
|
1166
|
+
|
1302
1167
|
# Reset trace context (span context resets automatically)
|
1303
1168
|
current_trace_var.reset(trace_token)
|
1304
|
-
# Reset in_traced_function_var
|
1305
|
-
in_traced_function_var.reset(token)
|
1306
1169
|
else:
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
'kwargs': kwargs
|
1316
|
-
})
|
1317
|
-
|
1318
|
-
# If deep tracing is enabled, apply monkey patching
|
1319
|
-
if use_deep_tracing:
|
1320
|
-
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1321
|
-
|
1322
|
-
# Execute function
|
1170
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1171
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1172
|
+
span.record_input(inputs)
|
1173
|
+
|
1174
|
+
if use_deep_tracing:
|
1175
|
+
with _DeepTracer():
|
1176
|
+
result = await func(*args, **kwargs)
|
1177
|
+
else:
|
1323
1178
|
result = await func(*args, **kwargs)
|
1324
1179
|
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
setattr(module, name, obj)
|
1329
|
-
|
1330
|
-
# Record output
|
1331
|
-
span.record_output(result)
|
1332
|
-
|
1333
|
-
return result
|
1334
|
-
finally:
|
1335
|
-
# Reset in_traced_function_var
|
1336
|
-
in_traced_function_var.reset(token)
|
1337
|
-
|
1180
|
+
span.record_output(result)
|
1181
|
+
return result
|
1182
|
+
|
1338
1183
|
return async_wrapper
|
1339
1184
|
else:
|
1340
1185
|
# Non-async function implementation with deep tracing
|
1341
1186
|
@functools.wraps(func)
|
1342
|
-
def wrapper(*args, **kwargs):
|
1343
|
-
# Check if we're already in a traced function
|
1344
|
-
if in_traced_function_var.get():
|
1345
|
-
return func(*args, **kwargs)
|
1346
|
-
|
1347
|
-
# Set in_traced_function_var to True
|
1348
|
-
token = in_traced_function_var.set(True)
|
1349
|
-
|
1187
|
+
def wrapper(*args, **kwargs):
|
1350
1188
|
# Get current trace from context
|
1351
1189
|
current_trace = current_trace_var.get()
|
1352
|
-
|
1190
|
+
|
1353
1191
|
# If there's no current trace, create a root trace
|
1354
1192
|
if not current_trace:
|
1355
1193
|
trace_id = str(uuid.uuid4())
|
@@ -1376,66 +1214,40 @@ class Tracer:
|
|
1376
1214
|
# This sets the current_span_var
|
1377
1215
|
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1378
1216
|
# Record inputs
|
1379
|
-
|
1380
|
-
|
1381
|
-
'kwargs': kwargs
|
1382
|
-
})
|
1217
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1218
|
+
span.record_input(inputs)
|
1383
1219
|
|
1384
|
-
# If deep tracing is enabled, apply monkey patching
|
1385
1220
|
if use_deep_tracing:
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
# Restore original functions if deep tracing was enabled
|
1392
|
-
if use_deep_tracing and module and 'original_functions' in locals():
|
1393
|
-
for name, obj in original_functions.items():
|
1394
|
-
setattr(module, name, obj)
|
1221
|
+
with _DeepTracer():
|
1222
|
+
result = func(*args, **kwargs)
|
1223
|
+
else:
|
1224
|
+
result = func(*args, **kwargs)
|
1395
1225
|
|
1396
1226
|
# Record output
|
1397
1227
|
span.record_output(result)
|
1398
|
-
|
1399
|
-
# Save the completed trace
|
1400
|
-
current_trace.save(overwrite=overwrite)
|
1401
1228
|
return result
|
1402
1229
|
finally:
|
1230
|
+
# Save the completed trace
|
1231
|
+
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1232
|
+
self.traces.append(trace)
|
1233
|
+
|
1403
1234
|
# Reset trace context (span context resets automatically)
|
1404
1235
|
current_trace_var.reset(trace_token)
|
1405
|
-
# Reset in_traced_function_var
|
1406
|
-
in_traced_function_var.reset(token)
|
1407
1236
|
else:
|
1408
|
-
|
1409
|
-
|
1410
|
-
|
1411
|
-
|
1412
|
-
|
1413
|
-
|
1414
|
-
|
1415
|
-
|
1416
|
-
|
1417
|
-
})
|
1418
|
-
|
1419
|
-
# If deep tracing is enabled, apply monkey patching
|
1420
|
-
if use_deep_tracing:
|
1421
|
-
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1422
|
-
|
1423
|
-
# Execute function
|
1237
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1238
|
+
|
1239
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1240
|
+
span.record_input(inputs)
|
1241
|
+
|
1242
|
+
if use_deep_tracing:
|
1243
|
+
with _DeepTracer():
|
1244
|
+
result = func(*args, **kwargs)
|
1245
|
+
else:
|
1424
1246
|
result = func(*args, **kwargs)
|
1425
1247
|
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
setattr(module, name, obj)
|
1430
|
-
|
1431
|
-
# Record output
|
1432
|
-
span.record_output(result)
|
1433
|
-
|
1434
|
-
return result
|
1435
|
-
finally:
|
1436
|
-
# Reset in_traced_function_var
|
1437
|
-
in_traced_function_var.reset(token)
|
1438
|
-
|
1248
|
+
span.record_output(result)
|
1249
|
+
return result
|
1250
|
+
|
1439
1251
|
return wrapper
|
1440
1252
|
|
1441
1253
|
def async_evaluate(self, *args, **kwargs):
|
@@ -1462,64 +1274,94 @@ class Tracer:
|
|
1462
1274
|
else:
|
1463
1275
|
warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
|
1464
1276
|
|
1465
|
-
|
1466
1277
|
def wrap(client: Any) -> Any:
|
1467
1278
|
"""
|
1468
1279
|
Wraps an API client to add tracing capabilities.
|
1469
1280
|
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1470
1281
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1471
1282
|
"""
|
1472
|
-
span_name, original_create, original_stream = _get_client_config(client)
|
1283
|
+
span_name, original_create, original_responses_create, original_stream = _get_client_config(client)
|
1284
|
+
|
1285
|
+
def _record_input_and_check_streaming(span, kwargs, is_responses=False):
|
1286
|
+
"""Record input and check for streaming"""
|
1287
|
+
is_streaming = kwargs.get("stream", False)
|
1473
1288
|
|
1474
|
-
|
1289
|
+
# Record input based on whether this is a responses endpoint
|
1290
|
+
if is_responses:
|
1291
|
+
span.record_input(kwargs)
|
1292
|
+
else:
|
1293
|
+
input_data = _format_input_data(client, **kwargs)
|
1294
|
+
span.record_input(input_data)
|
1295
|
+
|
1296
|
+
# Warn about token counting limitations with streaming
|
1297
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1298
|
+
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1299
|
+
warnings.warn(
|
1300
|
+
"OpenAI streaming calls don't include token counts by default. "
|
1301
|
+
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1302
|
+
"in your API call arguments.",
|
1303
|
+
UserWarning
|
1304
|
+
)
|
1305
|
+
|
1306
|
+
return is_streaming
|
1307
|
+
|
1308
|
+
def _format_and_record_output(span, response, is_streaming, is_async, is_responses):
|
1309
|
+
"""Format and record the output in the span"""
|
1310
|
+
if is_streaming:
|
1311
|
+
output_entry = span.record_output("<pending stream>")
|
1312
|
+
wrapper_func = _async_stream_wrapper if is_async else _sync_stream_wrapper
|
1313
|
+
return wrapper_func(response, client, output_entry)
|
1314
|
+
else:
|
1315
|
+
format_func = _format_response_output_data if is_responses else _format_output_data
|
1316
|
+
output_data = format_func(client, response)
|
1317
|
+
span.record_output(output_data)
|
1318
|
+
return response
|
1319
|
+
|
1320
|
+
def _handle_error(span, e, is_async):
|
1321
|
+
"""Handle and record errors"""
|
1322
|
+
call_type = "async" if is_async else "sync"
|
1323
|
+
print(f"Error during wrapped {call_type} API call ({span_name}): {e}")
|
1324
|
+
span.record_output({"error": str(e)})
|
1325
|
+
raise
|
1326
|
+
|
1327
|
+
# --- Traced Async Functions ---
|
1475
1328
|
async def traced_create_async(*args, **kwargs):
|
1476
|
-
# [Existing logic - unchanged]
|
1477
1329
|
current_trace = current_trace_var.get()
|
1478
1330
|
if not current_trace:
|
1479
|
-
|
1480
|
-
|
1481
|
-
else:
|
1482
|
-
return original_create(*args, **kwargs)
|
1483
|
-
|
1484
|
-
is_streaming = kwargs.get("stream", False)
|
1485
|
-
|
1331
|
+
return await original_create(*args, **kwargs)
|
1332
|
+
|
1486
1333
|
with current_trace.span(span_name, span_type="llm") as span:
|
1487
|
-
|
1488
|
-
|
1489
|
-
|
1490
|
-
# Warn about token counting limitations with streaming
|
1491
|
-
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1492
|
-
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1493
|
-
warnings.warn(
|
1494
|
-
"OpenAI streaming calls don't include token counts by default. "
|
1495
|
-
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1496
|
-
"in your API call arguments.",
|
1497
|
-
UserWarning
|
1498
|
-
)
|
1499
|
-
|
1334
|
+
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
1335
|
+
|
1500
1336
|
try:
|
1501
|
-
|
1502
|
-
|
1503
|
-
output_entry = span.record_output("<pending stream>")
|
1504
|
-
return _async_stream_wrapper(stream_iterator, client, output_entry)
|
1505
|
-
else:
|
1506
|
-
awaited_response = await original_create(*args, **kwargs)
|
1507
|
-
output_data = _format_output_data(client, awaited_response)
|
1508
|
-
span.record_output(output_data)
|
1509
|
-
return awaited_response
|
1337
|
+
response_or_iterator = await original_create(*args, **kwargs)
|
1338
|
+
return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
|
1510
1339
|
except Exception as e:
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
|
1515
|
-
|
1516
|
-
|
1340
|
+
return _handle_error(span, e, True)
|
1341
|
+
|
1342
|
+
# Async responses for OpenAI clients
|
1343
|
+
async def traced_response_create_async(*args, **kwargs):
|
1344
|
+
current_trace = current_trace_var.get()
|
1345
|
+
if not current_trace:
|
1346
|
+
return await original_responses_create(*args, **kwargs)
|
1347
|
+
|
1348
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1349
|
+
is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
|
1350
|
+
|
1351
|
+
try:
|
1352
|
+
response_or_iterator = await original_responses_create(*args, **kwargs)
|
1353
|
+
return _format_and_record_output(span, response_or_iterator, is_streaming, True, True)
|
1354
|
+
except Exception as e:
|
1355
|
+
return _handle_error(span, e, True)
|
1356
|
+
|
1357
|
+
# Function replacing .stream() for async clients
|
1517
1358
|
def traced_stream_async(*args, **kwargs):
|
1518
1359
|
current_trace = current_trace_var.get()
|
1519
1360
|
if not current_trace or not original_stream:
|
1520
1361
|
return original_stream(*args, **kwargs)
|
1362
|
+
|
1521
1363
|
original_manager = original_stream(*args, **kwargs)
|
1522
|
-
|
1364
|
+
return _TracedAsyncStreamManagerWrapper(
|
1523
1365
|
original_manager=original_manager,
|
1524
1366
|
client=client,
|
1525
1367
|
span_name=span_name,
|
@@ -1527,104 +1369,74 @@ def wrap(client: Any) -> Any:
|
|
1527
1369
|
stream_wrapper_func=_async_stream_wrapper,
|
1528
1370
|
input_kwargs=kwargs
|
1529
1371
|
)
|
1530
|
-
|
1531
|
-
|
1532
|
-
# --- Define Traced Sync Functions ---
|
1372
|
+
|
1373
|
+
# --- Traced Sync Functions ---
|
1533
1374
|
def traced_create_sync(*args, **kwargs):
|
1534
|
-
# [Existing logic - unchanged]
|
1535
1375
|
current_trace = current_trace_var.get()
|
1536
1376
|
if not current_trace:
|
1537
|
-
|
1538
|
-
|
1539
|
-
is_streaming = kwargs.get("stream", False)
|
1540
|
-
|
1377
|
+
return original_create(*args, **kwargs)
|
1378
|
+
|
1541
1379
|
with current_trace.span(span_name, span_type="llm") as span:
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
return _sync_stream_wrapper(response_or_iterator, client, output_entry)
|
1565
|
-
else:
|
1566
|
-
output_data = _format_output_data(client, response_or_iterator)
|
1567
|
-
span.record_output(output_data)
|
1568
|
-
return response_or_iterator
|
1569
|
-
|
1570
|
-
|
1380
|
+
is_streaming = _record_input_and_check_streaming(span, kwargs)
|
1381
|
+
|
1382
|
+
try:
|
1383
|
+
response_or_iterator = original_create(*args, **kwargs)
|
1384
|
+
return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
|
1385
|
+
except Exception as e:
|
1386
|
+
return _handle_error(span, e, False)
|
1387
|
+
|
1388
|
+
def traced_response_create_sync(*args, **kwargs):
|
1389
|
+
current_trace = current_trace_var.get()
|
1390
|
+
if not current_trace:
|
1391
|
+
return original_responses_create(*args, **kwargs)
|
1392
|
+
|
1393
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1394
|
+
is_streaming = _record_input_and_check_streaming(span, kwargs, is_responses=True)
|
1395
|
+
|
1396
|
+
try:
|
1397
|
+
response_or_iterator = original_responses_create(*args, **kwargs)
|
1398
|
+
return _format_and_record_output(span, response_or_iterator, is_streaming, False, True)
|
1399
|
+
except Exception as e:
|
1400
|
+
return _handle_error(span, e, False)
|
1401
|
+
|
1571
1402
|
# Function replacing sync .stream()
|
1572
1403
|
def traced_stream_sync(*args, **kwargs):
|
1573
|
-
|
1574
|
-
|
1575
|
-
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
|
1586
|
-
|
1587
|
-
|
1404
|
+
current_trace = current_trace_var.get()
|
1405
|
+
if not current_trace or not original_stream:
|
1406
|
+
return original_stream(*args, **kwargs)
|
1407
|
+
|
1408
|
+
original_manager = original_stream(*args, **kwargs)
|
1409
|
+
return _TracedSyncStreamManagerWrapper(
|
1410
|
+
original_manager=original_manager,
|
1411
|
+
client=client,
|
1412
|
+
span_name=span_name,
|
1413
|
+
trace_client=current_trace,
|
1414
|
+
stream_wrapper_func=_sync_stream_wrapper,
|
1415
|
+
input_kwargs=kwargs
|
1416
|
+
)
|
1417
|
+
|
1588
1418
|
# --- Assign Traced Methods to Client Instance ---
|
1589
|
-
# [Assignment logic remains the same]
|
1590
1419
|
if isinstance(client, (AsyncOpenAI, AsyncTogether)):
|
1591
1420
|
client.chat.completions.create = traced_create_async
|
1592
|
-
# Wrap the Responses API endpoint for OpenAI clients
|
1593
1421
|
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1594
|
-
|
1595
|
-
original_responses_create = client.responses.create
|
1596
|
-
def traced_responses(*args, **kwargs):
|
1597
|
-
# Get the current trace from contextvars
|
1598
|
-
current_trace = current_trace_var.get()
|
1599
|
-
# If no active trace, call the original
|
1600
|
-
if not current_trace:
|
1601
|
-
return original_responses_create(*args, **kwargs)
|
1602
|
-
# Trace this responses.create call
|
1603
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
1604
|
-
# Record raw input kwargs
|
1605
|
-
span.record_input(kwargs)
|
1606
|
-
# Make the actual API call
|
1607
|
-
response = original_responses_create(*args, **kwargs)
|
1608
|
-
# Record the output object
|
1609
|
-
span.record_output(response)
|
1610
|
-
return response
|
1611
|
-
# Assign the traced wrapper
|
1612
|
-
client.responses.create = traced_responses
|
1422
|
+
client.responses.create = traced_response_create_async
|
1613
1423
|
elif isinstance(client, AsyncAnthropic):
|
1614
1424
|
client.messages.create = traced_create_async
|
1615
1425
|
if original_stream:
|
1616
|
-
|
1426
|
+
client.messages.stream = traced_stream_async
|
1617
1427
|
elif isinstance(client, genai.client.AsyncClient):
|
1618
|
-
client.generate_content = traced_create_async
|
1428
|
+
client.models.generate_content = traced_create_async
|
1619
1429
|
elif isinstance(client, (OpenAI, Together)):
|
1620
|
-
|
1430
|
+
client.chat.completions.create = traced_create_sync
|
1431
|
+
if hasattr(client, "responses") and hasattr(client.responses, "create"):
|
1432
|
+
client.responses.create = traced_response_create_sync
|
1621
1433
|
elif isinstance(client, Anthropic):
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1434
|
+
client.messages.create = traced_create_sync
|
1435
|
+
if original_stream:
|
1436
|
+
client.messages.stream = traced_stream_sync
|
1625
1437
|
elif isinstance(client, genai.Client):
|
1626
|
-
|
1627
|
-
|
1438
|
+
client.models.generate_content = traced_create_sync
|
1439
|
+
|
1628
1440
|
return client
|
1629
1441
|
|
1630
1442
|
# Helper functions for client-specific operations
|
@@ -1639,19 +1451,20 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
|
|
1639
1451
|
tuple: (span_name, create_method, stream_method)
|
1640
1452
|
- span_name: String identifier for tracing
|
1641
1453
|
- create_method: Reference to the client's creation method
|
1454
|
+
- responses_method: Reference to the client's responses method (if applicable)
|
1642
1455
|
- stream_method: Reference to the client's stream method (if applicable)
|
1643
1456
|
|
1644
1457
|
Raises:
|
1645
1458
|
ValueError: If client type is not supported
|
1646
1459
|
"""
|
1647
1460
|
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1648
|
-
return "OPENAI_API_CALL", client.chat.completions.create, None
|
1461
|
+
return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None
|
1649
1462
|
elif isinstance(client, (Together, AsyncTogether)):
|
1650
|
-
return "TOGETHER_API_CALL", client.chat.completions.create, None
|
1463
|
+
return "TOGETHER_API_CALL", client.chat.completions.create, None, None
|
1651
1464
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1652
|
-
return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
|
1465
|
+
return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream
|
1653
1466
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1654
|
-
return "GOOGLE_API_CALL", client.models.generate_content, None
|
1467
|
+
return "GOOGLE_API_CALL", client.models.generate_content, None, None
|
1655
1468
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1656
1469
|
|
1657
1470
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
@@ -1677,6 +1490,26 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
1677
1490
|
"max_tokens": kwargs.get("max_tokens")
|
1678
1491
|
}
|
1679
1492
|
|
1493
|
+
def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
1494
|
+
"""Format API response data based on client type.
|
1495
|
+
|
1496
|
+
Normalizes different response formats into a consistent structure
|
1497
|
+
for tracing purposes.
|
1498
|
+
"""
|
1499
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1500
|
+
return {
|
1501
|
+
"content": response.output,
|
1502
|
+
"usage": {
|
1503
|
+
"prompt_tokens": response.usage.input_tokens,
|
1504
|
+
"completion_tokens": response.usage.output_tokens,
|
1505
|
+
"total_tokens": response.usage.total_tokens
|
1506
|
+
}
|
1507
|
+
}
|
1508
|
+
else:
|
1509
|
+
warnings.warn(f"Unsupported client type: {type(client)}")
|
1510
|
+
return {}
|
1511
|
+
|
1512
|
+
|
1680
1513
|
def _format_output_data(client: ApiClient, response: Any) -> dict:
|
1681
1514
|
"""Format API response data based on client type.
|
1682
1515
|
|
@@ -1716,117 +1549,51 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1716
1549
|
}
|
1717
1550
|
}
|
1718
1551
|
|
1719
|
-
|
1720
|
-
# These are typically utility functions, print statements, logging, etc.
|
1721
|
-
_TRACE_BLOCKLIST = {
|
1722
|
-
# Built-in functions
|
1723
|
-
'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
|
1724
|
-
'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
|
1725
|
-
'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
|
1726
|
-
# Logging functions
|
1727
|
-
'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
|
1728
|
-
# Common utility functions
|
1729
|
-
'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
|
1730
|
-
# String operations
|
1731
|
-
'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
|
1732
|
-
# Dict operations
|
1733
|
-
'get', 'items', 'keys', 'values', 'update',
|
1734
|
-
# List operations
|
1735
|
-
'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
|
1736
|
-
}
|
1737
|
-
|
1738
|
-
|
1739
|
-
# Add a new function for deep tracing at the module level
|
1740
|
-
def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
|
1552
|
+
def combine_args_kwargs(func, args, kwargs):
|
1741
1553
|
"""
|
1742
|
-
|
1743
|
-
This enables deep tracing without requiring explicit @observe decorators on every function.
|
1554
|
+
Combine positional arguments and keyword arguments into a single dictionary.
|
1744
1555
|
|
1745
1556
|
Args:
|
1746
|
-
func: The function
|
1747
|
-
|
1748
|
-
|
1557
|
+
func: The function being called
|
1558
|
+
args: Tuple of positional arguments
|
1559
|
+
kwargs: Dictionary of keyword arguments
|
1749
1560
|
|
1750
1561
|
Returns:
|
1751
|
-
A
|
1562
|
+
A dictionary combining both args and kwargs
|
1752
1563
|
"""
|
1753
|
-
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
# Create a span for this function call - use custom span_type if available
|
1787
|
-
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1788
|
-
# Record inputs
|
1789
|
-
span.record_input({
|
1790
|
-
'args': str(args),
|
1791
|
-
'kwargs': kwargs
|
1792
|
-
})
|
1793
|
-
|
1794
|
-
# Execute function
|
1795
|
-
result = await original_func(*args, **kwargs)
|
1796
|
-
|
1797
|
-
# Record output
|
1798
|
-
span.record_output(result)
|
1799
|
-
|
1800
|
-
return result
|
1801
|
-
|
1802
|
-
return async_deep_wrapper
|
1803
|
-
else:
|
1804
|
-
@functools.wraps(func)
|
1805
|
-
def deep_wrapper(*args, **kwargs):
|
1806
|
-
# Get current trace from context
|
1807
|
-
current_trace = current_trace_var.get()
|
1808
|
-
|
1809
|
-
# If no trace context, just call the function
|
1810
|
-
if not current_trace:
|
1811
|
-
return original_func(*args, **kwargs)
|
1812
|
-
|
1813
|
-
# Create a span for this function call - use custom span_type if available
|
1814
|
-
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1815
|
-
# Record inputs
|
1816
|
-
span.record_input({
|
1817
|
-
'args': str(args),
|
1818
|
-
'kwargs': kwargs
|
1819
|
-
})
|
1820
|
-
|
1821
|
-
# Execute function
|
1822
|
-
result = original_func(*args, **kwargs)
|
1823
|
-
|
1824
|
-
# Record output
|
1825
|
-
span.record_output(result)
|
1826
|
-
|
1827
|
-
return result
|
1828
|
-
|
1829
|
-
return deep_wrapper
|
1564
|
+
try:
|
1565
|
+
import inspect
|
1566
|
+
sig = inspect.signature(func)
|
1567
|
+
param_names = list(sig.parameters.keys())
|
1568
|
+
|
1569
|
+
args_dict = {}
|
1570
|
+
for i, arg in enumerate(args):
|
1571
|
+
if i < len(param_names):
|
1572
|
+
args_dict[param_names[i]] = arg
|
1573
|
+
else:
|
1574
|
+
args_dict[f"arg{i}"] = arg
|
1575
|
+
|
1576
|
+
return {**args_dict, **kwargs}
|
1577
|
+
except Exception as e:
|
1578
|
+
# Fallback if signature inspection fails
|
1579
|
+
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
1580
|
+
|
1581
|
+
# NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
|
1582
|
+
# @link https://docs.python.org/3.13/library/sysconfig.html
|
1583
|
+
_TRACE_FILEPATH_BLOCKLIST = tuple(
|
1584
|
+
os.path.realpath(p) + os.sep
|
1585
|
+
for p in {
|
1586
|
+
sysconfig.get_paths()['stdlib'],
|
1587
|
+
sysconfig.get_paths().get('platstdlib', ''),
|
1588
|
+
*site.getsitepackages(),
|
1589
|
+
site.getusersitepackages(),
|
1590
|
+
*(
|
1591
|
+
[os.path.join(os.path.dirname(__file__), '../../judgeval/')]
|
1592
|
+
if os.environ.get('JUDGMENT_DEV')
|
1593
|
+
else []
|
1594
|
+
),
|
1595
|
+
} if p
|
1596
|
+
)
|
1830
1597
|
|
1831
1598
|
# Add the new TraceThreadPoolExecutor class
|
1832
1599
|
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
@@ -1929,7 +1696,7 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
1929
1696
|
def _sync_stream_wrapper(
|
1930
1697
|
original_stream: Iterator,
|
1931
1698
|
client: ApiClient,
|
1932
|
-
|
1699
|
+
span: TraceSpan
|
1933
1700
|
) -> Generator[Any, None, None]:
|
1934
1701
|
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
1935
1702
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -1948,7 +1715,7 @@ def _sync_stream_wrapper(
|
|
1948
1715
|
final_usage = _extract_usage_from_final_chunk(client, last_chunk)
|
1949
1716
|
|
1950
1717
|
# Update the trace entry with the accumulated content and usage
|
1951
|
-
|
1718
|
+
span.output = {
|
1952
1719
|
"content": "".join(content_parts), # Join list at the end
|
1953
1720
|
"usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
|
1954
1721
|
"streamed": True
|
@@ -1960,7 +1727,7 @@ def _sync_stream_wrapper(
|
|
1960
1727
|
async def _async_stream_wrapper(
|
1961
1728
|
original_stream: AsyncIterator,
|
1962
1729
|
client: ApiClient,
|
1963
|
-
|
1730
|
+
span: TraceSpan
|
1964
1731
|
) -> AsyncGenerator[Any, None]:
|
1965
1732
|
# [Existing logic - unchanged]
|
1966
1733
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -1969,7 +1736,7 @@ async def _async_stream_wrapper(
|
|
1969
1736
|
anthropic_input_tokens = 0
|
1970
1737
|
anthropic_output_tokens = 0
|
1971
1738
|
|
1972
|
-
target_span_id =
|
1739
|
+
target_span_id = span.span_id
|
1973
1740
|
|
1974
1741
|
try:
|
1975
1742
|
async for chunk in original_stream:
|
@@ -2014,19 +1781,17 @@ async def _async_stream_wrapper(
|
|
2014
1781
|
elif last_content_chunk:
|
2015
1782
|
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
2016
1783
|
|
2017
|
-
if
|
2018
|
-
|
1784
|
+
if span and hasattr(span, 'output'):
|
1785
|
+
span.output = {
|
2019
1786
|
"content": "".join(content_parts), # Join list at the end
|
2020
1787
|
"usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
|
2021
1788
|
"streamed": True
|
2022
1789
|
}
|
2023
|
-
start_ts = getattr(
|
2024
|
-
|
1790
|
+
start_ts = getattr(span, 'created_at', time.time())
|
1791
|
+
span.duration = time.time() - start_ts
|
2025
1792
|
# else: # Handle error case if necessary, but remove debug print
|
2026
1793
|
|
2027
|
-
|
2028
|
-
class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
2029
|
-
"""Wraps an original async stream manager to add tracing."""
|
1794
|
+
class _BaseStreamManagerWrapper:
|
2030
1795
|
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
2031
1796
|
self._original_manager = original_manager
|
2032
1797
|
self._client = client
|
@@ -2036,281 +1801,74 @@ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
|
2036
1801
|
self._input_kwargs = input_kwargs
|
2037
1802
|
self._parent_span_id_at_entry = None
|
2038
1803
|
|
2039
|
-
|
2040
|
-
self._parent_span_id_at_entry = current_span_var.get()
|
2041
|
-
if not self._trace_client:
|
2042
|
-
# If no trace, just delegate to the original manager
|
2043
|
-
return await self._original_manager.__aenter__()
|
2044
|
-
|
2045
|
-
# --- Manually create the 'enter' entry ---
|
1804
|
+
def _create_span(self):
|
2046
1805
|
start_time = time.time()
|
2047
1806
|
span_id = str(uuid.uuid4())
|
2048
1807
|
current_depth = 0
|
2049
1808
|
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
2050
1809
|
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
2051
1810
|
self._trace_client._span_depths[span_id] = current_depth
|
2052
|
-
|
2053
|
-
|
2054
|
-
|
2055
|
-
|
1811
|
+
span = TraceSpan(
|
1812
|
+
function=self._span_name,
|
1813
|
+
span_id=span_id,
|
1814
|
+
trace_id=self._trace_client.trace_id,
|
1815
|
+
depth=current_depth,
|
1816
|
+
message=self._span_name,
|
1817
|
+
created_at=start_time,
|
1818
|
+
span_type="llm",
|
1819
|
+
parent_span_id=self._parent_span_id_at_entry
|
2056
1820
|
)
|
2057
|
-
self._trace_client.
|
2058
|
-
|
1821
|
+
self._trace_client.add_span(span)
|
1822
|
+
return span_id, span
|
2059
1823
|
|
2060
|
-
|
2061
|
-
|
1824
|
+
def _finalize_span(self, span_id):
|
1825
|
+
span = self._trace_client.span_id_to_span.get(span_id)
|
1826
|
+
if span:
|
1827
|
+
span.duration = time.time() - span.created_at
|
1828
|
+
if span_id in self._trace_client._span_depths:
|
1829
|
+
del self._trace_client._span_depths[span_id]
|
2062
1830
|
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
2066
|
-
|
2067
|
-
|
2068
|
-
created_at=time.time(), inputs=input_data, span_type="llm"
|
2069
|
-
)
|
2070
|
-
self._trace_client.add_entry(input_entry)
|
2071
|
-
|
2072
|
-
# Call the original __aenter__
|
2073
|
-
raw_iterator = await self._original_manager.__aenter__()
|
1831
|
+
class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
|
1832
|
+
async def __aenter__(self):
|
1833
|
+
self._parent_span_id_at_entry = current_span_var.get()
|
1834
|
+
if not self._trace_client:
|
1835
|
+
return await self._original_manager.__aenter__()
|
2074
1836
|
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
2078
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
2079
|
-
created_at=time.time(), output="<pending stream>", span_type="llm"
|
2080
|
-
)
|
2081
|
-
self._trace_client.add_entry(output_entry)
|
1837
|
+
span_id, span = self._create_span()
|
1838
|
+
self._span_context_token = current_span_var.set(span_id)
|
1839
|
+
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
2082
1840
|
|
2083
|
-
#
|
2084
|
-
|
2085
|
-
|
1841
|
+
# Call the original __aenter__ and expect it to be an async generator
|
1842
|
+
raw_iterator = await self._original_manager.__aenter__()
|
1843
|
+
span.output = "<pending stream>"
|
1844
|
+
return self._stream_wrapper_func(raw_iterator, self._client, span)
|
2086
1845
|
|
2087
1846
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
2088
|
-
# Manually create the 'exit' entry
|
2089
1847
|
if hasattr(self, '_span_context_token'):
|
2090
|
-
|
2091
|
-
|
2092
|
-
|
2093
|
-
|
2094
|
-
|
2095
|
-
break
|
2096
|
-
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
2097
|
-
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
2098
|
-
exit_entry = TraceEntry(
|
2099
|
-
type="exit", function=self._span_name, span_id=span_id,
|
2100
|
-
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
2101
|
-
created_at=time.time(), duration=duration, span_type="llm"
|
2102
|
-
)
|
2103
|
-
self._trace_client.add_entry(exit_entry)
|
2104
|
-
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
2105
|
-
current_span_var.reset(self._span_context_token)
|
2106
|
-
delattr(self, '_span_context_token')
|
2107
|
-
|
2108
|
-
# Delegate __aexit__
|
2109
|
-
if hasattr(self._original_manager, "__aexit__"):
|
2110
|
-
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2111
|
-
return None
|
2112
|
-
|
2113
|
-
class _TracedSyncStreamManagerWrapper(AbstractContextManager):
|
2114
|
-
"""Wraps an original sync stream manager to add tracing."""
|
2115
|
-
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
2116
|
-
self._original_manager = original_manager
|
2117
|
-
self._client = client
|
2118
|
-
self._span_name = span_name
|
2119
|
-
self._trace_client = trace_client
|
2120
|
-
self._stream_wrapper_func = stream_wrapper_func
|
2121
|
-
self._input_kwargs = input_kwargs
|
2122
|
-
self._parent_span_id_at_entry = None
|
1848
|
+
span_id = current_span_var.get()
|
1849
|
+
self._finalize_span(span_id)
|
1850
|
+
current_span_var.reset(self._span_context_token)
|
1851
|
+
delattr(self, '_span_context_token')
|
1852
|
+
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2123
1853
|
|
1854
|
+
class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
|
2124
1855
|
def __enter__(self):
|
2125
1856
|
self._parent_span_id_at_entry = current_span_var.get()
|
2126
1857
|
if not self._trace_client:
|
2127
|
-
|
1858
|
+
return self._original_manager.__enter__()
|
2128
1859
|
|
2129
|
-
|
2130
|
-
start_time = time.time()
|
2131
|
-
span_id = str(uuid.uuid4())
|
2132
|
-
current_depth = 0
|
2133
|
-
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
2134
|
-
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
2135
|
-
self._trace_client._span_depths[span_id] = current_depth
|
2136
|
-
enter_entry = TraceEntry(
|
2137
|
-
type="enter", function=self._span_name, span_id=span_id,
|
2138
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
|
2139
|
-
created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
|
2140
|
-
)
|
2141
|
-
self._trace_client.add_entry(enter_entry)
|
1860
|
+
span_id, span = self._create_span()
|
2142
1861
|
self._span_context_token = current_span_var.set(span_id)
|
1862
|
+
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
2143
1863
|
|
2144
|
-
# Manually create 'input' entry
|
2145
|
-
input_data = _format_input_data(self._client, **self._input_kwargs)
|
2146
|
-
input_entry = TraceEntry(
|
2147
|
-
type="input", function=self._span_name, span_id=span_id,
|
2148
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
|
2149
|
-
created_at=time.time(), inputs=input_data, span_type="llm"
|
2150
|
-
)
|
2151
|
-
self._trace_client.add_entry(input_entry)
|
2152
|
-
|
2153
|
-
# Call original __enter__
|
2154
1864
|
raw_iterator = self._original_manager.__enter__()
|
2155
|
-
|
2156
|
-
|
2157
|
-
output_entry = TraceEntry(
|
2158
|
-
type="output", function=self._span_name, span_id=span_id,
|
2159
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
2160
|
-
created_at=time.time(), output="<pending stream>", span_type="llm"
|
2161
|
-
)
|
2162
|
-
self._trace_client.add_entry(output_entry)
|
2163
|
-
|
2164
|
-
# Wrap the raw iterator
|
2165
|
-
wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
|
2166
|
-
return wrapped_iterator
|
1865
|
+
span.output = "<pending stream>"
|
1866
|
+
return self._stream_wrapper_func(raw_iterator, self._client, span)
|
2167
1867
|
|
2168
1868
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
2169
|
-
# Manually create 'exit' entry
|
2170
1869
|
if hasattr(self, '_span_context_token'):
|
2171
|
-
|
2172
|
-
|
2173
|
-
|
2174
|
-
|
2175
|
-
|
2176
|
-
break
|
2177
|
-
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
2178
|
-
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
2179
|
-
exit_entry = TraceEntry(
|
2180
|
-
type="exit", function=self._span_name, span_id=span_id,
|
2181
|
-
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
2182
|
-
created_at=time.time(), duration=duration, span_type="llm"
|
2183
|
-
)
|
2184
|
-
self._trace_client.add_entry(exit_entry)
|
2185
|
-
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
2186
|
-
current_span_var.reset(self._span_context_token)
|
2187
|
-
delattr(self, '_span_context_token')
|
2188
|
-
|
2189
|
-
# Delegate __exit__
|
2190
|
-
if hasattr(self._original_manager, "__exit__"):
|
2191
|
-
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
2192
|
-
return None
|
2193
|
-
|
2194
|
-
# --- NEW Generalized Helper Function (Moved from demo) ---
|
2195
|
-
def prepare_evaluation_for_state(
|
2196
|
-
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
2197
|
-
example: Optional[Example] = None,
|
2198
|
-
# --- Individual components (alternative to 'example') ---
|
2199
|
-
input: Optional[str] = None,
|
2200
|
-
actual_output: Optional[Union[str, List[str]]] = None,
|
2201
|
-
expected_output: Optional[Union[str, List[str]]] = None,
|
2202
|
-
context: Optional[List[str]] = None,
|
2203
|
-
retrieval_context: Optional[List[str]] = None,
|
2204
|
-
tools_called: Optional[List[str]] = None,
|
2205
|
-
expected_tools: Optional[List[str]] = None,
|
2206
|
-
additional_metadata: Optional[Dict[str, Any]] = None,
|
2207
|
-
# --- Other eval parameters ---
|
2208
|
-
model: Optional[str] = None,
|
2209
|
-
log_results: Optional[bool] = True
|
2210
|
-
) -> Optional[EvaluationConfig]:
|
2211
|
-
"""
|
2212
|
-
Prepares an EvaluationConfig object, similar to TraceClient.async_evaluate.
|
2213
|
-
|
2214
|
-
Accepts either a pre-made Example object or individual components to construct one.
|
2215
|
-
Returns the EvaluationConfig object ready to be placed in the state, or None.
|
2216
|
-
"""
|
2217
|
-
final_example = example
|
2218
|
-
|
2219
|
-
# If example is not provided, try to construct one from individual parts
|
2220
|
-
if final_example is None:
|
2221
|
-
# Basic validation: Ensure at least actual_output is present for most scorers
|
2222
|
-
if actual_output is None:
|
2223
|
-
# print("[prepare_evaluation_for_state] Warning: 'actual_output' is required when 'example' is not provided. Skipping evaluation setup.")
|
2224
|
-
return None
|
2225
|
-
try:
|
2226
|
-
final_example = Example(
|
2227
|
-
input=input,
|
2228
|
-
actual_output=actual_output,
|
2229
|
-
expected_output=expected_output,
|
2230
|
-
context=context,
|
2231
|
-
retrieval_context=retrieval_context,
|
2232
|
-
tools_called=tools_called,
|
2233
|
-
expected_tools=expected_tools,
|
2234
|
-
additional_metadata=additional_metadata,
|
2235
|
-
# trace_id will be set by the handler later if needed
|
2236
|
-
)
|
2237
|
-
# print("[prepare_evaluation_for_state] Constructed Example from individual components.")
|
2238
|
-
except Exception as e:
|
2239
|
-
# print(f"[prepare_evaluation_for_state] Error constructing Example: {e}. Skipping evaluation setup.")
|
2240
|
-
return None
|
2241
|
-
|
2242
|
-
# If we have a valid example (provided or constructed) and scorers
|
2243
|
-
if final_example and scorers:
|
2244
|
-
# TODO: Add validation like check_examples if needed here,
|
2245
|
-
# although the handler might implicitly handle some checks via TraceClient.
|
2246
|
-
return EvaluationConfig(
|
2247
|
-
scorers=scorers,
|
2248
|
-
example=final_example,
|
2249
|
-
model=model,
|
2250
|
-
log_results=log_results
|
2251
|
-
)
|
2252
|
-
elif not scorers:
|
2253
|
-
# print("[prepare_evaluation_for_state] No scorers provided. Skipping evaluation setup.")
|
2254
|
-
return None
|
2255
|
-
else: # No valid example
|
2256
|
-
# print("[prepare_evaluation_for_state] No valid Example available. Skipping evaluation setup.")
|
2257
|
-
return None
|
2258
|
-
# --- End NEW Helper Function ---
|
2259
|
-
|
2260
|
-
# --- NEW: Helper function to simplify adding eval config to state ---
|
2261
|
-
def add_evaluation_to_state(
|
2262
|
-
state: Dict[str, Any], # The LangGraph state dictionary
|
2263
|
-
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
2264
|
-
# --- Evaluation components (same as prepare_evaluation_for_state) ---
|
2265
|
-
input: Optional[str] = None,
|
2266
|
-
actual_output: Optional[Union[str, List[str]]] = None,
|
2267
|
-
expected_output: Optional[Union[str, List[str]]] = None,
|
2268
|
-
context: Optional[List[str]] = None,
|
2269
|
-
retrieval_context: Optional[List[str]] = None,
|
2270
|
-
tools_called: Optional[List[str]] = None,
|
2271
|
-
expected_tools: Optional[List[str]] = None,
|
2272
|
-
additional_metadata: Optional[Dict[str, Any]] = None,
|
2273
|
-
# --- Other eval parameters ---
|
2274
|
-
model: Optional[str] = None,
|
2275
|
-
log_results: Optional[bool] = True
|
2276
|
-
) -> None:
|
2277
|
-
"""
|
2278
|
-
Prepares an EvaluationConfig and adds it to the state dictionary
|
2279
|
-
under the '_judgeval_eval' key if successful.
|
2280
|
-
|
2281
|
-
This simplifies the process of setting up evaluations within LangGraph nodes.
|
2282
|
-
|
2283
|
-
Args:
|
2284
|
-
state: The LangGraph state dictionary to modify.
|
2285
|
-
scorers: List of scorer instances.
|
2286
|
-
input: Input for the evaluation example.
|
2287
|
-
actual_output: Actual output for the evaluation example.
|
2288
|
-
expected_output: Expected output for the evaluation example.
|
2289
|
-
context: Context for the evaluation example.
|
2290
|
-
retrieval_context: Retrieval context for the evaluation example.
|
2291
|
-
tools_called: Tools called for the evaluation example.
|
2292
|
-
expected_tools: Expected tools for the evaluation example.
|
2293
|
-
additional_metadata: Additional metadata for the evaluation example.
|
2294
|
-
model: Model name used for generation (optional).
|
2295
|
-
log_results: Whether to log evaluation results (optional, defaults to True).
|
2296
|
-
"""
|
2297
|
-
eval_config = prepare_evaluation_for_state(
|
2298
|
-
scorers=scorers,
|
2299
|
-
input=input,
|
2300
|
-
actual_output=actual_output,
|
2301
|
-
expected_output=expected_output,
|
2302
|
-
context=context,
|
2303
|
-
retrieval_context=retrieval_context,
|
2304
|
-
tools_called=tools_called,
|
2305
|
-
expected_tools=expected_tools,
|
2306
|
-
additional_metadata=additional_metadata,
|
2307
|
-
model=model,
|
2308
|
-
log_results=log_results
|
2309
|
-
)
|
2310
|
-
|
2311
|
-
if eval_config:
|
2312
|
-
state["_judgeval_eval"] = eval_config
|
2313
|
-
# print(f"[_judgeval_eval added to state for node]") # Optional: Log confirmation
|
2314
|
-
|
2315
|
-
# print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
|
2316
|
-
# --- End NEW Helper ---
|
1870
|
+
span_id = current_span_var.get()
|
1871
|
+
self._finalize_span(span_id)
|
1872
|
+
current_span_var.reset(self._span_context_token)
|
1873
|
+
delattr(self, '_span_context_token')
|
1874
|
+
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|