judgeval 0.0.38__py3-none-any.whl → 0.0.40__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/clients.py +6 -4
- judgeval/common/tracer.py +361 -236
- judgeval/constants.py +3 -0
- judgeval/data/__init__.py +2 -1
- judgeval/data/example.py +14 -13
- judgeval/data/tool.py +47 -0
- judgeval/data/trace.py +28 -39
- judgeval/data/trace_run.py +2 -1
- judgeval/evaluation_run.py +4 -7
- judgeval/judgment_client.py +27 -6
- judgeval/run_evaluation.py +395 -37
- judgeval/scorers/__init__.py +4 -1
- judgeval/scorers/judgeval_scorer.py +8 -0
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +4 -0
- judgeval/scorers/judgeval_scorers/api_scorers/classifier_scorer.py +124 -0
- judgeval/scorers/judgeval_scorers/api_scorers/tool_dependency.py +20 -0
- judgeval/scorers/judgeval_scorers/classifiers/text2sql/text2sql_scorer.py +1 -1
- judgeval/scorers/prompt_scorer.py +5 -164
- judgeval/scorers/score.py +15 -15
- judgeval-0.0.40.dist-info/METADATA +1441 -0
- {judgeval-0.0.38.dist-info → judgeval-0.0.40.dist-info}/RECORD +23 -20
- judgeval-0.0.38.dist-info/METADATA +0 -247
- {judgeval-0.0.38.dist-info → judgeval-0.0.40.dist-info}/WHEEL +0 -0
- {judgeval-0.0.38.dist-info → judgeval-0.0.40.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -34,13 +34,14 @@ from typing import (
|
|
34
34
|
Union,
|
35
35
|
AsyncGenerator,
|
36
36
|
TypeAlias,
|
37
|
+
Set
|
37
38
|
)
|
38
39
|
from rich import print as rprint
|
39
40
|
import types # <--- Add this import
|
40
41
|
|
41
42
|
# Third-party imports
|
42
43
|
import requests
|
43
|
-
from litellm import cost_per_token
|
44
|
+
from litellm import cost_per_token as _original_cost_per_token
|
44
45
|
from pydantic import BaseModel
|
45
46
|
from rich import print as rprint
|
46
47
|
from openai import OpenAI, AsyncOpenAI
|
@@ -59,7 +60,7 @@ from judgeval.constants import (
|
|
59
60
|
JUDGMENT_TRACES_DELETE_API_URL,
|
60
61
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
61
62
|
)
|
62
|
-
from judgeval.data import Example, Trace, TraceSpan
|
63
|
+
from judgeval.data import Example, Trace, TraceSpan, TraceUsage
|
63
64
|
from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
|
64
65
|
from judgeval.rules import Rule
|
65
66
|
from judgeval.evaluation_run import EvaluationRun
|
@@ -155,9 +156,29 @@ class TraceManagerClient:
|
|
155
156
|
NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
|
156
157
|
"""
|
157
158
|
# Save to Judgment API
|
159
|
+
|
160
|
+
def fallback_encoder(obj):
|
161
|
+
"""
|
162
|
+
Custom JSON encoder fallback.
|
163
|
+
Tries to use obj.__repr__(), then str(obj) if that fails or for a simpler string.
|
164
|
+
You can choose which one you prefer or try them in sequence.
|
165
|
+
"""
|
166
|
+
try:
|
167
|
+
# Option 1: Prefer __repr__ for a more detailed representation
|
168
|
+
return repr(obj)
|
169
|
+
except Exception:
|
170
|
+
# Option 2: Fallback to str() if __repr__ fails or if you prefer str()
|
171
|
+
try:
|
172
|
+
return str(obj)
|
173
|
+
except Exception as e:
|
174
|
+
# If both fail, you might return a placeholder or re-raise
|
175
|
+
return f"<Unserializable object of type {type(obj).__name__}: {e}>"
|
176
|
+
|
177
|
+
serialized_trace_data = json.dumps(trace_data, default=fallback_encoder)
|
178
|
+
|
158
179
|
response = requests.post(
|
159
180
|
JUDGMENT_TRACES_SAVE_API_URL,
|
160
|
-
|
181
|
+
data=serialized_trace_data,
|
161
182
|
headers={
|
162
183
|
"Content-Type": "application/json",
|
163
184
|
"Authorization": f"Bearer {self.judgment_api_key}",
|
@@ -463,6 +484,7 @@ class TraceClient:
|
|
463
484
|
if current_span_id:
|
464
485
|
span = self.span_id_to_span[current_span_id]
|
465
486
|
span.evaluation_runs.append(eval_run)
|
487
|
+
span.has_evaluation = True # Set the has_evaluation flag
|
466
488
|
self.evaluation_runs.append(eval_run)
|
467
489
|
|
468
490
|
def add_annotation(self, annotation: TraceAnnotation):
|
@@ -474,16 +496,25 @@ class TraceClient:
|
|
474
496
|
current_span_id = current_span_var.get()
|
475
497
|
if current_span_id:
|
476
498
|
span = self.span_id_to_span[current_span_id]
|
499
|
+
# Ignore self parameter
|
500
|
+
if "self" in inputs:
|
501
|
+
del inputs["self"]
|
477
502
|
span.inputs = inputs
|
503
|
+
|
504
|
+
def record_agent_name(self, agent_name: str):
|
505
|
+
current_span_id = current_span_var.get()
|
506
|
+
if current_span_id:
|
507
|
+
span = self.span_id_to_span[current_span_id]
|
508
|
+
span.agent_name = agent_name
|
478
509
|
|
479
|
-
async def
|
510
|
+
async def _update_coroutine(self, span: TraceSpan, coroutine: Any, field: str):
|
480
511
|
"""Helper method to update the output of a trace entry once the coroutine completes"""
|
481
512
|
try:
|
482
513
|
result = await coroutine
|
483
|
-
span
|
514
|
+
setattr(span, field, result)
|
484
515
|
return result
|
485
516
|
except Exception as e:
|
486
|
-
span
|
517
|
+
setattr(span, field, f"Error: {str(e)}")
|
487
518
|
raise
|
488
519
|
|
489
520
|
def record_output(self, output: Any):
|
@@ -493,12 +524,30 @@ class TraceClient:
|
|
493
524
|
span.output = "<pending>" if inspect.iscoroutine(output) else output
|
494
525
|
|
495
526
|
if inspect.iscoroutine(output):
|
496
|
-
asyncio.create_task(self.
|
527
|
+
asyncio.create_task(self._update_coroutine(span, output, "output"))
|
528
|
+
|
529
|
+
return span # Return the created entry
|
530
|
+
# Removed else block - original didn't have one
|
531
|
+
return None # Return None if no span_id found
|
532
|
+
|
533
|
+
def record_usage(self, usage: TraceUsage):
|
534
|
+
current_span_id = current_span_var.get()
|
535
|
+
if current_span_id:
|
536
|
+
span = self.span_id_to_span[current_span_id]
|
537
|
+
span.usage = usage
|
497
538
|
|
498
539
|
return span # Return the created entry
|
499
540
|
# Removed else block - original didn't have one
|
500
541
|
return None # Return None if no span_id found
|
501
|
-
|
542
|
+
|
543
|
+
def record_error(self, error: Any):
|
544
|
+
current_span_id = current_span_var.get()
|
545
|
+
if current_span_id:
|
546
|
+
span = self.span_id_to_span[current_span_id]
|
547
|
+
span.error = error
|
548
|
+
return span
|
549
|
+
return None
|
550
|
+
|
502
551
|
def add_span(self, span: TraceSpan):
|
503
552
|
"""Add a trace span to this trace context"""
|
504
553
|
self.trace_spans.append(span)
|
@@ -523,133 +572,6 @@ class TraceClient:
|
|
523
572
|
"""
|
524
573
|
# Calculate total elapsed time
|
525
574
|
total_duration = self.get_duration()
|
526
|
-
|
527
|
-
# Only count tokens for actual LLM API call spans
|
528
|
-
llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
|
529
|
-
for span in self.trace_spans:
|
530
|
-
span_function_name = span.function # Get function name safely
|
531
|
-
# Check if it's an LLM span AND function name CONTAINS an API call suffix AND output is dict
|
532
|
-
is_llm_span = span.span_type == "llm"
|
533
|
-
has_api_suffix = any(suffix in span_function_name for suffix in llm_span_names)
|
534
|
-
output_is_dict = isinstance(span.output, dict)
|
535
|
-
|
536
|
-
# --- DEBUG PRINT 1: Check if condition passes ---
|
537
|
-
# if is_llm_entry and has_api_suffix and output_is_dict:
|
538
|
-
# elif is_llm_entry:
|
539
|
-
# # Print why it failed if it was an LLM entry
|
540
|
-
# # --- END DEBUG ---
|
541
|
-
|
542
|
-
if is_llm_span and has_api_suffix and output_is_dict:
|
543
|
-
output = span.output
|
544
|
-
usage = output.get("usage", {}) # Gets the 'usage' dict from the 'output' field
|
545
|
-
|
546
|
-
# --- DEBUG PRINT 2: Check extracted usage ---
|
547
|
-
# --- END DEBUG ---
|
548
|
-
|
549
|
-
# --- NEW: Extract model_name correctly from nested inputs ---
|
550
|
-
model_name = None
|
551
|
-
span_inputs = span.inputs
|
552
|
-
if span_inputs:
|
553
|
-
# Try common locations for model name within the inputs structure
|
554
|
-
invocation_params = span_inputs.get("invocation_params", {})
|
555
|
-
serialized_data = span_inputs.get("serialized", {})
|
556
|
-
|
557
|
-
# Look in invocation_params (often directly contains model)
|
558
|
-
if isinstance(invocation_params, dict):
|
559
|
-
model_name = invocation_params.get("model")
|
560
|
-
|
561
|
-
# Fallback: Check serialized 'repr' if it contains model info
|
562
|
-
if not model_name and isinstance(serialized_data, dict):
|
563
|
-
serialized_repr = serialized_data.get("repr", "")
|
564
|
-
if "model_name=" in serialized_repr:
|
565
|
-
try: # Simple parsing attempt
|
566
|
-
model_name = serialized_repr.split("model_name='")[1].split("'")[0]
|
567
|
-
except IndexError: pass # Ignore parsing errors
|
568
|
-
|
569
|
-
# Fallback: Check top-level of invocation_params (sometimes passed flat)
|
570
|
-
if not model_name and isinstance(invocation_params, dict):
|
571
|
-
model_name = invocation_params.get("model") # Redundant check, but safe
|
572
|
-
|
573
|
-
# Fallback: Check top-level of inputs itself (less likely for callbacks)
|
574
|
-
if not model_name:
|
575
|
-
model_name = span_inputs.get("model")
|
576
|
-
|
577
|
-
|
578
|
-
# --- END NEW ---
|
579
|
-
|
580
|
-
prompt_tokens = 0
|
581
|
-
completion_tokens = 0
|
582
|
-
|
583
|
-
# Handle OpenAI/Together format (checks within the 'usage' dict)
|
584
|
-
if "prompt_tokens" in usage:
|
585
|
-
prompt_tokens = usage.get("prompt_tokens", 0)
|
586
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
587
|
-
|
588
|
-
# Handle Anthropic format - MAP values to standard keys
|
589
|
-
elif "input_tokens" in usage:
|
590
|
-
prompt_tokens = usage.get("input_tokens", 0) # Get value from input_tokens
|
591
|
-
completion_tokens = usage.get("output_tokens", 0) # Get value from output_tokens
|
592
|
-
|
593
|
-
# *** Overwrite the usage dict in the entry to use standard keys ***
|
594
|
-
original_total = usage.get("total_tokens", 0)
|
595
|
-
original_total_cost = usage.get("total_cost_usd", 0.0) # Preserve if already calculated
|
596
|
-
# Recalculate cost just in case it wasn't done correctly before
|
597
|
-
temp_prompt_cost, temp_completion_cost = 0.0, 0.0
|
598
|
-
if model_name:
|
599
|
-
try:
|
600
|
-
temp_prompt_cost, temp_completion_cost = cost_per_token(
|
601
|
-
model=model_name,
|
602
|
-
prompt_tokens=prompt_tokens,
|
603
|
-
completion_tokens=completion_tokens
|
604
|
-
)
|
605
|
-
except Exception:
|
606
|
-
pass # Ignore cost calculation errors here, focus on keys
|
607
|
-
# Replace the usage dict with one using standard keys but Anthropic values
|
608
|
-
output["usage"] = {
|
609
|
-
"prompt_tokens": prompt_tokens,
|
610
|
-
"completion_tokens": completion_tokens,
|
611
|
-
"total_tokens": original_total,
|
612
|
-
"prompt_tokens_cost_usd": temp_prompt_cost, # Use standard cost key
|
613
|
-
"completion_tokens_cost_usd": temp_completion_cost, # Use standard cost key
|
614
|
-
"total_cost_usd": original_total_cost if original_total_cost > 0 else (temp_prompt_cost + temp_completion_cost)
|
615
|
-
}
|
616
|
-
usage = output["usage"]
|
617
|
-
|
618
|
-
# Calculate costs if model name is available and ensure they are stored with standard keys
|
619
|
-
prompt_tokens = usage.get("prompt_tokens", 0)
|
620
|
-
completion_tokens = usage.get("completion_tokens", 0)
|
621
|
-
|
622
|
-
# Calculate costs if model name is available
|
623
|
-
if model_name:
|
624
|
-
try:
|
625
|
-
# Recalculate costs based on potentially mapped tokens
|
626
|
-
prompt_cost, completion_cost = cost_per_token(
|
627
|
-
model=model_name,
|
628
|
-
prompt_tokens=prompt_tokens,
|
629
|
-
completion_tokens=completion_tokens
|
630
|
-
)
|
631
|
-
|
632
|
-
# Add cost information directly to the usage dictionary in the condensed entry
|
633
|
-
# Ensure 'usage' exists in the output dict before modifying it
|
634
|
-
# Add/Update cost information using standard keys
|
635
|
-
|
636
|
-
if "usage" not in output:
|
637
|
-
output["usage"] = {} # Initialize if missing
|
638
|
-
elif not isinstance(output["usage"], dict): # Handle cases where 'usage' might not be a dict (e.g., placeholder string)
|
639
|
-
print(f"[WARN TraceClient.save] Output 'usage' for span {span.span_id} was not a dict ({type(output['usage'])}). Resetting before adding costs.")
|
640
|
-
output["usage"] = {} # Reset to dict
|
641
|
-
|
642
|
-
output["usage"]["prompt_tokens_cost_usd"] = prompt_cost
|
643
|
-
output["usage"]["completion_tokens_cost_usd"] = completion_cost
|
644
|
-
output["usage"]["total_cost_usd"] = prompt_cost + completion_cost
|
645
|
-
except Exception as e:
|
646
|
-
# If cost calculation fails, continue without adding costs
|
647
|
-
print(f"Error calculating cost for model '{model_name}' (span: {span.span_id}): {str(e)}")
|
648
|
-
pass
|
649
|
-
else:
|
650
|
-
print(f"[WARN TraceClient.save] Could not determine model name for cost calculation (span: {span.span_id}). Inputs: {span_inputs}")
|
651
|
-
|
652
|
-
|
653
575
|
# Create trace document - Always use standard keys for top-level counts
|
654
576
|
trace_data = {
|
655
577
|
"trace_id": self.trace_id,
|
@@ -677,13 +599,25 @@ class TraceClient:
|
|
677
599
|
def delete(self):
|
678
600
|
return self.trace_manager_client.delete_trace(self.trace_id)
|
679
601
|
|
680
|
-
|
602
|
+
def _capture_exception_for_trace(current_trace: Optional['TraceClient'], exc_info: Tuple[Optional[type], Optional[BaseException], Optional[types.TracebackType]]):
|
603
|
+
if not current_trace:
|
604
|
+
return
|
605
|
+
|
606
|
+
exc_type, exc_value, exc_traceback_obj = exc_info
|
607
|
+
formatted_exception = {
|
608
|
+
"type": exc_type.__name__ if exc_type else "UnknownExceptionType",
|
609
|
+
"message": str(exc_value) if exc_value else "No exception message",
|
610
|
+
"traceback": traceback.format_tb(exc_traceback_obj) if exc_traceback_obj else []
|
611
|
+
}
|
612
|
+
current_trace.record_error(formatted_exception)
|
681
613
|
class _DeepTracer:
|
682
614
|
_instance: Optional["_DeepTracer"] = None
|
683
615
|
_lock: threading.Lock = threading.Lock()
|
684
616
|
_refcount: int = 0
|
685
617
|
_span_stack: contextvars.ContextVar[List[Dict[str, Any]]] = contextvars.ContextVar("_deep_profiler_span_stack", default=[])
|
686
618
|
_skip_stack: contextvars.ContextVar[List[str]] = contextvars.ContextVar("_deep_profiler_skip_stack", default=[])
|
619
|
+
_original_sys_trace: Optional[Callable] = None
|
620
|
+
_original_threading_trace: Optional[Callable] = None
|
687
621
|
|
688
622
|
def _get_qual_name(self, frame) -> str:
|
689
623
|
func_name = frame.f_code.co_name
|
@@ -731,12 +665,53 @@ class _DeepTracer:
|
|
731
665
|
@functools.cache
|
732
666
|
def _is_user_code(self, filename: str):
|
733
667
|
return bool(filename) and not filename.startswith("<") and not os.path.realpath(filename).startswith(_TRACE_FILEPATH_BLOCKLIST)
|
668
|
+
|
669
|
+
def _cooperative_sys_trace(self, frame: types.FrameType, event: str, arg: Any):
|
670
|
+
"""Cooperative trace function for sys.settrace that chains with existing tracers."""
|
671
|
+
# First, call the original sys trace function if it exists
|
672
|
+
original_result = None
|
673
|
+
if self._original_sys_trace:
|
674
|
+
try:
|
675
|
+
original_result = self._original_sys_trace(frame, event, arg)
|
676
|
+
except Exception:
|
677
|
+
# If the original tracer fails, continue with our tracing
|
678
|
+
pass
|
679
|
+
|
680
|
+
# Then do our own tracing
|
681
|
+
our_result = self._trace(frame, event, arg, self._cooperative_sys_trace)
|
682
|
+
|
683
|
+
# Return our tracer to continue tracing, but respect the original's decision
|
684
|
+
# If the original tracer returned None (stop tracing), we should respect that
|
685
|
+
if original_result is None and self._original_sys_trace:
|
686
|
+
return None
|
687
|
+
|
688
|
+
return our_result or original_result
|
689
|
+
|
690
|
+
def _cooperative_threading_trace(self, frame: types.FrameType, event: str, arg: Any):
|
691
|
+
"""Cooperative trace function for threading.settrace that chains with existing tracers."""
|
692
|
+
# First, call the original threading trace function if it exists
|
693
|
+
original_result = None
|
694
|
+
if self._original_threading_trace:
|
695
|
+
try:
|
696
|
+
original_result = self._original_threading_trace(frame, event, arg)
|
697
|
+
except Exception:
|
698
|
+
# If the original tracer fails, continue with our tracing
|
699
|
+
pass
|
700
|
+
|
701
|
+
# Then do our own tracing
|
702
|
+
our_result = self._trace(frame, event, arg, self._cooperative_threading_trace)
|
703
|
+
|
704
|
+
# Return our tracer to continue tracing, but respect the original's decision
|
705
|
+
# If the original tracer returned None (stop tracing), we should respect that
|
706
|
+
if original_result is None and self._original_threading_trace:
|
707
|
+
return None
|
708
|
+
|
709
|
+
return our_result or original_result
|
734
710
|
|
735
|
-
def _trace(self, frame: types.FrameType, event: str, arg: Any):
|
711
|
+
def _trace(self, frame: types.FrameType, event: str, arg: Any, continuation_func: Callable):
|
736
712
|
frame.f_trace_lines = False
|
737
713
|
frame.f_trace_opcodes = False
|
738
714
|
|
739
|
-
|
740
715
|
if not self._should_trace(frame):
|
741
716
|
return
|
742
717
|
|
@@ -752,6 +727,12 @@ class _DeepTracer:
|
|
752
727
|
return
|
753
728
|
|
754
729
|
qual_name = self._get_qual_name(frame)
|
730
|
+
instance_name = None
|
731
|
+
if 'self' in frame.f_locals:
|
732
|
+
instance = frame.f_locals['self']
|
733
|
+
class_name = instance.__class__.__name__
|
734
|
+
class_identifiers = getattr(Tracer._instance, 'class_identifiers', {})
|
735
|
+
instance_name = get_instance_prefixed_name(instance, class_name, class_identifiers)
|
755
736
|
skip_stack = self._skip_stack.get()
|
756
737
|
|
757
738
|
if event == "call":
|
@@ -814,7 +795,8 @@ class _DeepTracer:
|
|
814
795
|
created_at=start_time,
|
815
796
|
span_type="span",
|
816
797
|
parent_span_id=parent_span_id,
|
817
|
-
function=qual_name
|
798
|
+
function=qual_name,
|
799
|
+
agent_name=instance_name
|
818
800
|
)
|
819
801
|
current_trace.add_span(span)
|
820
802
|
|
@@ -869,35 +851,40 @@ class _DeepTracer:
|
|
869
851
|
current_span_var.reset(frame.f_locals["_judgment_span_token"])
|
870
852
|
|
871
853
|
elif event == "exception":
|
872
|
-
exc_type
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
}
|
878
|
-
current_trace = current_trace_var.get()
|
879
|
-
current_trace.record_output({
|
880
|
-
"error": formatted_exception
|
881
|
-
})
|
854
|
+
exc_type = arg[0]
|
855
|
+
if issubclass(exc_type, (StopIteration, StopAsyncIteration, GeneratorExit)):
|
856
|
+
return
|
857
|
+
_capture_exception_for_trace(current_trace, arg)
|
858
|
+
|
882
859
|
|
883
|
-
return
|
860
|
+
return continuation_func
|
884
861
|
|
885
862
|
def __enter__(self):
|
886
863
|
with self._lock:
|
887
864
|
self._refcount += 1
|
888
865
|
if self._refcount == 1:
|
866
|
+
# Store the existing trace functions before setting ours
|
867
|
+
self._original_sys_trace = sys.gettrace()
|
868
|
+
self._original_threading_trace = threading.gettrace()
|
869
|
+
|
889
870
|
self._skip_stack.set([])
|
890
871
|
self._span_stack.set([])
|
891
|
-
|
892
|
-
|
872
|
+
|
873
|
+
sys.settrace(self._cooperative_sys_trace)
|
874
|
+
threading.settrace(self._cooperative_threading_trace)
|
893
875
|
return self
|
894
876
|
|
895
877
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
896
878
|
with self._lock:
|
897
879
|
self._refcount -= 1
|
898
880
|
if self._refcount == 0:
|
899
|
-
|
900
|
-
|
881
|
+
# Restore the original trace functions instead of setting to None
|
882
|
+
sys.settrace(self._original_sys_trace)
|
883
|
+
threading.settrace(self._original_threading_trace)
|
884
|
+
|
885
|
+
# Clean up the references
|
886
|
+
self._original_sys_trace = None
|
887
|
+
self._original_threading_trace = None
|
901
888
|
|
902
889
|
|
903
890
|
def log(self, message: str, level: str = "info"):
|
@@ -946,10 +933,6 @@ class Tracer:
|
|
946
933
|
raise ValueError("Tracer must be configured with an Organization ID")
|
947
934
|
if use_s3 and not s3_bucket_name:
|
948
935
|
raise ValueError("S3 bucket name must be provided when use_s3 is True")
|
949
|
-
if use_s3 and not (s3_aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")):
|
950
|
-
raise ValueError("AWS Access Key ID must be provided when use_s3 is True")
|
951
|
-
if use_s3 and not (s3_aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")):
|
952
|
-
raise ValueError("AWS Secret Access Key must be provided when use_s3 is True")
|
953
936
|
|
954
937
|
self.api_key: str = api_key
|
955
938
|
self.project_name: str = project_name
|
@@ -961,6 +944,7 @@ class Tracer:
|
|
961
944
|
self.initialized: bool = True
|
962
945
|
self.enable_monitoring: bool = enable_monitoring
|
963
946
|
self.enable_evaluations: bool = enable_evaluations
|
947
|
+
self.class_identifiers: Dict[str, str] = {} # Dictionary to store class identifiers
|
964
948
|
|
965
949
|
# Initialize S3 storage if enabled
|
966
950
|
self.use_s3 = use_s3
|
@@ -1084,6 +1068,32 @@ class Tracer:
|
|
1084
1068
|
|
1085
1069
|
rprint(f"[bold]{label}:[/bold] {msg}")
|
1086
1070
|
|
1071
|
+
def identify(self, identifier: str):
|
1072
|
+
"""
|
1073
|
+
Class decorator that associates a class with a custom identifier.
|
1074
|
+
|
1075
|
+
This decorator creates a mapping between the class name and the provided
|
1076
|
+
identifier, which can be useful for tagging, grouping, or referencing
|
1077
|
+
classes in a standardized way.
|
1078
|
+
|
1079
|
+
Args:
|
1080
|
+
identifier: The identifier to associate with the decorated class
|
1081
|
+
|
1082
|
+
Returns:
|
1083
|
+
A decorator function that registers the class with the given identifier
|
1084
|
+
|
1085
|
+
Example:
|
1086
|
+
@tracer.identify(identifier="user_model")
|
1087
|
+
class User:
|
1088
|
+
# Class implementation
|
1089
|
+
"""
|
1090
|
+
def decorator(cls):
|
1091
|
+
class_name = cls.__name__
|
1092
|
+
self.class_identifiers[class_name] = identifier
|
1093
|
+
return cls
|
1094
|
+
|
1095
|
+
return decorator
|
1096
|
+
|
1087
1097
|
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
|
1088
1098
|
"""
|
1089
1099
|
Decorator to trace function execution with detailed entry/exit information.
|
@@ -1106,10 +1116,10 @@ class Tracer:
|
|
1106
1116
|
overwrite=overwrite, deep_tracing=deep_tracing)
|
1107
1117
|
|
1108
1118
|
# Use provided name or fall back to function name
|
1109
|
-
|
1119
|
+
original_span_name = name or func.__name__
|
1110
1120
|
|
1111
1121
|
# Store custom attributes on the function object
|
1112
|
-
func._judgment_span_name =
|
1122
|
+
func._judgment_span_name = original_span_name
|
1113
1123
|
func._judgment_span_type = span_type
|
1114
1124
|
|
1115
1125
|
# Use the provided deep_tracing value or fall back to the tracer's default
|
@@ -1118,6 +1128,16 @@ class Tracer:
|
|
1118
1128
|
if asyncio.iscoroutinefunction(func):
|
1119
1129
|
@functools.wraps(func)
|
1120
1130
|
async def async_wrapper(*args, **kwargs):
|
1131
|
+
nonlocal original_span_name
|
1132
|
+
class_name = None
|
1133
|
+
instance_name = None
|
1134
|
+
span_name = original_span_name
|
1135
|
+
agent_name = None
|
1136
|
+
|
1137
|
+
if args and hasattr(args[0], '__class__'):
|
1138
|
+
class_name = args[0].__class__.__name__
|
1139
|
+
agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
|
1140
|
+
|
1121
1141
|
# Get current trace from context
|
1122
1142
|
current_trace = current_trace_var.get()
|
1123
1143
|
|
@@ -1141,7 +1161,7 @@ class Tracer:
|
|
1141
1161
|
# Save empty trace and set trace context
|
1142
1162
|
# current_trace.save(empty_save=True, overwrite=overwrite)
|
1143
1163
|
trace_token = current_trace_var.set(current_trace)
|
1144
|
-
|
1164
|
+
|
1145
1165
|
try:
|
1146
1166
|
# Use span for the function execution within the root trace
|
1147
1167
|
# This sets the current_span_var
|
@@ -1149,13 +1169,19 @@ class Tracer:
|
|
1149
1169
|
# Record inputs
|
1150
1170
|
inputs = combine_args_kwargs(func, args, kwargs)
|
1151
1171
|
span.record_input(inputs)
|
1172
|
+
if agent_name:
|
1173
|
+
span.record_agent_name(agent_name)
|
1152
1174
|
|
1153
1175
|
if use_deep_tracing:
|
1154
1176
|
with _DeepTracer():
|
1155
1177
|
result = await func(*args, **kwargs)
|
1156
1178
|
else:
|
1157
|
-
|
1158
|
-
|
1179
|
+
try:
|
1180
|
+
result = await func(*args, **kwargs)
|
1181
|
+
except Exception as e:
|
1182
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1183
|
+
raise e
|
1184
|
+
|
1159
1185
|
# Record output
|
1160
1186
|
span.record_output(result)
|
1161
1187
|
return result
|
@@ -1170,12 +1196,18 @@ class Tracer:
|
|
1170
1196
|
with current_trace.span(span_name, span_type=span_type) as span:
|
1171
1197
|
inputs = combine_args_kwargs(func, args, kwargs)
|
1172
1198
|
span.record_input(inputs)
|
1173
|
-
|
1199
|
+
if agent_name:
|
1200
|
+
span.record_agent_name(agent_name)
|
1201
|
+
|
1174
1202
|
if use_deep_tracing:
|
1175
1203
|
with _DeepTracer():
|
1176
1204
|
result = await func(*args, **kwargs)
|
1177
1205
|
else:
|
1178
|
-
|
1206
|
+
try:
|
1207
|
+
result = await func(*args, **kwargs)
|
1208
|
+
except Exception as e:
|
1209
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1210
|
+
raise e
|
1179
1211
|
|
1180
1212
|
span.record_output(result)
|
1181
1213
|
return result
|
@@ -1184,7 +1216,15 @@ class Tracer:
|
|
1184
1216
|
else:
|
1185
1217
|
# Non-async function implementation with deep tracing
|
1186
1218
|
@functools.wraps(func)
|
1187
|
-
def wrapper(*args, **kwargs):
|
1219
|
+
def wrapper(*args, **kwargs):
|
1220
|
+
nonlocal original_span_name
|
1221
|
+
class_name = None
|
1222
|
+
instance_name = None
|
1223
|
+
span_name = original_span_name
|
1224
|
+
agent_name = None
|
1225
|
+
if args and hasattr(args[0], '__class__'):
|
1226
|
+
class_name = args[0].__class__.__name__
|
1227
|
+
agent_name = get_instance_prefixed_name(args[0], class_name, self.class_identifiers)
|
1188
1228
|
# Get current trace from context
|
1189
1229
|
current_trace = current_trace_var.get()
|
1190
1230
|
|
@@ -1216,12 +1256,17 @@ class Tracer:
|
|
1216
1256
|
# Record inputs
|
1217
1257
|
inputs = combine_args_kwargs(func, args, kwargs)
|
1218
1258
|
span.record_input(inputs)
|
1219
|
-
|
1259
|
+
if agent_name:
|
1260
|
+
span.record_agent_name(agent_name)
|
1220
1261
|
if use_deep_tracing:
|
1221
1262
|
with _DeepTracer():
|
1222
1263
|
result = func(*args, **kwargs)
|
1223
1264
|
else:
|
1224
|
-
|
1265
|
+
try:
|
1266
|
+
result = func(*args, **kwargs)
|
1267
|
+
except Exception as e:
|
1268
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1269
|
+
raise e
|
1225
1270
|
|
1226
1271
|
# Record output
|
1227
1272
|
span.record_output(result)
|
@@ -1238,12 +1283,18 @@ class Tracer:
|
|
1238
1283
|
|
1239
1284
|
inputs = combine_args_kwargs(func, args, kwargs)
|
1240
1285
|
span.record_input(inputs)
|
1241
|
-
|
1286
|
+
if agent_name:
|
1287
|
+
span.record_agent_name(agent_name)
|
1288
|
+
|
1242
1289
|
if use_deep_tracing:
|
1243
1290
|
with _DeepTracer():
|
1244
1291
|
result = func(*args, **kwargs)
|
1245
1292
|
else:
|
1246
|
-
|
1293
|
+
try:
|
1294
|
+
result = func(*args, **kwargs)
|
1295
|
+
except Exception as e:
|
1296
|
+
_capture_exception_for_trace(current_trace, sys.exc_info())
|
1297
|
+
raise e
|
1247
1298
|
|
1248
1299
|
span.record_output(result)
|
1249
1300
|
return result
|
@@ -1313,8 +1364,9 @@ def wrap(client: Any) -> Any:
|
|
1313
1364
|
return wrapper_func(response, client, output_entry)
|
1314
1365
|
else:
|
1315
1366
|
format_func = _format_response_output_data if is_responses else _format_output_data
|
1316
|
-
|
1317
|
-
span.record_output(
|
1367
|
+
output, usage = format_func(client, response)
|
1368
|
+
span.record_output(output)
|
1369
|
+
span.record_usage(usage)
|
1318
1370
|
return response
|
1319
1371
|
|
1320
1372
|
def _handle_error(span, e, is_async):
|
@@ -1496,18 +1548,35 @@ def _format_response_output_data(client: ApiClient, response: Any) -> dict:
|
|
1496
1548
|
Normalizes different response formats into a consistent structure
|
1497
1549
|
for tracing purposes.
|
1498
1550
|
"""
|
1551
|
+
message_content = None
|
1552
|
+
prompt_tokens = 0
|
1553
|
+
completion_tokens = 0
|
1554
|
+
model_name = None
|
1499
1555
|
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
"completion_tokens": response.usage.output_tokens,
|
1505
|
-
"total_tokens": response.usage.total_tokens
|
1506
|
-
}
|
1507
|
-
}
|
1556
|
+
model_name = response.model
|
1557
|
+
prompt_tokens = response.usage.input_tokens
|
1558
|
+
completion_tokens = response.usage.output_tokens
|
1559
|
+
message_content = response.output
|
1508
1560
|
else:
|
1509
1561
|
warnings.warn(f"Unsupported client type: {type(client)}")
|
1510
1562
|
return {}
|
1563
|
+
|
1564
|
+
prompt_cost, completion_cost = cost_per_token(
|
1565
|
+
model=model_name,
|
1566
|
+
prompt_tokens=prompt_tokens,
|
1567
|
+
completion_tokens=completion_tokens,
|
1568
|
+
)
|
1569
|
+
total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
1570
|
+
usage = TraceUsage(
|
1571
|
+
prompt_tokens=prompt_tokens,
|
1572
|
+
completion_tokens=completion_tokens,
|
1573
|
+
total_tokens=prompt_tokens + completion_tokens,
|
1574
|
+
prompt_tokens_cost_usd=prompt_cost,
|
1575
|
+
completion_tokens_cost_usd=completion_cost,
|
1576
|
+
total_cost_usd=total_cost_usd,
|
1577
|
+
model_name=model_name
|
1578
|
+
)
|
1579
|
+
return message_content, usage
|
1511
1580
|
|
1512
1581
|
|
1513
1582
|
def _format_output_data(client: ApiClient, response: Any) -> dict:
|
@@ -1521,33 +1590,46 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1521
1590
|
- content: The generated text
|
1522
1591
|
- usage: Token usage statistics
|
1523
1592
|
"""
|
1593
|
+
prompt_tokens = 0
|
1594
|
+
completion_tokens = 0
|
1595
|
+
model_name = None
|
1596
|
+
message_content = None
|
1597
|
+
|
1524
1598
|
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1525
|
-
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
"completion_tokens": response.usage.completion_tokens,
|
1530
|
-
"total_tokens": response.usage.total_tokens
|
1531
|
-
}
|
1532
|
-
}
|
1599
|
+
model_name = response.model
|
1600
|
+
prompt_tokens = response.usage.prompt_tokens
|
1601
|
+
completion_tokens = response.usage.completion_tokens
|
1602
|
+
message_content = response.choices[0].message.content
|
1533
1603
|
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1534
|
-
|
1535
|
-
|
1536
|
-
|
1537
|
-
|
1538
|
-
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
"
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1604
|
+
model_name = response.model_version
|
1605
|
+
prompt_tokens = response.usage_metadata.prompt_token_count
|
1606
|
+
completion_tokens = response.usage_metadata.candidates_token_count
|
1607
|
+
message_content = response.candidates[0].content.parts[0].text
|
1608
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1609
|
+
model_name = response.model
|
1610
|
+
prompt_tokens = response.usage.input_tokens
|
1611
|
+
completion_tokens = response.usage.output_tokens
|
1612
|
+
message_content = response.content[0].text
|
1613
|
+
else:
|
1614
|
+
warnings.warn(f"Unsupported client type: {type(client)}")
|
1615
|
+
return None, None
|
1616
|
+
|
1617
|
+
prompt_cost, completion_cost = cost_per_token(
|
1618
|
+
model=model_name,
|
1619
|
+
prompt_tokens=prompt_tokens,
|
1620
|
+
completion_tokens=completion_tokens,
|
1621
|
+
)
|
1622
|
+
total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
1623
|
+
usage = TraceUsage(
|
1624
|
+
prompt_tokens=prompt_tokens,
|
1625
|
+
completion_tokens=completion_tokens,
|
1626
|
+
total_tokens=prompt_tokens + completion_tokens,
|
1627
|
+
prompt_tokens_cost_usd=prompt_cost,
|
1628
|
+
completion_tokens_cost_usd=completion_cost,
|
1629
|
+
total_cost_usd=total_cost_usd,
|
1630
|
+
model_name=model_name
|
1631
|
+
)
|
1632
|
+
return message_content, usage
|
1551
1633
|
|
1552
1634
|
def combine_args_kwargs(func, args, kwargs):
|
1553
1635
|
"""
|
@@ -1653,21 +1735,30 @@ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[D
|
|
1653
1735
|
# OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
|
1654
1736
|
# This typically requires specific API versions or settings. Often usage is *not* streamed.
|
1655
1737
|
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1656
|
-
|
1657
|
-
|
1658
|
-
|
1659
|
-
|
1660
|
-
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
|
1666
|
-
|
1667
|
-
|
1668
|
-
|
1669
|
-
|
1670
|
-
|
1738
|
+
# Check if usage is directly on the chunk (some models might do this)
|
1739
|
+
if hasattr(chunk, 'usage') and chunk.usage:
|
1740
|
+
prompt_tokens = chunk.usage.prompt_tokens
|
1741
|
+
completion_tokens = chunk.usage.completion_tokens
|
1742
|
+
# Check if usage is nested within choices (less common for final chunk, but check)
|
1743
|
+
elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
|
1744
|
+
prompt_tokens = chunk.choices[0].usage.prompt_tokens
|
1745
|
+
completion_tokens = chunk.choices[0].usage.completion_tokens
|
1746
|
+
|
1747
|
+
prompt_cost, completion_cost = cost_per_token(
|
1748
|
+
model=chunk.model,
|
1749
|
+
prompt_tokens=prompt_tokens,
|
1750
|
+
completion_tokens=completion_tokens,
|
1751
|
+
)
|
1752
|
+
total_cost_usd = (prompt_cost + completion_cost) if prompt_cost and completion_cost else None
|
1753
|
+
return TraceUsage(
|
1754
|
+
prompt_tokens=chunk.usage.prompt_tokens,
|
1755
|
+
completion_tokens=chunk.usage.completion_tokens,
|
1756
|
+
total_tokens=chunk.usage.total_tokens,
|
1757
|
+
prompt_tokens_cost_usd=prompt_cost,
|
1758
|
+
completion_tokens_cost_usd=completion_cost,
|
1759
|
+
total_cost_usd=total_cost_usd,
|
1760
|
+
model_name=chunk.model
|
1761
|
+
)
|
1671
1762
|
# Anthropic includes usage in the 'message_stop' event type
|
1672
1763
|
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1673
1764
|
if chunk.type == "message_stop":
|
@@ -1715,11 +1806,8 @@ def _sync_stream_wrapper(
|
|
1715
1806
|
final_usage = _extract_usage_from_final_chunk(client, last_chunk)
|
1716
1807
|
|
1717
1808
|
# Update the trace entry with the accumulated content and usage
|
1718
|
-
span.output =
|
1719
|
-
|
1720
|
-
"usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
|
1721
|
-
"streamed": True
|
1722
|
-
}
|
1809
|
+
span.output = "".join(content_parts)
|
1810
|
+
span.usage = final_usage
|
1723
1811
|
# Note: We might need to adjust _serialize_output if this dict causes issues,
|
1724
1812
|
# but Pydantic's model_dump should handle dicts.
|
1725
1813
|
|
@@ -1739,6 +1827,7 @@ async def _async_stream_wrapper(
|
|
1739
1827
|
target_span_id = span.span_id
|
1740
1828
|
|
1741
1829
|
try:
|
1830
|
+
model_name = ""
|
1742
1831
|
async for chunk in original_stream:
|
1743
1832
|
# Check for OpenAI's final usage chunk
|
1744
1833
|
if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
|
@@ -1747,16 +1836,18 @@ async def _async_stream_wrapper(
|
|
1747
1836
|
"completion_tokens": chunk.usage.completion_tokens,
|
1748
1837
|
"total_tokens": chunk.usage.total_tokens
|
1749
1838
|
}
|
1839
|
+
model_name = chunk.model
|
1750
1840
|
yield chunk
|
1751
1841
|
continue
|
1752
1842
|
|
1753
1843
|
if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
|
1754
|
-
|
1755
|
-
|
1844
|
+
if chunk.type == "message_start":
|
1845
|
+
if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
|
1756
1846
|
anthropic_input_tokens = chunk.message.usage.input_tokens
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1847
|
+
model_name = chunk.message.model
|
1848
|
+
elif chunk.type == "message_delta":
|
1849
|
+
if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
|
1850
|
+
anthropic_output_tokens = chunk.usage.output_tokens
|
1760
1851
|
|
1761
1852
|
content_part = _extract_content_from_chunk(client, chunk)
|
1762
1853
|
if content_part:
|
@@ -1779,18 +1870,37 @@ async def _async_stream_wrapper(
|
|
1779
1870
|
elif anthropic_final_usage:
|
1780
1871
|
usage_info = anthropic_final_usage
|
1781
1872
|
elif last_content_chunk:
|
1782
|
-
|
1873
|
+
usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
|
1783
1874
|
|
1875
|
+
if usage_info and not isinstance(usage_info, TraceUsage):
|
1876
|
+
prompt_cost, completion_cost = cost_per_token(
|
1877
|
+
model=model_name,
|
1878
|
+
prompt_tokens=usage_info["prompt_tokens"],
|
1879
|
+
completion_tokens=usage_info["completion_tokens"],
|
1880
|
+
)
|
1881
|
+
usage_info = TraceUsage(
|
1882
|
+
prompt_tokens=usage_info["prompt_tokens"],
|
1883
|
+
completion_tokens=usage_info["completion_tokens"],
|
1884
|
+
total_tokens=usage_info["total_tokens"],
|
1885
|
+
prompt_tokens_cost_usd=prompt_cost,
|
1886
|
+
completion_tokens_cost_usd=completion_cost,
|
1887
|
+
total_cost_usd=prompt_cost + completion_cost,
|
1888
|
+
model_name=model_name
|
1889
|
+
)
|
1784
1890
|
if span and hasattr(span, 'output'):
|
1785
|
-
span.output =
|
1786
|
-
|
1787
|
-
"usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
|
1788
|
-
"streamed": True
|
1789
|
-
}
|
1891
|
+
span.output = ''.join(content_parts)
|
1892
|
+
span.usage = usage_info
|
1790
1893
|
start_ts = getattr(span, 'created_at', time.time())
|
1791
1894
|
span.duration = time.time() - start_ts
|
1792
1895
|
# else: # Handle error case if necessary, but remove debug print
|
1793
1896
|
|
1897
|
+
def cost_per_token(*args, **kwargs):
|
1898
|
+
try:
|
1899
|
+
return _original_cost_per_token(*args, **kwargs)
|
1900
|
+
except Exception as e:
|
1901
|
+
warnings.warn(f"Error calculating cost per token: {e}")
|
1902
|
+
return None, None
|
1903
|
+
|
1794
1904
|
class _BaseStreamManagerWrapper:
|
1795
1905
|
def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
|
1796
1906
|
self._original_manager = original_manager
|
@@ -1872,3 +1982,18 @@ class _TracedSyncStreamManagerWrapper(_BaseStreamManagerWrapper, AbstractContext
|
|
1872
1982
|
current_span_var.reset(self._span_context_token)
|
1873
1983
|
delattr(self, '_span_context_token')
|
1874
1984
|
return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
|
1985
|
+
|
1986
|
+
# --- Helper function for instance-prefixed qual_name ---
|
1987
|
+
def get_instance_prefixed_name(instance, class_name, class_identifiers):
|
1988
|
+
"""
|
1989
|
+
Returns the agent name (prefix) if the class and attribute are found in class_identifiers.
|
1990
|
+
Otherwise, returns None.
|
1991
|
+
"""
|
1992
|
+
if class_name in class_identifiers:
|
1993
|
+
attr = class_identifiers[class_name]
|
1994
|
+
if hasattr(instance, attr):
|
1995
|
+
instance_name = getattr(instance, attr)
|
1996
|
+
return instance_name
|
1997
|
+
else:
|
1998
|
+
raise Exception(f"Attribute {class_identifiers[class_name]} does not exist for {class_name}. Check your identify() decorator.")
|
1999
|
+
return None
|