judgeval 0.0.31__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. judgeval/__init__.py +3 -1
  2. judgeval/common/s3_storage.py +93 -0
  3. judgeval/common/tracer.py +869 -183
  4. judgeval/constants.py +1 -1
  5. judgeval/data/datasets/dataset.py +5 -1
  6. judgeval/data/datasets/eval_dataset_client.py +2 -2
  7. judgeval/data/sequence.py +16 -26
  8. judgeval/data/sequence_run.py +2 -0
  9. judgeval/judgment_client.py +44 -166
  10. judgeval/rules.py +4 -7
  11. judgeval/run_evaluation.py +2 -2
  12. judgeval/scorers/__init__.py +4 -4
  13. judgeval/scorers/judgeval_scorers/__init__.py +0 -176
  14. judgeval/version_check.py +22 -0
  15. {judgeval-0.0.31.dist-info → judgeval-0.0.34.dist-info}/METADATA +15 -2
  16. judgeval-0.0.34.dist-info/RECORD +63 -0
  17. judgeval/scorers/base_scorer.py +0 -58
  18. judgeval/scorers/judgeval_scorers/local_implementations/__init__.py +0 -27
  19. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/__init__.py +0 -4
  20. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/answer_correctness_scorer.py +0 -276
  21. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/prompts.py +0 -169
  22. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/__init__.py +0 -4
  23. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/answer_relevancy_scorer.py +0 -298
  24. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/prompts.py +0 -174
  25. judgeval/scorers/judgeval_scorers/local_implementations/comparison/__init__.py +0 -0
  26. judgeval/scorers/judgeval_scorers/local_implementations/comparison/comparison_scorer.py +0 -161
  27. judgeval/scorers/judgeval_scorers/local_implementations/comparison/prompts.py +0 -222
  28. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/__init__.py +0 -3
  29. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/contextual_precision_scorer.py +0 -264
  30. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/prompts.py +0 -106
  31. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/__init__.py +0 -3
  32. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/contextual_recall_scorer.py +0 -254
  33. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/prompts.py +0 -142
  34. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/__init__.py +0 -3
  35. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/contextual_relevancy_scorer.py +0 -245
  36. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/prompts.py +0 -121
  37. judgeval/scorers/judgeval_scorers/local_implementations/execution_order/__init__.py +0 -3
  38. judgeval/scorers/judgeval_scorers/local_implementations/execution_order/execution_order.py +0 -156
  39. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/__init__.py +0 -3
  40. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +0 -318
  41. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +0 -268
  42. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/__init__.py +0 -3
  43. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/hallucination_scorer.py +0 -264
  44. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/prompts.py +0 -104
  45. judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/instruction_adherence.py +0 -232
  46. judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/prompt.py +0 -102
  47. judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/__init__.py +0 -5
  48. judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/json_correctness_scorer.py +0 -134
  49. judgeval/scorers/judgeval_scorers/local_implementations/summarization/__init__.py +0 -3
  50. judgeval/scorers/judgeval_scorers/local_implementations/summarization/prompts.py +0 -247
  51. judgeval/scorers/judgeval_scorers/local_implementations/summarization/summarization_scorer.py +0 -551
  52. judgeval-0.0.31.dist-info/RECORD +0 -96
  53. {judgeval-0.0.31.dist-info → judgeval-0.0.34.dist-info}/WHEEL +0 -0
  54. {judgeval-0.0.31.dist-info → judgeval-0.0.34.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py CHANGED
@@ -11,12 +11,28 @@ import time
11
11
  import uuid
12
12
  import warnings
13
13
  import contextvars
14
- from contextlib import contextmanager
14
+ import sys
15
+ from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
15
16
  from dataclasses import dataclass, field
16
17
  from datetime import datetime
17
18
  from http import HTTPStatus
18
- from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable
19
+ from typing import (
20
+ Any,
21
+ Callable,
22
+ Dict,
23
+ Generator,
24
+ List,
25
+ Literal,
26
+ Optional,
27
+ Tuple,
28
+ Type,
29
+ TypeVar,
30
+ Union,
31
+ AsyncGenerator,
32
+ TypeAlias,
33
+ )
19
34
  from rich import print as rprint
35
+ import types # <--- Add this import
20
36
 
21
37
  # Third-party imports
22
38
  import pika
@@ -27,6 +43,7 @@ from rich import print as rprint
27
43
  from openai import OpenAI, AsyncOpenAI
28
44
  from together import Together, AsyncTogether
29
45
  from anthropic import Anthropic, AsyncAnthropic
46
+ from google import genai
30
47
 
31
48
  # Local application/library-specific imports
32
49
  from judgeval.constants import (
@@ -40,20 +57,22 @@ from judgeval.constants import (
40
57
  )
41
58
  from judgeval.judgment_client import JudgmentClient
42
59
  from judgeval.data import Example
43
- from judgeval.scorers import APIJudgmentScorer, JudgevalScorer, ScorerWrapper
60
+ from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
44
61
  from judgeval.rules import Rule
45
62
  from judgeval.evaluation_run import EvaluationRun
46
63
  from judgeval.data.result import ScoringResult
47
64
 
48
65
  # Standard library imports needed for the new class
49
66
  import concurrent.futures
67
+ from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
50
68
 
51
69
  # Define context variables for tracking the current trace and the current span within a trace
52
70
  current_trace_var = contextvars.ContextVar('current_trace', default=None)
53
- current_span_var = contextvars.ContextVar('current_span', default=None) # NEW: ContextVar for the active span name
71
+ current_span_var = contextvars.ContextVar('current_span', default=None) # ContextVar for the active span name
72
+ in_traced_function_var = contextvars.ContextVar('in_traced_function', default=False) # Track if we're in a traced function
54
73
 
55
74
  # Define type aliases for better code readability and maintainability
56
- ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether] # Supported API clients
75
+ ApiClient: TypeAlias = Union[OpenAI, Together, Anthropic, AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.Client, genai.client.AsyncClient] # Supported API clients
57
76
  TraceEntryType = Literal['enter', 'exit', 'output', 'input', 'evaluation'] # Valid trace entry types
58
77
  SpanType = Literal['span', 'tool', 'llm', 'evaluation', 'chain']
59
78
  @dataclass
@@ -170,7 +189,7 @@ class TraceEntry:
170
189
  "inputs": self._serialize_inputs(),
171
190
  "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
172
191
  "span_type": self.span_type,
173
- "parent_span_id": self.parent_span_id
192
+ "parent_span_id": self.parent_span_id,
174
193
  }
175
194
 
176
195
  def _serialize_output(self) -> Any:
@@ -185,6 +204,15 @@ class TraceEntry:
185
204
  if isinstance(self.output, BaseModel):
186
205
  return self.output.model_dump()
187
206
 
207
+ # NEW check: If output is the dict structure from our stream wrapper
208
+ if isinstance(self.output, dict) and 'streamed' in self.output:
209
+ # Assume it's already JSON-serializable (content is string, usage is dict or None)
210
+ return self.output
211
+ # NEW check: If output is the placeholder string before stream completes
212
+ elif self.output == "<pending stream>":
213
+ # Represent this state clearly in the serialized data
214
+ return {"status": "pending stream"}
215
+
188
216
  try:
189
217
  # Try to serialize the output to verify it's JSON compatible
190
218
  json.dumps(self.output)
@@ -203,9 +231,10 @@ class TraceManagerClient:
203
231
  - Saving a trace
204
232
  - Deleting a trace
