judgeval 0.0.32__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. judgeval/common/s3_storage.py +93 -0
  2. judgeval/common/tracer.py +612 -123
  3. judgeval/data/sequence.py +4 -10
  4. judgeval/judgment_client.py +25 -86
  5. judgeval/rules.py +4 -7
  6. judgeval/run_evaluation.py +1 -1
  7. judgeval/scorers/__init__.py +4 -4
  8. judgeval/scorers/judgeval_scorers/__init__.py +0 -176
  9. {judgeval-0.0.32.dist-info → judgeval-0.0.33.dist-info}/METADATA +15 -2
  10. judgeval-0.0.33.dist-info/RECORD +63 -0
  11. judgeval/scorers/base_scorer.py +0 -58
  12. judgeval/scorers/judgeval_scorers/local_implementations/__init__.py +0 -27
  13. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/__init__.py +0 -4
  14. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/answer_correctness_scorer.py +0 -276
  15. judgeval/scorers/judgeval_scorers/local_implementations/answer_correctness/prompts.py +0 -169
  16. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/__init__.py +0 -4
  17. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/answer_relevancy_scorer.py +0 -298
  18. judgeval/scorers/judgeval_scorers/local_implementations/answer_relevancy/prompts.py +0 -174
  19. judgeval/scorers/judgeval_scorers/local_implementations/comparison/__init__.py +0 -0
  20. judgeval/scorers/judgeval_scorers/local_implementations/comparison/comparison_scorer.py +0 -161
  21. judgeval/scorers/judgeval_scorers/local_implementations/comparison/prompts.py +0 -222
  22. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/__init__.py +0 -3
  23. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/contextual_precision_scorer.py +0 -264
  24. judgeval/scorers/judgeval_scorers/local_implementations/contextual_precision/prompts.py +0 -106
  25. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/__init__.py +0 -3
  26. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/contextual_recall_scorer.py +0 -254
  27. judgeval/scorers/judgeval_scorers/local_implementations/contextual_recall/prompts.py +0 -142
  28. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/__init__.py +0 -3
  29. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/contextual_relevancy_scorer.py +0 -245
  30. judgeval/scorers/judgeval_scorers/local_implementations/contextual_relevancy/prompts.py +0 -121
  31. judgeval/scorers/judgeval_scorers/local_implementations/execution_order/__init__.py +0 -3
  32. judgeval/scorers/judgeval_scorers/local_implementations/execution_order/execution_order.py +0 -156
  33. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/__init__.py +0 -3
  34. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/faithfulness_scorer.py +0 -318
  35. judgeval/scorers/judgeval_scorers/local_implementations/faithfulness/prompts.py +0 -268
  36. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/__init__.py +0 -3
  37. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/hallucination_scorer.py +0 -264
  38. judgeval/scorers/judgeval_scorers/local_implementations/hallucination/prompts.py +0 -104
  39. judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/instruction_adherence.py +0 -232
  40. judgeval/scorers/judgeval_scorers/local_implementations/instruction_adherence/prompt.py +0 -102
  41. judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/__init__.py +0 -5
  42. judgeval/scorers/judgeval_scorers/local_implementations/json_correctness/json_correctness_scorer.py +0 -134
  43. judgeval/scorers/judgeval_scorers/local_implementations/summarization/__init__.py +0 -3
  44. judgeval/scorers/judgeval_scorers/local_implementations/summarization/prompts.py +0 -247
  45. judgeval/scorers/judgeval_scorers/local_implementations/summarization/summarization_scorer.py +0 -551
  46. judgeval-0.0.32.dist-info/RECORD +0 -97
  47. {judgeval-0.0.32.dist-info → judgeval-0.0.33.dist-info}/WHEEL +0 -0
  48. {judgeval-0.0.32.dist-info → judgeval-0.0.33.dist-info}/licenses/LICENSE.md +0 -0
judgeval/common/tracer.py CHANGED
@@ -12,12 +12,27 @@ import uuid
12
12
  import warnings
13
13
  import contextvars
14
14
  import sys
15
- from contextlib import contextmanager
15
+ from contextlib import contextmanager, asynccontextmanager, AbstractAsyncContextManager, AbstractContextManager # Import context manager bases
16
16
  from dataclasses import dataclass, field
17
17
  from datetime import datetime
18
18
  from http import HTTPStatus
19
- from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, TypeAlias, Union, Callable, Awaitable, Set
19
+ from typing import (
20
+ Any,
21
+ Callable,
22
+ Dict,
23
+ Generator,
24
+ List,
25
+ Literal,
26
+ Optional,
27
+ Tuple,
28
+ Type,
29
+ TypeVar,
30
+ Union,
31
+ AsyncGenerator,
32
+ TypeAlias,
33
+ )
20
34
  from rich import print as rprint
35
+ import types # <--- Add this import
21
36
 
22
37
  # Third-party imports
23
38
  import pika
@@ -42,13 +57,14 @@ from judgeval.constants import (
42
57
  )
43
58
  from judgeval.judgment_client import JudgmentClient
44
59
  from judgeval.data import Example
45
- from judgeval.scorers import APIJudgmentScorer, JudgevalScorer, ScorerWrapper
60
+ from judgeval.scorers import APIJudgmentScorer, JudgevalScorer
46
61
  from judgeval.rules import Rule
47
62
  from judgeval.evaluation_run import EvaluationRun
48
63
  from judgeval.data.result import ScoringResult
49
64
 
50
65
  # Standard library imports needed for the new class
51
66
  import concurrent.futures
67
+ from collections.abc import Iterator, AsyncIterator # Add Iterator and AsyncIterator
52
68
 
53
69
  # Define context variables for tracking the current trace and the current span within a trace
54
70
  current_trace_var = contextvars.ContextVar('current_trace', default=None)
@@ -173,7 +189,7 @@ class TraceEntry:
173
189
  "inputs": self._serialize_inputs(),
174
190
  "evaluation_runs": [evaluation_run.model_dump() for evaluation_run in self.evaluation_runs] if self.evaluation_runs else [],
175
191
  "span_type": self.span_type,
176
- "parent_span_id": self.parent_span_id
192
+ "parent_span_id": self.parent_span_id,
177
193
  }
178
194
 
