judgeval 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/common/tracer.py +528 -166
- judgeval/constants.py +7 -4
- judgeval/data/__init__.py +0 -3
- judgeval/data/datasets/dataset.py +42 -19
- judgeval/data/datasets/eval_dataset_client.py +59 -20
- judgeval/data/result.py +34 -56
- judgeval/integrations/langgraph.py +16 -12
- judgeval/judgment_client.py +85 -23
- judgeval/rules.py +177 -60
- judgeval/run_evaluation.py +143 -122
- judgeval/scorers/score.py +21 -18
- judgeval/utils/alerts.py +32 -1
- {judgeval-0.0.25.dist-info → judgeval-0.0.27.dist-info}/METADATA +1 -1
- {judgeval-0.0.25.dist-info → judgeval-0.0.27.dist-info}/RECORD +16 -17
- judgeval/data/api_example.py +0 -98
- {judgeval-0.0.25.dist-info → judgeval-0.0.27.dist-info}/WHEEL +0 -0
- {judgeval-0.0.25.dist-info → judgeval-0.0.27.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -10,16 +10,18 @@ import os
|
|
10
10
|
import time
|
11
11
|
import uuid
|
12
12
|
import warnings
|
13
|
+
import contextvars
|
13
14
|
from contextlib import contextmanager
|
14
15
|
from dataclasses import dataclass, field
|
15
16
|
from datetime import datetime
|
16
17
|
from http import HTTPStatus
|
17
|
-
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union
|
18
|
+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
|
18
19
|
from rich import print as rprint
|
19
20
|
|
20
21
|
# Third-party imports
|
21
22
|
import pika
|
22
23
|
import requests
|
24
|
+
from litellm import cost_per_token
|
23
25
|
from pydantic import BaseModel
|
24
26
|
from rich import print as rprint
|
25
27
|
from openai import OpenAI
|
@@ -44,6 +46,12 @@ from judgeval.rules import Rule
|
|
44
46
|
from judgeval.evaluation_run import EvaluationRun
|
45
47
|
from judgeval.data.result import ScoringResult
|
46
48
|
|
49
|
+
# Standard library imports needed for the new class
|
50
|
+
import concurrent.futures
|
51
|
+
|
52
|
+
# Define context variables for tracking the current trace and the current span within a trace
|
53
|
+
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
54
|
+
current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
|
47
55
|
|
48
56
|
# Define type aliases for better code readability and maintainability
|
49
57
|
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
|
@@ -62,6 +70,7 @@ class TraceEntry:
|
|
62
70
|
"""
|
63
71
|
type: TraceEntryType
|
64
72
|
function: str # Name of the function being traced
|
73
|
+
span_id: str # Unique ID for this specific span instance
|
65
74
|
depth: int # Indentation level for nested calls
|
66
75
|
message: str # Human-readable description
|
67
76
|
timestamp: float # Unix timestamp when entry was created
|
@@ -71,20 +80,28 @@ class TraceEntry:
|
|
71
80
|
inputs: dict = field(default_factory=dict)
|
72
81
|
span_type: SpanType = "span"
|
73
82
|
evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
|
83
|
+
parent_span_id: Optional[str] = None # ID of the parent span instance
|
74
84
|
|
75
85
|
def print_entry(self):
|
86
|
+
"""Print a trace entry with proper formatting and parent relationship information."""
|
76
87
|
indent = " " * self.depth
|
88
|
+
|
77
89
|
if self.type == "enter":
|
78
|
-
|
90
|
+
# Format parent info if present
|
91
|
+
parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
|
92
|
+
print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
|
79
93
|
elif self.type == "exit":
|
80
|
-
print(f"{indent}← {self.function} ({self.duration:.3f}s)")
|
94
|
+
print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
|
81
95
|
elif self.type == "output":
|
82
|
-
|
96
|
+
# Format output to align properly
|
97
|
+
output_str = str(self.output)
|
98
|
+
print(f"{indent}Output (for id: {self.span_id}): {output_str}")
|
83
99
|
elif self.type == "input":
|
84
|
-
|
100
|
+
# Format inputs to align properly
|
101
|
+
print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
|
85
102
|
elif self.type == "evaluation":
|
86
103
|
for evaluation_run in self.evaluation_runs:
|
87
|
-
print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
|
104
|
+
print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
|
88
105
|
|
89
106
|
def _serialize_inputs(self) -> dict:
|
90
107
|
"""Helper method to serialize input data safely.
|
@@ -143,6 +160,7 @@ class TraceEntry:
|
|
143
160
|
return {
|
144
161
|
"type": self.type,
|
145
162
|
"function": self.function,
|
163
|
+
"span_id": self.span_id,
|
146
164
|
"depth": self.depth,
|
147
165
|
"message": self.message,
|
148
166
|
"timestamp": self.timestamp,
|
@@ -150,7 +168,8 @@ class TraceEntry:
|
|
150
168
|
"output": self._serialize_output(),
|
151
169
|
"inputs": self._serialize_inputs(),
|
152
170
|
"evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
|
153
|
-
"span_type": self.span_type
|
171
|
+
"span_type": self.span_type,
|
172
|
+
"parent_span_id": self.parent_span_id
|
154
173
|
}
|
155
174
|
|
156
175
|
def _serialize_output(self) -> Any:
|
@@ -314,63 +333,79 @@ class TraceClient:
|
|
314
333
|
overwrite: bool = False,
|
315
334
|
rules: Optional[List[Rule]] = None,
|
316
335
|
enable_monitoring: bool = True,
|
317
|
-
enable_evaluations: bool = True
|
336
|
+
enable_evaluations: bool = True,
|
337
|
+
parent_trace_id: Optional[str] = None,
|
338
|
+
parent_name: Optional[str] = None
|
318
339
|
):
|
319
340
|
self.name = name
|
320
341
|
self.trace_id = trace_id or str(uuid.uuid4())
|
321
342
|
self.project_name = project_name
|
322
343
|
self.overwrite = overwrite
|
323
344
|
self.tracer = tracer
|
324
|
-
# Initialize rules with either provided rules or an empty list
|
325
345
|
self.rules = rules or []
|
326
346
|
self.enable_monitoring = enable_monitoring
|
327
347
|
self.enable_evaluations = enable_evaluations
|
328
|
-
|
348
|
+
self.parent_trace_id = parent_trace_id
|
349
|
+
self.parent_name = parent_name
|
329
350
|
self.client: JudgmentClient = tracer.client
|
330
351
|
self.entries: List[TraceEntry] = []
|
331
352
|
self.start_time = time.time()
|
332
|
-
self.
|
333
|
-
self.
|
334
|
-
self.
|
353
|
+
self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
|
354
|
+
self.visited_nodes = []
|
355
|
+
self.executed_tools = []
|
356
|
+
self.executed_node_tools = []
|
357
|
+
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
335
358
|
|
336
359
|
@contextmanager
|
337
360
|
def span(self, name: str, span_type: SpanType = "span"):
|
338
|
-
"""Context manager for creating a trace span"""
|
361
|
+
"""Context manager for creating a trace span, managing the current span via contextvars"""
|
339
362
|
start_time = time.time()
|
340
363
|
|
341
|
-
#
|
342
|
-
|
364
|
+
# Generate a unique ID for *this specific span invocation*
|
365
|
+
span_id = str(uuid.uuid4())
|
366
|
+
|
367
|
+
parent_span_id = current_span_var.get() # Get ID of the parent span from context var
|
368
|
+
token = current_span_var.set(span_id) # Set *this* span's ID as the current one
|
369
|
+
|
370
|
+
current_depth = 0
|
371
|
+
if parent_span_id and parent_span_id in self._span_depths:
|
372
|
+
current_depth = self._span_depths[parent_span_id] + 1
|
373
|
+
|
374
|
+
self._span_depths[span_id] = current_depth # Store depth by span_id
|
375
|
+
|
376
|
+
entry = TraceEntry(
|
343
377
|
type="enter",
|
344
378
|
function=name,
|
345
|
-
|
379
|
+
span_id=span_id, # Use the generated span_id
|
380
|
+
depth=current_depth,
|
346
381
|
message=name,
|
347
382
|
timestamp=start_time,
|
348
|
-
span_type=span_type
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
self.tracer.depth += 1
|
353
|
-
prev_span = self._current_span
|
354
|
-
self._current_span = name
|
383
|
+
span_type=span_type,
|
384
|
+
parent_span_id=parent_span_id # Use the parent_id from context var
|
385
|
+
)
|
386
|
+
self.add_entry(entry)
|
355
387
|
|
356
388
|
try:
|
357
389
|
yield self
|
358
390
|
finally:
|
359
|
-
self.tracer.depth -= 1
|
360
391
|
duration = time.time() - start_time
|
361
|
-
|
362
|
-
# Record span exit
|
392
|
+
exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
|
363
393
|
self.add_entry(TraceEntry(
|
364
394
|
type="exit",
|
365
395
|
function=name,
|
366
|
-
|
396
|
+
span_id=span_id, # Use the same span_id for exit
|
397
|
+
depth=exit_depth,
|
367
398
|
message=f"← {name}",
|
368
399
|
timestamp=time.time(),
|
369
400
|
duration=duration,
|
370
401
|
span_type=span_type
|
371
402
|
))
|
372
|
-
|
373
|
-
|
403
|
+
# Clean up depth tracking for this span_id
|
404
|
+
if span_id in self._span_depths:
|
405
|
+
del self._span_depths[span_id]
|
406
|
+
# Reset context var
|
407
|
+
current_span_var.reset(token)
|
408
|
+
|
374
409
|
def async_evaluate(
|
375
410
|
self,
|
376
411
|
scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
|
@@ -453,7 +488,7 @@ class TraceClient:
|
|
453
488
|
log_results=log_results,
|
454
489
|
project_name=self.project_name,
|
455
490
|
eval_name=f"{self.name.capitalize()}-"
|
456
|
-
f"{
|
491
|
+
f"{current_span_var.get()}-"
|
457
492
|
f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
|
458
493
|
examples=[example],
|
459
494
|
scorers=loaded_scorers,
|
@@ -467,24 +502,30 @@ class TraceClient:
|
|
467
502
|
self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
|
468
503
|
|
469
504
|
def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
505
|
+
current_span_id = current_span_var.get()
|
506
|
+
if current_span_id:
|
507
|
+
duration = time.time() - start_time
|
508
|
+
prev_entry = self.entries[-1] if self.entries else None
|
509
|
+
# Determine function name based on previous entry or context var (less ideal)
|
510
|
+
function_name = "unknown_function" # Default
|
511
|
+
if prev_entry and prev_entry.span_type == "llm":
|
512
|
+
function_name = prev_entry.function
|
513
|
+
else:
|
514
|
+
# Try to find the function name associated with the current span_id
|
515
|
+
for entry in reversed(self.entries):
|
516
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
517
|
+
function_name = entry.function
|
518
|
+
break
|
481
519
|
|
482
|
-
#
|
520
|
+
# Get depth for the current span
|
521
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
522
|
+
|
483
523
|
self.add_entry(TraceEntry(
|
484
524
|
type="evaluation",
|
485
|
-
function=
|
486
|
-
|
487
|
-
|
525
|
+
function=function_name,
|
526
|
+
span_id=current_span_id, # Associate with current span
|
527
|
+
depth=current_depth,
|
528
|
+
message=f"Evaluation results for {function_name}",
|
488
529
|
timestamp=time.time(),
|
489
530
|
evaluation_runs=[eval_run],
|
490
531
|
duration=duration,
|
@@ -492,16 +533,26 @@ class TraceClient:
|
|
492
533
|
))
|
493
534
|
|
494
535
|
def record_input(self, inputs: dict):
|
495
|
-
|
496
|
-
if
|
536
|
+
current_span_id = current_span_var.get()
|
537
|
+
if current_span_id:
|
538
|
+
entry_span_type = "span"
|
539
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
540
|
+
function_name = "unknown_function" # Default
|
541
|
+
for entry in reversed(self.entries):
|
542
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
543
|
+
entry_span_type = entry.span_type
|
544
|
+
function_name = entry.function
|
545
|
+
break
|
546
|
+
|
497
547
|
self.add_entry(TraceEntry(
|
498
548
|
type="input",
|
499
|
-
function=
|
500
|
-
|
501
|
-
|
549
|
+
function=function_name,
|
550
|
+
span_id=current_span_id, # Use current span_id
|
551
|
+
depth=current_depth,
|
552
|
+
message=f"Inputs to {function_name}",
|
502
553
|
timestamp=time.time(),
|
503
554
|
inputs=inputs,
|
504
|
-
span_type=
|
555
|
+
span_type=entry_span_type
|
505
556
|
))
|
506
557
|
|
507
558
|
async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
|
@@ -515,21 +566,30 @@ class TraceClient:
|
|
515
566
|
raise
|
516
567
|
|
517
568
|
def record_output(self, output: Any):
|
518
|
-
|
519
|
-
if
|
569
|
+
current_span_id = current_span_var.get()
|
570
|
+
if current_span_id:
|
571
|
+
entry_span_type = "span"
|
572
|
+
current_depth = self._span_depths.get(current_span_id, 0)
|
573
|
+
function_name = "unknown_function" # Default
|
574
|
+
for entry in reversed(self.entries):
|
575
|
+
if entry.span_id == current_span_id and entry.type == 'enter':
|
576
|
+
entry_span_type = entry.span_type
|
577
|
+
function_name = entry.function
|
578
|
+
break
|
579
|
+
|
520
580
|
entry = TraceEntry(
|
521
581
|
type="output",
|
522
|
-
function=
|
523
|
-
|
524
|
-
|
582
|
+
function=function_name,
|
583
|
+
span_id=current_span_id, # Use current span_id
|
584
|
+
depth=current_depth,
|
585
|
+
message=f"Output from {function_name}",
|
525
586
|
timestamp=time.time(),
|
526
587
|
output="<pending>" if inspect.iscoroutine(output) else output,
|
527
|
-
span_type=
|
588
|
+
span_type=entry_span_type
|
528
589
|
)
|
529
590
|
self.add_entry(entry)
|
530
591
|
|
531
592
|
if inspect.iscoroutine(output):
|
532
|
-
# Create a task to update the output once the coroutine completes
|
533
593
|
asyncio.create_task(self._update_coroutine_output(entry, output))
|
534
594
|
|
535
595
|
def add_entry(self, entry: TraceEntry):
|
@@ -542,6 +602,58 @@ class TraceClient:
|
|
542
602
|
for entry in self.entries:
|
543
603
|
entry.print_entry()
|
544
604
|
|
605
|
+
def print_hierarchical(self):
|
606
|
+
"""Print the trace in a hierarchical structure based on parent-child relationships"""
|
607
|
+
# First, build a map of spans
|
608
|
+
spans = {}
|
609
|
+
root_spans = []
|
610
|
+
|
611
|
+
# Collect all enter events first
|
612
|
+
for entry in self.entries:
|
613
|
+
if entry.type == "enter":
|
614
|
+
spans[entry.function] = {
|
615
|
+
"name": entry.function,
|
616
|
+
"depth": entry.depth,
|
617
|
+
"parent_id": entry.parent_span_id,
|
618
|
+
"children": []
|
619
|
+
}
|
620
|
+
|
621
|
+
# If no parent, it's a root span
|
622
|
+
if not entry.parent_span_id:
|
623
|
+
root_spans.append(entry.function)
|
624
|
+
elif entry.parent_span_id not in spans:
|
625
|
+
# If parent doesn't exist yet, temporarily treat as root
|
626
|
+
# (we'll fix this later)
|
627
|
+
root_spans.append(entry.function)
|
628
|
+
|
629
|
+
# Build parent-child relationships
|
630
|
+
for span_name, span in spans.items():
|
631
|
+
parent = span["parent_id"]
|
632
|
+
if parent and parent in spans:
|
633
|
+
spans[parent]["children"].append(span_name)
|
634
|
+
# Remove from root spans if it was temporarily there
|
635
|
+
if span_name in root_spans:
|
636
|
+
root_spans.remove(span_name)
|
637
|
+
|
638
|
+
# Now print the hierarchy
|
639
|
+
def print_span(span_name, level=0):
|
640
|
+
if span_name not in spans:
|
641
|
+
return
|
642
|
+
|
643
|
+
span = spans[span_name]
|
644
|
+
indent = " " * level
|
645
|
+
parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
|
646
|
+
print(f"{indent}→ {span_name}{parent_info}")
|
647
|
+
|
648
|
+
# Print children
|
649
|
+
for child in span["children"]:
|
650
|
+
print_span(child, level + 1)
|
651
|
+
|
652
|
+
# Print starting with root spans
|
653
|
+
print("\nHierarchical Trace Structure:")
|
654
|
+
for root in root_spans:
|
655
|
+
print_span(root)
|
656
|
+
|
545
657
|
def get_duration(self) -> float:
|
546
658
|
"""
|
547
659
|
Get the total duration of this trace
|
@@ -550,56 +662,122 @@ class TraceClient:
|
|
550
662
|
|
551
663
|
def condense_trace(self, entries: List[dict]) -> List[dict]:
|
552
664
|
"""
|
553
|
-
Condenses trace entries into a single entry for each
|
665
|
+
Condenses trace entries into a single entry for each span instance,
|
666
|
+
preserving parent-child span relationships using span_id and parent_span_id.
|
554
667
|
"""
|
555
|
-
|
556
|
-
|
557
|
-
|
668
|
+
spans_by_id: Dict[str, dict] = {}
|
669
|
+
|
670
|
+
# First pass: Group entries by span_id and gather data
|
671
|
+
for entry in entries:
|
672
|
+
span_id = entry.get("span_id")
|
673
|
+
if not span_id:
|
674
|
+
continue # Skip entries without a span_id (should not happen)
|
558
675
|
|
559
|
-
for i, entry in enumerate(entries):
|
560
|
-
function = entry["function"]
|
561
|
-
|
562
676
|
if entry["type"] == "enter":
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
#
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
active_functions.remove(function)
|
581
|
-
# del function_entries[function]
|
582
|
-
|
583
|
-
# The OR condition is to handle the LLM client case.
|
584
|
-
# LLM client is a special case where we exit the span, so when we attach evaluations to it,
|
585
|
-
# we have to check if the previous entry is an LLM call.
|
586
|
-
elif function in active_functions or entry["type"] == "evaluation" and entries[i-1]["function"] == entry["function"]:
|
587
|
-
# Update existing function entry with additional data
|
588
|
-
current_entry = function_entries[function]
|
677
|
+
if span_id not in spans_by_id:
|
678
|
+
spans_by_id[span_id] = {
|
679
|
+
"span_id": span_id,
|
680
|
+
"function": entry["function"],
|
681
|
+
"depth": entry["depth"], # Use the depth recorded at entry time
|
682
|
+
"timestamp": entry["timestamp"],
|
683
|
+
"parent_span_id": entry.get("parent_span_id"),
|
684
|
+
"span_type": entry.get("span_type", "span"),
|
685
|
+
"inputs": None,
|
686
|
+
"output": None,
|
687
|
+
"evaluation_runs": [],
|
688
|
+
"duration": None
|
689
|
+
}
|
690
|
+
# Handle potential duplicate enter events if necessary (e.g., log warning)
|
691
|
+
|
692
|
+
elif span_id in spans_by_id:
|
693
|
+
current_span_data = spans_by_id[span_id]
|
589
694
|
|
590
695
|
if entry["type"] == "input" and entry["inputs"]:
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
current_entry["evaluation_runs"] = entry["evaluation_runs"]
|
696
|
+
# Merge inputs if multiple are recorded, or just assign
|
697
|
+
if current_span_data["inputs"] is None:
|
698
|
+
current_span_data["inputs"] = entry["inputs"]
|
699
|
+
elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
|
700
|
+
current_span_data["inputs"].update(entry["inputs"])
|
701
|
+
# Add more sophisticated merging if needed
|
598
702
|
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
703
|
+
elif entry["type"] == "output" and "output" in entry:
|
704
|
+
current_span_data["output"] = entry["output"]
|
705
|
+
|
706
|
+
elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
|
707
|
+
if current_span_data.get("evaluation_runs") is None:
|
708
|
+
current_span_data["evaluation_runs"] = []
|
709
|
+
current_span_data["evaluation_runs"].extend(entry["evaluation_runs"])
|
710
|
+
|
711
|
+
elif entry["type"] == "exit":
|
712
|
+
if current_span_data["duration"] is None: # Calculate duration only once
|
713
|
+
start_time = current_span_data.get("timestamp", entry["timestamp"])
|
714
|
+
current_span_data["duration"] = entry["timestamp"] - start_time
|
715
|
+
# Update depth if exit depth is different (though current span() implementation keeps it same)
|
716
|
+
# current_span_data["depth"] = entry["depth"]
|
717
|
+
|
718
|
+
# Convert dictionary to a list initially for easier access
|
719
|
+
spans_list = list(spans_by_id.values())
|
720
|
+
|
721
|
+
# Build tree structure (adjacency list) and find roots
|
722
|
+
children_map: Dict[Optional[str], List[dict]] = {}
|
723
|
+
roots = []
|
724
|
+
span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
|
725
|
+
|
726
|
+
for span in spans_list:
|
727
|
+
parent_id = span.get("parent_span_id")
|
728
|
+
if parent_id is None:
|
729
|
+
roots.append(span)
|
730
|
+
else:
|
731
|
+
if parent_id not in children_map:
|
732
|
+
children_map[parent_id] = []
|
733
|
+
children_map[parent_id].append(span)
|
734
|
+
|
735
|
+
# Sort roots by timestamp
|
736
|
+
roots.sort(key=lambda x: x.get("timestamp", 0))
|
737
|
+
|
738
|
+
# Perform depth-first traversal to get the final sorted list
|
739
|
+
sorted_condensed_list = []
|
740
|
+
visited = set() # To handle potential cycles, though unlikely with UUIDs
|
741
|
+
|
742
|
+
def dfs(span_data):
|
743
|
+
span_id = span_data['span_id']
|
744
|
+
if span_id in visited:
|
745
|
+
return # Avoid infinite loops in case of cycles
|
746
|
+
visited.add(span_id)
|
747
|
+
|
748
|
+
sorted_condensed_list.append(span_data) # Add parent before children
|
749
|
+
|
750
|
+
# Get children, sort them by timestamp, and visit them
|
751
|
+
span_children = children_map.get(span_id, [])
|
752
|
+
span_children.sort(key=lambda x: x.get("timestamp", 0))
|
753
|
+
for child in span_children:
|
754
|
+
# Ensure the child exists in our map before recursing
|
755
|
+
if child['span_id'] in span_map:
|
756
|
+
dfs(child)
|
757
|
+
else:
|
758
|
+
# This case might indicate an issue, but we'll add the child directly
|
759
|
+
# if its parent was processed but the child itself wasn't in the initial list?
|
760
|
+
# Or if the child's 'enter' event was missing. For robustness, add it.
|
761
|
+
if child['span_id'] not in visited:
|
762
|
+
visited.add(child['span_id'])
|
763
|
+
sorted_condensed_list.append(child)
|
764
|
+
|
765
|
+
|
766
|
+
# Start DFS from each root
|
767
|
+
for root_span in roots:
|
768
|
+
if root_span['span_id'] not in visited:
|
769
|
+
dfs(root_span)
|
770
|
+
|
771
|
+
# Handle spans that might not have been reachable from roots (orphans)
|
772
|
+
# Though ideally, all spans should descend from a root.
|
773
|
+
for span_data in spans_list:
|
774
|
+
if span_data['span_id'] not in visited:
|
775
|
+
# Decide how to handle orphans, maybe append them at the end sorted by time?
|
776
|
+
# For now, let's just add them to ensure they aren't lost.
|
777
|
+
sorted_condensed_list.append(span_data)
|
778
|
+
|
779
|
+
|
780
|
+
return sorted_condensed_list
|
603
781
|
|
604
782
|
def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
|
605
783
|
"""
|
@@ -618,34 +796,76 @@ class TraceClient:
|
|
618
796
|
total_completion_tokens = 0
|
619
797
|
total_tokens = 0
|
620
798
|
|
799
|
+
total_prompt_tokens_cost = 0.0
|
800
|
+
total_completion_tokens_cost = 0.0
|
801
|
+
total_cost = 0.0
|
802
|
+
|
621
803
|
for entry in condensed_entries:
|
622
804
|
if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
|
623
|
-
|
805
|
+
output = entry["output"]
|
806
|
+
usage = output.get("usage", {})
|
807
|
+
model_name = entry.get("inputs", {}).get("model", "")
|
808
|
+
prompt_tokens = 0
|
809
|
+
completion_tokens = 0
|
810
|
+
|
624
811
|
# Handle OpenAI/Together format
|
625
812
|
if "prompt_tokens" in usage:
|
626
|
-
|
627
|
-
|
813
|
+
prompt_tokens = usage.get("prompt_tokens", 0)
|
814
|
+
completion_tokens = usage.get("completion_tokens", 0)
|
815
|
+
total_prompt_tokens += prompt_tokens
|
816
|
+
total_completion_tokens += completion_tokens
|
628
817
|
# Handle Anthropic format
|
629
818
|
elif "input_tokens" in usage:
|
630
|
-
|
631
|
-
|
819
|
+
prompt_tokens = usage.get("input_tokens", 0)
|
820
|
+
completion_tokens = usage.get("output_tokens", 0)
|
821
|
+
total_prompt_tokens += prompt_tokens
|
822
|
+
total_completion_tokens += completion_tokens
|
823
|
+
|
632
824
|
total_tokens += usage.get("total_tokens", 0)
|
825
|
+
|
826
|
+
# Calculate costs if model name is available
|
827
|
+
if model_name:
|
828
|
+
try:
|
829
|
+
prompt_cost, completion_cost = cost_per_token(
|
830
|
+
model=model_name,
|
831
|
+
prompt_tokens=prompt_tokens,
|
832
|
+
completion_tokens=completion_tokens
|
833
|
+
)
|
834
|
+
total_prompt_tokens_cost += prompt_cost
|
835
|
+
total_completion_tokens_cost += completion_cost
|
836
|
+
total_cost += prompt_cost + completion_cost
|
837
|
+
|
838
|
+
# Add cost information directly to the usage dictionary in the condensed entry
|
839
|
+
if "usage" not in output:
|
840
|
+
output["usage"] = {}
|
841
|
+
output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
|
842
|
+
output["usage"]["completion_tokens_cost_usd"] = completion_cost
|
843
|
+
output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
|
844
|
+
except Exception as e:
|
845
|
+
# If cost calculation fails, continue without adding costs
|
846
|
+
print(f"Error calculating cost for model '{model_name}': {str(e)}")
|
847
|
+
pass
|
633
848
|
|
634
849
|
# Create trace document
|
635
850
|
trace_data = {
|
636
851
|
"trace_id": self.trace_id,
|
637
852
|
"name": self.name,
|
638
853
|
"project_name": self.project_name,
|
639
|
-
"created_at": datetime.
|
854
|
+
"created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
|
640
855
|
"duration": total_duration,
|
641
856
|
"token_counts": {
|
642
857
|
"prompt_tokens": total_prompt_tokens,
|
643
858
|
"completion_tokens": total_completion_tokens,
|
644
859
|
"total_tokens": total_tokens,
|
860
|
+
"prompt_tokens_cost_usd": total_prompt_tokens_cost,
|
861
|
+
"completion_tokens_cost_usd": total_completion_tokens_cost,
|
862
|
+
"total_cost_usd": total_cost
|
645
863
|
},
|
646
864
|
"entries": condensed_entries,
|
647
865
|
"empty_save": empty_save,
|
648
|
-
"overwrite": overwrite
|
866
|
+
"overwrite": overwrite,
|
867
|
+
"parent_trace_id": self.parent_trace_id,
|
868
|
+
"parent_name": self.parent_name
|
649
869
|
}
|
650
870
|
# Execute asynchrous evaluation in the background
|
651
871
|
if not empty_save: # Only send to RabbitMQ if the trace is not empty
|
@@ -697,12 +917,10 @@ class Tracer:
|
|
697
917
|
|
698
918
|
if not organization_id:
|
699
919
|
raise ValueError("Tracer must be configured with an Organization ID")
|
700
|
-
|
701
920
|
self.api_key: str = api_key
|
702
921
|
self.project_name: str = project_name
|
703
922
|
self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
|
704
923
|
self.organization_id: str = organization_id
|
705
|
-
self.depth: int = 0
|
706
924
|
self._current_trace: Optional[str] = None
|
707
925
|
self.rules: List[Rule] = rules or [] # Store rules at tracer level
|
708
926
|
self.initialized: bool = True
|
@@ -727,6 +945,15 @@ class Tracer:
|
|
727
945
|
"""Start a new trace context using a context manager"""
|
728
946
|
trace_id = str(uuid.uuid4())
|
729
947
|
project = project_name if project_name is not None else self.project_name
|
948
|
+
|
949
|
+
# Get parent trace info from context
|
950
|
+
parent_trace = current_trace_var.get()
|
951
|
+
parent_trace_id = None
|
952
|
+
parent_name = None
|
953
|
+
|
954
|
+
if parent_trace:
|
955
|
+
parent_trace_id = parent_trace.trace_id
|
956
|
+
parent_name = parent_trace.name
|
730
957
|
|
731
958
|
trace = TraceClient(
|
732
959
|
self,
|
@@ -736,10 +963,13 @@ class Tracer:
|
|
736
963
|
overwrite=overwrite,
|
737
964
|
rules=self.rules, # Pass combined rules to the trace client
|
738
965
|
enable_monitoring=self.enable_monitoring,
|
739
|
-
enable_evaluations=self.enable_evaluations
|
966
|
+
enable_evaluations=self.enable_evaluations,
|
967
|
+
parent_trace_id=parent_trace_id,
|
968
|
+
parent_name=parent_name
|
740
969
|
)
|
741
|
-
|
742
|
-
|
970
|
+
|
971
|
+
# Set the current trace in context variables
|
972
|
+
token = current_trace_var.set(trace)
|
743
973
|
|
744
974
|
# Automatically create top-level span
|
745
975
|
with trace.span(name or "unnamed_trace") as span:
|
@@ -748,13 +978,14 @@ class Tracer:
|
|
748
978
|
trace.save(empty_save=True, overwrite=overwrite)
|
749
979
|
yield trace
|
750
980
|
finally:
|
751
|
-
|
981
|
+
# Reset the context variable
|
982
|
+
current_trace_var.reset(token)
|
752
983
|
|
753
984
|
def get_current_trace(self) -> Optional[TraceClient]:
|
754
985
|
"""
|
755
|
-
Get the current trace context
|
986
|
+
Get the current trace context from contextvars
|
756
987
|
"""
|
757
|
-
return
|
988
|
+
return current_trace_var.get()
|
758
989
|
|
759
990
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
|
760
991
|
"""
|
@@ -767,8 +998,9 @@ class Tracer:
|
|
767
998
|
project_name: Optional project name override
|
768
999
|
overwrite: Whether to overwrite existing traces
|
769
1000
|
"""
|
1001
|
+
# If monitoring is disabled, return the function as is
|
770
1002
|
if not self.enable_monitoring:
|
771
|
-
return
|
1003
|
+
return func if func else lambda f: f
|
772
1004
|
|
773
1005
|
if func is None:
|
774
1006
|
return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name, overwrite=overwrite)
|
@@ -779,20 +1011,56 @@ class Tracer:
|
|
779
1011
|
if asyncio.iscoroutinefunction(func):
|
780
1012
|
@functools.wraps(func)
|
781
1013
|
async def async_wrapper(*args, **kwargs):
|
782
|
-
#
|
783
|
-
|
784
|
-
|
785
|
-
|
1014
|
+
# Get current trace from context
|
1015
|
+
current_trace = current_trace_var.get()
|
1016
|
+
|
1017
|
+
# If there's no current trace, create a root trace
|
1018
|
+
if not current_trace:
|
786
1019
|
trace_id = str(uuid.uuid4())
|
787
|
-
trace_name = func.__name__
|
788
1020
|
project = project_name if project_name is not None else self.project_name
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
1021
|
+
|
1022
|
+
# Create a new trace client to serve as the root
|
1023
|
+
current_trace = TraceClient(
|
1024
|
+
self,
|
1025
|
+
trace_id,
|
1026
|
+
span_name, # MODIFIED: Use span_name directly
|
1027
|
+
project_name=project,
|
1028
|
+
overwrite=overwrite,
|
1029
|
+
rules=self.rules,
|
1030
|
+
enable_monitoring=self.enable_monitoring,
|
1031
|
+
enable_evaluations=self.enable_evaluations
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
# Save empty trace and set trace context
|
1035
|
+
current_trace.save(empty_save=True, overwrite=overwrite)
|
1036
|
+
trace_token = current_trace_var.set(current_trace)
|
1037
|
+
|
1038
|
+
try:
|
1039
|
+
# Use span for the function execution within the root trace
|
1040
|
+
# This sets the current_span_var
|
1041
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1042
|
+
# Record inputs
|
1043
|
+
span.record_input({
|
1044
|
+
'args': str(args),
|
1045
|
+
'kwargs': kwargs
|
1046
|
+
})
|
1047
|
+
|
1048
|
+
# Execute function
|
1049
|
+
result = await func(*args, **kwargs)
|
1050
|
+
|
1051
|
+
# Record output
|
1052
|
+
span.record_output(result)
|
1053
|
+
|
1054
|
+
# Save the completed trace
|
1055
|
+
current_trace.save(empty_save=False, overwrite=overwrite)
|
1056
|
+
return result
|
1057
|
+
finally:
|
1058
|
+
# Reset trace context (span context resets automatically)
|
1059
|
+
current_trace_var.reset(trace_token)
|
1060
|
+
else:
|
1061
|
+
# Already have a trace context, just create a span in it
|
1062
|
+
# The span method handles current_span_var
|
1063
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
796
1064
|
# Record inputs
|
797
1065
|
span.record_input({
|
798
1066
|
'args': str(args),
|
@@ -806,30 +1074,62 @@ class Tracer:
|
|
806
1074
|
span.record_output(result)
|
807
1075
|
|
808
1076
|
return result
|
809
|
-
finally:
|
810
|
-
# Only save and cleanup if this is the root observe call
|
811
|
-
if self.depth == 0:
|
812
|
-
trace.save(empty_save=False, overwrite=overwrite)
|
813
|
-
self._current_trace = None
|
814
1077
|
|
815
1078
|
return async_wrapper
|
816
1079
|
else:
|
1080
|
+
# Non-async function implementation remains unchanged
|
817
1081
|
@functools.wraps(func)
|
818
1082
|
def wrapper(*args, **kwargs):
|
819
|
-
#
|
820
|
-
|
821
|
-
|
822
|
-
|
1083
|
+
# Get current trace from context
|
1084
|
+
current_trace = current_trace_var.get()
|
1085
|
+
|
1086
|
+
# If there's no current trace, create a root trace
|
1087
|
+
if not current_trace:
|
823
1088
|
trace_id = str(uuid.uuid4())
|
824
|
-
trace_name = func.__name__
|
825
1089
|
project = project_name if project_name is not None else self.project_name
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
1090
|
+
|
1091
|
+
# Create a new trace client to serve as the root
|
1092
|
+
current_trace = TraceClient(
|
1093
|
+
self,
|
1094
|
+
trace_id,
|
1095
|
+
span_name, # MODIFIED: Use span_name directly
|
1096
|
+
project_name=project,
|
1097
|
+
overwrite=overwrite,
|
1098
|
+
rules=self.rules,
|
1099
|
+
enable_monitoring=self.enable_monitoring,
|
1100
|
+
enable_evaluations=self.enable_evaluations
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
# Save empty trace and set trace context
|
1104
|
+
current_trace.save(empty_save=True, overwrite=overwrite)
|
1105
|
+
trace_token = current_trace_var.set(current_trace)
|
1106
|
+
|
1107
|
+
try:
|
1108
|
+
# Use span for the function execution within the root trace
|
1109
|
+
# This sets the current_span_var
|
1110
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1111
|
+
# Record inputs
|
1112
|
+
span.record_input({
|
1113
|
+
'args': str(args),
|
1114
|
+
'kwargs': kwargs
|
1115
|
+
})
|
1116
|
+
|
1117
|
+
# Execute function
|
1118
|
+
result = func(*args, **kwargs)
|
1119
|
+
|
1120
|
+
# Record output
|
1121
|
+
span.record_output(result)
|
1122
|
+
|
1123
|
+
# Save the completed trace
|
1124
|
+
current_trace.save(empty_save=False, overwrite=overwrite)
|
1125
|
+
return result
|
1126
|
+
finally:
|
1127
|
+
# Reset trace context (span context resets automatically)
|
1128
|
+
current_trace_var.reset(trace_token)
|
1129
|
+
else:
|
1130
|
+
# Already have a trace context, just create a span in it
|
1131
|
+
# The span method handles current_span_var
|
1132
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
833
1133
|
# Record inputs
|
834
1134
|
span.record_input({
|
835
1135
|
'args': str(args),
|
@@ -843,11 +1143,6 @@ class Tracer:
|
|
843
1143
|
span.record_output(result)
|
844
1144
|
|
845
1145
|
return result
|
846
|
-
finally:
|
847
|
-
# Only save and cleanup if this is the root observe call
|
848
|
-
if self.depth == 0:
|
849
|
-
trace.save(empty_save=False, overwrite=overwrite)
|
850
|
-
self._current_trace = None
|
851
1146
|
|
852
1147
|
return wrapper
|
853
1148
|
|
@@ -856,24 +1151,36 @@ class Tracer:
|
|
856
1151
|
Decorator to trace function execution with detailed entry/exit information.
|
857
1152
|
"""
|
858
1153
|
if func is None:
|
859
|
-
return lambda f: self.
|
1154
|
+
return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
|
860
1155
|
|
861
1156
|
if asyncio.iscoroutinefunction(func):
|
862
1157
|
@functools.wraps(func)
|
863
1158
|
async def async_wrapper(*args, **kwargs):
|
864
|
-
|
865
|
-
|
1159
|
+
# Get current trace from contextvars
|
1160
|
+
current_trace = current_trace_var.get()
|
1161
|
+
if current_trace and scorers:
|
1162
|
+
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1163
|
+
return await func(*args, **kwargs)
|
866
1164
|
return async_wrapper
|
867
1165
|
else:
|
868
1166
|
@functools.wraps(func)
|
869
1167
|
def wrapper(*args, **kwargs):
|
870
|
-
|
871
|
-
|
1168
|
+
# Get current trace from contextvars
|
1169
|
+
current_trace = current_trace_var.get()
|
1170
|
+
if current_trace and scorers:
|
1171
|
+
current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
|
1172
|
+
return func(*args, **kwargs)
|
872
1173
|
return wrapper
|
873
1174
|
|
874
1175
|
def async_evaluate(self, *args, **kwargs):
|
875
|
-
if self.
|
876
|
-
|
1176
|
+
if not self.enable_evaluations:
|
1177
|
+
return
|
1178
|
+
|
1179
|
+
# Get current trace from context
|
1180
|
+
current_trace = current_trace_var.get()
|
1181
|
+
|
1182
|
+
if current_trace:
|
1183
|
+
current_trace.async_evaluate(*args, **kwargs)
|
877
1184
|
else:
|
878
1185
|
warnings.warn("No trace found, skipping evaluation")
|
879
1186
|
|
@@ -887,14 +1194,14 @@ def wrap(client: Any) -> Any:
|
|
887
1194
|
span_name, original_create = _get_client_config(client)
|
888
1195
|
|
889
1196
|
def traced_create(*args, **kwargs):
|
890
|
-
# Get the current
|
891
|
-
|
1197
|
+
# Get the current trace from contextvars
|
1198
|
+
current_trace = current_trace_var.get()
|
892
1199
|
|
893
|
-
# Skip tracing if no
|
894
|
-
if not
|
1200
|
+
# Skip tracing if no active trace
|
1201
|
+
if not current_trace:
|
895
1202
|
return original_create(*args, **kwargs)
|
896
1203
|
|
897
|
-
with
|
1204
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
898
1205
|
# Format and record the input parameters
|
899
1206
|
input_data = _format_input_data(client, **kwargs)
|
900
1207
|
span.record_input(input_data)
|
@@ -986,4 +1293,59 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
986
1293
|
"output_tokens": response.usage.output_tokens,
|
987
1294
|
"total_tokens": response.usage.input_tokens + response.usage.output_tokens
|
988
1295
|
}
|
989
|
-
}
|
1296
|
+
}
|
1297
|
+
|
1298
|
+
# Add a global context-preserving gather function
|
1299
|
+
# async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
|
1300
|
+
# """ # REMOVED
|
1301
|
+
# A wrapper around asyncio.gather that ensures the trace context # REMOVED
|
1302
|
+
# is available within the gathered coroutines using contextvars.copy_context. # REMOVED
|
1303
|
+
# """ # REMOVED
|
1304
|
+
# # Get the original asyncio.gather (if we patched it) # REMOVED
|
1305
|
+
# original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
|
1306
|
+
# # REMOVED
|
1307
|
+
# # Use contextvars.copy_context() to ensure context propagation # REMOVED
|
1308
|
+
# ctx = contextvars.copy_context() # REMOVED
|
1309
|
+
# # REMOVED
|
1310
|
+
# # Wrap the gather call within the copied context # REMOVED
|
1311
|
+
# return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
|
1312
|
+
|
1313
|
+
# Store the original gather and apply the patch *once*
|
1314
|
+
# global _original_gather_stored # REMOVED
|
1315
|
+
# if not globals().get('_original_gather_stored'): # REMOVED
|
1316
|
+
# # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
|
1317
|
+
# if asyncio.gather.__name__ != 'trace_gather': # REMOVED
|
1318
|
+
# asyncio._original_gather = asyncio.gather # REMOVED
|
1319
|
+
# asyncio.gather = trace_gather # REMOVED
|
1320
|
+
# _original_gather_stored = True # REMOVED
|
1321
|
+
|
1322
|
+
# Add the new TraceThreadPoolExecutor class
|
1323
|
+
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|
1324
|
+
"""
|
1325
|
+
A ThreadPoolExecutor subclass that automatically propagates contextvars
|
1326
|
+
from the submitting thread to the worker thread using copy_context().run().
|
1327
|
+
|
1328
|
+
This ensures that context variables like `current_trace_var` and
|
1329
|
+
`current_span_var` are available within functions executed by the pool,
|
1330
|
+
allowing the Tracer to maintain correct parent-child relationships across
|
1331
|
+
thread boundaries.
|
1332
|
+
"""
|
1333
|
+
def submit(self, fn, /, *args, **kwargs):
|
1334
|
+
"""
|
1335
|
+
Submit a callable to be executed with the captured context.
|
1336
|
+
"""
|
1337
|
+
# Capture context from the submitting thread
|
1338
|
+
ctx = contextvars.copy_context()
|
1339
|
+
|
1340
|
+
# We use functools.partial to bind the arguments to the function *now*,
|
1341
|
+
# as ctx.run doesn't directly accept *args, **kwargs in the same way
|
1342
|
+
# submit does. It expects ctx.run(callable, arg1, arg2...).
|
1343
|
+
func_with_bound_args = functools.partial(fn, *args, **kwargs)
|
1344
|
+
|
1345
|
+
# Submit the ctx.run callable to the original executor.
|
1346
|
+
# ctx.run will execute the (now argument-bound) function within the
|
1347
|
+
# captured context in the worker thread.
|
1348
|
+
return super().submit(ctx.run, func_with_bound_args)
|
1349
|
+
|
1350
|
+
# Note: The `map` method would also need to be overridden for full context
|
1351
|
+
# propagation if users rely on it, but `submit` is the most common use case.
|