judgeval 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -10,11 +10,12 @@ import os
10
10
  import time
11
11
  import uuid
12
12
  import warnings
13
+ import contextvars
13
14
  from contextlib import contextmanager
14
15
  from dataclasses import dataclass, field
15
16
  from datetime import datetime
16
17
  from http import HTTPStatus
17
- from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union
18
+ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
18
19
  from rich import print as rprint
19
20
 
20
21
  # Third-party imports
@@ -45,6 +46,12 @@ from judgeval.rules import Rule
45
46
  from judgeval.evaluation_run import EvaluationRun
46
47
  from judgeval.data.result import ScoringResult
47
48
 
49
+ # Standard library imports needed for the new class
50
+ import concurrent.futures
51
+
52
+ # Define context variables for tracking the current trace and the current span within a trace
53
+ current_trace_var = contextvars.ContextVar('current_trace', default=None)
54
+ current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
48
55
 
49
56
  # Define type aliases for better code readability and maintainability
50
57
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
@@ -63,6 +70,7 @@ class TraceEntry:
63
70
  """
64
71
  type: TraceEntryType
65
72
  function: str # Name of the function being traced
73
+ span_id: str # Unique ID for this specific span instance
66
74
  depth: int # Indentation level for nested calls
67
75
  message: str # Human-readable description
68
76
  timestamp: float # Unix timestamp when entry was created
@@ -72,20 +80,28 @@ class TraceEntry:
72
80
  inputs: dict = field(default_factory=dict)
73
81
  span_type: SpanType = "span"
74
82
  evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
83
+ parent_span_id: Optional[str] = None # ID of the parent span instance
75
84
 
76
85
  def print_entry(self):
86
+ """Print a trace entry with proper formatting and parent relationship information."""
77
87
  indent = " " * self.depth
88
+
78
89
  if self.type == "enter":
79
- print(f"{indent}→ {self.function} (trace: {self.message})")
90
+ # Format parent info if present
91
+ parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
92
+ print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
80
93
  elif self.type == "exit":
81
- print(f"{indent}← {self.function} ({self.duration:.3f}s)")
94
+ print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
82
95
  elif self.type == "output":
83
- print(f"{indent}Output: {self.output}")
96
+ # Format output to align properly
97
+ output_str = str(self.output)
98
+ print(f"{indent}Output (for id: {self.span_id}): {output_str}")
84
99
  elif self.type == "input":
85
- print(f"{indent}Input: {self.inputs}")
100
+ # Format inputs to align properly
101
+ print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
86
102
  elif self.type == "evaluation":
87
103
  for evaluation_run in self.evaluation_runs:
88
- print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
104
+ print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
89
105
 
90
106
  def _serialize_inputs(self) -> dict:
91
107
  """Helper method to serialize input data safely.
@@ -144,6 +160,7 @@ class TraceEntry:
144
160
  return {
145
161
  "type": self.type,
146
162
  "function": self.function,
163
+ "span_id": self.span_id,
147
164
  "depth": self.depth,
148
165
  "message": self.message,
149
166
  "timestamp": self.timestamp,
@@ -151,7 +168,8 @@ class TraceEntry:
151
168
  "output": self._serialize_output(),
152
169
  "inputs": self._serialize_inputs(),
153
170
  "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
154
- "span_type": self.span_type
171
+ "span_type": self.span_type,
172
+ "parent_span_id": self.parent_span_id
155
173
  }
156
174
 
157
175
  def _serialize_output(self) -> Any:
@@ -315,66 +333,79 @@ class TraceClient:
315
333
  overwrite: bool = False,
316
334
  rules: Optional[List[Rule]] = None,
317
335
  enable_monitoring: bool = True,
318
- enable_evaluations: bool = True
336
+ enable_evaluations: bool = True,
337
+ parent_trace_id: Optional[str] = None,
338
+ parent_name: Optional[str] = None
319
339
  ):
320
340
  self.name = name
321
341
  self.trace_id = trace_id or str(uuid.uuid4())
322
342
  self.project_name = project_name
323
343
  self.overwrite = overwrite
324
344
  self.tracer = tracer
325
- # Initialize rules with either provided rules or an empty list
326
345
  self.rules = rules or []
327
346
  self.enable_monitoring = enable_monitoring
328
347
  self.enable_evaluations = enable_evaluations
329
-
348
+ self.parent_trace_id = parent_trace_id
349
+ self.parent_name = parent_name
330
350
  self.client: JudgmentClient = tracer.client
331
351
  self.entries: List[TraceEntry] = []
