judgeval 0.0.30__py3-none-any.whl → 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
judgeval/common/tracer.py CHANGED
@@ -11,11 +11,12 @@ import time
11
11
  import uuid
12
12
  import warnings
13
13
  import contextvars
14
+ import sys
14
15
  from contextlib import contextmanager
15
16
  from dataclasses import dataclass, field
16
17
  from datetime import datetime
17
18
  from http import HTTPStatus
18
- from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
19
+ from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable, Set
19
20
  from rich import print as rprint
20
21
 
21
22
  # Third-party imports
@@ -24,9 +25,10 @@ import requests
24
25
  from litellm import cost_per_token
25
26
  from pydantic import BaseModel
26
27
  from rich import print as rprint
27
- from openai import OpenAI
28
- from together import Together
29
- from anthropic import Anthropic
28
+ from openai import OpenAI, AsyncOpenAI
29
+ from together import Together, AsyncTogether
30
+ from anthropic import Anthropic, AsyncAnthropic
31
+ from google import genai
30
32
 
31
33
  # Local application/library-specific imports
32
34
  from judgeval.constants import (
@@ -37,7 +39,6 @@ from judgeval.constants import (
37
39
  RABBITMQ_QUEUE,
38
40
  JUDGMENT_TRACES_DELETE_API_URL,
39
41
  JUDGMENT_PROJECT_DELETE_API_URL,
40
- JUDGMENT_TRACES_ADD_TO_EVAL_QUEUE_API_URL
41
42
  )
42
43
  from judgeval.judgment_client import JudgmentClient
43
44
  from judgeval.data import Example
@@ -51,10 +52,11 @@ import concurrent.futures
51
52
 
52
53
  # Define context variables for tracking the current trace and the current span within a trace
53
54
  current_trace_var = contextvars.ContextVar('current_trace', default=None)
54
- current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
55
+ current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
56
+ in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
55
57
 
56
58
  # Define type aliases for better code readability and maintainability
57
- ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic] # Supported API clients
59
+ ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
58
60
  TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
59
61
  SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
60
62
  @dataclass
@@ -69,11 +71,11 @@ class TraceEntry:
69
71
  - evaluation: Evaluation: (evaluation results)
70
72
  """
71
73
  type: TraceEntryType
72
- function: str # Name of the function being traced
73
74
  span_id: str # Unique ID for this specific span instance
74
75
  depth: int # Indentation level for nested calls
75
- message: str # Human-readable description
76
76
  created_at: float # Unix timestamp when entry was created, replacing the deprecated 'timestamp' field
77
+ function: Optional[str] = None # Name of the function being traced
78
+ message: Optional[str] = None # Human-readable description
77
79
  duration: Optional[float] = None # Time taken (for exit/evaluation entries)
78
80
  trace_id: str = None # ID of the trace this entry belongs to
79
81
  output: Any = None # Function output value
@@ -229,6 +231,8 @@ class TraceManagerClient:
229
231
  raise ValueError(f"Failed to fetch traces: {response.text}")
230
232
 
231
233
  return response.json()
234
+
235
+
232
236
 
233
237
  def save_trace(self, trace_data: dict):
234
238
  """
@@ -356,6 +360,18 @@ class TraceClient:
356
360
  self.executed_tools = []
357
361
  self.executed_node_tools = []
358
362
  self._span_depths: Dict[str, int] = {} # NEW: To track depth of active spans
363
+
364
+ def get_current_span(self):
365
+ """Get the current span from the context var"""
366
+ return current_span_var.get()
367
+
368
+ def set_current_span(self, span: Any):
369
+ """Set the current span from the context var"""
370
+ return current_span_var.set(span)
371
+
372
+ def reset_current_span(self, token: Any):
373
+ """Reset the current span from the context var"""
374
+ return current_span_var.reset(token)
359
375
 
360
376
  @contextmanager
361
377
  def span(self, name: str, span_type: SpanType = "span"):
@@ -874,27 +890,14 @@ class TraceClient:
874
890
  "overwrite": overwrite,
875
891
  "parent_trace_id": self.parent_trace_id,
876
892
  "parent_name": self.parent_name
