judgeval 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +515 -193
- judgeval/constants.py +4 -2
- judgeval/data/__init__.py +0 -3
- judgeval/data/{api_example.py → custom_api_example.py} +12 -19
- judgeval/data/datasets/eval_dataset_client.py +59 -20
- judgeval/data/result.py +34 -56
- judgeval/evaluation_run.py +1 -0
- judgeval/judgment_client.py +47 -15
- judgeval/run_evaluation.py +20 -36
- judgeval/scorers/score.py +9 -11
- {judgeval-0.0.26.dist-info → judgeval-0.0.28.dist-info}/METADATA +1 -1
- {judgeval-0.0.26.dist-info → judgeval-0.0.28.dist-info}/RECORD +14 -14
- {judgeval-0.0.26.dist-info → judgeval-0.0.28.dist-info}/WHEEL +0 -0
- {judgeval-0.0.26.dist-info → judgeval-0.0.28.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -10,11 +10,12 @@ import os
|
|
10
10
|
import time
|
11
11
|
import uuid
|
12
12
|
import warnings
|
13
|
+
import contextvars
|
13
14
|
from contextlib import contextmanager
|
14
15
|
from dataclasses import dataclass, field
|
15
16
|
from datetime import datetime
|
16
17
|
from http import HTTPStatus
|
17
|
-
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union
|
18
|
+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
|
18
19
|
from rich import print as rprint
|
19
20
|
|
20
21
|
# Third-party imports
|
@@ -45,6 +46,12 @@ from judgeval.rules import Rule
|
|
45
46
|
from judgeval.evaluation_run import EvaluationRun
|
46
47
|
from judgeval.data.result import ScoringResult
|
47
48
|
|
49
|
+
# Standard library imports needed for the new class
|
50
|
+
import concurrent.futures
|
51
|
+
|
52
|
+
# Define context variables for tracking the current trace and the current span within a trace
|
53
|
+
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
54
|
+
current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
|
48
55
|
|
49
56
|
# Define type aliases for better code readability and maintainability
|
50
57
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
|
@@ -63,29 +70,39 @@ class TraceEntry:
|
|
63
70
|
"""
|
64
71
|
type: TraceEntryType
|
65
72
|
function: str # Name of the function being traced
|
73
|
+
span_id: str # Unique ID for this specific span instance
|
66
74
|
depth: int # Indentation level for nested calls
|
67
75
|
message: str # Human-readable description
|
68
|
-
|
76
|
+
# created_at: Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
|
69
77
|
duration: Optional[float] = None # Time taken (for exit/evaluation entries)
|
78
|
+
trace_id: str = None # ID of the trace this entry belongs to
|
70
79
|
output: Any = None # Function output value
|
71
80
|
# Use field() for mutable defaults to avoid shared state issues
|
72
81
|
inputs: dict = field(default_factory=dict)
|
73
82
|
span_type: SpanType = "span"
|
74
83
|
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
84
|
+
parent_span_id: Optional[str] = None # ID of the parent span instance
|
75
85
|
|
76
86
|
def print_entry(self):
|
87
|
+
"""Print a trace entry with proper formatting and parent relationship information."""
|
77
88
|
indent = " " * self.depth
|
89
|
+
|
78
90
|
if self.type == "enter":
|
79
|
-
|
91
|
+
# Format parent info if present
|
92
|
+
parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
|
93
|
+
print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
|
80
94
|
elif self.type == "exit":
|
81
|
-
print(f"{indent}← {self.function} ({self.duration:.3f}s)")
|
95
|
+
print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
|
82
96
|
elif self.type == "output":
|
83
|
-
|
97
|
+
# Format output to align properly
|
98
|
+
output_str = str(self.output)
|
99
|
+
print(f"{indent}Output (for id: {self.span_id}): {output_str}")
|
84
100
|
elif self.type == "input":
|
85
|
-
|
101
|
+
# Format inputs to align properly
|
102
|
+
print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
|
86
103
|
elif self.type == "evaluation":
|
87
104
|
for evaluation_run in self.evaluation_runs:
|
88
|
-
print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
|
105
|
+
print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
|
89
106
|
|
90
107
|
def _serialize_inputs(self) -> dict:
|
91
108
|
"""Helper method to serialize input data safely.
|
@@ -144,14 +161,17 @@ class TraceEntry:
|
|
144
161
|
return {
|
145
162
|
"type": self.type,
|
146
163
|
"function": self.function,
|
164
|
+
"span_id": self.span_id,
|
165
|
+
"trace_id": self.trace_id,
|
147
166
|
"depth": self.depth,
|
148
167
|
"message": self.message,
|
149
|
-
"
|
168
|
+
"created_at": datetime.fromtimestamp(self.created_at).isoformat(),
|
150
169
|
"duration": self.duration,
|
151
170
|
"output": self._serialize_output(),
|
152
171
|
"inputs": self._serialize_inputs(),
|
153
172
|
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
154
|
-
"span_type": self.span_type
|
173
|
+
"span_type": self.span_type,
|
174
|
+
"parent_span_id": self.parent_span_id
|
155
175
|
}
|
156
176
|
|
157
177
|
def _serialize_output(self) -> Any:
|
@@ -210,13 +230,12 @@ class TraceManagerClient:
|
|
210
230
|
|
211
231
|
return response.json()
|
212
232
|
|
213
|
-
def save_trace(self, trace_data: dict
|
233
|
+
def save_trace(self, trace_data: dict):
|
214
234
|
"""
|
215
235
|
Saves a trace to the database
|
216
236
|
|
217
237
|
Args:
|
218
238
|
trace_data: The trace data to save
|
219
|
-
empty_save: Whether to save an empty trace
|
220
239
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
221
240
|
"""
|
222
241
|
response = requests.post(
|
@@ -235,7 +254,7 @@ class TraceManagerClient:
|
|
235
254
|
elif response.status_code != HTTPStatus.OK:
|
236
255
|
raise ValueError(f"Failed to save trace data: {response.text}")
|
237
256
|
|
238
|
-
if
|
257
|
+
if "ui_results_url" in response.json():
|
239
258
|
pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
|
240
259
|
rprint(pretty_str)
|
241
260
|
|
@@ -315,66 +334,81 @@ class TraceClient:
|
|
315
334
|
overwrite: bool = False,
|
316
335
|
rules: Optional[List[Rule]] = None,
|
317
336
|
enable_monitoring: bool = True,
|
318
|
-
enable_evaluations: bool = True
|
337
|
+
enable_evaluations: bool = True,
|
338
|
+
parent_trace_id: Optional[str] = None,
|
339
|
+
parent_name: Optional[str] = None
|
319
340
|
):
|
320
341
|
self.name = name
|
321
342
|
self.trace_id = trace_id or str(uuid.uuid4())
|
322
343
|
self.project_name = project_name
|
323
344
|
self.overwrite = overwrite
|
324
345
|
self.tracer = tracer
|
325
|
-
# Initialize rules with either provided rules or an empty list
|
326
346
|
self.rules = rules or []
|
327
347
|
self.enable_monitoring = enable_monitoring
|
328
348
|
self.enable_evaluations = enable_evaluations
|
329
|
-
|
349
|
+
self.parent_trace_id = parent_trace_id
|
350
|
+
self.parent_name = parent_name
|
330
351
|
self.client: JudgmentClient = tracer.client
|
331
352
|
self.entries: List[TraceEntry] = []
|
332
353
|
self.start_time = time.time()
|
333
|
-
self.
|
334
|
-
self.
|
335
|
-
self.
|
336
|
-
self.
|
337
|
-
self.
|
338
|
-
self.executed_node_tools = [] # Track node:tool combinations
|
354
|
+
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
|
355
|
+
self.visited_nodes = []
|
356
|
+
self.executed_tools = []
|
357
|
+
self.executed_node_tools = []
|
358
|
+
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
339
359
|
|
340
360
|
@contextmanager
|
341
361
|
def span(self, name: str, span_type: SpanType = "span"):
|
342
|
-
"""Context manager for creating a trace span"""
|
362
|
+
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
343
363
|
start_time = time.time()
|
344
364
|
|
345
|
-
#
|
346
|
-
|
365
|
+
# Generate a unique ID for *this specific span invocation*
|
366
|
+
span_id = str(uuid.uuid4())
|
367
|
+
|
368
|
+
parent_span_id = current_span_var.get() # Get ID of the parent span from context var
|
369
|
+
token = current_span_var.set(span_id) # Set *this* span's ID as the current one
|
370
|
+
|
371
|
+
current_depth = 0
|
372
|
+
if parent_span_id and parent_span_id in self._span_depths:
|
373
|
+
current_depth = self._span_depths[parent_span_id] + 1
|
374
|
+
|
375
|
+
self._span_depths[span_id] = current_depth # Store depth by span_id
|
376
|
+
|
377
|
+
entry = TraceEntry(
|
347
378
|
type="enter",
|
348
379
|
function=name,
|
349
|
-
|
380
|
+
span_id=span_id, # Use the generated span_id
|
381
|
+
trace_id=self.trace_id, # Use the trace_id from the trace client
|
382
|
+
depth=current_depth,
|
350
383
|
message=name,
|
351
|
-
|
352
|
-
span_type=span_type
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
self.tracer.depth += 1
|
357
|
-
prev_span = self._current_span
|
358
|
-
self._current_span = name
|
384
|
+
created_at=start_time,
|
385
|
+
span_type=span_type,
|
386
|
+
parent_span_id=parent_span_id # Use the parent_id from context var
|
387
|
+
)
|
388
|
+
self.add_entry(entry)
|
359
389
|
|
360
390
|
try:
|
361
391
|
yield self
|
362
392
|
finally:
|
363
|
-
self.tracer.depth -= 1
|
364
393
|
duration = time.time() - start_time
|
365
|
-
|
366
|
-
# Record span exit
|
394
|
+
exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
|
367
395
|
self.add_entry(TraceEntry(
|
368
396
|
type="exit",
|
369
397
|
function=name,
|
370
|
-
|
398
|
+
span_id=span_id, # Use the same span_id for exit
|
399
|
+
trace_id=self.trace_id, # Use the trace_id from the trace client
|
400
|
+
depth=exit_depth,
|
371
401
|
message=f"← {name}",
|
372
|
-
|
402
|
+
created_at=time.time(),
|
373
403
|
duration=duration,
|
374
404
|
span_type=span_type
|
375
405
|
))
|
376
|
-
|
377
|
-
|
406
|
+
# Clean up depth tracking for this span_id
|
407
|
+
if span_id in self._span_depths:
|
408
|
+
del self._span_depths[span_id]
|
409
|
+
# Reset context var
|
410
|
+
current_span_var.reset(token)
|
411
|
+
|
378
412
|
def async_evaluate(
|
379
413
|
self,
|
380
414
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
@@ -457,7 +491,7 @@ class TraceClient:
|
|
457
491
|
log_results=log_results,
|
458
492
|
project_name=self.project_name,
|
459
493
|
eval_name=f"{self.name.capitalize()}-"
|
460
|
-
f"{
|
494
|
+
f"{current_span_var.get()}-"
|
461
495
|
f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
|
462
496
|
examples=[example],
|
463
497
|
scorers=loaded_scorers,
|
@@ -465,47 +499,66 @@ class TraceClient:
|
|
465
499
|
metadata={},
|
466
500
|
judgment_api_key=self.tracer.api_key,
|
467
501
|
override=self.overwrite,
|
502
|
+
trace_span_id=current_span_var.get(),
|
468
503
|
rules=loaded_rules # Use the combined rules
|
469
504
|
)
|
470
505
|
|
471
506
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
472
507
|
|
473
508
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
509
|
+
current_span_id = current_span_var.get()
|
510
|
+
if current_span_id:
|
511
|
+
duration = time.time() - start_time
|
512
|
+
prev_entry = self.entries[-1] if self.entries else None
|
513
|
+
# Determine function name based on previous entry or context var (less ideal)
|
514
|
+
function_name = "unknown_function" # Default
|
515
|
+
if prev_entry and prev_entry.span_type == "llm":
|
516
|
+
function_name = prev_entry.function
|
517
|
+
else:
|
518
|
+
# Try to find the function name associated with the current span_id
|
519
|
+
for entry in reversed(self.entries):
|
520
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
521
|
+
function_name = entry.function
|
522
|
+
break
|
485
523
|
|
486
|
-
#
|
524
|
+
# Get depth for the current span
|
525
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
526
|
+
|
487
527
|
self.add_entry(TraceEntry(
|
488
528
|
type="evaluation",
|
489
|
-
function=
|
490
|
-
|
491
|
-
|
492
|
-
|
529
|
+
function=function_name,
|
530
|
+
span_id=current_span_id, # Associate with current span
|
531
|
+
trace_id=self.trace_id, # Use the trace_id from the trace client
|
532
|
+
depth=current_depth,
|
533
|
+
message=f"Evaluation results for {function_name}",
|
534
|
+
created_at=time.time(),
|
493
535
|
evaluation_runs=[eval_run],
|
494
536
|
duration=duration,
|
495
537
|
span_type="evaluation"
|
496
538
|
))
|
497
539
|
|
498
540
|
def record_input(self, inputs: dict):
|
499
|
-
|
500
|
-
if
|
541
|
+
current_span_id = current_span_var.get()
|
542
|
+
if current_span_id:
|
543
|
+
entry_span_type = "span"
|
544
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
545
|
+
function_name = "unknown_function" # Default
|
546
|
+
for entry in reversed(self.entries):
|
547
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
548
|
+
entry_span_type = entry.span_type
|
549
|
+
function_name = entry.function
|
550
|
+
break
|
551
|
+
|
501
552
|
self.add_entry(TraceEntry(
|
502
553
|
type="input",
|
503
|
-
function=
|
504
|
-
|
505
|
-
|
506
|
-
|
554
|
+
function=function_name,
|
555
|
+
span_id=current_span_id, # Use current span_id
|
556
|
+
trace_id=self.trace_id, # Use the trace_id from the trace client
|
557
|
+
depth=current_depth,
|
558
|
+
message=f"Inputs to {function_name}",
|
559
|
+
created_at=time.time(),
|
507
560
|
inputs=inputs,
|
508
|
-
span_type=
|
561
|
+
span_type=entry_span_type
|
509
562
|
))
|
510
563
|
|
511
564
|
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
@@ -519,21 +572,30 @@ class TraceClient:
|
|
519
572
|
raise
|
520
573
|
|
521
574
|
def record_output(self, output: Any):
|
522
|
-
|
523
|
-
if
|
575
|
+
current_span_id = current_span_var.get()
|
576
|
+
if current_span_id:
|
577
|
+
entry_span_type = "span"
|
578
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
579
|
+
function_name = "unknown_function" # Default
|
580
|
+
for entry in reversed(self.entries):
|
581
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
582
|
+
entry_span_type = entry.span_type
|
583
|
+
function_name = entry.function
|
584
|
+
break
|
585
|
+
|
524
586
|
entry = TraceEntry(
|
525
587
|
type="output",
|
526
|
-
function=
|
527
|
-
|
528
|
-
|
529
|
-
|
588
|
+
function=function_name,
|
589
|
+
span_id=current_span_id, # Use current span_id
|
590
|
+
depth=current_depth,
|
591
|
+
message=f"Output from {function_name}",
|
592
|
+
created_at=time.time(),
|
530
593
|
output="<pending>" if inspect.iscoroutine(output) else output,
|
531
|
-
span_type=
|
594
|
+
span_type=entry_span_type
|
532
595
|
)
|
533
596
|
self.add_entry(entry)
|
534
597
|
|
535
598
|
if inspect.iscoroutine(output):
|
536
|
-
# Create a task to update the output once the coroutine completes
|
537
599
|
asyncio.create_task(self._update_coroutine_output(entry, output))
|
538
600
|
|
539
601
|
def add_entry(self, entry: TraceEntry):
|
@@ -546,6 +608,58 @@ class TraceClient:
|
|
546
608
|
for entry in self.entries:
|
547
609
|
entry.print_entry()
|
548
610
|
|
611
|
+
def print_hierarchical(self):
|
612
|
+
"""Print the trace in a hierarchical structure based on parent-child relationships"""
|
613
|
+
# First, build a map of spans
|
614
|
+
spans = {}
|
615
|
+
root_spans = []
|
616
|
+
|
617
|
+
# Collect all enter events first
|
618
|
+
for entry in self.entries:
|
619
|
+
if entry.type == "enter":
|
620
|
+
spans[entry.function] = {
|
621
|
+
"name": entry.function,
|
622
|
+
"depth": entry.depth,
|
623
|
+
"parent_id": entry.parent_span_id,
|
624
|
+
"children": []
|
625
|
+
}
|
626
|
+
|
627
|
+
# If no parent, it's a root span
|
628
|
+
if not entry.parent_span_id:
|
629
|
+
root_spans.append(entry.function)
|
630
|
+
elif entry.parent_span_id not in spans:
|
631
|
+
# If parent doesn't exist yet, temporarily treat as root
|
632
|
+
# (we'll fix this later)
|
633
|
+
root_spans.append(entry.function)
|
634
|
+
|
635
|
+
# Build parent-child relationships
|
636
|
+
for span_name, span in spans.items():
|
637
|
+
parent = span["parent_id"]
|
638
|
+
if parent and parent in spans:
|
639
|
+
spans[parent]["children"].append(span_name)
|
640
|
+
# Remove from root spans if it was temporarily there
|
641
|
+
if span_name in root_spans:
|
642
|
+
root_spans.remove(span_name)
|
643
|
+
|
644
|
+
# Now print the hierarchy
|
645
|
+
def print_span(span_name, level=0):
|
646
|
+
if span_name not in spans:
|
647
|
+
return
|
648
|
+
|
649
|
+
span = spans[span_name]
|
650
|
+
indent = " " * level
|
651
|
+
parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
|
652
|
+
print(f"{indent}→ {span_name}{parent_info}")
|
653
|
+
|
654
|
+
# Print children
|
655
|
+
for child in span["children"]:
|
656
|
+
print_span(child, level + 1)
|
657
|
+
|
658
|
+
# Print starting with root spans
|
659
|
+
print("\nHierarchical Trace Structure:")
|
660
|
+
for root in root_spans:
|
661
|
+
print_span(root)
|
662
|
+
|
549
663
|
def get_duration(self) -> float:
|
550
664
|
"""
|
551
665
|
Get the total duration of this trace
|
@@ -554,58 +668,126 @@ class TraceClient:
|
|
554
668
|
|
555
669
|
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
556
670
|
"""
|
557
|
-
Condenses trace entries into a single entry for each
|
671
|
+
Condenses trace entries into a single entry for each span instance,
|
672
|
+
preserving parent-child span relationships using span_id and parent_span_id.
|
558
673
|
"""
|
559
|
-
|
560
|
-
|
561
|
-
|
674
|
+
spans_by_id: Dict[str, dict] = {}
|
675
|
+
evaluation_runs: List[EvaluationRun] = []
|
676
|
+
|
677
|
+
# First pass: Group entries by span_id and gather data
|
678
|
+
for entry in entries:
|
679
|
+
span_id = entry.get("span_id")
|
680
|
+
if not span_id:
|
681
|
+
continue # Skip entries without a span_id (should not happen)
|
562
682
|
|
563
|
-
for i, entry in enumerate(entries):
|
564
|
-
function = entry["function"]
|
565
|
-
|
566
683
|
if entry["type"] == "enter":
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
# del function_entries[function]
|
586
|
-
|
587
|
-
# The OR condition is to handle the LLM client case.
|
588
|
-
# LLM client is a special case where we exit the span, so when we attach evaluations to it,
|
589
|
-
# we have to check if the previous entry is an LLM call.
|
590
|
-
elif function in active_functions or entry["type"] == "evaluation" and entries[i-1]["function"] == entry["function"]:
|
591
|
-
# Update existing function entry with additional data
|
592
|
-
current_entry = function_entries[function]
|
684
|
+
if span_id not in spans_by_id:
|
685
|
+
spans_by_id[span_id] = {
|
686
|
+
"span_id": span_id,
|
687
|
+
"function": entry["function"],
|
688
|
+
"depth": entry["depth"], # Use the depth recorded at entry time
|
689
|
+
"created_at": entry["created_at"],
|
690
|
+
"trace_id": entry["trace_id"],
|
691
|
+
"parent_span_id": entry.get("parent_span_id"),
|
692
|
+
"span_type": entry.get("span_type", "span"),
|
693
|
+
"inputs": None,
|
694
|
+
"output": None,
|
695
|
+
"evaluation_runs": [],
|
696
|
+
"duration": None
|
697
|
+
}
|
698
|
+
# Handle potential duplicate enter events if necessary (e.g., log warning)
|
699
|
+
|
700
|
+
elif span_id in spans_by_id:
|
701
|
+
current_span_data = spans_by_id[span_id]
|
593
702
|
|
594
703
|
if entry["type"] == "input" and entry["inputs"]:
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
current_entry["evaluation_runs"] = entry["evaluation_runs"]
|
704
|
+
# Merge inputs if multiple are recorded, or just assign
|
705
|
+
if current_span_data["inputs"] is None:
|
706
|
+
current_span_data["inputs"] = entry["inputs"]
|
707
|
+
elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
|
708
|
+
current_span_data["inputs"].update(entry["inputs"])
|
709
|
+
# Add more sophisticated merging if needed
|
602
710
|
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
711
|
+
elif entry["type"] == "output" and "output" in entry:
|
712
|
+
current_span_data["output"] = entry["output"]
|
713
|
+
|
714
|
+
elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
|
715
|
+
if current_span_data.get("evaluation_runs") is not None:
|
716
|
+
evaluation_runs.extend(entry["evaluation_runs"])
|
717
|
+
|
718
|
+
elif entry["type"] == "exit":
|
719
|
+
if current_span_data["duration"] is None: # Calculate duration only once
|
720
|
+
start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
|
721
|
+
end_time = datetime.fromisoformat(entry["created_at"])
|
722
|
+
current_span_data["duration"] = (end_time - start_time).total_seconds()
|
723
|
+
# Update depth if exit depth is different (though current span() implementation keeps it same)
|
724
|
+
# current_span_data["depth"] = entry["depth"]
|
725
|
+
|
726
|
+
# Convert dictionary to a list initially for easier access
|
727
|
+
spans_list = list(spans_by_id.values())
|
728
|
+
|
729
|
+
# Build tree structure (adjacency list) and find roots
|
730
|
+
children_map: Dict[Optional[str], List[dict]] = {}
|
731
|
+
roots = []
|
732
|
+
span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
|
733
|
+
|
734
|
+
for span in spans_list:
|
735
|
+
parent_id = span.get("parent_span_id")
|
736
|
+
if parent_id is None:
|
737
|
+
roots.append(span)
|
738
|
+
else:
|
739
|
+
if parent_id not in children_map:
|
740
|
+
children_map[parent_id] = []
|
741
|
+
children_map[parent_id].append(span)
|
607
742
|
|
608
|
-
|
743
|
+
# Sort roots by timestamp
|
744
|
+
roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
745
|
+
|
746
|
+
# Perform depth-first traversal to get the final sorted list
|
747
|
+
sorted_condensed_list = []
|
748
|
+
visited = set() # To handle potential cycles, though unlikely with UUIDs
|
749
|
+
|
750
|
+
def dfs(span_data):
|
751
|
+
span_id = span_data['span_id']
|
752
|
+
if span_id in visited:
|
753
|
+
return # Avoid infinite loops in case of cycles
|
754
|
+
visited.add(span_id)
|
755
|
+
|
756
|
+
sorted_condensed_list.append(span_data) # Add parent before children
|
757
|
+
|
758
|
+
# Get children, sort them by created_at, and visit them
|
759
|
+
span_children = children_map.get(span_id, [])
|
760
|
+
span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
|
761
|
+
for child in span_children:
|
762
|
+
# Ensure the child exists in our map before recursing
|
763
|
+
if child['span_id'] in span_map:
|
764
|
+
dfs(child)
|
765
|
+
else:
|
766
|
+
# This case might indicate an issue, but we'll add the child directly
|
767
|
+
# if its parent was processed but the child itself wasn't in the initial list?
|
768
|
+
# Or if the child's 'enter' event was missing. For robustness, add it.
|
769
|
+
if child['span_id'] not in visited:
|
770
|
+
visited.add(child['span_id'])
|
771
|
+
sorted_condensed_list.append(child)
|
772
|
+
|
773
|
+
|
774
|
+
# Start DFS from each root
|
775
|
+
for root_span in roots:
|
776
|
+
if root_span['span_id'] not in visited:
|
777
|
+
dfs(root_span)
|
778
|
+
|
779
|
+
# Handle spans that might not have been reachable from roots (orphans)
|
780
|
+
# Though ideally, all spans should descend from a root.
|
781
|
+
for span_data in spans_list:
|
782
|
+
if span_data['span_id'] not in visited:
|
783
|
+
# Decide how to handle orphans, maybe append them at the end sorted by time?
|
784
|
+
# For now, let's just add them to ensure they aren't lost.
|
785
|
+
sorted_condensed_list.append(span_data)
|
786
|
+
|
787
|
+
|
788
|
+
return sorted_condensed_list, evaluation_runs
|
789
|
+
|
790
|
+
def save(self, overwrite: bool = False) -> Tuple[str, dict]:
|
609
791
|
"""
|
610
792
|
Save the current trace to the database.
|
611
793
|
Returns a tuple of (trace_id, trace_data) where trace_data is the trace data that was saved.
|
@@ -615,7 +797,7 @@ class TraceClient:
|
|
615
797
|
|
616
798
|
raw_entries = [entry.to_dict() for entry in self.entries]
|
617
799
|
|
618
|
-
condensed_entries = self.condense_trace(raw_entries)
|
800
|
+
condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
|
619
801
|
|
620
802
|
# Calculate total token counts from LLM API calls
|
621
803
|
total_prompt_tokens = 0
|
@@ -688,30 +870,32 @@ class TraceClient:
|
|
688
870
|
"total_cost_usd": total_cost
|
689
871
|
},
|
690
872
|
"entries": condensed_entries,
|
691
|
-
"
|
692
|
-
"overwrite": overwrite
|
873
|
+
"evaluation_runs": evaluation_runs,
|
874
|
+
"overwrite": overwrite,
|
875
|
+
"parent_trace_id": self.parent_trace_id,
|
876
|
+
"parent_name": self.parent_name
|
693
877
|
}
|
694
878
|
# Execute asynchrous evaluation in the background
|
695
|
-
if not empty_save: # Only send to RabbitMQ if the trace is not empty
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
879
|
+
# if not empty_save: # Only send to RabbitMQ if the trace is not empty
|
880
|
+
# # Send trace data to evaluation queue via API
|
881
|
+
# try:
|
882
|
+
# response = requests.post(
|
883
|
+
# JUDGMENT_TRACES_ADD_TO_EVAL_QUEUE_API_URL,
|
884
|
+
# json=trace_data,
|
885
|
+
# headers={
|
886
|
+
# "Content-Type": "application/json",
|
887
|
+
# "Authorization": f"Bearer {self.tracer.api_key}",
|
888
|
+
# "X-Organization-Id": self.tracer.organization_id
|
889
|
+
# },
|
890
|
+
# verify=True
|
891
|
+
# )
|
708
892
|
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
893
|
+
# if response.status_code != HTTPStatus.OK:
|
894
|
+
# warnings.warn(f"Failed to add trace to evaluation queue: {response.text}")
|
895
|
+
# except Exception as e:
|
896
|
+
# warnings.warn(f"Error sending trace to evaluation queue: {str(e)}")
|
713
897
|
|
714
|
-
self.trace_manager_client.save_trace(trace_data
|
898
|
+
self.trace_manager_client.save_trace(trace_data)
|
715
899
|
|
716
900
|
return self.trace_id, trace_data
|
717
901
|
|
@@ -745,7 +929,6 @@ class Tracer:
|
|
745
929
|
self.project_name: str = project_name
|
746
930
|
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
747
931
|
self.organization_id: str = organization_id
|
748
|
-
self.depth: int = 0
|
749
932
|
self._current_trace: Optional[str] = None
|
750
933
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
751
934
|
self.initialized: bool = True
|
@@ -770,6 +953,15 @@ class Tracer:
|
|
770
953
|
"""Start a new trace context using a context manager"""
|
771
954
|
trace_id = str(uuid.uuid4())
|
772
955
|
project = project_name if project_name is not None else self.project_name
|
956
|
+
|
957
|
+
# Get parent trace info from context
|
958
|
+
parent_trace = current_trace_var.get()
|
959
|
+
parent_trace_id = None
|
960
|
+
parent_name = None
|
961
|
+
|
962
|
+
if parent_trace:
|
963
|
+
parent_trace_id = parent_trace.trace_id
|
964
|
+
parent_name = parent_trace.name
|
773
965
|
|
774
966
|
trace = TraceClient(
|
775
967
|
self,
|
@@ -779,25 +971,28 @@ class Tracer:
|
|
779
971
|
overwrite=overwrite,
|
780
972
|
rules=self.rules, # Pass combined rules to the trace client
|
781
973
|
enable_monitoring=self.enable_monitoring,
|
782
|
-
enable_evaluations=self.enable_evaluations
|
974
|
+
enable_evaluations=self.enable_evaluations,
|
975
|
+
parent_trace_id=parent_trace_id,
|
976
|
+
parent_name=parent_name
|
783
977
|
)
|
784
|
-
|
785
|
-
|
978
|
+
|
979
|
+
# Set the current trace in context variables
|
980
|
+
token = current_trace_var.set(trace)
|
786
981
|
|
787
982
|
# Automatically create top-level span
|
788
983
|
with trace.span(name or "unnamed_trace") as span:
|
789
984
|
try:
|
790
985
|
# Save the trace to the database to handle Evaluations' trace_id referential integrity
|
791
|
-
trace.save(empty_save=True, overwrite=overwrite)
|
792
986
|
yield trace
|
793
987
|
finally:
|
794
|
-
|
988
|
+
# Reset the context variable
|
989
|
+
current_trace_var.reset(token)
|
795
990
|
|
796
991
|
def get_current_trace(self) -> Optional[TraceClient]:
|
797
992
|
"""
|
798
|
-
Get the current trace context
|
993
|
+
Get the current trace context from contextvars
|
799
994
|
"""
|
800
|
-
return
|
995
|
+
return current_trace_var.get()
|
801
996
|
|
802
997
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
|
803
998
|
"""
|
@@ -823,20 +1018,56 @@ class Tracer:
|
|
823
1018
|
if asyncio.iscoroutinefunction(func):
|
824
1019
|
@functools.wraps(func)
|
825
1020
|
async def async_wrapper(*args, **kwargs):
|
826
|
-
#
|
827
|
-
|
828
|
-
|
829
|
-
|
1021
|
+
# Get current trace from context
|
1022
|
+
current_trace = current_trace_var.get()
|
1023
|
+
|
1024
|
+
# If there's no current trace, create a root trace
|
1025
|
+
if not current_trace:
|
830
1026
|
trace_id = str(uuid.uuid4())
|
831
|
-
trace_name = func.__name__
|
832
1027
|
project = project_name if project_name is not None else self.project_name
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
1028
|
+
|
1029
|
+
# Create a new trace client to serve as the root
|
1030
|
+
current_trace = TraceClient(
|
1031
|
+
self,
|
1032
|
+
trace_id,
|
1033
|
+
span_name, # MODIFIED: Use span_name directly
|
1034
|
+
project_name=project,
|
1035
|
+
overwrite=overwrite,
|
1036
|
+
rules=self.rules,
|
1037
|
+
enable_monitoring=self.enable_monitoring,
|
1038
|
+
enable_evaluations=self.enable_evaluations
|
1039
|
+
)
|
1040
|
+
|
1041
|
+
# Save empty trace and set trace context
|
1042
|
+
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1043
|
+
trace_token = current_trace_var.set(current_trace)
|
1044
|
+
|
1045
|
+
try:
|
1046
|
+
# Use span for the function execution within the root trace
|
1047
|
+
# This sets the current_span_var
|
1048
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1049
|
+
# Record inputs
|
1050
|
+
span.record_input({
|
1051
|
+
'args': str(args),
|
1052
|
+
'kwargs': kwargs
|
1053
|
+
})
|
1054
|
+
|
1055
|
+
# Execute function
|
1056
|
+
result = await func(*args, **kwargs)
|
1057
|
+
|
1058
|
+
# Record output
|
1059
|
+
span.record_output(result)
|
1060
|
+
|
1061
|
+
# Save the completed trace
|
1062
|
+
current_trace.save(overwrite=overwrite)
|
1063
|
+
return result
|
1064
|
+
finally:
|
1065
|
+
# Reset trace context (span context resets automatically)
|
1066
|
+
current_trace_var.reset(trace_token)
|
1067
|
+
else:
|
1068
|
+
# Already have a trace context, just create a span in it
|
1069
|
+
# The span method handles current_span_var
|
1070
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
840
1071
|
# Record inputs
|
841
1072
|
span.record_input({
|
842
1073
|
'args': str(args),
|
@@ -850,30 +1081,62 @@ class Tracer:
|
|
850
1081
|
span.record_output(result)
|
851
1082
|
|
852
1083
|
return result
|
853
|
-
finally:
|
854
|
-
# Only save and cleanup if this is the root observe call
|
855
|
-
if self.depth == 0:
|
856
|
-
trace.save(empty_save=False, overwrite=overwrite)
|
857
|
-
self._current_trace = None
|
858
1084
|
|
859
1085
|
return async_wrapper
|
860
1086
|
else:
|
1087
|
+
# Non-async function implementation remains unchanged
|
861
1088
|
@functools.wraps(func)
|
862
1089
|
def wrapper(*args, **kwargs):
|
863
|
-
#
|
864
|
-
|
865
|
-
|
866
|
-
|
1090
|
+
# Get current trace from context
|
1091
|
+
current_trace = current_trace_var.get()
|
1092
|
+
|
1093
|
+
# If there's no current trace, create a root trace
|
1094
|
+
if not current_trace:
|
867
1095
|
trace_id = str(uuid.uuid4())
|
868
|
-
trace_name = func.__name__
|
869
1096
|
project = project_name if project_name is not None else self.project_name
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
1097
|
+
|
1098
|
+
# Create a new trace client to serve as the root
|
1099
|
+
current_trace = TraceClient(
|
1100
|
+
self,
|
1101
|
+
trace_id,
|
1102
|
+
span_name, # MODIFIED: Use span_name directly
|
1103
|
+
project_name=project,
|
1104
|
+
overwrite=overwrite,
|
1105
|
+
rules=self.rules,
|
1106
|
+
enable_monitoring=self.enable_monitoring,
|
1107
|
+
enable_evaluations=self.enable_evaluations
|
1108
|
+
)
|
1109
|
+
|
1110
|
+
# Save empty trace and set trace context
|
1111
|
+
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1112
|
+
trace_token = current_trace_var.set(current_trace)
|
1113
|
+
|
1114
|
+
try:
|
1115
|
+
# Use span for the function execution within the root trace
|
1116
|
+
# This sets the current_span_var
|
1117
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1118
|
+
# Record inputs
|
1119
|
+
span.record_input({
|
1120
|
+
'args': str(args),
|
1121
|
+
'kwargs': kwargs
|
1122
|
+
})
|
1123
|
+
|
1124
|
+
# Execute function
|
1125
|
+
result = func(*args, **kwargs)
|
1126
|
+
|
1127
|
+
# Record output
|
1128
|
+
span.record_output(result)
|
1129
|
+
|
1130
|
+
# Save the completed trace
|
1131
|
+
current_trace.save(overwrite=overwrite)
|
1132
|
+
return result
|
1133
|
+
finally:
|
1134
|
+
# Reset trace context (span context resets automatically)
|
1135
|
+
current_trace_var.reset(trace_token)
|
1136
|
+
else:
|
1137
|
+
# Already have a trace context, just create a span in it
|
1138
|
+
# The span method handles current_span_var
|
1139
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
877
1140
|
# Record inputs
|
878
1141
|
span.record_input({
|
879
1142
|
'args': str(args),
|
@@ -887,11 +1150,6 @@ class Tracer:
|
|
887
1150
|
span.record_output(result)
|
888
1151
|
|
889
1152
|
return result
|
890
|
-
finally:
|
891
|
-
# Only save and cleanup if this is the root observe call
|
892
|
-
if self.depth == 0:
|
893
|
-
trace.save(empty_save=False, overwrite=overwrite)
|
894
|
-
self._current_trace = None
|
895
1153
|
|
896
1154
|
return wrapper
|
897
1155
|
|
@@ -900,27 +1158,36 @@ class Tracer:
|
|
900
1158
|
Decorator to trace function execution with detailed entry/exit information.
|
901
1159
|
"""
|
902
1160
|
if func is None:
|
903
|
-
return lambda f: self.
|
1161
|
+
return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
|
904
1162
|
|
905
1163
|
if asyncio.iscoroutinefunction(func):
|
906
1164
|
@functools.wraps(func)
|
907
1165
|
async def async_wrapper(*args, **kwargs):
|
908
|
-
|
909
|
-
|
1166
|
+
# Get current trace from contextvars
|
1167
|
+
current_trace = current_trace_var.get()
|
1168
|
+
if current_trace and scorers:
|
1169
|
+
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1170
|
+
return await func(*args, **kwargs)
|
910
1171
|
return async_wrapper
|
911
1172
|
else:
|
912
1173
|
@functools.wraps(func)
|
913
1174
|
def wrapper(*args, **kwargs):
|
914
|
-
|
915
|
-
|
1175
|
+
# Get current trace from contextvars
|
1176
|
+
current_trace = current_trace_var.get()
|
1177
|
+
if current_trace and scorers:
|
1178
|
+
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1179
|
+
return func(*args, **kwargs)
|
916
1180
|
return wrapper
|
917
1181
|
|
918
1182
|
def async_evaluate(self, *args, **kwargs):
|
919
1183
|
if not self.enable_evaluations:
|
920
1184
|
return
|
921
1185
|
|
922
|
-
|
923
|
-
|
1186
|
+
# Get current trace from context
|
1187
|
+
current_trace = current_trace_var.get()
|
1188
|
+
|
1189
|
+
if current_trace:
|
1190
|
+
current_trace.async_evaluate(*args, **kwargs)
|
924
1191
|
else:
|
925
1192
|
warnings.warn("No trace found, skipping evaluation")
|
926
1193
|
|
@@ -934,14 +1201,14 @@ def wrap(client: Any) -> Any:
|
|
934
1201
|
span_name, original_create = _get_client_config(client)
|
935
1202
|
|
936
1203
|
def traced_create(*args, **kwargs):
|
937
|
-
# Get the current
|
938
|
-
|
1204
|
+
# Get the current trace from contextvars
|
1205
|
+
current_trace = current_trace_var.get()
|
939
1206
|
|
940
|
-
# Skip tracing if no
|
941
|
-
if not
|
1207
|
+
# Skip tracing if no active trace
|
1208
|
+
if not current_trace:
|
942
1209
|
return original_create(*args, **kwargs)
|
943
1210
|
|
944
|
-
with
|
1211
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
945
1212
|
# Format and record the input parameters
|
946
1213
|
input_data = _format_input_data(client, **kwargs)
|
947
1214
|
span.record_input(input_data)
|
@@ -1033,4 +1300,59 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1033
1300
|
"output_tokens": response.usage.output_tokens,
|
1034
1301
|
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
1035
1302
|
}
|
1036
|
-
}
|
1303
|
+
}
|
1304
|
+
|
1305
|
+
# Add a global context-preserving gather function
|
1306
|
+
# async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
|
1307
|
+
# """ # REMOVED
|
1308
|
+
# A wrapper around asyncio.gather that ensures the trace context # REMOVED
|
1309
|
+
# is available within the gathered coroutines using contextvars.copy_context. # REMOVED
|
1310
|
+
# """ # REMOVED
|
1311
|
+
# # Get the original asyncio.gather (if we patched it) # REMOVED
|
1312
|
+
# original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
|
1313
|
+
# # REMOVED
|
1314
|
+
# # Use contextvars.copy_context() to ensure context propagation # REMOVED
|
1315
|
+
# ctx = contextvars.copy_context() # REMOVED
|
1316
|
+
# # REMOVED
|
1317
|
+
# # Wrap the gather call within the copied context # REMOVED
|
1318
|
+
# return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
|
1319
|
+
|
1320
|
+
# Store the original gather and apply the patch *once*
|
1321
|
+
# global _original_gather_stored # REMOVED
|
1322
|
+
# if not globals().get('_original_gather_stored'): # REMOVED
|
1323
|
+
# # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
|
1324
|
+
# if asyncio.gather.__name__ != 'trace_gather': # REMOVED
|
1325
|
+
# asyncio._original_gather = asyncio.gather # REMOVED
|
1326
|
+
# asyncio.gather = trace_gather # REMOVED
|
1327
|
+
# _original_gather_stored = True # REMOVED
|
1328
|
+
|
1329
|
+
# Add the new TraceThreadPoolExecutor class
|
1330
|
+
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
1331
|
+
"""
|
1332
|
+
A ThreadPoolExecutor subclass that automatically propagates contextvars
|
1333
|
+
from the submitting thread to the worker thread using copy_context().run().
|
1334
|
+
|
1335
|
+
This ensures that context variables like `current_trace_var` and
|
1336
|
+
`current_span_var` are available within functions executed by the pool,
|
1337
|
+
allowing the Tracer to maintain correct parent-child relationships across
|
1338
|
+
thread boundaries.
|
1339
|
+
"""
|
1340
|
+
def submit(self, fn, /, *args, **kwargs):
|
1341
|
+
"""
|
1342
|
+
Submit a callable to be executed with the captured context.
|
1343
|
+
"""
|
1344
|
+
# Capture context from the submitting thread
|
1345
|
+
ctx = contextvars.copy_context()
|
1346
|
+
|
1347
|
+
# We use functools.partial to bind the arguments to the function *now*,
|
1348
|
+
# as ctx.run doesn't directly accept *args, **kwargs in the same way
|
1349
|
+
# submit does. It expects ctx.run(callable, arg1, arg2...).
|
1350
|
+
func_with_bound_args = functools.partial(fn, *args, **kwargs)
|
1351
|
+
|
1352
|
+
# Submit the ctx.run callable to the original executor.
|
1353
|
+
# ctx.run will execute the (now argument-bound) function within the
|
1354
|
+
# captured context in the worker thread.
|
1355
|
+
return super().submit(ctx.run, func_with_bound_args)
|
1356
|
+
|
1357
|
+
# Note: The `map` method would also need to be overridden for full context
|
1358
|
+
# propagation if users rely on it, but `submit` is the most common use case.
|