braintrust 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
+ import asyncio
1
2
  import logging
2
3
  import sys
3
4
  import time
4
- from collections.abc import AsyncGenerator, Iterable
5
5
  from contextlib import AbstractAsyncContextManager
6
- from typing import Any, TypeVar
6
+ from typing import Any
7
7
 
8
+ from braintrust.bt_json import bt_safe_deep_copy
8
9
  from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
9
10
  from braintrust.span_types import SpanTypeAttribute
10
11
  from wrapt import wrap_function_wrapper
@@ -74,7 +75,7 @@ def wrap_agent(Agent: Any) -> Any:
74
75
  name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
75
76
  type=SpanTypeAttribute.LLM,
76
77
  input=input_data if input_data else None,
77
- metadata=_try_dict(metadata),
78
+ metadata=metadata,
78
79
  ) as agent_span:
79
80
  start_time = time.time()
80
81
  result = await wrapped(*args, **kwargs)
@@ -98,7 +99,7 @@ def wrap_agent(Agent: Any) -> Any:
98
99
  else "agent_run_sync",
99
100
  type=SpanTypeAttribute.LLM,
100
101
  input=input_data if input_data else None,
101
- metadata=_try_dict(metadata),
102
+ metadata=metadata,
102
103
  ) as agent_span:
103
104
  start_time = time.time()
104
105
  result = wrapped(*args, **kwargs)
@@ -138,7 +139,7 @@ def wrap_agent(Agent: Any) -> Any:
138
139
  name=span_name,
139
140
  type=SpanTypeAttribute.LLM,
140
141
  input=input_data if input_data else None,
141
- metadata=_try_dict(metadata),
142
+ metadata=metadata,
142
143
  )
143
144
  span = span_cm.__enter__()
144
145
  start_time = time.time()
@@ -170,7 +171,7 @@ def wrap_agent(Agent: Any) -> Any:
170
171
  name=span_name,
171
172
  type=SpanTypeAttribute.LLM,
172
173
  input=input_data if input_data else None,
173
- metadata=_try_dict(metadata),
174
+ metadata=metadata,
174
175
  ) as agent_span:
175
176
  start_time = time.time()
176
177
  event_count = 0
@@ -216,7 +217,7 @@ def _create_direct_model_request_wrapper():
216
217
  name="model_request",
217
218
  type=SpanTypeAttribute.LLM,
218
219
  input=input_data,
219
- metadata=_try_dict(metadata),
220
+ metadata=metadata,
220
221
  ) as span:
221
222
  start_time = time.time()
222
223
  result = await wrapped(*args, **kwargs)
@@ -241,7 +242,7 @@ def _create_direct_model_request_sync_wrapper():
241
242
  name="model_request_sync",
242
243
  type=SpanTypeAttribute.LLM,
243
244
  input=input_data,
244
- metadata=_try_dict(metadata),
245
+ metadata=metadata,
245
246
  ) as span:
246
247
  start_time = time.time()
247
248
  result = wrapped(*args, **kwargs)
@@ -296,7 +297,7 @@ def wrap_model_request(original_func: Any) -> Any:
296
297
  name="model_request",
297
298
  type=SpanTypeAttribute.LLM,
298
299
  input=input_data,
299
- metadata=_try_dict(metadata),
300
+ metadata=metadata,
300
301
  ) as span:
301
302
  start_time = time.time()
302
303
  result = await original_func(*args, **kwargs)
@@ -319,7 +320,7 @@ def wrap_model_request_sync(original_func: Any) -> Any:
319
320
  name="model_request_sync",
320
321
  type=SpanTypeAttribute.LLM,
321
322
  input=input_data,
322
- metadata=_try_dict(metadata),
323
+ metadata=metadata,
323
324
  ) as span:
324
325
  start_time = time.time()
325
326
  result = original_func(*args, **kwargs)
@@ -391,7 +392,7 @@ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any)
391
392
  Tuple of (model_name, display_name, input_data, metadata)