205
233
  """
206
- def __init__(self, judgment_api_key: str, organization_id: str):
234
+ def __init__(self, judgment_api_key: str, organization_id: str, tracer: Optional["Tracer"] = None):
207
235
  self.judgment_api_key = judgment_api_key
208
236
  self.organization_id = organization_id
237
+ self.tracer = tracer
209
238
 
210
239
  def fetch_trace(self, trace_id: str):
211
240
  """
@@ -233,12 +262,13 @@ class TraceManagerClient:
233
262
 
234
263
  def save_trace(self, trace_data: dict):
235
264
  """
236
- Saves a trace to the database
265
+ Saves a trace to the Judgment Supabase and optionally to S3 if configured.
237
266
 
238
267
  Args:
239
268
  trace_data: The trace data to save
240
269
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
241
270
  """
271
+ # Save to Judgment API
242
272
  response = requests.post(
243
273
  JUDGMENT_TRACES_SAVE_API_URL,
244
274
  json=trace_data,
@@ -255,6 +285,18 @@ class TraceManagerClient:
255
285
  elif response.status_code != HTTPStatus.OK:
256
286
  raise ValueError(f"Failed to save trace data: {response.text}")
257
287
 
288
+ # If S3 storage is enabled, save to S3 as well
289
+ if self.tracer and self.tracer.use_s3:
290
+ try:
291
+ s3_key = self.tracer.s3_storage.save_trace(
292
+ trace_data=trace_data,
293
+ trace_id=trace_data["trace_id"],
294
+ project_name=trace_data["project_name"]
295
+ )
296
+ print(f"Trace also saved to S3 at key: {s3_key}")
297
+ except Exception as e:
298
+ warnings.warn(f"Failed to save trace to S3: {str(e)}")
299
+
258
300
  if "ui_results_url" in response.json():
259
301
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
260
302
  rprint(pretty_str)
@@ -352,7 +394,7 @@ class TraceClient:
352
394
  self.client: JudgmentClient = tracer.client
353
395
  self.entries: List[TraceEntry] = []
354
396
  self.start_time = time.time()
355
- self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
397
+ self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
356
398
  self.visited_nodes = []
357
399
  self.executed_tools = []
358
400
  self.executed_node_tools = []
@@ -390,13 +432,13 @@ class TraceClient:
390
432
  entry = TraceEntry(
391
433
  type="enter",
392
434
  function=name,
393
- span_id=span_id, # Use the generated span_id
394
- trace_id=self.trace_id, # Use the trace_id from the trace client
435
+ span_id=span_id,
436
+ trace_id=self.trace_id,
395
437
  depth=current_depth,
396
438
  message=name,
397
439
  created_at=start_time,
398
440
  span_type=span_type,
399
- parent_span_id=parent_span_id # Use the parent_id from context var
441
+ parent_span_id=parent_span_id,
400
442
  )
401
443
  self.add_entry(entry)
402
444
 
@@ -414,7 +456,7 @@ class TraceClient:
414
456
  message=f"← {name}",
415
457
  created_at=time.time(),
416
458
  duration=duration,
417
- span_type=span_type
459
+ span_type=span_type,
418
460
  ))
419
461
  # Clean up depth tracking for this span_id
420
462
  if span_id in self._span_depths:
@@ -451,47 +493,14 @@ class TraceClient:
451
493
  additional_metadata=additional_metadata,
452
494
  trace_id=self.trace_id
453
495
  )
454
- loaded_rules = None
455
- if self.rules:
456
- loaded_rules = []
457
- for rule in self.rules:
458
- processed_conditions = []
459
- for condition in rule.conditions:
460
- # Convert metric if it's a ScorerWrapper
461
- try:
462
- if isinstance(condition.metric, ScorerWrapper):
463
- condition_copy = condition.model_copy()
464
- condition_copy.metric = condition.metric.load_implementation(use_judgment=True)
465
- processed_conditions.append(condition_copy)
466
- else:
467
- processed_conditions.append(condition)
468
- except Exception as e:
469
- warnings.warn(f"Failed to convert ScorerWrapper in rule '{rule.name}', condition metric '{condition.metric_name}': {str(e)}")
470
- processed_conditions.append(condition) # Keep original condition as fallback
471
-
472
- # Create new rule with processed conditions
473
- new_rule = rule.model_copy()
474
- new_rule.conditions = processed_conditions
475
- loaded_rules.append(new_rule)
476
496
  try:
477
497
  # Load appropriate implementations for all scorers
478
- loaded_scorers: List[Union[JudgevalScorer, APIJudgmentScorer]] = []
479
- for scorer in scorers:
480
- try:
481
- if isinstance(scorer, ScorerWrapper):
482
- loaded_scorers.append(scorer.load_implementation(use_judgment=True))
483
- else:
484
- loaded_scorers.append(scorer)
485
- except Exception as e:
486
- warnings.warn(f"Failed to load implementation for scorer {scorer}: {str(e)}")
487
- # Skip this scorer
488
-
489
- if not loaded_scorers:
498
+ if not scorers:
490
499
  warnings.warn("No valid scorers available for evaluation")
491
500
  return
492
501
 
493
502
  # Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
494
- if loaded_rules and any(isinstance(scorer, JudgevalScorer) for scorer in loaded_scorers):
503
+ if self.rules and any(isinstance(scorer, JudgevalScorer) for scorer in scorers):
495
504
  raise ValueError("Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types.")
496
505
 
497
506
  except Exception as e:
@@ -505,15 +514,15 @@ class TraceClient:
505
514
  project_name=self.project_name,
506
515
  eval_name=f"{self.name.capitalize()}-"
507
516
  f"{current_span_var.get()}-"
508
- f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
517
+ f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
509
518
  examples=[example],
510
- scorers=loaded_scorers,
519
+ scorers=scorers,
511
520
  model=model,
512
521
  metadata={},
513
522
  judgment_api_key=self.tracer.api_key,
514
523
  override=self.overwrite,
515
524
  trace_span_id=current_span_var.get(),
516
- rules=loaded_rules # Use the combined rules
525
+ rules=self.rules # Use the combined rules
517
526
  )
518
527
 
519
528
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
@@ -571,7 +580,7 @@ class TraceClient:
571
580
  message=f"Inputs to {function_name}",
572
581
  created_at=time.time(),
573
582
  inputs=inputs,
574
- span_type=entry_span_type
583
+ span_type=entry_span_type,
575
584
  ))
576
585
 
577
586
  async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
@@ -604,12 +613,15 @@ class TraceClient:
604
613
  message=f"Output from {function_name}",
605
614
  created_at=time.time(),
606
615
  output="<pending>" if inspect.iscoroutine(output) else output,
607
- span_type=entry_span_type
616
+ span_type=entry_span_type,
608
617
  )
609
618
  self.add_entry(entry)
610
619
 
611
620
  if inspect.iscoroutine(output):
612
621
  asyncio.create_task(self._update_coroutine_output(entry, output))
622
+
623
+ # Return the created entry
624
+ return entry
613
625
 
614
626
  def add_entry(self, entry: TraceEntry):
615
627
  """Add a trace entry to this trace context"""
@@ -821,8 +833,10 @@ class TraceClient:
821
833
  total_completion_tokens_cost = 0.0
822
834
  total_cost = 0.0
823
835
 
836
+ # Only count tokens for actual LLM API call spans
837
+ llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
824
838
  for entry in condensed_entries:
825
- if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
839
+ if entry.get("span_type") == "llm" and entry.get("function") in llm_span_names and isinstance(entry.get("output"), dict):
826
840
  output = entry["output"]
827
841
  usage = output.get("usage", {})
828
842
  model_name = entry.get("inputs", {}).get("model", "")