332
352
  self.start_time = time.time()
333
- self.span_type = None
334
- self._current_span: Optional[TraceEntry] = None
335
- self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id) # Manages DB operations for trace data
336
- self.visited_nodes = [] # Track nodes visited through langgraph_node spans
337
- self.executed_tools = [] # Track tools executed through tool spans
338
- self.executed_node_tools = [] # Track node:tool combinations
353
+ self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
354
+ self.visited_nodes = []
355
+ self.executed_tools = []
356
+ self.executed_node_tools = []
357
+ self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
339
358
 
340
359
  @contextmanager
341
360
  def span(self, name: str, span_type: SpanType = "span"):
342
- """Context manager for creating a trace span"""
361
+ """Context manager for creating a trace span, managing the current span via contextvars"""
343
362
  start_time = time.time()
344
363
 
345
- # Record span entry
346
- self.add_entry(TraceEntry(
364
+ # Generate a unique ID for *this specific span invocation*
365
+ span_id = str(uuid.uuid4())
366
+
367
+ parent_span_id = current_span_var.get() # Get ID of the parent span from context var
368
+ token = current_span_var.set(span_id) # Set *this* span's ID as the current one
369
+
370
+ current_depth = 0
371
+ if parent_span_id and parent_span_id in self._span_depths:
372
+ current_depth = self._span_depths[parent_span_id] + 1
373
+
374
+ self._span_depths[span_id] = current_depth # Store depth by span_id
375
+
376
+ entry = TraceEntry(
347
377
  type="enter",
348
378
  function=name,
349
- depth=self.tracer.depth,
379
+ span_id=span_id, # Use the generated span_id
380
+ depth=current_depth,
350
381
  message=name,
351
382
  timestamp=start_time,
352
- span_type=span_type
353
- ))
354
-
355
- # Increment nested depth and set current span
356
- self.tracer.depth += 1
357
- prev_span = self._current_span
358
- self._current_span = name
383
+ span_type=span_type,
384
+ parent_span_id=parent_span_id # Use the parent_id from context var
385
+ )
386
+ self.add_entry(entry)
359
387
 
360
388
  try:
361
389
  yield self
362
390
  finally:
363
- self.tracer.depth -= 1
364
391
  duration = time.time() - start_time
365
-
366
- # Record span exit
392
+ exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
367
393
  self.add_entry(TraceEntry(
368
394
  type="exit",
369
395
  function=name,
370
- depth=self.tracer.depth,
396
+ span_id=span_id, # Use the same span_id for exit
397
+ depth=exit_depth,
371
398
  message=f"← {name}",
372
399
  timestamp=time.time(),
373
400
  duration=duration,
374
401
  span_type=span_type
375
402
  ))