877
- }
878
- # Execute asynchrous evaluation in the background
879
- # if not empty_save: # Only send to RabbitMQ if the trace is not empty
880
- # # Send trace data to evaluation queue via API
881
- # try:
882
- # response = requests.post(
883
- # JUDGMENT_TRACES_ADD_TO_EVAL_QUEUE_API_URL,
884
- # json=trace_data,
885
- # headers={
886
- # "Content-Type": "application/json",
887
- # "Authorization": f"Bearer {self.tracer.api_key}",
888
- # "X-Organization-Id": self.tracer.organization_id
889
- # },
890
- # verify=True
891
- # )
892
-
893
- # if response.status_code != HTTPStatus.OK:
894
- # warnings.warn(f"Failed to add trace to evaluation queue: {response.text}")
895
- # except Exception as e:
896
- # warnings.warn(f"Error sending trace to evaluation queue: {str(e)}")
897
-
893
+ }
894
+ # --- Log trace data before saving ---
895
+ try:
896
+ rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
897
+ rprint(json.dumps(trace_data, indent=2))
898
+ except Exception as log_e:
899
+ rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
900
+ # --- End logging ---
898
901
  self.trace_manager_client.save_trace(trace_data)
899
902
 
900
903
  return self.trace_id, trace_data
@@ -917,7 +920,8 @@ class Tracer:
917
920
  rules: Optional[List[Rule]] = None, # Added rules parameter
918
921
  organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
919
922
  enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
920
- enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true"
923
+ enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
924
+ deep_tracing: bool = True # NEW: Enable deep tracing by default
921
925
  ):
922
926
  if not hasattr(self, 'initialized'):
923
927
  if not api_key:
@@ -934,6 +938,7 @@ class Tracer:
934
938
  self.initialized: bool = True
935
939
  self.enable_monitoring: bool = enable_monitoring
936
940
  self.enable_evaluations: bool = enable_evaluations
941
+ self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
937
942
  elif hasattr(self, 'project_name') and self.project_name != project_name:
938
943
  warnings.warn(
939
944
  f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
@@ -941,7 +946,59 @@ class Tracer:
941
946
  "To use a different project name, ensure the first Tracer initialization uses the desired project name.",
942
947
  RuntimeWarning
943
948
  )
949
+
950
+ def set_current_trace(self, trace: TraceClient):
951
+ """
952
+ Set the current trace context in contextvars
953
+ """
954
+ current_trace_var.set(trace)
955
+
956
+ def get_current_trace(self) -> Optional[TraceClient]:
957
+ """
958
+ Get the current trace context from contextvars
959
+ """
960
+ return current_trace_var.get()
961
+
962
+ def _apply_deep_tracing(self, func, span_type="span"):
963
+ """
964
+ Apply deep tracing to all functions in the same module as the given function.
944
965
 
966
+ Args:
967
+ func: The function being traced
968
+ span_type: Type of span to use for traced functions
969
+
970
+ Returns:
971
+ A tuple of (module, original_functions_dict) where original_functions_dict
972
+ contains the original functions that were replaced with traced versions.
973
+ """
974
+ module = inspect.getmodule(func)
975
+ if not module:
976
+ return None, {}
977
+
978
+ # Save original functions
979
+ original_functions = {}
980
+
981
+ # Find all functions in the module
982
+ for name, obj in inspect.getmembers(module, inspect.isfunction):
983
+ # Skip already wrapped functions
984
+ if hasattr(obj, '_judgment_traced'):
985
+ continue
986
+
987
+ # Create a traced version of the function
988
+ # Always use default span type "span" for child functions
989
+ traced_func = _create_deep_tracing_wrapper(obj, self, "span")
990
+
991
+ # Mark the function as traced to avoid double wrapping
992
+ traced_func._judgment_traced = True
993
+
994
+ # Save the original function
995
+ original_functions[name] = obj
996
+
997
+ # Replace with traced version
998
+ setattr(module, name, traced_func)
999
+
1000
+ return module, original_functions
1001
+
945
1002
  @contextmanager
