judgeval 0.0.39__py3-none-any.whl → 0.0.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -5,7 +5,6 @@ Tracing system for judgeval that allows for function tracing using decorators.
5
5
  import asyncio
6
6
  import functools
7
7
  import inspect
8
- import json
9
8
  import os
10
9
  import site
11
10
  import sysconfig
@@ -16,6 +15,7 @@ import uuid
16
15
  import warnings
17
16
  import contextvars
18
17
  import sys
18
+ import json
19
19
  from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
20
20
  from dataclasses import dataclass, field
21
21
  from datetime import datetime
@@ -29,19 +29,16 @@ from typing import (
29
29
  Literal,
30
30
  Optional,
31
31
  Tuple,
32
- Type,
33
- TypeVar,
34
32
  Union,
35
33
  AsyncGenerator,
36
34
  TypeAlias,
37
35
  )
38
36
  from rich import print as rprint
39
- import types # <--- Add this import
37
+ import types
40
38
 
41
39
  # Third-party imports
42
40
  import requests
43
- from litellm import cost_per_token
44
- from pydantic import BaseModel
41
+ from litellm import cost_per_token as _original_cost_per_token
45
42
  from rich import print as rprint
46
43
  from openai import OpenAI, AsyncOpenAI
47
44
  from together import Together, AsyncTogether
@@ -59,12 +56,11 @@ from judgeval.constants import (
59
56
  JUDGMENT_TRACES_DELETE_API_URL,
60
57
  JUDGMENT_PROJECT_DELETE_API_URL,
61
58
  )
62
- from judgeval.data import Example, Trace, TraceSpan
59
+ from judgeval.data import Example, Trace, TraceSpan, TraceUsage
63
60
  from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
64
61
  from judgeval.rules import Rule
65
62
  from judgeval.evaluation_run import EvaluationRun
66
- from judgeval.data.result import ScoringResult
67
- from judgeval.common.utils import validate_api_key
63
+ from judgeval.common.utils import ExcInfo, validate_api_key
68
64
  from judgeval.common.exceptions import JudgmentAPIError
69
65
 
70
66
  # Standard library imports needed for the new class
@@ -155,9 +151,29 @@ class TraceManagerClient:
155
151
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
156
152
  """
157
153
  # Save to Judgment API
154
+
155
+ def fallback_encoder(obj):
156
+ """
157
+ Custom JSON encoder fallback.
158
+ Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
159
+ You can choose which one you prefer or try them in sequence.
160
+ """
161
+ try:
162
+ # Option 1: Prefer __repr__ for a more detailed representation
163
+ return repr(obj)
164
+ except Exception:
165
+ # Option 2: Fallback to str() if __repr__ fails or if you prefer str()
166
+ try:
167
+ return str(obj)
168
+ except Exception as e:
169
+ # If both fail, you might return a placeholder or re-raise
170
+ return f"<Unserializable object of type {type(obj).__name__}: {e}>"
171
+
172
+ serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
173
+
158
174
  response = requests.post(
159
175
  JUDGMENT_TRACES_SAVE_API_URL,
160
- json=trace_data,
176
+ data=serialized_trace_data,
161
177
  headers={
162
178
  "Content-Type": "application/json",
163
179
  "Authorization": f"Bearer {self.judgment_api_key}",
@@ -286,7 +302,7 @@ class TraceClient:
286
302
  tracer: Optional["Tracer"],
287
303
  trace_id: Optional[str] = None,
288
304
  name: str = "default",
289
- project_name: str = "default_project",
305
+ project_name: str = None,
290
306
  overwrite: bool = False,
291
307
  rules: Optional[List[Rule]] = None,
292
308
  enable_monitoring: bool = True,
@@ -296,7 +312,7 @@ class TraceClient:
296
312
  ):
297
313
  self.name = name
298
314
  self.trace_id = trace_id or str(uuid.uuid4())
299
- self.project_name = project_name
315
+ self.project_name = project_name or str(uuid.uuid4())
300
316
  self.overwrite = overwrite
301
317
  self.tracer = tracer
302
318
  self.rules = rules or []
@@ -463,6 +479,7 @@ class TraceClient:
463
479
  if current_span_id:
464
480
  span = self.span_id_to_span[current_span_id]
465
481
  span.evaluation_runs.append(eval_run)
482
+ span.has_evaluation = True # Set the has_evaluation flag
466
483
  self.evaluation_runs.append(eval_run)
467
484
 
468
485
  def add_annotation(self, annotation: TraceAnnotation):
@@ -474,16 +491,47 @@ class TraceClient:
474
491
  current_span_id = current_span_var.get()
475
492
  if current_span_id:
476
493
  span = self.span_id_to_span[current_span_id]
494
+ # Ignore self parameter
495
+ if "self" in inputs:
496
+ del inputs["self"]
477
497
  span.inputs = inputs
498
+
499
+ def record_agent_name(self, agent_name: str):
500
+ current_span_id = current_span_var.get()
501
+ if current_span_id:
502
+ span = self.span_id_to_span[current_span_id]
503
+ span.agent_name = agent_name
504
+
505
+ def record_state_before(self, state: dict):
506
+ """Records the agent's state before a tool execution on the current span.
507
+
508
+ Args:
509
+ state: A dictionary representing the agent's state.
510
+ """
511
+ current_span_id = current_span_var.get()
512
+ if current_span_id:
513
+ span = self.span_id_to_span[current_span_id]
514
+ span.state_before = state
515
+
516
+ def record_state_after(self, state: dict):
517
+ """Records the agent's state after a tool execution on the current span.
518
+
519
+ Args:
520
+ state: A dictionary representing the agent's state.
521
+ """
522
+ current_span_id = current_span_var.get()
523
+ if current_span_id:
524
+ span = self.span_id_to_span[current_span_id]
525
+ span.state_after = state
478
526
 
479
- async def _update_coroutine_output(self, span: TraceSpan, coroutine: Any):
527
+ async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
480
528
  """Helper method to update the output of a trace entry once the coroutine completes"""
481
529
  try:
482
530
  result = await coroutine
483
- span.output = result
531
+ setattr(span, field, result)
484
532
  return result
485
533
  except Exception as e:
486
- span.output = f"Error: {str(e)}"
534
+ setattr(span, field, f"Error: {str(e)}")
487
535
  raise
488
536
 
489
537
  def record_output(self, output: Any):
@@ -493,12 +541,30 @@ class TraceClient:
493
541
  span.output = "<pending>" if inspect.iscoroutine(output) else output
494
542
 
495
543
  if inspect.iscoroutine(output):
496
- asyncio.create_task(self._update_coroutine_output(span, output))
544
+ asyncio.create_task(self._update_coroutine(span, output, "output"))
545
+
546
+ return span # Return the created entry
547
+ # Removed else block - original didn't have one
548
+ return None # Return None if no span_id found
549
+
550
+ def record_usage(self, usage: TraceUsage):
551
+ current_span_id = current_span_var.get()
552
+ if current_span_id:
553
+ span = self.span_id_to_span[current_span_id]
554
+ span.usage = usage
497
555
 
498
556
  return span # Return the created entry
499
557
  # Removed else block - original didn't have one
500
558
  return None # Return None if no span_id found
501
-
559
+
560
+ def record_error(self, error: Dict[str, Any]):
561
+ current_span_id = current_span_var.get()
562
+ if current_span_id:
563
+ span = self.span_id_to_span[current_span_id]
564
+ span.error = error
565
+ return span
566
+ return None
567
+
502
568
  def add_span(self, span: TraceSpan):
503
569
  """Add a trace span to this trace context"""
504
570
  self.trace_spans.append(span)
@@ -523,133 +589,6 @@ class TraceClient:
523
589
  """
