judgeval 0.0.30__py3-none-any.whl → 0.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- judgeval/__init__.py +3 -1
- judgeval/common/tracer.py +352 -117
- judgeval/constants.py +5 -3
- judgeval/data/__init__.py +4 -0
- judgeval/data/custom_example.py +18 -0
- judgeval/data/datasets/dataset.py +5 -1
- judgeval/data/datasets/eval_dataset_client.py +64 -5
- judgeval/data/example.py +1 -0
- judgeval/data/result.py +7 -6
- judgeval/data/sequence.py +55 -0
- judgeval/data/sequence_run.py +44 -0
- judgeval/evaluation_run.py +12 -7
- judgeval/integrations/langgraph.py +89 -72
- judgeval/judgment_client.py +70 -68
- judgeval/run_evaluation.py +87 -13
- judgeval/scorers/__init__.py +2 -0
- judgeval/scorers/judgeval_scorer.py +3 -0
- judgeval/scorers/judgeval_scorers/__init__.py +7 -0
- judgeval/scorers/judgeval_scorers/api_scorers/__init__.py +2 -1
- judgeval/scorers/judgeval_scorers/api_scorers/derailment_scorer.py +21 -0
- judgeval/scorers/score.py +6 -5
- judgeval/version_check.py +22 -0
- {judgeval-0.0.30.dist-info → judgeval-0.0.32.dist-info}/METADATA +1 -1
- {judgeval-0.0.30.dist-info → judgeval-0.0.32.dist-info}/RECORD +26 -22
- judgeval/data/custom_api_example.py +0 -91
- {judgeval-0.0.30.dist-info → judgeval-0.0.32.dist-info}/WHEEL +0 -0
- {judgeval-0.0.30.dist-info → judgeval-0.0.32.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py
CHANGED
@@ -11,11 +11,12 @@ import time
|
|
11
11
|
import uuid
|
12
12
|
import warnings
|
13
13
|
import contextvars
|
14
|
+
import sys
|
14
15
|
from contextlib import contextmanager
|
15
16
|
from dataclasses import dataclass, field
|
16
17
|
from datetime import datetime
|
17
18
|
from http import HTTPStatus
|
18
|
-
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
|
19
|
+
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable, Set
|
19
20
|
from rich import print as rprint
|
20
21
|
|
21
22
|
# Third-party imports
|
@@ -24,9 +25,10 @@ import requests
|
|
24
25
|
from litellm import cost_per_token
|
25
26
|
from pydantic import BaseModel
|
26
27
|
from rich import print as rprint
|
27
|
-
from openai import OpenAI
|
28
|
-
from together import Together
|
29
|
-
from anthropic import Anthropic
|
28
|
+
from openai import OpenAI, AsyncOpenAI
|
29
|
+
from together import Together, AsyncTogether
|
30
|
+
from anthropic import Anthropic, AsyncAnthropic
|
31
|
+
from google import genai
|
30
32
|
|
31
33
|
# Local application/library-specific imports
|
32
34
|
from judgeval.constants import (
|
@@ -37,7 +39,6 @@ from judgeval.constants import (
|
|
37
39
|
RABBITMQ_QUEUE,
|
38
40
|
JUDGMENT_TRACES_DELETE_API_URL,
|
39
41
|
JUDGMENT_PROJECT_DELETE_API_URL,
|
40
|
-
JUDGMENT_TRACES_ADD_TO_EVAL_QUEUE_API_URL
|
41
42
|
)
|
42
43
|
from judgeval.judgment_client import JudgmentClient
|
43
44
|
from judgeval.data import Example
|
@@ -51,10 +52,11 @@ import concurrent.futures
|
|
51
52
|
|
52
53
|
# Define context variables for tracking the current trace and the current span within a trace
|
53
54
|
current_trace_var = contextvars.ContextVar('current_trace', default=None)
|
54
|
-
current_span_var = contextvars.ContextVar('current_span', default=None) #
|
55
|
+
current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
|
56
|
+
in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
|
55
57
|
|
56
58
|
# Define type aliases for better code readability and maintainability
|
57
|
-
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
|
59
|
+
ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
|
58
60
|
TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
|
59
61
|
SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
|
60
62
|
@dataclass
|
@@ -69,11 +71,11 @@ class TraceEntry:
|
|
69
71
|
- evaluation: Evaluation: (evaluation results)
|
70
72
|
"""
|
71
73
|
type: TraceEntryType
|
72
|
-
function: str # Name of the function being traced
|
73
74
|
span_id: str # Unique ID for this specific span instance
|
74
75
|
depth: int # Indentation level for nested calls
|
75
|
-
message: str # Human-readable description
|
76
76
|
created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
|
77
|
+
function: Optional[str] = None # Name of the function being traced
|
78
|
+
message: Optional[str] = None # Human-readable description
|
77
79
|
duration: Optional[float] = None # Time taken (for exit/evaluation entries)
|
78
80
|
trace_id: str = None # ID of the trace this entry belongs to
|
79
81
|
output: Any = None # Function output value
|
@@ -229,6 +231,8 @@ class TraceManagerClient:
|
|
229
231
|
raise ValueError(f"Failed to fetch traces: {response.text}")
|
230
232
|
|
231
233
|
return response.json()
|
234
|
+
|
235
|
+
|
232
236
|
|
233
237
|
def save_trace(self, trace_data: dict):
|
234
238
|
"""
|
@@ -356,6 +360,18 @@ class TraceClient:
|
|
356
360
|
self.executed_tools = []
|
357
361
|
self.executed_node_tools = []
|
358
362
|
self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
|
363
|
+
|
364
|
+
def get_current_span(self):
|
365
|
+
"""Get the current span from the context var"""
|
366
|
+
return current_span_var.get()
|
367
|
+
|
368
|
+
def set_current_span(self, span: Any):
|
369
|
+
"""Set the current span from the context var"""
|
370
|
+
return current_span_var.set(span)
|
371
|
+
|
372
|
+
def reset_current_span(self, token: Any):
|
373
|
+
"""Reset the current span from the context var"""
|
374
|
+
return current_span_var.reset(token)
|
359
375
|
|
360
376
|
@contextmanager
|
361
377
|
def span(self, name: str, span_type: SpanType = "span"):
|
@@ -874,27 +890,14 @@ class TraceClient:
|
|
874
890
|
"overwrite": overwrite,
|
875
891
|
"parent_trace_id": self.parent_trace_id,
|
876
892
|
"parent_name": self.parent_name
|
877
|
-
}
|
878
|
-
#
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
#
|
885
|
-
# headers={
|
886
|
-
# "Content-Type": "application/json",
|
887
|
-
# "Authorization": f"Bearer {self.tracer.api_key}",
|
888
|
-
# "X-Organization-Id": self.tracer.organization_id
|
889
|
-
# },
|
890
|
-
# verify=True
|
891
|
-
# )
|
892
|
-
|
893
|
-
# if response.status_code != HTTPStatus.OK:
|
894
|
-
# warnings.warn(f"Failed to add trace to evaluation queue: {response.text}")
|
895
|
-
# except Exception as e:
|
896
|
-
# warnings.warn(f"Error sending trace to evaluation queue: {str(e)}")
|
897
|
-
|
893
|
+
}
|
894
|
+
# --- Log trace data before saving ---
|
895
|
+
try:
|
896
|
+
rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
|
897
|
+
rprint(json.dumps(trace_data, indent=2))
|
898
|
+
except Exception as log_e:
|
899
|
+
rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
|
900
|
+
# --- End logging ---
|
898
901
|
self.trace_manager_client.save_trace(trace_data)
|
899
902
|
|
900
903
|
return self.trace_id, trace_data
|
@@ -917,7 +920,8 @@ class Tracer:
|
|
917
920
|
rules: Optional[List[Rule]] = None, # Added rules parameter
|
918
921
|
organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
|
919
922
|
enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
|
920
|
-
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true"
|
923
|
+
enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
|
924
|
+
deep_tracing: bool = True # NEW: Enable deep tracing by default
|
921
925
|
):
|
922
926
|
if not hasattr(self, 'initialized'):
|
923
927
|
if not api_key:
|
@@ -934,6 +938,7 @@ class Tracer:
|
|
934
938
|
self.initialized: bool = True
|
935
939
|
self.enable_monitoring: bool = enable_monitoring
|
936
940
|
self.enable_evaluations: bool = enable_evaluations
|
941
|
+
self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
|
937
942
|
elif hasattr(self, 'project_name') and self.project_name != project_name:
|
938
943
|
warnings.warn(
|
939
944
|
f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
|
@@ -941,7 +946,59 @@ class Tracer:
|
|
941
946
|
"To use a different project name, ensure the first Tracer initialization uses the desired project name.",
|
942
947
|
RuntimeWarning
|
943
948
|
)
|
949
|
+
|
950
|
+
def set_current_trace(self, trace: TraceClient):
|
951
|
+
"""
|
952
|
+
Set the current trace context in contextvars
|
953
|
+
"""
|
954
|
+
current_trace_var.set(trace)
|
955
|
+
|
956
|
+
def get_current_trace(self) -> Optional[TraceClient]:
|
957
|
+
"""
|
958
|
+
Get the current trace context from contextvars
|
959
|
+
"""
|
960
|
+
return current_trace_var.get()
|
961
|
+
|
962
|
+
def _apply_deep_tracing(self, func, span_type="span"):
|
963
|
+
"""
|
964
|
+
Apply deep tracing to all functions in the same module as the given function.
|
944
965
|
|
966
|
+
Args:
|
967
|
+
func: The function being traced
|
968
|
+
span_type: Type of span to use for traced functions
|
969
|
+
|
970
|
+
Returns:
|
971
|
+
A tuple of (module, original_functions_dict) where original_functions_dict
|
972
|
+
contains the original functions that were replaced with traced versions.
|
973
|
+
"""
|
974
|
+
module = inspect.getmodule(func)
|
975
|
+
if not module:
|
976
|
+
return None, {}
|
977
|
+
|
978
|
+
# Save original functions
|
979
|
+
original_functions = {}
|
980
|
+
|
981
|
+
# Find all functions in the module
|
982
|
+
for name, obj in inspect.getmembers(module, inspect.isfunction):
|
983
|
+
# Skip already wrapped functions
|
984
|
+
if hasattr(obj, '_judgment_traced'):
|
985
|
+
continue
|
986
|
+
|
987
|
+
# Create a traced version of the function
|
988
|
+
# Always use default span type "span" for child functions
|
989
|
+
traced_func = _create_deep_tracing_wrapper(obj, self, "span")
|
990
|
+
|
991
|
+
# Mark the function as traced to avoid double wrapping
|
992
|
+
traced_func._judgment_traced = True
|
993
|
+
|
994
|
+
# Save the original function
|
995
|
+
original_functions[name] = obj
|
996
|
+
|
997
|
+
# Replace with traced version
|
998
|
+
setattr(module, name, traced_func)
|
999
|
+
|
1000
|
+
return module, original_functions
|
1001
|
+
|
945
1002
|
@contextmanager
|
946
1003
|
def trace(
|
947
1004
|
self,
|
@@ -987,14 +1044,8 @@ class Tracer:
|
|
987
1044
|
finally:
|
988
1045
|
# Reset the context variable
|
989
1046
|
current_trace_var.reset(token)
|
990
|
-
|
991
|
-
def
|
992
|
-
"""
|
993
|
-
Get the current trace context from contextvars
|
994
|
-
"""
|
995
|
-
return current_trace_var.get()
|
996
|
-
|
997
|
-
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
|
1047
|
+
|
1048
|
+
def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
|
998
1049
|
"""
|
999
1050
|
Decorator to trace function execution with detailed entry/exit information.
|
1000
1051
|
|
@@ -1004,20 +1055,37 @@ class Tracer:
|
|
1004
1055
|
span_type: Type of span (default "span")
|
1005
1056
|
project_name: Optional project name override
|
1006
1057
|
overwrite: Whether to overwrite existing traces
|
1058
|
+
deep_tracing: Whether to enable deep tracing for this function and all nested calls.
|
1059
|
+
If None, uses the tracer's default setting.
|
1007
1060
|
"""
|
1008
1061
|
# If monitoring is disabled, return the function as is
|
1009
1062
|
if not self.enable_monitoring:
|
1010
1063
|
return func if func else lambda f: f
|
1011
1064
|
|
1012
1065
|
if func is None:
|
1013
|
-
return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
|
1066
|
+
return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
|
1067
|
+
overwrite=overwrite, deep_tracing=deep_tracing)
|
1014
1068
|
|
1015
1069
|
# Use provided name or fall back to function name
|
1016
1070
|
span_name = name or func.__name__
|
1017
1071
|
|
1072
|
+
# Store custom attributes on the function object
|
1073
|
+
func._judgment_span_name = span_name
|
1074
|
+
func._judgment_span_type = span_type
|
1075
|
+
|
1076
|
+
# Use the provided deep_tracing value or fall back to the tracer's default
|
1077
|
+
use_deep_tracing = deep_tracing if deep_tracing is not None else self.deep_tracing
|
1078
|
+
|
1018
1079
|
if asyncio.iscoroutinefunction(func):
|
1019
1080
|
@functools.wraps(func)
|
1020
1081
|
async def async_wrapper(*args, **kwargs):
|
1082
|
+
# Check if we're already in a traced function
|
1083
|
+
if in_traced_function_var.get():
|
1084
|
+
return await func(*args, **kwargs)
|
1085
|
+
|
1086
|
+
# Set in_traced_function_var to True
|
1087
|
+
token = in_traced_function_var.set(True)
|
1088
|
+
|
1021
1089
|
# Get current trace from context
|
1022
1090
|
current_trace = current_trace_var.get()
|
1023
1091
|
|
@@ -1052,9 +1120,18 @@ class Tracer:
|
|
1052
1120
|
'kwargs': kwargs
|
1053
1121
|
})
|
1054
1122
|
|
1123
|
+
# If deep tracing is enabled, apply monkey patching
|
1124
|
+
if use_deep_tracing:
|
1125
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1126
|
+
|
1055
1127
|
# Execute function
|
1056
1128
|
result = await func(*args, **kwargs)
|
1057
1129
|
|
1130
|
+
# Restore original functions if deep tracing was enabled
|
1131
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1132
|
+
for name, obj in original_functions.items():
|
1133
|
+
setattr(module, name, obj)
|
1134
|
+
|
1058
1135
|
# Record output
|
1059
1136
|
span.record_output(result)
|
1060
1137
|
|
@@ -1064,29 +1141,52 @@ class Tracer:
|
|
1064
1141
|
finally:
|
1065
1142
|
# Reset trace context (span context resets automatically)
|
1066
1143
|
current_trace_var.reset(trace_token)
|
1144
|
+
# Reset in_traced_function_var
|
1145
|
+
in_traced_function_var.reset(token)
|
1067
1146
|
else:
|
1068
1147
|
# Already have a trace context, just create a span in it
|
1069
1148
|
# The span method handles current_span_var
|
1070
|
-
|
1071
|
-
|
1072
|
-
span
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1149
|
+
|
1150
|
+
try:
|
1151
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1152
|
+
# Record inputs
|
1153
|
+
span.record_input({
|
1154
|
+
'args': str(args),
|
1155
|
+
'kwargs': kwargs
|
1156
|
+
})
|
1157
|
+
|
1158
|
+
# If deep tracing is enabled, apply monkey patching
|
1159
|
+
if use_deep_tracing:
|
1160
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1161
|
+
|
1162
|
+
# Execute function
|
1163
|
+
result = await func(*args, **kwargs)
|
1164
|
+
|
1165
|
+
# Restore original functions if deep tracing was enabled
|
1166
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1167
|
+
for name, obj in original_functions.items():
|
1168
|
+
setattr(module, name, obj)
|
1169
|
+
|
1170
|
+
# Record output
|
1171
|
+
span.record_output(result)
|
1082
1172
|
|
1083
1173
|
return result
|
1084
|
-
|
1174
|
+
finally:
|
1175
|
+
# Reset in_traced_function_var
|
1176
|
+
in_traced_function_var.reset(token)
|
1177
|
+
|
1085
1178
|
return async_wrapper
|
1086
1179
|
else:
|
1087
|
-
# Non-async function implementation
|
1180
|
+
# Non-async function implementation with deep tracing
|
1088
1181
|
@functools.wraps(func)
|
1089
1182
|
def wrapper(*args, **kwargs):
|
1183
|
+
# Check if we're already in a traced function
|
1184
|
+
if in_traced_function_var.get():
|
1185
|
+
return func(*args, **kwargs)
|
1186
|
+
|
1187
|
+
# Set in_traced_function_var to True
|
1188
|
+
token = in_traced_function_var.set(True)
|
1189
|
+
|
1090
1190
|
# Get current trace from context
|
1091
1191
|
current_trace = current_trace_var.get()
|
1092
1192
|
|
@@ -1121,9 +1221,18 @@ class Tracer:
|
|
1121
1221
|
'kwargs': kwargs
|
1122
1222
|
})
|
1123
1223
|
|
1224
|
+
# If deep tracing is enabled, apply monkey patching
|
1225
|
+
if use_deep_tracing:
|
1226
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1227
|
+
|
1124
1228
|
# Execute function
|
1125
1229
|
result = func(*args, **kwargs)
|
1126
1230
|
|
1231
|
+
# Restore original functions if deep tracing was enabled
|
1232
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1233
|
+
for name, obj in original_functions.items():
|
1234
|
+
setattr(module, name, obj)
|
1235
|
+
|
1127
1236
|
# Record output
|
1128
1237
|
span.record_output(result)
|
1129
1238
|
|
@@ -1133,24 +1242,40 @@ class Tracer:
|
|
1133
1242
|
finally:
|
1134
1243
|
# Reset trace context (span context resets automatically)
|
1135
1244
|
current_trace_var.reset(trace_token)
|
1245
|
+
# Reset in_traced_function_var
|
1246
|
+
in_traced_function_var.reset(token)
|
1136
1247
|
else:
|
1137
1248
|
# Already have a trace context, just create a span in it
|
1138
1249
|
# The span method handles current_span_var
|
1139
|
-
|
1140
|
-
|
1141
|
-
span
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1250
|
+
|
1251
|
+
try:
|
1252
|
+
with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
|
1253
|
+
# Record inputs
|
1254
|
+
span.record_input({
|
1255
|
+
'args': str(args),
|
1256
|
+
'kwargs': kwargs
|
1257
|
+
})
|
1258
|
+
|
1259
|
+
# If deep tracing is enabled, apply monkey patching
|
1260
|
+
if use_deep_tracing:
|
1261
|
+
module, original_functions = self._apply_deep_tracing(func, span_type)
|
1262
|
+
|
1263
|
+
# Execute function
|
1264
|
+
result = func(*args, **kwargs)
|
1265
|
+
|
1266
|
+
# Restore original functions if deep tracing was enabled
|
1267
|
+
if use_deep_tracing and module and 'original_functions' in locals():
|
1268
|
+
for name, obj in original_functions.items():
|
1269
|
+
setattr(module, name, obj)
|
1270
|
+
|
1271
|
+
# Record output
|
1272
|
+
span.record_output(result)
|
1151
1273
|
|
1152
1274
|
return result
|
1153
|
-
|
1275
|
+
finally:
|
1276
|
+
# Reset in_traced_function_var
|
1277
|
+
in_traced_function_var.reset(token)
|
1278
|
+
|
1154
1279
|
return wrapper
|
1155
1280
|
|
1156
1281
|
def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
|
@@ -1199,34 +1324,69 @@ def wrap(client: Any) -> Any:
|
|
1199
1324
|
"""
|
1200
1325
|
# Get the appropriate configuration for this client type
|
1201
1326
|
span_name, original_create = _get_client_config(client)
|
1202
|
-
|
1203
|
-
|
1204
|
-
|
1205
|
-
|
1206
|
-
|
1207
|
-
|
1208
|
-
if not current_trace:
|
1209
|
-
return original_create(*args, **kwargs)
|
1210
|
-
|
1211
|
-
with current_trace.span(span_name, span_type="llm") as span:
|
1212
|
-
# Format and record the input parameters
|
1213
|
-
input_data = _format_input_data(client, **kwargs)
|
1214
|
-
span.record_input(input_data)
|
1215
|
-
|
1216
|
-
# Make the actual API call
|
1217
|
-
response = original_create(*args, **kwargs)
|
1327
|
+
|
1328
|
+
# Handle async clients differently than synchronous clients (need an async function for async clients)
|
1329
|
+
if (isinstance(client, (AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.client.AsyncClient))):
|
1330
|
+
async def traced_create(*args, **kwargs):
|
1331
|
+
# Get the current trace from contextvars
|
1332
|
+
current_trace = current_trace_var.get()
|
1218
1333
|
|
1219
|
-
#
|
1220
|
-
|
1221
|
-
|
1334
|
+
# Skip tracing if no active trace
|
1335
|
+
if not current_trace:
|
1336
|
+
return original_create(*args, **kwargs)
|
1337
|
+
|
1338
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1339
|
+
# Format and record the input parameters
|
1340
|
+
input_data = _format_input_data(client, **kwargs)
|
1341
|
+
span.record_input(input_data)
|
1342
|
+
|
1343
|
+
# Make the actual API call
|
1344
|
+
try:
|
1345
|
+
response = await original_create(*args, **kwargs)
|
1346
|
+
except Exception as e:
|
1347
|
+
print(f"Error during API call: {e}")
|
1348
|
+
raise
|
1349
|
+
|
1350
|
+
# Format and record the output
|
1351
|
+
output_data = _format_output_data(client, response)
|
1352
|
+
span.record_output(output_data)
|
1353
|
+
|
1354
|
+
return response
|
1355
|
+
else:
|
1356
|
+
def traced_create(*args, **kwargs):
|
1357
|
+
# Get the current trace from contextvars
|
1358
|
+
current_trace = current_trace_var.get()
|
1222
1359
|
|
1223
|
-
|
1360
|
+
# Skip tracing if no active trace
|
1361
|
+
if not current_trace:
|
1362
|
+
return original_create(*args, **kwargs)
|
1363
|
+
|
1364
|
+
with current_trace.span(span_name, span_type="llm") as span:
|
1365
|
+
# Format and record the input parameters
|
1366
|
+
input_data = _format_input_data(client, **kwargs)
|
1367
|
+
span.record_input(input_data)
|
1368
|
+
|
1369
|
+
# Make the actual API call
|
1370
|
+
try:
|
1371
|
+
response = original_create(*args, **kwargs)
|
1372
|
+
except Exception as e:
|
1373
|
+
print(f"Error during API call: {e}")
|
1374
|
+
raise
|
1375
|
+
|
1376
|
+
# Format and record the output
|
1377
|
+
output_data = _format_output_data(client, response)
|
1378
|
+
span.record_output(output_data)
|
1379
|
+
|
1380
|
+
return response
|
1381
|
+
|
1224
1382
|
|
1225
1383
|
# Replace the original method with our traced version
|
1226
|
-
if isinstance(client, (OpenAI, Together)):
|
1384
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1227
1385
|
client.chat.completions.create = traced_create
|
1228
|
-
elif isinstance(client, Anthropic):
|
1386
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1229
1387
|
client.messages.create = traced_create
|
1388
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1389
|
+
client.models.generate_content = traced_create
|
1230
1390
|
|
1231
1391
|
return client
|
1232
1392
|
|
@@ -1246,12 +1406,14 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable]:
|
|
1246
1406
|
Raises:
|
1247
1407
|
ValueError: If client type is not supported
|
1248
1408
|
"""
|
1249
|
-
if isinstance(client, OpenAI):
|
1409
|
+
if isinstance(client, (OpenAI, AsyncOpenAI)):
|
1250
1410
|
return "OPENAI_API_CALL", client.chat.completions.create
|
1251
|
-
elif isinstance(client, Together):
|
1411
|
+
elif isinstance(client, (Together, AsyncTogether)):
|
1252
1412
|
return "TOGETHER_API_CALL", client.chat.completions.create
|
1253
|
-
elif isinstance(client, Anthropic):
|
1413
|
+
elif isinstance(client, (Anthropic, AsyncAnthropic)):
|
1254
1414
|
return "ANTHROPIC_API_CALL", client.messages.create
|
1415
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1416
|
+
return "GOOGLE_API_CALL", client.models.generate_content
|
1255
1417
|
raise ValueError(f"Unsupported client type: {type(client)}")
|
1256
1418
|
|
1257
1419
|
def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
@@ -1260,11 +1422,16 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
|
|
1260
1422
|
Extracts relevant parameters from kwargs based on the client type
|
1261
1423
|
to ensure consistent tracing across different APIs.
|
1262
1424
|
"""
|
1263
|
-
if isinstance(client, (OpenAI, Together)):
|
1425
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1264
1426
|
return {
|
1265
1427
|
"model": kwargs.get("model"),
|
1266
1428
|
"messages": kwargs.get("messages"),
|
1267
1429
|
}
|
1430
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1431
|
+
return {
|
1432
|
+
"model": kwargs.get("model"),
|
1433
|
+
"contents": kwargs.get("contents")
|
1434
|
+
}
|
1268
1435
|
# Anthropic requires additional max_tokens parameter
|
1269
1436
|
return {
|
1270
1437
|
"model": kwargs.get("model"),
|
@@ -1283,7 +1450,7 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1283
1450
|
- content: The generated text
|
1284
1451
|
- usage: Token usage statistics
|
1285
1452
|
"""
|
1286
|
-
if isinstance(client, (OpenAI, Together)):
|
1453
|
+
if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
|
1287
1454
|
return {
|
1288
1455
|
"content": response.choices[0].message.content,
|
1289
1456
|
"usage": {
|
@@ -1292,6 +1459,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1292
1459
|
"total_tokens": response.usage.total_tokens
|
1293
1460
|
}
|
1294
1461
|
}
|
1462
|
+
elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
|
1463
|
+
return {
|
1464
|
+
"content": response.candidates[0].content.parts[0].text,
|
1465
|
+
"usage": {
|
1466
|
+
"prompt_tokens": response.usage_metadata.prompt_token_count,
|
1467
|
+
"completion_tokens": response.usage_metadata.candidates_token_count,
|
1468
|
+
"total_tokens": response.usage_metadata.total_token_count
|
1469
|
+
}
|
1470
|
+
}
|
1295
1471
|
# Anthropic has a different response structure
|
1296
1472
|
return {
|
1297
1473
|
"content": response.content[0].text,
|
@@ -1302,29 +1478,88 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
|
|
1302
1478
|
}
|
1303
1479
|
}
|
1304
1480
|
|
1305
|
-
# Add a
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
|
1322
|
-
|
1323
|
-
#
|
1324
|
-
|
1325
|
-
|
1326
|
-
#
|
1327
|
-
|
1481
|
+
# Add a new function for deep tracing at the module level
|
1482
|
+
def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
|
1483
|
+
"""
|
1484
|
+
Creates a wrapper for a function that automatically traces it when called within a traced function.
|
1485
|
+
This enables deep tracing without requiring explicit @observe decorators on every function.
|
1486
|
+
|
1487
|
+
Args:
|
1488
|
+
func: The function to wrap
|
1489
|
+
tracer: The Tracer instance
|
1490
|
+
span_type: Type of span (default "span")
|
1491
|
+
|
1492
|
+
Returns:
|
1493
|
+
A wrapped function that will be traced when called
|
1494
|
+
"""
|
1495
|
+
# Skip wrapping if the function is not callable or is a built-in
|
1496
|
+
if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
|
1497
|
+
return func
|
1498
|
+
|
1499
|
+
# Get function name for the span - check for custom name set by @observe
|
1500
|
+
func_name = getattr(func, '_judgment_span_name', func.__name__)
|
1501
|
+
|
1502
|
+
# Check for custom span_type set by @observe
|
1503
|
+
func_span_type = getattr(func, '_judgment_span_type', "span")
|
1504
|
+
|
1505
|
+
# Store original function to prevent losing reference
|
1506
|
+
original_func = func
|
1507
|
+
|
1508
|
+
# Create appropriate wrapper based on whether the function is async or not
|
1509
|
+
if asyncio.iscoroutinefunction(func):
|
1510
|
+
@functools.wraps(func)
|
1511
|
+
async def async_deep_wrapper(*args, **kwargs):
|
1512
|
+
# Get current trace from context
|
1513
|
+
current_trace = current_trace_var.get()
|
1514
|
+
|
1515
|
+
# If no trace context, just call the function
|
1516
|
+
if not current_trace:
|
1517
|
+
return await original_func(*args, **kwargs)
|
1518
|
+
|
1519
|
+
# Create a span for this function call - use custom span_type if available
|
1520
|
+
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1521
|
+
# Record inputs
|
1522
|
+
span.record_input({
|
1523
|
+
'args': str(args),
|
1524
|
+
'kwargs': kwargs
|
1525
|
+
})
|
1526
|
+
|
1527
|
+
# Execute function
|
1528
|
+
result = await original_func(*args, **kwargs)
|
1529
|
+
|
1530
|
+
# Record output
|
1531
|
+
span.record_output(result)
|
1532
|
+
|
1533
|
+
return result
|
1534
|
+
|
1535
|
+
return async_deep_wrapper
|
1536
|
+
else:
|
1537
|
+
@functools.wraps(func)
|
1538
|
+
def deep_wrapper(*args, **kwargs):
|
1539
|
+
# Get current trace from context
|
1540
|
+
current_trace = current_trace_var.get()
|
1541
|
+
|
1542
|
+
# If no trace context, just call the function
|
1543
|
+
if not current_trace:
|
1544
|
+
return original_func(*args, **kwargs)
|
1545
|
+
|
1546
|
+
# Create a span for this function call - use custom span_type if available
|
1547
|
+
with current_trace.span(func_name, span_type=func_span_type) as span:
|
1548
|
+
# Record inputs
|
1549
|
+
span.record_input({
|
1550
|
+
'args': str(args),
|
1551
|
+
'kwargs': kwargs
|
1552
|
+
})
|
1553
|
+
|
1554
|
+
# Execute function
|
1555
|
+
result = original_func(*args, **kwargs)
|
1556
|
+
|
1557
|
+
# Record output
|
1558
|
+
span.record_output(result)
|
1559
|
+
|
1560
|
+
return result
|
1561
|
+
|
1562
|
+
return deep_wrapper
|
1328
1563
|
|
1329
1564
|
# Add the new TraceThreadPoolExecutor class
|
1330
1565
|
class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
|