judgeval 0.0.26__py3-none-any.whl → 0.0.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -10,11 +10,12 @@ import os
10
10
  import time
11
11
  import uuid
12
12
  import warnings
13
+ import contextvars
13
14
  from contextlib import contextmanager
14
15
  from dataclasses import dataclass, field
15
16
  from datetime import datetime
16
17
  from http import HTTPStatus
17
- from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union
18
+ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
18
19
  from rich import print as rprint
19
20
 
20
21
  # Third-party imports
@@ -45,6 +46,12 @@ from judgeval.rules import Rule
45
46
  from judgeval.evaluation_run import EvaluationRun
46
47
  from judgeval.data.result import ScoringResult
47
48
 
49
+ # Standard library imports needed for the new class
50
+ import concurrent.futures
51
+
52
+ # Define context variables for tracking the current trace and the current span within a trace
53
+ current_trace_var = contextvars.ContextVar('current_trace', default=None)
54
+ current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
48
55
 
49
56
  # Define type aliases for better code readability and maintainability
50
57
  ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
@@ -63,29 +70,39 @@ class TraceEntry:
63
70
  """
64
71
  type: TraceEntryType
65
72
  function: str # Name of the function being traced
73
+ span_id: str # Unique ID for this specific span instance
66
74
  depth: int # Indentation level for nested calls
67
75
  message: str # Human-readable description
68
- timestamp: float # Unix timestamp when entry was created
76
+ # created_at: Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
69
77
  duration: Optional[float] = None # Time taken (for exit/evaluation entries)
78
+ trace_id: str = None # ID of the trace this entry belongs to
70
79
  output: Any = None # Function output value
71
80
  # Use field() for mutable defaults to avoid shared state issues
72
81
  inputs: dict = field(default_factory=dict)
73
82
  span_type: SpanType = "span"
74
83
  evaluation_runs: List[Optional[EvaluationRun]] = field(default=None)
84
+ parent_span_id: Optional[str] = None # ID of the parent span instance
75
85
 
76
86
  def print_entry(self):
87
+ """Print a trace entry with proper formatting and parent relationship information."""
77
88
  indent = " " * self.depth
89
+
78
90
  if self.type == "enter":
79
- print(f"{indent}→ {self.function} (trace: {self.message})")
91
+ # Format parent info if present
92
+ parent_info = f" (parent_id: {self.parent_span_id})" if self.parent_span_id else ""
93
+ print(f"{indent}→ {self.function} (id: {self.span_id}){parent_info} (trace: {self.message})")
80
94
  elif self.type == "exit":
81
- print(f"{indent}← {self.function} ({self.duration:.3f}s)")
95
+ print(f"{indent}← {self.function} (id: {self.span_id}) ({self.duration:.3f}s)")
82
96
  elif self.type == "output":
83
- print(f"{indent}Output: {self.output}")
97
+ # Format output to align properly
98
+ output_str = str(self.output)
99
+ print(f"{indent}Output (for id: {self.span_id}): {output_str}")
84
100
  elif self.type == "input":
85
- print(f"{indent}Input: {self.inputs}")
101
+ # Format inputs to align properly
102
+ print(f"{indent}Input (for id: {self.span_id}): {self.inputs}")
86
103
  elif self.type == "evaluation":
87
104
  for evaluation_run in self.evaluation_runs:
88
- print(f"{indent}Evaluation: {evaluation_run.model_dump()}")
105
+ print(f"{indent}Evaluation (for id: {self.span_id}): {evaluation_run.model_dump()}")
89
106
 
90
107
  def _serialize_inputs(self) -> dict:
91
108
  """Helper method to serialize input data safely.
@@ -144,14 +161,17 @@ class TraceEntry:
144
161
  return {
145
162
  "type": self.type,
146
163
  "function": self.function,
164
+ "span_id": self.span_id,
165
+ "trace_id": self.trace_id,
147
166
  "depth": self.depth,
148
167
  "message": self.message,
149
- "timestamp": self.timestamp,
168
+ "created_at": datetime.fromtimestamp(self.created_at).isoformat(),
150
169
  "duration": self.duration,
151
170
  "output": self._serialize_output(),
152
171
  "inputs": self._serialize_inputs(),
153
172
  "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
154
- "span_type": self.span_type
173
+ "span_type": self.span_type,
174
+ "parent_span_id": self.parent_span_id
155
175
  }
156
176
 
157
177
  def _serialize_output(self) -> Any:
@@ -210,13 +230,12 @@ class TraceManagerClient:
210
230
 
211
231
  return response.json()
212
232
 
213
- def save_trace(self, trace_data: dict, empty_save: bool):
233
+ def save_trace(self, trace_data: dict):
214
234
  """
215
235
  Saves a trace to the database
216
236
 
217
237
  Args:
218
238
  trace_data: The trace data to save
219
- empty_save: Whether to save an empty trace
220
239
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
221
240
  """
222
241
  response = requests.post(
@@ -235,7 +254,7 @@ class TraceManagerClient:
235
254
  elif response.status_code != HTTPStatus.OK:
236
255
  raise ValueError(f"Failed to save trace data: {response.text}")
237
256
 
238
- if not empty_save and "ui_results_url" in response.json():
257
+ if "ui_results_url" in response.json():
239
258
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
240
259
  rprint(pretty_str)
241
260
 
@@ -315,66 +334,81 @@ class TraceClient:
315
334
  overwrite: bool = False,
316
335
  rules: Optional[List[Rule]] = None,
317
336
  enable_monitoring: bool = True,
318
- enable_evaluations: bool = True
337
+ enable_evaluations: bool = True,
338
+ parent_trace_id: Optional[str] = None,
339
+ parent_name: Optional[str] = None
319
340
  ):