946
1003
  def trace(
947
1004
  self,
@@ -987,14 +1044,8 @@ class Tracer:
987
1044
  finally:
988
1045
  # Reset the context variable
989
1046
  current_trace_var.reset(token)
990
-
991
- def get_current_trace(self) -> Optional[TraceClient]:
992
- """
993
- Get the current trace context from contextvars
994
- """
995
- return current_trace_var.get()
996
-
997
- def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
1047
+
1048
+ def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
998
1049
  """
999
1050
  Decorator to trace function execution with detailed entry/exit information.
1000
1051
 
@@ -1004,20 +1055,37 @@ class Tracer:
1004
1055
  span_type: Type of span (default "span")
1005
1056
  project_name: Optional project name override
1006
1057
  overwrite: Whether to overwrite existing traces
1058
+ deep_tracing: Whether to enable deep tracing for this function and all nested calls.
1059
+ If None, uses the tracer's default setting.
1007
1060
  """
1008
1061
  # If monitoring is disabled, return the function as is
1009
1062
  if not self.enable_monitoring:
1010
1063
  return func if func else lambda f: f
1011
1064
 
1012
1065
  if func is None:
1013
- return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name, overwrite=overwrite)
1066
+ return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
1067
+ overwrite=overwrite, deep_tracing=deep_tracing)
1014
1068
 
1015
1069
  # Use provided name or fall back to function name
1016
1070
  span_name = name or func.__name__
1017
1071
 
1072
+ # Store custom attributes on the function object
1073
+ func._judgment_span_name = span_name
1074
+ func._judgment_span_type = span_type
1075
+
1076
+ # Use the provided deep_tracing value or fall back to the tracer's default
1077
+ use_deep_tracing = deep_tracing if deep_tracing is not None else self.deep_tracing
1078
+
1018
1079
  if asyncio.iscoroutinefunction(func):
1019
1080
  @functools.wraps(func)
1020
1081
  async def async_wrapper(*args, **kwargs):
1082
+ # Check if we're already in a traced function
1083
+ if in_traced_function_var.get():
1084
+ return await func(*args, **kwargs)
1085
+
1086
+ # Set in_traced_function_var to True
1087
+ token = in_traced_function_var.set(True)
1088
+
1021
1089
  # Get current trace from context
1022
1090
  current_trace = current_trace_var.get()
1023
1091
 
@@ -1052,9 +1120,18 @@ class Tracer:
1052
1120
  'kwargs': kwargs
1053
1121
  })
1054
1122
 
1123
+ # If deep tracing is enabled, apply monkey patching
1124
+ if use_deep_tracing:
1125
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1126
+
1055
1127
  # Execute function
1056
1128
  result = await func(*args, **kwargs)
1057
1129
 
1130
+ # Restore original functions if deep tracing was enabled
1131
+ if use_deep_tracing and module and 'original_functions' in locals():
1132
+ for name, obj in original_functions.items():
1133
+ setattr(module, name, obj)
1134
+
1058
1135
  # Record output
1059
1136
  span.record_output(result)
1060
1137
 
@@ -1064,29 +1141,52 @@ class Tracer:
1064
1141
  finally:
1065
1142
  # Reset trace context (span context resets automatically)
1066
1143
  current_trace_var.reset(trace_token)
1144
+ # Reset in_traced_function_var
1145
+ in_traced_function_var.reset(token)
1067
1146
  else:
1068
1147
  # Already have a trace context, just create a span in it
1069
1148
  # The span method handles current_span_var
1070
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1071
- # Record inputs
1072
- span.record_input({
1073
- 'args': str(args),
1074
- 'kwargs': kwargs
1075
- })
1076
-
1077
- # Execute function
1078
- result = await func(*args, **kwargs)
1079
-
1080
- # Record output
1081
- span.record_output(result)
1149
+
1150
+ try:
1151
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1152
+ # Record inputs
1153
+ span.record_input({
1154
+ 'args': str(args),
1155
+ 'kwargs': kwargs
1156
+ })
1157
+
1158
+ # If deep tracing is enabled, apply monkey patching
1159
+ if use_deep_tracing:
1160
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1161
+
1162
+ # Execute function
1163
+ result = await func(*args, **kwargs)
1164
+
1165
+ # Restore original functions if deep tracing was enabled
1166
+ if use_deep_tracing and module and 'original_functions' in locals():
1167
+ for name, obj in original_functions.items():
1168
+ setattr(module, name, obj)
1169
+
1170
+ # Record output
1171
+ span.record_output(result)
1082
1172
 
1083
1173
  return result
1084
-
1174
+ finally:
1175
+ # Reset in_traced_function_var
1176
+ in_traced_function_var.reset(token)
1177
+
1085
1178
  return async_wrapper
1086
1179
  else:
1087
- # Non-async function implementation remains unchanged
1180
+ # Non-async function implementation with deep tracing
1088
1181
  @functools.wraps(func)
1089
1182
  def wrapper(*args, **kwargs):
