judgeval 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -10,16 +10,18 @@ import os
10
10
  import time
11
11
  import uuid
12
12
  import warnings
13
+ import contextvars
13
14
  from contextlib import contextmanager
14
15
  from dataclasses import dataclass, field
15
16
  from datetime import datetime
16
17
  from http import HTTPStatus
17
- from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union
18
+ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
18
19
  from rich import print as rprint
19
20
 
20
21
  # Third-party imports
21
22
  import pika
22
23
  import requests
24
+ from litellm import cost_per_token
23
25
  from pydantic import BaseModel
24
26
  from rich import print as rprint
25
27
  from openai import OpenAI
@@ -44,6 +46,12 @@ from judgeval.rules import Rule
44
46
  from judgeval.evaluation_run import EvaluationRun
45
47
  from judgeval.data.result import ScoringResult
46
48
 
49
+ # Standard library imports needed for the new class
50
+ import concurrent.futures
51
+
52
+ # Define context variables for tracking the current trace and the current span within a trace
53
+ current_trace_var = contextvars.ContextVar('current_trace', default=None)
54
+ current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
47
55
 
48
56
  # Define type aliases for better code readability and maintainability
49
57
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
@@ -62,6 +70,7 @@ class TraceEntry:
62
70
  """
63
71
  type: TraceEntryType
64
72
  function: str # Name of the function being traced
73
+ span_id: str # Unique ID for this specific span instance
65
74
  depth: int # Indentation level for nested calls
66
75
  message: str # Human-readable description
67
76
  timestamp: float # Unix timestamp when entry was created
@@ -71,20 +80,28 @@ class TraceEntry:
71
80
  inputs: dict = field(default_factory=dict)
72
81
  span_type: SpanType = "span"
73
82
  evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
83
+ parent_span_id: Optional[str] = None # ID of the parent span instance
74
84
 
75
85
  def print_entry(self):
86
+ """Print a trace entry with proper formatting and parent relationship information."""
76
87
  indent = " " * self.depth
88
+
77
89
  if self.type == "enter":
78
- print(f"{indent}→ {self.function} (trace: {self.message})")
90
+ # Format parent info if present
91
+ parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
92
+ print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
79
93
  elif self.type == "exit":
80
- print(f"{indent}← {self.function} ({self.duration:.3f}s)")
94
+ print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
81
95
  elif self.type == "output":
82
- print(f"{indent}Output: {self.output}")
96
+ # Format output to align properly
97
+ output_str = str(self.output)
98
+ print(f"{indent}Output (for id: {self.span_id}): {output_str}")
83
99
  elif self.type == "input":
84
- print(f"{indent}Input: {self.inputs}")
100
+ # Format inputs to align properly
101
+ print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
85
102
  elif self.type == "evaluation":
86
103
  for evaluation_run in self.evaluation_runs:
87
- print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
104
+ print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
88
105
 
89
106
  def _serialize_inputs(self) -> dict:
90
107
  """Helper method to serialize input data safely.
@@ -143,6 +160,7 @@ class TraceEntry:
143
160
  return {
144
161
  "type": self.type,
145
162
  "function": self.function,
163
+ "span_id": self.span_id,
146
164
  "depth": self.depth,
147
165
  "message": self.message,
148
166
  "timestamp": self.timestamp,
@@ -150,7 +168,8 @@ class TraceEntry:
150
168
  "output": self._serialize_output(),
151
169
  "inputs": self._serialize_inputs(),
152
170
  "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
153
- "span_type": self.span_type
171
+ "span_type": self.span_type,
172
+ "parent_span_id": self.parent_span_id
154
173
  }
155
174
 
156
175
  def _serialize_output(self) -> Any:
@@ -314,63 +333,79 @@ class TraceClient:
314
333
  overwrite: bool = False,
315
334
  rules: Optional[List[Rule]] = None,
316
335
  enable_monitoring: bool = True,
317
- enable_evaluations: bool = True
336
+ enable_evaluations: bool = True,
337
+ parent_trace_id: Optional[str] = None,
338
+ parent_name: Optional[str] = None
318
339
  ):
319
340
  self.name = name
320
341
  self.trace_id = trace_id or str(uuid.uuid4())
321
342
  self.project_name = project_name
322
343
  self.overwrite = overwrite
323
344
  self.tracer = tracer
324
- # Initialize rules with either provided rules or an empty list
325
345
  self.rules = rules or []
326
346
  self.enable_monitoring = enable_monitoring
327
347
  self.enable_evaluations = enable_evaluations
328
-
348
+ self.parent_trace_id = parent_trace_id
349
+ self.parent_name = parent_name
329
350
  self.client: JudgmentClient = tracer.client
330
351
  self.entries: List[TraceEntry] = []
331
352
  self.start_time = time.time()
332
- self.span_type = None
333
- self._current_span: Optional[TraceEntry] = None
334
- self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id) # Manages DB operations for trace data
353
+ self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
354
+ self.visited_nodes = []
355
+ self.executed_tools = []
356
+ self.executed_node_tools = []
357
+ self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
335
358
 