320
341
  self.name = name
321
342
  self.trace_id = trace_id or str(uuid.uuid4())
322
343
  self.project_name = project_name
323
344
  self.overwrite = overwrite
324
345
  self.tracer = tracer
325
- # Initialize rules with either provided rules or an empty list
326
346
  self.rules = rules or []
327
347
  self.enable_monitoring = enable_monitoring
328
348
  self.enable_evaluations = enable_evaluations
329
-
349
+ self.parent_trace_id = parent_trace_id
350
+ self.parent_name = parent_name
330
351
  self.client: JudgmentClient = tracer.client
331
352
  self.entries: List[TraceEntry] = []
332
353
  self.start_time = time.time()
333
- self.span_type = None
334
- self._current_span: Optional[TraceEntry] = None
335
- self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id) # Manages DB operations for trace data
336
- self.visited_nodes = [] # Track nodes visited through langgraph_node spans
337
- self.executed_tools = [] # Track tools executed through tool spans
338
- self.executed_node_tools = [] # Track node:tool combinations
354
+ self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
355
+ self.visited_nodes = []
356
+ self.executed_tools = []
357
+ self.executed_node_tools = []
358
+ self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
339
359
 
340
360
  @contextmanager
341
361
  def span(self, name: str, span_type: SpanType = "span"):
342
- """Context manager for creating a trace span"""
362
+ """Context manager for creating a trace span, managing the current span via contextvars"""
343
363
  start_time = time.time()
344
364
 
345
- # Record span entry
346
- self.add_entry(TraceEntry(
365
+ # Generate a unique ID for *this specific span invocation*
366
+ span_id = str(uuid.uuid4())
367
+
368
+ parent_span_id = current_span_var.get() # Get ID of the parent span from context var
369
+ token = current_span_var.set(span_id) # Set *this* span's ID as the current one
370
+
371
+ current_depth = 0
372
+ if parent_span_id and parent_span_id in self._span_depths:
373
+ current_depth = self._span_depths[parent_span_id] + 1
374
+
375
+ self._span_depths[span_id] = current_depth # Store depth by span_id
376
+
377
+ entry = TraceEntry(
347
378
  type="enter",
348
379
  function=name,
349
- depth=self.tracer.depth,
380
+ span_id=span_id, # Use the generated span_id
381
+ trace_id=self.trace_id, # Use the trace_id from the trace client
382
+ depth=current_depth,
350
383
  message=name,
351
- timestamp=start_time,
352
- span_type=span_type
353
- ))
354
-
355
- # Increment nested depth and set current span
356
- self.tracer.depth += 1
357
- prev_span = self._current_span
358
- self._current_span = name
384
+ created_at=start_time,
385
+ span_type=span_type,
386
+ parent_span_id=parent_span_id # Use the parent_id from context var
387
+ )
388
+ self.add_entry(entry)
359
389
 
360
390
  try:
361
391
  yield self
362
392
  finally:
363
- self.tracer.depth -= 1
364
393
  duration = time.time() - start_time
365
-
366
- # Record span exit
394
+ exit_depth = self._span_depths.get(span_id, 0) # Get depth using this span's ID
367
395
  self.add_entry(TraceEntry(
368
396
  type="exit",
369
397
  function=name,
370
- depth=self.tracer.depth,
398
+ span_id=span_id, # Use the same span_id for exit
399
+ trace_id=self.trace_id, # Use the trace_id from the trace client
400
+ depth=exit_depth,
371
401
  message=f"← {name}",
372
- timestamp=time.time(),
402
+ created_at=time.time(),
373
403
  duration=duration,
374
404
  span_type=span_type
375
405
  ))
376
- self._current_span = prev_span
377
-
406
+ # Clean up depth tracking for this span_id
407
+ if span_id in self._span_depths:
408
+ del self._span_depths[span_id]
409
+ # Reset context var
410
+ current_span_var.reset(token)
411
+
378
412
  def async_evaluate(
379
413
  self,
380
414
  scorers: List[Union[APIJudgmentScorer, JudgevalScorer]],
@@ -457,7 +491,7 @@ class TraceClient:
457
491
  log_results=log_results,
458
492
  project_name=self.project_name,
459
493
  eval_name=f"{self.name.capitalize()}-"
460
- f"{self._current_span}-"
494
+ f"{current_span_var.get()}-"
461
495
  f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
462
496
  examples=[example],
463
497
  scorers=loaded_scorers,
@@ -465,47 +499,66 @@ class TraceClient:
465
499
  metadata={},
466
500
  judgment_api_key=self.tracer.api_key,
467
501
  override=self.overwrite,
502
+ trace_span_id=current_span_var.get(),
468
503
  rules=loaded_rules # Use the combined rules
469
504
  )
470
505
 
471
506
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
472
507
 
473
508
  def add_eval_run(self, eval_run: EvaluationRun, start_time: float):
