judgeval 0.0.39__py3-none-any.whl → 0.0.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -34,13 +34,14 @@ from typing import (
34
34
  Union,
35
35
  AsyncGenerator,
36
36
  TypeAlias,
37
+ Set
37
38
  )
38
39
  from rich import print as rprint
39
40
  import types # <--- Add this import
40
41
 
41
42
  # Third-party imports
42
43
  import requests
43
- from litellm import cost_per_token
44
+ from litellm import cost_per_token as _original_cost_per_token
44
45
  from pydantic import BaseModel
45
46
  from rich import print as rprint
46
47
  from openai import OpenAI, AsyncOpenAI
@@ -59,7 +60,7 @@ from judgeval.constants import (
59
60
  JUDGMENT_TRACES_DELETE_API_URL,
60
61
  JUDGMENT_PROJECT_DELETE_API_URL,
61
62
  )
62
- from judgeval.data import Example, Trace, TraceSpan
63
+ from judgeval.data import Example, Trace, TraceSpan, TraceUsage
63
64
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
64
65
  from judgeval.rules import Rule
65
66
  from judgeval.evaluation_run import EvaluationRun
@@ -155,9 +156,29 @@ class TraceManagerClient:
155
156
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
156
157
  """
157
158
  # Save to Judgment API
159
+
160
+ def fallback_encoder(obj):
161
+ """
162
+ Custom JSON encoder fallback.
163
+ Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
164
+ You can choose which one you prefer or try them in sequence.
165
+ """
166
+ try:
167
+ # Option 1: Prefer __repr__ for a more detailed representation
168
+ return repr(obj)
169
+ except Exception:
170
+ # Option 2: Fallback to str() if __repr__ fails or if you prefer str()
171
+ try:
172
+ return str(obj)
173
+ except Exception as e:
174
+ # If both fail, you might return a placeholder or re-raise
175
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
176
+
177
+ serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
178
+
158
179
  response = requests.post(
159
180
  JUDGMENT_TRACES_SAVE_API_URL,
160
- json=trace_data,
181
+ data=serialized_trace_data,
161
182
  headers={
162
183
  "Content-Type": "application/json",
163
184
  "Authorization": f"Bearer {self.judgment_api_key}",
@@ -463,6 +484,7 @@ class TraceClient:
463
484
  if current_span_id:
464
485
  span = self.span_id_to_span[current_span_id]
465
486
  span.evaluation_runs.append(eval_run)
487
+ span.has_evaluation = True # Set the has_evaluation flag
466
488
  self.evaluation_runs.append(eval_run)
467
489
 
468
490
  def add_annotation(self, annotation: TraceAnnotation):
@@ -474,16 +496,25 @@ class TraceClient:
474
496
  current_span_id = current_span_var.get()
475
497
  if current_span_id:
476
498
  span = self.span_id_to_span[current_span_id]
499
+ # Ignore self parameter
500
+ if "self" in inputs:
501
+ del inputs["self"]
477
502
  span.inputs = inputs
503
+
504
+ def record_agent_name(self, agent_name: str):
505
+ current_span_id = current_span_var.get()
506
+ if current_span_id:
507
+ span = self.span_id_to_span[current_span_id]
508
+ span.agent_name = agent_name
478
509
 
479
- async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
510
+ async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
480
511
  """Helper method to update the output of a trace entry once the coroutine completes"""
481
512
  try:
482
513
  result = await coroutine
483
- span.output = result
514
+ setattr(span, field, result)
484
515
  return result
485
516
  except Exception as e:
486
- span.output = f"Error: {str(e)}"
517
+ setattr(span, field, f"Error: {str(e)}")
487
518
  raise
488
519
 
489
520
  def record_output(self, output: Any):
@@ -493,12 +524,30 @@ class TraceClient:
493
524
  span.output = "<pending>" if inspect.iscoroutine(output) else output
494
525
 
495
526
  if inspect.iscoroutine(output):
496
- asyncio.create_task(self._update_coroutine_output(span, output))
527
+ asyncio.create_task(self._update_coroutine(span, output, "output"))
528
+
529
+ return span # Return the created entry
530
+ # Removed else block - original didn't have one
531
+ return None # Return None if no span_id found
532
+
533
+ def record_usage(self, usage: TraceUsage):
534
+ current_span_id = current_span_var.get()
535
+ if current_span_id:
536
+ span = self.span_id_to_span[current_span_id]
537
+ span.usage = usage
497
538
 
498
539
  return span # Return the created entry
499
540
  # Removed else block - original didn't have one
500
541
  return None # Return None if no span_id found
501
-
542
+
543
+ def record_error(self, error: Any):
544
+ current_span_id = current_span_var.get()
545
+ if current_span_id:
546
+ span = self.span_id_to_span[current_span_id]
547
+ span.error = error
548
+ return span
549
+ return None
550
+
502
551
  def add_span(self, span: TraceSpan):
503
552
  """Add a trace span to this trace context"""
504
553
  self.trace_spans.append(span)
@@ -523,133 +572,6 @@ class TraceClient:
523
572
  """