376
- self._current_span = prev_span
377
-
403
+ # Clean up depth tracking for this span_id
404
+ if span_id in self._span_depths:
405
+ del self._span_depths[span_id]
406
+ # Reset context var
407
+ current_span_var.reset(token)
408
+
378
409
  def async_evaluate(
379
410
  self,
380
411
  scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
@@ -457,7 +488,7 @@ class TraceClient:
457
488
  log_results=log_results,
458
489
  project_name=self.project_name,
459
490
  eval_name=f"{self.name.capitalize()}-"
460
- f"{self._current_span}-"
491
+ f"{current_span_var.get()}-"
461
492
  f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
462
493
  examples=[example],
463
494
  scorers=loaded_scorers,
@@ -471,24 +502,30 @@ class TraceClient:
471
502
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
472
503
 
473
504
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
474
- """
475
- Add evaluation run data to the trace
476
-
477
- Args:
478
- eval_run (EvaluationRun): The evaluation run to add to the trace
479
- start_time (float): The start time of the evaluation run
480
- """
481
- if self._current_span:
482
- duration = time.time() - start_time # Calculate duration from start_time
483
-
484
- prev_entry = self.entries[-1]
505
+ current_span_id = current_span_var.get()
506
+ if current_span_id:
507
+ duration = time.time() - start_time
508
+ prev_entry = self.entries[-1] if self.entries else None
509
+ # Determine function name based on previous entry or context var (less ideal)
510
+ function_name = "unknown_function" # Default
511
+ if prev_entry and prev_entry.span_type == "llm":
512
+ function_name = prev_entry.function
513
+ else:
514
+ # Try to find the function name associated with the current span_id
515
+ for entry in reversed(self.entries):
516
+ if entry.span_id == current_span_id and entry.type == 'enter':
517
+ function_name = entry.function
518
+ break
485
519
 
486
- # Select the last entry in the trace if it's an LLM call, otherwise use the current span
520
+ # Get depth for the current span
521
+ current_depth = self._span_depths.get(current_span_id, 0)
522
+
487
523
  self.add_entry(TraceEntry(
488
524
  type="evaluation",
489
- function=prev_entry.function if prev_entry.span_type == "llm" else self._current_span,
490
- depth=self.tracer.depth,
491
- message=f"Evaluation results for {self._current_span}",
525
+ function=function_name,
526
+ span_id=current_span_id, # Associate with current span
527
+ depth=current_depth,
528
+ message=f"Evaluation results for {function_name}",
492
529
  timestamp=time.time(),
493
530
  evaluation_runs=[eval_run],
494
531
  duration=duration,
@@ -496,16 +533,26 @@ class TraceClient:
496
533
  ))
497
534
 
498
535
  def record_input(self, inputs: dict):
499
- """Record input parameters for the current span"""
500
- if self._current_span:
536
+ current_span_id = current_span_var.get()
537
+ if current_span_id:
538
+ entry_span_type = "span"
539
+ current_depth = self._span_depths.get(current_span_id, 0)
540
+ function_name = "unknown_function" # Default
541
+ for entry in reversed(self.entries):
542
+ if entry.span_id == current_span_id and entry.type == 'enter':
543
+ entry_span_type = entry.span_type
544
+ function_name = entry.function
545
+ break
546
+
501
547
  self.add_entry(TraceEntry(
502
548
  type="input",
503
- function=self._current_span,
504
- depth=self.tracer.depth,
505
- message=f"Inputs to {self._current_span}",
549
+ function=function_name,
550
+ span_id=current_span_id, # Use current span_id
551
+ depth=current_depth,
552
+ message=f"Inputs to {function_name}",
506
553
  timestamp=time.time(),
507
554
  inputs=inputs,
508
- span_type=self.span_type
555
+ span_type=entry_span_type
509
556
  ))
510
557
 
511
558
  async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
@@ -519,21 +566,30 @@ class TraceClient:
519
566
  raise
520
567
 
521
568
  def record_output(self, output: Any):
522
- """Record output for the current span"""
523
- if self._current_span:
569
+ current_span_id = current_span_var.get()
570
+ if current_span_id:
571
+ entry_span_type = "span"
572
+ current_depth = self._span_depths.get(current_span_id, 0)
573
+ function_name = "unknown_function" # Default
574
+ for entry in reversed(self.entries):
575
+ if entry.span_id == current_span_id and entry.type == 'enter':
576
+ entry_span_type = entry.span_type
577
+ function_name = entry.function
578
+ break
579
+
524
580
  entry = TraceEntry(
525
581
  type="output",
526
- function=self._current_span,
527
- depth=self.tracer.depth,
528
- message=f"Output from {self._current_span}",
582
+ function=function_name,
583
+ span_id=current_span_id, # Use current span_id
584
+ depth=current_depth,
585
+ message=f"Output from {function_name}",
529
586
  timestamp=time.time(),
530
587
  output="<pending>" if inspect.iscoroutine(output) else output,
531
- span_type=self.span_type
588
+ span_type=entry_span_type
532
589
  )
533
590
  self.add_entry(entry)
534
591
 
535
592
  if inspect.iscoroutine(output):
536
- # Create a task to update the output once the coroutine completes
537
593
  asyncio.create_task(self._update_coroutine_output(entry, output))
538
594
 
539
595
  def add_entry(self, entry: TraceEntry):
@@ -546,6 +602,58 @@ class TraceClient:
546
602
  for entry in self.entries:
547
603
  entry.print_entry()
548
604
 
605
+ def print_hierarchical(self):
606
+ """Print the trace in a hierarchical structure based on parent-child relationships"""
607
+ # First, build a map of spans
608
+ spans = {}
609
+ root_spans = []
610
+
611
+ # Collect all enter events first
612
+ for entry in self.entries:
613
+ if entry.type == "enter":
614
+ spans[entry.function] = {
615
+ "name": entry.function,
616
+ "depth": entry.depth,
617
+ "parent_id": entry.parent_span_id,
618
+ "children": []
619
+ }
620
+
621
+ # If no parent, it's a root span
622
+ if not entry.parent_span_id:
623
+ root_spans.append(entry.function)
624
+ elif entry.parent_span_id not in spans:
625
+ # If parent doesn't exist yet, temporarily treat as root
626
+ # (we'll fix this later)
627
+ root_spans.append(entry.function)
628
+
629
+ # Build parent-child relationships
630
+ for span_name, span in spans.items():
631
+ parent = span["parent_id"]
632
+ if parent and parent in spans:
633
+ spans[parent]["children"].append(span_name)
634
+ # Remove from root spans if it was temporarily there
635
+ if span_name in root_spans:
636
+ root_spans.remove(span_name)
637
+
638
+ # Now print the hierarchy
639
+ def print_span(span_name, level=0):
640
+ if span_name not in spans:
641
+ return
642
+
643
+ span = spans[span_name]
644
+ indent = " " * level
645
+ parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
646
+ print(f"{indent}→ {span_name}{parent_info}")
647
+
648
+ # Print children
649
+ for child in span["children"]:
650
+ print_span(child, level + 1)
651
+
652
+ # Print starting with root spans
653
+ print("\nHierarchical Trace Structure:")
654
+ for root in root_spans:
655
+ print_span(root)
656
+
549
657
  def get_duration(self) -> float:
550
658
  """
551
659
  Get the total duration of this trace
@@ -554,56 +662,122 @@ class TraceClient:
554
662
 
555
663
  def condense_trace(self, entries: List[dict]) -> List[dict]:
556
664
  """
557
- Condenses trace entries into a single entry for each function call.
665
+ Condenses trace entries into a single entry for each span instance,
666
+ preserving parent-child span relationships using span_id and parent_span_id.
558
667
  """
559
- condensed = []
560
- active_functions = [] # Stack to track nested function calls
561
- function_entries = {} # Store entries for each function
668
+ spans_by_id: Dict[str, dict] = {}
669
+
670
+ # First pass: Group entries by span_id and gather data
671
+ for entry in entries:
672
+ span_id = entry.get("span_id")
673
+ if not span_id:
674
+ continue # Skip entries without a span_id (should not happen)
562
675
 
563
- for i, entry in enumerate(entries):
564
- function = entry["function"]
565
-
566
676
  if entry["type"] == "enter":
567
- # Initialize new function entry
568
- function_entries[function] = {
569
- "depth": entry["depth"],
570
- "function": function,
571
- "timestamp": entry["timestamp"],
572
- "inputs": None,
573
- "output": None,
574
- "evaluation_runs": [],
575
- "span_type": entry.get("span_type", "span")
576
- }
577
- active_functions.append(function)
578
-
579
- elif entry["type"] == "exit" and function in active_functions:
580
- # Complete function entry
581
- current_entry = function_entries[function]
582
- current_entry["duration"] = entry["timestamp"] - current_entry["timestamp"]
583
- condensed.append(current_entry)
584
- active_functions.remove(function)
585
- # del function_entries[function]
586
-
587
- # The OR condition is to handle the LLM client case.
588
- # LLM client is a special case where we exit the span, so when we attach evaluations to it,
589
- # we have to check if the previous entry is an LLM call.
590
- elif function in active_functions or entry["type"] == "evaluation" and entries[i-1]["function"] == entry["function"]:
591
- # Update existing function entry with additional data
592
- current_entry = function_entries[function]
677
+ if span_id not in spans_by_id:
678
+ spans_by_id[span_id] = {
679
+ "span_id": span_id,
680
+ "function": entry["function"],
681
+ "depth": entry["depth"], # Use the depth recorded at entry time
682
+ "timestamp": entry["timestamp"],
683
+ "parent_span_id": entry.get("parent_span_id"),
684
+ "span_type": entry.get("span_type", "span"),
685
+ "inputs": None,
686
+ "output": None,
687
+ "evaluation_runs": [],
688
+ "duration": None
689
+ }
690
+ # Handle potential duplicate enter events if necessary (e.g., log warning)
691
+
692
+ elif span_id in spans_by_id:
693
+ current_span_data = spans_by_id[span_id]
593
694
 
594
695
  if entry["type"] == "input" and entry["inputs"]:
595
- current_entry["inputs"] = entry["inputs"]
596
-
597
- if entry["type"] == "output" and entry["output"]:
598
- current_entry["output"] = entry["output"]
599
-
600
- if entry["type"] == "evaluation" and entry["evaluation_runs"]:
601
- current_entry["evaluation_runs"] = entry["evaluation_runs"]
696
+ # Merge inputs if multiple are recorded, or just assign
697
+ if current_span_data["inputs"] is None:
698
+ current_span_data["inputs"] = entry["inputs"]
699
+ elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
700
+ current_span_data["inputs"].update(entry["inputs"])
701
+ # Add more sophisticated merging if needed
602
702
 
603
- # Sort by timestamp
604
- condensed.sort(key=lambda x: x["timestamp"])
605
-
606
- return condensed
703
+ elif entry["type"] == "output" and "output" in entry:
704
+ current_span_data["output"] = entry["output"]
705
+
706
+ elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
707
+ if current_span_data.get("evaluation_runs") is None:
708
+ current_span_data["evaluation_runs"] = []
709
+ current_span_data["evaluation_runs"].extend(entry["evaluation_runs"])
710
+
711
+ elif entry["type"] == "exit":
712
+ if current_span_data["duration"] is None: # Calculate duration only once
713
+ start_time = current_span_data.get("timestamp", entry["timestamp"])
714
+ current_span_data["duration"] = entry["timestamp"] - start_time
715
+ # Update depth if exit depth is different (though current span() implementation keeps it same)
716
+ # current_span_data["depth"] = entry["depth"]
717
+
718
+ # Convert dictionary to a list initially for easier access
719
+ spans_list = list(spans_by_id.values())
720
+
721
+ # Build tree structure (adjacency list) and find roots
722
+ children_map: Dict[Optional[str], List[dict]] = {}
723
+ roots = []
724
+ span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
725
+
726
+ for span in spans_list:
727
+ parent_id = span.get("parent_span_id")
728
+ if parent_id is None:
729
+ roots.append(span)
730
+ else:
731
+ if parent_id not in children_map:
732
+ children_map[parent_id] = []
733
+ children_map[parent_id].append(span)
734
+
735
+ # Sort roots by timestamp
736
+ roots.sort(key=lambda x: x.get("timestamp", 0))
737
+
738
+ # Perform depth-first traversal to get the final sorted list
739
+ sorted_condensed_list = []
740
+ visited = set() # To handle potential cycles, though unlikely with UUIDs
741
+
742
+ def dfs(span_data):
743
+ span_id = span_data['span_id']
744
+ if span_id in visited:
745
+ return # Avoid infinite loops in case of cycles
746
+ visited.add(span_id)
747
+
748
+ sorted_condensed_list.append(span_data) # Add parent before children
749
+
750
+ # Get children, sort them by timestamp, and visit them
751
+ span_children = children_map.get(span_id, [])
752
+ span_children.sort(key=lambda x: x.get("timestamp", 0))
753
+ for child in span_children:
754
+ # Ensure the child exists in our map before recursing
755
+ if child['span_id'] in span_map:
756
+ dfs(child)
757
+ else:
758
+ # This case might indicate an issue, but we'll add the child directly
759
+ # if its parent was processed but the child itself wasn't in the initial list?
760
+ # Or if the child's 'enter' event was missing. For robustness, add it.
761
+ if child['span_id'] not in visited:
762
+ visited.add(child['span_id'])
763
+ sorted_condensed_list.append(child)
764
+
765
+
766
+ # Start DFS from each root
767
+ for root_span in roots:
768
+ if root_span['span_id'] not in visited:
769
+ dfs(root_span)
770
+
771
+ # Handle spans that might not have been reachable from roots (orphans)
772
+ # Though ideally, all spans should descend from a root.
773
+ for span_data in spans_list:
774
+ if span_data['span_id'] not in visited:
775
+ # Decide how to handle orphans, maybe append them at the end sorted by time?
776
+ # For now, let's just add them to ensure they aren't lost.
777
+ sorted_condensed_list.append(span_data)
778
+
779
+
780
+ return sorted_condensed_list
607
781
 
608
782
  def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
609
783
  """
@@ -689,7 +863,9 @@ class TraceClient:
689
863
  },