474
- """
475
- Add evaluation run data to the trace
476
-
477
- Args:
478
- eval_run (EvaluationRun): The evaluation run to add to the trace
479
- start_time (float): The start time of the evaluation run
480
- """
481
- if self._current_span:
482
- duration = time.time() - start_time # Calculate duration from start_time
483
-
484
- prev_entry = self.entries[-1]
509
+ current_span_id = current_span_var.get()
510
+ if current_span_id:
511
+ duration = time.time() - start_time
512
+ prev_entry = self.entries[-1] if self.entries else None
513
+ # Determine function name based on previous entry or context var (less ideal)
514
+ function_name = "unknown_function" # Default
515
+ if prev_entry and prev_entry.span_type == "llm":
516
+ function_name = prev_entry.function
517
+ else:
518
+ # Try to find the function name associated with the current span_id
519
+ for entry in reversed(self.entries):
520
+ if entry.span_id == current_span_id and entry.type == 'enter':
521
+ function_name = entry.function
522
+ break
485
523
 
486
- # Select the last entry in the trace if it's an LLM call, otherwise use the current span
524
+ # Get depth for the current span
525
+ current_depth = self._span_depths.get(current_span_id, 0)
526
+
487
527
  self.add_entry(TraceEntry(
488
528
  type="evaluation",
489
- function=prev_entry.function if prev_entry.span_type == "llm" else self._current_span,
490
- depth=self.tracer.depth,
491
- message=f"Evaluation results for {self._current_span}",
492
- timestamp=time.time(),
529
+ function=function_name,
530
+ span_id=current_span_id, # Associate with current span
531
+ trace_id=self.trace_id, # Use the trace_id from the trace client
532
+ depth=current_depth,
533
+ message=f"Evaluation results for {function_name}",
534
+ created_at=time.time(),
493
535
  evaluation_runs=[eval_run],
494
536
  duration=duration,
495
537
  span_type="evaluation"
496
538
  ))
497
539
 
498
540
  def record_input(self, inputs: dict):
499
- """Record input parameters for the current span"""
500
- if self._current_span:
541
+ current_span_id = current_span_var.get()
542
+ if current_span_id:
543
+ entry_span_type = "span"
544
+ current_depth = self._span_depths.get(current_span_id, 0)
545
+ function_name = "unknown_function" # Default
546
+ for entry in reversed(self.entries):
547
+ if entry.span_id == current_span_id and entry.type == 'enter':
548
+ entry_span_type = entry.span_type
549
+ function_name = entry.function
550
+ break
551
+
501
552
  self.add_entry(TraceEntry(
502
553
  type="input",
503
- function=self._current_span,
504
- depth=self.tracer.depth,
505
- message=f"Inputs to {self._current_span}",
506
- timestamp=time.time(),
554
+ function=function_name,
555
+ span_id=current_span_id, # Use current span_id
556
+ trace_id=self.trace_id, # Use the trace_id from the trace client
557
+ depth=current_depth,
558
+ message=f"Inputs to {function_name}",
559
+ created_at=time.time(),
507
560
  inputs=inputs,
508
- span_type=self.span_type
561
+ span_type=entry_span_type
509
562
  ))
510
563
 
511
564
  async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
@@ -519,21 +572,30 @@ class TraceClient:
519
572
  raise
520
573
 
521
574
  def record_output(self, output: Any):
522
- """Record output for the current span"""
523
- if self._current_span:
575
+ current_span_id = current_span_var.get()
576
+ if current_span_id:
577
+ entry_span_type = "span"
578
+ current_depth = self._span_depths.get(current_span_id, 0)
579
+ function_name = "unknown_function" # Default
580
+ for entry in reversed(self.entries):
581
+ if entry.span_id == current_span_id and entry.type == 'enter':
582
+ entry_span_type = entry.span_type
583
+ function_name = entry.function
584
+ break
585
+
524
586
  entry = TraceEntry(
525
587
  type="output",
526
- function=self._current_span,
527
- depth=self.tracer.depth,
528
- message=f"Output from {self._current_span}",
529
- timestamp=time.time(),
588
+ function=function_name,
589
+ span_id=current_span_id, # Use current span_id
590
+ depth=current_depth,
591
+ message=f"Output from {function_name}",
592
+ created_at=time.time(),
530
593
  output="<pending>" if inspect.iscoroutine(output) else output,
531
- span_type=self.span_type
594
+ span_type=entry_span_type
532
595
  )
533
596
  self.add_entry(entry)
534
597
 
535
598
  if inspect.iscoroutine(output):
536
- # Create a task to update the output once the coroutine completes
537
599
  asyncio.create_task(self._update_coroutine_output(entry, output))
538
600
 
539
601
  def add_entry(self, entry: TraceEntry):
@@ -546,6 +608,58 @@ class TraceClient:
546
608
  for entry in self.entries:
547
609
  entry.print_entry()
548
610
 