@@ -888,6 +902,13 @@ class TraceClient:
888
902
  "parent_trace_id": self.parent_trace_id,
889
903
  "parent_name": self.parent_name
890
904
  }
905
+ # --- Log trace data before saving ---
906
+ try:
907
+ rprint(f"[TraceClient.save] Saving trace data for trace_id {self.trace_id}:")
908
+ rprint(json.dumps(trace_data, indent=2))
909
+ except Exception as log_e:
910
+ rprint(f"[TraceClient.save] Error logging trace data: {log_e}")
911
+ # --- End logging ---
891
912
  self.trace_manager_client.save_trace(trace_data)
892
913
 
893
914
  return self.trace_id, trace_data
@@ -910,7 +931,14 @@ class Tracer:
910
931
  rules: Optional[List[Rule]] = None, # Added rules parameter
911
932
  organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
912
933
  enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
913
- enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true"
934
+ enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
935
+ # S3 configuration
936
+ use_s3: bool = False,
937
+ s3_bucket_name: Optional[str] = None,
938
+ s3_aws_access_key_id: Optional[str] = None,
939
+ s3_aws_secret_access_key: Optional[str] = None,
940
+ s3_region_name: Optional[str] = None,
941
+ deep_tracing: bool = True # NEW: Enable deep tracing by default
914
942
  ):
915
943
  if not hasattr(self, 'initialized'):
916
944
  if not api_key:
@@ -918,6 +946,13 @@ class Tracer:
918
946
 
919
947
  if not organization_id:
920
948
  raise ValueError("Tracer must be configured with an Organization ID")