524
573
  # Calculate total elapsed time
525
574
  total_duration = self.get_duration()
526
-
527
- # Only count tokens for actual LLM API call spans
528
- llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
529
- for span in self.trace_spans:
530
- span_function_name = span.function # Get function name safely
531
- # Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
532
- is_llm_span = span.span_type == "llm"
533
- has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
534
- output_is_dict = isinstance(span.output, dict)
535
-
536
- # --- DEBUG PRINT 1: Check if condition passes ---
537
- # if is_llm_entry and has_api_suffix and output_is_dict:
538
- # elif is_llm_entry:
539
- # # Print why it failed if it was an LLM entry
540
- # # --- END DEBUG ---
541
-
542
- if is_llm_span and has_api_suffix and output_is_dict:
543
- output = span.output
544
- usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
545
-
546
- # --- DEBUG PRINT 2: Check extracted usage ---
547
- # --- END DEBUG ---
548
-
549
- # --- NEW: Extract model_name correctly from nested inputs ---
550
- model_name = None
551
- span_inputs = span.inputs
552
- if span_inputs:
553
- # Try common locations for model name within the inputs structure
554
- invocation_params = span_inputs.get("invocation_params", {})
555
- serialized_data = span_inputs.get("serialized", {})
556
-
557
- # Look in invocation_params (often directly contains model)
558
- if isinstance(invocation_params, dict):
559
- model_name = invocation_params.get("model")
560
-
561
- # Fallback: Check serialized 'repr' if it contains model info
562
- if not model_name and isinstance(serialized_data, dict):
563
- serialized_repr = serialized_data.get("repr", "")
564
- if "model_name=" in serialized_repr:
565
- try: # Simple parsing attempt
566
- model_name = serialized_repr.split("model_name='")[1].split("'")[0]
567
- except IndexError: pass # Ignore parsing errors
568
-
569
- # Fallback: Check top-level of invocation_params (sometimes passed flat)
570
- if not model_name and isinstance(invocation_params, dict):
571
- model_name = invocation_params.get("model") # Redundant check, but safe
572
-
573
- # Fallback: Check top-level of inputs itself (less likely for callbacks)
574
- if not model_name:
575
- model_name = span_inputs.get("model")
576
-
577
-
578
- # --- END NEW ---
579
-
580
- prompt_tokens = 0
581
- completion_tokens = 0
582
-
583
- # Handle OpenAI/Together format (checks within the 'usage' dict)
584
- if "prompt_tokens" in usage:
585
- prompt_tokens = usage.get("prompt_tokens", 0)
586
- completion_tokens = usage.get("completion_tokens", 0)
587
-
588
- # Handle Anthropic format - MAP values to standard keys
589
- elif "input_tokens" in usage:
590
- prompt_tokens = usage.get("input_tokens", 0) # Get value from input_tokens
591
- completion_tokens = usage.get("output_tokens", 0) # Get value from output_tokens
592
-
593
- # *** Overwrite the usage dict in the entry to use standard keys ***
594
- original_total = usage.get("total_tokens", 0)
595
- original_total_cost = usage.get("total_cost_usd", 0.0) # Preserve if already calculated
596
- # Recalculate cost just in case it wasn't done correctly before
597
- temp_prompt_cost, temp_completion_cost = 0.0, 0.0
598
- if model_name:
599
- try:
600
- temp_prompt_cost, temp_completion_cost = cost_per_token(
601
- model=model_name,
602
- prompt_tokens=prompt_tokens,
603
- completion_tokens=completion_tokens
604
- )
605
- except Exception:
606
- pass # Ignore cost calculation errors here, focus on keys
607
- # Replace the usage dict with one using standard keys but Anthropic values
608
- output["usage"] = {
609
- "prompt_tokens": prompt_tokens,
610
- "completion_tokens": completion_tokens,
611
- "total_tokens": original_total,
612
- "prompt_tokens_cost_usd": temp_prompt_cost, # Use standard cost key
613
- "completion_tokens_cost_usd": temp_completion_cost, # Use standard cost key
614
- "total_cost_usd": original_total_cost if original_total_cost > 0 else (temp_prompt_cost + temp_completion_cost)
615
- }
616
- usage = output["usage"]
617
-
618
- # Calculate costs if model name is available and ensure they are stored with standard keys
619
- prompt_tokens = usage.get("prompt_tokens", 0)
620
- completion_tokens = usage.get("completion_tokens", 0)
621
-
622
- # Calculate costs if model name is available
623
- if model_name:
624
- try:
625
- # Recalculate costs based on potentially mapped tokens
626
- prompt_cost, completion_cost = cost_per_token(
627
- model=model_name,
628
- prompt_tokens=prompt_tokens,
629
- completion_tokens=completion_tokens
630
- )
631
-
632
- # Add cost information directly to the usage dictionary in the condensed entry
633
- # Ensure 'usage' exists in the output dict before modifying it
634
- # Add/Update cost information using standard keys
635
-
636
- if "usage" not in output:
637
- output["usage"] = {} # Initialize if missing
638
- elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
639
- print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
640
- output["usage"] = {} # Reset to dict
641
-
642
- output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
643
- output["usage"]["completion_tokens_cost_usd"] = completion_cost
644
- output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
645
- except Exception as e:
646
- # If cost calculation fails, continue without adding costs
647
- print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
648
- pass
649
- else:
650
- print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
651
-
652
-
653
575
  # Create trace document - Always use standard keys for top-level counts