611
+ def print_hierarchical(self):
612
+ """Print the trace in a hierarchical structure based on parent-child relationships"""
613
+ # First, build a map of spans
614
+ spans = {}
615
+ root_spans = []
616
+
617
+ # Collect all enter events first
618
+ for entry in self.entries:
619
+ if entry.type == "enter":
620
+ spans[entry.function] = {
621
+ "name": entry.function,
622
+ "depth": entry.depth,
623
+ "parent_id": entry.parent_span_id,
624
+ "children": []
625
+ }
626
+
627
+ # If no parent, it's a root span
628
+ if not entry.parent_span_id:
629
+ root_spans.append(entry.function)
630
+ elif entry.parent_span_id not in spans:
631
+ # If parent doesn't exist yet, temporarily treat as root
632
+ # (we'll fix this later)
633
+ root_spans.append(entry.function)
634
+
635
+ # Build parent-child relationships
636
+ for span_name, span in spans.items():
637
+ parent = span["parent_id"]
638
+ if parent and parent in spans:
639
+ spans[parent]["children"].append(span_name)
640
+ # Remove from root spans if it was temporarily there
641
+ if span_name in root_spans:
642
+ root_spans.remove(span_name)
643
+
644
+ # Now print the hierarchy
645
+ def print_span(span_name, level=0):
646
+ if span_name not in spans:
647
+ return
648
+
649
+ span = spans[span_name]
650
+ indent = " " * level
651
+ parent_info = f" (parent_id: {span['parent_id']})" if span["parent_id"] else ""
652
+ print(f"{indent}→ {span_name}{parent_info}")
653
+
654
+ # Print children
655
+ for child in span["children"]:
656
+ print_span(child, level + 1)
657
+
658
+ # Print starting with root spans
659
+ print("\nHierarchical Trace Structure:")
660
+ for root in root_spans:
661
+ print_span(root)
662
+
549
663
  def get_duration(self) -> float:
550
664
  """
551
665
  Get the total duration of this trace
@@ -554,58 +668,126 @@ class TraceClient:
554
668
 
555
669
  def condense_trace(self, entries: List[dict]) -> List[dict]:
556
670
  """