690
864
  "entries": condensed_entries,
691
865
  "empty_save": empty_save,
692
- "overwrite": overwrite
866
+ "overwrite": overwrite,
867
+ "parent_trace_id": self.parent_trace_id,
868
+ "parent_name": self.parent_name
693
869
  }
694
870
  # Execute asynchrous evaluation in the background
695
871
  if not empty_save: # Only send to RabbitMQ if the trace is not empty
@@ -745,7 +921,6 @@ class Tracer:
745
921
  self.project_name: str = project_name
746
922
  self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
747
923
  self.organization_id: str = organization_id
748
- self.depth: int = 0
749
924
  self._current_trace: Optional[str] = None
750
925
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
751
926
  self.initialized: bool = True
@@ -770,6 +945,15 @@ class Tracer:
770
945
  """Start a new trace context using a context manager"""
771
946
  trace_id = str(uuid.uuid4())
772
947
  project = project_name if project_name is not None else self.project_name
948
+
949
+ # Get parent trace info from context
950
+ parent_trace = current_trace_var.get()
951
+ parent_trace_id = None
952
+ parent_name = None
953
+
954
+ if parent_trace:
955
+ parent_trace_id = parent_trace.trace_id
956
+ parent_name = parent_trace.name
773
957
 
774
958
  trace = TraceClient(
775
959
  self,
@@ -779,10 +963,13 @@ class Tracer:
779
963
  overwrite=overwrite,
780
964
  rules=self.rules, # Pass combined rules to the trace client
781
965
  enable_monitoring=self.enable_monitoring,
782
- enable_evaluations=self.enable_evaluations
966
+ enable_evaluations=self.enable_evaluations,
967
+ parent_trace_id=parent_trace_id,
968
+ parent_name=parent_name
783
969
  )