1183
+ # Check if we're already in a traced function
1184
+ if in_traced_function_var.get():
1185
+ return func(*args, **kwargs)
1186
+
1187
+ # Set in_traced_function_var to True
1188
+ token = in_traced_function_var.set(True)
1189
+
1090
1190
  # Get current trace from context
1091
1191
  current_trace = current_trace_var.get()
1092
1192
 
@@ -1121,9 +1221,18 @@ class Tracer:
1121
1221
  'kwargs': kwargs
1122
1222
  })
1123
1223
 
1224
+ # If deep tracing is enabled, apply monkey patching
1225
+ if use_deep_tracing:
1226
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1227
+
1124
1228
  # Execute function
1125
1229
  result = func(*args, **kwargs)
1126
1230
 
1231
+ # Restore original functions if deep tracing was enabled
1232
+ if use_deep_tracing and module and 'original_functions' in locals():
1233
+ for name, obj in original_functions.items():
1234
+ setattr(module, name, obj)
1235
+
1127
1236
  # Record output
1128
1237
  span.record_output(result)
1129
1238
 
@@ -1133,24 +1242,40 @@ class Tracer:
1133
1242
  finally:
1134
1243
  # Reset trace context (span context resets automatically)
1135
1244
  current_trace_var.reset(trace_token)
1245
+ # Reset in_traced_function_var
1246
+ in_traced_function_var.reset(token)
1136
1247
  else:
1137
1248
  # Already have a trace context, just create a span in it
1138
1249
  # The span method handles current_span_var
1139
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1140
- # Record inputs
1141
- span.record_input({
1142
- 'args': str(args),
1143
- 'kwargs': kwargs
1144
- })
1145
-
1146
- # Execute function
1147
- result = func(*args, **kwargs)
1148
-
1149
- # Record output
1150
- span.record_output(result)
1250
+
1251
+ try:
1252
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1253
+ # Record inputs
1254
+ span.record_input({
1255
+ 'args': str(args),
1256
+ 'kwargs': kwargs
1257
+ })
1258
+
1259
+ # If deep tracing is enabled, apply monkey patching
1260
+ if use_deep_tracing:
1261
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1262
+
1263
+ # Execute function
1264
+ result = func(*args, **kwargs)
1265
+
1266
+ # Restore original functions if deep tracing was enabled
1267
+ if use_deep_tracing and module and 'original_functions' in locals():
1268
+ for name, obj in original_functions.items():
1269
+ setattr(module, name, obj)
1270
+
1271
+ # Record output
1272
+ span.record_output(result)
1151
1273
 
1152
1274
  return result
1153
-
1275
+ finally:
1276
+ # Reset in_traced_function_var
1277
+ in_traced_function_var.reset(token)
1278
+
1154
1279
  return wrapper
1155
1280
 
1156
1281
  def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