392
393
  """
393
394
  model_name, provider = _extract_model_info_from_model_instance(instance)
394
- display_name = model_name or str(instance)
395
+ display_name = model_name or type(instance).__name__
395
396
 
396
397
  messages = args[0] if len(args) > 0 else kwargs.get("messages")
397
398
  model_settings = args[1] if len(args) > 1 else kwargs.get("model_settings")
@@ -400,7 +401,7 @@ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any)
400
401
 
401
402
  input_data = {"messages": serialized_messages}
402
403
  if model_settings is not None:
403
- input_data["model_settings"] = _try_dict(model_settings)
404
+ input_data["model_settings"] = bt_safe_deep_copy(model_settings)
404
405
 
405
406
  metadata = _build_model_metadata(model_name, provider, model_settings=None)
406
407
 
@@ -419,7 +420,7 @@ def _wrap_concrete_model_class(model_class: Any):
419
420
  name=f"chat {display_name}",
420
421
  type=SpanTypeAttribute.LLM,
421
422
  input=input_data,
422
- metadata=_try_dict(metadata),
423
+ metadata=metadata,
423
424
  ) as span:
424
425
  start_time = time.time()
425
426
  result = await wrapped(*args, **kwargs)
@@ -457,22 +458,28 @@ class _AgentStreamWrapper(AbstractAsyncContextManager):
457
458
  self.span_cm = None
458
459
  self.start_time = None
459
460
  self.stream_result = None
461
+ self._enter_task = None
462
+ self._first_token_time = None
460
463
 
461
464
  async def __aenter__(self):
465
+ self._enter_task = asyncio.current_task()
466
+
462
467
  # Use context manager properly so span stays current
463
468
  # DON'T pass start_time here - we'll set it via metrics in __aexit__
464
469
  self.span_cm = start_span(
465
470
  name=self.span_name,
466
471
  type=SpanTypeAttribute.LLM,
467
472
  input=self.input_data if self.input_data else None,
468
- metadata=_try_dict(self.metadata),
473
+ metadata=self.metadata,
469
474
  )
470
- span = self.span_cm.__enter__()
475
+ self.span_cm.__enter__()
471
476
 
472
477
  # Capture start time right before entering the stream (API call initiation)
473
478
  self.start_time = time.time()
474
479
  self.stream_result = await self.stream_cm.__aenter__()
475
- return self.stream_result # Return actual stream result object
480
+
481
+ # Wrap the stream result to capture first token time
482
+ return _StreamResultProxy(self.stream_result, self)
476
483
 
477
484
  async def __aexit__(self, exc_type, exc_val, exc_tb):
478
485
  try:
@@ -482,16 +489,47 @@ class _AgentStreamWrapper(AbstractAsyncContextManager):
482
489
  end_time = time.time()
483
490
 
484
491
  output = _serialize_stream_output(self.stream_result)
485
- metrics = _extract_stream_usage_metrics(self.stream_result, self.start_time, end_time, None)
492
+ metrics = _extract_stream_usage_metrics(
493
+ self.stream_result, self.start_time, end_time, self._first_token_time
494
+ )
486
495
  self.span_cm.log(output=output, metrics=metrics)
487
496
 
488
- # Always clean up span context
497
+ # Clean up span context
489
498
  if self.span_cm:
490
- self.span_cm.__exit__(None, None, None)
499
+ if asyncio.current_task() is self._enter_task:
500
+ self.span_cm.__exit__(None, None, None)
501
+ else:
502
+ self.span_cm.end()
491
503
 
492
504
  return False
493
505
 
494
506
 
507
+ class _StreamResultProxy:
508
+ """Proxy for stream result that captures first token time."""
509
+
510
+ def __init__(self, stream_result: Any, wrapper: _AgentStreamWrapper):
511
+ self._stream_result = stream_result
512
+ self._wrapper = wrapper
513
+
514
+ def __getattr__(self, name: str):
515
+ """Delegate all attribute access to the wrapped stream result."""
516
+ attr = getattr(self._stream_result, name)
517
+
518
+ # Wrap streaming methods to capture first token time
519
+ if callable(attr) and name in ("stream_text", "stream_output"):
520
+
521
+ async def wrapped_method(*args, **kwargs):
522
+ result = attr(*args, **kwargs)
523
+ async for item in result:
524
+ if self._wrapper._first_token_time is None:
525
+ self._wrapper._first_token_time = time.time()
526
+ yield item
527
+
528
+ return wrapped_method
529
+
530
+ return attr
531
+
532
+
495
533
  class _DirectStreamWrapper(AbstractAsyncContextManager):
496
534
  """Wrapper for model_request_stream() that adds tracing while passing through the stream."""
497
535
 
@@ -503,22 +541,28 @@ class _DirectStreamWrapper(AbstractAsyncContextManager):
503
541
  self.span_cm = None
504
542
  self.start_time = None
505
543
  self.stream = None
544
+ self._enter_task = None
545
+ self._first_token_time = None
506
546
 
507
547
  async def __aenter__(self):
548
+ self._enter_task = asyncio.current_task()
549
+
508
550
  # Use context manager properly so span stays current
509
551
  # DON'T pass start_time here - we'll set it via metrics in __aexit__
510
552
  self.span_cm = start_span(
511
553
  name=self.span_name,
512
554
  type=SpanTypeAttribute.LLM,
513
555
  input=self.input_data if self.input_data else None,
514
- metadata=_try_dict(self.metadata),
556
+ metadata=self.metadata,
515
557
  )
516
- span = self.span_cm.__enter__()
558
+ self.span_cm.__enter__()
517
559
 
518
560
  # Capture start time right before entering the stream (API call initiation)
519
561
  self.start_time = time.time()
520
562
  self.stream = await self.stream_cm.__aenter__()
521
- return self.stream # Return actual stream object
563
+
564
+ # Wrap the stream to capture first token time
565
+ return _DirectStreamIteratorProxy(self.stream, self)
522
566
 
523
567
  async def __aexit__(self, exc_type, exc_val, exc_tb):
524
568
  try:
@@ -530,18 +574,53 @@ class _DirectStreamWrapper(AbstractAsyncContextManager):
530
574
  try:
531
575
  final_response = self.stream.get()
532
576
  output = _serialize_model_response(final_response)
533
- metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
577
+ metrics = _extract_response_metrics(
578
+ final_response, self.start_time, end_time, self._first_token_time
579
+ )
534
580
  self.span_cm.log(output=output, metrics=metrics)
535
581
  except Exception as e:
536
582
  logger.debug(f"Failed to extract stream output/metrics: {e}")
537
583
 
538
- # Always clean up span context
584
+ # Clean up span context
539
585
  if self.span_cm:
540
- self.span_cm.__exit__(None, None, None)
586
+ if asyncio.current_task() is self._enter_task:
587
+ self.span_cm.__exit__(None, None, None)
588
+ else:
589
+ self.span_cm.end()
541
590
 
542
591
  return False
543
592
 
544
593
 
594
+ class _DirectStreamIteratorProxy:
595
+ """Proxy for direct stream that captures first token time."""
596
+
597
+ def __init__(self, stream: Any, wrapper: _DirectStreamWrapper):
598
+ self._stream = stream
599
+ self._wrapper = wrapper
600
+ self._iterator = None
601
+
602
+ def __getattr__(self, name: str):
603
+ """Delegate all attribute access to the wrapped stream."""
604
+ return getattr(self._stream, name)
605
+
606
+ def __aiter__(self):
607
+ """Return async iterator that captures first token time."""
608
+ # Get the actual async iterator from the stream
609
+ self._iterator = self._stream.__aiter__() if hasattr(self._stream, "__aiter__") else self._stream
610
+ return self
611
+
612
+ async def __anext__(self):
613
+ """Capture first token time on first iteration."""
614
+ if self._iterator is None:
615
+ # In case __aiter__ wasn't called, initialize it
616
+ self._iterator = self._stream.__aiter__() if hasattr(self._stream, "__aiter__") else self._stream
617
+
618
+ item = await self._iterator.__anext__()
619
+ if self._wrapper._first_token_time is None:
620
+ self._wrapper._first_token_time = time.time()
621
+ return item
622
+
623
+
545
624
  class _AgentStreamResultSyncProxy:
546
625
  """Proxy for agent.run_stream_sync() result that adds tracing while delegating to actual stream result."""
547
626
 
@@ -552,20 +631,25 @@ class _AgentStreamResultSyncProxy:
552
631
  self._start_time = start_time
553
632
  self._logged = False
554
633
  self._finalize_on_del = True
634
+ self._first_token_time = None
555
635
 
556
636
  def __getattr__(self, name: str):
557
637
  """Delegate all attribute access to the wrapped stream result."""
558
638
  attr = getattr(self._stream_result, name)
559
639
 
560
640
  # Wrap any method that returns an iterator to auto-finalize when exhausted
561
- if callable(attr) and name in ('stream_text', 'stream_output', '__iter__'):
641
+ if callable(attr) and name in ("stream_text", "stream_output", "__iter__"):
642
+
562
643
  def wrapped_method(*args, **kwargs):
563
644
  try:
564
645
  iterator = attr(*args, **kwargs)
565
646
  # If it's an iterator, wrap it
566
- if hasattr(iterator, '__iter__') or hasattr(iterator, '__next__'):
647
+ if hasattr(iterator, "__iter__") or hasattr(iterator, "__next__"):
567
648
  try:
568
- yield from iterator
649
+ for item in iterator:
650
+ if self._first_token_time is None:
651
+ self._first_token_time = time.time()
652
+ yield item
569
653
  finally:
570
654
  self._finalize()
571
655
  self._finalize_on_del = False # Don't finalize again in __del__
@@ -575,6 +659,7 @@ class _AgentStreamResultSyncProxy:
575
659
  self._finalize()
576
660
  self._finalize_on_del = False
577
661
  raise
662
+
578
663
  return wrapped_method
579
664
 
580
665
  return attr
@@ -585,7 +670,9 @@ class _AgentStreamResultSyncProxy:
585
670
  try:
586
671
  end_time = time.time()
587
672
  output = _serialize_stream_output(self._stream_result)
588
- metrics = _extract_stream_usage_metrics(self._stream_result, self._start_time, end_time, None)
673
+ metrics = _extract_stream_usage_metrics(
674
+ self._stream_result, self._start_time, end_time, self._first_token_time
675
+ )
589
676
  self._span.log(output=output, metrics=metrics)
590
677
  self._logged = True
591
678
  finally:
@@ -611,6 +698,7 @@ class _DirectStreamWrapperSync:
611
698
  self.span_cm = None
612
699
  self.start_time = None
613
700
  self.stream = None
701
+ self._first_token_time = None
614
702
 
615
703
  def __enter__(self):
616
704
  # Use context manager properly so span stays current
@@ -619,14 +707,16 @@ class _DirectStreamWrapperSync:
619
707
  name=self.span_name,
620
708
  type=SpanTypeAttribute.LLM,
621
709
  input=self.input_data if self.input_data else None,
622
- metadata=_try_dict(self.metadata),
710
+ metadata=self.metadata,
623
711
  )
624
712
  span = self.span_cm.__enter__()
625
713
 
626
714
  # Capture start time right before entering the stream (API call initiation)
627
715
  self.start_time = time.time()
628
716
  self.stream = self.stream_cm.__enter__()
629
- return self.stream # Return actual stream object
717
+
718
+ # Wrap the stream to capture first token time
719
+ return _DirectStreamIteratorSyncProxy(self.stream, self)
630
720
 
631
721
  def __exit__(self, exc_type, exc_val, exc_tb):
632
722
  try:
@@ -638,7 +728,9 @@ class _DirectStreamWrapperSync:
638
728
  try:
639
729
  final_response = self.stream.get()
640
730
  output = _serialize_model_response(final_response)
641
- metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
731
+ metrics = _extract_response_metrics(
732
+ final_response, self.start_time, end_time, self._first_token_time
733
+ )
642
734
  self.span_cm.log(output=output, metrics=metrics)
643
735
  except Exception as e:
644
736
  logger.debug(f"Failed to extract stream output/metrics: {e}")
@@ -650,6 +742,36 @@ class _DirectStreamWrapperSync:
650
742
  return False
651
743
 
652
744
 
745
+ class _DirectStreamIteratorSyncProxy:
746
+ """Proxy for direct stream (sync) that captures first token time."""
747
+
748
+ def __init__(self, stream: Any, wrapper: _DirectStreamWrapperSync):
749
+ self._stream = stream
750
+ self._wrapper = wrapper
751
+ self._iterator = None
752
+
753
+ def __getattr__(self, name: str):
754
+ """Delegate all attribute access to the wrapped stream."""
755
+ return getattr(self._stream, name)
756
+
757
+ def __iter__(self):
758
+ """Return iterator that captures first token time."""
759
+ # Get the actual iterator from the stream
760
+ self._iterator = self._stream.__iter__() if hasattr(self._stream, "__iter__") else self._stream
761
+ return self
762
+
763
+ def __next__(self):
764
+ """Capture first token time on first iteration."""
765
+ if self._iterator is None:
766
+ # In case __iter__ wasn't called, initialize it
767
+ self._iterator = self._stream.__iter__() if hasattr(self._stream, "__iter__") else self._stream
768
+
769
+ item = self._iterator.__next__()
770
+ if self._wrapper._first_token_time is None:
771
+ self._wrapper._first_token_time = time.time()
772
+ return item
773
+
774
+
653
775
  def _serialize_user_prompt(user_prompt: Any) -> Any:
654
776
  """Serialize user prompt, handling BinaryContent and other types."""
655
777
  if user_prompt is None:
@@ -665,7 +787,14 @@ def _serialize_user_prompt(user_prompt: Any) -> Any:
665
787
 
666
788
 
667
789
  def _serialize_content_part(part: Any) -> Any:
668
- """Serialize a content part, handling BinaryContent specially."""
790
+ """Serialize a content part, handling BinaryContent specially.
791
+
792
+ This function handles:
793
+ - BinaryContent: converts to Braintrust Attachment
794
+ - Parts with nested content (UserPromptPart): recursively serializes content items
795
+ - Strings: passes through unchanged
796
+ - Other objects: converts to dict via model_dump
797
+ """
669
798
  if part is None:
670
799
  return None
671
800
 
@@ -680,10 +809,25 @@ def _serialize_content_part(part: Any) -> Any:
680
809
  attachment = Attachment(data=data, filename=filename, content_type=media_type)
681
810
  return {"type": "binary", "attachment": attachment, "media_type": media_type}
682
811
 
812
+ if hasattr(part, "content"):
813
+ content = part.content
814
+ if isinstance(content, list):
815
+ serialized_content = [_serialize_content_part(item) for item in content]
816
+ result = bt_safe_deep_copy(part)
817
+ if isinstance(result, dict):
818
+ result["content"] = serialized_content
819
+ return result
820
+ elif content is not None:
821
+ serialized_content = _serialize_content_part(content)
822
+ result = bt_safe_deep_copy(part)
823
+ if isinstance(result, dict):
824
+ result["content"] = serialized_content
825
+ return result
826
+
683
827
  if isinstance(part, str):
684
828
  return part
685
829
 
686
- return _try_dict(part)
830
+ return bt_safe_deep_copy(part)
687
831
 
688
832
 
689
833
  def _serialize_messages(messages: Any) -> Any:
@@ -693,10 +837,24 @@ def _serialize_messages(messages: Any) -> Any:
693
837
 
694
838
  result = []
695
839
  for msg in messages:
696
- serialized_msg = _try_dict(msg)
840
+ if hasattr(msg, "parts") and msg.parts:
841
+ original_parts = msg.parts
842
+ serialized_parts = [_serialize_content_part(p) for p in original_parts]
697
843
 
698
- if isinstance(serialized_msg, dict) and "parts" in serialized_msg:
699
- serialized_msg["parts"] = [_serialize_content_part(p) for p in msg.parts]
844
+ # Use model_dump with exclude to avoid serializing parts field prematurely
845
+ if hasattr(msg, "model_dump"):
846
+ try:
847
+ serialized_msg = msg.model_dump(exclude={"parts"}, exclude_none=True)
848
+ except (TypeError, ValueError):
849
+ # If exclude parameter not supported, fall back to bt_safe_deep_copy
850
+ serialized_msg = bt_safe_deep_copy(msg)
851
+ else:
852
+ serialized_msg = bt_safe_deep_copy(msg)
853
+
854
+ if isinstance(serialized_msg, dict):
855
+ serialized_msg["parts"] = serialized_parts
856
+ else:
857
+ serialized_msg = bt_safe_deep_copy(msg)
700
858
 
701
859
  result.append(serialized_msg)
702
860
 
@@ -711,12 +869,12 @@ def _serialize_result_output(result: Any) -> Any:
711
869
  output_dict = {}
712
870
 
713
871
  if hasattr(result, "output"):
714
- output_dict["output"] = _try_dict(result.output)
872
+ output_dict["output"] = bt_safe_deep_copy(result.output)
715
873
 
716
874
  if hasattr(result, "response"):
717
875
  output_dict["response"] = _serialize_model_response(result.response)
718
876
 
719
- return output_dict if output_dict else _try_dict(result)
877
+ return output_dict if output_dict else bt_safe_deep_copy(result)
720
878
 
721
879
 
722
880
  def _serialize_stream_output(stream_result: Any) -> Any:
@@ -737,11 +895,10 @@ def _serialize_model_response(response: Any) -> Any:
737
895
  if not response:
738
896
  return None
739
897
 
740
- response_dict = _try_dict(response)
898
+ response_dict = bt_safe_deep_copy(response)
741
899
 
742
- if isinstance(response_dict, dict) and "parts" in response_dict:
743
- if hasattr(response, "parts"):
744
- response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
900
+ if hasattr(response, "parts") and isinstance(response_dict, dict):
901
+ response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
745
902
 
746
903
  return response_dict
747
904
 
@@ -801,9 +958,7 @@ def _extract_model_info(agent: Any) -> tuple[str | None, str | None]:
801
958
  return _extract_model_info_from_model_instance(agent.model)
802
959
 
803
960
 
804
- def _build_model_metadata(
805
- model_name: str | None, provider: str | None, model_settings: Any = None
806
- ) -> dict[str, Any]:
961
+ def _build_model_metadata(model_name: str | None, provider: str | None, model_settings: Any = None) -> dict[str, Any]:
807
962
  """Build metadata dictionary with model info.