784
- prev_trace = self._current_trace
785
- self._current_trace = trace
970
+
971
+ # Set the current trace in context variables
972
+ token = current_trace_var.set(trace)
786
973
 
787
974
  # Automatically create top-level span
788
975
  with trace.span(name or "unnamed_trace") as span:
@@ -791,13 +978,14 @@ class Tracer:
791
978
  trace.save(empty_save=True, overwrite=overwrite)
792
979
  yield trace
793
980
  finally:
794
- self._current_trace = prev_trace
981
+ # Reset the context variable
982
+ current_trace_var.reset(token)
795
983
 
796
984
  def get_current_trace(self) -> Optional[TraceClient]:
797
985
  """
798
- Get the current trace context
986
+ Get the current trace context from contextvars
799
987
  """
800
- return self._current_trace
988
+ return current_trace_var.get()
801
989
 
802
990
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
803
991
  """
@@ -823,20 +1011,56 @@ class Tracer:
823
1011
  if asyncio.iscoroutinefunction(func):
824
1012
  @functools.wraps(func)
825
1013
  async def async_wrapper(*args, **kwargs):
826
- # If there's already a trace, use it. Otherwise create a new one
827
- if self._current_trace:
828
- trace = self._current_trace
829
- else:
1014
+ # Get current trace from context
1015
+ current_trace = current_trace_var.get()
1016
+
1017
+ # If there's no current trace, create a root trace
1018
+ if not current_trace:
830
1019
  trace_id = str(uuid.uuid4())
831
- trace_name = func.__name__
832
1020
  project = project_name if project_name is not None else self.project_name
833
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules, enable_monitoring=self.enable_monitoring, enable_evaluations=self.enable_evaluations)
834
- self._current_trace = trace
835
- # Only save empty trace for the root call
836
- trace.save(empty_save=True, overwrite=overwrite)
837
-
838
- try:
839
- with trace.span(span_name, span_type=span_type) as span:
1021
+
1022
+ # Create a new trace client to serve as the root
1023
+ current_trace = TraceClient(
1024
+ self,
1025
+ trace_id,
1026
+ span_name, # MODIFIED: Use span_name directly
1027
+ project_name=project,
1028
+ overwrite=overwrite,
1029
+ rules=self.rules,
1030
+ enable_monitoring=self.enable_monitoring,
1031
+ enable_evaluations=self.enable_evaluations
1032
+ )
1033
+
1034
+ # Save empty trace and set trace context
1035
+ current_trace.save(empty_save=True, overwrite=overwrite)
1036
+ trace_token = current_trace_var.set(current_trace)
1037
+
1038
+ try:
1039
+ # Use span for the function execution within the root trace
1040
+ # This sets the current_span_var
1041
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1042
+ # Record inputs
1043
+ span.record_input({
1044
+ 'args': str(args),
1045
+ 'kwargs': kwargs
1046
+ })
1047
+
1048
+ # Execute function
1049
+ result = await func(*args, **kwargs)
1050
+
1051
+ # Record output
1052
+ span.record_output(result)
1053
+
1054
+ # Save the completed trace
1055
+ current_trace.save(empty_save=False, overwrite=overwrite)
1056
+ return result
1057
+ finally:
1058
+ # Reset trace context (span context resets automatically)
1059
+ current_trace_var.reset(trace_token)
1060
+ else:
1061
+ # Already have a trace context, just create a span in it
1062
+ # The span method handles current_span_var
1063
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
840
1064
  # Record inputs
841
1065
  span.record_input({
842
1066
  'args': str(args),
@@ -850,30 +1074,62 @@ class Tracer:
850
1074
  span.record_output(result)
851
1075
 
852
1076
  return result
853
- finally:
854
- # Only save and cleanup if this is the root observe call
855
- if self.depth == 0:
856
- trace.save(empty_save=False, overwrite=overwrite)
857
- self._current_trace = None
858
1077
 
859
1078
  return async_wrapper
860
1079
  else:
1080
+ # Non-async function implementation remains unchanged
861
1081
  @functools.wraps(func)
862
1082
  def wrapper(*args, **kwargs):
863
- # If there's already a trace, use it. Otherwise create a new one
864
- if self._current_trace:
865
- trace = self._current_trace
866
- else:
1083
+ # Get current trace from context
1084
+ current_trace = current_trace_var.get()
1085
+
1086
+ # If there's no current trace, create a root trace
1087
+ if not current_trace:
867
1088
  trace_id = str(uuid.uuid4())
868
- trace_name = func.__name__
869
1089
  project = project_name if project_name is not None else self.project_name
870
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules, enable_monitoring=self.enable_monitoring)
871
- self._current_trace = trace
872
- # Only save empty trace for the root call
873
- trace.save(empty_save=True, overwrite=overwrite)
874
-
875
- try:
876
- with trace.span(span_name, span_type=span_type) as span:
1090
+
1091
+ # Create a new trace client to serve as the root
1092
+ current_trace = TraceClient(
1093
+ self,
1094
+ trace_id,
1095
+ span_name, # MODIFIED: Use span_name directly
1096
+ project_name=project,
1097
+ overwrite=overwrite,
1098
+ rules=self.rules,
1099
+ enable_monitoring=self.enable_monitoring,
1100
+ enable_evaluations=self.enable_evaluations
1101
+ )
1102
+
1103
+ # Save empty trace and set trace context
1104
+ current_trace.save(empty_save=True, overwrite=overwrite)
1105
+ trace_token = current_trace_var.set(current_trace)
1106
+
1107
+ try:
1108
+ # Use span for the function execution within the root trace
1109
+ # This sets the current_span_var
1110
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1111
+ # Record inputs
1112
+ span.record_input({
1113
+ 'args': str(args),
1114
+ 'kwargs': kwargs
1115
+ })
1116
+
1117
+ # Execute function
1118
+ result = func(*args, **kwargs)
1119
+
1120
+ # Record output
1121
+ span.record_output(result)
1122
+
1123
+ # Save the completed trace
1124
+ current_trace.save(empty_save=False, overwrite=overwrite)
1125
+ return result
1126
+ finally:
1127
+ # Reset trace context (span context resets automatically)
1128
+ current_trace_var.reset(trace_token)
1129
+ else:
1130
+ # Already have a trace context, just create a span in it
1131
+ # The span method handles current_span_var
1132
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
877
1133
  # Record inputs
878
1134
  span.record_input({
879
1135
  'args': str(args),
@@ -887,11 +1143,6 @@ class Tracer:
887
1143
  span.record_output(result)
888
1144
 
889
1145
  return result
890
- finally:
891
- # Only save and cleanup if this is the root observe call
892
- if self.depth == 0:
893
- trace.save(empty_save=False, overwrite=overwrite)
894
- self._current_trace = None
895
1146
 
896
1147
  return wrapper
897
1148
 
@@ -900,27 +1151,36 @@ class Tracer:
900
1151
  Decorator to trace function execution with detailed entry/exit information.
901
1152
  """