557
- Condenses trace entries into a single entry for each function call.
671
+ Condenses trace entries into a single entry for each span instance,
672
+ preserving parent-child span relationships using span_id and parent_span_id.
558
673
  """
559
- condensed = []
560
- active_functions = [] # Stack to track nested function calls
561
- function_entries = {} # Store entries for each function
674
+ spans_by_id: Dict[str, dict] = {}
675
+ evaluation_runs: List[EvaluationRun] = []
676
+
677
+ # First pass: Group entries by span_id and gather data
678
+ for entry in entries:
679
+ span_id = entry.get("span_id")
680
+ if not span_id:
681
+ continue # Skip entries without a span_id (should not happen)
562
682
 
563
- for i, entry in enumerate(entries):
564
- function = entry["function"]
565
-
566
683
  if entry["type"] == "enter":
567
- # Initialize new function entry
568
- function_entries[function] = {
569
- "depth": entry["depth"],
570
- "function": function,
571
- "timestamp": entry["timestamp"],
572
- "inputs": None,
573
- "output": None,
574
- "evaluation_runs": [],
575
- "span_type": entry.get("span_type", "span")
576
- }
577
- active_functions.append(function)
578
-
579
- elif entry["type"] == "exit" and function in active_functions:
580
- # Complete function entry
581
- current_entry = function_entries[function]
582
- current_entry["duration"] = entry["timestamp"] - current_entry["timestamp"]
583
- condensed.append(current_entry)
584
- active_functions.remove(function)
585
- # del function_entries[function]
586
-
587
- # The OR condition is to handle the LLM client case.
588
- # LLM client is a special case where we exit the span, so when we attach evaluations to it,
589
- # we have to check if the previous entry is an LLM call.
590
- elif function in active_functions or entry["type"] == "evaluation" and entries[i-1]["function"] == entry["function"]:
591
- # Update existing function entry with additional data
592
- current_entry = function_entries[function]
684
+ if span_id not in spans_by_id:
685
+ spans_by_id[span_id] = {
686
+ "span_id": span_id,
687
+ "function": entry["function"],
688
+ "depth": entry["depth"], # Use the depth recorded at entry time
689
+ "created_at": entry["created_at"],
690
+ "trace_id": entry["trace_id"],
691
+ "parent_span_id": entry.get("parent_span_id"),
692
+ "span_type": entry.get("span_type", "span"),
693
+ "inputs": None,
694
+ "output": None,
695
+ "evaluation_runs": [],
696
+ "duration": None
697
+ }
698
+ # Handle potential duplicate enter events if necessary (e.g., log warning)
699
+
700
+ elif span_id in spans_by_id:
701
+ current_span_data = spans_by_id[span_id]
593
702
 
594
703
  if entry["type"] == "input" and entry["inputs"]:
595
- current_entry["inputs"] = entry["inputs"]
596
-
597
- if entry["type"] == "output" and entry["output"]:
598
- current_entry["output"] = entry["output"]
599
-
600
- if entry["type"] == "evaluation" and entry["evaluation_runs"]:
601
- current_entry["evaluation_runs"] = entry["evaluation_runs"]
704
+ # Merge inputs if multiple are recorded, or just assign
705
+ if current_span_data["inputs"] is None:
706
+ current_span_data["inputs"] = entry["inputs"]
707
+ elif isinstance(current_span_data["inputs"], dict) and isinstance(entry["inputs"], dict):
708
+ current_span_data["inputs"].update(entry["inputs"])
709
+ # Add more sophisticated merging if needed
602
710
 
603
- # Sort by timestamp
604
- condensed.sort(key=lambda x: x["timestamp"])
605
-
606
- return condensed
711
+ elif entry["type"] == "output" and "output" in entry:
712
+ current_span_data["output"] = entry["output"]
713
+
714
+ elif entry["type"] == "evaluation" and entry.get("evaluation_runs"):
715
+ if current_span_data.get("evaluation_runs") is not None:
716
+ evaluation_runs.extend(entry["evaluation_runs"])
717
+
718
+ elif entry["type"] == "exit":
719
+ if current_span_data["duration"] is None: # Calculate duration only once
720
+ start_time = datetime.fromisoformat(current_span_data.get("created_at", entry["created_at"]))
721
+ end_time = datetime.fromisoformat(entry["created_at"])
722
+ current_span_data["duration"] = (end_time - start_time).total_seconds()
723
+ # Update depth if exit depth is different (though current span() implementation keeps it same)
724
+ # current_span_data["depth"] = entry["depth"]
725
+
726
+ # Convert dictionary to a list initially for easier access
727
+ spans_list = list(spans_by_id.values())
728
+
729
+ # Build tree structure (adjacency list) and find roots
730
+ children_map: Dict[Optional[str], List[dict]] = {}
731
+ roots = []
732
+ span_map = {span['span_id']: span for span in spans_list} # Map for quick lookup
733
+
734
+ for span in spans_list:
735
+ parent_id = span.get("parent_span_id")
736
+ if parent_id is None:
737
+ roots.append(span)
738
+ else:
739
+ if parent_id not in children_map:
740
+ children_map[parent_id] = []
741
+ children_map[parent_id].append(span)
607
742
 
608
- def save(self, empty_save: bool = False, overwrite: bool = False) -> Tuple[str, dict]:
743
+ # Sort roots by timestamp
744
+ roots.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
745
+
746
+ # Perform depth-first traversal to get the final sorted list
747
+ sorted_condensed_list = []
748
+ visited = set() # To handle potential cycles, though unlikely with UUIDs
749
+
750
+ def dfs(span_data):
751
+ span_id = span_data['span_id']
752
+ if span_id in visited:
753
+ return # Avoid infinite loops in case of cycles
754
+ visited.add(span_id)
755
+
756
+ sorted_condensed_list.append(span_data) # Add parent before children
757
+
758
+ # Get children, sort them by created_at, and visit them
759
+ span_children = children_map.get(span_id, [])
760
+ span_children.sort(key=lambda x: datetime.fromisoformat(x.get("created_at", "1970-01-01T00:00:00")))
761
+ for child in span_children:
762
+ # Ensure the child exists in our map before recursing
763
+ if child['span_id'] in span_map:
764
+ dfs(child)
765
+ else:
766
+ # This case might indicate an issue, but we'll add the child directly
767
+ # if its parent was processed but the child itself wasn't in the initial list?
768
+ # Or if the child's 'enter' event was missing. For robustness, add it.
769
+ if child['span_id'] not in visited:
770
+ visited.add(child['span_id'])
771
+ sorted_condensed_list.append(child)
772
+
773
+
774
+ # Start DFS from each root
775
+ for root_span in roots:
776
+ if root_span['span_id'] not in visited:
777
+ dfs(root_span)
778
+
779
+ # Handle spans that might not have been reachable from roots (orphans)
780
+ # Though ideally, all spans should descend from a root.
781
+ for span_data in spans_list:
782
+ if span_data['span_id'] not in visited:
783
+ # Decide how to handle orphans, maybe append them at the end sorted by time?
784
+ # For now, let's just add them to ensure they aren't lost.
785
+ sorted_condensed_list.append(span_data)
786
+
787
+
788
+ return sorted_condensed_list, evaluation_runs
789
+
790
+ def save(self, overwrite: bool = False) -> Tuple[str, dict]:
609
791
  """
610
792
  Save the current trace to the database.
611
793
  Returns a tuple of (trace_id, trace_data) where trace_data is the trace data that was saved.
@@ -615,7 +797,7 @@ class TraceClient:
615
797
 
616
798
  raw_entries = [entry.to_dict() for entry in self.entries]
617
799
 
618
- condensed_entries = self.condense_trace(raw_entries)
800
+ condensed_entries, evaluation_runs = self.condense_trace(raw_entries)
619
801
 
620
802
  # Calculate total token counts from LLM API calls
621
803
  total_prompt_tokens = 0
@@ -688,30 +870,32 @@ class TraceClient:
688
870
  "total_cost_usd": total_cost
689
871
  },
690
872
  "entries": condensed_entries,
691
- "empty_save": empty_save,
692
- "overwrite": overwrite
873
+ "evaluation_runs": evaluation_runs,
874
+ "overwrite": overwrite,
875
+ "parent_trace_id": self.parent_trace_id,
876
+ "parent_name": self.parent_name
693
877
  }
694
878
  # Execute asynchrous evaluation in the background
