judgeval 0.0.35__py3-none-any.whl → 0.0.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +869 -928
- judgeval/common/utils.py +18 -0
- judgeval/constants.py +6 -3
- judgeval/data/__init__.py +4 -0
- judgeval/data/datasets/dataset.py +3 -2
- judgeval/data/datasets/eval_dataset_client.py +63 -3
- judgeval/data/example.py +29 -7
- judgeval/data/sequence.py +5 -4
- judgeval/data/sequence_run.py +4 -3
- judgeval/data/trace.py +129 -0
- judgeval/evaluation_run.py +1 -1
- judgeval/integrations/langgraph.py +1962 -299
- judgeval/judgment_client.py +85 -66
- judgeval/run_evaluation.py +191 -45
- judgeval/scorers/__init__.py +2 -0
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -0
- judgeval/scorers/judgeval_scorers/api_scorers/tool_order.py +18 -0
- judgeval/scorers/score.py +2 -1
- judgeval/utils/data_utils.py +57 -0
- judgeval-0.0.37.dist-info/METADATA +214 -0
- {judgeval-0.0.35.dist-info → judgeval-0.0.37.dist-info}/RECORD +23 -20
- judgeval-0.0.35.dist-info/METADATA +0 -170
- {judgeval-0.0.35.dist-info → judgeval-0.0.37.dist-info}/WHEEL +0 -0
- {judgeval-0.0.35.dist-info → judgeval-0.0.37.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -7,7 +7,11 @@ import functools
|
|
7
7
|
import inspect
|
8
8
|
import json
|
9
9
|
import os
|
10
|
+
import site
|
11
|
+
import sysconfig
|
12
|
+
import threading
|
10
13
|
import time
|
14
|
+
import traceback
|
11
15
|
import uuid
|
12
16
|
import warnings
|
13
17
|
import contextvars
|
@@ -35,7 +39,6 @@ from rich import print as rprint
|
|
35
39
|
import types # <--- Add this import
|
36
40
|
|
37
41
|
# Third-party imports
|
38
|
-
import pika
|
39
42
|
import requests
|
40
43
|
from litellm import cost_per_token
|
41
44
|
from pydantic import BaseModel
|
@@ -47,6 +50,7 @@ from google import genai
|
|
47
50
|
|
48
51
|
# Local application/library-specific imports
|
49
52
|
from judgeval.constants import (
|
53
|
+
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
50
54
|
JUDGMENT_TRACES_SAVE_API_URL,
|
51
55
|
JUDGMENT_TRACES_FETCH_API_URL,
|
52
56
|
RABBITMQ_HOST,
|
@@ -55,172 +59,56 @@ from judgeval.constants import (
|
|
55
59
|
JUDGMENT_TRACES_DELETE_API_URL,
|
56
60
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
57
61
|
)
|
58
|
-
from judgeval.
|
59
|
-
from judgeval.data import Example
|
62
|
+
from judgeval.data import Example, Trace, TraceSpan
|
60
63
|
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
61
64
|
from judgeval.rules import Rule
|
62
65
|
from judgeval.evaluation_run import EvaluationRun
|
63
66
|
from judgeval.data.result import ScoringResult
|
67
|
+
from judgeval.common.utils import validate_api_key
|
68
|
+
from judgeval.common.exceptions import JudgmentAPIError
|
64
69
|
|
65
70
|
# Standard library imports needed for the new class
|
66
71
|
import concurrent.futures
|
67
72
|
from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
|
68
73
|
|
69
74
|
# Define context variables for tracking the current trace and the current span within a trace
|
70
|
-
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
75
|
+
current_trace_var = contextvars.ContextVar[Optional['TraceClient']]('current_trace', default=None)
|
71
76
|
current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
|
72
|
-
in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
|
73
77
|
|
74
78
|
# Define type aliases for better code readability and maintainability
|
75
79
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
|
76
|
-
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
77
80
|
SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
|
78
|
-
@dataclass
|
79
|
-
class TraceEntry:
|
80
|
-
"""Represents a single trace entry with its visual representation.
|
81
|
-
|
82
|
-
Visual representations:
|
83
|
-
- enter: → (function entry)
|
84
|
-
- exit: ← (function exit)
|
85
|
-
- output: Output: (function return value)
|
86
|
-
- input: Input: (function parameters)
|
87
|
-
- evaluation: Evaluation: (evaluation results)
|
88
|
-
"""
|
89
|
-
type: TraceEntryType
|
90
|
-
span_id: str # Unique ID for this specific span instance
|
91
|
-
depth: int # Indentation level for nested calls
|
92
|
-
created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
|
93
|
-
function: Optional[str] = None # Name of the function being traced
|
94
|
-
message: Optional[str] = None # Human-readable description
|
95
|
-
duration: Optional[float] = None # Time taken (for exit/evaluation entries)
|
96
|
-
trace_id: str = None # ID of the trace this entry belongs to
|
97
|
-
output: Any = None # Function output value
|
98
|
-
# Use field() for mutable defaults to avoid shared state issues
|
99
|
-
inputs: dict = field(default_factory=dict)
|
100
|
-
span_type: SpanType = "span"
|
101
|
-
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
102
|
-
parent_span_id: Optional[str] = None # ID of the parent span instance
|
103
|
-
|
104
|
-
def print_entry(self):
|
105
|
-
"""Print a trace entry with proper formatting and parent relationship information."""
|
106
|
-
indent = " " * self.depth
|
107
|
-
|
108
|
-
if self.type == "enter":
|
109
|
-
# Format parent info if present
|
110
|
-
parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
|
111
|
-
print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
|
112
|
-
elif self.type == "exit":
|
113
|
-
print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
|
114
|
-
elif self.type == "output":
|
115
|
-
# Format output to align properly
|
116
|
-
output_str = str(self.output)
|
117
|
-
print(f"{indent}Output (for id: {self.span_id}): {output_str}")
|
118
|
-
elif self.type == "input":
|
119
|
-
# Format inputs to align properly
|
120
|
-
print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
|
121
|
-
elif self.type == "evaluation":
|
122
|
-
for evaluation_run in self.evaluation_runs:
|
123
|
-
print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
|
124
|
-
|
125
|
-
def _serialize_inputs(self) -> dict:
|
126
|
-
"""Helper method to serialize input data safely.
|
127
|
-
|
128
|
-
Returns a dict with serializable versions of inputs, converting non-serializable
|
129
|
-
objects to None with a warning.
|
130
|
-
"""
|
131
|
-
serialized_inputs = {}
|
132
|
-
for key, value in self.inputs.items():
|
133
|
-
if isinstance(value, BaseModel):
|
134
|
-
serialized_inputs[key] = value.model_dump()
|
135
|
-
elif isinstance(value, (list, tuple)):
|
136
|
-
# Handle lists/tuples of arguments
|
137
|
-
serialized_inputs[key] = [
|
138
|
-
item.model_dump() if isinstance(item, BaseModel)
|
139
|
-
else None if not self._is_json_serializable(item)
|
140
|
-
else item
|
141
|
-
for item in value
|
142
|
-
]
|
143
|
-
else:
|
144
|
-
if self._is_json_serializable(value):
|
145
|
-
serialized_inputs[key] = value
|
146
|
-
else:
|
147
|
-
serialized_inputs[key] = self.safe_stringify(value, self.function)
|
148
|
-
return serialized_inputs
|
149
|
-
|
150
|
-
def _is_json_serializable(self, obj: Any) -> bool:
|
151
|
-
"""Helper method to check if an object is JSON serializable."""
|
152
|
-
try:
|
153
|
-
json.dumps(obj)
|
154
|
-
return True
|
155
|
-
except (TypeError, OverflowError, ValueError):
|
156
|
-
return False
|
157
81
|
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
82
|
+
# --- Evaluation Config Dataclass (Moved from langgraph.py) ---
|
83
|
+
@dataclass
|
84
|
+
class EvaluationConfig:
|
85
|
+
"""Configuration for triggering an evaluation from the handler."""
|
86
|
+
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]]
|
87
|
+
example: Example
|
88
|
+
model: Optional[str] = None
|
89
|
+
log_results: Optional[bool] = True
|
90
|
+
# --- End Evaluation Config Dataclass ---
|
91
|
+
|
92
|
+
# Temporary as a POC to have log use the existing annotations feature until log endpoints are ready
|
93
|
+
@dataclass
|
94
|
+
class TraceAnnotation:
|
95
|
+
"""Represents a single annotation for a trace span."""
|
96
|
+
span_id: str
|
97
|
+
text: str
|
98
|
+
label: str
|
99
|
+
score: int
|
176
100
|
|
177
101
|
def to_dict(self) -> dict:
|
178
|
-
"""Convert the
|
102
|
+
"""Convert the annotation to a dictionary format for storage/transmission."""
|
179
103
|
return {
|
180
|
-
"type": self.type,
|
181
|
-
"function": self.function,
|
182
104
|
"span_id": self.span_id,
|
183
|
-
"
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
"output": self._serialize_output(),
|
189
|
-
"inputs": self._serialize_inputs(),
|
190
|
-
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
191
|
-
"span_type": self.span_type,
|
192
|
-
"parent_span_id": self.parent_span_id,
|
105
|
+
"annotation": {
|
106
|
+
"text": self.text,
|
107
|
+
"label": self.label,
|
108
|
+
"score": self.score
|
109
|
+
}
|
193
110
|
}
|
194
|
-
|
195
|
-
def _serialize_output(self) -> Any:
|
196
|
-
"""Helper method to serialize output data safely.
|
197
|
-
|
198
|
-
Handles special cases:
|
199
|
-
- Pydantic models are converted using model_dump()
|
200
|
-
- We try to serialize into JSON, then string, then the base representation (__repr__)
|
201
|
-
- Non-serializable objects return None with a warning
|
202
|
-
"""
|
203
|
-
|
204
|
-
if isinstance(self.output, BaseModel):
|
205
|
-
return self.output.model_dump()
|
206
|
-
|
207
|
-
# NEW check: If output is the dict structure from our stream wrapper
|
208
|
-
if isinstance(self.output, dict) and 'streamed' in self.output:
|
209
|
-
# Assume it's already JSON-serializable (content is string, usage is dict or None)
|
210
|
-
return self.output
|
211
|
-
# NEW check: If output is the placeholder string before stream completes
|
212
|
-
elif self.output == "<pending stream>":
|
213
|
-
# Represent this state clearly in the serialized data
|
214
|
-
return {"status": "pending stream"}
|
215
|
-
|
216
|
-
try:
|
217
|
-
# Try to serialize the output to verify it's JSON compatible
|
218
|
-
json.dumps(self.output)
|
219
|
-
return self.output
|
220
|
-
except (TypeError, OverflowError, ValueError):
|
221
|
-
return self.safe_stringify(self.output, self.function)
|
222
|
-
|
223
|
-
|
111
|
+
|
224
112
|
class TraceManagerClient:
|
225
113
|
"""
|
226
114
|
Client for handling trace endpoints with the Judgment API
|
@@ -257,8 +145,6 @@ class TraceManagerClient:
|
|
257
145
|
raise ValueError(f"Failed to fetch traces: {response.text}")
|
258
146
|
|
259
147
|
return response.json()
|
260
|
-
|
261
|
-
|
262
148
|
|
263
149
|
def save_trace(self, trace_data: dict):
|
264
150
|
"""
|
@@ -301,6 +187,33 @@ class TraceManagerClient:
|
|
301
187
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
|
302
188
|
rprint(pretty_str)
|
303
189
|
|
190
|
+
## TODO: Should have a log endpoint, endpoint should also support batched payloads
|
191
|
+
def save_annotation(self, annotation: TraceAnnotation):
|
192
|
+
json_data = {
|
193
|
+
"span_id": annotation.span_id,
|
194
|
+
"annotation": {
|
195
|
+
"text": annotation.text,
|
196
|
+
"label": annotation.label,
|
197
|
+
"score": annotation.score
|
198
|
+
}
|
199
|
+
}
|
200
|
+
|
201
|
+
response = requests.post(
|
202
|
+
JUDGMENT_TRACES_ADD_ANNOTATION_API_URL,
|
203
|
+
json=json_data,
|
204
|
+
headers={
|
205
|
+
'Content-Type': 'application/json',
|
206
|
+
'Authorization': f'Bearer {self.judgment_api_key}',
|
207
|
+
'X-Organization-Id': self.organization_id
|
208
|
+
},
|
209
|
+
verify=True
|
210
|
+
)
|
211
|
+
|
212
|
+
if response.status_code != HTTPStatus.OK:
|
213
|
+
raise ValueError(f"Failed to save annotation: {response.text}")
|
214
|
+
|
215
|
+
return response.json()
|
216
|
+
|
304
217
|
def delete_trace(self, trace_id: str):
|
305
218
|
"""
|
306
219
|
Delete a trace from the database.
|
@@ -391,15 +304,16 @@ class TraceClient:
|
|
391
304
|
self.enable_evaluations = enable_evaluations
|
392
305
|
self.parent_trace_id = parent_trace_id
|
393
306
|
self.parent_name = parent_name
|
394
|
-
self.
|
395
|
-
self.
|
307
|
+
self.trace_spans: List[TraceSpan] = []
|
308
|
+
self.span_id_to_span: Dict[str, TraceSpan] = {}
|
309
|
+
self.evaluation_runs: List[EvaluationRun] = []
|
310
|
+
self.annotations: List[TraceAnnotation] = []
|
396
311
|
self.start_time = time.time()
|
397
312
|
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
|
398
313
|
self.visited_nodes = []
|
399
314
|
self.executed_tools = []
|
400
315
|
self.executed_node_tools = []
|
401
316
|
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
402
|
-
|
403
317
|
def get_current_span(self):
|
404
318
|
"""Get the current span from the context var"""
|
405
319
|
return current_span_var.get()
|
@@ -429,9 +343,7 @@ class TraceClient:
|
|
429
343
|
|
430
344
|
self._span_depths[span_id] = current_depth # Store depth by span_id
|
431
345
|
|
432
|
-
|
433
|
-
type="enter",
|
434
|
-
function=name,
|
346
|
+
span = TraceSpan(
|
435
347
|
span_id=span_id,
|
436
348
|
trace_id=self.trace_id,
|
437
349
|
depth=current_depth,
|
@@ -439,25 +351,15 @@ class TraceClient:
|
|
439
351
|
created_at=start_time,
|
440
352
|
span_type=span_type,
|
441
353
|
parent_span_id=parent_span_id,
|
354
|
+
function=name,
|
442
355
|
)
|
443
|
-
self.
|
356
|
+
self.add_span(span)
|
444
357
|
|
445
358
|
try:
|
446
359
|
yield self
|
447
360
|
finally:
|
448
361
|
duration = time.time() - start_time
|
449
|
-
|
450
|
-
self.add_entry(TraceEntry(
|
451
|
-
type="exit",
|
452
|
-
function=name,
|
453
|
-
span_id=span_id, # Use the same span_id for exit
|
454
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
455
|
-
depth=exit_depth,
|
456
|
-
message=f"← {name}",
|
457
|
-
created_at=time.time(),
|
458
|
-
duration=duration,
|
459
|
-
span_type=span_type,
|
460
|
-
))
|
362
|
+
span.duration = duration
|
461
363
|
# Clean up depth tracking for this span_id
|
462
364
|
if span_id in self._span_depths:
|
463
365
|
del self._span_depths[span_id]
|
@@ -467,32 +369,24 @@ class TraceClient:
|
|
467
369
|
def async_evaluate(
|
468
370
|
self,
|
469
371
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
372
|
+
example: Optional[Example] = None,
|
470
373
|
input: Optional[str] = None,
|
471
|
-
actual_output: Optional[str] = None,
|
472
|
-
expected_output: Optional[str] = None,
|
374
|
+
actual_output: Optional[Union[str, List[str]]] = None,
|
375
|
+
expected_output: Optional[Union[str, List[str]]] = None,
|
473
376
|
context: Optional[List[str]] = None,
|
474
377
|
retrieval_context: Optional[List[str]] = None,
|
475
378
|
tools_called: Optional[List[str]] = None,
|
476
379
|
expected_tools: Optional[List[str]] = None,
|
477
380
|
additional_metadata: Optional[Dict[str, Any]] = None,
|
478
381
|
model: Optional[str] = None,
|
382
|
+
span_id: Optional[str] = None, # <<< ADDED optional span_id parameter
|
479
383
|
log_results: Optional[bool] = True
|
480
384
|
):
|
481
385
|
if not self.enable_evaluations:
|
482
386
|
return
|
483
387
|
|
484
388
|
start_time = time.time() # Record start time
|
485
|
-
|
486
|
-
input=input,
|
487
|
-
actual_output=actual_output,
|
488
|
-
expected_output=expected_output,
|
489
|
-
context=context,
|
490
|
-
retrieval_context=retrieval_context,
|
491
|
-
tools_called=tools_called,
|
492
|
-
expected_tools=expected_tools,
|
493
|
-
additional_metadata=additional_metadata,
|
494
|
-
trace_id=self.trace_id
|
495
|
-
)
|
389
|
+
|
496
390
|
try:
|
497
391
|
# Load appropriate implementations for all scorers
|
498
392
|
if not scorers:
|
@@ -507,13 +401,44 @@ class TraceClient:
|
|
507
401
|
warnings.warn(f"Failed to load scorers: {str(e)}")
|
508
402
|
return
|
509
403
|
|
404
|
+
# If example is not provided, create one from the individual parameters
|
405
|
+
if example is None:
|
406
|
+
# Check if any of the individual parameters are provided
|
407
|
+
if any(param is not None for param in [input, actual_output, expected_output, context,
|
408
|
+
retrieval_context, tools_called, expected_tools,
|
409
|
+
additional_metadata]):
|
410
|
+
example = Example(
|
411
|
+
input=input,
|
412
|
+
actual_output=actual_output,
|
413
|
+
expected_output=expected_output,
|
414
|
+
context=context,
|
415
|
+
retrieval_context=retrieval_context,
|
416
|
+
tools_called=tools_called,
|
417
|
+
expected_tools=expected_tools,
|
418
|
+
additional_metadata=additional_metadata,
|
419
|
+
)
|
420
|
+
else:
|
421
|
+
raise ValueError("Either 'example' or at least one of the individual parameters (input, actual_output, etc.) must be provided")
|
422
|
+
|
423
|
+
# Check examples before creating evaluation run
|
424
|
+
|
425
|
+
# check_examples([example], scorers)
|
426
|
+
|
427
|
+
# --- Modification: Capture span_id immediately ---
|
428
|
+
# span_id_at_eval_call = current_span_var.get()
|
429
|
+
# print(f"[TraceClient.async_evaluate] Captured span ID at eval call: {span_id_at_eval_call}")
|
430
|
+
# Prioritize explicitly passed span_id, fallback to context var
|
431
|
+
span_id_to_use = span_id if span_id is not None else current_span_var.get()
|
432
|
+
# print(f"[TraceClient.async_evaluate] Using span_id: {span_id_to_use}")
|
433
|
+
# --- End Modification ---
|
434
|
+
|
510
435
|
# Combine the trace-level rules with any evaluation-specific rules)
|
511
436
|
eval_run = EvaluationRun(
|
512
437
|
organization_id=self.tracer.organization_id,
|
513
438
|
log_results=log_results,
|
514
439
|
project_name=self.project_name,
|
515
440
|
eval_name=f"{self.name.capitalize()}-"
|
516
|
-
f"{current_span_var.get()}-"
|
441
|
+
f"{current_span_var.get()}-" # Keep original eval name format using context var if available
|
517
442
|
f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
|
518
443
|
examples=[example],
|
519
444
|
scorers=scorers,
|
@@ -521,296 +446,73 @@ class TraceClient:
|
|
521
446
|
metadata={},
|
522
447
|
judgment_api_key=self.tracer.api_key,
|
523
448
|
override=self.overwrite,
|
524
|
-
trace_span_id=
|
449
|
+
trace_span_id=span_id_to_use, # Pass the determined ID
|
525
450
|
rules=self.rules # Use the combined rules
|
526
451
|
)
|
527
452
|
|
528
453
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
529
454
|
|
530
455
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
# Determine function name based on previous entry or context var (less ideal)
|
536
|
-
function_name = "unknown_function" # Default
|
537
|
-
if prev_entry and prev_entry.span_type == "llm":
|
538
|
-
function_name = prev_entry.function
|
539
|
-
else:
|
540
|
-
# Try to find the function name associated with the current span_id
|
541
|
-
for entry in reversed(self.entries):
|
542
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
543
|
-
function_name = entry.function
|
544
|
-
break
|
545
|
-
|
546
|
-
# Get depth for the current span
|
547
|
-
current_depth = self._span_depths.get(current_span_id, 0)
|
548
|
-
|
549
|
-
self.add_entry(TraceEntry(
|
550
|
-
type="evaluation",
|
551
|
-
function=function_name,
|
552
|
-
span_id=current_span_id, # Associate with current span
|
553
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
554
|
-
depth=current_depth,
|
555
|
-
message=f"Evaluation results for {function_name}",
|
556
|
-
created_at=time.time(),
|
557
|
-
evaluation_runs=[eval_run],
|
558
|
-
duration=duration,
|
559
|
-
span_type="evaluation"
|
560
|
-
))
|
456
|
+
# --- Modification: Use span_id from eval_run ---
|
457
|
+
current_span_id = eval_run.trace_span_id # Get ID from the eval_run object
|
458
|
+
# print(f"[TraceClient.add_eval_run] Using span_id from eval_run: {current_span_id}")
|
459
|
+
# --- End Modification ---
|
561
460
|
|
461
|
+
if current_span_id:
|
462
|
+
span = self.span_id_to_span[current_span_id]
|
463
|
+
span.evaluation_runs.append(eval_run)
|
464
|
+
self.evaluation_runs.append(eval_run)
|
465
|
+
|
466
|
+
def add_annotation(self, annotation: TraceAnnotation):
|
467
|
+
"""Add an annotation to this trace context"""
|
468
|
+
self.annotations.append(annotation)
|
469
|
+
return self
|
470
|
+
|
562
471
|
def record_input(self, inputs: dict):
|
563
472
|
current_span_id = current_span_var.get()
|
564
473
|
if current_span_id:
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
570
|
-
entry_span_type = entry.span_type
|
571
|
-
function_name = entry.function
|
572
|
-
break
|
573
|
-
|
574
|
-
self.add_entry(TraceEntry(
|
575
|
-
type="input",
|
576
|
-
function=function_name,
|
577
|
-
span_id=current_span_id, # Use current span_id
|
578
|
-
trace_id=self.trace_id, # Use the trace_id from the trace client
|
579
|
-
depth=current_depth,
|
580
|
-
message=f"Inputs to {function_name}",
|
581
|
-
created_at=time.time(),
|
582
|
-
inputs=inputs,
|
583
|
-
span_type=entry_span_type,
|
584
|
-
))
|
585
|
-
|
586
|
-
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
474
|
+
span = self.span_id_to_span[current_span_id]
|
475
|
+
span.inputs = inputs
|
476
|
+
|
477
|
+
async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
|
587
478
|
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
588
479
|
try:
|
589
480
|
result = await coroutine
|
590
|
-
|
481
|
+
span.output = result
|
591
482
|
return result
|
592
483
|
except Exception as e:
|
593
|
-
|
484
|
+
span.output = f"Error: {str(e)}"
|
594
485
|
raise
|
595
486
|
|
596
487
|
def record_output(self, output: Any):
|
597
488
|
current_span_id = current_span_var.get()
|
598
489
|
if current_span_id:
|
599
|
-
|
600
|
-
|
601
|
-
function_name = "unknown_function" # Default
|
602
|
-
for entry in reversed(self.entries):
|
603
|
-
if entry.span_id == current_span_id and entry.type == 'enter':
|
604
|
-
entry_span_type = entry.span_type
|
605
|
-
function_name = entry.function
|
606
|
-
break
|
607
|
-
|
608
|
-
entry = TraceEntry(
|
609
|
-
type="output",
|
610
|
-
function=function_name,
|
611
|
-
span_id=current_span_id, # Use current span_id
|
612
|
-
depth=current_depth,
|
613
|
-
message=f"Output from {function_name}",
|
614
|
-
created_at=time.time(),
|
615
|
-
output="<pending>" if inspect.iscoroutine(output) else output,
|
616
|
-
span_type=entry_span_type,
|
617
|
-
)
|
618
|
-
self.add_entry(entry)
|
490
|
+
span = self.span_id_to_span[current_span_id]
|
491
|
+
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
619
492
|
|
620
493
|
if inspect.iscoroutine(output):
|
621
|
-
asyncio.create_task(self._update_coroutine_output(
|
494
|
+
asyncio.create_task(self._update_coroutine_output(span, output))
|
622
495
|
|
623
|
-
# Return the created entry
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
496
|
+
return span # Return the created entry
|
497
|
+
# Removed else block - original didn't have one
|
498
|
+
return None # Return None if no span_id found
|
499
|
+
|
500
|
+
def add_span(self, span: TraceSpan):
|
501
|
+
"""Add a trace span to this trace context"""
|
502
|
+
self.trace_spans.append(span)
|
503
|
+
self.span_id_to_span[span.span_id] = span
|
629
504
|
return self
|
630
505
|
|
631
506
|
def print(self):
|
632
507
|
"""Print the complete trace with proper visual structure"""
|
633
|
-
for
|
634
|
-
|
635
|
-
|
636
|
-
def print_hierarchical(self):
|
637
|
-
"""Print the trace in a hierarchical structure based on parent-child relationships"""
|
638
|
-
# First, build a map of spans
|
639
|
-
spans = {}
|
640
|
-
root_spans = []
|
641
|
-
|
642
|
-
# Collect all enter events first
|
643
|
-
for entry in self.entries:
|
644
|
-
if entry.type == "enter":
|
645
|
-
spans[entry.function] = {
|
646
|
-
"name": entry.function,
|
647
|
-
"depth": entry.depth,
|
648
|
-
"parent_id": entry.parent_span_id,
|
649
|
-
"children": []
|
650
|
-
}
|
651
|
-
|
652
|
-
# If no parent, it's a root span
|
653
|
-
if not entry.parent_span_id:
|
654
|
-
root_spans.append(entry.function)
|
655
|
-
elif entry.parent_span_id not in spans:
|
656
|
-
# If parent doesn't exist yet, temporarily treat as root
|
657
|
-
# (we'll fix this later)
|
658
|
-
root_spans.append(entry.function)
|
659
|
-
|
660
|
-
# Build parent-child relationships
|
661
|
-
for span_name, span in spans.items():
|
662
|
-
parent = span["parent_id"]
|
663
|
-
if parent and parent in spans:
|
664
|
-
spans[parent]["children"].append(span_name)
|
665
|
-
# Remove from root spans if it was temporarily there
|
666
|
-
if span_name in root_spans:
|
667
|
-
root_spans.remove(span_name)
|
668
|
-
|
669
|
-
# Now print the hierarchy
|
670
|
-
def print_span(span_name, level=0):
|
671
|
-
if span_name not in spans:
|
672
|
-
return
|
673
|
-
|
674
|
-
span = spans[span_name]
|
675
|
-
indent = " " * level
|
676
|
-
parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
|
677
|
-
print(f"{indent}→ {span_name}{parent_info}")
|
678
|
-
|
679
|
-
# Print children
|
680
|
-
for child in span["children"]:
|
681
|
-
print_span(child, level + 1)
|
682
|
-
|
683
|
-
# Print starting with root spans
|
684
|
-
print("\nHierarchical Trace Structure:")
|
685
|
-
for root in root_spans:
|
686
|
-
print_span(root)
|
508
|
+
for span in self.trace_spans:
|
509
|
+
span.print_span()
|
687
510
|
|
688
511
|
def get_duration(self) -> float:
|
689
512
|
"""
|
690
513
|
Get the total duration of this trace
|
691
514
|
"""
|
692
515
|
return time.time() - self.start_time
|
693
|
-
|
694
|
-
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
695
|
-
"""
|
696
|
-
Condenses trace entries into a single entry for each span instance,
|
697
|
-
preserving parent-child span relationships using span_id and parent_span_id.
|
698
|
-
"""
|
699
|
-
spans_by_id: Dict[str, dict] = {}
|
700
|
-
evaluation_runs: List[EvaluationRun] = []
|
701
|
-
|
702
|
-
# First pass: Group entries by span_id and gather data
|
703
|
-
for entry in entries:
|
704
|
-
span_id = entry.get("span_id")
|
705
|
-
if not span_id:
|
706
|
-
continue # Skip entries without a span_id (should not happen)
|
707
|
-
|
708
|
-
if entry["type"] == "enter":
|
709
|
-
if span_id not in spans_by_id:
|
710
|
-
spans_by_id[span_id] = {
|
711
|
-
"span_id": span_id,
|
712
|
-
"function": entry["function"],
|
713
|
-
"depth": entry["depth"], # Use the depth recorded at entry time
|
714
|
-
"created_at": entry["created_at"],
|
715
|
-
"trace_id": entry["trace_id"],
|
716
|
-
"parent_span_id": entry.get("parent_span_id"),
|
717
|
-
"span_type": entry.get("span_type", "span"),
|
718
|
-
"inputs": None,
|
719
|
-
"output": None,
|
720
|
-
"evaluation_runs": [],
|
721
|
-
"duration": None
|
722
|
-
}
|
723
|
-
# Handle potential duplicate enter events if necessary (e.g., log warning)
|
724
|
-
|
725
|
-
elif span_id in spans_by_id:
|
726
|
-
current_span_data = spans_by_id[span_id]
|
727
|
-
|
728
|
-
if entry["type"] == "input" and entry["inputs"]:
|
729
|
-
# Merge inputs if multiple are recorded, or just assign
|
730
|
-
if current_span_data["inputs"] is None:
|
731
|
-
current_span_data["inputs"] = entry["inputs"]
|
732
|
-
elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
|
733
|
-
current_span_data["inputs"].update(entry["inputs"])
|
734
|
-
# Add more sophisticated merging if needed
|
735
|
-
|
736
|
-
elif entry["type"] == "output" and "output" in entry:
|
737
|
-
current_span_data["output"] = entry["output"]
|
738
|
-
|
739
|
-
elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
|
740
|
-
if current_span_data.get("evaluation_runs") is not None:
|
741
|
-
evaluation_runs.extend(entry["evaluation_runs"])
|
742
|
-
|
743
|
-
elif entry["type"] == "exit":
|
744
|
-
if current_span_data["duration"] is None: # Calculate duration only once
|
745
|
-
start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
|
746
|
-
end_time = datetime.fromisoformat(entry["created_at"])
|
747
|
-
current_span_data["duration"] = (end_time - start_time).total_seconds()
|
748
|
-
# Update depth if exit depth is different (though current span() implementation keeps it same)
|
749
|
-
# current_span_data["depth"] = entry["depth"]
|
750
|
-
|
751
|
-
# Convert dictionary to a list initially for easier access
|
752
|
-
spans_list = list(spans_by_id.values())
|
753
|
-
|
754
|
-
# Build tree structure (adjacency list) and find roots
|
755
|
-
children_map: Dict[Optional[str], List[dict]] = {}
|
756
|
-
roots = []
|
757
|
-
span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
|
758
|
-
|
759
|
-
for span in spans_list:
|
760
|
-
parent_id = span.get("parent_span_id")
|
761
|
-
if parent_id is None:
|
762
|
-
roots.append(span)
|
763
|
-
else:
|
764
|
-
if parent_id not in children_map:
|
765
|
-
children_map[parent_id] = []
|
766
|
-
children_map[parent_id].append(span)
|
767
|
-
|
768
|
-
# Sort roots by timestamp
|
769
|
-
roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
770
|
-
|
771
|
-
# Perform depth-first traversal to get the final sorted list
|
772
|
-
sorted_condensed_list = []
|
773
|
-
visited = set() # To handle potential cycles, though unlikely with UUIDs
|
774
|
-
|
775
|
-
def dfs(span_data):
|
776
|
-
span_id = span_data['span_id']
|
777
|
-
if span_id in visited:
|
778
|
-
return # Avoid infinite loops in case of cycles
|
779
|
-
visited.add(span_id)
|
780
|
-
|
781
|
-
sorted_condensed_list.append(span_data) # Add parent before children
|
782
|
-
|
783
|
-
# Get children, sort them by created_at, and visit them
|
784
|
-
span_children = children_map.get(span_id, [])
|
785
|
-
span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
786
|
-
for child in span_children:
|
787
|
-
# Ensure the child exists in our map before recursing
|
788
|
-
if child['span_id'] in span_map:
|
789
|
-
dfs(child)
|
790
|
-
else:
|
791
|
-
# This case might indicate an issue, but we'll add the child directly
|
792
|
-
# if its parent was processed but the child itself wasn't in the initial list?
|
793
|
-
# Or if the child's 'enter' event was missing. For robustness, add it.
|
794
|
-
if child['span_id'] not in visited:
|
795
|
-
visited.add(child['span_id'])
|
796
|
-
sorted_condensed_list.append(child)
|
797
|
-
|
798
|
-
|
799
|
-
# Start DFS from each root
|
800
|
-
for root_span in roots:
|
801
|
-
if root_span['span_id'] not in visited:
|
802
|
-
dfs(root_span)
|
803
|
-
|
804
|
-
# Handle spans that might not have been reachable from roots (orphans)
|
805
|
-
# Though ideally, all spans should descend from a root.
|
806
|
-
for span_data in spans_list:
|
807
|
-
if span_data['span_id'] not in visited:
|
808
|
-
# Decide how to handle orphans, maybe append them at the end sorted by time?
|
809
|
-
# For now, let's just add them to ensure they aren't lost.
|
810
|
-
sorted_condensed_list.append(span_data)
|
811
|
-
|
812
|
-
|
813
|
-
return sorted_condensed_list, evaluation_runs
|
814
516
|
|
815
517
|
def save(self, overwrite: bool = False) -> Tuple[str, dict]:
|
816
518
|
"""
|
@@ -819,103 +521,391 @@ class TraceClient:
|
|
819
521
|
"""
|
820
522
|
# Calculate total elapsed time
|
821
523
|
total_duration = self.get_duration()
|
822
|
-
|
823
|
-
raw_entries = [entry.to_dict() for entry in self.entries]
|
824
|
-
|
825
|
-
condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
|
826
524
|
|
827
|
-
# Calculate total token counts from LLM API calls
|
828
|
-
total_prompt_tokens = 0
|
829
|
-
total_completion_tokens = 0
|
830
|
-
total_tokens = 0
|
831
|
-
|
832
|
-
total_prompt_tokens_cost = 0.0
|
833
|
-
total_completion_tokens_cost = 0.0
|
834
|
-
total_cost = 0.0
|
835
|
-
|
836
525
|
# Only count tokens for actual LLM API call spans
|
837
526
|
llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
|
838
|
-
for
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
527
|
+
for span in self.trace_spans:
|
528
|
+
span_function_name = span.function # Get function name safely
|
529
|
+
# Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
|
530
|
+
is_llm_span = span.span_type == "llm"
|
531
|
+
has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
|
532
|
+
output_is_dict = isinstance(span.output, dict)
|
533
|
+
|
534
|
+
# --- DEBUG PRINT 1: Check if condition passes ---
|
535
|
+
# if is_llm_entry and has_api_suffix and output_is_dict:
|
536
|
+
# elif is_llm_entry:
|
537
|
+
# # Print why it failed if it was an LLM entry
|
538
|
+
# # --- END DEBUG ---
|
539
|
+
|
540
|
+
if is_llm_span and has_api_suffix and output_is_dict:
|
541
|
+
output = span.output
|
542
|
+
usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
|
543
|
+
|
544
|
+
# --- DEBUG PRINT 2: Check extracted usage ---
|
545
|
+
# --- END DEBUG ---
|
546
|
+
|
547
|
+
# --- NEW: Extract model_name correctly from nested inputs ---
|
548
|
+
model_name = None
|
549
|
+
span_inputs = span.inputs
|
550
|
+
if span_inputs:
|
551
|
+
# Try common locations for model name within the inputs structure
|
552
|
+
invocation_params = span_inputs.get("invocation_params", {})
|
553
|
+
serialized_data = span_inputs.get("serialized", {})
|
554
|
+
|
555
|
+
# Look in invocation_params (often directly contains model)
|
556
|
+
if isinstance(invocation_params, dict):
|
557
|
+
model_name = invocation_params.get("model")
|
558
|
+
|
559
|
+
# Fallback: Check serialized 'repr' if it contains model info
|
560
|
+
if not model_name and isinstance(serialized_data, dict):
|
561
|
+
serialized_repr = serialized_data.get("repr", "")
|
562
|
+
if "model_name=" in serialized_repr:
|
563
|
+
try: # Simple parsing attempt
|
564
|
+
model_name = serialized_repr.split("model_name='")[1].split("'")[0]
|
565
|
+
except IndexError: pass # Ignore parsing errors
|
566
|
+
|
567
|
+
# Fallback: Check top-level of invocation_params (sometimes passed flat)
|
568
|
+
if not model_name and isinstance(invocation_params, dict):
|
569
|
+
model_name = invocation_params.get("model") # Redundant check, but safe
|
570
|
+
|
571
|
+
# Fallback: Check top-level of inputs itself (less likely for callbacks)
|
572
|
+
if not model_name:
|
573
|
+
model_name = span_inputs.get("model")
|
574
|
+
|
575
|
+
|
576
|
+
# --- END NEW ---
|
577
|
+
|
843
578
|
prompt_tokens = 0
|
844
|
-
completion_tokens = 0
|
845
|
-
|
846
|
-
# Handle OpenAI/Together format
|
579
|
+
completion_tokens = 0
|
580
|
+
|
581
|
+
# Handle OpenAI/Together format (checks within the 'usage' dict)
|
847
582
|
if "prompt_tokens" in usage:
|
848
583
|
prompt_tokens = usage.get("prompt_tokens", 0)
|
849
584
|
completion_tokens = usage.get("completion_tokens", 0)
|
850
|
-
|
851
|
-
|
852
|
-
# Handle Anthropic format
|
585
|
+
|
586
|
+
# Handle Anthropic format - MAP values to standard keys
|
853
587
|
elif "input_tokens" in usage:
|
854
|
-
prompt_tokens = usage.get("input_tokens", 0)
|
855
|
-
completion_tokens = usage.get("output_tokens", 0)
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
588
|
+
prompt_tokens = usage.get("input_tokens", 0) # Get value from input_tokens
|
589
|
+
completion_tokens = usage.get("output_tokens", 0) # Get value from output_tokens
|
590
|
+
|
591
|
+
# *** Overwrite the usage dict in the entry to use standard keys ***
|
592
|
+
original_total = usage.get("total_tokens", 0)
|
593
|
+
original_total_cost = usage.get("total_cost_usd", 0.0) # Preserve if already calculated
|
594
|
+
# Recalculate cost just in case it wasn't done correctly before
|
595
|
+
temp_prompt_cost, temp_completion_cost = 0.0, 0.0
|
596
|
+
if model_name:
|
597
|
+
try:
|
598
|
+
temp_prompt_cost, temp_completion_cost = cost_per_token(
|
599
|
+
model=model_name,
|
600
|
+
prompt_tokens=prompt_tokens,
|
601
|
+
completion_tokens=completion_tokens
|
602
|
+
)
|
603
|
+
except Exception:
|
604
|
+
pass # Ignore cost calculation errors here, focus on keys
|
605
|
+
# Replace the usage dict with one using standard keys but Anthropic values
|
606
|
+
output["usage"] = {
|
607
|
+
"prompt_tokens": prompt_tokens,
|
608
|
+
"completion_tokens": completion_tokens,
|
609
|
+
"total_tokens": original_total,
|
610
|
+
"prompt_tokens_cost_usd": temp_prompt_cost, # Use standard cost key
|
611
|
+
"completion_tokens_cost_usd": temp_completion_cost, # Use standard cost key
|
612
|
+
"total_cost_usd": original_total_cost if original_total_cost > 0 else (temp_prompt_cost + temp_completion_cost)
|
613
|
+
}
|
614
|
+
usage = output["usage"]
|
615
|
+
|
616
|
+
# Calculate costs if model name is available and ensure they are stored with standard keys
|
617
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
618
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
860
619
|
|
861
620
|
# Calculate costs if model name is available
|
862
621
|
if model_name:
|
863
622
|
try:
|
623
|
+
# Recalculate costs based on potentially mapped tokens
|
864
624
|
prompt_cost, completion_cost = cost_per_token(
|
865
625
|
model=model_name,
|
866
626
|
prompt_tokens=prompt_tokens,
|
867
627
|
completion_tokens=completion_tokens
|
868
628
|
)
|
869
|
-
total_prompt_tokens_cost += prompt_cost
|
870
|
-
total_completion_tokens_cost += completion_cost
|
871
|
-
total_cost += prompt_cost + completion_cost
|
872
629
|
|
873
630
|
# Add cost information directly to the usage dictionary in the condensed entry
|
631
|
+
# Ensure 'usage' exists in the output dict before modifying it
|
632
|
+
# Add/Update cost information using standard keys
|
633
|
+
|
874
634
|
if "usage" not in output:
|
875
|
-
output["usage"] = {}
|
635
|
+
output["usage"] = {} # Initialize if missing
|
636
|
+
elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
|
637
|
+
print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
|
638
|
+
output["usage"] = {} # Reset to dict
|
639
|
+
|
876
640
|
output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
|
877
641
|
output["usage"]["completion_tokens_cost_usd"] = completion_cost
|
878
642
|
output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
|
879
643
|
except Exception as e:
|
880
644
|
# If cost calculation fails, continue without adding costs
|
881
|
-
print(f"Error calculating cost for model '{model_name}': {str(e)}")
|
645
|
+
print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
|
882
646
|
pass
|
647
|
+
else:
|
648
|
+
print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
|
883
649
|
|
884
|
-
|
650
|
+
|
651
|
+
# Create trace document - Always use standard keys for top-level counts
|
885
652
|
trace_data = {
|
886
653
|
"trace_id": self.trace_id,
|
887
654
|
"name": self.name,
|
888
655
|
"project_name": self.project_name,
|
889
656
|
"created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
|
890
657
|
"duration": total_duration,
|
891
|
-
"
|
892
|
-
|
893
|
-
"completion_tokens": total_completion_tokens,
|
894
|
-
"total_tokens": total_tokens,
|
895
|
-
"prompt_tokens_cost_usd": total_prompt_tokens_cost,
|
896
|
-
"completion_tokens_cost_usd": total_completion_tokens_cost,
|
897
|
-
"total_cost_usd": total_cost
|
898
|
-
},
|
899
|
-
"entries": condensed_entries,
|
900
|
-
"evaluation_runs": evaluation_runs,
|
658
|
+
"entries": [span.model_dump() for span in self.trace_spans],
|
659
|
+
"evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
|
901
660
|
"overwrite": overwrite,
|
902
661
|
"parent_trace_id": self.parent_trace_id,
|
903
662
|
"parent_name": self.parent_name
|
904
663
|
}
|
905
664
|
# --- Log trace data before saving ---
|
906
|
-
try:
|
907
|
-
rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
|
908
|
-
rprint(json.dumps(trace_data, indent=2))
|
909
|
-
except Exception as log_e:
|
910
|
-
rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
|
911
|
-
# --- End logging ---
|
912
665
|
self.trace_manager_client.save_trace(trace_data)
|
913
666
|
|
667
|
+
# upload annotations
|
668
|
+
# TODO: batch to the log endpoint
|
669
|
+
for annotation in self.annotations:
|
670
|
+
self.trace_manager_client.save_annotation(annotation)
|
671
|
+
|
914
672
|
return self.trace_id, trace_data
|
915
673
|
|
916
674
|
def delete(self):
|
917
675
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
918
676
|
|
677
|
+
|
678
|
+
class _DeepTracer:
|
679
|
+
_instance: Optional["_DeepTracer"] = None
|
680
|
+
_lock: threading.Lock = threading.Lock()
|
681
|
+
_refcount: int = 0
|
682
|
+
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
|
683
|
+
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
|
684
|
+
|
685
|
+
def _get_qual_name(self, frame) -> str:
|
686
|
+
func_name = frame.f_code.co_name
|
687
|
+
module_name = frame.f_globals.get("__name__", "unknown_module")
|
688
|
+
|
689
|
+
try:
|
690
|
+
func = frame.f_globals.get(func_name)
|
691
|
+
if func is None:
|
692
|
+
return f"{module_name}.{func_name}"
|
693
|
+
if hasattr(func, "__qualname__"):
|
694
|
+
return f"{module_name}.{func.__qualname__}"
|
695
|
+
except Exception:
|
696
|
+
return f"{module_name}.{func_name}"
|
697
|
+
|
698
|
+
def __new__(cls):
|
699
|
+
with cls._lock:
|
700
|
+
if cls._instance is None:
|
701
|
+
cls._instance = super().__new__(cls)
|
702
|
+
return cls._instance
|
703
|
+
|
704
|
+
def _should_trace(self, frame):
|
705
|
+
# Skip stack is maintained by the tracer as an optimization to skip earlier
|
706
|
+
# frames in the call stack that we've already determined should be skipped
|
707
|
+
skip_stack = self._skip_stack.get()
|
708
|
+
if len(skip_stack) > 0:
|
709
|
+
return False
|
710
|
+
|
711
|
+
func_name = frame.f_code.co_name
|
712
|
+
module_name = frame.f_globals.get("__name__", None)
|
713
|
+
|
714
|
+
func = frame.f_globals.get(func_name)
|
715
|
+
if func and (hasattr(func, '_judgment_span_name') or hasattr(func, '_judgment_span_type')):
|
716
|
+
return False
|
717
|
+
|
718
|
+
if (
|
719
|
+
not module_name
|
720
|
+
or func_name.startswith("<") # ex: <listcomp>
|
721
|
+
or func_name.startswith("__") and func_name != "__call__" # dunders
|
722
|
+
or not self._is_user_code(frame.f_code.co_filename)
|
723
|
+
):
|
724
|
+
return False
|
725
|
+
|
726
|
+
return True
|
727
|
+
|
728
|
+
@functools.cache
|
729
|
+
def _is_user_code(self, filename: str):
|
730
|
+
return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
731
|
+
|
732
|
+
def _trace(self, frame: types.FrameType, event: str, arg: Any):
|
733
|
+
frame.f_trace_lines = False
|
734
|
+
frame.f_trace_opcodes = False
|
735
|
+
|
736
|
+
|
737
|
+
if not self._should_trace(frame):
|
738
|
+
return
|
739
|
+
|
740
|
+
if event not in ("call", "return", "exception"):
|
741
|
+
return
|
742
|
+
|
743
|
+
current_trace = current_trace_var.get()
|
744
|
+
if not current_trace:
|
745
|
+
return
|
746
|
+
|
747
|
+
parent_span_id = current_span_var.get()
|
748
|
+
if not parent_span_id:
|
749
|
+
return
|
750
|
+
|
751
|
+
qual_name = self._get_qual_name(frame)
|
752
|
+
skip_stack = self._skip_stack.get()
|
753
|
+
|
754
|
+
if event == "call":
|
755
|
+
# If we have entries in the skip stack and the current qual_name matches the top entry,
|
756
|
+
# push it again to track nesting depth and skip
|
757
|
+
# As an optimization, we only care about duplicate qual_names.
|
758
|
+
if skip_stack:
|
759
|
+
if qual_name == skip_stack[-1]:
|
760
|
+
skip_stack.append(qual_name)
|
761
|
+
self._skip_stack.set(skip_stack)
|
762
|
+
return
|
763
|
+
|
764
|
+
should_trace = self._should_trace(frame)
|
765
|
+
|
766
|
+
if not should_trace:
|
767
|
+
if not skip_stack:
|
768
|
+
self._skip_stack.set([qual_name])
|
769
|
+
return
|
770
|
+
elif event == "return":
|
771
|
+
# If we have entries in skip stack and current qual_name matches the top entry,
|
772
|
+
# pop it to track exiting from the skipped section
|
773
|
+
if skip_stack and qual_name == skip_stack[-1]:
|
774
|
+
skip_stack.pop()
|
775
|
+
self._skip_stack.set(skip_stack)
|
776
|
+
return
|
777
|
+
|
778
|
+
if skip_stack:
|
779
|
+
return
|
780
|
+
|
781
|
+
span_stack = self._span_stack.get()
|
782
|
+
if event == "call":
|
783
|
+
if not self._should_trace(frame):
|
784
|
+
return
|
785
|
+
|
786
|
+
span_id = str(uuid.uuid4())
|
787
|
+
|
788
|
+
parent_depth = current_trace._span_depths.get(parent_span_id, 0)
|
789
|
+
depth = parent_depth + 1
|
790
|
+
|
791
|
+
current_trace._span_depths[span_id] = depth
|
792
|
+
|
793
|
+
start_time = time.time()
|
794
|
+
|
795
|
+
span_stack.append({
|
796
|
+
"span_id": span_id,
|
797
|
+
"parent_span_id": parent_span_id,
|
798
|
+
"function": qual_name,
|
799
|
+
"start_time": start_time
|
800
|
+
})
|
801
|
+
self._span_stack.set(span_stack)
|
802
|
+
|
803
|
+
token = current_span_var.set(span_id)
|
804
|
+
frame.f_locals["_judgment_span_token"] = token
|
805
|
+
|
806
|
+
span = TraceSpan(
|
807
|
+
span_id=span_id,
|
808
|
+
trace_id=current_trace.trace_id,
|
809
|
+
depth=depth,
|
810
|
+
message=qual_name,
|
811
|
+
created_at=start_time,
|
812
|
+
span_type="span",
|
813
|
+
parent_span_id=parent_span_id,
|
814
|
+
function=qual_name
|
815
|
+
)
|
816
|
+
current_trace.add_span(span)
|
817
|
+
|
818
|
+
inputs = {}
|
819
|
+
try:
|
820
|
+
args_info = inspect.getargvalues(frame)
|
821
|
+
for arg in args_info.args:
|
822
|
+
try:
|
823
|
+
inputs[arg] = args_info.locals.get(arg)
|
824
|
+
except:
|
825
|
+
inputs[arg] = "<<Unserializable>>"
|
826
|
+
current_trace.record_input(inputs)
|
827
|
+
except Exception as e:
|
828
|
+
current_trace.record_input({
|
829
|
+
"error": str(e)
|
830
|
+
})
|
831
|
+
|
832
|
+
elif event == "return":
|
833
|
+
if not span_stack:
|
834
|
+
return
|
835
|
+
|
836
|
+
current_id = current_span_var.get()
|
837
|
+
|
838
|
+
span_data = None
|
839
|
+
for i, entry in enumerate(reversed(span_stack)):
|
840
|
+
if entry["span_id"] == current_id:
|
841
|
+
span_data = span_stack.pop(-(i+1))
|
842
|
+
self._span_stack.set(span_stack)
|
843
|
+
break
|
844
|
+
|
845
|
+
if not span_data:
|
846
|
+
return
|
847
|
+
|
848
|
+
start_time = span_data["start_time"]
|
849
|
+
duration = time.time() - start_time
|
850
|
+
|
851
|
+
current_trace.span_id_to_span[span_data["span_id"]].duration = duration
|
852
|
+
|
853
|
+
if arg is not None:
|
854
|
+
# exception handling will take priority.
|
855
|
+
current_trace.record_output(arg)
|
856
|
+
|
857
|
+
if span_data["span_id"] in current_trace._span_depths:
|
858
|
+
del current_trace._span_depths[span_data["span_id"]]
|
859
|
+
|
860
|
+
if span_stack:
|
861
|
+
current_span_var.set(span_stack[-1]["span_id"])
|
862
|
+
else:
|
863
|
+
current_span_var.set(span_data["parent_span_id"])
|
864
|
+
|
865
|
+
if "_judgment_span_token" in frame.f_locals:
|
866
|
+
current_span_var.reset(frame.f_locals["_judgment_span_token"])
|
867
|
+
|
868
|
+
elif event == "exception":
|
869
|
+
exc_type, exc_value, exc_traceback = arg
|
870
|
+
formatted_exception = {
|
871
|
+
"type": exc_type.__name__,
|
872
|
+
"message": str(exc_value),
|
873
|
+
"traceback": traceback.format_tb(exc_traceback)
|
874
|
+
}
|
875
|
+
current_trace = current_trace_var.get()
|
876
|
+
current_trace.record_output({
|
877
|
+
"error": formatted_exception
|
878
|
+
})
|
879
|
+
|
880
|
+
return self._trace
|
881
|
+
|
882
|
+
def __enter__(self):
|
883
|
+
with self._lock:
|
884
|
+
self._refcount += 1
|
885
|
+
if self._refcount == 1:
|
886
|
+
self._skip_stack.set([])
|
887
|
+
self._span_stack.set([])
|
888
|
+
sys.settrace(self._trace)
|
889
|
+
threading.settrace(self._trace)
|
890
|
+
return self
|
891
|
+
|
892
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
893
|
+
with self._lock:
|
894
|
+
self._refcount -= 1
|
895
|
+
if self._refcount == 0:
|
896
|
+
sys.settrace(None)
|
897
|
+
threading.settrace(None)
|
898
|
+
|
899
|
+
|
900
|
+
def log(self, message: str, level: str = "info"):
|
901
|
+
""" Log a message with the span context """
|
902
|
+
current_trace = current_trace_var.get()
|
903
|
+
if current_trace:
|
904
|
+
current_trace.log(message, level)
|
905
|
+
else:
|
906
|
+
print(f"[{level}] {message}")
|
907
|
+
current_trace.record_output({"log": message})
|
908
|
+
|
919
909
|
class Tracer:
|
920
910
|
_instance = None
|
921
911
|
|
@@ -938,12 +928,16 @@ class Tracer:
|
|
938
928
|
s3_aws_access_key_id: Optional[str] = None,
|
939
929
|
s3_aws_secret_access_key: Optional[str] = None,
|
940
930
|
s3_region_name: Optional[str] = None,
|
941
|
-
deep_tracing: bool = True #
|
931
|
+
deep_tracing: bool = True # Deep tracing is enabled by default
|
942
932
|
):
|
943
933
|
if not hasattr(self, 'initialized'):
|
944
934
|
if not api_key:
|
945
935
|
raise ValueError("Tracer must be configured with a Judgment API key")
|
946
936
|
|
937
|
+
result, response = validate_api_key(api_key)
|
938
|
+
if not result:
|
939
|
+
raise JudgmentAPIError(f"Issue with passed in Judgment API key: {response}")
|
940
|
+
|
947
941
|
if not organization_id:
|
948
942
|
raise ValueError("Tracer must be configured with an Organization ID")
|
949
943
|
if use_s3 and not s3_bucket_name:
|
@@ -955,10 +949,11 @@ class Tracer:
|
|
955
949
|
|
956
950
|
self.api_key: str = api_key
|
957
951
|
self.project_name: str = project_name
|
958
|
-
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
959
952
|
self.organization_id: str = organization_id
|
960
953
|
self._current_trace: Optional[str] = None
|
954
|
+
self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
|
961
955
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
956
|
+
self.traces: List[Trace] = []
|
962
957
|
self.initialized: bool = True
|
963
958
|
self.enable_monitoring: bool = enable_monitoring
|
964
959
|
self.enable_evaluations: bool = enable_evaluations
|
@@ -991,49 +986,29 @@ class Tracer:
|
|
991
986
|
|
992
987
|
def get_current_trace(self) -> Optional[TraceClient]:
|
993
988
|
"""
|
994
|
-
Get the current trace context
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
989
|
+
Get the current trace context.
|
990
|
+
|
991
|
+
Tries to get the trace client from the context variable first.
|
992
|
+
If not found (e.g., context lost across threads/tasks),
|
993
|
+
it falls back to the active trace client managed by the callback handler.
|
999
994
|
"""
|
1000
|
-
|
995
|
+
trace_from_context = current_trace_var.get()
|
996
|
+
if trace_from_context:
|
997
|
+
return trace_from_context
|
1001
998
|
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
999
|
+
# Fallback: Check the active client potentially set by a callback handler
|
1000
|
+
if hasattr(self, '_active_trace_client') and self._active_trace_client:
|
1001
|
+
# warnings.warn("Falling back to _active_trace_client in get_current_trace. ContextVar might be lost.", RuntimeWarning)
|
1002
|
+
return self._active_trace_client
|
1005
1003
|
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
"""
|
1010
|
-
module = inspect.getmodule(func)
|
1011
|
-
if not module:
|
1012
|
-
return None, {}
|
1013
|
-
|
1014
|
-
# Save original functions
|
1015
|
-
original_functions = {}
|
1004
|
+
# If neither is available
|
1005
|
+
# warnings.warn("No current trace found in context variable or active client fallback.", RuntimeWarning)
|
1006
|
+
return None
|
1016
1007
|
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
continue
|
1022
|
-
|
1023
|
-
# Create a traced version of the function
|
1024
|
-
# Always use default span type "span" for child functions
|
1025
|
-
traced_func = _create_deep_tracing_wrapper(obj, self, "span")
|
1026
|
-
|
1027
|
-
# Mark the function as traced to avoid double wrapping
|
1028
|
-
traced_func._judgment_traced = True
|
1029
|
-
|
1030
|
-
# Save the original function
|
1031
|
-
original_functions[name] = obj
|
1032
|
-
|
1033
|
-
# Replace with traced version
|
1034
|
-
setattr(module, name, traced_func)
|
1035
|
-
|
1036
|
-
return module, original_functions
|
1008
|
+
def get_active_trace_client(self) -> Optional[TraceClient]:
|
1009
|
+
"""Returns the TraceClient instance currently marked as active by the handler."""
|
1010
|
+
return self._active_trace_client
|
1011
|
+
|
1037
1012
|
|
1038
1013
|
@contextmanager
|
1039
1014
|
def trace(
|
@@ -1080,6 +1055,23 @@ class Tracer:
|
|
1080
1055
|
finally:
|
1081
1056
|
# Reset the context variable
|
1082
1057
|
current_trace_var.reset(token)
|
1058
|
+
|
1059
|
+
|
1060
|
+
def log(self, msg: str, label: str = "log", score: int = 1):
|
1061
|
+
"""Log a message with the current span context"""
|
1062
|
+
current_span_id = current_span_var.get()
|
1063
|
+
current_trace = current_trace_var.get()
|
1064
|
+
if current_span_id:
|
1065
|
+
annotation = TraceAnnotation(
|
1066
|
+
span_id=current_span_id,
|
1067
|
+
text=msg,
|
1068
|
+
label=label,
|
1069
|
+
score=score
|
1070
|
+
)
|
1071
|
+
|
1072
|
+
current_trace.add_annotation(annotation)
|
1073
|
+
|
1074
|
+
rprint(f"[bold]{label}:[/bold] {msg}")
|
1083
1075
|
|
1084
1076
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
|
1085
1077
|
"""
|
@@ -1115,13 +1107,6 @@ class Tracer:
|
|
1115
1107
|
if asyncio.iscoroutinefunction(func):
|
1116
1108
|
@functools.wraps(func)
|
1117
1109
|
async def async_wrapper(*args, **kwargs):
|
1118
|
-
# Check if we're already in a traced function
|
1119
|
-
if in_traced_function_var.get():
|
1120
|
-
return await func(*args, **kwargs)
|
1121
|
-
|
1122
|
-
# Set in_traced_function_var to True
|
1123
|
-
token = in_traced_function_var.set(True)
|
1124
|
-
|
1125
1110
|
# Get current trace from context
|
1126
1111
|
current_trace = current_trace_var.get()
|
1127
1112
|
|
@@ -1151,81 +1136,47 @@ class Tracer:
|
|
1151
1136
|
# This sets the current_span_var
|
1152
1137
|
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1153
1138
|
# Record inputs
|
1154
|
-
|
1155
|
-
|
1156
|
-
'kwargs': kwargs
|
1157
|
-
})
|
1139
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1140
|
+
span.record_input(inputs)
|
1158
1141
|
|
1159
|
-
# If deep tracing is enabled, apply monkey patching
|
1160
1142
|
if use_deep_tracing:
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
# Restore original functions if deep tracing was enabled
|
1167
|
-
if use_deep_tracing and module and 'original_functions' in locals():
|
1168
|
-
for name, obj in original_functions.items():
|
1169
|
-
setattr(module, name, obj)
|
1170
|
-
|
1143
|
+
with _DeepTracer():
|
1144
|
+
result = await func(*args, **kwargs)
|
1145
|
+
else:
|
1146
|
+
result = await func(*args, **kwargs)
|
1147
|
+
|
1171
1148
|
# Record output
|
1172
1149
|
span.record_output(result)
|
1173
|
-
|
1174
|
-
# Save the completed trace
|
1175
|
-
current_trace.save(overwrite=overwrite)
|
1176
1150
|
return result
|
1177
1151
|
finally:
|
1152
|
+
# Save the completed trace
|
1153
|
+
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1154
|
+
self.traces.append(trace)
|
1155
|
+
|
1178
1156
|
# Reset trace context (span context resets automatically)
|
1179
1157
|
current_trace_var.reset(trace_token)
|
1180
|
-
# Reset in_traced_function_var
|
1181
|
-
in_traced_function_var.reset(token)
|
1182
1158
|
else:
|
1183
|
-
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1191
|
-
'kwargs': kwargs
|
1192
|
-
})
|
1193
|
-
|
1194
|
-
# If deep tracing is enabled, apply monkey patching
|
1195
|
-
if use_deep_tracing:
|
1196
|
-
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1197
|
-
|
1198
|
-
# Execute function
|
1159
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1160
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1161
|
+
span.record_input(inputs)
|
1162
|
+
|
1163
|
+
if use_deep_tracing:
|
1164
|
+
with _DeepTracer():
|
1165
|
+
result = await func(*args, **kwargs)
|
1166
|
+
else:
|
1199
1167
|
result = await func(*args, **kwargs)
|
1200
1168
|
|
1201
|
-
|
1202
|
-
|
1203
|
-
|
1204
|
-
setattr(module, name, obj)
|
1205
|
-
|
1206
|
-
# Record output
|
1207
|
-
span.record_output(result)
|
1208
|
-
|
1209
|
-
return result
|
1210
|
-
finally:
|
1211
|
-
# Reset in_traced_function_var
|
1212
|
-
in_traced_function_var.reset(token)
|
1213
|
-
|
1169
|
+
span.record_output(result)
|
1170
|
+
return result
|
1171
|
+
|
1214
1172
|
return async_wrapper
|
1215
1173
|
else:
|
1216
1174
|
# Non-async function implementation with deep tracing
|
1217
1175
|
@functools.wraps(func)
|
1218
|
-
def wrapper(*args, **kwargs):
|
1219
|
-
# Check if we're already in a traced function
|
1220
|
-
if in_traced_function_var.get():
|
1221
|
-
return func(*args, **kwargs)
|
1222
|
-
|
1223
|
-
# Set in_traced_function_var to True
|
1224
|
-
token = in_traced_function_var.set(True)
|
1225
|
-
|
1176
|
+
def wrapper(*args, **kwargs):
|
1226
1177
|
# Get current trace from context
|
1227
1178
|
current_trace = current_trace_var.get()
|
1228
|
-
|
1179
|
+
|
1229
1180
|
# If there's no current trace, create a root trace
|
1230
1181
|
if not current_trace:
|
1231
1182
|
trace_id = str(uuid.uuid4())
|
@@ -1252,105 +1203,65 @@ class Tracer:
|
|
1252
1203
|
# This sets the current_span_var
|
1253
1204
|
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1254
1205
|
# Record inputs
|
1255
|
-
|
1256
|
-
|
1257
|
-
'kwargs': kwargs
|
1258
|
-
})
|
1206
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1207
|
+
span.record_input(inputs)
|
1259
1208
|
|
1260
|
-
# If deep tracing is enabled, apply monkey patching
|
1261
1209
|
if use_deep_tracing:
|
1262
|
-
|
1263
|
-
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
# Restore original functions if deep tracing was enabled
|
1268
|
-
if use_deep_tracing and module and 'original_functions' in locals():
|
1269
|
-
for name, obj in original_functions.items():
|
1270
|
-
setattr(module, name, obj)
|
1210
|
+
with _DeepTracer():
|
1211
|
+
result = func(*args, **kwargs)
|
1212
|
+
else:
|
1213
|
+
result = func(*args, **kwargs)
|
1271
1214
|
|
1272
1215
|
# Record output
|
1273
1216
|
span.record_output(result)
|
1274
|
-
|
1275
|
-
# Save the completed trace
|
1276
|
-
current_trace.save(overwrite=overwrite)
|
1277
1217
|
return result
|
1278
1218
|
finally:
|
1219
|
+
# Save the completed trace
|
1220
|
+
trace_id, trace = current_trace.save(overwrite=overwrite)
|
1221
|
+
self.traces.append(trace)
|
1222
|
+
|
1279
1223
|
# Reset trace context (span context resets automatically)
|
1280
1224
|
current_trace_var.reset(trace_token)
|
1281
|
-
# Reset in_traced_function_var
|
1282
|
-
in_traced_function_var.reset(token)
|
1283
1225
|
else:
|
1284
|
-
|
1285
|
-
|
1286
|
-
|
1287
|
-
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
})
|
1294
|
-
|
1295
|
-
# If deep tracing is enabled, apply monkey patching
|
1296
|
-
if use_deep_tracing:
|
1297
|
-
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1298
|
-
|
1299
|
-
# Execute function
|
1226
|
+
with current_trace.span(span_name, span_type=span_type) as span:
|
1227
|
+
|
1228
|
+
inputs = combine_args_kwargs(func, args, kwargs)
|
1229
|
+
span.record_input(inputs)
|
1230
|
+
|
1231
|
+
if use_deep_tracing:
|
1232
|
+
with _DeepTracer():
|
1233
|
+
result = func(*args, **kwargs)
|
1234
|
+
else:
|
1300
1235
|
result = func(*args, **kwargs)
|
1301
1236
|
|
1302
|
-
|
1303
|
-
|
1304
|
-
|
1305
|
-
setattr(module, name, obj)
|
1306
|
-
|
1307
|
-
# Record output
|
1308
|
-
span.record_output(result)
|
1309
|
-
|
1310
|
-
return result
|
1311
|
-
finally:
|
1312
|
-
# Reset in_traced_function_var
|
1313
|
-
in_traced_function_var.reset(token)
|
1314
|
-
|
1315
|
-
return wrapper
|
1316
|
-
|
1317
|
-
def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
|
1318
|
-
"""
|
1319
|
-
Decorator to trace function execution with detailed entry/exit information.
|
1320
|
-
"""
|
1321
|
-
if func is None:
|
1322
|
-
return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
|
1323
|
-
|
1324
|
-
if asyncio.iscoroutinefunction(func):
|
1325
|
-
@functools.wraps(func)
|
1326
|
-
async def async_wrapper(*args, **kwargs):
|
1327
|
-
# Get current trace from contextvars
|
1328
|
-
current_trace = current_trace_var.get()
|
1329
|
-
if current_trace and scorers:
|
1330
|
-
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1331
|
-
return await func(*args, **kwargs)
|
1332
|
-
return async_wrapper
|
1333
|
-
else:
|
1334
|
-
@functools.wraps(func)
|
1335
|
-
def wrapper(*args, **kwargs):
|
1336
|
-
# Get current trace from contextvars
|
1337
|
-
current_trace = current_trace_var.get()
|
1338
|
-
if current_trace and scorers:
|
1339
|
-
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1340
|
-
return func(*args, **kwargs)
|
1237
|
+
span.record_output(result)
|
1238
|
+
return result
|
1239
|
+
|
1341
1240
|
return wrapper
|
1342
1241
|
|
1343
1242
|
def async_evaluate(self, *args, **kwargs):
|
1344
1243
|
if not self.enable_evaluations:
|
1345
1244
|
return
|
1346
1245
|
|
1347
|
-
# Get
|
1246
|
+
# --- Get trace_id passed explicitly (if any) ---
|
1247
|
+
passed_trace_id = kwargs.pop('trace_id', None) # Get and remove trace_id from kwargs
|
1248
|
+
|
1249
|
+
# --- Get current trace from context FIRST ---
|
1348
1250
|
current_trace = current_trace_var.get()
|
1349
|
-
|
1251
|
+
|
1252
|
+
# --- Fallback Logic: Use active client only if context var is empty ---
|
1253
|
+
if not current_trace:
|
1254
|
+
current_trace = self._active_trace_client # Use the fallback
|
1255
|
+
# --- End Fallback Logic ---
|
1256
|
+
|
1350
1257
|
if current_trace:
|
1258
|
+
# Pass the explicitly provided trace_id if it exists, otherwise let async_evaluate handle it
|
1259
|
+
# (Note: TraceClient.async_evaluate doesn't currently use an explicit trace_id, but this is for future proofing/consistency)
|
1260
|
+
if passed_trace_id:
|
1261
|
+
kwargs['trace_id'] = passed_trace_id # Re-add if needed by TraceClient.async_evaluate
|
1351
1262
|
current_trace.async_evaluate(*args, **kwargs)
|
1352
1263
|
else:
|
1353
|
-
warnings.warn("No trace found, skipping evaluation")
|
1264
|
+
warnings.warn("No trace found (context var or fallback), skipping evaluation") # Modified warning
|
1354
1265
|
|
1355
1266
|
|
1356
1267
|
def wrap(client: Any) -> Any:
|
@@ -1359,7 +1270,7 @@ def wrap(client: Any) -> Any:
|
|
1359
1270
|
Supports OpenAI, Together, Anthropic, and Google GenAI clients.
|
1360
1271
|
Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
|
1361
1272
|
"""
|
1362
|
-
span_name, original_create, original_stream = _get_client_config(client)
|
1273
|
+
span_name, original_create, responses_create, original_stream = _get_client_config(client)
|
1363
1274
|
|
1364
1275
|
# --- Define Traced Async Functions ---
|
1365
1276
|
async def traced_create_async(*args, **kwargs):
|
@@ -1457,7 +1368,41 @@ def wrap(client: Any) -> Any:
|
|
1457
1368
|
span.record_output(output_data)
|
1458
1369
|
return response_or_iterator
|
1459
1370
|
|
1371
|
+
# --- Define Traced Sync Functions ---
|
1372
|
+
def traced_response_create_sync(*args, **kwargs):
|
1373
|
+
# [Existing logic - unchanged]
|
1374
|
+
current_trace = current_trace_var.get()
|
1375
|
+
if not current_trace:
|
1376
|
+
return responses_create(*args, **kwargs)
|
1377
|
+
|
1378
|
+
is_streaming = kwargs.get("stream", False)
|
1379
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1380
|
+
span.record_input(kwargs)
|
1381
|
+
|
1382
|
+
# Warn about token counting limitations with streaming
|
1383
|
+
if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
|
1384
|
+
if not kwargs.get("stream_options", {}).get("include_usage"):
|
1385
|
+
warnings.warn(
|
1386
|
+
"OpenAI streaming calls don't include token counts by default. "
|
1387
|
+
"To enable token counting with streams, set stream_options={'include_usage': True} "
|
1388
|
+
"in your API call arguments.",
|
1389
|
+
UserWarning
|
1390
|
+
)
|
1460
1391
|
|
1392
|
+
try:
|
1393
|
+
response_or_iterator = responses_create(*args, **kwargs)
|
1394
|
+
except Exception as e:
|
1395
|
+
print(f"Error during wrapped sync API call ({span_name}): {e}")
|
1396
|
+
span.record_output({"error": str(e)})
|
1397
|
+
raise
|
1398
|
+
if is_streaming:
|
1399
|
+
output_entry = span.record_output("<pending stream>")
|
1400
|
+
return _sync_stream_wrapper(response_or_iterator, client, output_entry)
|
1401
|
+
else:
|
1402
|
+
output_data = _format_response_output_data(client, response_or_iterator)
|
1403
|
+
span.record_output(output_data)
|
1404
|
+
return response_or_iterator
|
1405
|
+
|
1461
1406
|
# Function replacing sync .stream()
|
1462
1407
|
def traced_stream_sync(*args, **kwargs):
|
1463
1408
|
current_trace = current_trace_var.get()
|
@@ -1505,15 +1450,16 @@ def wrap(client: Any) -> Any:
|
|
1505
1450
|
if original_stream:
|
1506
1451
|
client.messages.stream = traced_stream_async
|
1507
1452
|
elif isinstance(client, genai.client.AsyncClient):
|
1508
|
-
client.generate_content = traced_create_async
|
1453
|
+
client.models.generate_content = traced_create_async
|
1509
1454
|
elif isinstance(client, (OpenAI, Together)):
|
1510
1455
|
client.chat.completions.create = traced_create_sync
|
1456
|
+
client.responses.create = traced_response_create_sync
|
1511
1457
|
elif isinstance(client, Anthropic):
|
1512
1458
|
client.messages.create = traced_create_sync
|
1513
1459
|
if original_stream:
|
1514
1460
|
client.messages.stream = traced_stream_sync
|
1515
1461
|
elif isinstance(client, genai.Client):
|
1516
|
-
client.generate_content = traced_create_sync
|
1462
|
+
client.models.generate_content = traced_create_sync
|
1517
1463
|
|
1518
1464
|
return client
|
1519
1465
|
|
@@ -1529,19 +1475,20 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[calla
|
|
1529
1475
|
tuple: (span_name, create_method, stream_method)
|
1530
1476
|
- span_name: String identifier for tracing
|
1531
1477
|
- create_method: Reference to the client's creation method
|
1478
|
+
- responses_method: Reference to the client's responses method (if applicable)
|
1532
1479
|
- stream_method: Reference to the client's stream method (if applicable)
|
1533
1480
|
|
1534
1481
|
Raises:
|
1535
1482
|
ValueError: If client type is not supported
|
1536
1483
|
"""
|
1537
1484
|
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1538
|
-
return "OPENAI_API_CALL", client.chat.completions.create, None
|
1485
|
+
return "OPENAI_API_CALL", client.chat.completions.create, client.responses.create, None
|
1539
1486
|
elif isinstance(client, (Together, AsyncTogether)):
|
1540
|
-
return "TOGETHER_API_CALL", client.chat.completions.create, None
|
1487
|
+
return "TOGETHER_API_CALL", client.chat.completions.create, None, None
|
1541
1488
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1542
|
-
return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
|
1489
|
+
return "ANTHROPIC_API_CALL", client.messages.create, None, client.messages.stream
|
1543
1490
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1544
|
-
return "GOOGLE_API_CALL", client.models.generate_content, None
|
1491
|
+
return "GOOGLE_API_CALL", client.models.generate_content, None, None
|
1545
1492
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1546
1493
|
|
1547
1494
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
@@ -1567,6 +1514,26 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
1567
1514
|
"max_tokens": kwargs.get("max_tokens")
|
1568
1515
|
}
|
1569
1516
|
|
1517
|
+
def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
1518
|
+
"""Format API response data based on client type.
|
1519
|
+
|
1520
|
+
Normalizes different response formats into a consistent structure
|
1521
|
+
for tracing purposes.
|
1522
|
+
"""
|
1523
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1524
|
+
return {
|
1525
|
+
"content": response.output,
|
1526
|
+
"usage": {
|
1527
|
+
"prompt_tokens": response.usage.input_tokens,
|
1528
|
+
"completion_tokens": response.usage.output_tokens,
|
1529
|
+
"total_tokens": response.usage.total_tokens
|
1530
|
+
}
|
1531
|
+
}
|
1532
|
+
else:
|
1533
|
+
warnings.warn(f"Unsupported client type: {type(client)}")
|
1534
|
+
return {}
|
1535
|
+
|
1536
|
+
|
1570
1537
|
def _format_output_data(client: ApiClient, response: Any) -> dict:
|
1571
1538
|
"""Format API response data based on client type.
|
1572
1539
|
|
@@ -1600,123 +1567,57 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1600
1567
|
return {
|
1601
1568
|
"content": response.content[0].text,
|
1602
1569
|
"usage": {
|
1603
|
-
"
|
1604
|
-
"
|
1570
|
+
"prompt_tokens": response.usage.input_tokens,
|
1571
|
+
"completion_tokens": response.usage.output_tokens,
|
1605
1572
|
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
1606
1573
|
}
|
1607
1574
|
}
|
1608
1575
|
|
1609
|
-
|
1610
|
-
# These are typically utility functions, print statements, logging, etc.
|
1611
|
-
_TRACE_BLOCKLIST = {
|
1612
|
-
# Built-in functions
|
1613
|
-
'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
|
1614
|
-
'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
|
1615
|
-
'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
|
1616
|
-
# Logging functions
|
1617
|
-
'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
|
1618
|
-
# Common utility functions
|
1619
|
-
'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
|
1620
|
-
# String operations
|
1621
|
-
'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
|
1622
|
-
# Dict operations
|
1623
|
-
'get', 'items', 'keys', 'values', 'update',
|
1624
|
-
# List operations
|
1625
|
-
'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
|
1626
|
-
}
|
1627
|
-
|
1628
|
-
|
1629
|
-
# Add a new function for deep tracing at the module level
|
1630
|
-
def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
|
1576
|
+
def combine_args_kwargs(func, args, kwargs):
|
1631
1577
|
"""
|
1632
|
-
|
1633
|
-
This enables deep tracing without requiring explicit @observe decorators on every function.
|
1578
|
+
Combine positional arguments and keyword arguments into a single dictionary.
|
1634
1579
|
|
1635
1580
|
Args:
|
1636
|
-
func: The function
|
1637
|
-
|
1638
|
-
|
1581
|
+
func: The function being called
|
1582
|
+
args: Tuple of positional arguments
|
1583
|
+
kwargs: Dictionary of keyword arguments
|
1639
1584
|
|
1640
1585
|
Returns:
|
1641
|
-
A
|
1586
|
+
A dictionary combining both args and kwargs
|
1642
1587
|
"""
|
1643
|
-
|
1644
|
-
|
1645
|
-
|
1646
|
-
|
1647
|
-
|
1648
|
-
|
1649
|
-
|
1650
|
-
|
1651
|
-
|
1652
|
-
|
1653
|
-
|
1654
|
-
|
1655
|
-
|
1656
|
-
|
1657
|
-
|
1658
|
-
|
1659
|
-
|
1660
|
-
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1669
|
-
|
1670
|
-
|
1671
|
-
|
1672
|
-
|
1673
|
-
|
1674
|
-
|
1675
|
-
|
1676
|
-
# Create a span for this function call - use custom span_type if available
|
1677
|
-
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1678
|
-
# Record inputs
|
1679
|
-
span.record_input({
|
1680
|
-
'args': str(args),
|
1681
|
-
'kwargs': kwargs
|
1682
|
-
})
|
1683
|
-
|
1684
|
-
# Execute function
|
1685
|
-
result = await original_func(*args, **kwargs)
|
1686
|
-
|
1687
|
-
# Record output
|
1688
|
-
span.record_output(result)
|
1689
|
-
|
1690
|
-
return result
|
1691
|
-
|
1692
|
-
return async_deep_wrapper
|
1693
|
-
else:
|
1694
|
-
@functools.wraps(func)
|
1695
|
-
def deep_wrapper(*args, **kwargs):
|
1696
|
-
# Get current trace from context
|
1697
|
-
current_trace = current_trace_var.get()
|
1698
|
-
|
1699
|
-
# If no trace context, just call the function
|
1700
|
-
if not current_trace:
|
1701
|
-
return original_func(*args, **kwargs)
|
1702
|
-
|
1703
|
-
# Create a span for this function call - use custom span_type if available
|
1704
|
-
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1705
|
-
# Record inputs
|
1706
|
-
span.record_input({
|
1707
|
-
'args': str(args),
|
1708
|
-
'kwargs': kwargs
|
1709
|
-
})
|
1710
|
-
|
1711
|
-
# Execute function
|
1712
|
-
result = original_func(*args, **kwargs)
|
1713
|
-
|
1714
|
-
# Record output
|
1715
|
-
span.record_output(result)
|
1716
|
-
|
1717
|
-
return result
|
1718
|
-
|
1719
|
-
return deep_wrapper
|
1588
|
+
try:
|
1589
|
+
import inspect
|
1590
|
+
sig = inspect.signature(func)
|
1591
|
+
param_names = list(sig.parameters.keys())
|
1592
|
+
|
1593
|
+
args_dict = {}
|
1594
|
+
for i, arg in enumerate(args):
|
1595
|
+
if i < len(param_names):
|
1596
|
+
args_dict[param_names[i]] = arg
|
1597
|
+
else:
|
1598
|
+
args_dict[f"arg{i}"] = arg
|
1599
|
+
|
1600
|
+
return {**args_dict, **kwargs}
|
1601
|
+
except Exception as e:
|
1602
|
+
# Fallback if signature inspection fails
|
1603
|
+
return {**{f"arg{i}": arg for i, arg in enumerate(args)}, **kwargs}
|
1604
|
+
|
1605
|
+
# NOTE: This builds once, can be tweaked if we are missing / capturing other unncessary modules
|
1606
|
+
# @link https://docs.python.org/3.13/library/sysconfig.html
|
1607
|
+
_TRACE_FILEPATH_BLOCKLIST = tuple(
|
1608
|
+
os.path.realpath(p) + os.sep
|
1609
|
+
for p in {
|
1610
|
+
sysconfig.get_paths()['stdlib'],
|
1611
|
+
sysconfig.get_paths().get('platstdlib', ''),
|
1612
|
+
*site.getsitepackages(),
|
1613
|
+
site.getusersitepackages(),
|
1614
|
+
*(
|
1615
|
+
[os.path.join(os.path.dirname(__file__), '../../judgeval/')]
|
1616
|
+
if os.environ.get('JUDGMENT_DEV')
|
1617
|
+
else []
|
1618
|
+
),
|
1619
|
+
} if p
|
1620
|
+
)
|
1720
1621
|
|
1721
1622
|
# Add the new TraceThreadPoolExecutor class
|
1722
1623
|
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
@@ -1819,7 +1720,7 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
1819
1720
|
def _sync_stream_wrapper(
|
1820
1721
|
original_stream: Iterator,
|
1821
1722
|
client: ApiClient,
|
1822
|
-
|
1723
|
+
span: TraceSpan
|
1823
1724
|
) -> Generator[Any, None, None]:
|
1824
1725
|
"""Wraps a synchronous stream iterator to capture content and update the trace."""
|
1825
1726
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -1838,7 +1739,7 @@ def _sync_stream_wrapper(
|
|
1838
1739
|
final_usage = _extract_usage_from_final_chunk(client, last_chunk)
|
1839
1740
|
|
1840
1741
|
# Update the trace entry with the accumulated content and usage
|
1841
|
-
|
1742
|
+
span.output = {
|
1842
1743
|
"content": "".join(content_parts), # Join list at the end
|
1843
1744
|
"usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
|
1844
1745
|
"streamed": True
|
@@ -1850,7 +1751,7 @@ def _sync_stream_wrapper(
|
|
1850
1751
|
async def _async_stream_wrapper(
|
1851
1752
|
original_stream: AsyncIterator,
|
1852
1753
|
client: ApiClient,
|
1853
|
-
|
1754
|
+
span: TraceSpan
|
1854
1755
|
) -> AsyncGenerator[Any, None]:
|
1855
1756
|
# [Existing logic - unchanged]
|
1856
1757
|
content_parts = [] # Use a list instead of string concatenation
|
@@ -1859,7 +1760,7 @@ async def _async_stream_wrapper(
|
|
1859
1760
|
anthropic_input_tokens = 0
|
1860
1761
|
anthropic_output_tokens = 0
|
1861
1762
|
|
1862
|
-
target_span_id =
|
1763
|
+
target_span_id = span.span_id
|
1863
1764
|
|
1864
1765
|
try:
|
1865
1766
|
async for chunk in original_stream:
|
@@ -1891,8 +1792,8 @@ async def _async_stream_wrapper(
|
|
1891
1792
|
anthropic_final_usage = None
|
1892
1793
|
if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
|
1893
1794
|
anthropic_final_usage = {
|
1894
|
-
"
|
1895
|
-
"
|
1795
|
+
"prompt_tokens": anthropic_input_tokens,
|
1796
|
+
"completion_tokens": anthropic_output_tokens,
|
1896
1797
|
"total_tokens": anthropic_input_tokens + anthropic_output_tokens
|
1897
1798
|
}
|
1898
1799
|
|
@@ -1904,19 +1805,17 @@ async def _async_stream_wrapper(
|
|
1904
1805
|
elif last_content_chunk:
|
1905
1806
|
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
1906
1807
|
|
1907
|
-
if
|
1908
|
-
|
1808
|
+
if span and hasattr(span, 'output'):
|
1809
|
+
span.output = {
|
1909
1810
|
"content": "".join(content_parts), # Join list at the end
|
1910
1811
|
"usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
|
1911
1812
|
"streamed": True
|
1912
1813
|
}
|
1913
|
-
start_ts = getattr(
|
1914
|
-
|
1814
|
+
start_ts = getattr(span, 'created_at', time.time())
|
1815
|
+
span.duration = time.time() - start_ts
|
1915
1816
|
# else: # Handle error case if necessary, but remove debug print
|
1916
1817
|
|
1917
|
-
|
1918
|
-
class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
1919
|
-
"""Wraps an original async stream manager to add tracing."""
|
1818
|
+
class _BaseStreamManagerWrapper:
|
1920
1819
|
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
1921
1820
|
self._original_manager = original_manager
|
1922
1821
|
self._client = client
|
@@ -1926,157 +1825,199 @@ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
|
|
1926
1825
|
self._input_kwargs = input_kwargs
|
1927
1826
|
self._parent_span_id_at_entry = None
|
1928
1827
|
|
1929
|
-
|
1930
|
-
self._parent_span_id_at_entry = current_span_var.get()
|
1931
|
-
if not self._trace_client:
|
1932
|
-
# If no trace, just delegate to the original manager
|
1933
|
-
return await self._original_manager.__aenter__()
|
1934
|
-
|
1935
|
-
# --- Manually create the 'enter' entry ---
|
1828
|
+
def _create_span(self):
|
1936
1829
|
start_time = time.time()
|
1937
1830
|
span_id = str(uuid.uuid4())
|
1938
1831
|
current_depth = 0
|
1939
1832
|
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
1940
1833
|
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
1941
1834
|
self._trace_client._span_depths[span_id] = current_depth
|
1942
|
-
|
1943
|
-
|
1944
|
-
|
1945
|
-
|
1835
|
+
span = TraceSpan(
|
1836
|
+
function=self._span_name,
|
1837
|
+
span_id=span_id,
|
1838
|
+
trace_id=self._trace_client.trace_id,
|
1839
|
+
depth=current_depth,
|
1840
|
+
message=self._span_name,
|
1841
|
+
created_at=start_time,
|
1842
|
+
span_type="llm",
|
1843
|
+
parent_span_id=self._parent_span_id_at_entry
|
1946
1844
|
)
|
1947
|
-
self._trace_client.
|
1948
|
-
|
1949
|
-
|
1950
|
-
# Set the current span ID in contextvars
|
1951
|
-
self._span_context_token = current_span_var.set(span_id)
|
1845
|
+
self._trace_client.add_span(span)
|
1846
|
+
return span_id, span
|
1952
1847
|
|
1953
|
-
|
1954
|
-
|
1955
|
-
|
1956
|
-
|
1957
|
-
|
1958
|
-
|
1959
|
-
)
|
1960
|
-
self._trace_client.add_entry(input_entry)
|
1848
|
+
def _finalize_span(self, span_id):
|
1849
|
+
span = self._trace_client.span_id_to_span.get(span_id)
|
1850
|
+
if span:
|
1851
|
+
span.duration = time.time() - span.created_at
|
1852
|
+
if span_id in self._trace_client._span_depths:
|
1853
|
+
del self._trace_client._span_depths[span_id]
|
1961
1854
|
|
1962
|
-
|
1963
|
-
|
1855
|
+
class _TracedAsyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractAsyncContextManager):
|
1856
|
+
async def __aenter__(self):
|
1857
|
+
self._parent_span_id_at_entry = current_span_var.get()
|
1858
|
+
if not self._trace_client:
|
1859
|
+
return await self._original_manager.__aenter__()
|
1964
1860
|
|
1965
|
-
|
1966
|
-
|
1967
|
-
|
1968
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
1969
|
-
created_at=time.time(), output="<pending stream>", span_type="llm"
|
1970
|
-
)
|
1971
|
-
self._trace_client.add_entry(output_entry)
|
1861
|
+
span_id, span = self._create_span()
|
1862
|
+
self._span_context_token = current_span_var.set(span_id)
|
1863
|
+
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
1972
1864
|
|
1973
|
-
#
|
1974
|
-
|
1975
|
-
|
1865
|
+
# Call the original __aenter__ and expect it to be an async generator
|
1866
|
+
raw_iterator = await self._original_manager.__aenter__()
|
1867
|
+
span.output = "<pending stream>"
|
1868
|
+
return self._stream_wrapper_func(raw_iterator, self._client, span)
|
1976
1869
|
|
1977
1870
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
1978
|
-
# Manually create the 'exit' entry
|
1979
1871
|
if hasattr(self, '_span_context_token'):
|
1980
|
-
|
1981
|
-
|
1982
|
-
|
1983
|
-
|
1984
|
-
|
1985
|
-
break
|
1986
|
-
duration = time.time() - start_time_for_duration if start_time_for_duration else None
|
1987
|
-
exit_depth = self._trace_client._span_depths.get(span_id, 0)
|
1988
|
-
exit_entry = TraceEntry(
|
1989
|
-
type="exit", function=self._span_name, span_id=span_id,
|
1990
|
-
trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
|
1991
|
-
created_at=time.time(), duration=duration, span_type="llm"
|
1992
|
-
)
|
1993
|
-
self._trace_client.add_entry(exit_entry)
|
1994
|
-
if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
|
1995
|
-
current_span_var.reset(self._span_context_token)
|
1996
|
-
delattr(self, '_span_context_token')
|
1997
|
-
|
1998
|
-
# Delegate __aexit__
|
1999
|
-
if hasattr(self._original_manager, "__aexit__"):
|
2000
|
-
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2001
|
-
return None
|
2002
|
-
|
2003
|
-
class _TracedSyncStreamManagerWrapper(AbstractContextManager):
|
2004
|
-
"""Wraps an original sync stream manager to add tracing."""
|
2005
|
-
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
2006
|
-
self._original_manager = original_manager
|
2007
|
-
self._client = client
|
2008
|
-
self._span_name = span_name
|
2009
|
-
self._trace_client = trace_client
|
2010
|
-
self._stream_wrapper_func = stream_wrapper_func
|
2011
|
-
self._input_kwargs = input_kwargs
|
2012
|
-
self._parent_span_id_at_entry = None
|
1872
|
+
span_id = current_span_var.get()
|
1873
|
+
self._finalize_span(span_id)
|
1874
|
+
current_span_var.reset(self._span_context_token)
|
1875
|
+
delattr(self, '_span_context_token')
|
1876
|
+
return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
|
2013
1877
|
|
1878
|
+
class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContextManager):
|
2014
1879
|
def __enter__(self):
|
2015
1880
|
self._parent_span_id_at_entry = current_span_var.get()
|
2016
1881
|
if not self._trace_client:
|
2017
|
-
|
1882
|
+
return self._original_manager.__enter__()
|
2018
1883
|
|
2019
|
-
|
2020
|
-
start_time = time.time()
|
2021
|
-
span_id = str(uuid.uuid4())
|
2022
|
-
current_depth = 0
|
2023
|
-
if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
|
2024
|
-
current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
|
2025
|
-
self._trace_client._span_depths[span_id] = current_depth
|
2026
|
-
enter_entry = TraceEntry(
|
2027
|
-
type="enter", function=self._span_name, span_id=span_id,
|
2028
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
|
2029
|
-
created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
|
2030
|
-
)
|
2031
|
-
self._trace_client.add_entry(enter_entry)
|
1884
|
+
span_id, span = self._create_span()
|
2032
1885
|
self._span_context_token = current_span_var.set(span_id)
|
1886
|
+
span.inputs = _format_input_data(self._client, **self._input_kwargs)
|
2033
1887
|
|
2034
|
-
# Manually create 'input' entry
|
2035
|
-
input_data = _format_input_data(self._client, **self._input_kwargs)
|
2036
|
-
input_entry = TraceEntry(
|
2037
|
-
type="input", function=self._span_name, span_id=span_id,
|
2038
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
|
2039
|
-
created_at=time.time(), inputs=input_data, span_type="llm"
|
2040
|
-
)
|
2041
|
-
self._trace_client.add_entry(input_entry)
|
2042
|
-
|
2043
|
-
# Call original __enter__
|
2044
1888
|
raw_iterator = self._original_manager.__enter__()
|
2045
|
-
|
2046
|
-
|
2047
|
-
output_entry = TraceEntry(
|
2048
|
-
type="output", function=self._span_name, span_id=span_id,
|
2049
|
-
trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
|
2050
|
-
created_at=time.time(), output="<pending stream>", span_type="llm"
|
2051
|
-
)
|
2052
|
-
self._trace_client.add_entry(output_entry)
|
2053
|
-
|
2054
|
-
# Wrap the raw iterator
|
2055
|
-
wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
|
2056
|
-
return wrapped_iterator
|
1889
|
+
span.output = "<pending stream>"
|
1890
|
+
return self._stream_wrapper_func(raw_iterator, self._client, span)
|
2057
1891
|
|
2058
1892
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
2059
|
-
# Manually create 'exit' entry
|
2060
1893
|
if hasattr(self, '_span_context_token'):
|
2061
|
-
|
2062
|
-
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
2066
|
-
|
2067
|
-
|
2068
|
-
|
2069
|
-
|
2070
|
-
|
2071
|
-
|
2072
|
-
|
2073
|
-
|
2074
|
-
|
2075
|
-
|
2076
|
-
|
2077
|
-
|
2078
|
-
|
2079
|
-
|
2080
|
-
|
2081
|
-
|
1894
|
+
span_id = current_span_var.get()
|
1895
|
+
self._finalize_span(span_id)
|
1896
|
+
current_span_var.reset(self._span_context_token)
|
1897
|
+
delattr(self, '_span_context_token')
|
1898
|
+
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
1899
|
+
|
1900
|
+
# --- NEW Generalized Helper Function (Moved from demo) ---
|
1901
|
+
def prepare_evaluation_for_state(
|
1902
|
+
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
1903
|
+
example: Optional[Example] = None,
|
1904
|
+
# --- Individual components (alternative to 'example') ---
|
1905
|
+
input: Optional[str] = None,
|
1906
|
+
actual_output: Optional[Union[str, List[str]]] = None,
|
1907
|
+
expected_output: Optional[Union[str, List[str]]] = None,
|
1908
|
+
context: Optional[List[str]] = None,
|
1909
|
+
retrieval_context: Optional[List[str]] = None,
|
1910
|
+
tools_called: Optional[List[str]] = None,
|
1911
|
+
expected_tools: Optional[List[str]] = None,
|
1912
|
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
1913
|
+
# --- Other eval parameters ---
|
1914
|
+
model: Optional[str] = None,
|
1915
|
+
log_results: Optional[bool] = True
|
1916
|
+
) -> Optional[EvaluationConfig]:
|
1917
|
+
"""
|
1918
|
+
Prepares an EvaluationConfig object, similar to TraceClient.async_evaluate.
|
1919
|
+
|
1920
|
+
Accepts either a pre-made Example object or individual components to construct one.
|
1921
|
+
Returns the EvaluationConfig object ready to be placed in the state, or None.
|
1922
|
+
"""
|
1923
|
+
final_example = example
|
1924
|
+
|
1925
|
+
# If example is not provided, try to construct one from individual parts
|
1926
|
+
if final_example is None:
|
1927
|
+
# Basic validation: Ensure at least actual_output is present for most scorers
|
1928
|
+
if actual_output is None:
|
1929
|
+
# print("[prepare_evaluation_for_state] Warning: 'actual_output' is required when 'example' is not provided. Skipping evaluation setup.")
|
1930
|
+
return None
|
1931
|
+
try:
|
1932
|
+
final_example = Example(
|
1933
|
+
input=input,
|
1934
|
+
actual_output=actual_output,
|
1935
|
+
expected_output=expected_output,
|
1936
|
+
context=context,
|
1937
|
+
retrieval_context=retrieval_context,
|
1938
|
+
tools_called=tools_called,
|
1939
|
+
expected_tools=expected_tools,
|
1940
|
+
additional_metadata=additional_metadata,
|
1941
|
+
# trace_id will be set by the handler later if needed
|
1942
|
+
)
|
1943
|
+
# print("[prepare_evaluation_for_state] Constructed Example from individual components.")
|
1944
|
+
except Exception as e:
|
1945
|
+
# print(f"[prepare_evaluation_for_state] Error constructing Example: {e}. Skipping evaluation setup.")
|
1946
|
+
return None
|
1947
|
+
|
1948
|
+
# If we have a valid example (provided or constructed) and scorers
|
1949
|
+
if final_example and scorers:
|
1950
|
+
# TODO: Add validation like check_examples if needed here,
|
1951
|
+
# although the handler might implicitly handle some checks via TraceClient.
|
1952
|
+
return EvaluationConfig(
|
1953
|
+
scorers=scorers,
|
1954
|
+
example=final_example,
|
1955
|
+
model=model,
|
1956
|
+
log_results=log_results
|
1957
|
+
)
|
1958
|
+
elif not scorers:
|
1959
|
+
# print("[prepare_evaluation_for_state] No scorers provided. Skipping evaluation setup.")
|
2082
1960
|
return None
|
1961
|
+
else: # No valid example
|
1962
|
+
# print("[prepare_evaluation_for_state] No valid Example available. Skipping evaluation setup.")
|
1963
|
+
return None
|
1964
|
+
# --- End NEW Helper Function ---
|
1965
|
+
|
1966
|
+
# --- NEW: Helper function to simplify adding eval config to state ---
|
1967
|
+
def add_evaluation_to_state(
|
1968
|
+
state: Dict[str, Any], # The LangGraph state dictionary
|
1969
|
+
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
1970
|
+
# --- Evaluation components (same as prepare_evaluation_for_state) ---
|
1971
|
+
input: Optional[str] = None,
|
1972
|
+
actual_output: Optional[Union[str, List[str]]] = None,
|
1973
|
+
expected_output: Optional[Union[str, List[str]]] = None,
|
1974
|
+
context: Optional[List[str]] = None,
|
1975
|
+
retrieval_context: Optional[List[str]] = None,
|
1976
|
+
tools_called: Optional[List[str]] = None,
|
1977
|
+
expected_tools: Optional[List[str]] = None,
|
1978
|
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
1979
|
+
# --- Other eval parameters ---
|
1980
|
+
model: Optional[str] = None,
|
1981
|
+
log_results: Optional[bool] = True
|
1982
|
+
) -> None:
|
1983
|
+
"""
|
1984
|
+
Prepares an EvaluationConfig and adds it to the state dictionary
|
1985
|
+
under the '_judgeval_eval' key if successful.
|
1986
|
+
|
1987
|
+
This simplifies the process of setting up evaluations within LangGraph nodes.
|
1988
|
+
|
1989
|
+
Args:
|
1990
|
+
state: The LangGraph state dictionary to modify.
|
1991
|
+
scorers: List of scorer instances.
|
1992
|
+
input: Input for the evaluation example.
|
1993
|
+
actual_output: Actual output for the evaluation example.
|
1994
|
+
expected_output: Expected output for the evaluation example.
|
1995
|
+
context: Context for the evaluation example.
|
1996
|
+
retrieval_context: Retrieval context for the evaluation example.
|
1997
|
+
tools_called: Tools called for the evaluation example.
|
1998
|
+
expected_tools: Expected tools for the evaluation example.
|
1999
|
+
additional_metadata: Additional metadata for the evaluation example.
|
2000
|
+
model: Model name used for generation (optional).
|
2001
|
+
log_results: Whether to log evaluation results (optional, defaults to True).
|
2002
|
+
"""
|
2003
|
+
eval_config = prepare_evaluation_for_state(
|
2004
|
+
scorers=scorers,
|
2005
|
+
input=input,
|
2006
|
+
actual_output=actual_output,
|
2007
|
+
expected_output=expected_output,
|
2008
|
+
context=context,
|
2009
|
+
retrieval_context=retrieval_context,
|
2010
|
+
tools_called=tools_called,
|
2011
|
+
expected_tools=expected_tools,
|
2012
|
+
additional_metadata=additional_metadata,
|
2013
|
+
model=model,
|
2014
|
+
log_results=log_results
|
2015
|
+
)
|
2016
|
+
|
2017
|
+
if eval_config:
|
2018
|
+
state["_judgeval_eval"] = eval_config
|
2019
|
+
# print(f"[_judgeval_eval added to state for node]") # Optional: Log confirmation
|
2020
|
+
|
2021
|
+
# print("[Skipped adding _judgeval_eval to state: prepare_evaluation_for_state failed]")
|
2022
|
+
# --- End NEW Helper ---
|
2023
|
+
|