902
1153
  if func is None:
903
- return lambda f: self.observe(f, name=name, span_type=span_type)
1154
+ return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
904
1155
 
905
1156
  if asyncio.iscoroutinefunction(func):
906
1157
  @functools.wraps(func)
907
1158
  async def async_wrapper(*args, **kwargs):
908
- if self._current_trace:
909
- self._current_trace.async_evaluate(scorers=[scorers], input=args, actual_output=kwargs, model=model, log_results=log_results)
1159
+ # Get current trace from contextvars
1160
+ current_trace = current_trace_var.get()
1161
+ if current_trace and scorers:
1162
+ current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1163
+ return await func(*args, **kwargs)
910
1164
  return async_wrapper
911
1165
  else:
912
1166
  @functools.wraps(func)
913
1167
  def wrapper(*args, **kwargs):
914
- if self._current_trace:
915
- self._current_trace.async_evaluate(scorers=[scorers], input=args, actual_output=kwargs, model="gpt-4o-mini", log_results=True)
1168
+ # Get current trace from contextvars
1169
+ current_trace = current_trace_var.get()
1170
+ if current_trace and scorers:
1171
+ current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1172
+ return func(*args, **kwargs)
916
1173
  return wrapper
917
1174
 
918
1175
  def async_evaluate(self, *args, **kwargs):
