judgeval 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +476 -161
- judgeval/constants.py +4 -2
- judgeval/data/__init__.py +0 -3
- judgeval/data/datasets/eval_dataset_client.py +59 -20
- judgeval/data/result.py +34 -56
- judgeval/judgment_client.py +47 -15
- judgeval/run_evaluation.py +20 -36
- judgeval/scorers/score.py +9 -11
- {judgeval-0.0.26.dist-info → judgeval-0.0.27.dist-info}/METADATA +1 -1
- {judgeval-0.0.26.dist-info → judgeval-0.0.27.dist-info}/RECORD +12 -13
- judgeval/data/api_example.py +0 -98
- {judgeval-0.0.26.dist-info → judgeval-0.0.27.dist-info}/WHEEL +0 -0
- {judgeval-0.0.26.dist-info → judgeval-0.0.27.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -10,11 +10,12 @@ import os
|
|
10
10
|
import time
|
11
11
|
import uuid
|
12
12
|
import warnings
|
13
|
+
import contextvars
|
13
14
|
from contextlib import contextmanager
|
14
15
|
from dataclasses import dataclass, field
|
15
16
|
from datetime import datetime
|
16
17
|
from http import HTTPStatus
|
17
|
-
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union
|
18
|
+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
|
18
19
|
from rich import print as rprint
|
19
20
|
|
20
21
|
# Third-party imports
|
@@ -45,6 +46,12 @@ from judgeval.rules import Rule
|
|
45
46
|
from judgeval.evaluation_run import EvaluationRun
|
46
47
|
from judgeval.data.result import ScoringResult
|
47
48
|
|
49
|
+
# Standard library imports needed for the new class
|
50
|
+
import concurrent.futures
|
51
|
+
|
52
|
+
# Define context variables for tracking the current trace and the current span within a trace
|
53
|
+
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
54
|
+
current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
|
48
55
|
|
49
56
|
# Define type aliases for better code readability and maintainability
|
50
57
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
|
@@ -63,6 +70,7 @@ class TraceEntry:
|
|
63
70
|
"""
|
64
71
|
type: TraceEntryType
|
65
72
|
function: str # Name of the function being traced
|
73
|
+
span_id: str # Unique ID for this specific span instance
|
66
74
|
depth: int # Indentation level for nested calls
|
67
75
|
message: str # Human-readable description
|
68
76
|
timestamp: float # Unix timestamp when entry was created
|
@@ -72,20 +80,28 @@ class TraceEntry:
|
|
72
80
|
inputs: dict = field(default_factory=dict)
|
73
81
|
span_type: SpanType = "span"
|
74
82
|
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
83
|
+
parent_span_id: Optional[str] = None # ID of the parent span instance
|
75
84
|
|
76
85
|
def print_entry(self):
|
86
|
+
"""Print a trace entry with proper formatting and parent relationship information."""
|
77
87
|
indent = " " * self.depth
|
88
|
+
|
78
89
|
if self.type == "enter":
|
79
|
-
|
90
|
+
# Format parent info if present
|
91
|
+
parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
|
92
|
+
print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
|
80
93
|
elif self.type == "exit":
|
81
|
-
print(f"{indent}← {self.function} ({self.duration:.3f}s)")
|
94
|
+
print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
|
82
95
|
elif self.type == "output":
|
83
|
-
|
96
|
+
# Format output to align properly
|
97
|
+
output_str = str(self.output)
|
98
|
+
print(f"{indent}Output (for id: {self.span_id}): {output_str}")
|
84
99
|
elif self.type == "input":
|
85
|
-
|
100
|
+
# Format inputs to align properly
|
101
|
+
print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
|
86
102
|
elif self.type == "evaluation":
|
87
103
|
for evaluation_run in self.evaluation_runs:
|
88
|
-
print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
|
104
|
+
print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
|
89
105
|
|
90
106
|
def _serialize_inputs(self) -> dict:
|
91
107
|
"""Helper method to serialize input data safely.
|
@@ -144,6 +160,7 @@ class TraceEntry:
|
|
144
160
|
return {
|
145
161
|
"type": self.type,
|
146
162
|
"function": self.function,
|
163
|
+
"span_id": self.span_id,
|
147
164
|
"depth": self.depth,
|
148
165
|
"message": self.message,
|
149
166
|
"timestamp": self.timestamp,
|
@@ -151,7 +168,8 @@ class TraceEntry:
|
|
151
168
|
"output": self._serialize_output(),
|
152
169
|
"inputs": self._serialize_inputs(),
|
153
170
|
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
154
|
-
"span_type": self.span_type
|
171
|
+
"span_type": self.span_type,
|
172
|
+
"parent_span_id": self.parent_span_id
|
155
173
|
}
|
156
174
|
|
157
175
|
def _serialize_output(self) -> Any:
|
@@ -315,66 +333,79 @@ class TraceClient:
|
|
315
333
|
overwrite: bool = False,
|
316
334
|
rules: Optional[List[Rule]] = None,
|
317
335
|
enable_monitoring: bool = True,
|
318
|
-
enable_evaluations: bool = True
|
336
|
+
enable_evaluations: bool = True,
|
337
|
+
parent_trace_id: Optional[str] = None,
|
338
|
+
parent_name: Optional[str] = None
|
319
339
|
):
|
320
340
|
self.name = name
|
321
341
|
self.trace_id = trace_id or str(uuid.uuid4())
|
322
342
|
self.project_name = project_name
|
323
343
|
self.overwrite = overwrite
|
324
344
|
self.tracer = tracer
|
325
|
-
# Initialize rules with either provided rules or an empty list
|
326
345
|
self.rules = rules or []
|
327
346
|
self.enable_monitoring = enable_monitoring
|
328
347
|
self.enable_evaluations = enable_evaluations
|
329
|
-
|
348
|
+
self.parent_trace_id = parent_trace_id
|
349
|
+
self.parent_name = parent_name
|
330
350
|
self.client: JudgmentClient = tracer.client
|
331
351
|
self.entries: List[TraceEntry] = []
|
332
352
|
self.start_time = time.time()
|
333
|
-
self.
|
334
|
-
self.
|
335
|
-
self.
|
336
|
-
self.
|
337
|
-
self.
|
338
|
-
self.executed_node_tools = [] # Track node:tool combinations
|
353
|
+
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
|
354
|
+
self.visited_nodes = []
|
355
|
+
self.executed_tools = []
|
356
|
+
self.executed_node_tools = []
|
357
|
+
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
339
358
|
|
340
359
|
@contextmanager
|
341
360
|
def span(self, name: str, span_type: SpanType = "span"):
|
342
|
-
"""Context manager for creating a trace span"""
|
361
|
+
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
343
362
|
start_time = time.time()
|
344
363
|
|
345
|
-
#
|
346
|
-
|
364
|
+
# Generate a unique ID for *this specific span invocation*
|
365
|
+
span_id = str(uuid.uuid4())
|
366
|
+
|
367
|
+
parent_span_id = current_span_var.get() # Get ID of the parent span from context var
|
368
|
+
token = current_span_var.set(span_id) # Set *this* span's ID as the current one
|
369
|
+
|
370
|
+
current_depth = 0
|
371
|
+
if parent_span_id and parent_span_id in self._span_depths:
|
372
|
+
current_depth = self._span_depths[parent_span_id] + 1
|
373
|
+
|
374
|
+
self._span_depths[span_id] = current_depth # Store depth by span_id
|
375
|
+
|
376
|
+
entry = TraceEntry(
|
347
377
|
type="enter",
|
348
378
|
function=name,
|
349
|
-
|
379
|
+
span_id=span_id, # Use the generated span_id
|
380
|
+
depth=current_depth,
|
350
381
|
message=name,
|
351
382
|
timestamp=start_time,
|
352
|
-
span_type=span_type
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
self.tracer.depth += 1
|
357
|
-
prev_span = self._current_span
|
358
|
-
self._current_span = name
|
383
|
+
span_type=span_type,
|
384
|
+
parent_span_id=parent_span_id # Use the parent_id from context var
|
385
|
+
)
|
386
|
+
self.add_entry(entry)
|
359
387
|
|
360
388
|
try:
|
361
389
|
yield self
|
362
390
|
finally:
|
363
|
-
self.tracer.depth -= 1
|
364
391
|
duration = time.time() - start_time
|
365
|
-
|
366
|
-
# Record span exit
|
392
|
+
exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
|
367
393
|
self.add_entry(TraceEntry(
|
368
394
|
type="exit",
|
369
395
|
function=name,
|
370
|
-
|
396
|
+
span_id=span_id, # Use the same span_id for exit
|
397
|
+
depth=exit_depth,
|
371
398
|
message=f"← {name}",
|
372
399
|
timestamp=time.time(),
|
373
400
|
duration=duration,
|
374
401
|
span_type=span_type
|
375
402
|
))
|
376
|
-
|
377
|
-
|
403
|
+
# Clean up depth tracking for this span_id
|
404
|
+
if span_id in self._span_depths:
|
405
|
+
del self._span_depths[span_id]
|
406
|
+
# Reset context var
|
407
|
+
current_span_var.reset(token)
|
408
|
+
|
378
409
|
def async_evaluate(
|
379
410
|
self,
|
380
411
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
@@ -457,7 +488,7 @@ class TraceClient:
|
|
457
488
|
log_results=log_results,
|
458
489
|
project_name=self.project_name,
|
459
490
|
eval_name=f"{self.name.capitalize()}-"
|
460
|
-
f"{
|
491
|
+
f"{current_span_var.get()}-"
|
461
492
|
f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
|
462
493
|
examples=[example],
|
463
494
|
scorers=loaded_scorers,
|
@@ -471,24 +502,30 @@ class TraceClient:
|
|
471
502
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
472
503
|
|
473
504
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
505
|
+
current_span_id = current_span_var.get()
|
506
|
+
if current_span_id:
|
507
|
+
duration = time.time() - start_time
|
508
|
+
prev_entry = self.entries[-1] if self.entries else None
|
509
|
+
# Determine function name based on previous entry or context var (less ideal)
|
510
|
+
function_name = "unknown_function" # Default
|
511
|
+
if prev_entry and prev_entry.span_type == "llm":
|
512
|
+
function_name = prev_entry.function
|
513
|
+
else:
|
514
|
+
# Try to find the function name associated with the current span_id
|
515
|
+
for entry in reversed(self.entries):
|
516
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
517
|
+
function_name = entry.function
|
518
|
+
break
|
485
519
|
|
486
|
-
#
|
520
|
+
# Get depth for the current span
|
521
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
522
|
+
|
487
523
|
self.add_entry(TraceEntry(
|
488
524
|
type="evaluation",
|
489
|
-
function=
|
490
|
-
|
491
|
-
|
525
|
+
function=function_name,
|
526
|
+
span_id=current_span_id, # Associate with current span
|
527
|
+
depth=current_depth,
|
528
|
+
message=f"Evaluation results for {function_name}",
|
492
529
|
timestamp=time.time(),
|
493
530
|
evaluation_runs=[eval_run],
|
494
531
|
duration=duration,
|
@@ -496,16 +533,26 @@ class TraceClient:
|
|
496
533
|
))
|
497
534
|
|
498
535
|
def record_input(self, inputs: dict):
|
499
|
-
|
500
|
-
if
|
536
|
+
current_span_id = current_span_var.get()
|
537
|
+
if current_span_id:
|
538
|
+
entry_span_type = "span"
|
539
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
540
|
+
function_name = "unknown_function" # Default
|
541
|
+
for entry in reversed(self.entries):
|
542
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
543
|
+
entry_span_type = entry.span_type
|
544
|
+
function_name = entry.function
|
545
|
+
break
|
546
|
+
|
501
547
|
self.add_entry(TraceEntry(
|
502
548
|
type="input",
|
503
|
-
function=
|
504
|
-
|
505
|
-
|
549
|
+
function=function_name,
|
550
|
+
span_id=current_span_id, # Use current span_id
|
551
|
+
depth=current_depth,
|
552
|
+
message=f"Inputs to {function_name}",
|
506
553
|
timestamp=time.time(),
|
507
554
|
inputs=inputs,
|
508
|
-
span_type=
|
555
|
+
span_type=entry_span_type
|
509
556
|
))
|
510
557
|
|
511
558
|
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
@@ -519,21 +566,30 @@ class TraceClient:
|
|
519
566
|
raise
|
520
567
|
|
521
568
|
def record_output(self, output: Any):
|
522
|
-
|
523
|
-
if
|
569
|
+
current_span_id = current_span_var.get()
|
570
|
+
if current_span_id:
|
571
|
+
entry_span_type = "span"
|
572
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
573
|
+
function_name = "unknown_function" # Default
|
574
|
+
for entry in reversed(self.entries):
|
575
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
576
|
+
entry_span_type = entry.span_type
|
577
|
+
function_name = entry.function
|
578
|
+
break
|
579
|
+
|
524
580
|
entry = TraceEntry(
|
525
581
|
type="output",
|
526
|
-
function=
|
527
|
-
|
528
|
-
|
582
|
+
function=function_name,
|
583
|
+
span_id=current_span_id, # Use current span_id
|
584
|
+
depth=current_depth,
|
585
|
+
message=f"Output from {function_name}",
|
529
586
|
timestamp=time.time(),
|
530
587
|
output="<pending>" if inspect.iscoroutine(output) else output,
|
531
|
-
span_type=
|
588
|
+
span_type=entry_span_type
|
532
589
|
)
|
533
590
|
self.add_entry(entry)
|
534
591
|
|
535
592
|
if inspect.iscoroutine(output):
|
536
|
-
# Create a task to update the output once the coroutine completes
|
537
593
|
asyncio.create_task(self._update_coroutine_output(entry, output))
|
538
594
|
|
539
595
|
def add_entry(self, entry: TraceEntry):
|
@@ -546,6 +602,58 @@ class TraceClient:
|
|
546
602
|
for entry in self.entries:
|
547
603
|
entry.print_entry()
|
548
604
|
|
605
|
+
def print_hierarchical(self):
|
606
|
+
"""Print the trace in a hierarchical structure based on parent-child relationships"""
|
607
|
+
# First, build a map of spans
|
608
|
+
spans = {}
|
609
|
+
root_spans = []
|
610
|
+
|
611
|
+
# Collect all enter events first
|
612
|
+
for entry in self.entries:
|
613
|
+
if entry.type == "enter":
|
614
|
+
spans[entry.function] = {
|
615
|
+
"name": entry.function,
|
616
|
+
"depth": entry.depth,
|
617
|
+
"parent_id": entry.parent_span_id,
|
618
|
+
"children": []
|
619
|
+
}
|
620
|
+
|
621
|
+
# If no parent, it's a root span
|
622
|
+
if not entry.parent_span_id:
|
623
|
+
root_spans.append(entry.function)
|
624
|
+
elif entry.parent_span_id not in spans:
|
625
|
+
# If parent doesn't exist yet, temporarily treat as root
|
626
|
+
# (we'll fix this later)
|
627
|
+
root_spans.append(entry.function)
|
628
|
+
|
629
|
+
# Build parent-child relationships
|
630
|
+
for span_name, span in spans.items():
|
631
|
+
parent = span["parent_id"]
|
632
|
+
if parent and parent in spans:
|
633
|
+
spans[parent]["children"].append(span_name)
|
634
|
+
# Remove from root spans if it was temporarily there
|
635
|
+
if span_name in root_spans:
|
636
|
+
root_spans.remove(span_name)
|
637
|
+
|
638
|
+
# Now print the hierarchy
|
639
|
+
def print_span(span_name, level=0):
|
640
|
+
if span_name not in spans:
|
641
|
+
return
|
642
|
+
|
643
|
+
span = spans[span_name]
|
644
|
+
indent = " " * level
|
645
|
+
parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
|
646
|
+
print(f"{indent}→ {span_name}{parent_info}")
|
647
|
+
|
648
|
+
# Print children
|
649
|
+
for child in span["children"]:
|
650
|
+
print_span(child, level + 1)
|
651
|
+
|
652
|
+
# Print starting with root spans
|
653
|
+
print("\nHierarchical Trace Structure:")
|
654
|
+
for root in root_spans:
|
655
|
+
print_span(root)
|
656
|
+
|
549
657
|
def get_duration(self) -> float:
|
550
658
|
"""
|
551
659
|
Get the total duration of this trace
|
@@ -554,56 +662,122 @@ class TraceClient:
|
|
554
662
|
|
555
663
|
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
556
664
|
"""
|
557
|
-
Condenses trace entries into a single entry for each
|
665
|
+
Condenses trace entries into a single entry for each span instance,
|
666
|
+
preserving parent-child span relationships using span_id and parent_span_id.
|
558
667
|
"""
|
559
|
-
|
560
|
-
|
561
|
-
|
668
|
+
spans_by_id: Dict[str, dict] = {}
|
669
|
+
|
670
|
+
# First pass: Group entries by span_id and gather data
|
671
|
+
for entry in entries:
|
672
|
+
span_id = entry.get("span_id")
|
673
|
+
if not span_id:
|
674
|
+
continue # Skip entries without a span_id (should not happen)
|
562
675
|
|
563
|
-
for i, entry in enumerate(entries):
|
564
|
-
function = entry["function"]
|
565
|
-
|
566
676
|
if entry["type"] == "enter":
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
#
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
active_functions.remove(function)
|
585
|
-
# del function_entries[function]
|
586
|
-
|
587
|
-
# The OR condition is to handle the LLM client case.
|
588
|
-
# LLM client is a special case where we exit the span, so when we attach evaluations to it,
|
589
|
-
# we have to check if the previous entry is an LLM call.
|
590
|
-
elif function in active_functions or entry["type"] == "evaluation" and entries[i-1]["function"] == entry["function"]:
|
591
|
-
# Update existing function entry with additional data
|
592
|
-
current_entry = function_entries[function]
|
677
|
+
if span_id not in spans_by_id:
|
678
|
+
spans_by_id[span_id] = {
|
679
|
+
"span_id": span_id,
|
680
|
+
"function": entry["function"],
|
681
|
+
"depth": entry["depth"], # Use the depth recorded at entry time
|
682
|
+
"timestamp": entry["timestamp"],
|
683
|
+
"parent_span_id": entry.get("parent_span_id"),
|
684
|
+
"span_type": entry.get("span_type", "span"),
|
685
|
+
"inputs": None,
|
686
|
+
"output": None,
|
687
|
+
"evaluation_runs": [],
|
688
|
+
"duration": None
|
689
|
+
}
|
690
|
+
# Handle potential duplicate enter events if necessary (e.g., log warning)
|
691
|
+
|
692
|
+
elif span_id in spans_by_id:
|
693
|
+
current_span_data = spans_by_id[span_id]
|
593
694
|
|
594
695
|
if entry["type"] == "input" and entry["inputs"]:
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
current_entry["evaluation_runs"] = entry["evaluation_runs"]
|
696
|
+
# Merge inputs if multiple are recorded, or just assign
|
697
|
+
if current_span_data["inputs"] is None:
|
698
|
+
current_span_data["inputs"] = entry["inputs"]
|
699
|
+
elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
|
700
|
+
current_span_data["inputs"].update(entry["inputs"])
|
701
|
+
# Add more sophisticated merging if needed
|
602
702
|
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
703
|
+
elif entry["type"] == "output" and "output" in entry:
|
704
|
+
current_span_data["output"] = entry["output"]
|
705
|
+
|
706
|
+
elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
|
707
|
+
if current_span_data.get("evaluation_runs") is None:
|
708
|
+
current_span_data["evaluation_runs"] = []
|
709
|
+
current_span_data["evaluation_runs"].extend(entry["evaluation_runs"])
|
710
|
+
|
711
|
+
elif entry["type"] == "exit":
|
712
|
+
if current_span_data["duration"] is None: # Calculate duration only once
|
713
|
+
start_time = current_span_data.get("timestamp", entry["timestamp"])
|
714
|
+
current_span_data["duration"] = entry["timestamp"] - start_time
|
715
|
+
# Update depth if exit depth is different (though current span() implementation keeps it same)
|
716
|
+
# current_span_data["depth"] = entry["depth"]
|
717
|
+
|
718
|
+
# Convert dictionary to a list initially for easier access
|
719
|
+
spans_list = list(spans_by_id.values())
|
720
|
+
|
721
|
+
# Build tree structure (adjacency list) and find roots
|
722
|
+
children_map: Dict[Optional[str], List[dict]] = {}
|
723
|
+
roots = []
|
724
|
+
span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
|
725
|
+
|
726
|
+
for span in spans_list:
|
727
|
+
parent_id = span.get("parent_span_id")
|
728
|
+
if parent_id is None:
|
729
|
+
roots.append(span)
|
730
|
+
else:
|
731
|
+
if parent_id not in children_map:
|
732
|
+
children_map[parent_id] = []
|
733
|
+
children_map[parent_id].append(span)
|
734
|
+
|
735
|
+
# Sort roots by timestamp
|
736
|
+
roots.sort(key=lambda x: x.get("timestamp", 0))
|
737
|
+
|
738
|
+
# Perform depth-first traversal to get the final sorted list
|
739
|
+
sorted_condensed_list = []
|
740
|
+
visited = set() # To handle potential cycles, though unlikely with UUIDs
|
741
|
+
|
742
|
+
def dfs(span_data):
|
743
|
+
span_id = span_data['span_id']
|
744
|
+
if span_id in visited:
|
745
|
+
return # Avoid infinite loops in case of cycles
|
746
|
+
visited.add(span_id)
|
747
|
+
|
748
|
+
sorted_condensed_list.append(span_data) # Add parent before children
|
749
|
+
|
750
|
+
# Get children, sort them by timestamp, and visit them
|
751
|
+
span_children = children_map.get(span_id, [])
|
752
|
+
span_children.sort(key=lambda x: x.get("timestamp", 0))
|
753
|
+
for child in span_children:
|
754
|
+
# Ensure the child exists in our map before recursing
|
755
|
+
if child['span_id'] in span_map:
|
756
|
+
dfs(child)
|
757
|
+
else:
|
758
|
+
# This case might indicate an issue, but we'll add the child directly
|
759
|
+
# if its parent was processed but the child itself wasn't in the initial list?
|
760
|
+
# Or if the child's 'enter' event was missing. For robustness, add it.
|
761
|
+
if child['span_id'] not in visited:
|
762
|
+
visited.add(child['span_id'])
|
763
|
+
sorted_condensed_list.append(child)
|
764
|
+
|
765
|
+
|
766
|
+
# Start DFS from each root
|
767
|
+
for root_span in roots:
|
768
|
+
if root_span['span_id'] not in visited:
|
769
|
+
dfs(root_span)
|
770
|
+
|
771
|
+
# Handle spans that might not have been reachable from roots (orphans)
|
772
|
+
# Though ideally, all spans should descend from a root.
|
773
|
+
for span_data in spans_list:
|
774
|
+
if span_data['span_id'] not in visited:
|
775
|
+
# Decide how to handle orphans, maybe append them at the end sorted by time?
|
776
|
+
# For now, let's just add them to ensure they aren't lost.
|
777
|
+
sorted_condensed_list.append(span_data)
|
778
|
+
|
779
|
+
|
780
|
+
return sorted_condensed_list
|
607
781
|
|
608
782
|
def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
|
609
783
|
"""
|
@@ -689,7 +863,9 @@ class TraceClient:
|
|
689
863
|
},
|
690
864
|
"entries": condensed_entries,
|
691
865
|
"empty_save": empty_save,
|
692
|
-
"overwrite": overwrite
|
866
|
+
"overwrite": overwrite,
|
867
|
+
"parent_trace_id": self.parent_trace_id,
|
868
|
+
"parent_name": self.parent_name
|
693
869
|
}
|
694
870
|
# Execute asynchrous evaluation in the background
|
695
871
|
if not empty_save: # Only send to RabbitMQ if the trace is not empty
|
@@ -745,7 +921,6 @@ class Tracer:
|
|
745
921
|
self.project_name: str = project_name
|
746
922
|
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
747
923
|
self.organization_id: str = organization_id
|
748
|
-
self.depth: int = 0
|
749
924
|
self._current_trace: Optional[str] = None
|
750
925
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
751
926
|
self.initialized: bool = True
|
@@ -770,6 +945,15 @@ class Tracer:
|
|
770
945
|
"""Start a new trace context using a context manager"""
|
771
946
|
trace_id = str(uuid.uuid4())
|
772
947
|
project = project_name if project_name is not None else self.project_name
|
948
|
+
|
949
|
+
# Get parent trace info from context
|
950
|
+
parent_trace = current_trace_var.get()
|
951
|
+
parent_trace_id = None
|
952
|
+
parent_name = None
|
953
|
+
|
954
|
+
if parent_trace:
|
955
|
+
parent_trace_id = parent_trace.trace_id
|
956
|
+
parent_name = parent_trace.name
|
773
957
|
|
774
958
|
trace = TraceClient(
|
775
959
|
self,
|
@@ -779,10 +963,13 @@ class Tracer:
|
|
779
963
|
overwrite=overwrite,
|
780
964
|
rules=self.rules, # Pass combined rules to the trace client
|
781
965
|
enable_monitoring=self.enable_monitoring,
|
782
|
-
enable_evaluations=self.enable_evaluations
|
966
|
+
enable_evaluations=self.enable_evaluations,
|
967
|
+
parent_trace_id=parent_trace_id,
|
968
|
+
parent_name=parent_name
|
783
969
|
)
|
784
|
-
|
785
|
-
|
970
|
+
|
971
|
+
# Set the current trace in context variables
|
972
|
+
token = current_trace_var.set(trace)
|
786
973
|
|
787
974
|
# Automatically create top-level span
|
788
975
|
with trace.span(name or "unnamed_trace") as span:
|
@@ -791,13 +978,14 @@ class Tracer:
|
|
791
978
|
trace.save(empty_save=True, overwrite=overwrite)
|
792
979
|
yield trace
|
793
980
|
finally:
|
794
|
-
|
981
|
+
# Reset the context variable
|
982
|
+
current_trace_var.reset(token)
|
795
983
|
|
796
984
|
def get_current_trace(self) -> Optional[TraceClient]:
|
797
985
|
"""
|
798
|
-
Get the current trace context
|
986
|
+
Get the current trace context from contextvars
|
799
987
|
"""
|
800
|
-
return
|
988
|
+
return current_trace_var.get()
|
801
989
|
|
802
990
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
|
803
991
|
"""
|
@@ -823,20 +1011,56 @@ class Tracer:
|
|
823
1011
|
if asyncio.iscoroutinefunction(func):
|
824
1012
|
@functools.wraps(func)
|
825
1013
|
async def async_wrapper(*args, **kwargs):
|
826
|
-
#
|
827
|
-
|
828
|
-
|
829
|
-
|
1014
|
+
# Get current trace from context
|
1015
|
+
current_trace = current_trace_var.get()
|
1016
|
+
|
1017
|
+
# If there's no current trace, create a root trace
|
1018
|
+
if not current_trace:
|
830
1019
|
trace_id = str(uuid.uuid4())
|
831
|
-
trace_name = func.__name__
|
832
1020
|
project = project_name if project_name is not None else self.project_name
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
1021
|
+
|
1022
|
+
# Create a new trace client to serve as the root
|
1023
|
+
current_trace = TraceClient(
|
1024
|
+
self,
|
1025
|
+
trace_id,
|
1026
|
+
span_name, # MODIFIED: Use span_name directly
|
1027
|
+
project_name=project,
|
1028
|
+
overwrite=overwrite,
|
1029
|
+
rules=self.rules,
|
1030
|
+
enable_monitoring=self.enable_monitoring,
|
1031
|
+
enable_evaluations=self.enable_evaluations
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
# Save empty trace and set trace context
|
1035
|
+
current_trace.save(empty_save=True, overwrite=overwrite)
|
1036
|
+
trace_token = current_trace_var.set(current_trace)
|
1037
|
+
|
1038
|
+
try:
|
1039
|
+
# Use span for the function execution within the root trace
|
1040
|
+
# This sets the current_span_var
|
1041
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1042
|
+
# Record inputs
|
1043
|
+
span.record_input({
|
1044
|
+
'args': str(args),
|
1045
|
+
'kwargs': kwargs
|
1046
|
+
})
|
1047
|
+
|
1048
|
+
# Execute function
|
1049
|
+
result = await func(*args, **kwargs)
|
1050
|
+
|
1051
|
+
# Record output
|
1052
|
+
span.record_output(result)
|
1053
|
+
|
1054
|
+
# Save the completed trace
|
1055
|
+
current_trace.save(empty_save=False, overwrite=overwrite)
|
1056
|
+
return result
|
1057
|
+
finally:
|
1058
|
+
# Reset trace context (span context resets automatically)
|
1059
|
+
current_trace_var.reset(trace_token)
|
1060
|
+
else:
|
1061
|
+
# Already have a trace context, just create a span in it
|
1062
|
+
# The span method handles current_span_var
|
1063
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
840
1064
|
# Record inputs
|
841
1065
|
span.record_input({
|
842
1066
|
'args': str(args),
|
@@ -850,30 +1074,62 @@ class Tracer:
|
|
850
1074
|
span.record_output(result)
|
851
1075
|
|
852
1076
|
return result
|
853
|
-
finally:
|
854
|
-
# Only save and cleanup if this is the root observe call
|
855
|
-
if self.depth == 0:
|
856
|
-
trace.save(empty_save=False, overwrite=overwrite)
|
857
|
-
self._current_trace = None
|
858
1077
|
|
859
1078
|
return async_wrapper
|
860
1079
|
else:
|
1080
|
+
# Non-async function implementation remains unchanged
|
861
1081
|
@functools.wraps(func)
|
862
1082
|
def wrapper(*args, **kwargs):
|
863
|
-
#
|
864
|
-
|
865
|
-
|
866
|
-
|
1083
|
+
# Get current trace from context
|
1084
|
+
current_trace = current_trace_var.get()
|
1085
|
+
|
1086
|
+
# If there's no current trace, create a root trace
|
1087
|
+
if not current_trace:
|
867
1088
|
trace_id = str(uuid.uuid4())
|
868
|
-
trace_name = func.__name__
|
869
1089
|
project = project_name if project_name is not None else self.project_name
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
1090
|
+
|
1091
|
+
# Create a new trace client to serve as the root
|
1092
|
+
current_trace = TraceClient(
|
1093
|
+
self,
|
1094
|
+
trace_id,
|
1095
|
+
span_name, # MODIFIED: Use span_name directly
|
1096
|
+
project_name=project,
|
1097
|
+
overwrite=overwrite,
|
1098
|
+
rules=self.rules,
|
1099
|
+
enable_monitoring=self.enable_monitoring,
|
1100
|
+
enable_evaluations=self.enable_evaluations
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
# Save empty trace and set trace context
|
1104
|
+
current_trace.save(empty_save=True, overwrite=overwrite)
|
1105
|
+
trace_token = current_trace_var.set(current_trace)
|
1106
|
+
|
1107
|
+
try:
|
1108
|
+
# Use span for the function execution within the root trace
|
1109
|
+
# This sets the current_span_var
|
1110
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1111
|
+
# Record inputs
|
1112
|
+
span.record_input({
|
1113
|
+
'args': str(args),
|
1114
|
+
'kwargs': kwargs
|
1115
|
+
})
|
1116
|
+
|
1117
|
+
# Execute function
|
1118
|
+
result = func(*args, **kwargs)
|
1119
|
+
|
1120
|
+
# Record output
|
1121
|
+
span.record_output(result)
|
1122
|
+
|
1123
|
+
# Save the completed trace
|
1124
|
+
current_trace.save(empty_save=False, overwrite=overwrite)
|
1125
|
+
return result
|
1126
|
+
finally:
|
1127
|
+
# Reset trace context (span context resets automatically)
|
1128
|
+
current_trace_var.reset(trace_token)
|
1129
|
+
else:
|
1130
|
+
# Already have a trace context, just create a span in it
|
1131
|
+
# The span method handles current_span_var
|
1132
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
877
1133
|
# Record inputs
|
878
1134
|
span.record_input({
|
879
1135
|
'args': str(args),
|
@@ -887,11 +1143,6 @@ class Tracer:
|
|
887
1143
|
span.record_output(result)
|
888
1144
|
|
889
1145
|
return result
|
890
|
-
finally:
|
891
|
-
# Only save and cleanup if this is the root observe call
|
892
|
-
if self.depth == 0:
|
893
|
-
trace.save(empty_save=False, overwrite=overwrite)
|
894
|
-
self._current_trace = None
|
895
1146
|
|
896
1147
|
return wrapper
|
897
1148
|
|
@@ -900,27 +1151,36 @@ class Tracer:
|
|
900
1151
|
Decorator to trace function execution with detailed entry/exit information.
|
901
1152
|
"""
|
902
1153
|
if func is None:
|
903
|
-
return lambda f: self.
|
1154
|
+
return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
|
904
1155
|
|
905
1156
|
if asyncio.iscoroutinefunction(func):
|
906
1157
|
@functools.wraps(func)
|
907
1158
|
async def async_wrapper(*args, **kwargs):
|
908
|
-
|
909
|
-
|
1159
|
+
# Get current trace from contextvars
|
1160
|
+
current_trace = current_trace_var.get()
|
1161
|
+
if current_trace and scorers:
|
1162
|
+
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1163
|
+
return await func(*args, **kwargs)
|
910
1164
|
return async_wrapper
|
911
1165
|
else:
|
912
1166
|
@functools.wraps(func)
|
913
1167
|
def wrapper(*args, **kwargs):
|
914
|
-
|
915
|
-
|
1168
|
+
# Get current trace from contextvars
|
1169
|
+
current_trace = current_trace_var.get()
|
1170
|
+
if current_trace and scorers:
|
1171
|
+
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1172
|
+
return func(*args, **kwargs)
|
916
1173
|
return wrapper
|
917
1174
|
|
918
1175
|
def async_evaluate(self, *args, **kwargs):
|
919
1176
|
if not self.enable_evaluations:
|
920
1177
|
return
|
921
1178
|
|
922
|
-
|
923
|
-
|
1179
|
+
# Get current trace from context
|
1180
|
+
current_trace = current_trace_var.get()
|
1181
|
+
|
1182
|
+
if current_trace:
|
1183
|
+
current_trace.async_evaluate(*args, **kwargs)
|
924
1184
|
else:
|
925
1185
|
warnings.warn("No trace found, skipping evaluation")
|
926
1186
|
|
@@ -934,14 +1194,14 @@ def wrap(client: Any) -> Any:
|
|
934
1194
|
span_name, original_create = _get_client_config(client)
|
935
1195
|
|
936
1196
|
def traced_create(*args, **kwargs):
|
937
|
-
# Get the current
|
938
|
-
|
1197
|
+
# Get the current trace from contextvars
|
1198
|
+
current_trace = current_trace_var.get()
|
939
1199
|
|
940
|
-
# Skip tracing if no
|
941
|
-
if not
|
1200
|
+
# Skip tracing if no active trace
|
1201
|
+
if not current_trace:
|
942
1202
|
return original_create(*args, **kwargs)
|
943
1203
|
|
944
|
-
with
|
1204
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
945
1205
|
# Format and record the input parameters
|
946
1206
|
input_data = _format_input_data(client, **kwargs)
|
947
1207
|
span.record_input(input_data)
|
@@ -1033,4 +1293,59 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1033
1293
|
"output_tokens": response.usage.output_tokens,
|
1034
1294
|
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
1035
1295
|
}
|
1036
|
-
}
|
1296
|
+
}
|
1297
|
+
|
1298
|
+
# Add a global context-preserving gather function
|
1299
|
+
# async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
|
1300
|
+
# """ # REMOVED
|
1301
|
+
# A wrapper around asyncio.gather that ensures the trace context # REMOVED
|
1302
|
+
# is available within the gathered coroutines using contextvars.copy_context. # REMOVED
|
1303
|
+
# """ # REMOVED
|
1304
|
+
# # Get the original asyncio.gather (if we patched it) # REMOVED
|
1305
|
+
# original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
|
1306
|
+
# # REMOVED
|
1307
|
+
# # Use contextvars.copy_context() to ensure context propagation # REMOVED
|
1308
|
+
# ctx = contextvars.copy_context() # REMOVED
|
1309
|
+
# # REMOVED
|
1310
|
+
# # Wrap the gather call within the copied context # REMOVED
|
1311
|
+
# return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
|
1312
|
+
|
1313
|
+
# Store the original gather and apply the patch *once*
|
1314
|
+
# global _original_gather_stored # REMOVED
|
1315
|
+
# if not globals().get('_original_gather_stored'): # REMOVED
|
1316
|
+
# # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
|
1317
|
+
# if asyncio.gather.__name__ != 'trace_gather': # REMOVED
|
1318
|
+
# asyncio._original_gather = asyncio.gather # REMOVED
|
1319
|
+
# asyncio.gather = trace_gather # REMOVED
|
1320
|
+
# _original_gather_stored = True # REMOVED
|
1321
|
+
|
1322
|
+
# Add the new TraceThreadPoolExecutor class
|
1323
|
+
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
1324
|
+
"""
|
1325
|
+
A ThreadPoolExecutor subclass that automatically propagates contextvars
|
1326
|
+
from the submitting thread to the worker thread using copy_context().run().
|
1327
|
+
|
1328
|
+
This ensures that context variables like `current_trace_var` and
|
1329
|
+
`current_span_var` are available within functions executed by the pool,
|
1330
|
+
allowing the Tracer to maintain correct parent-child relationships across
|
1331
|
+
thread boundaries.
|
1332
|
+
"""
|
1333
|
+
def submit(self, fn, /, *args, **kwargs):
|
1334
|
+
"""
|
1335
|
+
Submit a callable to be executed with the captured context.
|
1336
|
+
"""
|
1337
|
+
# Capture context from the submitting thread
|
1338
|
+
ctx = contextvars.copy_context()
|
1339
|
+
|
1340
|
+
# We use functools.partial to bind the arguments to the function *now*,
|
1341
|
+
# as ctx.run doesn't directly accept *args, **kwargs in the same way
|
1342
|
+
# submit does. It expects ctx.run(callable, arg1, arg2...).
|
1343
|
+
func_with_bound_args = functools.partial(fn, *args, **kwargs)
|
1344
|
+
|
1345
|
+
# Submit the ctx.run callable to the original executor.
|
1346
|
+
# ctx.run will execute the (now argument-bound) function within the
|
1347
|
+
# captured context in the worker thread.
|
1348
|
+
return super().submit(ctx.run, func_with_bound_args)
|
1349
|
+
|
1350
|
+
# Note: The `map` method would also need to be overridden for full context
|
1351
|
+
# propagation if users rely on it, but `submit` is the most common use case.
|