336
359
  @contextmanager
337
360
  def span(self, name: str, span_type: SpanType = "span"):
338
- """Context manager for creating a trace span"""
361
+ """Context manager for creating a trace span, managing the current span via contextvars"""
339
362
  start_time = time.time()
340
363
 
341
- # Record span entry
342
- self.add_entry(TraceEntry(
364
+ # Generate a unique ID for *this specific span invocation*
365
+ span_id = str(uuid.uuid4())
366
+
367
+ parent_span_id = current_span_var.get() # Get ID of the parent span from context var
368
+ token = current_span_var.set(span_id) # Set *this* span's ID as the current one
369
+
370
+ current_depth = 0
371
+ if parent_span_id and parent_span_id in self._span_depths:
372
+ current_depth = self._span_depths[parent_span_id] + 1
373
+
374
+ self._span_depths[span_id] = current_depth # Store depth by span_id
375
+
376
+ entry = TraceEntry(
343
377
  type="enter",
344
378
  function=name,
345
- depth=self.tracer.depth,
379
+ span_id=span_id, # Use the generated span_id
380
+ depth=current_depth,
346
381
  message=name,
347
382
  timestamp=start_time,
348
- span_type=span_type
349
- ))
350
-
351
- # Increment nested depth and set current span
352
- self.tracer.depth += 1
353
- prev_span = self._current_span
354
- self._current_span = name
383
+ span_type=span_type,
384
+ parent_span_id=parent_span_id # Use the parent_id from context var
385
+ )
386
+ self.add_entry(entry)
355
387
 
356
388
  try:
357
389
  yield self
358
390
  finally:
359
- self.tracer.depth -= 1
360
391
  duration = time.time() - start_time
361
-
362
- # Record span exit
392
+ exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
363
393
  self.add_entry(TraceEntry(
364
394
  type="exit",
365
395
  function=name,
366
- depth=self.tracer.depth,
396
+ span_id=span_id, # Use the same span_id for exit
397
+ depth=exit_depth,
367
398
  message=f"← {name}",
368
399
  timestamp=time.time(),
369
400
  duration=duration,
370
401
  span_type=span_type
371
402
  ))