919
1176
  if not self.enable_evaluations:
920
1177
  return
921
1178
 
922
- if self._current_trace:
923
- self._current_trace.async_evaluate(*args, **kwargs)
1179
+ # Get current trace from context
1180
+ current_trace = current_trace_var.get()
1181
+
1182
+ if current_trace:
1183
+ current_trace.async_evaluate(*args, **kwargs)
924
1184
  else:
925
1185
  warnings.warn("No trace found, skipping evaluation")
926
1186
 
@@ -934,14 +1194,14 @@ def wrap(client: Any) -> Any:
934
1194
  span_name, original_create = _get_client_config(client)
935
1195
 
936
1196
  def traced_create(*args, **kwargs):
937
- # Get the current tracer instance (might be created after client was wrapped)
938
- tracer = Tracer._instance
1197
+ # Get the current trace from contextvars
1198
+ current_trace = current_trace_var.get()
939
1199
 
940
- # Skip tracing if no tracer exists or no active trace
941
- if not tracer or not tracer._current_trace:
1200
+ # Skip tracing if no active trace
1201
+ if not current_trace:
942
1202
  return original_create(*args, **kwargs)
943
1203
 
944
- with tracer._current_trace.span(span_name, span_type="llm") as span:
1204
+ with current_trace.span(span_name, span_type="llm") as span:
945
1205
  # Format and record the input parameters