524
590
  # Calculate total elapsed time
525
591
  total_duration = self.get_duration()
526
-
527
- # Only count tokens for actual LLM API call spans
528
- llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
529
- for span in self.trace_spans:
530
- span_function_name = span.function # Get function name safely
531
- # Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
532
- is_llm_span = span.span_type == "llm"
533
- has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
534
- output_is_dict = isinstance(span.output, dict)
535
-
536
- # --- DEBUG PRINT 1: Check if condition passes ---
537
- # if is_llm_entry and has_api_suffix and output_is_dict:
538
- # elif is_llm_entry:
539
- # # Print why it failed if it was an LLM entry
540
- # # --- END DEBUG ---
541
-
542
- if is_llm_span and has_api_suffix and output_is_dict:
543
- output = span.output
544
- usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
545
-
546
- # --- DEBUG PRINT 2: Check extracted usage ---
547
- # --- END DEBUG ---
548
-
549
- # --- NEW: Extract model_name correctly from nested inputs ---
550
- model_name = None
551
- span_inputs = span.inputs
552
- if span_inputs:
553
- # Try common locations for model name within the inputs structure
554
- invocation_params = span_inputs.get("invocation_params", {})
555
- serialized_data = span_inputs.get("serialized", {})
556
-
557
- # Look in invocation_params (often directly contains model)
558
- if isinstance(invocation_params, dict):
559
- model_name = invocation_params.get("model")
560
-
561
- # Fallback: Check serialized 'repr' if it contains model info
562
- if not model_name and isinstance(serialized_data, dict):
563
- serialized_repr = serialized_data.get("repr", "")
564
- if "model_name=" in serialized_repr:
565
- try: # Simple parsing attempt
566
- model_name = serialized_repr.split("model_name='")[1].split("'")[0]
567
- except IndexError: pass # Ignore parsing errors
568
-
569
- # Fallback: Check top-level of invocation_params (sometimes passed flat)
570
- if not model_name and isinstance(invocation_params, dict):
571
- model_name = invocation_params.get("model") # Redundant check, but safe
572
-
573
- # Fallback: Check top-level of inputs itself (less likely for callbacks)
574
- if not model_name:
575
- model_name = span_inputs.get("model")
576
-
577
-
578
- # --- END NEW ---
579
-
580
- prompt_tokens = 0
581
- completion_tokens = 0
582
-
583
- # Handle OpenAI/Together format (checks within the 'usage' dict)
584
- if "prompt_tokens" in usage:
585
- prompt_tokens = usage.get("prompt_tokens", 0)
586
- completion_tokens = usage.get("completion_tokens", 0)
587
-
588
- # Handle Anthropic format - MAP values to standard keys
589
- elif "input_tokens" in usage:
590
- prompt_tokens = usage.get("input_tokens", 0) # Get value from input_tokens
591
- completion_tokens = usage.get("output_tokens", 0) # Get value from output_tokens
592
-
593
- # *** Overwrite the usage dict in the entry to use standard keys ***
594
- original_total = usage.get("total_tokens", 0)
595
- original_total_cost = usage.get("total_cost_usd", 0.0) # Preserve if already calculated
596
- # Recalculate cost just in case it wasn't done correctly before
597
- temp_prompt_cost, temp_completion_cost = 0.0, 0.0
598
- if model_name:
599
- try:
600
- temp_prompt_cost, temp_completion_cost = cost_per_token(
601
- model=model_name,
602
- prompt_tokens=prompt_tokens,
603
- completion_tokens=completion_tokens
604
- )
605
- except Exception:
606
- pass # Ignore cost calculation errors here, focus on keys
607
- # Replace the usage dict with one using standard keys but Anthropic values
608
- output["usage"] = {
609
- "prompt_tokens": prompt_tokens,
610
- "completion_tokens": completion_tokens,
611
- "total_tokens": original_total,
612
- "prompt_tokens_cost_usd": temp_prompt_cost, # Use standard cost key
613
- "completion_tokens_cost_usd": temp_completion_cost, # Use standard cost key
614
- "total_cost_usd": original_total_cost if original_total_cost > 0 else (temp_prompt_cost + temp_completion_cost)
615
- }
616
- usage = output["usage"]
617
-
618
- # Calculate costs if model name is available and ensure they are stored with standard keys
619
- prompt_tokens = usage.get("prompt_tokens", 0)
620
- completion_tokens = usage.get("completion_tokens", 0)
621
-
622
- # Calculate costs if model name is available
623
- if model_name:
624
- try:
625
- # Recalculate costs based on potentially mapped tokens
626
- prompt_cost, completion_cost = cost_per_token(
627
- model=model_name,
628
- prompt_tokens=prompt_tokens,
629
- completion_tokens=completion_tokens
630
- )
631
-
632
- # Add cost information directly to the usage dictionary in the condensed entry
633
- # Ensure 'usage' exists in the output dict before modifying it
634
- # Add/Update cost information using standard keys
635
-
636
- if "usage" not in output:
637
- output["usage"] = {} # Initialize if missing
638
- elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
639
- print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
640
- output["usage"] = {} # Reset to dict
641
-
642
- output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
643
- output["usage"]["completion_tokens_cost_usd"] = completion_cost
644
- output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
645
- except Exception as e:
646
- # If cost calculation fails, continue without adding costs
647
- print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
648
- pass
649
- else:
650
- print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
651
-
652
-
653
592
  # Create trace document - Always use standard keys for top-level counts