372
- self._current_span = prev_span
373
-
403
+ # Clean up depth tracking for this span_id
404
+ if span_id in self._span_depths:
405
+ del self._span_depths[span_id]
406
+ # Reset context var
407
+ current_span_var.reset(token)
408
+
374
409
  def async_evaluate(
375
410
  self,
376
411
  scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
@@ -453,7 +488,7 @@ class TraceClient:
453
488
  log_results=log_results,
454
489
  project_name=self.project_name,
455
490
  eval_name=f"{self.name.capitalize()}-"
456
- f"{self._current_span}-"
491
+ f"{current_span_var.get()}-"
457
492
  f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
458
493
  examples=[example],
459
494
  scorers=loaded_scorers,
@@ -467,24 +502,30 @@ class TraceClient:
467
502
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
468
503
 
469
504
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
470
- """
471
- Add evaluation run data to the trace
472
-
473
- Args:
474
- eval_run (EvaluationRun): The evaluation run to add to the trace
475
- start_time (float): The start time of the evaluation run
476
- """
477
- if self._current_span:
478
- duration = time.time() - start_time # Calculate duration from start_time
479
-
480
- prev_entry = self.entries[-1]
505
+ current_span_id = current_span_var.get()
506
+ if current_span_id:
507
+ duration = time.time() - start_time
508
+ prev_entry = self.entries[-1] if self.entries else None
509
+ # Determine function name based on previous entry or context var (less ideal)
510
+ function_name = "unknown_function" # Default
511
+ if prev_entry and prev_entry.span_type == "llm":
512
+ function_name = prev_entry.function
513
+ else:
514
+ # Try to find the function name associated with the current span_id
515
+ for entry in reversed(self.entries):
516
+ if entry.span_id == current_span_id and entry.type == 'enter':
517
+ function_name = entry.function
518
+ break
481
519
 
482
- # Select the last entry in the trace if it's an LLM call, otherwise use the current span
520
+ # Get depth for the current span
521
+ current_depth = self._span_depths.get(current_span_id, 0)
522
+
483
523
  self.add_entry(TraceEntry(
484
524
  type="evaluation",
485
- function=prev_entry.function if prev_entry.span_type == "llm" else self._current_span,
486
- depth=self.tracer.depth,
487
- message=f"Evaluation results for {self._current_span}",
525
+ function=function_name,
526
+ span_id=current_span_id, # Associate with current span
527
+ depth=current_depth,
528
+ message=f"Evaluation results for {function_name}",
488
529
  timestamp=time.time(),
489
530
  evaluation_runs=[eval_run],
490
531
  duration=duration,
@@ -492,16 +533,26 @@ class TraceClient:
492
533
  ))
493
534
 
494
535
  def record_input(self, inputs: dict):
495
- """Record input parameters for the current span"""
496
- if self._current_span:
536
+ current_span_id = current_span_var.get()
537
+ if current_span_id:
538
+ entry_span_type = "span"
539
+ current_depth = self._span_depths.get(current_span_id, 0)
540
+ function_name = "unknown_function" # Default
541
+ for entry in reversed(self.entries):
542
+ if entry.span_id == current_span_id and entry.type == 'enter':
543
+ entry_span_type = entry.span_type
544
+ function_name = entry.function
545
+ break
546
+
497
547
  self.add_entry(TraceEntry(
498
548
  type="input",
499
- function=self._current_span,
500
- depth=self.tracer.depth,
501
- message=f"Inputs to {self._current_span}",
549
+ function=function_name,
550
+ span_id=current_span_id, # Use current span_id
551
+ depth=current_depth,
552
+ message=f"Inputs to {function_name}",
502
553
  timestamp=time.time(),
503
554
  inputs=inputs,
504
- span_type=self.span_type
555
+ span_type=entry_span_type
505
556
  ))
506
557
 
507
558
  async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
@@ -515,21 +566,30 @@ class TraceClient:
515
566
  raise
516
567
 
517
568
  def record_output(self, output: Any):
518
- """Record output for the current span"""
519
- if self._current_span:
569
+ current_span_id = current_span_var.get()
570
+ if current_span_id:
571
+ entry_span_type = "span"
572
+ current_depth = self._span_depths.get(current_span_id, 0)
573
+ function_name = "unknown_function" # Default
574
+ for entry in reversed(self.entries):
575
+ if entry.span_id == current_span_id and entry.type == 'enter':
576
+ entry_span_type = entry.span_type
577
+ function_name = entry.function
578
+ break
579
+
520
580
  entry = TraceEntry(
521
581
  type="output",
522
- function=self._current_span,
523
- depth=self.tracer.depth,
524
- message=f"Output from {self._current_span}",
582
+ function=function_name,
583
+ span_id=current_span_id, # Use current span_id
584
+ depth=current_depth,
585
+ message=f"Output from {function_name}",
525
586
  timestamp=time.time(),
526
587
  output="<pending>" if inspect.iscoroutine(output) else output,
527
- span_type=self.span_type
588
+ span_type=entry_span_type
528
589
  )
529
590
  self.add_entry(entry)
530
591
 
531
592
  if inspect.iscoroutine(output):
532
- # Create a task to update the output once the coroutine completes
533
593
  asyncio.create_task(self._update_coroutine_output(entry, output))
534
594
 
535
595
  def add_entry(self, entry: TraceEntry):
@@ -542,6 +602,58 @@ class TraceClient:
542
602
  for entry in self.entries:
543
603
  entry.print_entry()
544
604
 
605
+ def print_hierarchical(self):
606
+ """Print the trace in a hierarchical structure based on parent-child relationships"""
607
+ # First, build a map of spans
608
+ spans = {}
609
+ root_spans = []
610
+
611
+ # Collect all enter events first
612
+ for entry in self.entries:
613
+ if entry.type == "enter":
614
+ spans[entry.function] = {
615
+ "name": entry.function,
616
+ "depth": entry.depth,
617
+ "parent_id": entry.parent_span_id,
618
+ "children": []
619
+ }
620
+
621
+ # If no parent, it's a root span
622
+ if not entry.parent_span_id:
623
+ root_spans.append(entry.function)
624
+ elif entry.parent_span_id not in spans:
625
+ # If parent doesn't exist yet, temporarily treat as root
626
+ # (we'll fix this later)
627
+ root_spans.append(entry.function)
628
+
629
+ # Build parent-child relationships
630
+ for span_name, span in spans.items():
631
+ parent = span["parent_id"]
632
+ if parent and parent in spans:
633
+ spans[parent]["children"].append(span_name)
634
+ # Remove from root spans if it was temporarily there
635
+ if span_name in root_spans:
636
+ root_spans.remove(span_name)
637
+
638
+ # Now print the hierarchy
639
+ def print_span(span_name, level=0):
640
+ if span_name not in spans:
641
+ return
642
+
643
+ span = spans[span_name]
644
+ indent = " " * level
645
+ parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
646
+ print(f"{indent}→ {span_name}{parent_info}")
647
+
648
+ # Print children
649
+ for child in span["children"]:
650
+ print_span(child, level + 1)
651
+
652
+ # Print starting with root spans
653
+ print("\nHierarchical Trace Structure:")
654
+ for root in root_spans:
655
+ print_span(root)
656
+
545
657
  def get_duration(self) -> float:
546
658
  """
547
659
  Get the total duration of this trace
@@ -550,56 +662,122 @@ class TraceClient:
550
662
 
551
663
  def condense_trace(self, entries: List[dict]) -> List[dict]:
552
664
  """
553
- Condenses trace entries into a single entry for each function call.
665
+ Condenses trace entries into a single entry for each span instance,
666
+ preserving parent-child span relationships using span_id and parent_span_id.
554
667
  """
555
- condensed = []
556
- active_functions = [] # Stack to track nested function calls
557
- function_entries = {} # Store entries for each function
668
+ spans_by_id: Dict[str, dict] = {}
669
+
670
+ # First pass: Group entries by span_id and gather data
671
+ for entry in entries:
672
+ span_id = entry.get("span_id")
673
+ if not span_id:
674
+ continue # Skip entries without a span_id (should not happen)
558
675
 
559
- for i, entry in enumerate(entries):
560
- function = entry["function"]
561
-
562
676
  if entry["type"] == "enter":
563
- # Initialize new function entry
564
- function_entries[function] = {
565
- "depth": entry["depth"],
566
- "function": function,
567
- "timestamp": entry["timestamp"],
568
- "inputs": None,
569
- "output": None,
570
- "evaluation_runs": [],
571
- "span_type": entry.get("span_type", "span")
572
- }
573
- active_functions.append(function)
574
-
575
- elif entry["type"] == "exit" and function in active_functions:
576
- # Complete function entry
577
- current_entry = function_entries[function]
578
- current_entry["duration"] = entry["timestamp"] - current_entry["timestamp"]
579
- condensed.append(current_entry)
580
- active_functions.remove(function)
581
- # del function_entries[function]
582
-
583
- # The OR condition is to handle the LLM client case.
584
- # LLM client is a special case where we exit the span, so when we attach evaluations to it,
585
- # we have to check if the previous entry is an LLM call.
586
- elif function in active_functions or entry["type"] == "evaluation" and entries[i-1]["function"] == entry["function"]:
587
- # Update existing function entry with additional data
588
- current_entry = function_entries[function]
677
+ if span_id not in spans_by_id:
678
+ spans_by_id[span_id] = {
679
+ "span_id": span_id,
680
+ "function": entry["function"],
681
+ "depth": entry["depth"], # Use the depth recorded at entry time
682
+ "timestamp": entry["timestamp"],
683
+ "parent_span_id": entry.get("parent_span_id"),
684
+ "span_type": entry.get("span_type", "span"),
685
+ "inputs": None,
686
+ "output": None,
687
+ "evaluation_runs": [],
688
+ "duration": None
689
+ }
690
+ # Handle potential duplicate enter events if necessary (e.g., log warning)
691
+
692
+ elif span_id in spans_by_id:
693
+ current_span_data = spans_by_id[span_id]
589
694
 
590
695
  if entry["type"] == "input" and entry["inputs"]:
591
- current_entry["inputs"] = entry["inputs"]
592
-
593
- if entry["type"] == "output" and entry["output"]:
594
- current_entry["output"] = entry["output"]
595
-
596
- if entry["type"] == "evaluation" and entry["evaluation_runs"]:
597
- current_entry["evaluation_runs"] = entry["evaluation_runs"]
696
+ # Merge inputs if multiple are recorded, or just assign
697
+ if current_span_data["inputs"] is None:
698
+ current_span_data["inputs"] = entry["inputs"]
699
+ elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
700
+ current_span_data["inputs"].update(entry["inputs"])
701
+ # Add more sophisticated merging if needed
598
702
 
599
- # Sort by timestamp
600
- condensed.sort(key=lambda x: x["timestamp"])
601
-
602
- return condensed
703
+ elif entry["type"] == "output" and "output" in entry:
704
+ current_span_data["output"] = entry["output"]
705
+
706
+ elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
707
+ if current_span_data.get("evaluation_runs") is None:
708
+ current_span_data["evaluation_runs"] = []
709
+ current_span_data["evaluation_runs"].extend(entry["evaluation_runs"])
710
+
711
+ elif entry["type"] == "exit":
712
+ if current_span_data["duration"] is None: # Calculate duration only once
713
+ start_time = current_span_data.get("timestamp", entry["timestamp"])
714
+ current_span_data["duration"] = entry["timestamp"] - start_time
715
+ # Update depth if exit depth is different (though current span() implementation keeps it same)
716
+ # current_span_data["depth"] = entry["depth"]
717
+
718
+ # Convert dictionary to a list initially for easier access
719
+ spans_list = list(spans_by_id.values())
720
+
721
+ # Build tree structure (adjacency list) and find roots
722
+ children_map: Dict[Optional[str], List[dict]] = {}
723
+ roots = []
724
+ span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
725
+
726
+ for span in spans_list:
727
+ parent_id = span.get("parent_span_id")
728
+ if parent_id is None:
729
+ roots.append(span)
730
+ else:
731
+ if parent_id not in children_map:
732
+ children_map[parent_id] = []
733
+ children_map[parent_id].append(span)
734
+
735
+ # Sort roots by timestamp
736
+ roots.sort(key=lambda x: x.get("timestamp", 0))
737
+
738
+ # Perform depth-first traversal to get the final sorted list
739
+ sorted_condensed_list = []
740
+ visited = set() # To handle potential cycles, though unlikely with UUIDs
741
+
742
+ def dfs(span_data):
743
+ span_id = span_data['span_id']
744
+ if span_id in visited:
745
+ return # Avoid infinite loops in case of cycles
746
+ visited.add(span_id)
747
+
748
+ sorted_condensed_list.append(span_data) # Add parent before children
749
+
750
+ # Get children, sort them by timestamp, and visit them
751
+ span_children = children_map.get(span_id, [])
752
+ span_children.sort(key=lambda x: x.get("timestamp", 0))
753
+ for child in span_children:
754
+ # Ensure the child exists in our map before recursing
755
+ if child['span_id'] in span_map:
756
+ dfs(child)
757
+ else:
758
+ # This case might indicate an issue, but we'll add the child directly
759
+ # if its parent was processed but the child itself wasn't in the initial list?
760
+ # Or if the child's 'enter' event was missing. For robustness, add it.
761
+ if child['span_id'] not in visited:
762
+ visited.add(child['span_id'])
763
+ sorted_condensed_list.append(child)
764
+
765
+
766
+ # Start DFS from each root
767
+ for root_span in roots:
768
+ if root_span['span_id'] not in visited:
769
+ dfs(root_span)
770
+
771
+ # Handle spans that might not have been reachable from roots (orphans)
772
+ # Though ideally, all spans should descend from a root.
773
+ for span_data in spans_list:
774
+ if span_data['span_id'] not in visited:
775
+ # Decide how to handle orphans, maybe append them at the end sorted by time?
776
+ # For now, let's just add them to ensure they aren't lost.
777
+ sorted_condensed_list.append(span_data)
778
+
779
+
780
+ return sorted_condensed_list
603
781
 
604
782
  def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
605
783
  """
@@ -618,34 +796,76 @@ class TraceClient:
618
796
  total_completion_tokens = 0
619
797
  total_tokens = 0
620
798
 
799
+ total_prompt_tokens_cost = 0.0
800
+ total_completion_tokens_cost = 0.0
801
+ total_cost = 0.0
802
+
621
803
  for entry in condensed_entries:
622
804
  if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
623
- usage = entry["output"].get("usage", {})
805
+ output = entry["output"]
806
+ usage = output.get("usage", {})
807
+ model_name = entry.get("inputs", {}).get("model", "")
808
+ prompt_tokens = 0
809
+ completion_tokens = 0
810
+
624
811
  # Handle OpenAI/Together format
625
812
  if "prompt_tokens" in usage:
626
- total_prompt_tokens += usage.get("prompt_tokens", 0)
627
- total_completion_tokens += usage.get("completion_tokens", 0)
813
+ prompt_tokens = usage.get("prompt_tokens", 0)
814
+ completion_tokens = usage.get("completion_tokens", 0)
815
+ total_prompt_tokens += prompt_tokens
816
+ total_completion_tokens += completion_tokens
628
817
  # Handle Anthropic format
629
818
  elif "input_tokens" in usage:
630
- total_prompt_tokens += usage.get("input_tokens", 0)
631
- total_completion_tokens += usage.get("output_tokens", 0)
819
+ prompt_tokens = usage.get("input_tokens", 0)
820
+ completion_tokens = usage.get("output_tokens", 0)
821
+ total_prompt_tokens += prompt_tokens
822
+ total_completion_tokens += completion_tokens
823
+
632
824
  total_tokens += usage.get("total_tokens", 0)
825
+
826
+ # Calculate costs if model name is available
827
+ if model_name:
828
+ try:
829
+ prompt_cost, completion_cost = cost_per_token(
830
+ model=model_name,
831
+ prompt_tokens=prompt_tokens,
832
+ completion_tokens=completion_tokens
833
+ )
834
+ total_prompt_tokens_cost += prompt_cost
835
+ total_completion_tokens_cost += completion_cost
836
+ total_cost += prompt_cost + completion_cost
837
+
838
+ # Add cost information directly to the usage dictionary in the condensed entry
839
+ if "usage" not in output:
840
+ output["usage"] = {}
841
+ output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
842
+ output["usage"]["completion_tokens_cost_usd"] = completion_cost
843
+ output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
844
+ except Exception as e:
845
+ # If cost calculation fails, continue without adding costs
846
+ print(f"Error calculating cost for model '{model_name}': {str(e)}")
847
+ pass
633
848
 
634
849
  # Create trace document
635
850
  trace_data = {
636
851
  "trace_id": self.trace_id,
637
852
  "name": self.name,
638
853
  "project_name": self.project_name,
639
- "created_at": datetime.fromtimestamp(self.start_time).isoformat(),
854
+ "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
640
855
  "duration": total_duration,
641
856
  "token_counts": {
642
857
  "prompt_tokens": total_prompt_tokens,
643
858
  "completion_tokens": total_completion_tokens,
644
859
  "total_tokens": total_tokens,
860
+ "prompt_tokens_cost_usd": total_prompt_tokens_cost,
861
+ "completion_tokens_cost_usd": total_completion_tokens_cost,
862
+ "total_cost_usd": total_cost
645
863
  },
646
864
  "entries": condensed_entries,
647
865
  "empty_save": empty_save,
648
- "overwrite": overwrite
866
+ "overwrite": overwrite,
867
+ "parent_trace_id": self.parent_trace_id,
868
+ "parent_name": self.parent_name
649
869
  }
650
870
  # Execute asynchrous evaluation in the background
651
871
  if not empty_save: # Only send to RabbitMQ if the trace is not empty
@@ -697,12 +917,10 @@ class Tracer:
697
917
 
698
918
  if not organization_id:
699
919
  raise ValueError("Tracer must be configured with an Organization ID")
700
-
701
920
  self.api_key: str = api_key
702
921
  self.project_name: str = project_name
703
922
  self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
704
923
  self.organization_id: str = organization_id
705
- self.depth: int = 0
706
924
  self._current_trace: Optional[str] = None
707
925
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
708
926
  self.initialized: bool = True
@@ -727,6 +945,15 @@ class Tracer:
727
945
  """Start a new trace context using a context manager"""
728
946
  trace_id = str(uuid.uuid4())
729
947
  project = project_name if project_name is not None else self.project_name
948
+
949
+ # Get parent trace info from context
950
+ parent_trace = current_trace_var.get()
951
+ parent_trace_id = None
952
+ parent_name = None
953
+
954
+ if parent_trace:
955
+ parent_trace_id = parent_trace.trace_id
956
+ parent_name = parent_trace.name
730
957
 
731
958
  trace = TraceClient(
732
959
  self,
@@ -736,10 +963,13 @@ class Tracer:
736
963
  overwrite=overwrite,
737
964
  rules=self.rules, # Pass combined rules to the trace client
738
965
  enable_monitoring=self.enable_monitoring,
739
- enable_evaluations=self.enable_evaluations
966
+ enable_evaluations=self.enable_evaluations,
967
+ parent_trace_id=parent_trace_id,
968
+ parent_name=parent_name
740
969
  )
741
- prev_trace = self._current_trace
742
- self._current_trace = trace
970
+
971
+ # Set the current trace in context variables
972
+ token = current_trace_var.set(trace)
743
973
 
744
974
  # Automatically create top-level span
745
975
  with trace.span(name or "unnamed_trace") as span:
@@ -748,13 +978,14 @@ class Tracer:
748
978
  trace.save(empty_save=True, overwrite=overwrite)
749
979
  yield trace
750
980
  finally:
751
- self._current_trace = prev_trace
981
+ # Reset the context variable
982
+ current_trace_var.reset(token)
752
983
 
753
984
  def get_current_trace(self) -> Optional[TraceClient]:
754
985
  """
755
- Get the current trace context
986
+ Get the current trace context from contextvars
756
987
  """
757
- return self._current_trace
988
+ return current_trace_var.get()
758
989
 
759
990
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
760
991
  """
@@ -767,8 +998,9 @@ class Tracer:
767
998
  project_name: Optional project name override
768
999
  overwrite: Whether to overwrite existing traces
769
1000
  """
1001
+ # If monitoring is disabled, return the function as is
770
1002
  if not self.enable_monitoring:
771
- return
1003
+ return func if func else lambda f: f
772
1004
 
773
1005
  if func is None:
774
1006
  return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name, overwrite=overwrite)