808
963
 
809
964
  Args:
@@ -820,7 +975,7 @@ def _build_model_metadata(
820
975
  if provider:
821
976
  metadata["provider"] = provider
822
977
  if model_settings:
823
- metadata["model_settings"] = _try_dict(model_settings)
978
+ metadata["model_settings"] = bt_safe_deep_copy(model_settings)
824
979
  return metadata
825
980
 
826
981
 
@@ -986,24 +1141,6 @@ def _is_patched(obj: Any) -> bool:
986
1141
  return getattr(obj, "_braintrust_patched", False)
987
1142
 
988
1143
 
989
- def _try_dict(obj: Any) -> Iterable[Any] | dict[str, Any]:
990
- """Try to convert object to dict, handling Pydantic models and circular references."""
991
- if hasattr(obj, "model_dump"):
992
- try:
993
- obj = obj.model_dump(exclude_none=True)
994
- except ValueError as e:
995
- if "Circular reference" in str(e):
996
- return {}
997
- raise
998
-
999
- if isinstance(obj, dict):
1000
- return {k: _try_dict(v) for k, v in obj.items()}
1001
- elif isinstance(obj, (list, tuple)):
1002
- return [_try_dict(item) for item in obj]
1003
-
1004
- return obj
1005
-
1006
-
1007
1144
  def _serialize_type(obj: Any) -> Any:
1008
1145
  """Serialize a type/class for logging, handling Pydantic models and other types.
1009
1146
 
@@ -1047,32 +1184,7 @@ def _serialize_type(obj: Any) -> Any:
1047
1184
  return obj.__name__
1048
1185
 
1049
1186
  # Try standard serialization
1050
- return _try_dict(obj)
1051
-
1052
-
1053
- G = TypeVar("G", bound=AsyncGenerator[Any, None])
1054
-
1055
-
1056
- class aclosing(AbstractAsyncContextManager[G]):
1057
- """Context manager for closing async generators."""
1058
-
1059
- def __init__(self, async_generator: G):
1060
- self.async_generator = async_generator
1061
-
1062
- async def __aenter__(self):
1063
- return self.async_generator
1064
-
1065
- async def __aexit__(self, *exc_info: Any):
1066
- try:
1067
- await self.async_generator.aclose()
1068
- except ValueError as e:
1069
- if "was created in a different Context" not in str(e):
1070
- raise
1071
- else:
1072
- logger.debug(
1073
- f"Suppressed ContextVar error during async cleanup: {e}. "
1074
- "This is expected when async generators yield across context boundaries."
1075
- )
1187
+ return bt_safe_deep_copy(obj)
1076
1188
 
1077
1189
 
1078
1190
  def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -1097,9 +1209,9 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu
1097
1209
  input_data[key] = _serialize_type(value) if value is not None else None
1098
1210
  elif key == "model_settings":
1099
1211
  # model_settings passed to run() goes in INPUT (it's a run() parameter)
1100
- input_data[key] = _try_dict(value) if value is not None else None
1212
+ input_data[key] = bt_safe_deep_copy(value) if value is not None else None
1101
1213
  else:
1102
- input_data[key] = _try_dict(value) if value is not None else None
1214
+ input_data[key] = bt_safe_deep_copy(value) if value is not None else None
1103
1215
 
1104
1216
  if "model" in kwargs:
1105
1217
  model_name, provider = _parse_model_string(kwargs["model"])
@@ -1155,7 +1267,9 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu
1155
1267
  if hasattr(tool_obj, "description") and tool_obj.description:
1156
1268
  tool_dict["description"] = tool_obj.description
1157
1269
  # Extract JSON schema for parameters
1158
- if hasattr(tool_obj, "function_schema") and hasattr(tool_obj.function_schema, "json_schema"):
1270
+ if hasattr(tool_obj, "function_schema") and hasattr(
1271
+ tool_obj.function_schema, "json_schema"
1272
+ ):
1159
1273
  tool_dict["parameters"] = tool_obj.function_schema.json_schema
1160
1274
  tools_list.append(tool_dict)
1161
1275
  ts_info["tools"] = tools_list
@@ -1196,7 +1310,7 @@ def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[dict
1196
1310
 
1197
1311
  for key, value in kwargs.items():
1198
1312
  if key not in ["model", "messages"]:
1199
- input_data[key] = _try_dict(value) if value is not None else None
1313
+ input_data[key] = bt_safe_deep_copy(value) if value is not None else None
1200
1314
 
1201
1315
  model_name, provider = _parse_model_string(model)
1202
1316
  metadata = _build_model_metadata(model_name, provider)
@@ -547,7 +547,6 @@ def test_error_handling(memory_logger):
547
547
  assert log["error"]
548
548
 
549
549
 
550
- # Test 9: Stop Sequences
551
550
  @pytest.mark.vcr
552
551
  def test_stop_sequences(memory_logger):
553
552
  """Test stop sequences parameter."""
@@ -571,3 +570,65 @@ def test_stop_sequences(memory_logger):
571
570
  assert len(spans) == 1
572
571
  span = spans[0]
573
572
  assert span["metadata"]["model"] == MODEL
573
+
574
+
575
+ def test_attachment_in_config(memory_logger):
576
+ """Test that attachments in config are preserved through serialization."""
577
+ from braintrust.bt_json import bt_safe_deep_copy
578
+ from braintrust.logger import Attachment
579
+
580
+ attachment = Attachment(data=b"config data", filename="config.txt", content_type="text/plain")
581
+
582
+ # Simulate config with attachment
583
+ config = {"temperature": 0.5, "context_file": attachment, "max_output_tokens": 100}
584
+
585
+ # Test bt_safe_deep_copy preserves attachment
586
+ copied = bt_safe_deep_copy(config)
587
+ assert copied["context_file"] is attachment
588
+ assert copied["temperature"] == 0.5
589
+
590
+
591
+ def test_nested_attachments_in_contents(memory_logger):
592
+ """Test that nested attachments in contents are preserved."""
593
+ from braintrust.bt_json import bt_safe_deep_copy
594
+ from braintrust.logger import Attachment, ExternalAttachment
595
+
596
+ attachment1 = Attachment(data=b"file1", filename="file1.txt", content_type="text/plain")
597
+ attachment2 = ExternalAttachment(url="s3://bucket/file2.pdf", filename="file2.pdf", content_type="application/pdf")
598
+
599
+ # Simulate contents with nested attachments
600
+ contents = [
601
+ {"role": "user", "parts": [{"text": "Check these files"}, {"file": attachment1}]},
602
+ {"role": "model", "parts": [{"text": "Analyzed"}, {"result_file": attachment2}]},
603
+ ]
604
+
605
+ copied = bt_safe_deep_copy(contents)
606
+
607
+ # Verify attachments preserved
608
+ assert copied[0]["parts"][1]["file"] is attachment1
609
+ assert copied[1]["parts"][1]["result_file"] is attachment2
610
+
611
+
612
+ def test_attachment_with_pydantic_model(memory_logger):
613
+ """Test that attachments work alongside Pydantic model serialization."""
614
+ from braintrust.bt_json import bt_safe_deep_copy
615
+ from braintrust.logger import Attachment
616
+ from pydantic import BaseModel
617
+
618
+ class TestModel(BaseModel):
619
+ name: str
620
+ value: int
621
+
622
+ attachment = Attachment(data=b"model data", filename="model.txt", content_type="text/plain")
623
+
624
+ # Structure with both Pydantic model and attachment
625
+ data = {"model_config": TestModel(name="test", value=42), "context_file": attachment}
626
+
627
+ copied = bt_safe_deep_copy(data)
628
+
629
+ # Pydantic model should be converted to dict
630
+ assert isinstance(copied["model_config"], dict)
631
+ assert copied["model_config"]["name"] == "test"
632
+
633
+ # Attachment should be preserved
634
+ assert copied["context_file"] is attachment