654
576
  trace_data = {
655
577
  "trace_id": self.trace_id,
@@ -677,13 +599,25 @@ class TraceClient:
677
599
  def delete(self):
678
600
  return self.trace_manager_client.delete_trace(self.trace_id)
679
601
 
680
-
602
+ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: Tuple[Optional[type], Optional[BaseException], Optional[types.TracebackType]]):
603
+ if not current_trace:
604
+ return
605
+
606
+ exc_type, exc_value, exc_traceback_obj = exc_info
607
+ formatted_exception = {
608
+ "type": exc_type.__name__ if exc_type else "UnknownExceptionType",
609
+ "message": str(exc_value) if exc_value else "No exception message",
610
+ "traceback": traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
611
+ }
612
+ current_trace.record_error(formatted_exception)
681
613
  class _DeepTracer:
682
614
  _instance: Optional["_DeepTracer"] = None
683
615
  _lock: threading.Lock = threading.Lock()
684
616
  _refcount: int = 0
685
617
  _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
686
618
  _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
619
+ _original_sys_trace: Optional[Callable] = None
620
+ _original_threading_trace: Optional[Callable] = None
687
621
 
688
622
  def _get_qual_name(self, frame) -> str:
689
623
  func_name = frame.f_code.co_name
@@ -731,12 +665,53 @@ class _DeepTracer:
731
665
  @functools.cache
732
666
  def _is_user_code(self, filename: str):
733
667
  return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
668
+
669
+ def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
670
+ """Cooperative trace function for sys.settrace that chains with existing tracers."""
671
+ # First, call the original sys trace function if it exists
672
+ original_result = None
673
+ if self._original_sys_trace:
674
+ try:
675
+ original_result = self._original_sys_trace(frame, event, arg)
676
+ except Exception:
677
+ # If the original tracer fails, continue with our tracing
678
+ pass
679
+
680
+ # Then do our own tracing
681
+ our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
682
+
683
+ # Return our tracer to continue tracing, but respect the original's decision
684
+ # If the original tracer returned None (stop tracing), we should respect that
685
+ if original_result is None and self._original_sys_trace:
686
+ return None
687
+
688
+ return our_result or original_result
689
+
690
+ def _cooperative_threading_trace(self, frame: types.FrameType, event: str, arg: Any):
691
+ """Cooperative trace function for threading.settrace that chains with existing tracers."""
692
+ # First, call the original threading trace function if it exists
693
+ original_result = None
694
+ if self._original_threading_trace:
695
+ try:
696
+ original_result = self._original_threading_trace(frame, event, arg)
697
+ except Exception:
698
+ # If the original tracer fails, continue with our tracing
699
+ pass
700
+
701
+ # Then do our own tracing
702
+ our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
703
+
704
+ # Return our tracer to continue tracing, but respect the original's decision
705
+ # If the original tracer returned None (stop tracing), we should respect that
706
+ if original_result is None and self._original_threading_trace:
707
+ return None
708
+
709
+ return our_result or original_result
734
710
 
735
- def _trace(self, frame: types.FrameType, event: str, arg: Any):
711
+ def _trace(self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable):
736
712
  frame.f_trace_lines = False
737
713
  frame.f_trace_opcodes = False
738
714
 
739
-
740
715
  if not self._should_trace(frame):
741
716
  return
742
717
 
@@ -752,6 +727,12 @@ class _DeepTracer:
752
727
  return