695
- if not empty_save: # Only send to RabbitMQ if the trace is not empty
696
- # Send trace data to evaluation queue via API
697
- try:
698
- response = requests.post(
699
- JUDGMENT_TRACES_ADD_TO_EVAL_QUEUE_API_URL,
700
- json=trace_data,
701
- headers={
702
- "Content-Type": "application/json",
703
- "Authorization": f"Bearer {self.tracer.api_key}",
704
- "X-Organization-Id": self.tracer.organization_id
705
- },
706
- verify=True
707
- )
879
+ # if not empty_save: # Only send to RabbitMQ if the trace is not empty
880
+ # # Send trace data to evaluation queue via API
881
+ # try:
882
+ # response = requests.post(
883
+ # JUDGMENT_TRACES_ADD_TO_EVAL_QUEUE_API_URL,
884
+ # json=trace_data,
885
+ # headers={
886
+ # "Content-Type": "application/json",
887
+ # "Authorization": f"Bearer {self.tracer.api_key}",
888
+ # "X-Organization-Id": self.tracer.organization_id
889
+ # },
890
+ # verify=True
891
+ # )
708
892
 
709
- if response.status_code != HTTPStatus.OK:
710
- warnings.warn(f"Failed to add trace to evaluation queue: {response.text}")
711
- except Exception as e:
712
- warnings.warn(f"Error sending trace to evaluation queue: {str(e)}")
893
+ # if response.status_code != HTTPStatus.OK:
894
+ # warnings.warn(f"Failed to add trace to evaluation queue: {response.text}")
895
+ # except Exception as e:
896
+ # warnings.warn(f"Error sending trace to evaluation queue: {str(e)}")
713
897
 
714
- self.trace_manager_client.save_trace(trace_data, empty_save)
898
+ self.trace_manager_client.save_trace(trace_data)
715
899
 
716
900
  return self.trace_id, trace_data
717
901
 
@@ -745,7 +929,6 @@ class Tracer:
745
929
  self.project_name: str = project_name
746
930
  self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
747
931
  self.organization_id: str = organization_id
748
- self.depth: int = 0
749
932
  self._current_trace: Optional[str] = None
750
933
  self.rules: List[Rule] = rules or [] # Store rules at tracer level
751
934
  self.initialized: bool = True
@@ -770,6 +953,15 @@ class Tracer:
770
953
  """Start a new trace context using a context manager"""
771
954
  trace_id = str(uuid.uuid4())
772
955
  project = project_name if project_name is not None else self.project_name
956
+
957
+ # Get parent trace info from context
958
+ parent_trace = current_trace_var.get()
959
+ parent_trace_id = None
960
+ parent_name = None
961
+
962
+ if parent_trace:
963
+ parent_trace_id = parent_trace.trace_id
964
+ parent_name = parent_trace.name
773
965
 
774
966
  trace = TraceClient(
775
967
  self,
@@ -779,25 +971,28 @@ class Tracer:
779
971
  overwrite=overwrite,
780
972
  rules=self.rules, # Pass combined rules to the trace client
781
973
  enable_monitoring=self.enable_monitoring,
782
- enable_evaluations=self.enable_evaluations
974
+ enable_evaluations=self.enable_evaluations,
975
+ parent_trace_id=parent_trace_id,
976
+ parent_name=parent_name
783
977
  )
784
- prev_trace = self._current_trace
785
- self._current_trace = trace
978
+
979
+ # Set the current trace in context variables
980
+ token = current_trace_var.set(trace)
786
981
 
787
982
  # Automatically create top-level span
788
983
  with trace.span(name or "unnamed_trace") as span:
789
984
  try:
790
985
  # Save the trace to the database to handle Evaluations' trace_id referential integrity
791
- trace.save(empty_save=True, overwrite=overwrite)
792
986
  yield trace
793
987
  finally:
794
- self._current_trace = prev_trace
988
+ # Reset the context variable
989
+ current_trace_var.reset(token)
795
990
 
796
991
  def get_current_trace(self) -> Optional[TraceClient]:
797
992
  """
798
- Get the current trace context
993
+ Get the current trace context from contextvars
799
994
  """
800
- return self._current_trace
995
+ return current_trace_var.get()
801
996
 