@@ -779,20 +1011,56 @@ class Tracer:
779
1011
  if asyncio.iscoroutinefunction(func):
780
1012
  @functools.wraps(func)
781
1013
  async def async_wrapper(*args, **kwargs):
782
- # If there's already a trace, use it. Otherwise create a new one
783
- if self._current_trace:
784
- trace = self._current_trace
785
- else:
1014
+ # Get current trace from context
1015
+ current_trace = current_trace_var.get()
1016
+
1017
+ # If there's no current trace, create a root trace
1018
+ if not current_trace:
786
1019
  trace_id = str(uuid.uuid4())
787
- trace_name = func.__name__
788
1020
  project = project_name if project_name is not None else self.project_name
789
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules, enable_monitoring=self.enable_monitoring, enable_evaluations=self.enable_evaluations)
790
- self._current_trace = trace
791
- # Only save empty trace for the root call
792
- trace.save(empty_save=True, overwrite=overwrite)
793
-
794
- try:
795
- with trace.span(span_name, span_type=span_type) as span:
1021
+
1022
+ # Create a new trace client to serve as the root
1023
+ current_trace = TraceClient(
1024
+ self,
1025
+ trace_id,
1026
+ span_name, # MODIFIED: Use span_name directly
1027
+ project_name=project,
1028
+ overwrite=overwrite,
1029
+ rules=self.rules,
1030
+ enable_monitoring=self.enable_monitoring,
1031
+ enable_evaluations=self.enable_evaluations
1032
+ )
1033
+
1034
+ # Save empty trace and set trace context
1035
+ current_trace.save(empty_save=True, overwrite=overwrite)
1036
+ trace_token = current_trace_var.set(current_trace)
1037
+
1038
+ try:
1039
+ # Use span for the function execution within the root trace
1040
+ # This sets the current_span_var
1041
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1042
+ # Record inputs
1043
+ span.record_input({
1044
+ 'args': str(args),
1045
+ 'kwargs': kwargs
1046
+ })
1047
+
1048
+ # Execute function
1049
+ result = await func(*args, **kwargs)
1050
+
1051
+ # Record output
1052
+ span.record_output(result)
1053
+
1054
+ # Save the completed trace
1055
+ current_trace.save(empty_save=False, overwrite=overwrite)
1056
+ return result
1057
+ finally:
1058
+ # Reset trace context (span context resets automatically)
1059
+ current_trace_var.reset(trace_token)
1060
+ else:
1061
+ # Already have a trace context, just create a span in it
1062
+ # The span method handles current_span_var
1063
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
796
1064
  # Record inputs