@@ -1199,34 +1324,69 @@ def wrap(client: Any) -> Any:
1199
1324
  """
1200
1325
  # Get the appropriate configuration for this client type
1201
1326
  span_name, original_create = _get_client_config(client)
1202
-
1203
- def traced_create(*args, **kwargs):
1204
- # Get the current trace from contextvars
1205
- current_trace = current_trace_var.get()
1206
-
1207
- # Skip tracing if no active trace
1208
- if not current_trace:
1209
- return original_create(*args, **kwargs)
1210
-
1211
- with current_trace.span(span_name, span_type="llm") as span:
1212
- # Format and record the input parameters
1213
- input_data = _format_input_data(client, **kwargs)
1214
- span.record_input(input_data)
1215
-
1216
- # Make the actual API call
1217
- response = original_create(*args, **kwargs)
1327
+
1328
+ # Handle async clients differently than synchronous clients (need an async function for async clients)
1329
+ if (isinstance(client, (AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.client.AsyncClient))):
1330
+ async def traced_create(*args, **kwargs):
1331
+ # Get the current trace from contextvars
1332
+ current_trace = current_trace_var.get()
1218
1333
 
1219
- # Format and record the output
1220
- output_data = _format_output_data(client, response)
1221
- span.record_output(output_data)
1334
+ # Skip tracing if no active trace
1335
+ if not current_trace:
1336
+ return original_create(*args, **kwargs)
1337
+
1338
+ with current_trace.span(span_name, span_type="llm") as span:
1339
+ # Format and record the input parameters
1340
+ input_data = _format_input_data(client, **kwargs)
1341
+ span.record_input(input_data)
1342
+
1343
+ # Make the actual API call
1344
+ try:
1345
+ response = await original_create(*args, **kwargs)
1346
+ except Exception as e:
1347
+ print(f"Error during API call: {e}")
1348
+ raise
1349
+
1350
+ # Format and record the output
1351
+ output_data = _format_output_data(client, response)
1352
+ span.record_output(output_data)
1353
+
1354
+ return response
1355
+ else:
1356
+ def traced_create(*args, **kwargs):
1357
+ # Get the current trace from contextvars
1358
+ current_trace = current_trace_var.get()
1222
1359
 
1223
- return response
1360
+ # Skip tracing if no active trace
1361
+ if not current_trace:
1362
+ return original_create(*args, **kwargs)
1363
+
1364
+ with current_trace.span(span_name, span_type="llm") as span:
1365
+ # Format and record the input parameters
1366
+ input_data = _format_input_data(client, **kwargs)
1367
+ span.record_input(input_data)
1368
+
1369
+ # Make the actual API call
1370
+ try:
1371
+ response = original_create(*args, **kwargs)
1372
+ except Exception as e:
1373
+ print(f"Error during API call: {e}")
1374
+ raise
1375
+
1376
+ # Format and record the output
1377
+ output_data = _format_output_data(client, response)
1378
+ span.record_output(output_data)
1379
+
1380
+ return response
1381
+
1224
1382
 
1225
1383
  # Replace the original method with our traced version
1226
- if isinstance(client, (OpenAI, Together)):
1384
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1227
1385
  client.chat.completions.create = traced_create
1228
- elif isinstance(client, Anthropic):
1386
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1229
1387
  client.messages.create = traced_create
1388
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1389
+ client.models.generate_content = traced_create
1230
1390
 
1231
1391
  return client
1232
1392
 
@@ -1246,12 +1406,14 @@ def _get_client_config(client: ApiClient) -> tuple[str, callable]:
1246
1406
  Raises:
1247
1407
  ValueError: If client type is not supported
1248
1408
  """
1249
- if isinstance(client, OpenAI):
1409
+ if isinstance(client, (OpenAI, AsyncOpenAI)):
1250
1410
  return "OPENAI_API_CALL", client.chat.completions.create
1251
- elif isinstance(client, Together):
1411
+ elif isinstance(client, (Together, AsyncTogether)):
1252
1412
  return "TOGETHER_API_CALL", client.chat.completions.create
1253
- elif isinstance(client, Anthropic):
1413
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1254
1414
  return "ANTHROPIC_API_CALL", client.messages.create
1415
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1416
+ return "GOOGLE_API_CALL", client.models.generate_content
1255
1417
  raise ValueError(f"Unsupported client type: {type(client)}")
1256
1418
 
1257
1419
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
@@ -1260,11 +1422,16 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
1260
1422
  Extracts relevant parameters from kwargs based on the client type
1261
1423
  to ensure consistent tracing across different APIs.
1262
1424
  """
1263
- if isinstance(client, (OpenAI, Together)):
1425
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1264
1426
  return {
1265
1427
  "model": kwargs.get("model"),
1266
1428
  "messages": kwargs.get("messages"),
1267
1429
  }
1430
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1431
+ return {
1432
+ "model": kwargs.get("model"),
1433
+ "contents": kwargs.get("contents")
1434
+ }
1268
1435
  # Anthropic requires additional max_tokens parameter
1269
1436
  return {
1270
1437
  "model": kwargs.get("model"),
@@ -1283,7 +1450,7 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1283
1450
  - content: The generated text
1284
1451
  - usage: Token usage statistics
1285
1452
  """
1286
- if isinstance(client, (OpenAI, Together)):
1453
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1287
1454
  return {
1288
1455
  "content": response.choices[0].message.content,
1289
1456
  "usage": {
@@ -1292,6 +1459,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1292
1459
  "total_tokens": response.usage.total_tokens
1293
1460
  }
1294
1461
  }
1462
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1463
+ return {
1464
+ "content": response.candidates[0].content.parts[0].text,
1465
+ "usage": {
1466
+ "prompt_tokens": response.usage_metadata.prompt_token_count,
1467
+ "completion_tokens": response.usage_metadata.candidates_token_count,
1468
+ "total_tokens": response.usage_metadata.total_token_count
1469
+ }
1470
+ }
1295
1471
  # Anthropic has a different response structure