946
1206
  input_data = _format_input_data(client, **kwargs)
947
1207
  span.record_input(input_data)
@@ -1033,4 +1293,59 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1033
1293
  "output_tokens": response.usage.output_tokens,
1034
1294
  "total_tokens": response.usage.input_tokens + response.usage.output_tokens
1035
1295
  }
1036
- }
1296
+ }
1297
+
1298
+ # Add a global context-preserving gather function
1299
+ # async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
1300
+ # """ # REMOVED
1301
+ # A wrapper around asyncio.gather that ensures the trace context # REMOVED
1302
+ # is available within the gathered coroutines using contextvars.copy_context. # REMOVED
1303
+ # """ # REMOVED
1304
+ # # Get the original asyncio.gather (if we patched it) # REMOVED
1305
+ # original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
1306
+ # # REMOVED
1307
+ # # Use contextvars.copy_context() to ensure context propagation # REMOVED
1308
+ # ctx = contextvars.copy_context() # REMOVED
1309
+ # # REMOVED
1310
+ # # Wrap the gather call within the copied context # REMOVED
1311
+ # return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
1312
+
1313
+ # Store the original gather and apply the patch *once*
1314
+ # global _original_gather_stored # REMOVED
1315
+ # if not globals().get('_original_gather_stored'): # REMOVED
1316
+ # # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
1317
+ # if asyncio.gather.__name__ != 'trace_gather': # REMOVED
1318
+ # asyncio._original_gather = asyncio.gather # REMOVED
1319
+ # asyncio.gather = trace_gather # REMOVED
1320
+ # _original_gather_stored = True # REMOVED
1321
+
1322
+ # Add the new TraceThreadPoolExecutor class
1323
+ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
1324
+ """
1325
+ A ThreadPoolExecutor subclass that automatically propagates contextvars
1326
+ from the submitting thread to the worker thread using copy_context().run().
1327
+
1328
+ This ensures that context variables like `current_trace_var` and
1329
+ `current_span_var` are available within functions executed by the pool,
1330
+ allowing the Tracer to maintain correct parent-child relationships across
1331
+ thread boundaries.
1332
+ """
1333
+ def submit(self, fn, /, *args, **kwargs):
1334
+ """
1335
+ Submit a callable to be executed with the captured context.
1336
+ """
1337
+ # Capture context from the submitting thread
1338
+ ctx = contextvars.copy_context()
1339
+
1340
+ # We use functools.partial to bind the arguments to the function *now*,
1341
+ # as ctx.run doesn't directly accept *args, **kwargs in the same way
1342
+ # submit does. It expects ctx.run(callable, arg1, arg2...).
1343
+ func_with_bound_args = functools.partial(fn, *args, **kwargs)
1344
+
1345
+ # Submit the ctx.run callable to the original executor.
1346
+ # ctx.run will execute the (now argument-bound) function within the
1347
+ # captured context in the worker thread.
1348
+ return super().submit(ctx.run, func_with_bound_args)
1349
+
1350
+ # Note: The `map` method would also need to be overridden for full context
1351
+ # propagation if users rely on it, but `submit` is the most common use case.