802
997
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
803
998
  """
@@ -823,20 +1018,56 @@ class Tracer:
823
1018
  if asyncio.iscoroutinefunction(func):
824
1019
  @functools.wraps(func)
825
1020
  async def async_wrapper(*args, **kwargs):
826
- # If there's already a trace, use it. Otherwise create a new one
827
- if self._current_trace:
828
- trace = self._current_trace
829
- else:
1021
+ # Get current trace from context
1022
+ current_trace = current_trace_var.get()
1023
+
1024
+ # If there's no current trace, create a root trace
1025
+ if not current_trace:
830
1026
  trace_id = str(uuid.uuid4())
831
- trace_name = func.__name__
832
1027
  project = project_name if project_name is not None else self.project_name
833
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules, enable_monitoring=self.enable_monitoring, enable_evaluations=self.enable_evaluations)
834
- self._current_trace = trace
835
- # Only save empty trace for the root call
836
- trace.save(empty_save=True, overwrite=overwrite)
837
-
838
- try:
839
- with trace.span(span_name, span_type=span_type) as span:
1028
+
1029
+ # Create a new trace client to serve as the root
1030
+ current_trace = TraceClient(
1031
+ self,
1032
+ trace_id,
1033
+ span_name, # MODIFIED: Use span_name directly
1034
+ project_name=project,
1035
+ overwrite=overwrite,
1036
+ rules=self.rules,
1037
+ enable_monitoring=self.enable_monitoring,
1038
+ enable_evaluations=self.enable_evaluations
1039
+ )
1040
+
1041
+ # Save empty trace and set trace context
1042
+ # current_trace.save(empty_save=True, overwrite=overwrite)
1043
+ trace_token = current_trace_var.set(current_trace)
1044
+
1045
+ try:
1046
+ # Use span for the function execution within the root trace
1047
+ # This sets the current_span_var
1048
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1049
+ # Record inputs
1050
+ span.record_input({
1051
+ 'args': str(args),
1052
+ 'kwargs': kwargs
1053
+ })
1054
+
1055
+ # Execute function
1056
+ result = await func(*args, **kwargs)
1057
+
1058
+ # Record output
1059
+ span.record_output(result)
1060
+
1061
+ # Save the completed trace
1062
+ current_trace.save(overwrite=overwrite)
1063
+ return result
1064
+ finally:
1065
+ # Reset trace context (span context resets automatically)
1066
+ current_trace_var.reset(trace_token)
1067
+ else:
1068
+ # Already have a trace context, just create a span in it
1069
+ # The span method handles current_span_var
1070
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
840
1071
  # Record inputs
841
1072
  span.record_input({
842
1073
  'args': str(args),
@@ -850,30 +1081,62 @@ class Tracer:
850
1081
  span.record_output(result)
851
1082
 
852
1083
  return result
853
- finally:
854
- # Only save and cleanup if this is the root observe call
855
- if self.depth == 0:
856
- trace.save(empty_save=False, overwrite=overwrite)
857
- self._current_trace = None
858
1084
 
859
1085
  return async_wrapper
860
1086
  else:
1087
+ # Non-async function implementation remains unchanged
861
1088
  @functools.wraps(func)
862
1089
  def wrapper(*args, **kwargs):
863
- # If there's already a trace, use it. Otherwise create a new one
864
- if self._current_trace:
865
- trace = self._current_trace
866
- else:
1090
+ # Get current trace from context
1091
+ current_trace = current_trace_var.get()
1092
+
1093
+ # If there's no current trace, create a root trace
1094
+ if not current_trace:
867
1095
  trace_id = str(uuid.uuid4())
868
- trace_name = func.__name__
869
1096
  project = project_name if project_name is not None else self.project_name
870
- trace = TraceClient(self, trace_id, trace_name, project_name=project, overwrite=overwrite, rules=self.rules, enable_monitoring=self.enable_monitoring)
871
- self._current_trace = trace
872
- # Only save empty trace for the root call
873
- trace.save(empty_save=True, overwrite=overwrite)
874
-
875
- try:
876
- with trace.span(span_name, span_type=span_type) as span:
1097
+
1098
+ # Create a new trace client to serve as the root
1099
+ current_trace = TraceClient(
1100
+ self,
1101
+ trace_id,
1102
+ span_name, # MODIFIED: Use span_name directly
1103
+ project_name=project,
1104
+ overwrite=overwrite,
1105
+ rules=self.rules,
1106
+ enable_monitoring=self.enable_monitoring,
1107
+ enable_evaluations=self.enable_evaluations
1108
+ )
1109
+
1110
+ # Save empty trace and set trace context
1111
+ # current_trace.save(empty_save=True, overwrite=overwrite)
1112
+ trace_token = current_trace_var.set(current_trace)
1113
+
1114
+ try:
1115
+ # Use span for the function execution within the root trace
1116
+ # This sets the current_span_var
1117
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1118
+ # Record inputs
1119
+ span.record_input({
1120
+ 'args': str(args),
1121
+ 'kwargs': kwargs
1122
+ })
1123
+
1124
+ # Execute function
1125
+ result = func(*args, **kwargs)
1126
+
1127
+ # Record output
1128
+ span.record_output(result)
1129
+
1130
+ # Save the completed trace
1131
+ current_trace.save(overwrite=overwrite)
1132
+ return result
1133
+ finally:
1134
+ # Reset trace context (span context resets automatically)
1135
+ current_trace_var.reset(trace_token)
1136
+ else:
1137
+ # Already have a trace context, just create a span in it
1138
+ # The span method handles current_span_var
1139
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
877
1140
  # Record inputs
878
1141
  span.record_input({
879
1142
  'args': str(args),
@@ -887,11 +1150,6 @@ class Tracer:
887
1150
  span.record_output(result)
888
1151
 
889
1152
  return result
890
- finally:
891
- # Only save and cleanup if this is the root observe call
892
- if self.depth == 0:
893
- trace.save(empty_save=False, overwrite=overwrite)
894
- self._current_trace = None
895
1153
 
896
1154
  return wrapper
897
1155
 
@@ -900,27 +1158,36 @@ class Tracer:
900
1158
  Decorator to trace function execution with detailed entry/exit information.
901
1159
  """
902
1160
  if func is None:
903
- return lambda f: self.observe(f, name=name, span_type=span_type)
1161
+ return lambda f: self.score(f, scorers=scorers, model=model, log_results=log_results, name=name, span_type=span_type)
904
1162
 
905
1163
  if asyncio.iscoroutinefunction(func):
906
1164
  @functools.wraps(func)
907
1165
  async def async_wrapper(*args, **kwargs):
908
- if self._current_trace:
909
- self._current_trace.async_evaluate(scorers=[scorers], input=args, actual_output=kwargs, model=model, log_results=log_results)
1166
+ # Get current trace from contextvars
1167
+ current_trace = current_trace_var.get()
1168
+ if current_trace and scorers:
1169
+ current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1170
+ return await func(*args, **kwargs)
910
1171
  return async_wrapper