797
1065
  span.record_input({
798
1066
  'args': str(args),
@@ -806,30 +1074,62 @@ class Tracer:
806
1074
  span.record_output(result)
807
1075
 
808
1076
  return result
809
- finally:
810
- # Only save and cleanup if this is the root observe call
811
- if self.depth == 0:
812
- trace.save(empty_save=False, overwrite=overwrite)
813
- self._current_trace = None
814
1077
 
815
1078
  return async_wrapper
816
1079
  else:
1080
+ # Non-async function implementation remains unchanged
817
1081
  @functools.wraps(func)
818
1082
  def wrapper(*args, **kwargs):
819
- # If there's already a trace, use it. Otherwise create a new one
820
- if self._current_trace:
821
- trace = self._current_trace
822
- else:
1083
+ # Get current trace from context
1084
+ current_trace = current_trace_var.get()
1085
+
1086
+ # If there's no current trace, create a root trace
1087
+ if not current_trace:
823
1088
  trace_id = str(uuid.uuid4())
824
- trace_name = func.__name__
825
1089
  project = project_name if project_name is not None else self.project_name
826
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules, enable_monitoring=self.enable_monitoring)
827
- self._current_trace = trace
828
- # Only save empty trace for the root call
829
- trace.save(empty_save=True, overwrite=overwrite)
830
-
831
- try:
832
- with trace.span(span_name, span_type=span_type) as span:
1090
+
1091
+ # Create a new trace client to serve as the root
1092
+ current_trace = TraceClient(
1093
+ self,
1094
+ trace_id,
1095
+ span_name, # MODIFIED: Use span_name directly
1096
+ project_name=project,
1097
+ overwrite=overwrite,
1098
+ rules=self.rules,
1099
+ enable_monitoring=self.enable_monitoring,
1100
+ enable_evaluations=self.enable_evaluations
1101
+ )
1102
+
1103
+ # Save empty trace and set trace context
1104
+ current_trace.save(empty_save=True, overwrite=overwrite)
1105
+ trace_token = current_trace_var.set(current_trace)
1106
+
1107
+ try:
1108
+ # Use span for the function execution within the root trace
1109
+ # This sets the current_span_var
1110
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1111
+ # Record inputs
1112
+ span.record_input({
1113
+ 'args': str(args),
1114
+ 'kwargs': kwargs
1115
+ })
1116
+
1117
+ # Execute function
1118
+ result = func(*args, **kwargs)
1119
+
1120
+ # Record output
1121
+ span.record_output(result)
1122
+
1123
+ # Save the completed trace
1124
+ current_trace.save(empty_save=False, overwrite=overwrite)
1125
+ return result
1126
+ finally:
1127
+ # Reset trace context (span context resets automatically)
1128
+ current_trace_var.reset(trace_token)
1129
+ else:
1130
+ # Already have a trace context, just create a span in it
1131
+ # The span method handles current_span_var
1132
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
833
1133
  # Record inputs
834
1134
  span.record_input({
835
1135
  'args': str(args),
@@ -843,11 +1143,6 @@ class Tracer:
843
1143
  span.record_output(result)
844
1144
 
845
1145
  return result
846
- finally:
847
- # Only save and cleanup if this is the root observe call
848
- if self.depth == 0:
849
- trace.save(empty_save=False, overwrite=overwrite)
850
- self._current_trace = None
851
1146
 
852
1147
  return wrapper
853
1148
 
@@ -856,24 +1151,36 @@ class Tracer:
856
1151
  Decorator to trace function execution with detailed entry/exit information.
857
1152
  """
858
1153
  if func is None:
859
- return lambda f: self.observe(f, name=name, span_type=span_type)
1154
+ return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
860
1155
 
861
1156
  if asyncio.iscoroutinefunction(func):
862
1157
  @functools.wraps(func)
863
1158
  async def async_wrapper(*args, **kwargs):
864
- if self._current_trace:
865
- self._current_trace.async_evaluate(scorers=[scorers], input=args, actual_output=kwargs, model=model, log_results=log_results)
1159
+ # Get current trace from contextvars
1160
+ current_trace = current_trace_var.get()
1161
+ if current_trace and scorers:
1162
+ current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1163
+ return await func(*args, **kwargs)
866
1164
  return async_wrapper
867
1165
  else:
868
1166
  @functools.wraps(func)
869
1167
  def wrapper(*args, **kwargs):
870
- if self._current_trace:
871
- self._current_trace.async_evaluate(scorers=[scorers], input=args, actual_output=kwargs, model="gpt-4o-mini", log_results=True)
1168
+ # Get current trace from contextvars
1169
+ current_trace = current_trace_var.get()
1170
+ if current_trace and scorers:
1171
+ current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1172
+ return func(*args, **kwargs)
872
1173
  return wrapper
873
1174
 
874
1175
  def async_evaluate(self, *args, **kwargs):
875
- if self._current_trace:
876
- self._current_trace.async_evaluate(*args, **kwargs)
1176
+ if not self.enable_evaluations:
1177
+ return
1178
+
1179
+ # Get current trace from context
1180
+ current_trace = current_trace_var.get()
1181
+
1182
+ if current_trace:
1183
+ current_trace.async_evaluate(*args, **kwargs)
877
1184
  else:
878
1185
  warnings.warn("No trace found, skipping evaluation")
879
1186
 
@@ -887,14 +1194,14 @@ def wrap(client: Any) -> Any:
887
1194
  span_name, original_create = _get_client_config(client)
888
1195
 
889
1196
  def traced_create(*args, **kwargs):
890
- # Get the current tracer instance (might be created after client was wrapped)
891
- tracer = Tracer._instance
1197
+ # Get the current trace from contextvars
1198
+ current_trace = current_trace_var.get()
892
1199
 
893
- # Skip tracing if no tracer exists or no active trace
894
- if not tracer or not tracer._current_trace:
1200
+ # Skip tracing if no active trace
1201
+ if not current_trace:
895
1202
  return original_create(*args, **kwargs)
896
1203
 
897
- with tracer._current_trace.span(span_name, span_type="llm") as span:
1204
+ with current_trace.span(span_name, span_type="llm") as span:
898
1205
  # Format and record the input parameters
899
1206
  input_data = _format_input_data(client, **kwargs)
900
1207
  span.record_input(input_data)
@@ -986,4 +1293,59 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
986
1293
  "output_tokens": response.usage.output_tokens,
987
1294
  "total_tokens": response.usage.input_tokens + response.usage.output_tokens
988
1295
  }
989
- }
1296
+ }
1297
+
1298
+ # Add a global context-preserving gather function
1299
+ # async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
1300
+ # """ # REMOVED
1301
+ # A wrapper around asyncio.gather that ensures the trace context # REMOVED
1302
+ # is available within the gathered coroutines using contextvars.copy_context. # REMOVED
1303
+ # """ # REMOVED
1304
+ # # Get the original asyncio.gather (if we patched it) # REMOVED
1305
+ # original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
1306
+ # # REMOVED
1307
+ # # Use contextvars.copy_context() to ensure context propagation # REMOVED
1308
+ # ctx = contextvars.copy_context() # REMOVED
1309
+ # # REMOVED
1310
+ # # Wrap the gather call within the copied context # REMOVED
1311
+ # return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
1312
+
1313
+ # Store the original gather and apply the patch *once*
1314
+ # global _original_gather_stored # REMOVED
1315
+ # if not globals().get('_original_gather_stored'): # REMOVED
1316
+ # # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
1317
+ # if asyncio.gather.__name__ != 'trace_gather': # REMOVED
1318
+ # asyncio._original_gather = asyncio.gather # REMOVED
1319
+ # asyncio.gather = trace_gather # REMOVED
1320
+ # _original_gather_stored = True # REMOVED
1321
+
1322
+ # Add the new TraceThreadPoolExecutor class
1323
+ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
1324
+ """
1325
+ A ThreadPoolExecutor subclass that automatically propagates contextvars
1326
+ from the submitting thread to the worker thread using copy_context().run().
1327
+
1328
+ This ensures that context variables like `current_trace_var` and
1329
+ `current_span_var` are available within functions executed by the pool,
1330
+ allowing the Tracer to maintain correct parent-child relationships across
1331
+ thread boundaries.
1332
+ """
1333
+ def submit(self, fn, /, *args, **kwargs):
1334
+ """
1335
+ Submit a callable to be executed with the captured context.
1336
+ """
1337
+ # Capture context from the submitting thread
1338
+ ctx = contextvars.copy_context()
1339
+
1340
+ # We use functools.partial to bind the arguments to the function *now*,
1341
+ # as ctx.run doesn't directly accept *args, **kwargs in the same way
1342
+ # submit does. It expects ctx.run(callable, arg1, arg2...).
1343
+ func_with_bound_args = functools.partial(fn, *args, **kwargs)
1344
+
1345
+ # Submit the ctx.run callable to the original executor.
1346
+ # ctx.run will execute the (now argument-bound) function within the
1347
+ # captured context in the worker thread.
1348
+ return super().submit(ctx.run, func_with_bound_args)
1349
+
1350
+ # Note: The `map` method would also need to be overridden for full context
1351
+ # propagation if users rely on it, but `submit` is the most common use case.