654
593
  trace_data = {
655
594
  "trace_id": self.trace_id,
@@ -657,7 +596,7 @@ class TraceClient:
657
596
  "project_name": self.project_name,
658
597
  "created_at": datetime.utcfromtimestamp(self.start_time).isoformat(),
659
598
  "duration": total_duration,
660
- "entries": [span.model_dump() for span in self.trace_spans],
599
+ "trace_spans": [span.model_dump() for span in self.trace_spans],
661
600
  "evaluation_runs": [run.model_dump() for run in self.evaluation_runs],
662
601
  "overwrite": overwrite,
663
602
  "offline_mode": self.tracer.offline_mode,
@@ -677,13 +616,46 @@ class TraceClient:
677
616
  def delete(self):
678
617
  return self.trace_manager_client.delete_trace(self.trace_id)
679
618
 
619
+ def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: ExcInfo):
620
+ if not current_trace:
621
+ return
622
+
623
+ exc_type, exc_value, exc_traceback_obj = exc_info
624
+ formatted_exception = {
625
+ "type": exc_type.__name__ if exc_type else "UnknownExceptionType",
626
+ "message": str(exc_value) if exc_value else "No exception message",
627
+ "traceback": traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
628
+ }
629
+
630
+ # This is where we specially handle exceptions that we might want to collect additional data for.
631
+ # When we do this, always try checking the module from sys.modules instead of importing. This will
632
+ # Let us support a wider range of exceptions without needing to import them for all clients.
633
+
634
+ # Most clients (requests, httpx, urllib) support the standard format of exposing error.request.url and error.response.status_code
635
+ # The alternative is to hand select libraries we want from sys.modules and check for them:
636
+ # As an example: requests_module = sys.modules.get("requests", None) // then do things with requests_module;
680
637
 
638
+ # General HTTP Like errors
639
+ try:
640
+ url = getattr(getattr(exc_value, "request", None), "url", None)
641
+ status_code = getattr(getattr(exc_value, "response", None), "status_code", None)
642
+ if status_code:
643
+ formatted_exception["http"] = {
644
+ "url": url if url else "Unknown URL",
645
+ "status_code": status_code if status_code else None,
646
+ }
647
+ except Exception as e:
648
+ pass
649
+
650
+ current_trace.record_error(formatted_exception)
681
651
  class _DeepTracer:
682
652
  _instance: Optional["_DeepTracer"] = None
683
653
  _lock: threading.Lock = threading.Lock()
684
654
  _refcount: int = 0
685
655
  _span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
686
656
  _skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
657
+ _original_sys_trace: Optional[Callable] = None
658
+ _original_threading_trace: Optional[Callable] = None
687
659
 
688
660
  def _get_qual_name(self, frame) -> str:
689
661
  func_name = frame.f_code.co_name
@@ -731,12 +703,53 @@ class _DeepTracer:
731
703
  @functools.cache
732
704
  def _is_user_code(self, filename: str):
733
705
  return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
706
+
707
+ def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
708
+ """Cooperative trace function for sys.settrace that chains with existing tracers."""
709
+ # First, call the original sys trace function if it exists
710
+ original_result = None
711
+ if self._original_sys_trace:
712
+ try:
713
+ original_result = self._original_sys_trace(frame, event, arg)
714
+ except Exception:
715
+ # If the original tracer fails, continue with our tracing
716
+ pass
717
+
718
+ # Then do our own tracing
719
+ our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
720
+
721
+ # Return our tracer to continue tracing, but respect the original's decision
722
+ # If the original tracer returned None (stop tracing), we should respect that
723
+ if original_result is None and self._original_sys_trace:
724
+ return None
725
+
726
+ return our_result or original_result
734
727
 
735
- def _trace(self, frame: types.FrameType, event: str, arg: Any):
728
+ def _cooperative_threading_trace(self, frame: types.FrameType, event: str, arg: Any):
729
+ """Cooperative trace function for threading.settrace that chains with existing tracers."""
730
+ # First, call the original threading trace function if it exists
731
+ original_result = None
732
+ if self._original_threading_trace:
733
+ try:
734
+ original_result = self._original_threading_trace(frame, event, arg)
735
+ except Exception:
736
+ # If the original tracer fails, continue with our tracing
737
+ pass
738
+
739
+ # Then do our own tracing
740
+ our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
741
+
742
+ # Return our tracer to continue tracing, but respect the original's decision
743
+ # If the original tracer returned None (stop tracing), we should respect that
744
+ if original_result is None and self._original_threading_trace:
745
+ return None
746
+
747
+ return our_result or original_result
748
+
749
+ def _trace(self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable):
736
750
  frame.f_trace_lines = False
737
751
  frame.f_trace_opcodes = False
738
752
 
739
-
740
753
  if not self._should_trace(frame):
741
754
  return
742
755
 
@@ -752,6 +765,12 @@ class _DeepTracer:
752
765
  return
753
766
 