753
728
 
754
729
  qual_name = self._get_qual_name(frame)
730
+ instance_name = None
731
+ if 'self' in frame.f_locals:
732
+ instance = frame.f_locals['self']
733
+ class_name = instance.__class__.__name__
734
+ class_identifiers = getattr(Tracer._instance, 'class_identifiers', {})
735
+ instance_name = get_instance_prefixed_name(instance, class_name, class_identifiers)
755
736
  skip_stack = self._skip_stack.get()
756
737
 
757
738
  if event == "call":
@@ -814,7 +795,8 @@ class _DeepTracer:
814
795
  created_at=start_time,
815
796
  span_type="span",
816
797
  parent_span_id=parent_span_id,
817
- function=qual_name
798
+ function=qual_name,
799
+ agent_name=instance_name
818
800
  )
819
801
  current_trace.add_span(span)
820
802
 
@@ -869,35 +851,40 @@ class _DeepTracer:
869
851
  current_span_var.reset(frame.f_locals["_judgment_span_token"])
870
852
 
871
853
  elif event == "exception":
872
- exc_type, exc_value, exc_traceback = arg
873
- formatted_exception = {
874
- "type": exc_type.__name__,
875
- "message": str(exc_value),
876
- "traceback": traceback.format_tb(exc_traceback)
877
- }
878
- current_trace = current_trace_var.get()
879
- current_trace.record_output({
880
- "error": formatted_exception
881
- })
854
+ exc_type = arg[0]
855
+ if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
856
+ return
857
+ _capture_exception_for_trace(current_trace, arg)
858
+
882
859
 
883
- return self._trace
860
+ return continuation_func
884
861
 
885
862
  def __enter__(self):
886
863
  with self._lock:
887
864
  self._refcount += 1
888
865
  if self._refcount == 1:
866
+ # Store the existing trace functions before setting ours
867
+ self._original_sys_trace = sys.gettrace()
868
+ self._original_threading_trace = threading.gettrace()
869
+
889
870
  self._skip_stack.set([])
890
871
  self._span_stack.set([])
891
- sys.settrace(self._trace)
892
- threading.settrace(self._trace)
872
+
873
+ sys.settrace(self._cooperative_sys_trace)
874
+ threading.settrace(self._cooperative_threading_trace)
893
875
  return self
894
876
 
895
877
  def __exit__(self, exc_type, exc_val, exc_tb):
896
878
  with self._lock:
897
879
  self._refcount -= 1
898
880
  if self._refcount == 0:
899
- sys.settrace(None)
900
- threading.settrace(None)
881
+ # Restore the original trace functions instead of setting to None
882
+ sys.settrace(self._original_sys_trace)
883
+ threading.settrace(self._original_threading_trace)
884
+
885
+ # Clean up the references
886
+ self._original_sys_trace = None
887
+ self._original_threading_trace = None
901
888
 
902
889
 
903
890
  def log(self, message: str, level: str = "info"):
@@ -946,10 +933,6 @@ class Tracer:
946
933
  raise ValueError("Tracer must be configured with an Organization ID")
947
934
  if use_s3 and not s3_bucket_name:
948
935
  raise ValueError("S3 bucket name must be provided when use_s3 is True")