911
1172
  else:
912
1173
  @functools.wraps(func)
913
1174
  def wrapper(*args, **kwargs):
914
- if self._current_trace:
915
- self._current_trace.async_evaluate(scorers=[scorers], input=args, actual_output=kwargs, model="gpt-4o-mini", log_results=True)
1175
+ # Get current trace from contextvars
1176
+ current_trace = current_trace_var.get()
1177
+ if current_trace and scorers:
1178
+ current_trace.async_evaluate(scorers=scorers, input=args, actual_output=kwargs, model=model, log_results=log_results)
1179
+ return func(*args, **kwargs)
916
1180
  return wrapper
917
1181
 
918
1182
  def async_evaluate(self, *args, **kwargs):
919
1183
  if not self.enable_evaluations:
920
1184
  return
921
1185
 
922
- if self._current_trace:
923
- self._current_trace.async_evaluate(*args, **kwargs)
1186
+ # Get current trace from context
1187
+ current_trace = current_trace_var.get()
1188
+
1189
+ if current_trace:
1190
+ current_trace.async_evaluate(*args, **kwargs)
924
1191
  else:
925
1192
  warnings.warn("No trace found, skipping evaluation")
926
1193
 
@@ -934,14 +1201,14 @@ def wrap(client: Any) -> Any:
934
1201
  span_name, original_create = _get_client_config(client)
935
1202
 
936
1203
  def traced_create(*args, **kwargs):
937
- # Get the current tracer instance (might be created after client was wrapped)
938
- tracer = Tracer._instance
1204
+ # Get the current trace from contextvars
1205
+ current_trace = current_trace_var.get()
939
1206
 
940
- # Skip tracing if no tracer exists or no active trace
941
- if not tracer or not tracer._current_trace:
1207
+ # Skip tracing if no active trace
1208
+ if not current_trace:
942
1209
  return original_create(*args, **kwargs)
943
1210
 
944
- with tracer._current_trace.span(span_name, span_type="llm") as span:
1211
+ with current_trace.span(span_name, span_type="llm") as span:
945
1212
  # Format and record the input parameters
946
1213
  input_data = _format_input_data(client, **kwargs)
947
1214
  span.record_input(input_data)
@@ -1033,4 +1300,59 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1033
1300
  "output_tokens": response.usage.output_tokens,
1034
1301
  "total_tokens": response.usage.input_tokens + response.usage.output_tokens
1035
1302
  }
1036
- }
1303
+ }
1304
+
1305
+ # Add a global context-preserving gather function
1306
+ # async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
1307
+ # """ # REMOVED
1308
+ # A wrapper around asyncio.gather that ensures the trace context # REMOVED
1309
+ # is available within the gathered coroutines using contextvars.copy_context. # REMOVED
1310
+ # """ # REMOVED
1311
+ # # Get the original asyncio.gather (if we patched it) # REMOVED
1312
+ # original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
1313
+ # # REMOVED
1314
+ # # Use contextvars.copy_context() to ensure context propagation # REMOVED
1315
+ # ctx = contextvars.copy_context() # REMOVED
1316
+ # # REMOVED
1317
+ # # Wrap the gather call within the copied context # REMOVED
1318
+ # return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
1319
+
1320
+ # Store the original gather and apply the patch *once*
1321
+ # global _original_gather_stored # REMOVED
1322
+ # if not globals().get('_original_gather_stored'): # REMOVED
1323
+ # # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
1324
+ # if asyncio.gather.__name__ != 'trace_gather': # REMOVED
1325
+ # asyncio._original_gather = asyncio.gather # REMOVED
1326
+ # asyncio.gather = trace_gather # REMOVED
1327
+ # _original_gather_stored = True # REMOVED
1328
+
1329
+ # Add the new TraceThreadPoolExecutor class
1330
+ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
1331
+ """
1332
+ A ThreadPoolExecutor subclass that automatically propagates contextvars
1333
+ from the submitting thread to the worker thread using copy_context().run().
1334
+
1335
+ This ensures that context variables like `current_trace_var` and
1336
+ `current_span_var` are available within functions executed by the pool,
1337
+ allowing the Tracer to maintain correct parent-child relationships across
1338
+ thread boundaries.
1339
+ """
1340
+ def submit(self, fn, /, *args, **kwargs):
1341
+ """
1342
+ Submit a callable to be executed with the captured context.
1343
+ """
1344
+ # Capture context from the submitting thread
1345
+ ctx = contextvars.copy_context()
1346
+
1347
+ # We use functools.partial to bind the arguments to the function *now*,
1348
+ # as ctx.run doesn't directly accept *args, **kwargs in the same way
1349
+ # submit does. It expects ctx.run(callable, arg1, arg2...).
1350
+ func_with_bound_args = functools.partial(fn, *args, **kwargs)
1351
+
1352
+ # Submit the ctx.run callable to the original executor.
1353
+ # ctx.run will execute the (now argument-bound) function within the
1354
+ # captured context in the worker thread.
1355
+ return super().submit(ctx.run, func_with_bound_args)
1356
+
1357
+ # Note: The `map` method would also need to be overridden for full context
1358
+ # propagation if users rely on it, but `submit` is the most common use case.