1296
1472
  return {
1297
1473
  "content": response.content[0].text,
@@ -1302,29 +1478,88 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1302
1478
  }
1303
1479
  }
1304
1480
 
1305
- # Add a global context-preserving gather function
1306
- # async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
1307
- # """ # REMOVED
1308
- # A wrapper around asyncio.gather that ensures the trace context # REMOVED
1309
- # is available within the gathered coroutines using contextvars.copy_context. # REMOVED
1310
- # """ # REMOVED
1311
- # # Get the original asyncio.gather (if we patched it) # REMOVED
1312
- # original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
1313
- # # REMOVED
1314
- # # Use contextvars.copy_context() to ensure context propagation # REMOVED
1315
- # ctx = contextvars.copy_context() # REMOVED
1316
- # # REMOVED
1317
- # # Wrap the gather call within the copied context # REMOVED
1318
- # return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
1319
-
1320
- # Store the original gather and apply the patch *once*
1321
- # global _original_gather_stored # REMOVED
1322
- # if not globals().get('_original_gather_stored'): # REMOVED
1323
- # # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
1324
- # if asyncio.gather.__name__ != 'trace_gather': # REMOVED
1325
- # asyncio._original_gather = asyncio.gather # REMOVED
1326
- # asyncio.gather = trace_gather # REMOVED
1327
- # _original_gather_stored = True # REMOVED
1481
+ # Add a new function for deep tracing at the module level
1482
+ def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
1483
+ """
1484
+ Creates a wrapper for a function that automatically traces it when called within a traced function.
1485
+ This enables deep tracing without requiring explicit @observe decorators on every function.
1486
+
1487
+ Args:
1488
+ func: The function to wrap
1489
+ tracer: The Tracer instance
1490
+ span_type: Type of span (default "span")
1491
+
1492
+ Returns:
1493
+ A wrapped function that will be traced when called
1494
+ """
1495
+ # Skip wrapping if the function is not callable or is a built-in
1496
+ if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
1497
+ return func
1498
+
1499
+ # Get function name for the span - check for custom name set by @observe
1500
+ func_name = getattr(func, '_judgment_span_name', func.__name__)
1501
+
1502
+ # Check for custom span_type set by @observe
1503
+ func_span_type = getattr(func, '_judgment_span_type', "span")
1504
+
1505
+ # Store original function to prevent losing reference
1506
+ original_func = func
1507
+
1508
+ # Create appropriate wrapper based on whether the function is async or not
1509
+ if asyncio.iscoroutinefunction(func):
1510
+ @functools.wraps(func)
1511
+ async def async_deep_wrapper(*args, **kwargs):
1512
+ # Get current trace from context
1513
+ current_trace = current_trace_var.get()
1514
+
1515
+ # If no trace context, just call the function
1516
+ if not current_trace:
1517
+ return await original_func(*args, **kwargs)
1518
+
1519
+ # Create a span for this function call - use custom span_type if available
1520
+ with current_trace.span(func_name, span_type=func_span_type) as span:
1521
+ # Record inputs
1522
+ span.record_input({
1523
+ 'args': str(args),
1524
+ 'kwargs': kwargs
1525
+ })
1526
+
1527
+ # Execute function
1528
+ result = await original_func(*args, **kwargs)
1529
+
1530
+ # Record output
1531
+ span.record_output(result)
1532
+
1533
+ return result
1534
+
1535
+ return async_deep_wrapper
1536
+ else:
1537
+ @functools.wraps(func)
1538
+ def deep_wrapper(*args, **kwargs):
1539
+ # Get current trace from context
1540
+ current_trace = current_trace_var.get()
1541
+
1542
+ # If no trace context, just call the function
1543
+ if not current_trace:
1544
+ return original_func(*args, **kwargs)
1545
+
1546
+ # Create a span for this function call - use custom span_type if available
1547
+ with current_trace.span(func_name, span_type=func_span_type) as span:
1548
+ # Record inputs
1549
+ span.record_input({
1550
+ 'args': str(args),
1551
+ 'kwargs': kwargs
1552
+ })
1553
+
1554
+ # Execute function
1555
+ result = original_func(*args, **kwargs)
1556
+
1557
+ # Record output
1558
+ span.record_output(result)
1559
+
1560
+ return result
1561
+
1562
+ return deep_wrapper
1328
1563
 
1329
1564
  # Add the new TraceThreadPoolExecutor class
1330
1565
  class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):