754
767
  qual_name = self._get_qual_name(frame)
768
+ instance_name = None
769
+ if 'self' in frame.f_locals:
770
+ instance = frame.f_locals['self']
771
+ class_name = instance.__class__.__name__
772
+ class_identifiers = getattr(Tracer._instance, 'class_identifiers', {})
773
+ instance_name = get_instance_prefixed_name(instance, class_name, class_identifiers)
755
774
  skip_stack = self._skip_stack.get()
756
775
 
757
776
  if event == "call":
@@ -814,7 +833,8 @@ class _DeepTracer:
814
833
  created_at=start_time,
815
834
  span_type="span",
816
835
  parent_span_id=parent_span_id,
817
- function=qual_name
836
+ function=qual_name,
837
+ agent_name=instance_name
818
838
  )
819
839
  current_trace.add_span(span)
820
840
 
@@ -869,35 +889,40 @@ class _DeepTracer:
869
889
  current_span_var.reset(frame.f_locals["_judgment_span_token"])
870
890
 
871
891
  elif event == "exception":
872
- exc_type, exc_value, exc_traceback = arg
873
- formatted_exception = {
874
- "type": exc_type.__name__,
875
- "message": str(exc_value),
876
- "traceback": traceback.format_tb(exc_traceback)
877
- }
878
- current_trace = current_trace_var.get()
879
- current_trace.record_output({
880
- "error": formatted_exception
881
- })
892
+ exc_type = arg[0]
893
+ if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
894
+ return
895
+ _capture_exception_for_trace(current_trace, arg)
896
+
882
897
 
883
- return self._trace
898
+ return continuation_func
884
899
 
885
900
  def __enter__(self):
886
901
  with self._lock:
887
902
  self._refcount += 1
888
903
  if self._refcount == 1:
904
+ # Store the existing trace functions before setting ours
905
+ self._original_sys_trace = sys.gettrace()
906
+ self._original_threading_trace = threading.gettrace()
907
+
889
908
  self._skip_stack.set([])
890
909
  self._span_stack.set([])
891
- sys.settrace(self._trace)
892
- threading.settrace(self._trace)
910
+
911
+ sys.settrace(self._cooperative_sys_trace)
912
+ threading.settrace(self._cooperative_threading_trace)
893
913
  return self
894
914
 
895
915
  def __exit__(self, exc_type, exc_val, exc_tb):
896
916
  with self._lock:
897
917
  self._refcount -= 1
898
918
  if self._refcount == 0:
899
- sys.settrace(None)
900
- threading.settrace(None)
919
+ # Restore the original trace functions instead of setting to None
920
+ sys.settrace(self._original_sys_trace)
921
+ threading.settrace(self._original_threading_trace)
922
+
923
+ # Clean up the references
924
+ self._original_sys_trace = None
925
+ self._original_threading_trace = None
901
926
 
902
927
 
903
928
  def log(self, message: str, level: str = "info"):
