braintrust 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/bt_json.py +178 -19
- braintrust/framework.py +11 -2
- braintrust/logger.py +30 -117
- braintrust/test_bt_json.py +644 -0
- braintrust/test_framework.py +56 -0
- braintrust/test_logger.py +211 -107
- braintrust/version.py +2 -2
- braintrust/wrappers/google_genai/__init__.py +2 -15
- braintrust/wrappers/pydantic_ai.py +209 -95
- braintrust/wrappers/test_google_genai.py +62 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
- {braintrust-0.4.0.dist-info → braintrust-0.4.1.dist-info}/METADATA +1 -1
- {braintrust-0.4.0.dist-info → braintrust-0.4.1.dist-info}/RECORD +16 -15
- {braintrust-0.4.0.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
- {braintrust-0.4.0.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
- {braintrust-0.4.0.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import logging
|
|
2
3
|
import sys
|
|
3
4
|
import time
|
|
4
|
-
from collections.abc import AsyncGenerator, Iterable
|
|
5
5
|
from contextlib import AbstractAsyncContextManager
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
8
9
|
from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
|
|
9
10
|
from braintrust.span_types import SpanTypeAttribute
|
|
10
11
|
from wrapt import wrap_function_wrapper
|
|
@@ -74,7 +75,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
74
75
|
name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
|
|
75
76
|
type=SpanTypeAttribute.LLM,
|
|
76
77
|
input=input_data if input_data else None,
|
|
77
|
-
metadata=
|
|
78
|
+
metadata=metadata,
|
|
78
79
|
) as agent_span:
|
|
79
80
|
start_time = time.time()
|
|
80
81
|
result = await wrapped(*args, **kwargs)
|
|
@@ -98,7 +99,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
98
99
|
else "agent_run_sync",
|
|
99
100
|
type=SpanTypeAttribute.LLM,
|
|
100
101
|
input=input_data if input_data else None,
|
|
101
|
-
metadata=
|
|
102
|
+
metadata=metadata,
|
|
102
103
|
) as agent_span:
|
|
103
104
|
start_time = time.time()
|
|
104
105
|
result = wrapped(*args, **kwargs)
|
|
@@ -138,7 +139,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
138
139
|
name=span_name,
|
|
139
140
|
type=SpanTypeAttribute.LLM,
|
|
140
141
|
input=input_data if input_data else None,
|
|
141
|
-
metadata=
|
|
142
|
+
metadata=metadata,
|
|
142
143
|
)
|
|
143
144
|
span = span_cm.__enter__()
|
|
144
145
|
start_time = time.time()
|
|
@@ -170,7 +171,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
170
171
|
name=span_name,
|
|
171
172
|
type=SpanTypeAttribute.LLM,
|
|
172
173
|
input=input_data if input_data else None,
|
|
173
|
-
metadata=
|
|
174
|
+
metadata=metadata,
|
|
174
175
|
) as agent_span:
|
|
175
176
|
start_time = time.time()
|
|
176
177
|
event_count = 0
|
|
@@ -216,7 +217,7 @@ def _create_direct_model_request_wrapper():
|
|
|
216
217
|
name="model_request",
|
|
217
218
|
type=SpanTypeAttribute.LLM,
|
|
218
219
|
input=input_data,
|
|
219
|
-
metadata=
|
|
220
|
+
metadata=metadata,
|
|
220
221
|
) as span:
|
|
221
222
|
start_time = time.time()
|
|
222
223
|
result = await wrapped(*args, **kwargs)
|
|
@@ -241,7 +242,7 @@ def _create_direct_model_request_sync_wrapper():
|
|
|
241
242
|
name="model_request_sync",
|
|
242
243
|
type=SpanTypeAttribute.LLM,
|
|
243
244
|
input=input_data,
|
|
244
|
-
metadata=
|
|
245
|
+
metadata=metadata,
|
|
245
246
|
) as span:
|
|
246
247
|
start_time = time.time()
|
|
247
248
|
result = wrapped(*args, **kwargs)
|
|
@@ -296,7 +297,7 @@ def wrap_model_request(original_func: Any) -> Any:
|
|
|
296
297
|
name="model_request",
|
|
297
298
|
type=SpanTypeAttribute.LLM,
|
|
298
299
|
input=input_data,
|
|
299
|
-
metadata=
|
|
300
|
+
metadata=metadata,
|
|
300
301
|
) as span:
|
|
301
302
|
start_time = time.time()
|
|
302
303
|
result = await original_func(*args, **kwargs)
|
|
@@ -319,7 +320,7 @@ def wrap_model_request_sync(original_func: Any) -> Any:
|
|
|
319
320
|
name="model_request_sync",
|
|
320
321
|
type=SpanTypeAttribute.LLM,
|
|
321
322
|
input=input_data,
|
|
322
|
-
metadata=
|
|
323
|
+
metadata=metadata,
|
|
323
324
|
) as span:
|
|
324
325
|
start_time = time.time()
|
|
325
326
|
result = original_func(*args, **kwargs)
|
|
@@ -391,7 +392,7 @@ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any)
|
|
|
391
392
|
Tuple of (model_name, display_name, input_data, metadata)
|
|
392
393
|
"""
|
|
393
394
|
model_name, provider = _extract_model_info_from_model_instance(instance)
|
|
394
|
-
display_name = model_name or
|
|
395
|
+
display_name = model_name or type(instance).__name__
|
|
395
396
|
|
|
396
397
|
messages = args[0] if len(args) > 0 else kwargs.get("messages")
|
|
397
398
|
model_settings = args[1] if len(args) > 1 else kwargs.get("model_settings")
|
|
@@ -400,7 +401,7 @@ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any)
|
|
|
400
401
|
|
|
401
402
|
input_data = {"messages": serialized_messages}
|
|
402
403
|
if model_settings is not None:
|
|
403
|
-
input_data["model_settings"] =
|
|
404
|
+
input_data["model_settings"] = bt_safe_deep_copy(model_settings)
|
|
404
405
|
|
|
405
406
|
metadata = _build_model_metadata(model_name, provider, model_settings=None)
|
|
406
407
|
|
|
@@ -419,7 +420,7 @@ def _wrap_concrete_model_class(model_class: Any):
|
|
|
419
420
|
name=f"chat {display_name}",
|
|
420
421
|
type=SpanTypeAttribute.LLM,
|
|
421
422
|
input=input_data,
|
|
422
|
-
metadata=
|
|
423
|
+
metadata=metadata,
|
|
423
424
|
) as span:
|
|
424
425
|
start_time = time.time()
|
|
425
426
|
result = await wrapped(*args, **kwargs)
|
|
@@ -457,22 +458,28 @@ class _AgentStreamWrapper(AbstractAsyncContextManager):
|
|
|
457
458
|
self.span_cm = None
|
|
458
459
|
self.start_time = None
|
|
459
460
|
self.stream_result = None
|
|
461
|
+
self._enter_task = None
|
|
462
|
+
self._first_token_time = None
|
|
460
463
|
|
|
461
464
|
async def __aenter__(self):
|
|
465
|
+
self._enter_task = asyncio.current_task()
|
|
466
|
+
|
|
462
467
|
# Use context manager properly so span stays current
|
|
463
468
|
# DON'T pass start_time here - we'll set it via metrics in __aexit__
|
|
464
469
|
self.span_cm = start_span(
|
|
465
470
|
name=self.span_name,
|
|
466
471
|
type=SpanTypeAttribute.LLM,
|
|
467
472
|
input=self.input_data if self.input_data else None,
|
|
468
|
-
metadata=
|
|
473
|
+
metadata=self.metadata,
|
|
469
474
|
)
|
|
470
|
-
|
|
475
|
+
self.span_cm.__enter__()
|
|
471
476
|
|
|
472
477
|
# Capture start time right before entering the stream (API call initiation)
|
|
473
478
|
self.start_time = time.time()
|
|
474
479
|
self.stream_result = await self.stream_cm.__aenter__()
|
|
475
|
-
|
|
480
|
+
|
|
481
|
+
# Wrap the stream result to capture first token time
|
|
482
|
+
return _StreamResultProxy(self.stream_result, self)
|
|
476
483
|
|
|
477
484
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
478
485
|
try:
|
|
@@ -482,16 +489,47 @@ class _AgentStreamWrapper(AbstractAsyncContextManager):
|
|
|
482
489
|
end_time = time.time()
|
|
483
490
|
|
|
484
491
|
output = _serialize_stream_output(self.stream_result)
|
|
485
|
-
metrics = _extract_stream_usage_metrics(
|
|
492
|
+
metrics = _extract_stream_usage_metrics(
|
|
493
|
+
self.stream_result, self.start_time, end_time, self._first_token_time
|
|
494
|
+
)
|
|
486
495
|
self.span_cm.log(output=output, metrics=metrics)
|
|
487
496
|
|
|
488
|
-
#
|
|
497
|
+
# Clean up span context
|
|
489
498
|
if self.span_cm:
|
|
490
|
-
|
|
499
|
+
if asyncio.current_task() is self._enter_task:
|
|
500
|
+
self.span_cm.__exit__(None, None, None)
|
|
501
|
+
else:
|
|
502
|
+
self.span_cm.end()
|
|
491
503
|
|
|
492
504
|
return False
|
|
493
505
|
|
|
494
506
|
|
|
507
|
+
class _StreamResultProxy:
|
|
508
|
+
"""Proxy for stream result that captures first token time."""
|
|
509
|
+
|
|
510
|
+
def __init__(self, stream_result: Any, wrapper: _AgentStreamWrapper):
|
|
511
|
+
self._stream_result = stream_result
|
|
512
|
+
self._wrapper = wrapper
|
|
513
|
+
|
|
514
|
+
def __getattr__(self, name: str):
|
|
515
|
+
"""Delegate all attribute access to the wrapped stream result."""
|
|
516
|
+
attr = getattr(self._stream_result, name)
|
|
517
|
+
|
|
518
|
+
# Wrap streaming methods to capture first token time
|
|
519
|
+
if callable(attr) and name in ("stream_text", "stream_output"):
|
|
520
|
+
|
|
521
|
+
async def wrapped_method(*args, **kwargs):
|
|
522
|
+
result = attr(*args, **kwargs)
|
|
523
|
+
async for item in result:
|
|
524
|
+
if self._wrapper._first_token_time is None:
|
|
525
|
+
self._wrapper._first_token_time = time.time()
|
|
526
|
+
yield item
|
|
527
|
+
|
|
528
|
+
return wrapped_method
|
|
529
|
+
|
|
530
|
+
return attr
|
|
531
|
+
|
|
532
|
+
|
|
495
533
|
class _DirectStreamWrapper(AbstractAsyncContextManager):
|
|
496
534
|
"""Wrapper for model_request_stream() that adds tracing while passing through the stream."""
|
|
497
535
|
|
|
@@ -503,22 +541,28 @@ class _DirectStreamWrapper(AbstractAsyncContextManager):
|
|
|
503
541
|
self.span_cm = None
|
|
504
542
|
self.start_time = None
|
|
505
543
|
self.stream = None
|
|
544
|
+
self._enter_task = None
|
|
545
|
+
self._first_token_time = None
|
|
506
546
|
|
|
507
547
|
async def __aenter__(self):
|
|
548
|
+
self._enter_task = asyncio.current_task()
|
|
549
|
+
|
|
508
550
|
# Use context manager properly so span stays current
|
|
509
551
|
# DON'T pass start_time here - we'll set it via metrics in __aexit__
|
|
510
552
|
self.span_cm = start_span(
|
|
511
553
|
name=self.span_name,
|
|
512
554
|
type=SpanTypeAttribute.LLM,
|
|
513
555
|
input=self.input_data if self.input_data else None,
|
|
514
|
-
metadata=
|
|
556
|
+
metadata=self.metadata,
|
|
515
557
|
)
|
|
516
|
-
|
|
558
|
+
self.span_cm.__enter__()
|
|
517
559
|
|
|
518
560
|
# Capture start time right before entering the stream (API call initiation)
|
|
519
561
|
self.start_time = time.time()
|
|
520
562
|
self.stream = await self.stream_cm.__aenter__()
|
|
521
|
-
|
|
563
|
+
|
|
564
|
+
# Wrap the stream to capture first token time
|
|
565
|
+
return _DirectStreamIteratorProxy(self.stream, self)
|
|
522
566
|
|
|
523
567
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
524
568
|
try:
|
|
@@ -530,18 +574,53 @@ class _DirectStreamWrapper(AbstractAsyncContextManager):
|
|
|
530
574
|
try:
|
|
531
575
|
final_response = self.stream.get()
|
|
532
576
|
output = _serialize_model_response(final_response)
|
|
533
|
-
metrics = _extract_response_metrics(
|
|
577
|
+
metrics = _extract_response_metrics(
|
|
578
|
+
final_response, self.start_time, end_time, self._first_token_time
|
|
579
|
+
)
|
|
534
580
|
self.span_cm.log(output=output, metrics=metrics)
|
|
535
581
|
except Exception as e:
|
|
536
582
|
logger.debug(f"Failed to extract stream output/metrics: {e}")
|
|
537
583
|
|
|
538
|
-
#
|
|
584
|
+
# Clean up span context
|
|
539
585
|
if self.span_cm:
|
|
540
|
-
|
|
586
|
+
if asyncio.current_task() is self._enter_task:
|
|
587
|
+
self.span_cm.__exit__(None, None, None)
|
|
588
|
+
else:
|
|
589
|
+
self.span_cm.end()
|
|
541
590
|
|
|
542
591
|
return False
|
|
543
592
|
|
|
544
593
|
|
|
594
|
+
class _DirectStreamIteratorProxy:
|
|
595
|
+
"""Proxy for direct stream that captures first token time."""
|
|
596
|
+
|
|
597
|
+
def __init__(self, stream: Any, wrapper: _DirectStreamWrapper):
|
|
598
|
+
self._stream = stream
|
|
599
|
+
self._wrapper = wrapper
|
|
600
|
+
self._iterator = None
|
|
601
|
+
|
|
602
|
+
def __getattr__(self, name: str):
|
|
603
|
+
"""Delegate all attribute access to the wrapped stream."""
|
|
604
|
+
return getattr(self._stream, name)
|
|
605
|
+
|
|
606
|
+
def __aiter__(self):
|
|
607
|
+
"""Return async iterator that captures first token time."""
|
|
608
|
+
# Get the actual async iterator from the stream
|
|
609
|
+
self._iterator = self._stream.__aiter__() if hasattr(self._stream, "__aiter__") else self._stream
|
|
610
|
+
return self
|
|
611
|
+
|
|
612
|
+
async def __anext__(self):
|
|
613
|
+
"""Capture first token time on first iteration."""
|
|
614
|
+
if self._iterator is None:
|
|
615
|
+
# In case __aiter__ wasn't called, initialize it
|
|
616
|
+
self._iterator = self._stream.__aiter__() if hasattr(self._stream, "__aiter__") else self._stream
|
|
617
|
+
|
|
618
|
+
item = await self._iterator.__anext__()
|
|
619
|
+
if self._wrapper._first_token_time is None:
|
|
620
|
+
self._wrapper._first_token_time = time.time()
|
|
621
|
+
return item
|
|
622
|
+
|
|
623
|
+
|
|
545
624
|
class _AgentStreamResultSyncProxy:
|
|
546
625
|
"""Proxy for agent.run_stream_sync() result that adds tracing while delegating to actual stream result."""
|
|
547
626
|
|
|
@@ -552,20 +631,25 @@ class _AgentStreamResultSyncProxy:
|
|
|
552
631
|
self._start_time = start_time
|
|
553
632
|
self._logged = False
|
|
554
633
|
self._finalize_on_del = True
|
|
634
|
+
self._first_token_time = None
|
|
555
635
|
|
|
556
636
|
def __getattr__(self, name: str):
|
|
557
637
|
"""Delegate all attribute access to the wrapped stream result."""
|
|
558
638
|
attr = getattr(self._stream_result, name)
|
|
559
639
|
|
|
560
640
|
# Wrap any method that returns an iterator to auto-finalize when exhausted
|
|
561
|
-
if callable(attr) and name in (
|
|
641
|
+
if callable(attr) and name in ("stream_text", "stream_output", "__iter__"):
|
|
642
|
+
|
|
562
643
|
def wrapped_method(*args, **kwargs):
|
|
563
644
|
try:
|
|
564
645
|
iterator = attr(*args, **kwargs)
|
|
565
646
|
# If it's an iterator, wrap it
|
|
566
|
-
if hasattr(iterator,
|
|
647
|
+
if hasattr(iterator, "__iter__") or hasattr(iterator, "__next__"):
|
|
567
648
|
try:
|
|
568
|
-
|
|
649
|
+
for item in iterator:
|
|
650
|
+
if self._first_token_time is None:
|
|
651
|
+
self._first_token_time = time.time()
|
|
652
|
+
yield item
|
|
569
653
|
finally:
|
|
570
654
|
self._finalize()
|
|
571
655
|
self._finalize_on_del = False # Don't finalize again in __del__
|
|
@@ -575,6 +659,7 @@ class _AgentStreamResultSyncProxy:
|
|
|
575
659
|
self._finalize()
|
|
576
660
|
self._finalize_on_del = False
|
|
577
661
|
raise
|
|
662
|
+
|
|
578
663
|
return wrapped_method
|
|
579
664
|
|
|
580
665
|
return attr
|
|
@@ -585,7 +670,9 @@ class _AgentStreamResultSyncProxy:
|
|
|
585
670
|
try:
|
|
586
671
|
end_time = time.time()
|
|
587
672
|
output = _serialize_stream_output(self._stream_result)
|
|
588
|
-
metrics = _extract_stream_usage_metrics(
|
|
673
|
+
metrics = _extract_stream_usage_metrics(
|
|
674
|
+
self._stream_result, self._start_time, end_time, self._first_token_time
|
|
675
|
+
)
|
|
589
676
|
self._span.log(output=output, metrics=metrics)
|
|
590
677
|
self._logged = True
|
|
591
678
|
finally:
|
|
@@ -611,6 +698,7 @@ class _DirectStreamWrapperSync:
|
|
|
611
698
|
self.span_cm = None
|
|
612
699
|
self.start_time = None
|
|
613
700
|
self.stream = None
|
|
701
|
+
self._first_token_time = None
|
|
614
702
|
|
|
615
703
|
def __enter__(self):
|
|
616
704
|
# Use context manager properly so span stays current
|
|
@@ -619,14 +707,16 @@ class _DirectStreamWrapperSync:
|
|
|
619
707
|
name=self.span_name,
|
|
620
708
|
type=SpanTypeAttribute.LLM,
|
|
621
709
|
input=self.input_data if self.input_data else None,
|
|
622
|
-
metadata=
|
|
710
|
+
metadata=self.metadata,
|
|
623
711
|
)
|
|
624
712
|
span = self.span_cm.__enter__()
|
|
625
713
|
|
|
626
714
|
# Capture start time right before entering the stream (API call initiation)
|
|
627
715
|
self.start_time = time.time()
|
|
628
716
|
self.stream = self.stream_cm.__enter__()
|
|
629
|
-
|
|
717
|
+
|
|
718
|
+
# Wrap the stream to capture first token time
|
|
719
|
+
return _DirectStreamIteratorSyncProxy(self.stream, self)
|
|
630
720
|
|
|
631
721
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
632
722
|
try:
|
|
@@ -638,7 +728,9 @@ class _DirectStreamWrapperSync:
|
|
|
638
728
|
try:
|
|
639
729
|
final_response = self.stream.get()
|
|
640
730
|
output = _serialize_model_response(final_response)
|
|
641
|
-
metrics = _extract_response_metrics(
|
|
731
|
+
metrics = _extract_response_metrics(
|
|
732
|
+
final_response, self.start_time, end_time, self._first_token_time
|
|
733
|
+
)
|
|
642
734
|
self.span_cm.log(output=output, metrics=metrics)
|
|
643
735
|
except Exception as e:
|
|
644
736
|
logger.debug(f"Failed to extract stream output/metrics: {e}")
|
|
@@ -650,6 +742,36 @@ class _DirectStreamWrapperSync:
|
|
|
650
742
|
return False
|
|
651
743
|
|
|
652
744
|
|
|
745
|
+
class _DirectStreamIteratorSyncProxy:
|
|
746
|
+
"""Proxy for direct stream (sync) that captures first token time."""
|
|
747
|
+
|
|
748
|
+
def __init__(self, stream: Any, wrapper: _DirectStreamWrapperSync):
|
|
749
|
+
self._stream = stream
|
|
750
|
+
self._wrapper = wrapper
|
|
751
|
+
self._iterator = None
|
|
752
|
+
|
|
753
|
+
def __getattr__(self, name: str):
|
|
754
|
+
"""Delegate all attribute access to the wrapped stream."""
|
|
755
|
+
return getattr(self._stream, name)
|
|
756
|
+
|
|
757
|
+
def __iter__(self):
|
|
758
|
+
"""Return iterator that captures first token time."""
|
|
759
|
+
# Get the actual iterator from the stream
|
|
760
|
+
self._iterator = self._stream.__iter__() if hasattr(self._stream, "__iter__") else self._stream
|
|
761
|
+
return self
|
|
762
|
+
|
|
763
|
+
def __next__(self):
|
|
764
|
+
"""Capture first token time on first iteration."""
|
|
765
|
+
if self._iterator is None:
|
|
766
|
+
# In case __iter__ wasn't called, initialize it
|
|
767
|
+
self._iterator = self._stream.__iter__() if hasattr(self._stream, "__iter__") else self._stream
|
|
768
|
+
|
|
769
|
+
item = self._iterator.__next__()
|
|
770
|
+
if self._wrapper._first_token_time is None:
|
|
771
|
+
self._wrapper._first_token_time = time.time()
|
|
772
|
+
return item
|
|
773
|
+
|
|
774
|
+
|
|
653
775
|
def _serialize_user_prompt(user_prompt: Any) -> Any:
|
|
654
776
|
"""Serialize user prompt, handling BinaryContent and other types."""
|
|
655
777
|
if user_prompt is None:
|
|
@@ -665,7 +787,14 @@ def _serialize_user_prompt(user_prompt: Any) -> Any:
|
|
|
665
787
|
|
|
666
788
|
|
|
667
789
|
def _serialize_content_part(part: Any) -> Any:
|
|
668
|
-
"""Serialize a content part, handling BinaryContent specially.
|
|
790
|
+
"""Serialize a content part, handling BinaryContent specially.
|
|
791
|
+
|
|
792
|
+
This function handles:
|
|
793
|
+
- BinaryContent: converts to Braintrust Attachment
|
|
794
|
+
- Parts with nested content (UserPromptPart): recursively serializes content items
|
|
795
|
+
- Strings: passes through unchanged
|
|
796
|
+
- Other objects: converts to dict via model_dump
|
|
797
|
+
"""
|
|
669
798
|
if part is None:
|
|
670
799
|
return None
|
|
671
800
|
|
|
@@ -680,10 +809,25 @@ def _serialize_content_part(part: Any) -> Any:
|
|
|
680
809
|
attachment = Attachment(data=data, filename=filename, content_type=media_type)
|
|
681
810
|
return {"type": "binary", "attachment": attachment, "media_type": media_type}
|
|
682
811
|
|
|
812
|
+
if hasattr(part, "content"):
|
|
813
|
+
content = part.content
|
|
814
|
+
if isinstance(content, list):
|
|
815
|
+
serialized_content = [_serialize_content_part(item) for item in content]
|
|
816
|
+
result = bt_safe_deep_copy(part)
|
|
817
|
+
if isinstance(result, dict):
|
|
818
|
+
result["content"] = serialized_content
|
|
819
|
+
return result
|
|
820
|
+
elif content is not None:
|
|
821
|
+
serialized_content = _serialize_content_part(content)
|
|
822
|
+
result = bt_safe_deep_copy(part)
|
|
823
|
+
if isinstance(result, dict):
|
|
824
|
+
result["content"] = serialized_content
|
|
825
|
+
return result
|
|
826
|
+
|
|
683
827
|
if isinstance(part, str):
|
|
684
828
|
return part
|
|
685
829
|
|
|
686
|
-
return
|
|
830
|
+
return bt_safe_deep_copy(part)
|
|
687
831
|
|
|
688
832
|
|
|
689
833
|
def _serialize_messages(messages: Any) -> Any:
|
|
@@ -693,10 +837,24 @@ def _serialize_messages(messages: Any) -> Any:
|
|
|
693
837
|
|
|
694
838
|
result = []
|
|
695
839
|
for msg in messages:
|
|
696
|
-
|
|
840
|
+
if hasattr(msg, "parts") and msg.parts:
|
|
841
|
+
original_parts = msg.parts
|
|
842
|
+
serialized_parts = [_serialize_content_part(p) for p in original_parts]
|
|
697
843
|
|
|
698
|
-
|
|
699
|
-
|
|
844
|
+
# Use model_dump with exclude to avoid serializing parts field prematurely
|
|
845
|
+
if hasattr(msg, "model_dump"):
|
|
846
|
+
try:
|
|
847
|
+
serialized_msg = msg.model_dump(exclude={"parts"}, exclude_none=True)
|
|
848
|
+
except (TypeError, ValueError):
|
|
849
|
+
# If exclude parameter not supported, fall back to bt_safe_deep_copy
|
|
850
|
+
serialized_msg = bt_safe_deep_copy(msg)
|
|
851
|
+
else:
|
|
852
|
+
serialized_msg = bt_safe_deep_copy(msg)
|
|
853
|
+
|
|
854
|
+
if isinstance(serialized_msg, dict):
|
|
855
|
+
serialized_msg["parts"] = serialized_parts
|
|
856
|
+
else:
|
|
857
|
+
serialized_msg = bt_safe_deep_copy(msg)
|
|
700
858
|
|
|
701
859
|
result.append(serialized_msg)
|
|
702
860
|
|
|
@@ -711,12 +869,12 @@ def _serialize_result_output(result: Any) -> Any:
|
|
|
711
869
|
output_dict = {}
|
|
712
870
|
|
|
713
871
|
if hasattr(result, "output"):
|
|
714
|
-
output_dict["output"] =
|
|
872
|
+
output_dict["output"] = bt_safe_deep_copy(result.output)
|
|
715
873
|
|
|
716
874
|
if hasattr(result, "response"):
|
|
717
875
|
output_dict["response"] = _serialize_model_response(result.response)
|
|
718
876
|
|
|
719
|
-
return output_dict if output_dict else
|
|
877
|
+
return output_dict if output_dict else bt_safe_deep_copy(result)
|
|
720
878
|
|
|
721
879
|
|
|
722
880
|
def _serialize_stream_output(stream_result: Any) -> Any:
|
|
@@ -737,11 +895,10 @@ def _serialize_model_response(response: Any) -> Any:
|
|
|
737
895
|
if not response:
|
|
738
896
|
return None
|
|
739
897
|
|
|
740
|
-
response_dict =
|
|
898
|
+
response_dict = bt_safe_deep_copy(response)
|
|
741
899
|
|
|
742
|
-
if
|
|
743
|
-
|
|
744
|
-
response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
|
|
900
|
+
if hasattr(response, "parts") and isinstance(response_dict, dict):
|
|
901
|
+
response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
|
|
745
902
|
|
|
746
903
|
return response_dict
|
|
747
904
|
|
|
@@ -801,9 +958,7 @@ def _extract_model_info(agent: Any) -> tuple[str | None, str | None]:
|
|
|
801
958
|
return _extract_model_info_from_model_instance(agent.model)
|
|
802
959
|
|
|
803
960
|
|
|
804
|
-
def _build_model_metadata(
|
|
805
|
-
model_name: str | None, provider: str | None, model_settings: Any = None
|
|
806
|
-
) -> dict[str, Any]:
|
|
961
|
+
def _build_model_metadata(model_name: str | None, provider: str | None, model_settings: Any = None) -> dict[str, Any]:
|
|
807
962
|
"""Build metadata dictionary with model info.
|
|
808
963
|
|
|
809
964
|
Args:
|
|
@@ -820,7 +975,7 @@ def _build_model_metadata(
|
|
|
820
975
|
if provider:
|
|
821
976
|
metadata["provider"] = provider
|
|
822
977
|
if model_settings:
|
|
823
|
-
metadata["model_settings"] =
|
|
978
|
+
metadata["model_settings"] = bt_safe_deep_copy(model_settings)
|
|
824
979
|
return metadata
|
|
825
980
|
|
|
826
981
|
|
|
@@ -986,24 +1141,6 @@ def _is_patched(obj: Any) -> bool:
|
|
|
986
1141
|
return getattr(obj, "_braintrust_patched", False)
|
|
987
1142
|
|
|
988
1143
|
|
|
989
|
-
def _try_dict(obj: Any) -> Iterable[Any] | dict[str, Any]:
|
|
990
|
-
"""Try to convert object to dict, handling Pydantic models and circular references."""
|
|
991
|
-
if hasattr(obj, "model_dump"):
|
|
992
|
-
try:
|
|
993
|
-
obj = obj.model_dump(exclude_none=True)
|
|
994
|
-
except ValueError as e:
|
|
995
|
-
if "Circular reference" in str(e):
|
|
996
|
-
return {}
|
|
997
|
-
raise
|
|
998
|
-
|
|
999
|
-
if isinstance(obj, dict):
|
|
1000
|
-
return {k: _try_dict(v) for k, v in obj.items()}
|
|
1001
|
-
elif isinstance(obj, (list, tuple)):
|
|
1002
|
-
return [_try_dict(item) for item in obj]
|
|
1003
|
-
|
|
1004
|
-
return obj
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
1144
|
def _serialize_type(obj: Any) -> Any:
|
|
1008
1145
|
"""Serialize a type/class for logging, handling Pydantic models and other types.
|
|
1009
1146
|
|
|
@@ -1047,32 +1184,7 @@ def _serialize_type(obj: Any) -> Any:
|
|
|
1047
1184
|
return obj.__name__
|
|
1048
1185
|
|
|
1049
1186
|
# Try standard serialization
|
|
1050
|
-
return
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
G = TypeVar("G", bound=AsyncGenerator[Any, None])
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
class aclosing(AbstractAsyncContextManager[G]):
|
|
1057
|
-
"""Context manager for closing async generators."""
|
|
1058
|
-
|
|
1059
|
-
def __init__(self, async_generator: G):
|
|
1060
|
-
self.async_generator = async_generator
|
|
1061
|
-
|
|
1062
|
-
async def __aenter__(self):
|
|
1063
|
-
return self.async_generator
|
|
1064
|
-
|
|
1065
|
-
async def __aexit__(self, *exc_info: Any):
|
|
1066
|
-
try:
|
|
1067
|
-
await self.async_generator.aclose()
|
|
1068
|
-
except ValueError as e:
|
|
1069
|
-
if "was created in a different Context" not in str(e):
|
|
1070
|
-
raise
|
|
1071
|
-
else:
|
|
1072
|
-
logger.debug(
|
|
1073
|
-
f"Suppressed ContextVar error during async cleanup: {e}. "
|
|
1074
|
-
"This is expected when async generators yield across context boundaries."
|
|
1075
|
-
)
|
|
1187
|
+
return bt_safe_deep_copy(obj)
|
|
1076
1188
|
|
|
1077
1189
|
|
|
1078
1190
|
def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
@@ -1097,9 +1209,9 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu
|
|
|
1097
1209
|
input_data[key] = _serialize_type(value) if value is not None else None
|
|
1098
1210
|
elif key == "model_settings":
|
|
1099
1211
|
# model_settings passed to run() goes in INPUT (it's a run() parameter)
|
|
1100
|
-
input_data[key] =
|
|
1212
|
+
input_data[key] = bt_safe_deep_copy(value) if value is not None else None
|
|
1101
1213
|
else:
|
|
1102
|
-
input_data[key] =
|
|
1214
|
+
input_data[key] = bt_safe_deep_copy(value) if value is not None else None
|
|
1103
1215
|
|
|
1104
1216
|
if "model" in kwargs:
|
|
1105
1217
|
model_name, provider = _parse_model_string(kwargs["model"])
|
|
@@ -1155,7 +1267,9 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu
|
|
|
1155
1267
|
if hasattr(tool_obj, "description") and tool_obj.description:
|
|
1156
1268
|
tool_dict["description"] = tool_obj.description
|
|
1157
1269
|
# Extract JSON schema for parameters
|
|
1158
|
-
if hasattr(tool_obj, "function_schema") and hasattr(
|
|
1270
|
+
if hasattr(tool_obj, "function_schema") and hasattr(
|
|
1271
|
+
tool_obj.function_schema, "json_schema"
|
|
1272
|
+
):
|
|
1159
1273
|
tool_dict["parameters"] = tool_obj.function_schema.json_schema
|
|
1160
1274
|
tools_list.append(tool_dict)
|
|
1161
1275
|
ts_info["tools"] = tools_list
|
|
@@ -1196,7 +1310,7 @@ def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[dict
|
|
|
1196
1310
|
|
|
1197
1311
|
for key, value in kwargs.items():
|
|
1198
1312
|
if key not in ["model", "messages"]:
|
|
1199
|
-
input_data[key] =
|
|
1313
|
+
input_data[key] = bt_safe_deep_copy(value) if value is not None else None
|
|
1200
1314
|
|
|
1201
1315
|
model_name, provider = _parse_model_string(model)
|
|
1202
1316
|
metadata = _build_model_metadata(model_name, provider)
|
|
@@ -547,7 +547,6 @@ def test_error_handling(memory_logger):
|
|
|
547
547
|
assert log["error"]
|
|
548
548
|
|
|
549
549
|
|
|
550
|
-
# Test 9: Stop Sequences
|
|
551
550
|
@pytest.mark.vcr
|
|
552
551
|
def test_stop_sequences(memory_logger):
|
|
553
552
|
"""Test stop sequences parameter."""
|
|
@@ -571,3 +570,65 @@ def test_stop_sequences(memory_logger):
|
|
|
571
570
|
assert len(spans) == 1
|
|
572
571
|
span = spans[0]
|
|
573
572
|
assert span["metadata"]["model"] == MODEL
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def test_attachment_in_config(memory_logger):
|
|
576
|
+
"""Test that attachments in config are preserved through serialization."""
|
|
577
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
578
|
+
from braintrust.logger import Attachment
|
|
579
|
+
|
|
580
|
+
attachment = Attachment(data=b"config data", filename="config.txt", content_type="text/plain")
|
|
581
|
+
|
|
582
|
+
# Simulate config with attachment
|
|
583
|
+
config = {"temperature": 0.5, "context_file": attachment, "max_output_tokens": 100}
|
|
584
|
+
|
|
585
|
+
# Test bt_safe_deep_copy preserves attachment
|
|
586
|
+
copied = bt_safe_deep_copy(config)
|
|
587
|
+
assert copied["context_file"] is attachment
|
|
588
|
+
assert copied["temperature"] == 0.5
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
def test_nested_attachments_in_contents(memory_logger):
|
|
592
|
+
"""Test that nested attachments in contents are preserved."""
|
|
593
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
594
|
+
from braintrust.logger import Attachment, ExternalAttachment
|
|
595
|
+
|
|
596
|
+
attachment1 = Attachment(data=b"file1", filename="file1.txt", content_type="text/plain")
|
|
597
|
+
attachment2 = ExternalAttachment(url="s3://bucket/file2.pdf", filename="file2.pdf", content_type="application/pdf")
|
|
598
|
+
|
|
599
|
+
# Simulate contents with nested attachments
|
|
600
|
+
contents = [
|
|
601
|
+
{"role": "user", "parts": [{"text": "Check these files"}, {"file": attachment1}]},
|
|
602
|
+
{"role": "model", "parts": [{"text": "Analyzed"}, {"result_file": attachment2}]},
|
|
603
|
+
]
|
|
604
|
+
|
|
605
|
+
copied = bt_safe_deep_copy(contents)
|
|
606
|
+
|
|
607
|
+
# Verify attachments preserved
|
|
608
|
+
assert copied[0]["parts"][1]["file"] is attachment1
|
|
609
|
+
assert copied[1]["parts"][1]["result_file"] is attachment2
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def test_attachment_with_pydantic_model(memory_logger):
|
|
613
|
+
"""Test that attachments work alongside Pydantic model serialization."""
|
|
614
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
615
|
+
from braintrust.logger import Attachment
|
|
616
|
+
from pydantic import BaseModel
|
|
617
|
+
|
|
618
|
+
class TestModel(BaseModel):
|
|
619
|
+
name: str
|
|
620
|
+
value: int
|
|
621
|
+
|
|
622
|
+
attachment = Attachment(data=b"model data", filename="model.txt", content_type="text/plain")
|
|
623
|
+
|
|
624
|
+
# Structure with both Pydantic model and attachment
|
|
625
|
+
data = {"model_config": TestModel(name="test", value=42), "context_file": attachment}
|
|
626
|
+
|
|
627
|
+
copied = bt_safe_deep_copy(data)
|
|
628
|
+
|
|
629
|
+
# Pydantic model should be converted to dict
|
|
630
|
+
assert isinstance(copied["model_config"], dict)
|
|
631
|
+
assert copied["model_config"]["name"] == "test"
|
|
632
|
+
|
|
633
|
+
# Attachment should be preserved
|
|
634
|
+
assert copied["context_file"] is attachment
|