179
195
  def _serialize_output(self) -> Any:
@@ -188,6 +204,15 @@ class TraceEntry:
188
204
  if isinstance(self.output, BaseModel):
189
205
  return self.output.model_dump()
190
206
 
207
+ # NEW check: If output is the dict structure from our stream wrapper
208
+ if isinstance(self.output, dict) and 'streamed' in self.output:
209
+ # Assume it's already JSON-serializable (content is string, usage is dict or None)
210
+ return self.output
211
+ # NEW check: If output is the placeholder string before stream completes
212
+ elif self.output == "<pending stream>":
213
+ # Represent this state clearly in the serialized data
214
+ return {"status": "pending stream"}
215
+
191
216
  try:
192
217
  # Try to serialize the output to verify it's JSON compatible
193
218
  json.dumps(self.output)
@@ -206,9 +231,10 @@ class TraceManagerClient:
206
231
  - Saving a trace
207
232
  - Deleting a trace
208
233
  """
209
- def __init__(self, judgment_api_key: str, organization_id: str):
234
+ def __init__(self, judgment_api_key: str, organization_id: str, tracer: Optional["Tracer"] = None):
210
235
  self.judgment_api_key = judgment_api_key
211
236
  self.organization_id = organization_id
237
+ self.tracer = tracer
212
238
 
213
239
  def fetch_trace(self, trace_id: str):
214
240
  """
@@ -236,12 +262,13 @@ class TraceManagerClient:
236
262
 
237
263
  def save_trace(self, trace_data: dict):
238
264
  """
239
- Saves a trace to the database
265
+ Saves a trace to the Judgment Supabase and optionally to S3 if configured.
240
266
 
241
267
  Args:
242
268
  trace_data: The trace data to save
243
269
  NOTE we save empty traces in order to properly handle async operations; we need something in the DB to associate the async results with
244
270
  """
271
+ # Save to Judgment API
245
272
  response = requests.post(
246
273
  JUDGMENT_TRACES_SAVE_API_URL,
247
274
  json=trace_data,
@@ -258,6 +285,18 @@ class TraceManagerClient:
258
285
  elif response.status_code != HTTPStatus.OK:
259
286
  raise ValueError(f"Failed to save trace data: {response.text}")
260
287
 
288
+ # If S3 storage is enabled, save to S3 as well
289
+ if self.tracer and self.tracer.use_s3:
290
+ try:
291
+ s3_key = self.tracer.s3_storage.save_trace(
292
+ trace_data=trace_data,
293
+ trace_id=trace_data["trace_id"],
294
+ project_name=trace_data["project_name"]
295
+ )
296
+ print(f"Trace also saved to S3 at key: {s3_key}")
297
+ except Exception as e:
298
+ warnings.warn(f"Failed to save trace to S3: {str(e)}")
299
+
261
300
  if "ui_results_url" in response.json():
262
301
  pretty_str = f"\n🔍 You can view your trace data here: [rgb(106,0,255)][link={response.json()['ui_results_url']}]View Trace[/link]\n"
263
302
  rprint(pretty_str)
@@ -355,7 +394,7 @@ class TraceClient:
355
394
  self.client: JudgmentClient = tracer.client
356
395
  self.entries: List[TraceEntry] = []
357
396
  self.start_time = time.time()
358
- self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id)
397
+ self.trace_manager_client = TraceManagerClient(tracer.api_key, tracer.organization_id, tracer)
359
398
  self.visited_nodes = []
360
399
  self.executed_tools = []
361
400
  self.executed_node_tools = []
@@ -393,13 +432,13 @@ class TraceClient:
393
432
  entry = TraceEntry(
394
433
  type="enter",
395
434
  function=name,
396
- span_id=span_id, # Use the generated span_id
397
- trace_id=self.trace_id, # Use the trace_id from the trace client
435
+ span_id=span_id,
436
+ trace_id=self.trace_id,
398
437
  depth=current_depth,
399
438
  message=name,
400
439
  created_at=start_time,
401
440
  span_type=span_type,
402
- parent_span_id=parent_span_id # Use the parent_id from context var
441
+ parent_span_id=parent_span_id,
403
442
  )
404
443
  self.add_entry(entry)
405
444
 
@@ -417,7 +456,7 @@ class TraceClient:
417
456
  message=f"← {name}",
418
457
  created_at=time.time(),
419
458
  duration=duration,
420
- span_type=span_type
459
+ span_type=span_type,
421
460
  ))
422
461
  # Clean up depth tracking for this span_id
423
462
  if span_id in self._span_depths:
@@ -454,47 +493,14 @@ class TraceClient:
454
493
  additional_metadata=additional_metadata,
455
494
  trace_id=self.trace_id
456
495
  )
457
- loaded_rules = None
458
- if self.rules:
459
- loaded_rules = []
460
- for rule in self.rules:
461
- processed_conditions = []
462
- for condition in rule.conditions:
463
- # Convert metric if it's a ScorerWrapper
464
- try:
465
- if isinstance(condition.metric, ScorerWrapper):
466
- condition_copy = condition.model_copy()
467
- condition_copy.metric = condition.metric.load_implementation(use_judgment=True)
468
- processed_conditions.append(condition_copy)
469
- else:
470
- processed_conditions.append(condition)
471
- except Exception as e:
472
- warnings.warn(f"Failed to convert ScorerWrapper in rule '{rule.name}', condition metric '{condition.metric_name}': {str(e)}")
473
- processed_conditions.append(condition) # Keep original condition as fallback
474
-
475
- # Create new rule with processed conditions
476
- new_rule = rule.model_copy()
477
- new_rule.conditions = processed_conditions
478
- loaded_rules.append(new_rule)
479
496
  try:
480
497
  # Load appropriate implementations for all scorers
481
- loaded_scorers: List[Union[JudgevalScorer, APIJudgmentScorer]] = []
482
- for scorer in scorers:
483
- try:
484
- if isinstance(scorer, ScorerWrapper):
485
- loaded_scorers.append(scorer.load_implementation(use_judgment=True))
486
- else:
487
- loaded_scorers.append(scorer)
488
- except Exception as e:
489
- warnings.warn(f"Failed to load implementation for scorer {scorer}: {str(e)}")
490
- # Skip this scorer
491
-
492
- if not loaded_scorers:
498
+ if not scorers:
493
499
  warnings.warn("No valid scorers available for evaluation")