949
+ if use_s3 and not s3_bucket_name:
950
+ raise ValueError("S3 bucket name must be provided when use_s3 is True")
951
+ if use_s3 and not (s3_aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")):
952
+ raise ValueError("AWS Access Key ID must be provided when use_s3 is True")
953
+ if use_s3 and not (s3_aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")):
954
+ raise ValueError("AWS Secret Access Key must be provided when use_s3 is True")
955
+
921
956
  self.api_key: str = api_key
922
957
  self.project_name: str = project_name
923
958
  self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
@@ -927,6 +962,19 @@ class Tracer:
927
962
  self.initialized: bool = True
928
963
  self.enable_monitoring: bool = enable_monitoring
929
964
  self.enable_evaluations: bool = enable_evaluations
965
+
966
+ # Initialize S3 storage if enabled
967
+ self.use_s3 = use_s3
968
+ if use_s3:
969
+ from judgeval.common.s3_storage import S3Storage
970
+ self.s3_storage = S3Storage(
971
+ bucket_name=s3_bucket_name,
972
+ aws_access_key_id=s3_aws_access_key_id,
973
+ aws_secret_access_key=s3_aws_secret_access_key,
974
+ region_name=s3_region_name
975
+ )
976
+ self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
977
+
930
978
  elif hasattr(self, 'project_name') and self.project_name != project_name:
931
979
  warnings.warn(
932
980
  f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
@@ -941,12 +989,52 @@ class Tracer:
941
989
  """
942
990
  current_trace_var.set(trace)
943
991
 
944
- def get_current_trace(self):
992
+ def get_current_trace(self) -> Optional[TraceClient]:
945
993
  """
946
994
  Get the current trace context from contextvars
947
995
  """
948
996
  return current_trace_var.get()
949
997
 
998
+ def _apply_deep_tracing(self, func, span_type="span"):
999
+ """
1000
+ Apply deep tracing to all functions in the same module as the given function.
1001
+
1002
+ Args:
1003
+ func: The function being traced
1004
+ span_type: Type of span to use for traced functions
1005
+
1006
+ Returns:
1007
+ A tuple of (module, original_functions_dict) where original_functions_dict
1008
+ contains the original functions that were replaced with traced versions.
1009
+ """
1010
+ module = inspect.getmodule(func)
1011
+ if not module:
1012
+ return None, {}
1013
+
1014
+ # Save original functions
1015
+ original_functions = {}
1016
+
1017
+ # Find all functions in the module
1018
+ for name, obj in inspect.getmembers(module, inspect.isfunction):
1019
+ # Skip already wrapped functions
1020
+ if hasattr(obj, '_judgment_traced'):
1021
+ continue
1022
+
1023
+ # Create a traced version of the function
1024
+ # Always use default span type "span" for child functions
1025
+ traced_func = _create_deep_tracing_wrapper(obj, self, "span")
1026
+
1027
+ # Mark the function as traced to avoid double wrapping
1028
+ traced_func._judgment_traced = True
1029
+
1030
+ # Save the original function
1031
+ original_functions[name] = obj
1032
+
1033
+ # Replace with traced version
1034
+ setattr(module, name, traced_func)
1035
+
1036
+ return module, original_functions
1037
+
950
1038
  @contextmanager
951
1039
  def trace(
952
1040
  self,
@@ -992,14 +1080,8 @@ class Tracer:
992
1080
  finally:
993
1081
  # Reset the context variable
994
1082
  current_trace_var.reset(token)
995
-
996
- def get_current_trace(self) -> Optional[TraceClient]:
997
- """
998
- Get the current trace context from contextvars
999
- """
1000
- return current_trace_var.get()
1001
-
1002
- def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False):
1083
+
1084
+ def observe(self, func=None, *, name=None, span_type: SpanType = "span", project_name: str = None, overwrite: bool = False, deep_tracing: bool = None):
1003
1085
  """
1004
1086
  Decorator to trace function execution with detailed entry/exit information.
1005
1087
 
@@ -1009,20 +1091,37 @@ class Tracer:
1009
1091
  span_type: Type of span (default "span")
1010
1092
  project_name: Optional project name override
1011
1093
  overwrite: Whether to overwrite existing traces
1094
+ deep_tracing: Whether to enable deep tracing for this function and all nested calls.
1095
+ If None, uses the tracer's default setting.
1012
1096
  """
1013
1097
  # If monitoring is disabled, return the function as is
1014
1098
  if not self.enable_monitoring:
1015
1099
  return func if func else lambda f: f
1016
1100
 
1017
1101
  if func is None:
1018
- return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name, overwrite=overwrite)
1102
+ return lambda f: self.observe(f, name=name, span_type=span_type, project_name=project_name,
1103
+ overwrite=overwrite, deep_tracing=deep_tracing)
1019
1104
 
1020
1105
  # Use provided name or fall back to function name
1021
1106
  span_name = name or func.__name__
1022
1107
 
1108
+ # Store custom attributes on the function object
1109
+ func._judgment_span_name = span_name
1110
+ func._judgment_span_type = span_type
1111
+
1112
+ # Use the provided deep_tracing value or fall back to the tracer's default
1113
+ use_deep_tracing = deep_tracing if deep_tracing is not None else self.deep_tracing
1114
+
1023
1115
  if asyncio.iscoroutinefunction(func):
1024
1116
  @functools.wraps(func)
1025
1117
  async def async_wrapper(*args, **kwargs):
1118
+ # Check if we're already in a traced function
1119
+ if in_traced_function_var.get():
1120
+ return await func(*args, **kwargs)
1121
+
1122
+ # Set in_traced_function_var to True
1123
+ token = in_traced_function_var.set(True)
1124
+
1026
1125
  # Get current trace from context
1027
1126
  current_trace = current_trace_var.get()
1028
1127
 
@@ -1057,9 +1156,18 @@ class Tracer:
1057
1156
  'kwargs': kwargs
1058
1157
  })
1059
1158
 
1159
+ # If deep tracing is enabled, apply monkey patching
1160
+ if use_deep_tracing:
1161
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1162
+
1060
1163
  # Execute function
1061
1164
  result = await func(*args, **kwargs)
1062
1165
 
1166
+ # Restore original functions if deep tracing was enabled
1167
+ if use_deep_tracing and module and 'original_functions' in locals():
1168
+ for name, obj in original_functions.items():
1169
+ setattr(module, name, obj)
1170
+
1063
1171
  # Record output
1064
1172
  span.record_output(result)
1065
1173
 
@@ -1069,29 +1177,52 @@ class Tracer:
1069
1177
  finally:
1070
1178
  # Reset trace context (span context resets automatically)
1071
1179
  current_trace_var.reset(trace_token)
1180
+ # Reset in_traced_function_var
1181
+ in_traced_function_var.reset(token)
1072
1182
  else:
1073
1183
  # Already have a trace context, just create a span in it
1074
1184
  # The span method handles current_span_var
1075
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1076
- # Record inputs
1077
- span.record_input({
1078
- 'args': str(args),
1079
- 'kwargs': kwargs
1080
- })
1081
-
1082
- # Execute function
1083
- result = await func(*args, **kwargs)
1084
-
1085
- # Record output
1086
- span.record_output(result)
1185
+
1186
+ try:
1187
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1188
+ # Record inputs
1189
+ span.record_input({
1190
+ 'args': str(args),
1191
+ 'kwargs': kwargs
1192
+ })
1193
+
1194
+ # If deep tracing is enabled, apply monkey patching
1195
+ if use_deep_tracing:
1196
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1197
+
1198
+ # Execute function
1199
+ result = await func(*args, **kwargs)
1200
+
1201
+ # Restore original functions if deep tracing was enabled
1202
+ if use_deep_tracing and module and 'original_functions' in locals():
1203
+ for name, obj in original_functions.items():
1204
+ setattr(module, name, obj)
1205
+
1206
+ # Record output
1207
+ span.record_output(result)
1087
1208
 
1088
1209
  return result
1089
-
1210
+ finally:
1211
+ # Reset in_traced_function_var
1212
+ in_traced_function_var.reset(token)
1213
+
1090
1214
  return async_wrapper
1091
1215
  else:
1092
- # Non-async function implementation remains unchanged
1216
+ # Non-async function implementation with deep tracing
1093
1217
  @functools.wraps(func)
1094
1218
  def wrapper(*args, **kwargs):
1219
+ # Check if we're already in a traced function
1220
+ if in_traced_function_var.get():
1221
+ return func(*args, **kwargs)
1222
+
1223
+ # Set in_traced_function_var to True
1224
+ token = in_traced_function_var.set(True)
1225
+
1095
1226
  # Get current trace from context
1096
1227
  current_trace = current_trace_var.get()
1097
1228
 
@@ -1126,9 +1257,18 @@ class Tracer:
1126
1257
  'kwargs': kwargs
1127
1258
  })
1128
1259
 
1260
+ # If deep tracing is enabled, apply monkey patching
1261
+ if use_deep_tracing:
1262
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1263
+
1129
1264
  # Execute function
1130
1265
  result = func(*args, **kwargs)
1131
1266
 
1267
+ # Restore original functions if deep tracing was enabled
1268
+ if use_deep_tracing and module and 'original_functions' in locals():
1269
+ for name, obj in original_functions.items():
1270
+ setattr(module, name, obj)
1271
+
1132
1272
  # Record output
1133
1273
  span.record_output(result)
1134
1274
 
@@ -1138,24 +1278,40 @@ class Tracer:
1138
1278
  finally:
1139
1279
  # Reset trace context (span context resets automatically)
1140
1280
  current_trace_var.reset(trace_token)
1281
+ # Reset in_traced_function_var
1282
+ in_traced_function_var.reset(token)
1141
1283
  else:
1142
1284
  # Already have a trace context, just create a span in it
1143
1285
  # The span method handles current_span_var
1144
- with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1145
- # Record inputs
1146
- span.record_input({
1147
- 'args': str(args),
1148
- 'kwargs': kwargs
1149
- })
1150
-
1151
- # Execute function
1152
- result = func(*args, **kwargs)
1153
-
1154
- # Record output
1155
- span.record_output(result)
1286
+
1287
+ try:
1288
+ with current_trace.span(span_name, span_type=span_type) as span: # MODIFIED: Use span_name directly
1289
+ # Record inputs
1290
+ span.record_input({
1291
+ 'args': str(args),
1292
+ 'kwargs': kwargs
1293
+ })
1294
+
1295
+ # If deep tracing is enabled, apply monkey patching
1296
+ if use_deep_tracing:
1297
+ module, original_functions = self._apply_deep_tracing(func, span_type)
1298
+
1299
+ # Execute function
1300
+ result = func(*args, **kwargs)
1301
+
1302
+ # Restore original functions if deep tracing was enabled
1303
+ if use_deep_tracing and module and 'original_functions' in locals():
1304
+ for name, obj in original_functions.items():
1305
+ setattr(module, name, obj)
1306
+
1307
+ # Record output
1308
+ span.record_output(result)
1156
1309
 
1157
1310
  return result
1158
-
1311
+ finally:
1312
+ # Reset in_traced_function_var
1313
+ in_traced_function_var.reset(token)
1314
+
1159
1315
  return wrapper
1160
1316
 
1161
1317
  def score(self, func=None, scorers: List[Union[APIJudgmentScorer, JudgevalScorer]] = None, model: str = None, log_results: bool = True, *, name: str = None, span_type: SpanType = "span"):
@@ -1200,96 +1356,192 @@ class Tracer:
1200
1356
  def wrap(client: Any) -> Any:
1201
1357
  """
1202
1358
  Wraps an API client to add tracing capabilities.
1203
- Supports OpenAI, Together, and Anthropic clients.
1359
+ Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1360
+ Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1204
1361
  """
1205
- # Get the appropriate configuration for this client type
1206
- span_name, original_create = _get_client_config(client)
1362
+ span_name, original_create, original_stream = _get_client_config(client)
1207
1363
 
1208
- # Handle async clients differently than synchronous clients (need an async function for async clients)
1209
- if (isinstance(client, (AsyncOpenAI, AsyncAnthropic, AsyncTogether))):
1210
- async def traced_create(*args, **kwargs):
1211
- # Get the current trace from contextvars
1212
- current_trace = current_trace_var.get()
1213
-
1214
- # Skip tracing if no active trace
1215
- if not current_trace:
1216
- return original_create(*args, **kwargs)
1364
+ # --- Define Traced Async Functions ---
1365
+ async def traced_create_async(*args, **kwargs):
1366
+ # [Existing logic - unchanged]
1367
+ current_trace = current_trace_var.get()
1368
+ if not current_trace:
1369
+ if asyncio.iscoroutinefunction(original_create):
1370
+ return await original_create(*args, **kwargs)
1371
+ else:
1372
+ return original_create(*args, **kwargs)
1373
+
1374
+ is_streaming = kwargs.get("stream", False)
1375
+
1376
+ with current_trace.span(span_name, span_type="llm") as span:
1377
+ input_data = _format_input_data(client, **kwargs)
1378
+ span.record_input(input_data)
1379
+
1380
+ # Warn about token counting limitations with streaming
1381
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1382
+ if not kwargs.get("stream_options", {}).get("include_usage"):
1383
+ warnings.warn(
1384
+ "OpenAI streaming calls don't include token counts by default. "
1385
+ "To enable token counting with streams, set stream_options={'include_usage': True} "
1386
+ "in your API call arguments.",
1387
+ UserWarning
1388
+ )
1217
1389
 
1218
- with current_trace.span(span_name, span_type="llm") as span:
1219
- # Format and record the input parameters
1220
- input_data = _format_input_data(client, **kwargs)
1221
- span.record_input(input_data)
1222
-
1223
- # Make the actual API call
1224
- try:
1225
- response = await original_create(*args, **kwargs)
1226
- except Exception as e:
1227
- print(f"Error during API call: {e}")
1228
- raise
1229
-
1230
- # Format and record the output
1231
- output_data = _format_output_data(client, response)
1232
- span.record_output(output_data)
1233
-
1234
- return response
1235
- else:
1236
- def traced_create(*args, **kwargs):
1237
- # Get the current trace from contextvars
1238
- current_trace = current_trace_var.get()
1239
-
1240
- # Skip tracing if no active trace
1241
- if not current_trace:
1242
- return original_create(*args, **kwargs)
1390
+ try:
1391
+ if is_streaming:
1392
+ stream_iterator = await original_create(*args, **kwargs)
1393
+ output_entry = span.record_output("<pending stream>")
1394
+ return _async_stream_wrapper(stream_iterator, client, output_entry)
1395
+ else:
1396
+ awaited_response = await original_create(*args, **kwargs)
1397
+ output_data = _format_output_data(client, awaited_response)
1398
+ span.record_output(output_data)
1399
+ return awaited_response
1400
+ except Exception as e:
1401
+ print(f"Error during wrapped async API call ({span_name}): {e}")
1402
+ span.record_output({"error": str(e)})
1403
+ raise
1404
+
1405
+
1406
+ # Function replacing .stream() - NOW returns the wrapper class instance
1407
+ def traced_stream_async(*args, **kwargs):
1408
+ current_trace = current_trace_var.get()
1409
+ if not current_trace or not original_stream:
1410
+ return original_stream(*args, **kwargs)
1411
+ original_manager = original_stream(*args, **kwargs)
1412
+ wrapper_manager = _TracedAsyncStreamManagerWrapper(
1413
+ original_manager=original_manager,
1414
+ client=client,
1415
+ span_name=span_name,
1416
+ trace_client=current_trace,
1417
+ stream_wrapper_func=_async_stream_wrapper,
1418
+ input_kwargs=kwargs
1419
+ )
1420
+ return wrapper_manager
1421
+
1422
+ # --- Define Traced Sync Functions ---
1423
+ def traced_create_sync(*args, **kwargs):
1424
+ # [Existing logic - unchanged]
1425
+ current_trace = current_trace_var.get()
1426
+ if not current_trace:
1427
+ return original_create(*args, **kwargs)
1428
+
1429
+ is_streaming = kwargs.get("stream", False)
1430
+
1431
+ with current_trace.span(span_name, span_type="llm") as span:
1432
+ input_data = _format_input_data(client, **kwargs)
1433
+ span.record_input(input_data)
1434
+
1435
+ # Warn about token counting limitations with streaming
1436
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1437
+ if not kwargs.get("stream_options", {}).get("include_usage"):
1438
+ warnings.warn(
1439
+ "OpenAI streaming calls don't include token counts by default. "
1440
+ "To enable token counting with streams, set stream_options={'include_usage': True} "
1441
+ "in your API call arguments.",
1442
+ UserWarning
1443
+ )
1444
+
1445
+ try:
1446
+ response_or_iterator = original_create(*args, **kwargs)
1447
+ except Exception as e:
1448
+ print(f"Error during wrapped sync API call ({span_name}): {e}")
1449
+ span.record_output({"error": str(e)})
1450
+ raise
1451
+
1452
+ if is_streaming:
1453
+ output_entry = span.record_output("<pending stream>")
1454
+ return _sync_stream_wrapper(response_or_iterator, client, output_entry)
1455
+ else:
1456
+ output_data = _format_output_data(client, response_or_iterator)
1457
+ span.record_output(output_data)
1458
+ return response_or_iterator
1459
+
1460
+
1461
+ # Function replacing sync .stream()
1462
+ def traced_stream_sync(*args, **kwargs):
1463
+ current_trace = current_trace_var.get()
1464
+ if not current_trace or not original_stream:
1465
+ return original_stream(*args, **kwargs)
1466
+ original_manager = original_stream(*args, **kwargs)
1467
+ wrapper_manager = _TracedSyncStreamManagerWrapper(
1468
+ original_manager=original_manager,
1469
+ client=client,
1470
+ span_name=span_name,
1471
+ trace_client=current_trace,
1472
+ stream_wrapper_func=_sync_stream_wrapper,
1473
+ input_kwargs=kwargs
1474
+ )
1475
+ return wrapper_manager
1476
+
1477
+
1478
+ # --- Assign Traced Methods to Client Instance ---
1479
+ # [Assignment logic remains the same]
1480
+ if isinstance(client, (AsyncOpenAI, AsyncTogether)):
1481
+ client.chat.completions.create = traced_create_async
1482
+ # Wrap the Responses API endpoint for OpenAI clients
1483
+ if hasattr(client, "responses") and hasattr(client.responses, "create"):
1484
+ # Capture the original responses.create
1485
+ original_responses_create = client.responses.create
1486
+ def traced_responses(*args, **kwargs):
1487
+ # Get the current trace from contextvars
1488
+ current_trace = current_trace_var.get()
1489
+ # If no active trace, call the original
1490
+ if not current_trace:
1491
+ return original_responses_create(*args, **kwargs)
1492
+ # Trace this responses.create call
1493
+ with current_trace.span(span_name, span_type="llm") as span:
1494
+ # Record raw input kwargs
1495
+ span.record_input(kwargs)
1496
+ # Make the actual API call
1497
+ response = original_responses_create(*args, **kwargs)
1498
+ # Record the output object
1499
+ span.record_output(response)
1500
+ return response
1501
+ # Assign the traced wrapper
1502
+ client.responses.create = traced_responses
1503
+ elif isinstance(client, AsyncAnthropic):
1504
+ client.messages.create = traced_create_async
1505
+ if original_stream:
1506
+ client.messages.stream = traced_stream_async
1507
+ elif isinstance(client, genai.client.AsyncClient):
1508
+ client.generate_content = traced_create_async
1509
+ elif isinstance(client, (OpenAI, Together)):
1510
+ client.chat.completions.create = traced_create_sync
1511
+ elif isinstance(client, Anthropic):
1512
+ client.messages.create = traced_create_sync
1513
+ if original_stream:
1514
+ client.messages.stream = traced_stream_sync
1515
+ elif isinstance(client, genai.Client):
1516
+ client.generate_content = traced_create_sync
1243
1517
 
1244
- with current_trace.span(span_name, span_type="llm") as span:
1245
- # Format and record the input parameters
1246
- input_data = _format_input_data(client, **kwargs)
1247
- span.record_input(input_data)
1248
-
1249
- # Make the actual API call
1250
- try:
1251
- response = original_create(*args, **kwargs)
1252
- except Exception as e:
1253
- print(f"Error during API call: {e}")
1254
- raise
1255
-
1256
- # Format and record the output
1257
- output_data = _format_output_data(client, response)
1258
- span.record_output(output_data)
1259
-
1260
- return response
1261
-
1262
-
1263
- # Replace the original method with our traced version
1264
- if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1265
- client.chat.completions.create = traced_create
1266
- elif isinstance(client, (Anthropic, AsyncAnthropic)):
1267
- client.messages.create = traced_create
1268
-
1269
1518
  return client
1270
1519
 
1271
1520
  # Helper functions for client-specific operations
1272
1521
 
1273
- def _get_client_config(client: ApiClient) -> tuple[str, callable]:
1522
+ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[callable]]:
1274
1523
  """Returns configuration tuple for the given API client.
1275
1524
 
1276
1525
  Args:
1277
1526
  client: An instance of OpenAI, Together, or Anthropic client
1278
1527
 
1279
1528
  Returns:
1280
- tuple: (span_name, create_method)
1529
+ tuple: (span_name, create_method, stream_method)
1281
1530
  - span_name: String identifier for tracing
1282
1531
  - create_method: Reference to the client's creation method
1532
+ - stream_method: Reference to the client's stream method (if applicable)
1283
1533
 
1284
1534
  Raises:
1285
1535
  ValueError: If client type is not supported
1286
1536
  """
1287
1537
  if isinstance(client, (OpenAI, AsyncOpenAI)):
1288
- return "OPENAI_API_CALL", client.chat.completions.create
1538
+ return "OPENAI_API_CALL", client.chat.completions.create, None
1289
1539
  elif isinstance(client, (Together, AsyncTogether)):
1290
- return "TOGETHER_API_CALL", client.chat.completions.create
1540
+ return "TOGETHER_API_CALL", client.chat.completions.create, None
1291
1541
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1292
- return "ANTHROPIC_API_CALL", client.messages.create
1542
+ return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
1543
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1544
+ return "GOOGLE_API_CALL", client.models.generate_content, None
1293
1545
  raise ValueError(f"Unsupported client type: {type(client)}")
1294
1546
 
1295
1547
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
@@ -1303,6 +1555,11 @@ def _format_input_data(client: ApiClient, **kwargs) -> dict:
1303
1555
  "model": kwargs.get("model"),
1304
1556
  "messages": kwargs.get("messages"),
1305
1557
  }
1558
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1559
+ return {
1560
+ "model": kwargs.get("model"),
1561
+ "contents": kwargs.get("contents")
1562
+ }
1306
1563
  # Anthropic requires additional max_tokens parameter
1307
1564
  return {
1308
1565
  "model": kwargs.get("model"),
@@ -1330,6 +1587,15 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1330
1587
  "total_tokens": response.usage.total_tokens
1331
1588
  }
1332
1589
  }
1590
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1591
+ return {
1592
+ "content": response.candidates[0].content.parts[0].text,
1593
+ "usage": {
1594
+ "prompt_tokens": response.usage_metadata.prompt_token_count,
1595
+ "completion_tokens": response.usage_metadata.candidates_token_count,
1596
+ "total_tokens": response.usage_metadata.total_token_count
1597
+ }
1598
+ }
1333
1599
  # Anthropic has a different response structure
1334
1600
  return {
1335
1601
  "content": response.content[0].text,
@@ -1340,29 +1606,117 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1340
1606
  }
1341
1607
  }
1342
1608
 
1343
- # Add a global context-preserving gather function
1344
- # async def trace_gather(*coroutines, return_exceptions=False): # REMOVED
1345
- # """ # REMOVED
1346
- # A wrapper around asyncio.gather that ensures the trace context # REMOVED
1347
- # is available within the gathered coroutines using contextvars.copy_context. # REMOVED
1348
- # """ # REMOVED
1349
- # # Get the original asyncio.gather (if we patched it) # REMOVED
1350
- # original_gather = getattr(asyncio, "_original_gather", asyncio.gather) # REMOVED
1351
- # # REMOVED
1352
- # # Use contextvars.copy_context() to ensure context propagation # REMOVED
1353
- # ctx = contextvars.copy_context() # REMOVED
1354
- # # REMOVED
1355
- # # Wrap the gather call within the copied context # REMOVED
1356
- # return await ctx.run(original_gather, *coroutines, return_exceptions=return_exceptions) # REMOVED
1357
-
1358
- # Store the original gather and apply the patch *once*
1359
- # global _original_gather_stored # REMOVED
1360
- # if not globals().get('_original_gather_stored'): # REMOVED
1361
- # # Check if asyncio.gather is already our wrapper to prevent double patching # REMOVED
1362
- # if asyncio.gather.__name__ != 'trace_gather': # REMOVED
1363
- # asyncio._original_gather = asyncio.gather # REMOVED
1364
- # asyncio.gather = trace_gather # REMOVED
1365
- # _original_gather_stored = True # REMOVED
1609
+ # Define a blocklist of functions that should not be traced
1610
+ # These are typically utility functions, print statements, logging, etc.
1611
+ _TRACE_BLOCKLIST = {
1612
+ # Built-in functions
1613
+ 'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
1614
+ 'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
1615
+ 'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
1616
+ # Logging functions
1617
+ 'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
1618
+ # Common utility functions
1619
+ 'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
1620
+ # String operations
1621
+ 'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
1622
+ # Dict operations
1623
+ 'get', 'items', 'keys', 'values', 'update',
1624
+ # List operations
1625
+ 'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
1626
+ }
1627
+
1628
+
1629
+ # Add a new function for deep tracing at the module level
1630
+ def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
1631
+ """
1632
+ Creates a wrapper for a function that automatically traces it when called within a traced function.
1633
+ This enables deep tracing without requiring explicit @observe decorators on every function.
1634
+
1635
+ Args:
1636
+ func: The function to wrap
1637
+ tracer: The Tracer instance
1638
+ span_type: Type of span (default "span")
1639
+
1640
+ Returns:
1641
+ A wrapped function that will be traced when called
1642
+ """
1643
+ # Skip wrapping if the function is not callable or is a built-in
1644
+ if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
1645
+ return func
1646
+
1647
+ # Skip functions in the blocklist
1648
+ if func.__name__ in _TRACE_BLOCKLIST:
1649
+ return func
1650
+
1651
+ # Skip functions from certain modules (logging, sys, etc.)
1652
+ if func.__module__ and any(func.__module__.startswith(m) for m in ['logging', 'sys', 'os', 'json', 'time', 'datetime']):
1653
+ return func
1654
+
1655
+
1656
+ # Get function name for the span - check for custom name set by @observe
1657
+ func_name = getattr(func, '_judgment_span_name', func.__name__)
1658
+
1659
+ # Check for custom span_type set by @observe
1660
+ func_span_type = getattr(func, '_judgment_span_type', "span")
1661
+
1662
+ # Store original function to prevent losing reference
1663
+ original_func = func
1664
+
1665
+ # Create appropriate wrapper based on whether the function is async or not
1666
+ if asyncio.iscoroutinefunction(func):
1667
+ @functools.wraps(func)
1668
+ async def async_deep_wrapper(*args, **kwargs):
1669
+ # Get current trace from context
1670
+ current_trace = current_trace_var.get()
1671
+
1672
+ # If no trace context, just call the function
1673
+ if not current_trace:
1674
+ return await original_func(*args, **kwargs)
1675
+
1676
+ # Create a span for this function call - use custom span_type if available
1677
+ with current_trace.span(func_name, span_type=func_span_type) as span:
1678
+ # Record inputs
1679
+ span.record_input({
1680
+ 'args': str(args),
1681
+ 'kwargs': kwargs
1682
+ })
1683
+
1684
+ # Execute function
1685
+ result = await original_func(*args, **kwargs)
1686
+
1687
+ # Record output
1688
+ span.record_output(result)
1689
+
1690
+ return result
1691
+
1692
+ return async_deep_wrapper
1693
+ else:
1694
+ @functools.wraps(func)
1695
+ def deep_wrapper(*args, **kwargs):
1696
+ # Get current trace from context
1697
+ current_trace = current_trace_var.get()
1698
+
1699
+ # If no trace context, just call the function
1700
+ if not current_trace:
1701
+ return original_func(*args, **kwargs)
1702
+
1703
+ # Create a span for this function call - use custom span_type if available
1704
+ with current_trace.span(func_name, span_type=func_span_type) as span:
1705
+ # Record inputs
1706
+ span.record_input({
1707
+ 'args': str(args),
1708
+ 'kwargs': kwargs
1709
+ })
1710
+
1711
+ # Execute function
1712
+ result = original_func(*args, **kwargs)
1713
+
1714
+ # Record output
1715
+ span.record_output(result)
1716
+
1717
+ return result
1718
+
1719
+ return deep_wrapper
1366
1720
 
1367
1721
  # Add the new TraceThreadPoolExecutor class
1368
1722
  class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
@@ -1393,4 +1747,336 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
1393
1747
  return super().submit(ctx.run, func_with_bound_args)
1394
1748
 
1395
1749
  # Note: The `map` method would also need to be overridden for full context
1396
- # propagation if users rely on it, but `submit` is the most common use case.
1750
+ # propagation if users rely on it, but `submit` is the most common use case.
1751
+
1752
+ # Helper functions for stream processing
1753
+ # ---------------------------------------
1754
+
1755
+ def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
1756
+ """Extracts the text content from a stream chunk based on the client type."""
1757
+ try:
1758
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1759
+ return chunk.choices[0].delta.content
1760
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1761
+ # Anthropic streams various event types, we only care for content blocks
1762
+ if chunk.type == "content_block_delta":
1763
+ return chunk.delta.text
1764
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1765
+ # Google streams Candidate objects
1766
+ if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts:
1767
+ return chunk.candidates[0].content.parts[0].text
1768
+ except (AttributeError, IndexError, KeyError):
1769
+ # Handle cases where chunk structure is unexpected or doesn't contain content
1770
+ pass # Return None
1771
+ return None
1772
+
1773
+ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[Dict[str, int]]:
1774
+ """Extracts usage data if present in the *final* chunk (client-specific)."""
1775
+ try:
1776
+ # OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
1777
+ # This typically requires specific API versions or settings. Often usage is *not* streamed.
1778
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1779
+ # Check if usage is directly on the chunk (some models might do this)
1780
+ if hasattr(chunk, 'usage') and chunk.usage:
1781
+ return {
1782
+ "prompt_tokens": chunk.usage.prompt_tokens,
1783
+ "completion_tokens": chunk.usage.completion_tokens,
1784
+ "total_tokens": chunk.usage.total_tokens
1785
+ }
1786
+ # Check if usage is nested within choices (less common for final chunk, but check)
1787
+ elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
1788
+ usage = chunk.choices[0].usage
1789
+ return {
1790
+ "prompt_tokens": usage.prompt_tokens,
1791
+ "completion_tokens": usage.completion_tokens,
1792
+ "total_tokens": usage.total_tokens
1793
+ }
1794
+ # Anthropic includes usage in the 'message_stop' event type
1795
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1796
+ if chunk.type == "message_stop":
1797
+ # Anthropic final usage is often attached to the *message* object, not the chunk directly
1798
+ # The API might provide a way to get the final message object, but typically not in the stream itself.
1799
+ # Let's assume for now usage might appear in the final *chunk* metadata if supported.
1800
+ # This is a placeholder - Anthropic usage typically needs a separate call or context.
1801
+ pass
1802
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1803
+ # Google provides usage metadata on the full response object, not typically streamed per chunk.
1804
+ # It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
1805
+ if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
1806
+ return {
1807
+ "prompt_tokens": chunk.usage_metadata.prompt_token_count,
1808
+ "completion_tokens": chunk.usage_metadata.candidates_token_count,
1809
+ "total_tokens": chunk.usage_metadata.total_token_count
1810
+ }
1811
+
1812
+ except (AttributeError, IndexError, KeyError, TypeError):
1813
+ # Handle cases where usage data is missing or malformed
1814
+ pass # Return None
1815
+ return None
1816
+
1817
+
1818
+ # --- Sync Stream Wrapper ---
1819
+ def _sync_stream_wrapper(
1820
+ original_stream: Iterator,
1821
+ client: ApiClient,
1822
+ output_entry: TraceEntry
1823
+ ) -> Generator[Any, None, None]:
1824
+ """Wraps a synchronous stream iterator to capture content and update the trace."""
1825
+ content_parts = [] # Use a list instead of string concatenation
1826
+ final_usage = None
1827
+ last_chunk = None
1828
+ try:
1829
+ for chunk in original_stream:
1830
+ content_part = _extract_content_from_chunk(client, chunk)
1831
+ if content_part:
1832
+ content_parts.append(content_part) # Append to list instead of concatenating
1833
+ last_chunk = chunk # Keep track of the last chunk for potential usage data
1834
+ yield chunk # Pass the chunk to the caller
1835
+ finally:
1836
+ # Attempt to extract usage from the last chunk received
1837
+ if last_chunk:
1838
+ final_usage = _extract_usage_from_final_chunk(client, last_chunk)
1839
+
1840
+ # Update the trace entry with the accumulated content and usage
1841
+ output_entry.output = {
1842
+ "content": "".join(content_parts), # Join list at the end
1843
+ "usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
1844
+ "streamed": True
1845
+ }
1846
+ # Note: We might need to adjust _serialize_output if this dict causes issues,
1847
+ # but Pydantic's model_dump should handle dicts.
1848
+
1849
+ # --- Async Stream Wrapper ---
1850
+ async def _async_stream_wrapper(
1851
+ original_stream: AsyncIterator,
1852
+ client: ApiClient,
1853
+ output_entry: TraceEntry
1854
+ ) -> AsyncGenerator[Any, None]:
1855
+ # [Existing logic - unchanged]
1856
+ content_parts = [] # Use a list instead of string concatenation
1857
+ final_usage_data = None
1858
+ last_content_chunk = None
1859
+ anthropic_input_tokens = 0
1860
+ anthropic_output_tokens = 0
1861
+
1862
+ target_span_id = getattr(output_entry, 'span_id', 'UNKNOWN')
1863
+
1864
+ try:
1865
+ async for chunk in original_stream:
1866
+ # Check for OpenAI's final usage chunk
1867
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
1868
+ final_usage_data = {
1869
+ "prompt_tokens": chunk.usage.prompt_tokens,
1870
+ "completion_tokens": chunk.usage.completion_tokens,
1871
+ "total_tokens": chunk.usage.total_tokens
1872
+ }
1873
+ yield chunk
1874
+ continue
1875
+
1876
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
1877
+ if chunk.type == "message_start":
1878
+ if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
1879
+ anthropic_input_tokens = chunk.message.usage.input_tokens
1880
+ elif chunk.type == "message_delta":
1881
+ if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
1882
+ anthropic_output_tokens += chunk.usage.output_tokens
1883
+
1884
+ content_part = _extract_content_from_chunk(client, chunk)
1885
+ if content_part:
1886
+ content_parts.append(content_part) # Append to list instead of concatenating
1887
+ last_content_chunk = chunk
1888
+
1889
+ yield chunk
1890
+ finally:
1891
+ anthropic_final_usage = None
1892
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
1893
+ anthropic_final_usage = {
1894
+ "input_tokens": anthropic_input_tokens,
1895
+ "output_tokens": anthropic_output_tokens,
1896
+ "total_tokens": anthropic_input_tokens + anthropic_output_tokens
1897
+ }
1898
+
1899
+ usage_info = None
1900
+ if final_usage_data:
1901
+ usage_info = final_usage_data
1902
+ elif anthropic_final_usage:
1903
+ usage_info = anthropic_final_usage
1904
+ elif last_content_chunk:
1905
+ usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
1906
+
1907
+ if output_entry and hasattr(output_entry, 'output'):
1908
+ output_entry.output = {
1909
+ "content": "".join(content_parts), # Join list at the end
1910
+ "usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
1911
+ "streamed": True
1912
+ }
1913
+ start_ts = getattr(output_entry, 'created_at', time.time())
1914
+ output_entry.duration = time.time() - start_ts
1915
+ # else: # Handle error case if necessary, but remove debug print
1916
+
1917
+ # --- Define Context Manager Wrapper Classes ---
1918
+ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
1919
+ """Wraps an original async stream manager to add tracing."""
1920
+ def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
1921
+ self._original_manager = original_manager
1922
+ self._client = client
1923
+ self._span_name = span_name
1924
+ self._trace_client = trace_client
1925
+ self._stream_wrapper_func = stream_wrapper_func
1926
+ self._input_kwargs = input_kwargs
1927
+ self._parent_span_id_at_entry = None
1928
+
1929
+ async def __aenter__(self):
1930
+ self._parent_span_id_at_entry = current_span_var.get()
1931
+ if not self._trace_client:
1932
+ # If no trace, just delegate to the original manager
1933
+ return await self._original_manager.__aenter__()
1934
+
1935
+ # --- Manually create the 'enter' entry ---
1936
+ start_time = time.time()
1937
+ span_id = str(uuid.uuid4())
1938
+ current_depth = 0
1939
+ if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
1940
+ current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
1941
+ self._trace_client._span_depths[span_id] = current_depth
1942
+ enter_entry = TraceEntry(
1943
+ type="enter", function=self._span_name, span_id=span_id,
1944
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
1945
+ created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
1946
+ )
1947
+ self._trace_client.add_entry(enter_entry)
1948
+ # --- End manual 'enter' entry ---
1949
+
1950
+ # Set the current span ID in contextvars
1951
+ self._span_context_token = current_span_var.set(span_id)
1952
+
1953
+ # Manually create 'input' entry
1954
+ input_data = _format_input_data(self._client, **self._input_kwargs)
1955
+ input_entry = TraceEntry(
1956
+ type="input", function=self._span_name, span_id=span_id,
1957
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
1958
+ created_at=time.time(), inputs=input_data, span_type="llm"
1959
+ )
1960
+ self._trace_client.add_entry(input_entry)
1961
+
1962
+ # Call the original __aenter__
1963
+ raw_iterator = await self._original_manager.__aenter__()
1964
+
1965
+ # Manually create pending 'output' entry
1966
+ output_entry = TraceEntry(
1967
+ type="output", function=self._span_name, span_id=span_id,
1968
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
1969
+ created_at=time.time(), output="<pending stream>", span_type="llm"
1970
+ )
1971
+ self._trace_client.add_entry(output_entry)
1972
+
1973
+ # Wrap the raw iterator
1974
+ wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
1975
+ return wrapped_iterator
1976
+
1977
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
1978
+ # Manually create the 'exit' entry
1979
+ if hasattr(self, '_span_context_token'):
1980
+ span_id = current_span_var.get()
1981
+ start_time_for_duration = 0
1982
+ for entry in reversed(self._trace_client.entries):
1983
+ if entry.span_id == span_id and entry.type == 'enter':
1984
+ start_time_for_duration = entry.created_at
1985
+ break
1986
+ duration = time.time() - start_time_for_duration if start_time_for_duration else None
1987
+ exit_depth = self._trace_client._span_depths.get(span_id, 0)
1988
+ exit_entry = TraceEntry(
1989
+ type="exit", function=self._span_name, span_id=span_id,
1990
+ trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
1991
+ created_at=time.time(), duration=duration, span_type="llm"
1992
+ )
1993
+ self._trace_client.add_entry(exit_entry)
1994
+ if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
1995
+ current_span_var.reset(self._span_context_token)
1996
+ delattr(self, '_span_context_token')
1997
+
1998
+ # Delegate __aexit__
1999
+ if hasattr(self._original_manager, "__aexit__"):
2000
+ return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2001
+ return None
2002
+
2003
+ class _TracedSyncStreamManagerWrapper(AbstractContextManager):
2004
+ """Wraps an original sync stream manager to add tracing."""
2005
+ def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
2006
+ self._original_manager = original_manager
2007
+ self._client = client
2008
+ self._span_name = span_name
2009
+ self._trace_client = trace_client
2010
+ self._stream_wrapper_func = stream_wrapper_func
2011
+ self._input_kwargs = input_kwargs
2012
+ self._parent_span_id_at_entry = None
2013
+
2014
+ def __enter__(self):
2015
+ self._parent_span_id_at_entry = current_span_var.get()
2016
+ if not self._trace_client:
2017
+ return self._original_manager.__enter__()
2018
+
2019
+ # Manually create 'enter' entry
2020
+ start_time = time.time()
2021
+ span_id = str(uuid.uuid4())
2022
+ current_depth = 0
2023
+ if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2024
+ current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
2025
+ self._trace_client._span_depths[span_id] = current_depth
2026
+ enter_entry = TraceEntry(
2027
+ type="enter", function=self._span_name, span_id=span_id,
2028
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
2029
+ created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
2030
+ )
2031
+ self._trace_client.add_entry(enter_entry)
2032
+ self._span_context_token = current_span_var.set(span_id)
2033
+
2034
+ # Manually create 'input' entry
2035
+ input_data = _format_input_data(self._client, **self._input_kwargs)
2036
+ input_entry = TraceEntry(
2037
+ type="input", function=self._span_name, span_id=span_id,
2038
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
2039
+ created_at=time.time(), inputs=input_data, span_type="llm"
2040
+ )
2041
+ self._trace_client.add_entry(input_entry)
2042
+
2043
+ # Call original __enter__
2044
+ raw_iterator = self._original_manager.__enter__()
2045
+
2046
+ # Manually create 'output' entry (pending)
2047
+ output_entry = TraceEntry(
2048
+ type="output", function=self._span_name, span_id=span_id,
2049
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
2050
+ created_at=time.time(), output="<pending stream>", span_type="llm"
2051
+ )
2052
+ self._trace_client.add_entry(output_entry)
2053
+
2054
+ # Wrap the raw iterator
2055
+ wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
2056
+ return wrapped_iterator
2057
+
2058
+ def __exit__(self, exc_type, exc_val, exc_tb):
2059
+ # Manually create 'exit' entry
2060
+ if hasattr(self, '_span_context_token'):
2061
+ span_id = current_span_var.get()
2062
+ start_time_for_duration = 0
2063
+ for entry in reversed(self._trace_client.entries):
2064
+ if entry.span_id == span_id and entry.type == 'enter':
2065
+ start_time_for_duration = entry.created_at
2066
+ break
2067
+ duration = time.time() - start_time_for_duration if start_time_for_duration else None
2068
+ exit_depth = self._trace_client._span_depths.get(span_id, 0)
2069
+ exit_entry = TraceEntry(
2070
+ type="exit", function=self._span_name, span_id=span_id,
2071
+ trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
2072
+ created_at=time.time(), duration=duration, span_type="llm"
2073
+ )
2074
+ self._trace_client.add_entry(exit_entry)
2075
+ if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
2076
+ current_span_var.reset(self._span_context_token)
2077
+ delattr(self, '_span_context_token')
2078
+
2079
+ # Delegate __exit__
2080
+ if hasattr(self._original_manager, "__exit__"):
2081
+ return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2082
+ return None