949
- if use_s3 and not (s3_aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")):
950
- raise ValueError("AWS Access Key ID must be provided when use_s3 is True")
951
- if use_s3 and not (s3_aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")):
952
- raise ValueError("AWS Secret Access Key must be provided when use_s3 is True")
953
936
 
954
937
  self.api_key: str = api_key
955
938
  self.project_name: str = project_name
@@ -961,6 +944,7 @@ class Tracer:
961
944
  self.initialized: bool = True
962
945
  self.enable_monitoring: bool = enable_monitoring
963
946
  self.enable_evaluations: bool = enable_evaluations
947
+ self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
964
948
 
965
949
  # Initialize S3 storage if enabled
966
950
  self.use_s3 = use_s3
@@ -1084,6 +1068,32 @@ class Tracer:
1084
1068
 
1085
1069
  rprint(f"[bold]{label}:[/bold] {msg}")
1086
1070
 
1071
+ def identify(self, identifier: str):
1072
+ """
1073
+ Class decorator that associates a class with a custom identifier.
1074
+
1075
+ This decorator creates a mapping between the class name and the provided
1076
+ identifier, which can be useful for tagging, grouping, or referencing
1077
+ classes in a standardized way.
1078
+
1079
+ Args:
1080
+ identifier: The identifier to associate with the decorated class
1081
+
1082
+ Returns:
1083
+ A decorator function that registers the class with the given identifier
1084
+
1085
+ Example:
1086
+ @tracer.identify(identifier="user_model")
1087
+ class User:
1088
+ # Class implementation
1089
+ """
1090
+ def decorator(cls):
1091
+ class_name = cls.__name__
1092
+ self.class_identifiers[class_name] = identifier
1093
+ return cls
1094
+
1095
+ return decorator
1096
+
1087
1097
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1088
1098
  """
1089
1099
  Decorator to trace function execution with detailed entry/exit information.
@@ -1106,10 +1116,10 @@ class Tracer:
1106
1116
  overwrite=overwrite, deep_tracing=deep_tracing)
1107
1117
 
1108
1118
  # Use provided name or fall back to function name
1109
- span_name = name or func.__name__
1119
+ original_span_name = name or func.__name__
1110
1120
 
1111
1121
  # Store custom attributes on the function object
1112
- func._judgment_span_name = span_name
1122
+ func._judgment_span_name = original_span_name
1113
1123
  func._judgment_span_type = span_type
1114
1124
 
1115
1125
  # Use the provided deep_tracing value or fall back to the tracer's default
@@ -1118,6 +1128,16 @@ class Tracer:
1118
1128
  if asyncio.iscoroutinefunction(func):
1119
1129
  @functools.wraps(func)
1120
1130
  async def async_wrapper(*args, **kwargs):
1131
+ nonlocal original_span_name
1132
+ class_name = None
1133
+ instance_name = None
1134
+ span_name = original_span_name
1135
+ agent_name = None
1136
+
1137
+ if args and hasattr(args[0], '__class__'):
1138
+ class_name = args[0].__class__.__name__
1139
+ agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1140
+
1121
1141
  # Get current trace from context
1122
1142
  current_trace = current_trace_var.get()
1123
1143
 
@@ -1141,7 +1161,7 @@ class Tracer:
1141
1161
  # Save empty trace and set trace context
1142
1162
  # current_trace.save(empty_save=True, overwrite=overwrite)
1143
1163
  trace_token = current_trace_var.set(current_trace)
1144
-
1164
+
1145
1165
  try:
1146
1166
  # Use span for the function execution within the root trace
1147
1167
  # This sets the current_span_var
@@ -1149,13 +1169,19 @@ class Tracer:
1149
1169
  # Record inputs
1150
1170
  inputs = combine_args_kwargs(func, args, kwargs)
1151
1171
  span.record_input(inputs)
1172
+ if agent_name:
1173
+ span.record_agent_name(agent_name)
1152
1174
 
1153
1175
  if use_deep_tracing:
1154
1176
  with _DeepTracer():
1155
1177
  result = await func(*args, **kwargs)
1156
1178
  else:
1157
- result = await func(*args, **kwargs)
1158
-
1179
+ try:
1180
+ result = await func(*args, **kwargs)
1181
+ except Exception as e:
1182
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1183
+ raise e
1184
+
1159
1185
  # Record output
1160
1186
  span.record_output(result)
1161
1187
  return result
@@ -1170,12 +1196,18 @@ class Tracer:
1170
1196
  with current_trace.span(span_name, span_type=span_type) as span:
1171
1197
  inputs = combine_args_kwargs(func, args, kwargs)
1172
1198
  span.record_input(inputs)
1173
-
1199
+ if agent_name:
1200
+ span.record_agent_name(agent_name)
1201
+
1174
1202
  if use_deep_tracing:
1175
1203
  with _DeepTracer():
1176
1204
  result = await func(*args, **kwargs)
1177
1205
  else:
1178
- result = await func(*args, **kwargs)
1206
+ try:
1207
+ result = await func(*args, **kwargs)
1208
+ except Exception as e:
1209
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1210
+ raise e
1179
1211
 
1180
1212
  span.record_output(result)
1181
1213
  return result
@@ -1184,7 +1216,15 @@ class Tracer:
1184
1216
  else:
1185
1217
  # Non-async function implementation with deep tracing
1186
1218
  @functools.wraps(func)
1187
- def wrapper(*args, **kwargs):
1219
+ def wrapper(*args, **kwargs):
1220
+ nonlocal original_span_name
1221
+ class_name = None
1222
+ instance_name = None
1223
+ span_name = original_span_name
1224
+ agent_name = None
1225
+ if args and hasattr(args[0], '__class__'):
1226
+ class_name = args[0].__class__.__name__
1227
+ agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1188
1228
  # Get current trace from context
1189
1229
  current_trace = current_trace_var.get()
1190
1230
 
@@ -1216,12 +1256,17 @@ class Tracer:
1216
1256
  # Record inputs
1217
1257
  inputs = combine_args_kwargs(func, args, kwargs)
1218
1258
  span.record_input(inputs)
1219
-
1259
+ if agent_name:
1260
+ span.record_agent_name(agent_name)
1220
1261
  if use_deep_tracing:
1221
1262
  with _DeepTracer():
1222
1263
  result = func(*args, **kwargs)
1223
1264
  else:
1224
- result = func(*args, **kwargs)
1265
+ try:
1266
+ result = func(*args, **kwargs)
1267
+ except Exception as e:
1268
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1269
+ raise e
1225
1270
 
1226
1271
  # Record output
1227
1272
  span.record_output(result)
@@ -1238,12 +1283,18 @@ class Tracer:
1238
1283
 
1239
1284
  inputs = combine_args_kwargs(func, args, kwargs)
1240
1285
  span.record_input(inputs)
1241
-
1286
+ if agent_name:
1287
+ span.record_agent_name(agent_name)
1288
+
1242
1289
  if use_deep_tracing:
1243
1290
  with _DeepTracer():
1244
1291
  result = func(*args, **kwargs)
1245
1292
  else:
1246
- result = func(*args, **kwargs)
1293
+ try:
1294
+ result = func(*args, **kwargs)
1295
+ except Exception as e:
1296
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1297
+ raise e
1247
1298
 
1248
1299
  span.record_output(result)
1249
1300
  return result
@@ -1313,8 +1364,9 @@ def wrap(client: Any) -> Any:
1313
1364
  return wrapper_func(response, client, output_entry)
1314
1365
  else:
1315
1366
  format_func = _format_response_output_data if is_responses else _format_output_data
1316
- output_data = format_func(client, response)
1317
- span.record_output(output_data)
1367
+ output, usage = format_func(client, response)
1368
+ span.record_output(output)
1369
+ span.record_usage(usage)
1318
1370
  return response
1319
1371
 
1320
1372
  def _handle_error(span, e, is_async):
@@ -1496,18 +1548,35 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
1496
1548
  Normalizes different response formats into a consistent structure
1497
1549
  for tracing purposes.
1498
1550
  """
1551
+ message_content = None
1552
+ prompt_tokens = 0
1553
+ completion_tokens = 0
1554
+ model_name = None
1499
1555
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1500
- return {
1501
- "content": response.output,
1502
- "usage": {
1503
- "prompt_tokens": response.usage.input_tokens,
1504
- "completion_tokens": response.usage.output_tokens,
1505
- "total_tokens": response.usage.total_tokens
1506
- }
1507
- }
1556
+ model_name = response.model
1557
+ prompt_tokens = response.usage.input_tokens
1558
+ completion_tokens = response.usage.output_tokens
1559
+ message_content = response.output
1508
1560
  else:
1509
1561
  warnings.warn(f"Unsupported client type: {type(client)}")
1510
1562
  return {}
1563
+
1564
+ prompt_cost, completion_cost = cost_per_token(
1565
+ model=model_name,
1566
+ prompt_tokens=prompt_tokens,
1567
+ completion_tokens=completion_tokens,
1568
+ )
1569
+ total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
1570
+ usage = TraceUsage(
1571
+ prompt_tokens=prompt_tokens,
1572
+ completion_tokens=completion_tokens,
1573
+ total_tokens=prompt_tokens + completion_tokens,
1574
+ prompt_tokens_cost_usd=prompt_cost,
1575
+ completion_tokens_cost_usd=completion_cost,
1576
+ total_cost_usd=total_cost_usd,
1577
+ model_name=model_name
1578
+ )
1579
+ return message_content, usage
1511
1580
 
1512
1581
 
1513
1582
  def _format_output_data(client: ApiClient, response: Any) -> dict:
@@ -1521,33 +1590,46 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1521
1590
  - content: The generated text
1522
1591
  - usage: Token usage statistics
1523
1592
  """
1593
+ prompt_tokens = 0
1594
+ completion_tokens = 0
1595
+ model_name = None
1596
+ message_content = None
1597
+
1524
1598
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1525
- return {
1526
- "content": response.choices[0].message.content,
1527
- "usage": {
1528
- "prompt_tokens": response.usage.prompt_tokens,
1529
- "completion_tokens": response.usage.completion_tokens,
1530
- "total_tokens": response.usage.total_tokens
1531
- }
1532
- }
1599
+ model_name = response.model
1600
+ prompt_tokens = response.usage.prompt_tokens
1601
+ completion_tokens = response.usage.completion_tokens
1602
+ message_content = response.choices[0].message.content
1533
1603
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1534
- return {
1535
- "content": response.candidates[0].content.parts[0].text,
1536
- "usage": {
1537
- "prompt_tokens": response.usage_metadata.prompt_token_count,
1538
- "completion_tokens": response.usage_metadata.candidates_token_count,
1539
- "total_tokens": response.usage_metadata.total_token_count
1540
- }
1541
- }
1542
- # Anthropic has a different response structure
1543
- return {
1544
- "content": response.content[0].text,
1545
- "usage": {
1546
- "prompt_tokens": response.usage.input_tokens,
1547
- "completion_tokens": response.usage.output_tokens,
1548
- "total_tokens": response.usage.input_tokens + response.usage.output_tokens
1549
- }
1550
- }
1604
+ model_name = response.model_version
1605
+ prompt_tokens = response.usage_metadata.prompt_token_count
1606
+ completion_tokens = response.usage_metadata.candidates_token_count
1607
+ message_content = response.candidates[0].content.parts[0].text
1608
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1609
+ model_name = response.model
1610
+ prompt_tokens = response.usage.input_tokens
1611
+ completion_tokens = response.usage.output_tokens
1612
+ message_content = response.content[0].text
1613
+ else:
1614
+ warnings.warn(f"Unsupported client type: {type(client)}")
1615
+ return None, None
1616
+
1617
+ prompt_cost, completion_cost = cost_per_token(
1618
+ model=model_name,
1619
+ prompt_tokens=prompt_tokens,
1620
+ completion_tokens=completion_tokens,
1621
+ )
1622
+ total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
1623
+ usage = TraceUsage(
1624
+ prompt_tokens=prompt_tokens,
1625
+ completion_tokens=completion_tokens,
1626
+ total_tokens=prompt_tokens + completion_tokens,
1627
+ prompt_tokens_cost_usd=prompt_cost,
1628
+ completion_tokens_cost_usd=completion_cost,
1629
+ total_cost_usd=total_cost_usd,
1630
+ model_name=model_name
1631
+ )
1632
+ return message_content, usage
1551
1633
 
1552
1634
  def combine_args_kwargs(func, args, kwargs):
1553
1635
  """
@@ -1653,21 +1735,30 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
1653
1735
  # OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
1654
1736
  # This typically requires specific API versions or settings. Often usage is *not* streamed.
1655
1737
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1656
- # Check if usage is directly on the chunk (some models might do this)
1657
- if hasattr(chunk, 'usage') and chunk.usage:
1658
- return {
1659
- "prompt_tokens": chunk.usage.prompt_tokens,
1660
- "completion_tokens": chunk.usage.completion_tokens,
1661
- "total_tokens": chunk.usage.total_tokens
1662
- }
1663
- # Check if usage is nested within choices (less common for final chunk, but check)
1664
- elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
1665
- usage = chunk.choices[0].usage
1666
- return {
1667
- "prompt_tokens": usage.prompt_tokens,
1668
- "completion_tokens": usage.completion_tokens,
1669
- "total_tokens": usage.total_tokens
1670
- }
1738
+ # Check if usage is directly on the chunk (some models might do this)
1739
+ if hasattr(chunk, 'usage') and chunk.usage:
1740
+ prompt_tokens = chunk.usage.prompt_tokens
1741
+ completion_tokens = chunk.usage.completion_tokens
1742
+ # Check if usage is nested within choices (less common for final chunk, but check)
1743
+ elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
1744
+ prompt_tokens = chunk.choices[0].usage.prompt_tokens
1745
+ completion_tokens = chunk.choices[0].usage.completion_tokens
1746
+
1747
+ prompt_cost, completion_cost = cost_per_token(
1748
+ model=chunk.model,
1749
+ prompt_tokens=prompt_tokens,
1750
+ completion_tokens=completion_tokens,
1751
+ )
1752
+ total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
1753
+ return TraceUsage(
1754
+ prompt_tokens=chunk.usage.prompt_tokens,
1755
+ completion_tokens=chunk.usage.completion_tokens,
1756
+ total_tokens=chunk.usage.total_tokens,
1757
+ prompt_tokens_cost_usd=prompt_cost,
1758
+ completion_tokens_cost_usd=completion_cost,
1759
+ total_cost_usd=total_cost_usd,
1760
+ model_name=chunk.model
1761
+ )
1671
1762
  # Anthropic includes usage in the 'message_stop' event type
1672
1763
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1673
1764
  if chunk.type == "message_stop":
@@ -1715,11 +1806,8 @@ def _sync_stream_wrapper(
1715
1806
  final_usage = _extract_usage_from_final_chunk(client, last_chunk)
1716
1807
 
1717
1808
  # Update the trace entry with the accumulated content and usage
1718
- span.output = {
1719
- "content": "".join(content_parts), # Join list at the end
1720
- "usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
1721
- "streamed": True
1722
- }
1809
+ span.output = "".join(content_parts)
1810
+ span.usage = final_usage
1723
1811
  # Note: We might need to adjust _serialize_output if this dict causes issues,
1724
1812
  # but Pydantic's model_dump should handle dicts.
1725
1813
 
@@ -1739,6 +1827,7 @@ async def _async_stream_wrapper(
1739
1827
  target_span_id = span.span_id
1740
1828
 
1741
1829
  try:
1830
+ model_name = ""
1742
1831
  async for chunk in original_stream:
1743
1832
  # Check for OpenAI's final usage chunk
1744
1833
  if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
@@ -1747,16 +1836,18 @@ async def _async_stream_wrapper(
1747
1836
  "completion_tokens": chunk.usage.completion_tokens,
1748
1837
  "total_tokens": chunk.usage.total_tokens
1749
1838
  }
1839
+ model_name = chunk.model
1750
1840
  yield chunk
1751
1841
  continue
1752
1842
 
1753
1843
  if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
1754
- if chunk.type == "message_start":
1755
- if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
1844
+ if chunk.type == "message_start":
1845
+ if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
1756
1846
  anthropic_input_tokens = chunk.message.usage.input_tokens
1757
- elif chunk.type == "message_delta":
1758
- if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
1759
- anthropic_output_tokens += chunk.usage.output_tokens
1847
+ model_name = chunk.message.model
1848
+ elif chunk.type == "message_delta":
1849
+ if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
1850
+ anthropic_output_tokens = chunk.usage.output_tokens
1760
1851
 
1761
1852
  content_part = _extract_content_from_chunk(client, chunk)
1762
1853
  if content_part:
@@ -1779,18 +1870,37 @@ async def _async_stream_wrapper(
1779
1870
  elif anthropic_final_usage:
1780
1871
  usage_info = anthropic_final_usage
1781
1872
  elif last_content_chunk:
1782
- usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
1873
+ usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
1783
1874
 
1875
+ if usage_info and not isinstance(usage_info, TraceUsage):
1876
+ prompt_cost, completion_cost = cost_per_token(
1877
+ model=model_name,
1878
+ prompt_tokens=usage_info["prompt_tokens"],
1879
+ completion_tokens=usage_info["completion_tokens"],
1880
+ )
1881
+ usage_info = TraceUsage(
1882
+ prompt_tokens=usage_info["prompt_tokens"],
1883
+ completion_tokens=usage_info["completion_tokens"],
1884
+ total_tokens=usage_info["total_tokens"],
1885
+ prompt_tokens_cost_usd=prompt_cost,
1886
+ completion_tokens_cost_usd=completion_cost,
1887
+ total_cost_usd=prompt_cost + completion_cost,
1888
+ model_name=model_name
1889
+ )
1784
1890
  if span and hasattr(span, 'output'):
1785
- span.output = {
1786
- "content": "".join(content_parts), # Join list at the end
1787
- "usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
1788
- "streamed": True
1789
- }
1891
+ span.output = ''.join(content_parts)
1892
+ span.usage = usage_info
1790
1893
  start_ts = getattr(span, 'created_at', time.time())
1791
1894
  span.duration = time.time() - start_ts
1792
1895
  # else: # Handle error case if necessary, but remove debug print
1793
1896
 
1897
+ def cost_per_token(*args, **kwargs):
1898
+ try:
1899
+ return _original_cost_per_token(*args, **kwargs)
1900
+ except Exception as e:
1901
+ warnings.warn(f"Error calculating cost per token: {e}")
1902
+ return None, None
1903
+
1794
1904
  class _BaseStreamManagerWrapper:
1795
1905
  def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
1796
1906
  self._original_manager = original_manager
@@ -1872,3 +1982,18 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
1872
1982
  current_span_var.reset(self._span_context_token)
1873
1983
  delattr(self, '_span_context_token')
1874
1984
  return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
1985
+
1986
+ # --- Helper function for instance-prefixed qual_name ---
1987
+ def get_instance_prefixed_name(instance, class_name, class_identifiers):
1988
+ """
1989
+ Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
1990
+ Otherwise, returns None.
1991
+ """
1992
+ if class_name in class_identifiers:
1993
+ attr = class_identifiers[class_name]
1994
+ if hasattr(instance, attr):
1995
+ instance_name = getattr(instance, attr)
1996
+ return instance_name
1997
+ else:
1998
+ raise Exception(f"Attribute {class_identifiers[class_name]} does not exist for {class_name}. Check your identify() decorator.")
1999
+ return None