494
500
  return
495
501
 
496
502
  # Prevent using JudgevalScorer with rules - only APIJudgmentScorer allowed with rules
497
- if loaded_rules and any(isinstance(scorer, JudgevalScorer) for scorer in loaded_scorers):
503
+ if self.rules and any(isinstance(scorer, JudgevalScorer) for scorer in scorers):
498
504
  raise ValueError("Cannot use Judgeval scorers, you can only use API scorers when using rules. Please either remove rules or use only APIJudgmentScorer types.")
499
505
 
500
506
  except Exception as e:
@@ -508,15 +514,15 @@ class TraceClient:
508
514
  project_name=self.project_name,
509
515
  eval_name=f"{self.name.capitalize()}-"
510
516
  f"{current_span_var.get()}-"
511
- f"[{','.join(scorer.score_type.capitalize() for scorer in loaded_scorers)}]",
517
+ f"[{','.join(scorer.score_type.capitalize() for scorer in scorers)}]",
512
518
  examples=[example],
513
- scorers=loaded_scorers,
519
+ scorers=scorers,
514
520
  model=model,
515
521
  metadata={},
516
522
  judgment_api_key=self.tracer.api_key,
517
523
  override=self.overwrite,
518
524
  trace_span_id=current_span_var.get(),
519
- rules=loaded_rules # Use the combined rules
525
+ rules=self.rules # Use the combined rules
520
526
  )
521
527
 
522
528
  self.add_eval_run(eval_run, start_time) # Pass start_time to record_evaluation
@@ -574,7 +580,7 @@ class TraceClient:
574
580
  message=f"Inputs to {function_name}",
575
581
  created_at=time.time(),
576
582
  inputs=inputs,
577
- span_type=entry_span_type
583
+ span_type=entry_span_type,
578
584
  ))
579
585
 
580
586
  async def _update_coroutine_output(self, entry: TraceEntry, coroutine: Any):
@@ -607,12 +613,15 @@ class TraceClient:
607
613
  message=f"Output from {function_name}",
608
614
  created_at=time.time(),
609
615
  output="<pending>" if inspect.iscoroutine(output) else output,
610
- span_type=entry_span_type
616
+ span_type=entry_span_type,
611
617
  )
612
618
  self.add_entry(entry)
613
619
 
614
620
  if inspect.iscoroutine(output):
615
621
  asyncio.create_task(self._update_coroutine_output(entry, output))
622
+
623
+ # Return the created entry
624
+ return entry
616
625
 
617
626
  def add_entry(self, entry: TraceEntry):
618
627
  """Add a trace entry to this trace context"""
@@ -824,8 +833,10 @@ class TraceClient:
824
833
  total_completion_tokens_cost = 0.0
825
834
  total_cost = 0.0
826
835
 
836
+ # Only count tokens for actual LLM API call spans
837
+ llm_span_names = {"OPENAI_API_CALL", "TOGETHER_API_CALL", "ANTHROPIC_API_CALL", "GOOGLE_API_CALL"}
827
838
  for entry in condensed_entries:
828
- if entry.get("span_type") == "llm" and isinstance(entry.get("output"), dict):
839
+ if entry.get("span_type") == "llm" and entry.get("function") in llm_span_names and isinstance(entry.get("output"), dict):
829
840
  output = entry["output"]
830
841
  usage = output.get("usage", {})
831
842
  model_name = entry.get("inputs", {}).get("model", "")
@@ -921,6 +932,12 @@ class Tracer:
921
932
  organization_id: str = os.getenv("JUDGMENT_ORG_ID"),
922
933
  enable_monitoring: bool = os.getenv("JUDGMENT_MONITORING", "true").lower() == "true",
923
934
  enable_evaluations: bool = os.getenv("JUDGMENT_EVALUATIONS", "true").lower() == "true",
935
+ # S3 configuration
936
+ use_s3: bool = False,
937
+ s3_bucket_name: Optional[str] = None,
938
+ s3_aws_access_key_id: Optional[str] = None,
939
+ s3_aws_secret_access_key: Optional[str] = None,
940
+ s3_region_name: Optional[str] = None,
924
941
  deep_tracing: bool = True # NEW: Enable deep tracing by default
925
942
  ):
926
943
  if not hasattr(self, 'initialized'):
@@ -929,6 +946,13 @@ class Tracer:
929
946
 
930
947
  if not organization_id:
931
948
  raise ValueError("Tracer must be configured with an Organization ID")