@@ -920,7 +945,7 @@ class Tracer:
920
945
  def __init__(
921
946
  self,
922
947
  api_key: str = os.getenv("JUDGMENT_API_KEY"),
923
- project_name: str = "default_project",
948
+ project_name: str = None,
924
949
  rules: Optional[List[Rule]] = None, # Added rules parameter
925
950
  organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
926
951
  enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
@@ -946,13 +971,9 @@ class Tracer:
946
971
  raise ValueError("Tracer must be configured with an Organization ID")
947
972
  if use_s3 and not s3_bucket_name:
948
973
  raise ValueError("S3 bucket name must be provided when use_s3 is True")
949
- if use_s3 and not (s3_aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")):
950
- raise ValueError("AWS Access Key ID must be provided when use_s3 is True")
951
- if use_s3 and not (s3_aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")):
952
- raise ValueError("AWS Secret Access Key must be provided when use_s3 is True")
953
974
 
954
975
  self.api_key: str = api_key
955
- self.project_name: str = project_name
976
+ self.project_name: str = project_name or str(uuid.uuid4())
956
977
  self.organization_id: str = organization_id
957
978
  self._current_trace: Optional[str] = None
958
979
  self._active_trace_client: Optional[TraceClient] = None # Add active trace client attribute
@@ -961,6 +982,7 @@ class Tracer:
961
982
  self.initialized: bool = True
962
983
  self.enable_monitoring: bool = enable_monitoring
963
984
  self.enable_evaluations: bool = enable_evaluations
985
+ self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
964
986
 
965
987
  # Initialize S3 storage if enabled
966
988
  self.use_s3 = use_s3
@@ -1084,6 +1106,92 @@ class Tracer:
1084
1106
 
1085
1107
  rprint(f"[bold]{label}:[/bold] {msg}")
1086
1108
 
1109
+ def identify(self, identifier: str, track_state: bool = False, track_attributes: Optional[List[str]] = None, field_mappings: Optional[Dict[str, str]] = None):
1110
+ """
1111
+ Class decorator that associates a class with a custom identifier and enables state tracking.
1112
+
1113
+ This decorator creates a mapping between the class name and the provided
1114
+ identifier, which can be useful for tagging, grouping, or referencing
1115
+ classes in a standardized way. It also enables automatic state capture
1116
+ for instances of the decorated class when used with tracing.
1117
+
1118
+ Args:
1119
+ identifier: The identifier to associate with the decorated class.
1120
+ This will be used as the instance name in traces.
1121
+ track_state: Whether to automatically capture the state (attributes)
1122
+ of instances before and after function execution. Defaults to False.
1123
+ track_attributes: Optional list of specific attribute names to track.
1124
+ If None, all non-private attributes (not starting with '_')
1125
+ will be tracked when track_state=True.
1126
+ field_mappings: Optional dictionary mapping internal attribute names to
1127
+ display names in the captured state. For example:
1128
+ {"system_prompt": "instructions"} will capture the
1129
+ 'instructions' attribute as 'system_prompt' in the state.
1130
+
1131
+ Example:
1132
+ @tracer.identify(identifier="user_model", track_state=True, track_attributes=["name", "age"], field_mappings={"system_prompt": "instructions"})
1133
+ class User:
1134
+ # Class implementation
1135
+ """
1136
+ def decorator(cls):
1137
+ class_name = cls.__name__
1138
+ self.class_identifiers[class_name] = {
1139
+ "identifier": identifier,
1140
+ "track_state": track_state,
1141
+ "track_attributes": track_attributes,
1142
+ "field_mappings": field_mappings or {}
1143
+ }
1144
+ return cls
1145
+
1146
+ return decorator
1147
+
1148
+ def _capture_instance_state(self, instance: Any, class_config: Dict[str, Any]) -> Dict[str, Any]:
1149
+ """
1150
+ Capture the state of an instance based on class configuration.
1151
+ Args:
1152
+ instance: The instance to capture the state of.
1153
+ class_config: Configuration dictionary for state capture,
1154
+ expected to contain 'track_attributes' and 'field_mappings'.
1155
+ """
1156
+ track_attributes = class_config.get('track_attributes')
1157
+ field_mappings = class_config.get('field_mappings')
1158
+
1159
+ if track_attributes:
1160
+
1161
+ state = {attr: getattr(instance, attr, None) for attr in track_attributes}
1162
+ else:
1163
+
1164
+ state = {k: v for k, v in instance.__dict__.items() if not k.startswith('_')}
1165
+
1166
+ if field_mappings:
1167
+ state['field_mappings'] = field_mappings
1168
+
1169
+ return state
1170
+
1171
+
1172
+ def _get_instance_state_if_tracked(self, args):
1173
+ """
1174
+ Extract instance state if the instance should be tracked.
1175
+
1176
+ Returns the captured state dict if tracking is enabled, None otherwise.
1177
+ """
1178
+ if args and hasattr(args[0], '__class__'):
1179
+ instance = args[0]
1180
+ class_name = instance.__class__.__name__
1181
+ if (class_name in self.class_identifiers and
1182
+ isinstance(self.class_identifiers[class_name], dict) and
1183
+ self.class_identifiers[class_name].get('track_state', False)):
1184
+ return self._capture_instance_state(instance, self.class_identifiers[class_name])
1185
+
1186
+ def _conditionally_capture_and_record_state(self, trace_client_instance: TraceClient, args: tuple, is_before: bool):
1187
+ """Captures instance state if tracked and records it via the trace_client."""
1188
+ state = self._get_instance_state_if_tracked(args)
1189
+ if state:
1190
+ if is_before:
1191
+ trace_client_instance.record_state_before(state)
1192
+ else:
1193
+ trace_client_instance.record_state_after(state)
1194
+
1087
1195
  def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1088
1196
  """
1089
1197
  Decorator to trace function execution with detailed entry/exit information.
@@ -1106,10 +1214,10 @@ class Tracer:
1106
1214
  overwrite=overwrite, deep_tracing=deep_tracing)
1107
1215
 
1108
1216
  # Use provided name or fall back to function name
1109
- span_name = name or func.__name__
1217
+ original_span_name = name or func.__name__
1110
1218
 
1111
1219
  # Store custom attributes on the function object
1112
- func._judgment_span_name = span_name
1220
+ func._judgment_span_name = original_span_name
1113
1221
  func._judgment_span_type = span_type
1114
1222
 
1115
1223
  # Use the provided deep_tracing value or fall back to the tracer's default
@@ -1118,6 +1226,16 @@ class Tracer:
1118
1226
  if asyncio.iscoroutinefunction(func):
1119
1227
  @functools.wraps(func)
1120
1228
  async def async_wrapper(*args, **kwargs):
1229
+ nonlocal original_span_name
1230
+ class_name = None
1231
+ instance_name = None
1232
+ span_name = original_span_name
1233
+ agent_name = None
1234
+
1235
+ if args and hasattr(args[0], '__class__'):
1236
+ class_name = args[0].__class__.__name__
1237
+ agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1238
+
1121
1239
  # Get current trace from context
1122
1240
  current_trace = current_trace_var.get()
1123
1241
 
@@ -1141,7 +1259,7 @@ class Tracer:
1141
1259
  # Save empty trace and set trace context
1142
1260
  # current_trace.save(empty_save=True, overwrite=overwrite)
1143
1261
  trace_token = current_trace_var.set(current_trace)
1144
-
1262
+
1145
1263
  try:
1146
1264
  # Use span for the function execution within the root trace
1147
1265
  # This sets the current_span_var
@@ -1149,12 +1267,24 @@ class Tracer:
1149
1267
  # Record inputs
1150
1268
  inputs = combine_args_kwargs(func, args, kwargs)
1151
1269
  span.record_input(inputs)
1270
+ if agent_name:
1271
+ span.record_agent_name(agent_name)
1272
+
1273
+ # Capture state before execution
1274
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
1152
1275
 
1153
1276
  if use_deep_tracing:
1154
1277
  with _DeepTracer():
1155
1278
  result = await func(*args, **kwargs)
1156
1279
  else:
1157
- result = await func(*args, **kwargs)
1280
+ try:
1281
+ result = await func(*args, **kwargs)
1282
+ except Exception as e:
1283
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1284
+ raise e
1285
+
1286
+ # Capture state after execution
1287
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
1158
1288
 
1159
1289
  # Record output
1160
1290
  span.record_output(result)
@@ -1170,12 +1300,24 @@ class Tracer:
1170
1300
  with current_trace.span(span_name, span_type=span_type) as span:
1171
1301
  inputs = combine_args_kwargs(func, args, kwargs)
1172
1302
  span.record_input(inputs)
1173
-
1303
+ if agent_name:
1304
+ span.record_agent_name(agent_name)
1305
+
1306
+ # Capture state before execution
1307
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
1308
+
1174
1309
  if use_deep_tracing:
1175
1310
  with _DeepTracer():
1176
1311
  result = await func(*args, **kwargs)
1177
1312
  else:
1178
- result = await func(*args, **kwargs)
1313
+ try:
1314
+ result = await func(*args, **kwargs)
1315
+ except Exception as e:
1316
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1317
+ raise e
1318
+
1319
+ # Capture state after execution
1320
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
1179
1321
 
1180
1322
  span.record_output(result)
1181
1323
  return result
@@ -1184,7 +1326,15 @@ class Tracer:
1184
1326
  else:
1185
1327
  # Non-async function implementation with deep tracing
1186
1328
  @functools.wraps(func)
1187
- def wrapper(*args, **kwargs):
1329
+ def wrapper(*args, **kwargs):
1330
+ nonlocal original_span_name
1331
+ class_name = None
1332
+ instance_name = None
1333
+ span_name = original_span_name
1334
+ agent_name = None
1335
+ if args and hasattr(args[0], '__class__'):
1336
+ class_name = args[0].__class__.__name__
1337
+ agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
1188
1338
  # Get current trace from context
1189
1339
  current_trace = current_trace_var.get()
1190
1340
 
@@ -1216,12 +1366,24 @@ class Tracer:
1216
1366
  # Record inputs
1217
1367
  inputs = combine_args_kwargs(func, args, kwargs)
1218
1368
  span.record_input(inputs)
1219
-
1369
+ if agent_name:
1370
+ span.record_agent_name(agent_name)
1371
+ # Capture state before execution
1372
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
1373
+
1220
1374
  if use_deep_tracing:
1221
1375
  with _DeepTracer():
1222
1376
  result = func(*args, **kwargs)
1223
1377
  else:
1224
- result = func(*args, **kwargs)
1378
+ try:
1379
+ result = func(*args, **kwargs)
1380
+ except Exception as e:
1381
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1382
+ raise e
1383
+
1384
+ # Capture state after execution
1385
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
1386
+
1225
1387
 
1226
1388
  # Record output
1227
1389
  span.record_output(result)
@@ -1238,12 +1400,24 @@ class Tracer:
1238
1400
 
1239
1401
  inputs = combine_args_kwargs(func, args, kwargs)
1240
1402
  span.record_input(inputs)
1241
-
1403
+ if agent_name:
1404
+ span.record_agent_name(agent_name)
1405
+
1406
+ # Capture state before execution
1407
+ self._conditionally_capture_and_record_state(span, args, is_before=True)
1408
+
1242
1409
  if use_deep_tracing:
1243
1410
  with _DeepTracer():
1244
1411
  result = func(*args, **kwargs)
1245
1412
  else:
1246
- result = func(*args, **kwargs)
1413
+ try:
1414
+ result = func(*args, **kwargs)
1415
+ except Exception as e:
1416
+ _capture_exception_for_trace(current_trace, sys.exc_info())
1417
+ raise e
1418
+
1419
+ # Capture state after execution
1420
+ self._conditionally_capture_and_record_state(span, args, is_before=False)
1247
1421
 
1248
1422
  span.record_output(result)
1249
1423
  return result
@@ -1313,17 +1487,11 @@ def wrap(client: Any) -> Any:
1313
1487
  return wrapper_func(response, client, output_entry)
1314
1488
  else:
1315
1489
  format_func = _format_response_output_data if is_responses else _format_output_data
1316
- output_data = format_func(client, response)
1317
- span.record_output(output_data)
1490
+ output, usage = format_func(client, response)
1491
+ span.record_output(output)
1492
+ span.record_usage(usage)
1318
1493
  return response
1319
1494
 
1320
- def _handle_error(span, e, is_async):
1321
- """Handle and record errors"""
1322
- call_type = "async" if is_async else "sync"
1323
- print(f"Error during wrapped {call_type} API call ({span_name}): {e}")
1324
- span.record_output({"error": str(e)})
1325
- raise
1326
-
1327
1495
  # --- Traced Async Functions ---
1328
1496
  async def traced_create_async(*args, **kwargs):
1329
1497
  current_trace = current_trace_var.get()
@@ -1337,7 +1505,8 @@ def wrap(client: Any) -> Any:
1337
1505
  response_or_iterator = await original_create(*args, **kwargs)
1338
1506
  return _format_and_record_output(span, response_or_iterator, is_streaming, True, False)
1339
1507
  except Exception as e:
1340
- return _handle_error(span, e, True)
1508
+ _capture_exception_for_trace(span, sys.exc_info())
1509
+ raise e
1341
1510
 
1342
1511
  # Async responses for OpenAI clients
1343
1512
  async def traced_response_create_async(*args, **kwargs):
@@ -1352,7 +1521,8 @@ def wrap(client: Any) -> Any:
1352
1521
  response_or_iterator = await original_responses_create(*args, **kwargs)
1353
1522
  return _format_and_record_output(span, response_or_iterator, is_streaming, True, True)
1354
1523
  except Exception as e:
1355
- return _handle_error(span, e, True)
1524
+ _capture_exception_for_trace(span, sys.exc_info())
1525
+ raise e
1356
1526
 
1357
1527
  # Function replacing .stream() for async clients
1358
1528
  def traced_stream_async(*args, **kwargs):
@@ -1383,7 +1553,8 @@ def wrap(client: Any) -> Any:
1383
1553
  response_or_iterator = original_create(*args, **kwargs)
1384
1554
  return _format_and_record_output(span, response_or_iterator, is_streaming, False, False)
1385
1555
  except Exception as e:
1386
- return _handle_error(span, e, False)
1556
+ _capture_exception_for_trace(span, sys.exc_info())
1557
+ raise e
1387
1558
 
1388
1559
  def traced_response_create_sync(*args, **kwargs):
1389
1560
  current_trace = current_trace_var.get()
@@ -1397,7 +1568,8 @@ def wrap(client: Any) -> Any:
1397
1568
  response_or_iterator = original_responses_create(*args, **kwargs)
1398
1569
  return _format_and_record_output(span, response_or_iterator, is_streaming, False, True)
1399
1570
  except Exception as e:
1400
- return _handle_error(span, e, False)
1571
+ _capture_exception_for_trace(span, sys.exc_info())
1572
+ raise e
1401
1573
 
1402
1574
  # Function replacing sync .stream()
1403
1575
  def traced_stream_sync(*args, **kwargs):
@@ -1496,18 +1668,35 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
1496
1668
  Normalizes different response formats into a consistent structure
1497
1669
  for tracing purposes.
1498
1670
  """
1671
+ message_content = None
1672
+ prompt_tokens = 0
1673
+ completion_tokens = 0
1674
+ model_name = None
1499
1675
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1500
- return {
1501
- "content": response.output,
1502
- "usage": {
1503
- "prompt_tokens": response.usage.input_tokens,
1504
- "completion_tokens": response.usage.output_tokens,
1505
- "total_tokens": response.usage.total_tokens
1506
- }
1507
- }
1676
+ model_name = response.model
1677
+ prompt_tokens = response.usage.input_tokens
1678
+ completion_tokens = response.usage.output_tokens
1679
+ message_content = response.output
1508
1680
  else:
1509
1681
  warnings.warn(f"Unsupported client type: {type(client)}")
1510
1682
  return {}
1683
+
1684
+ prompt_cost, completion_cost = cost_per_token(
1685
+ model=model_name,
1686
+ prompt_tokens=prompt_tokens,
1687
+ completion_tokens=completion_tokens,
1688
+ )
1689
+ total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
1690
+ usage = TraceUsage(
1691
+ prompt_tokens=prompt_tokens,
1692
+ completion_tokens=completion_tokens,
1693
+ total_tokens=prompt_tokens + completion_tokens,
1694
+ prompt_tokens_cost_usd=prompt_cost,
1695
+ completion_tokens_cost_usd=completion_cost,
1696
+ total_cost_usd=total_cost_usd,
1697
+ model_name=model_name
1698
+ )
1699
+ return message_content, usage
1511
1700
 
1512
1701
 
1513
1702
  def _format_output_data(client: ApiClient, response: Any) -> dict:
@@ -1521,33 +1710,46 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1521
1710
  - content: The generated text
1522
1711
  - usage: Token usage statistics
1523
1712
  """
1713
+ prompt_tokens = 0
1714
+ completion_tokens = 0
1715
+ model_name = None
1716
+ message_content = None
1717
+
1524
1718
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1525
- return {
1526
- "content": response.choices[0].message.content,
1527
- "usage": {
1528
- "prompt_tokens": response.usage.prompt_tokens,
1529
- "completion_tokens": response.usage.completion_tokens,
1530
- "total_tokens": response.usage.total_tokens
1531
- }
1532
- }
1719
+ model_name = response.model
1720
+ prompt_tokens = response.usage.prompt_tokens
1721
+ completion_tokens = response.usage.completion_tokens
1722
+ message_content = response.choices[0].message.content
1533
1723
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1534
- return {
1535
- "content": response.candidates[0].content.parts[0].text,
1536
- "usage": {
1537
- "prompt_tokens": response.usage_metadata.prompt_token_count,
1538
- "completion_tokens": response.usage_metadata.candidates_token_count,
1539
- "total_tokens": response.usage_metadata.total_token_count
1540
- }
1541
- }
1542
- # Anthropic has a different response structure
1543
- return {
1544
- "content": response.content[0].text,
1545
- "usage": {
1546
- "prompt_tokens": response.usage.input_tokens,
1547
- "completion_tokens": response.usage.output_tokens,
1548
- "total_tokens": response.usage.input_tokens + response.usage.output_tokens
1549
- }
1550
- }
1724
+ model_name = response.model_version
1725
+ prompt_tokens = response.usage_metadata.prompt_token_count
1726
+ completion_tokens = response.usage_metadata.candidates_token_count
1727
+ message_content = response.candidates[0].content.parts[0].text
1728
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1729
+ model_name = response.model
1730
+ prompt_tokens = response.usage.input_tokens
1731
+ completion_tokens = response.usage.output_tokens
1732
+ message_content = response.content[0].text
1733
+ else:
1734
+ warnings.warn(f"Unsupported client type: {type(client)}")
1735
+ return None, None
1736
+
1737
+ prompt_cost, completion_cost = cost_per_token(
1738
+ model=model_name,
1739
+ prompt_tokens=prompt_tokens,
1740
+ completion_tokens=completion_tokens,
1741
+ )
1742
+ total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
1743
+ usage = TraceUsage(
1744
+ prompt_tokens=prompt_tokens,
1745
+ completion_tokens=completion_tokens,
1746
+ total_tokens=prompt_tokens + completion_tokens,
1747
+ prompt_tokens_cost_usd=prompt_cost,
1748
+ completion_tokens_cost_usd=completion_cost,
1749
+ total_cost_usd=total_cost_usd,
1750
+ model_name=model_name
1751
+ )
1752
+ return message_content, usage
1551
1753
 
1552
1754
  def combine_args_kwargs(func, args, kwargs):
1553
1755
  """
@@ -1653,21 +1855,30 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
1653
1855
  # OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
1654
1856
  # This typically requires specific API versions or settings. Often usage is *not* streamed.
1655
1857
  if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1656
- # Check if usage is directly on the chunk (some models might do this)
1657
- if hasattr(chunk, 'usage') and chunk.usage:
1658
- return {
1659
- "prompt_tokens": chunk.usage.prompt_tokens,
1660
- "completion_tokens": chunk.usage.completion_tokens,
1661
- "total_tokens": chunk.usage.total_tokens
1662
- }
1663
- # Check if usage is nested within choices (less common for final chunk, but check)
1664
- elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
1665
- usage = chunk.choices[0].usage
1666
- return {
1667
- "prompt_tokens": usage.prompt_tokens,
1668
- "completion_tokens": usage.completion_tokens,
1669
- "total_tokens": usage.total_tokens
1670
- }
1858
+ # Check if usage is directly on the chunk (some models might do this)
1859
+ if hasattr(chunk, 'usage') and chunk.usage:
1860
+ prompt_tokens = chunk.usage.prompt_tokens
1861
+ completion_tokens = chunk.usage.completion_tokens
1862
+ # Check if usage is nested within choices (less common for final chunk, but check)
1863
+ elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
1864
+ prompt_tokens = chunk.choices[0].usage.prompt_tokens
1865
+ completion_tokens = chunk.choices[0].usage.completion_tokens
1866
+
1867
+ prompt_cost, completion_cost = cost_per_token(
1868
+ model=chunk.model,
1869
+ prompt_tokens=prompt_tokens,
1870
+ completion_tokens=completion_tokens,
1871
+ )
1872
+ total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
1873
+ return TraceUsage(
1874
+ prompt_tokens=chunk.usage.prompt_tokens,
1875
+ completion_tokens=chunk.usage.completion_tokens,
1876
+ total_tokens=chunk.usage.total_tokens,
1877
+ prompt_tokens_cost_usd=prompt_cost,
1878
+ completion_tokens_cost_usd=completion_cost,
1879
+ total_cost_usd=total_cost_usd,
1880
+ model_name=chunk.model
1881
+ )
1671
1882
  # Anthropic includes usage in the 'message_stop' event type
1672
1883
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1673
1884
  if chunk.type == "message_stop":
@@ -1715,11 +1926,8 @@ def _sync_stream_wrapper(
1715
1926
  final_usage = _extract_usage_from_final_chunk(client, last_chunk)
1716
1927
 
1717
1928
  # Update the trace entry with the accumulated content and usage
1718
- span.output = {
1719
- "content": "".join(content_parts), # Join list at the end
1720
- "usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
1721
- "streamed": True
1722
- }
1929
+ span.output = "".join(content_parts)
1930
+ span.usage = final_usage
1723
1931
  # Note: We might need to adjust _serialize_output if this dict causes issues,
1724
1932
  # but Pydantic's model_dump should handle dicts.
1725
1933
 
@@ -1739,6 +1947,7 @@ async def _async_stream_wrapper(
1739
1947
  target_span_id = span.span_id
1740
1948
 
1741
1949
  try:
1950
+ model_name = ""
1742
1951
  async for chunk in original_stream:
1743
1952
  # Check for OpenAI's final usage chunk
1744
1953
  if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
@@ -1747,16 +1956,18 @@ async def _async_stream_wrapper(
1747
1956
  "completion_tokens": chunk.usage.completion_tokens,
1748
1957
  "total_tokens": chunk.usage.total_tokens
1749
1958
  }
1959
+ model_name = chunk.model
1750
1960
  yield chunk
1751
1961
  continue
1752
1962
 
1753
1963
  if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
1754
- if chunk.type == "message_start":
1755
- if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
1964
+ if chunk.type == "message_start":
1965
+ if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
1756
1966
  anthropic_input_tokens = chunk.message.usage.input_tokens
1757
- elif chunk.type == "message_delta":
1758
- if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
1759
- anthropic_output_tokens += chunk.usage.output_tokens
1967
+ model_name = chunk.message.model
1968
+ elif chunk.type == "message_delta":
1969
+ if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
1970
+ anthropic_output_tokens = chunk.usage.output_tokens
1760
1971
 
1761
1972
  content_part = _extract_content_from_chunk(client, chunk)
1762
1973
  if content_part:
@@ -1779,18 +1990,37 @@ async def _async_stream_wrapper(
1779
1990
  elif anthropic_final_usage:
1780
1991
  usage_info = anthropic_final_usage
1781
1992
  elif last_content_chunk:
1782
- usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
1993
+ usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
1783
1994
 
1995
+ if usage_info and not isinstance(usage_info, TraceUsage):
1996
+ prompt_cost, completion_cost = cost_per_token(
1997
+ model=model_name,
1998
+ prompt_tokens=usage_info["prompt_tokens"],
1999
+ completion_tokens=usage_info["completion_tokens"],
2000
+ )
2001
+ usage_info = TraceUsage(
2002
+ prompt_tokens=usage_info["prompt_tokens"],
2003
+ completion_tokens=usage_info["completion_tokens"],
2004
+ total_tokens=usage_info["total_tokens"],
2005
+ prompt_tokens_cost_usd=prompt_cost,
2006
+ completion_tokens_cost_usd=completion_cost,
2007
+ total_cost_usd=prompt_cost + completion_cost,
2008
+ model_name=model_name
2009
+ )
1784
2010
  if span and hasattr(span, 'output'):
1785
- span.output = {
1786
- "content": "".join(content_parts), # Join list at the end
1787
- "usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
1788
- "streamed": True
1789
- }
2011
+ span.output = ''.join(content_parts)
2012
+ span.usage = usage_info
1790
2013
  start_ts = getattr(span, 'created_at', time.time())
1791
2014
  span.duration = time.time() - start_ts
1792
2015
  # else: # Handle error case if necessary, but remove debug print
1793
2016
 
2017
+ def cost_per_token(*args, **kwargs):
2018
+ try:
2019
+ return _original_cost_per_token(*args, **kwargs)
2020
+ except Exception as e:
2021
+ warnings.warn(f"Error calculating cost per token: {e}")
2022
+ return None, None
2023
+
1794
2024
  class _BaseStreamManagerWrapper:
1795
2025
  def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
1796
2026
  self._original_manager = original_manager
@@ -1872,3 +2102,20 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
1872
2102
  current_span_var.reset(self._span_context_token)
1873
2103
  delattr(self, '_span_context_token')
1874
2104
  return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2105
+
2106
+ # --- Helper function for instance-prefixed qual_name ---
2107
+ def get_instance_prefixed_name(instance, class_name, class_identifiers):
2108
+ """
2109
+ Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
2110
+ Otherwise, returns None.
2111
+ """
2112
+ if class_name in class_identifiers:
2113
+ class_config = class_identifiers[class_name]
2114
+ attr = class_config['identifier']
2115
+
2116
+ if hasattr(instance, attr):
2117
+ instance_name = getattr(instance, attr)
2118
+ return instance_name
2119
+ else:
2120
+ raise Exception(f"Attribute {attr} does not exist for {class_name}. Check your identify() decorator.")
2121
+ return None