949
+ if use_s3 and not s3_bucket_name:
950
+ raise ValueError("S3 bucket name must be provided when use_s3 is True")
951
+ if use_s3 and not (s3_aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")):
952
+ raise ValueError("AWS Access Key ID must be provided when use_s3 is True")
953
+ if use_s3 and not (s3_aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")):
954
+ raise ValueError("AWS Secret Access Key must be provided when use_s3 is True")
955
+
932
956
  self.api_key: str = api_key
933
957
  self.project_name: str = project_name
934
958
  self.client: JudgmentClient = JudgmentClient(judgment_api_key=api_key)
@@ -938,7 +962,19 @@ class Tracer:
938
962
  self.initialized: bool = True
939
963
  self.enable_monitoring: bool = enable_monitoring
940
964
  self.enable_evaluations: bool = enable_evaluations
965
+
966
+ # Initialize S3 storage if enabled
967
+ self.use_s3 = use_s3
968
+ if use_s3:
969
+ from judgeval.common.s3_storage import S3Storage
970
+ self.s3_storage = S3Storage(
971
+ bucket_name=s3_bucket_name,
972
+ aws_access_key_id=s3_aws_access_key_id,
973
+ aws_secret_access_key=s3_aws_secret_access_key,
974
+ region_name=s3_region_name
975
+ )
941
976
  self.deep_tracing: bool = deep_tracing # NEW: Store deep tracing setting
977
+
942
978
  elif hasattr(self, 'project_name') and self.project_name != project_name:
943
979
  warnings.warn(
944
980
  f"Attempting to initialize Tracer with project_name='{project_name}' but it was already initialized with "
@@ -1320,100 +1356,192 @@ class Tracer:
1320
1356
  def wrap(client: Any) -> Any:
1321
1357
  """
1322
1358
  Wraps an API client to add tracing capabilities.
1323
- Supports OpenAI, Together, and Anthropic clients.
1359
+ Supports OpenAI, Together, Anthropic, and Google GenAI clients.
1360
+ Patches both '.create' and Anthropic's '.stream' methods using a wrapper class.
1324
1361
  """
1325
- # Get the appropriate configuration for this client type
1326
- span_name, original_create = _get_client_config(client)
1362
+ span_name, original_create, original_stream = _get_client_config(client)
1327
1363
 
1328
- # Handle async clients differently than synchronous clients (need an async function for async clients)
1329
- if (isinstance(client, (AsyncOpenAI, AsyncAnthropic, AsyncTogether, genai.client.AsyncClient))):
1330
- async def traced_create(*args, **kwargs):
1331
- # Get the current trace from contextvars
1332
- current_trace = current_trace_var.get()
1333
-
1334
- # Skip tracing if no active trace
1335
- if not current_trace:
1336
- return original_create(*args, **kwargs)
1364
+ # --- Define Traced Async Functions ---
1365
+ async def traced_create_async(*args, **kwargs):
1366
+ # [Existing logic - unchanged]
1367
+ current_trace = current_trace_var.get()
1368
+ if not current_trace:
1369
+ if asyncio.iscoroutinefunction(original_create):
1370
+ return await original_create(*args, **kwargs)
1371
+ else:
1372
+ return original_create(*args, **kwargs)
1373
+
1374
+ is_streaming = kwargs.get("stream", False)
1375
+
1376
+ with current_trace.span(span_name, span_type="llm") as span:
1377
+ input_data = _format_input_data(client, **kwargs)
1378
+ span.record_input(input_data)
1379
+
1380
+ # Warn about token counting limitations with streaming
1381
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1382
+ if not kwargs.get("stream_options", {}).get("include_usage"):
1383
+ warnings.warn(
1384
+ "OpenAI streaming calls don't include token counts by default. "
1385
+ "To enable token counting with streams, set stream_options={'include_usage': True} "
1386
+ "in your API call arguments.",
1387
+ UserWarning
1388
+ )
1337
1389
 
1338
- with current_trace.span(span_name, span_type="llm") as span:
1339
- # Format and record the input parameters
1340
- input_data = _format_input_data(client, **kwargs)
1341
- span.record_input(input_data)
1342
-
1343
- # Make the actual API call
1344
- try:
1345
- response = await original_create(*args, **kwargs)
1346
- except Exception as e:
1347
- print(f"Error during API call: {e}")
1348
- raise
1349
-
1350
- # Format and record the output
1351
- output_data = _format_output_data(client, response)
1352
- span.record_output(output_data)
1353
-
1354
- return response
1355
- else:
1356
- def traced_create(*args, **kwargs):
1357
- # Get the current trace from contextvars
1358
- current_trace = current_trace_var.get()
1359
-
1360
- # Skip tracing if no active trace
1361
- if not current_trace:
1362
- return original_create(*args, **kwargs)
1390
+ try:
1391
+ if is_streaming:
1392
+ stream_iterator = await original_create(*args, **kwargs)
1393
+ output_entry = span.record_output("<pending stream>")
1394
+ return _async_stream_wrapper(stream_iterator, client, output_entry)
1395
+ else:
1396
+ awaited_response = await original_create(*args, **kwargs)
1397
+ output_data = _format_output_data(client, awaited_response)
1398
+ span.record_output(output_data)
1399
+ return awaited_response
1400
+ except Exception as e:
1401
+ print(f"Error during wrapped async API call ({span_name}): {e}")
1402
+ span.record_output({"error": str(e)})
1403
+ raise
1404
+
1405
+
1406
+ # Function replacing .stream() - NOW returns the wrapper class instance
1407
+ def traced_stream_async(*args, **kwargs):
1408
+ current_trace = current_trace_var.get()
1409
+ if not current_trace or not original_stream:
1410
+ return original_stream(*args, **kwargs)
1411
+ original_manager = original_stream(*args, **kwargs)
1412
+ wrapper_manager = _TracedAsyncStreamManagerWrapper(
1413
+ original_manager=original_manager,
1414
+ client=client,
1415
+ span_name=span_name,
1416
+ trace_client=current_trace,
1417
+ stream_wrapper_func=_async_stream_wrapper,
1418
+ input_kwargs=kwargs
1419
+ )
1420
+ return wrapper_manager
1421
+
1422
+ # --- Define Traced Sync Functions ---
1423
+ def traced_create_sync(*args, **kwargs):
1424
+ # [Existing logic - unchanged]
1425
+ current_trace = current_trace_var.get()
1426
+ if not current_trace:
1427
+ return original_create(*args, **kwargs)
1428
+
1429
+ is_streaming = kwargs.get("stream", False)
1430
+
1431
+ with current_trace.span(span_name, span_type="llm") as span:
1432
+ input_data = _format_input_data(client, **kwargs)
1433
+ span.record_input(input_data)
1434
+
1435
+ # Warn about token counting limitations with streaming
1436
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and is_streaming:
1437
+ if not kwargs.get("stream_options", {}).get("include_usage"):
1438
+ warnings.warn(
1439
+ "OpenAI streaming calls don't include token counts by default. "
1440
+ "To enable token counting with streams, set stream_options={'include_usage': True} "
1441
+ "in your API call arguments.",
1442
+ UserWarning
1443
+ )
1444
+
1445
+ try:
1446
+ response_or_iterator = original_create(*args, **kwargs)
1447
+ except Exception as e:
1448
+ print(f"Error during wrapped sync API call ({span_name}): {e}")
1449
+ span.record_output({"error": str(e)})
1450
+ raise
1451
+
1452
+ if is_streaming:
1453
+ output_entry = span.record_output("<pending stream>")
1454
+ return _sync_stream_wrapper(response_or_iterator, client, output_entry)
1455
+ else:
1456
+ output_data = _format_output_data(client, response_or_iterator)
1457
+ span.record_output(output_data)
1458
+ return response_or_iterator
1459
+
1460
+
1461
+ # Function replacing sync .stream()
1462
+ def traced_stream_sync(*args, **kwargs):
1463
+ current_trace = current_trace_var.get()
1464
+ if not current_trace or not original_stream:
1465
+ return original_stream(*args, **kwargs)
1466
+ original_manager = original_stream(*args, **kwargs)
1467
+ wrapper_manager = _TracedSyncStreamManagerWrapper(
1468
+ original_manager=original_manager,
1469
+ client=client,
1470
+ span_name=span_name,
1471
+ trace_client=current_trace,
1472
+ stream_wrapper_func=_sync_stream_wrapper,
1473
+ input_kwargs=kwargs
1474
+ )
1475
+ return wrapper_manager
1476
+
1477
+
1478
+ # --- Assign Traced Methods to Client Instance ---
1479
+ # [Assignment logic remains the same]
1480
+ if isinstance(client, (AsyncOpenAI, AsyncTogether)):
1481
+ client.chat.completions.create = traced_create_async
1482
+ # Wrap the Responses API endpoint for OpenAI clients
1483
+ if hasattr(client, "responses") and hasattr(client.responses, "create"):
1484
+ # Capture the original responses.create
1485
+ original_responses_create = client.responses.create
1486
+ def traced_responses(*args, **kwargs):
1487
+ # Get the current trace from contextvars
1488
+ current_trace = current_trace_var.get()
1489
+ # If no active trace, call the original
1490
+ if not current_trace:
1491
+ return original_responses_create(*args, **kwargs)
1492
+ # Trace this responses.create call
1493
+ with current_trace.span(span_name, span_type="llm") as span:
1494
+ # Record raw input kwargs
1495
+ span.record_input(kwargs)
1496
+ # Make the actual API call
1497
+ response = original_responses_create(*args, **kwargs)
1498
+ # Record the output object
1499
+ span.record_output(response)
1500
+ return response
1501
+ # Assign the traced wrapper
1502
+ client.responses.create = traced_responses
1503
+ elif isinstance(client, AsyncAnthropic):
1504
+ client.messages.create = traced_create_async
1505
+ if original_stream:
1506
+ client.messages.stream = traced_stream_async
1507
+ elif isinstance(client, genai.client.AsyncClient):
1508
+ client.generate_content = traced_create_async
1509
+ elif isinstance(client, (OpenAI, Together)):
1510
+ client.chat.completions.create = traced_create_sync
1511
+ elif isinstance(client, Anthropic):
1512
+ client.messages.create = traced_create_sync
1513
+ if original_stream:
1514
+ client.messages.stream = traced_stream_sync
1515
+ elif isinstance(client, genai.Client):
1516
+ client.generate_content = traced_create_sync
1363
1517
 
1364
- with current_trace.span(span_name, span_type="llm") as span:
1365
- # Format and record the input parameters
1366
- input_data = _format_input_data(client, **kwargs)
1367
- span.record_input(input_data)
1368
-
1369
- # Make the actual API call
1370
- try:
1371
- response = original_create(*args, **kwargs)
1372
- except Exception as e:
1373
- print(f"Error during API call: {e}")
1374
- raise
1375
-
1376
- # Format and record the output
1377
- output_data = _format_output_data(client, response)
1378
- span.record_output(output_data)
1379
-
1380
- return response
1381
-
1382
-
1383
- # Replace the original method with our traced version
1384
- if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1385
- client.chat.completions.create = traced_create
1386
- elif isinstance(client, (Anthropic, AsyncAnthropic)):
1387
- client.messages.create = traced_create
1388
- elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1389
- client.models.generate_content = traced_create
1390
-
1391
1518
  return client
1392
1519
 
1393
1520
  # Helper functions for client-specific operations
1394
1521
 
1395
- def _get_client_config(client: ApiClient) -> tuple[str, callable]:
1522
+ def _get_client_config(client: ApiClient) -> tuple[str, callable, Optional[callable]]:
1396
1523
  """Returns configuration tuple for the given API client.
1397
1524
 
1398
1525
  Args:
1399
1526
  client: An instance of OpenAI, Together, or Anthropic client
1400
1527
 
1401
1528
  Returns:
1402
- tuple: (span_name, create_method)
1529
+ tuple: (span_name, create_method, stream_method)
1403
1530
  - span_name: String identifier for tracing
1404
1531
  - create_method: Reference to the client's creation method
1532
+ - stream_method: Reference to the client's stream method (if applicable)
1405
1533
 
1406
1534
  Raises:
1407
1535
  ValueError: If client type is not supported
1408
1536
  """
1409
1537
  if isinstance(client, (OpenAI, AsyncOpenAI)):
1410
- return "OPENAI_API_CALL", client.chat.completions.create
1538
+ return "OPENAI_API_CALL", client.chat.completions.create, None
1411
1539
  elif isinstance(client, (Together, AsyncTogether)):
1412
- return "TOGETHER_API_CALL", client.chat.completions.create
1540
+ return "TOGETHER_API_CALL", client.chat.completions.create, None
1413
1541
  elif isinstance(client, (Anthropic, AsyncAnthropic)):
1414
- return "ANTHROPIC_API_CALL", client.messages.create
1542
+ return "ANTHROPIC_API_CALL", client.messages.create, client.messages.stream
1415
1543
  elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1416
- return "GOOGLE_API_CALL", client.models.generate_content
1544
+ return "GOOGLE_API_CALL", client.models.generate_content, None
1417
1545
  raise ValueError(f"Unsupported client type: {type(client)}")
1418
1546
 
1419
1547
  def _format_input_data(client: ApiClient, **kwargs) -> dict:
@@ -1478,6 +1606,26 @@ def _format_output_data(client: ApiClient, response: Any) -> dict:
1478
1606
  }
1479
1607
  }
1480
1608
 
1609
+ # Define a blocklist of functions that should not be traced
1610
+ # These are typically utility functions, print statements, logging, etc.
1611
+ _TRACE_BLOCKLIST = {
1612
+ # Built-in functions
1613
+ 'print', 'str', 'int', 'float', 'bool', 'list', 'dict', 'set', 'tuple',
1614
+ 'len', 'range', 'enumerate', 'zip', 'map', 'filter', 'sorted', 'reversed',
1615
+ 'min', 'max', 'sum', 'any', 'all', 'abs', 'round', 'format',
1616
+ # Logging functions
1617
+ 'debug', 'info', 'warning', 'error', 'critical', 'exception', 'log',
1618
+ # Common utility functions
1619
+ 'sleep', 'time', 'datetime', 'json', 'dumps', 'loads',
1620
+ # String operations
1621
+ 'join', 'split', 'strip', 'lstrip', 'rstrip', 'replace', 'lower', 'upper',
1622
+ # Dict operations
1623
+ 'get', 'items', 'keys', 'values', 'update',
1624
+ # List operations
1625
+ 'append', 'extend', 'insert', 'remove', 'pop', 'clear', 'index', 'count', 'sort',
1626
+ }
1627
+
1628
+
1481
1629
  # Add a new function for deep tracing at the module level
1482
1630
  def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
1483
1631
  """
@@ -1496,6 +1644,15 @@ def _create_deep_tracing_wrapper(func, tracer, span_type="span"):
1496
1644
  if not callable(func) or isinstance(func, type) or func.__module__ == 'builtins':
1497
1645
  return func
1498
1646
 
1647
+ # Skip functions in the blocklist
1648
+ if func.__name__ in _TRACE_BLOCKLIST:
1649
+ return func
1650
+
1651
+ # Skip functions from certain modules (logging, sys, etc.)
1652
+ if func.__module__ and any(func.__module__.startswith(m) for m in ['logging', 'sys', 'os', 'json', 'time', 'datetime']):
1653
+ return func
1654
+
1655
+
1499
1656
  # Get function name for the span - check for custom name set by @observe
1500
1657
  func_name = getattr(func, '_judgment_span_name', func.__name__)
1501
1658
 
@@ -1590,4 +1747,336 @@ class TraceThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
1590
1747
  return super().submit(ctx.run, func_with_bound_args)
1591
1748
 
1592
1749
  # Note: The `map` method would also need to be overridden for full context
1593
- # propagation if users rely on it, but `submit` is the most common use case.
1750
+ # propagation if users rely on it, but `submit` is the most common use case.
1751
+
1752
+ # Helper functions for stream processing
1753
+ # ---------------------------------------
1754
+
1755
+ def _extract_content_from_chunk(client: ApiClient, chunk: Any) -> Optional[str]:
1756
+ """Extracts the text content from a stream chunk based on the client type."""
1757
+ try:
1758
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1759
+ return chunk.choices[0].delta.content
1760
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1761
+ # Anthropic streams various event types, we only care for content blocks
1762
+ if chunk.type == "content_block_delta":
1763
+ return chunk.delta.text
1764
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1765
+ # Google streams Candidate objects
1766
+ if chunk.candidates and chunk.candidates[0].content and chunk.candidates[0].content.parts:
1767
+ return chunk.candidates[0].content.parts[0].text
1768
+ except (AttributeError, IndexError, KeyError):
1769
+ # Handle cases where chunk structure is unexpected or doesn't contain content
1770
+ pass # Return None
1771
+ return None
1772
+
1773
+ def _extract_usage_from_final_chunk(client: ApiClient, chunk: Any) -> Optional[Dict[str, int]]:
1774
+ """Extracts usage data if present in the *final* chunk (client-specific)."""
1775
+ try:
1776
+ # OpenAI/Together include usage in the *last* chunk's `usage` attribute if available
1777
+ # This typically requires specific API versions or settings. Often usage is *not* streamed.
1778
+ if isinstance(client, (OpenAI, Together, AsyncOpenAI, AsyncTogether)):
1779
+ # Check if usage is directly on the chunk (some models might do this)
1780
+ if hasattr(chunk, 'usage') and chunk.usage:
1781
+ return {
1782
+ "prompt_tokens": chunk.usage.prompt_tokens,
1783
+ "completion_tokens": chunk.usage.completion_tokens,
1784
+ "total_tokens": chunk.usage.total_tokens
1785
+ }
1786
+ # Check if usage is nested within choices (less common for final chunk, but check)
1787
+ elif chunk.choices and hasattr(chunk.choices[0], 'usage') and chunk.choices[0].usage:
1788
+ usage = chunk.choices[0].usage
1789
+ return {
1790
+ "prompt_tokens": usage.prompt_tokens,
1791
+ "completion_tokens": usage.completion_tokens,
1792
+ "total_tokens": usage.total_tokens
1793
+ }
1794
+ # Anthropic includes usage in the 'message_stop' event type
1795
+ elif isinstance(client, (Anthropic, AsyncAnthropic)):
1796
+ if chunk.type == "message_stop":
1797
+ # Anthropic final usage is often attached to the *message* object, not the chunk directly
1798
+ # The API might provide a way to get the final message object, but typically not in the stream itself.
1799
+ # Let's assume for now usage might appear in the final *chunk* metadata if supported.
1800
+ # This is a placeholder - Anthropic usage typically needs a separate call or context.
1801
+ pass
1802
+ elif isinstance(client, (genai.Client, genai.client.AsyncClient)):
1803
+ # Google provides usage metadata on the full response object, not typically streamed per chunk.
1804
+ # It might be in the *last* chunk's usage_metadata if the stream implementation supports it.
1805
+ if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
1806
+ return {
1807
+ "prompt_tokens": chunk.usage_metadata.prompt_token_count,
1808
+ "completion_tokens": chunk.usage_metadata.candidates_token_count,
1809
+ "total_tokens": chunk.usage_metadata.total_token_count
1810
+ }
1811
+
1812
+ except (AttributeError, IndexError, KeyError, TypeError):
1813
+ # Handle cases where usage data is missing or malformed
1814
+ pass # Return None
1815
+ return None
1816
+
1817
+
1818
+ # --- Sync Stream Wrapper ---
1819
+ def _sync_stream_wrapper(
1820
+ original_stream: Iterator,
1821
+ client: ApiClient,
1822
+ output_entry: TraceEntry
1823
+ ) -> Generator[Any, None, None]:
1824
+ """Wraps a synchronous stream iterator to capture content and update the trace."""
1825
+ content_parts = [] # Use a list instead of string concatenation
1826
+ final_usage = None
1827
+ last_chunk = None
1828
+ try:
1829
+ for chunk in original_stream:
1830
+ content_part = _extract_content_from_chunk(client, chunk)
1831
+ if content_part:
1832
+ content_parts.append(content_part) # Append to list instead of concatenating
1833
+ last_chunk = chunk # Keep track of the last chunk for potential usage data
1834
+ yield chunk # Pass the chunk to the caller
1835
+ finally:
1836
+ # Attempt to extract usage from the last chunk received
1837
+ if last_chunk:
1838
+ final_usage = _extract_usage_from_final_chunk(client, last_chunk)
1839
+
1840
+ # Update the trace entry with the accumulated content and usage
1841
+ output_entry.output = {
1842
+ "content": "".join(content_parts), # Join list at the end
1843
+ "usage": final_usage if final_usage else {"info": "Usage data not available in stream."}, # Provide placeholder if None
1844
+ "streamed": True
1845
+ }
1846
+ # Note: We might need to adjust _serialize_output if this dict causes issues,
1847
+ # but Pydantic's model_dump should handle dicts.
1848
+
1849
+ # --- Async Stream Wrapper ---
1850
+ async def _async_stream_wrapper(
1851
+ original_stream: AsyncIterator,
1852
+ client: ApiClient,
1853
+ output_entry: TraceEntry
1854
+ ) -> AsyncGenerator[Any, None]:
1855
+ # [Existing logic - unchanged]
1856
+ content_parts = [] # Use a list instead of string concatenation
1857
+ final_usage_data = None
1858
+ last_content_chunk = None
1859
+ anthropic_input_tokens = 0
1860
+ anthropic_output_tokens = 0
1861
+
1862
+ target_span_id = getattr(output_entry, 'span_id', 'UNKNOWN')
1863
+
1864
+ try:
1865
+ async for chunk in original_stream:
1866
+ # Check for OpenAI's final usage chunk
1867
+ if isinstance(client, (AsyncOpenAI, OpenAI)) and hasattr(chunk, 'usage') and chunk.usage is not None:
1868
+ final_usage_data = {
1869
+ "prompt_tokens": chunk.usage.prompt_tokens,
1870
+ "completion_tokens": chunk.usage.completion_tokens,
1871
+ "total_tokens": chunk.usage.total_tokens
1872
+ }
1873
+ yield chunk
1874
+ continue
1875
+
1876
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and hasattr(chunk, 'type'):
1877
+ if chunk.type == "message_start":
1878
+ if hasattr(chunk, 'message') and hasattr(chunk.message, 'usage') and hasattr(chunk.message.usage, 'input_tokens'):
1879
+ anthropic_input_tokens = chunk.message.usage.input_tokens
1880
+ elif chunk.type == "message_delta":
1881
+ if hasattr(chunk, 'usage') and hasattr(chunk.usage, 'output_tokens'):
1882
+ anthropic_output_tokens += chunk.usage.output_tokens
1883
+
1884
+ content_part = _extract_content_from_chunk(client, chunk)
1885
+ if content_part:
1886
+ content_parts.append(content_part) # Append to list instead of concatenating
1887
+ last_content_chunk = chunk
1888
+
1889
+ yield chunk
1890
+ finally:
1891
+ anthropic_final_usage = None
1892
+ if isinstance(client, (AsyncAnthropic, Anthropic)) and (anthropic_input_tokens > 0 or anthropic_output_tokens > 0):
1893
+ anthropic_final_usage = {
1894
+ "input_tokens": anthropic_input_tokens,
1895
+ "output_tokens": anthropic_output_tokens,
1896
+ "total_tokens": anthropic_input_tokens + anthropic_output_tokens
1897
+ }
1898
+
1899
+ usage_info = None
1900
+ if final_usage_data:
1901
+ usage_info = final_usage_data
1902
+ elif anthropic_final_usage:
1903
+ usage_info = anthropic_final_usage
1904
+ elif last_content_chunk:
1905
+ usage_info = _extract_usage_from_final_chunk(client, last_content_chunk)
1906
+
1907
+ if output_entry and hasattr(output_entry, 'output'):
1908
+ output_entry.output = {
1909
+ "content": "".join(content_parts), # Join list at the end
1910
+ "usage": usage_info if usage_info else {"info": "Usage data not available in stream."},
1911
+ "streamed": True
1912
+ }
1913
+ start_ts = getattr(output_entry, 'created_at', time.time())
1914
+ output_entry.duration = time.time() - start_ts
1915
+ # else: # Handle error case if necessary, but remove debug print
1916
+
1917
+ # --- Define Context Manager Wrapper Classes ---
1918
+ class _TracedAsyncStreamManagerWrapper(AbstractAsyncContextManager):
1919
+ """Wraps an original async stream manager to add tracing."""
1920
+ def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
1921
+ self._original_manager = original_manager
1922
+ self._client = client
1923
+ self._span_name = span_name
1924
+ self._trace_client = trace_client
1925
+ self._stream_wrapper_func = stream_wrapper_func
1926
+ self._input_kwargs = input_kwargs
1927
+ self._parent_span_id_at_entry = None
1928
+
1929
+ async def __aenter__(self):
1930
+ self._parent_span_id_at_entry = current_span_var.get()
1931
+ if not self._trace_client:
1932
+ # If no trace, just delegate to the original manager
1933
+ return await self._original_manager.__aenter__()
1934
+
1935
+ # --- Manually create the 'enter' entry ---
1936
+ start_time = time.time()
1937
+ span_id = str(uuid.uuid4())
1938
+ current_depth = 0
1939
+ if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
1940
+ current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
1941
+ self._trace_client._span_depths[span_id] = current_depth
1942
+ enter_entry = TraceEntry(
1943
+ type="enter", function=self._span_name, span_id=span_id,
1944
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
1945
+ created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
1946
+ )
1947
+ self._trace_client.add_entry(enter_entry)
1948
+ # --- End manual 'enter' entry ---
1949
+
1950
+ # Set the current span ID in contextvars
1951
+ self._span_context_token = current_span_var.set(span_id)
1952
+
1953
+ # Manually create 'input' entry
1954
+ input_data = _format_input_data(self._client, **self._input_kwargs)
1955
+ input_entry = TraceEntry(
1956
+ type="input", function=self._span_name, span_id=span_id,
1957
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
1958
+ created_at=time.time(), inputs=input_data, span_type="llm"
1959
+ )
1960
+ self._trace_client.add_entry(input_entry)
1961
+
1962
+ # Call the original __aenter__
1963
+ raw_iterator = await self._original_manager.__aenter__()
1964
+
1965
+ # Manually create pending 'output' entry
1966
+ output_entry = TraceEntry(
1967
+ type="output", function=self._span_name, span_id=span_id,
1968
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
1969
+ created_at=time.time(), output="<pending stream>", span_type="llm"
1970
+ )
1971
+ self._trace_client.add_entry(output_entry)
1972
+
1973
+ # Wrap the raw iterator
1974
+ wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
1975
+ return wrapped_iterator
1976
+
1977
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
1978
+ # Manually create the 'exit' entry
1979
+ if hasattr(self, '_span_context_token'):
1980
+ span_id = current_span_var.get()
1981
+ start_time_for_duration = 0
1982
+ for entry in reversed(self._trace_client.entries):
1983
+ if entry.span_id == span_id and entry.type == 'enter':
1984
+ start_time_for_duration = entry.created_at
1985
+ break
1986
+ duration = time.time() - start_time_for_duration if start_time_for_duration else None
1987
+ exit_depth = self._trace_client._span_depths.get(span_id, 0)
1988
+ exit_entry = TraceEntry(
1989
+ type="exit", function=self._span_name, span_id=span_id,
1990
+ trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
1991
+ created_at=time.time(), duration=duration, span_type="llm"
1992
+ )
1993
+ self._trace_client.add_entry(exit_entry)
1994
+ if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
1995
+ current_span_var.reset(self._span_context_token)
1996
+ delattr(self, '_span_context_token')
1997
+
1998
+ # Delegate __aexit__
1999
+ if hasattr(self._original_manager, "__aexit__"):
2000
+ return await self._original_manager.__aexit__(exc_type, exc_val, exc_tb)
2001
+ return None
2002
+
2003
+ class _TracedSyncStreamManagerWrapper(AbstractContextManager):
2004
+ """Wraps an original sync stream manager to add tracing."""
2005
+ def __init__(self, original_manager, client, span_name, trace_client, stream_wrapper_func, input_kwargs):
2006
+ self._original_manager = original_manager
2007
+ self._client = client
2008
+ self._span_name = span_name
2009
+ self._trace_client = trace_client
2010
+ self._stream_wrapper_func = stream_wrapper_func
2011
+ self._input_kwargs = input_kwargs
2012
+ self._parent_span_id_at_entry = None
2013
+
2014
+ def __enter__(self):
2015
+ self._parent_span_id_at_entry = current_span_var.get()
2016
+ if not self._trace_client:
2017
+ return self._original_manager.__enter__()
2018
+
2019
+ # Manually create 'enter' entry
2020
+ start_time = time.time()
2021
+ span_id = str(uuid.uuid4())
2022
+ current_depth = 0
2023
+ if self._parent_span_id_at_entry and self._parent_span_id_at_entry in self._trace_client._span_depths:
2024
+ current_depth = self._trace_client._span_depths[self._parent_span_id_at_entry] + 1
2025
+ self._trace_client._span_depths[span_id] = current_depth
2026
+ enter_entry = TraceEntry(
2027
+ type="enter", function=self._span_name, span_id=span_id,
2028
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=self._span_name,
2029
+ created_at=start_time, span_type="llm", parent_span_id=self._parent_span_id_at_entry
2030
+ )
2031
+ self._trace_client.add_entry(enter_entry)
2032
+ self._span_context_token = current_span_var.set(span_id)
2033
+
2034
+ # Manually create 'input' entry
2035
+ input_data = _format_input_data(self._client, **self._input_kwargs)
2036
+ input_entry = TraceEntry(
2037
+ type="input", function=self._span_name, span_id=span_id,
2038
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Inputs to {self._span_name}",
2039
+ created_at=time.time(), inputs=input_data, span_type="llm"
2040
+ )
2041
+ self._trace_client.add_entry(input_entry)
2042
+
2043
+ # Call original __enter__
2044
+ raw_iterator = self._original_manager.__enter__()
2045
+
2046
+ # Manually create 'output' entry (pending)
2047
+ output_entry = TraceEntry(
2048
+ type="output", function=self._span_name, span_id=span_id,
2049
+ trace_id=self._trace_client.trace_id, depth=current_depth, message=f"Output from {self._span_name}",
2050
+ created_at=time.time(), output="<pending stream>", span_type="llm"
2051
+ )
2052
+ self._trace_client.add_entry(output_entry)
2053
+
2054
+ # Wrap the raw iterator
2055
+ wrapped_iterator = self._stream_wrapper_func(raw_iterator, self._client, output_entry)
2056
+ return wrapped_iterator
2057
+
2058
+ def __exit__(self, exc_type, exc_val, exc_tb):
2059
+ # Manually create 'exit' entry
2060
+ if hasattr(self, '_span_context_token'):
2061
+ span_id = current_span_var.get()
2062
+ start_time_for_duration = 0
2063
+ for entry in reversed(self._trace_client.entries):
2064
+ if entry.span_id == span_id and entry.type == 'enter':
2065
+ start_time_for_duration = entry.created_at
2066
+ break
2067
+ duration = time.time() - start_time_for_duration if start_time_for_duration else None
2068
+ exit_depth = self._trace_client._span_depths.get(span_id, 0)
2069
+ exit_entry = TraceEntry(
2070
+ type="exit", function=self._span_name, span_id=span_id,
2071
+ trace_id=self._trace_client.trace_id, depth=exit_depth, message=f"← {self._span_name}",
2072
+ created_at=time.time(), duration=duration, span_type="llm"
2073
+ )
2074
+ self._trace_client.add_entry(exit_entry)
2075
+ if span_id in self._trace_client._span_depths: del self._trace_client._span_depths[span_id]
2076
+ current_span_var.reset(self._span_context_token)
2077
+ delattr(self, '_span_context_token')
2078
+
2079
+ # Delegate __exit__
2080
+ if hasattr(self._original_manager, "__exit__"):
2081
+ return self._original_manager.__exit__(exc_type, exc_val, exc_tb)
2082
+ return None