braintrust 0.3.15__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/_generated_types.py +737 -672
- braintrust/audit.py +2 -2
- braintrust/bt_json.py +178 -19
- braintrust/cli/eval.py +6 -7
- braintrust/cli/push.py +11 -11
- braintrust/context.py +12 -17
- braintrust/contrib/temporal/__init__.py +16 -27
- braintrust/contrib/temporal/test_temporal.py +8 -3
- braintrust/devserver/auth.py +8 -8
- braintrust/devserver/cache.py +3 -4
- braintrust/devserver/cors.py +8 -7
- braintrust/devserver/dataset.py +3 -5
- braintrust/devserver/eval_hooks.py +7 -6
- braintrust/devserver/schemas.py +22 -19
- braintrust/devserver/server.py +19 -12
- braintrust/devserver/test_cached_login.py +4 -4
- braintrust/framework.py +139 -142
- braintrust/framework2.py +88 -87
- braintrust/functions/invoke.py +66 -59
- braintrust/functions/stream.py +3 -2
- braintrust/generated_types.py +3 -1
- braintrust/git_fields.py +11 -11
- braintrust/gitutil.py +2 -3
- braintrust/graph_util.py +10 -10
- braintrust/id_gen.py +2 -2
- braintrust/logger.py +373 -471
- braintrust/merge_row_batch.py +10 -9
- braintrust/oai.py +21 -20
- braintrust/otel/__init__.py +49 -49
- braintrust/otel/context.py +16 -30
- braintrust/otel/test_distributed_tracing.py +14 -11
- braintrust/otel/test_otel_bt_integration.py +32 -31
- braintrust/parameters.py +8 -8
- braintrust/prompt.py +14 -14
- braintrust/prompt_cache/disk_cache.py +5 -4
- braintrust/prompt_cache/lru_cache.py +3 -2
- braintrust/prompt_cache/prompt_cache.py +13 -14
- braintrust/queue.py +4 -4
- braintrust/score.py +4 -4
- braintrust/serializable_data_class.py +4 -4
- braintrust/span_identifier_v1.py +1 -2
- braintrust/span_identifier_v2.py +3 -4
- braintrust/span_identifier_v3.py +23 -20
- braintrust/span_identifier_v4.py +34 -25
- braintrust/test_bt_json.py +644 -0
- braintrust/test_framework.py +72 -6
- braintrust/test_helpers.py +5 -5
- braintrust/test_id_gen.py +2 -3
- braintrust/test_logger.py +211 -107
- braintrust/test_otel.py +61 -53
- braintrust/test_queue.py +0 -1
- braintrust/test_score.py +1 -3
- braintrust/test_span_components.py +29 -44
- braintrust/util.py +9 -8
- braintrust/version.py +2 -2
- braintrust/wrappers/_anthropic_utils.py +4 -4
- braintrust/wrappers/agno/__init__.py +3 -4
- braintrust/wrappers/agno/agent.py +1 -2
- braintrust/wrappers/agno/function_call.py +1 -2
- braintrust/wrappers/agno/model.py +1 -2
- braintrust/wrappers/agno/team.py +1 -2
- braintrust/wrappers/agno/utils.py +12 -12
- braintrust/wrappers/anthropic.py +7 -8
- braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
- braintrust/wrappers/dspy.py +15 -17
- braintrust/wrappers/google_genai/__init__.py +17 -30
- braintrust/wrappers/langchain.py +22 -24
- braintrust/wrappers/litellm.py +4 -3
- braintrust/wrappers/openai.py +15 -15
- braintrust/wrappers/pydantic_ai.py +225 -110
- braintrust/wrappers/test_agno.py +0 -1
- braintrust/wrappers/test_dspy.py +0 -1
- braintrust/wrappers/test_google_genai.py +64 -4
- braintrust/wrappers/test_litellm.py +0 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +819 -22
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/METADATA +3 -2
- braintrust-0.4.1.dist-info/RECORD +121 -0
- braintrust-0.3.15.dist-info/RECORD +0 -120
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/WHEEL +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.15.dist-info → braintrust-0.4.1.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import logging
|
|
2
3
|
import sys
|
|
3
4
|
import time
|
|
4
5
|
from contextlib import AbstractAsyncContextManager
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any
|
|
6
7
|
|
|
8
|
+
from braintrust.bt_json import bt_safe_deep_copy
|
|
7
9
|
from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
|
|
8
10
|
from braintrust.span_types import SpanTypeAttribute
|
|
9
11
|
from wrapt import wrap_function_wrapper
|
|
@@ -14,9 +16,9 @@ __all__ = ["setup_pydantic_ai"]
|
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def setup_pydantic_ai(
|
|
17
|
-
api_key:
|
|
18
|
-
project_id:
|
|
19
|
-
project_name:
|
|
19
|
+
api_key: str | None = None,
|
|
20
|
+
project_id: str | None = None,
|
|
21
|
+
project_name: str | None = None,
|
|
20
22
|
) -> bool:
|
|
21
23
|
"""
|
|
22
24
|
Setup Braintrust integration with Pydantic AI. Will automatically patch Pydantic AI Agents and direct API functions for automatic tracing.
|
|
@@ -73,7 +75,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
73
75
|
name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
|
|
74
76
|
type=SpanTypeAttribute.LLM,
|
|
75
77
|
input=input_data if input_data else None,
|
|
76
|
-
metadata=
|
|
78
|
+
metadata=metadata,
|
|
77
79
|
) as agent_span:
|
|
78
80
|
start_time = time.time()
|
|
79
81
|
result = await wrapped(*args, **kwargs)
|
|
@@ -97,7 +99,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
97
99
|
else "agent_run_sync",
|
|
98
100
|
type=SpanTypeAttribute.LLM,
|
|
99
101
|
input=input_data if input_data else None,
|
|
100
|
-
metadata=
|
|
102
|
+
metadata=metadata,
|
|
101
103
|
) as agent_span:
|
|
102
104
|
start_time = time.time()
|
|
103
105
|
result = wrapped(*args, **kwargs)
|
|
@@ -137,7 +139,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
137
139
|
name=span_name,
|
|
138
140
|
type=SpanTypeAttribute.LLM,
|
|
139
141
|
input=input_data if input_data else None,
|
|
140
|
-
metadata=
|
|
142
|
+
metadata=metadata,
|
|
141
143
|
)
|
|
142
144
|
span = span_cm.__enter__()
|
|
143
145
|
start_time = time.time()
|
|
@@ -169,7 +171,7 @@ def wrap_agent(Agent: Any) -> Any:
|
|
|
169
171
|
name=span_name,
|
|
170
172
|
type=SpanTypeAttribute.LLM,
|
|
171
173
|
input=input_data if input_data else None,
|
|
172
|
-
metadata=
|
|
174
|
+
metadata=metadata,
|
|
173
175
|
) as agent_span:
|
|
174
176
|
start_time = time.time()
|
|
175
177
|
event_count = 0
|
|
@@ -215,7 +217,7 @@ def _create_direct_model_request_wrapper():
|
|
|
215
217
|
name="model_request",
|
|
216
218
|
type=SpanTypeAttribute.LLM,
|
|
217
219
|
input=input_data,
|
|
218
|
-
metadata=
|
|
220
|
+
metadata=metadata,
|
|
219
221
|
) as span:
|
|
220
222
|
start_time = time.time()
|
|
221
223
|
result = await wrapped(*args, **kwargs)
|
|
@@ -240,7 +242,7 @@ def _create_direct_model_request_sync_wrapper():
|
|
|
240
242
|
name="model_request_sync",
|
|
241
243
|
type=SpanTypeAttribute.LLM,
|
|
242
244
|
input=input_data,
|
|
243
|
-
metadata=
|
|
245
|
+
metadata=metadata,
|
|
244
246
|
) as span:
|
|
245
247
|
start_time = time.time()
|
|
246
248
|
result = wrapped(*args, **kwargs)
|
|
@@ -295,7 +297,7 @@ def wrap_model_request(original_func: Any) -> Any:
|
|
|
295
297
|
name="model_request",
|
|
296
298
|
type=SpanTypeAttribute.LLM,
|
|
297
299
|
input=input_data,
|
|
298
|
-
metadata=
|
|
300
|
+
metadata=metadata,
|
|
299
301
|
) as span:
|
|
300
302
|
start_time = time.time()
|
|
301
303
|
result = await original_func(*args, **kwargs)
|
|
@@ -318,7 +320,7 @@ def wrap_model_request_sync(original_func: Any) -> Any:
|
|
|
318
320
|
name="model_request_sync",
|
|
319
321
|
type=SpanTypeAttribute.LLM,
|
|
320
322
|
input=input_data,
|
|
321
|
-
metadata=
|
|
323
|
+
metadata=metadata,
|
|
322
324
|
) as span:
|
|
323
325
|
start_time = time.time()
|
|
324
326
|
result = original_func(*args, **kwargs)
|
|
@@ -390,7 +392,7 @@ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any)
|
|
|
390
392
|
Tuple of (model_name, display_name, input_data, metadata)
|
|
391
393
|
"""
|
|
392
394
|
model_name, provider = _extract_model_info_from_model_instance(instance)
|
|
393
|
-
display_name = model_name or
|
|
395
|
+
display_name = model_name or type(instance).__name__
|
|
394
396
|
|
|
395
397
|
messages = args[0] if len(args) > 0 else kwargs.get("messages")
|
|
396
398
|
model_settings = args[1] if len(args) > 1 else kwargs.get("model_settings")
|
|
@@ -399,7 +401,7 @@ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any)
|
|
|
399
401
|
|
|
400
402
|
input_data = {"messages": serialized_messages}
|
|
401
403
|
if model_settings is not None:
|
|
402
|
-
input_data["model_settings"] =
|
|
404
|
+
input_data["model_settings"] = bt_safe_deep_copy(model_settings)
|
|
403
405
|
|
|
404
406
|
metadata = _build_model_metadata(model_name, provider, model_settings=None)
|
|
405
407
|
|
|
@@ -418,7 +420,7 @@ def _wrap_concrete_model_class(model_class: Any):
|
|
|
418
420
|
name=f"chat {display_name}",
|
|
419
421
|
type=SpanTypeAttribute.LLM,
|
|
420
422
|
input=input_data,
|
|
421
|
-
metadata=
|
|
423
|
+
metadata=metadata,
|
|
422
424
|
) as span:
|
|
423
425
|
start_time = time.time()
|
|
424
426
|
result = await wrapped(*args, **kwargs)
|
|
@@ -456,22 +458,28 @@ class _AgentStreamWrapper(AbstractAsyncContextManager):
|
|
|
456
458
|
self.span_cm = None
|
|
457
459
|
self.start_time = None
|
|
458
460
|
self.stream_result = None
|
|
461
|
+
self._enter_task = None
|
|
462
|
+
self._first_token_time = None
|
|
459
463
|
|
|
460
464
|
async def __aenter__(self):
|
|
465
|
+
self._enter_task = asyncio.current_task()
|
|
466
|
+
|
|
461
467
|
# Use context manager properly so span stays current
|
|
462
468
|
# DON'T pass start_time here - we'll set it via metrics in __aexit__
|
|
463
469
|
self.span_cm = start_span(
|
|
464
470
|
name=self.span_name,
|
|
465
471
|
type=SpanTypeAttribute.LLM,
|
|
466
472
|
input=self.input_data if self.input_data else None,
|
|
467
|
-
metadata=
|
|
473
|
+
metadata=self.metadata,
|
|
468
474
|
)
|
|
469
|
-
|
|
475
|
+
self.span_cm.__enter__()
|
|
470
476
|
|
|
471
477
|
# Capture start time right before entering the stream (API call initiation)
|
|
472
478
|
self.start_time = time.time()
|
|
473
479
|
self.stream_result = await self.stream_cm.__aenter__()
|
|
474
|
-
|
|
480
|
+
|
|
481
|
+
# Wrap the stream result to capture first token time
|
|
482
|
+
return _StreamResultProxy(self.stream_result, self)
|
|
475
483
|
|
|
476
484
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
477
485
|
try:
|
|
@@ -481,16 +489,47 @@ class _AgentStreamWrapper(AbstractAsyncContextManager):
|
|
|
481
489
|
end_time = time.time()
|
|
482
490
|
|
|
483
491
|
output = _serialize_stream_output(self.stream_result)
|
|
484
|
-
metrics = _extract_stream_usage_metrics(
|
|
492
|
+
metrics = _extract_stream_usage_metrics(
|
|
493
|
+
self.stream_result, self.start_time, end_time, self._first_token_time
|
|
494
|
+
)
|
|
485
495
|
self.span_cm.log(output=output, metrics=metrics)
|
|
486
496
|
|
|
487
|
-
#
|
|
497
|
+
# Clean up span context
|
|
488
498
|
if self.span_cm:
|
|
489
|
-
|
|
499
|
+
if asyncio.current_task() is self._enter_task:
|
|
500
|
+
self.span_cm.__exit__(None, None, None)
|
|
501
|
+
else:
|
|
502
|
+
self.span_cm.end()
|
|
490
503
|
|
|
491
504
|
return False
|
|
492
505
|
|
|
493
506
|
|
|
507
|
+
class _StreamResultProxy:
|
|
508
|
+
"""Proxy for stream result that captures first token time."""
|
|
509
|
+
|
|
510
|
+
def __init__(self, stream_result: Any, wrapper: _AgentStreamWrapper):
|
|
511
|
+
self._stream_result = stream_result
|
|
512
|
+
self._wrapper = wrapper
|
|
513
|
+
|
|
514
|
+
def __getattr__(self, name: str):
|
|
515
|
+
"""Delegate all attribute access to the wrapped stream result."""
|
|
516
|
+
attr = getattr(self._stream_result, name)
|
|
517
|
+
|
|
518
|
+
# Wrap streaming methods to capture first token time
|
|
519
|
+
if callable(attr) and name in ("stream_text", "stream_output"):
|
|
520
|
+
|
|
521
|
+
async def wrapped_method(*args, **kwargs):
|
|
522
|
+
result = attr(*args, **kwargs)
|
|
523
|
+
async for item in result:
|
|
524
|
+
if self._wrapper._first_token_time is None:
|
|
525
|
+
self._wrapper._first_token_time = time.time()
|
|
526
|
+
yield item
|
|
527
|
+
|
|
528
|
+
return wrapped_method
|
|
529
|
+
|
|
530
|
+
return attr
|
|
531
|
+
|
|
532
|
+
|
|
494
533
|
class _DirectStreamWrapper(AbstractAsyncContextManager):
|
|
495
534
|
"""Wrapper for model_request_stream() that adds tracing while passing through the stream."""
|
|
496
535
|
|
|
@@ -502,22 +541,28 @@ class _DirectStreamWrapper(AbstractAsyncContextManager):
|
|
|
502
541
|
self.span_cm = None
|
|
503
542
|
self.start_time = None
|
|
504
543
|
self.stream = None
|
|
544
|
+
self._enter_task = None
|
|
545
|
+
self._first_token_time = None
|
|
505
546
|
|
|
506
547
|
async def __aenter__(self):
|
|
548
|
+
self._enter_task = asyncio.current_task()
|
|
549
|
+
|
|
507
550
|
# Use context manager properly so span stays current
|
|
508
551
|
# DON'T pass start_time here - we'll set it via metrics in __aexit__
|
|
509
552
|
self.span_cm = start_span(
|
|
510
553
|
name=self.span_name,
|
|
511
554
|
type=SpanTypeAttribute.LLM,
|
|
512
555
|
input=self.input_data if self.input_data else None,
|
|
513
|
-
metadata=
|
|
556
|
+
metadata=self.metadata,
|
|
514
557
|
)
|
|
515
|
-
|
|
558
|
+
self.span_cm.__enter__()
|
|
516
559
|
|
|
517
560
|
# Capture start time right before entering the stream (API call initiation)
|
|
518
561
|
self.start_time = time.time()
|
|
519
562
|
self.stream = await self.stream_cm.__aenter__()
|
|
520
|
-
|
|
563
|
+
|
|
564
|
+
# Wrap the stream to capture first token time
|
|
565
|
+
return _DirectStreamIteratorProxy(self.stream, self)
|
|
521
566
|
|
|
522
567
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
523
568
|
try:
|
|
@@ -529,18 +574,53 @@ class _DirectStreamWrapper(AbstractAsyncContextManager):
|
|
|
529
574
|
try:
|
|
530
575
|
final_response = self.stream.get()
|
|
531
576
|
output = _serialize_model_response(final_response)
|
|
532
|
-
metrics = _extract_response_metrics(
|
|
577
|
+
metrics = _extract_response_metrics(
|
|
578
|
+
final_response, self.start_time, end_time, self._first_token_time
|
|
579
|
+
)
|
|
533
580
|
self.span_cm.log(output=output, metrics=metrics)
|
|
534
581
|
except Exception as e:
|
|
535
582
|
logger.debug(f"Failed to extract stream output/metrics: {e}")
|
|
536
583
|
|
|
537
|
-
#
|
|
584
|
+
# Clean up span context
|
|
538
585
|
if self.span_cm:
|
|
539
|
-
|
|
586
|
+
if asyncio.current_task() is self._enter_task:
|
|
587
|
+
self.span_cm.__exit__(None, None, None)
|
|
588
|
+
else:
|
|
589
|
+
self.span_cm.end()
|
|
540
590
|
|
|
541
591
|
return False
|
|
542
592
|
|
|
543
593
|
|
|
594
|
+
class _DirectStreamIteratorProxy:
|
|
595
|
+
"""Proxy for direct stream that captures first token time."""
|
|
596
|
+
|
|
597
|
+
def __init__(self, stream: Any, wrapper: _DirectStreamWrapper):
|
|
598
|
+
self._stream = stream
|
|
599
|
+
self._wrapper = wrapper
|
|
600
|
+
self._iterator = None
|
|
601
|
+
|
|
602
|
+
def __getattr__(self, name: str):
|
|
603
|
+
"""Delegate all attribute access to the wrapped stream."""
|
|
604
|
+
return getattr(self._stream, name)
|
|
605
|
+
|
|
606
|
+
def __aiter__(self):
|
|
607
|
+
"""Return async iterator that captures first token time."""
|
|
608
|
+
# Get the actual async iterator from the stream
|
|
609
|
+
self._iterator = self._stream.__aiter__() if hasattr(self._stream, "__aiter__") else self._stream
|
|
610
|
+
return self
|
|
611
|
+
|
|
612
|
+
async def __anext__(self):
|
|
613
|
+
"""Capture first token time on first iteration."""
|
|
614
|
+
if self._iterator is None:
|
|
615
|
+
# In case __aiter__ wasn't called, initialize it
|
|
616
|
+
self._iterator = self._stream.__aiter__() if hasattr(self._stream, "__aiter__") else self._stream
|
|
617
|
+
|
|
618
|
+
item = await self._iterator.__anext__()
|
|
619
|
+
if self._wrapper._first_token_time is None:
|
|
620
|
+
self._wrapper._first_token_time = time.time()
|
|
621
|
+
return item
|
|
622
|
+
|
|
623
|
+
|
|
544
624
|
class _AgentStreamResultSyncProxy:
|
|
545
625
|
"""Proxy for agent.run_stream_sync() result that adds tracing while delegating to actual stream result."""
|
|
546
626
|
|
|
@@ -551,20 +631,25 @@ class _AgentStreamResultSyncProxy:
|
|
|
551
631
|
self._start_time = start_time
|
|
552
632
|
self._logged = False
|
|
553
633
|
self._finalize_on_del = True
|
|
634
|
+
self._first_token_time = None
|
|
554
635
|
|
|
555
636
|
def __getattr__(self, name: str):
|
|
556
637
|
"""Delegate all attribute access to the wrapped stream result."""
|
|
557
638
|
attr = getattr(self._stream_result, name)
|
|
558
639
|
|
|
559
640
|
# Wrap any method that returns an iterator to auto-finalize when exhausted
|
|
560
|
-
if callable(attr) and name in (
|
|
641
|
+
if callable(attr) and name in ("stream_text", "stream_output", "__iter__"):
|
|
642
|
+
|
|
561
643
|
def wrapped_method(*args, **kwargs):
|
|
562
644
|
try:
|
|
563
645
|
iterator = attr(*args, **kwargs)
|
|
564
646
|
# If it's an iterator, wrap it
|
|
565
|
-
if hasattr(iterator,
|
|
647
|
+
if hasattr(iterator, "__iter__") or hasattr(iterator, "__next__"):
|
|
566
648
|
try:
|
|
567
|
-
|
|
649
|
+
for item in iterator:
|
|
650
|
+
if self._first_token_time is None:
|
|
651
|
+
self._first_token_time = time.time()
|
|
652
|
+
yield item
|
|
568
653
|
finally:
|
|
569
654
|
self._finalize()
|
|
570
655
|
self._finalize_on_del = False # Don't finalize again in __del__
|
|
@@ -574,6 +659,7 @@ class _AgentStreamResultSyncProxy:
|
|
|
574
659
|
self._finalize()
|
|
575
660
|
self._finalize_on_del = False
|
|
576
661
|
raise
|
|
662
|
+
|
|
577
663
|
return wrapped_method
|
|
578
664
|
|
|
579
665
|
return attr
|
|
@@ -584,7 +670,9 @@ class _AgentStreamResultSyncProxy:
|
|
|
584
670
|
try:
|
|
585
671
|
end_time = time.time()
|
|
586
672
|
output = _serialize_stream_output(self._stream_result)
|
|
587
|
-
metrics = _extract_stream_usage_metrics(
|
|
673
|
+
metrics = _extract_stream_usage_metrics(
|
|
674
|
+
self._stream_result, self._start_time, end_time, self._first_token_time
|
|
675
|
+
)
|
|
588
676
|
self._span.log(output=output, metrics=metrics)
|
|
589
677
|
self._logged = True
|
|
590
678
|
finally:
|
|
@@ -610,6 +698,7 @@ class _DirectStreamWrapperSync:
|
|
|
610
698
|
self.span_cm = None
|
|
611
699
|
self.start_time = None
|
|
612
700
|
self.stream = None
|
|
701
|
+
self._first_token_time = None
|
|
613
702
|
|
|
614
703
|
def __enter__(self):
|
|
615
704
|
# Use context manager properly so span stays current
|
|
@@ -618,14 +707,16 @@ class _DirectStreamWrapperSync:
|
|
|
618
707
|
name=self.span_name,
|
|
619
708
|
type=SpanTypeAttribute.LLM,
|
|
620
709
|
input=self.input_data if self.input_data else None,
|
|
621
|
-
metadata=
|
|
710
|
+
metadata=self.metadata,
|
|
622
711
|
)
|
|
623
712
|
span = self.span_cm.__enter__()
|
|
624
713
|
|
|
625
714
|
# Capture start time right before entering the stream (API call initiation)
|
|
626
715
|
self.start_time = time.time()
|
|
627
716
|
self.stream = self.stream_cm.__enter__()
|
|
628
|
-
|
|
717
|
+
|
|
718
|
+
# Wrap the stream to capture first token time
|
|
719
|
+
return _DirectStreamIteratorSyncProxy(self.stream, self)
|
|
629
720
|
|
|
630
721
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
631
722
|
try:
|
|
@@ -637,7 +728,9 @@ class _DirectStreamWrapperSync:
|
|
|
637
728
|
try:
|
|
638
729
|
final_response = self.stream.get()
|
|
639
730
|
output = _serialize_model_response(final_response)
|
|
640
|
-
metrics = _extract_response_metrics(
|
|
731
|
+
metrics = _extract_response_metrics(
|
|
732
|
+
final_response, self.start_time, end_time, self._first_token_time
|
|
733
|
+
)
|
|
641
734
|
self.span_cm.log(output=output, metrics=metrics)
|
|
642
735
|
except Exception as e:
|
|
643
736
|
logger.debug(f"Failed to extract stream output/metrics: {e}")
|
|
@@ -649,6 +742,36 @@ class _DirectStreamWrapperSync:
|
|
|
649
742
|
return False
|
|
650
743
|
|
|
651
744
|
|
|
745
|
+
class _DirectStreamIteratorSyncProxy:
|
|
746
|
+
"""Proxy for direct stream (sync) that captures first token time."""
|
|
747
|
+
|
|
748
|
+
def __init__(self, stream: Any, wrapper: _DirectStreamWrapperSync):
|
|
749
|
+
self._stream = stream
|
|
750
|
+
self._wrapper = wrapper
|
|
751
|
+
self._iterator = None
|
|
752
|
+
|
|
753
|
+
def __getattr__(self, name: str):
|
|
754
|
+
"""Delegate all attribute access to the wrapped stream."""
|
|
755
|
+
return getattr(self._stream, name)
|
|
756
|
+
|
|
757
|
+
def __iter__(self):
|
|
758
|
+
"""Return iterator that captures first token time."""
|
|
759
|
+
# Get the actual iterator from the stream
|
|
760
|
+
self._iterator = self._stream.__iter__() if hasattr(self._stream, "__iter__") else self._stream
|
|
761
|
+
return self
|
|
762
|
+
|
|
763
|
+
def __next__(self):
|
|
764
|
+
"""Capture first token time on first iteration."""
|
|
765
|
+
if self._iterator is None:
|
|
766
|
+
# In case __iter__ wasn't called, initialize it
|
|
767
|
+
self._iterator = self._stream.__iter__() if hasattr(self._stream, "__iter__") else self._stream
|
|
768
|
+
|
|
769
|
+
item = self._iterator.__next__()
|
|
770
|
+
if self._wrapper._first_token_time is None:
|
|
771
|
+
self._wrapper._first_token_time = time.time()
|
|
772
|
+
return item
|
|
773
|
+
|
|
774
|
+
|
|
652
775
|
def _serialize_user_prompt(user_prompt: Any) -> Any:
|
|
653
776
|
"""Serialize user prompt, handling BinaryContent and other types."""
|
|
654
777
|
if user_prompt is None:
|
|
@@ -664,7 +787,14 @@ def _serialize_user_prompt(user_prompt: Any) -> Any:
|
|
|
664
787
|
|
|
665
788
|
|
|
666
789
|
def _serialize_content_part(part: Any) -> Any:
|
|
667
|
-
"""Serialize a content part, handling BinaryContent specially.
|
|
790
|
+
"""Serialize a content part, handling BinaryContent specially.
|
|
791
|
+
|
|
792
|
+
This function handles:
|
|
793
|
+
- BinaryContent: converts to Braintrust Attachment
|
|
794
|
+
- Parts with nested content (UserPromptPart): recursively serializes content items
|
|
795
|
+
- Strings: passes through unchanged
|
|
796
|
+
- Other objects: converts to dict via model_dump
|
|
797
|
+
"""
|
|
668
798
|
if part is None:
|
|
669
799
|
return None
|
|
670
800
|
|
|
@@ -679,10 +809,25 @@ def _serialize_content_part(part: Any) -> Any:
|
|
|
679
809
|
attachment = Attachment(data=data, filename=filename, content_type=media_type)
|
|
680
810
|
return {"type": "binary", "attachment": attachment, "media_type": media_type}
|
|
681
811
|
|
|
812
|
+
if hasattr(part, "content"):
|
|
813
|
+
content = part.content
|
|
814
|
+
if isinstance(content, list):
|
|
815
|
+
serialized_content = [_serialize_content_part(item) for item in content]
|
|
816
|
+
result = bt_safe_deep_copy(part)
|
|
817
|
+
if isinstance(result, dict):
|
|
818
|
+
result["content"] = serialized_content
|
|
819
|
+
return result
|
|
820
|
+
elif content is not None:
|
|
821
|
+
serialized_content = _serialize_content_part(content)
|
|
822
|
+
result = bt_safe_deep_copy(part)
|
|
823
|
+
if isinstance(result, dict):
|
|
824
|
+
result["content"] = serialized_content
|
|
825
|
+
return result
|
|
826
|
+
|
|
682
827
|
if isinstance(part, str):
|
|
683
828
|
return part
|
|
684
829
|
|
|
685
|
-
return
|
|
830
|
+
return bt_safe_deep_copy(part)
|
|
686
831
|
|
|
687
832
|
|
|
688
833
|
def _serialize_messages(messages: Any) -> Any:
|
|
@@ -692,10 +837,24 @@ def _serialize_messages(messages: Any) -> Any:
|
|
|
692
837
|
|
|
693
838
|
result = []
|
|
694
839
|
for msg in messages:
|
|
695
|
-
|
|
840
|
+
if hasattr(msg, "parts") and msg.parts:
|
|
841
|
+
original_parts = msg.parts
|
|
842
|
+
serialized_parts = [_serialize_content_part(p) for p in original_parts]
|
|
696
843
|
|
|
697
|
-
|
|
698
|
-
|
|
844
|
+
# Use model_dump with exclude to avoid serializing parts field prematurely
|
|
845
|
+
if hasattr(msg, "model_dump"):
|
|
846
|
+
try:
|
|
847
|
+
serialized_msg = msg.model_dump(exclude={"parts"}, exclude_none=True)
|
|
848
|
+
except (TypeError, ValueError):
|
|
849
|
+
# If exclude parameter not supported, fall back to bt_safe_deep_copy
|
|
850
|
+
serialized_msg = bt_safe_deep_copy(msg)
|
|
851
|
+
else:
|
|
852
|
+
serialized_msg = bt_safe_deep_copy(msg)
|
|
853
|
+
|
|
854
|
+
if isinstance(serialized_msg, dict):
|
|
855
|
+
serialized_msg["parts"] = serialized_parts
|
|
856
|
+
else:
|
|
857
|
+
serialized_msg = bt_safe_deep_copy(msg)
|
|
699
858
|
|
|
700
859
|
result.append(serialized_msg)
|
|
701
860
|
|
|
@@ -710,12 +869,12 @@ def _serialize_result_output(result: Any) -> Any:
|
|
|
710
869
|
output_dict = {}
|
|
711
870
|
|
|
712
871
|
if hasattr(result, "output"):
|
|
713
|
-
output_dict["output"] =
|
|
872
|
+
output_dict["output"] = bt_safe_deep_copy(result.output)
|
|
714
873
|
|
|
715
874
|
if hasattr(result, "response"):
|
|
716
875
|
output_dict["response"] = _serialize_model_response(result.response)
|
|
717
876
|
|
|
718
|
-
return output_dict if output_dict else
|
|
877
|
+
return output_dict if output_dict else bt_safe_deep_copy(result)
|
|
719
878
|
|
|
720
879
|
|
|
721
880
|
def _serialize_stream_output(stream_result: Any) -> Any:
|
|
@@ -736,16 +895,15 @@ def _serialize_model_response(response: Any) -> Any:
|
|
|
736
895
|
if not response:
|
|
737
896
|
return None
|
|
738
897
|
|
|
739
|
-
response_dict =
|
|
898
|
+
response_dict = bt_safe_deep_copy(response)
|
|
740
899
|
|
|
741
|
-
if
|
|
742
|
-
|
|
743
|
-
response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
|
|
900
|
+
if hasattr(response, "parts") and isinstance(response_dict, dict):
|
|
901
|
+
response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
|
|
744
902
|
|
|
745
903
|
return response_dict
|
|
746
904
|
|
|
747
905
|
|
|
748
|
-
def _extract_model_info_from_model_instance(model: Any) -> tuple[
|
|
906
|
+
def _extract_model_info_from_model_instance(model: Any) -> tuple[str | None, str | None]:
|
|
749
907
|
"""Extract model name and provider from a model instance.
|
|
750
908
|
|
|
751
909
|
Args:
|
|
@@ -785,7 +943,7 @@ def _extract_model_info_from_model_instance(model: Any) -> tuple[Optional[str],
|
|
|
785
943
|
return None, None
|
|
786
944
|
|
|
787
945
|
|
|
788
|
-
def _extract_model_info(agent: Any) -> tuple[
|
|
946
|
+
def _extract_model_info(agent: Any) -> tuple[str | None, str | None]:
|
|
789
947
|
"""Extract model name and provider from agent.
|
|
790
948
|
|
|
791
949
|
Args:
|
|
@@ -800,9 +958,7 @@ def _extract_model_info(agent: Any) -> tuple[Optional[str], Optional[str]]:
|
|
|
800
958
|
return _extract_model_info_from_model_instance(agent.model)
|
|
801
959
|
|
|
802
960
|
|
|
803
|
-
def _build_model_metadata(
|
|
804
|
-
model_name: Optional[str], provider: Optional[str], model_settings: Any = None
|
|
805
|
-
) -> Dict[str, Any]:
|
|
961
|
+
def _build_model_metadata(model_name: str | None, provider: str | None, model_settings: Any = None) -> dict[str, Any]:
|
|
806
962
|
"""Build metadata dictionary with model info.
|
|
807
963
|
|
|
808
964
|
Args:
|
|
@@ -819,11 +975,11 @@ def _build_model_metadata(
|
|
|
819
975
|
if provider:
|
|
820
976
|
metadata["provider"] = provider
|
|
821
977
|
if model_settings:
|
|
822
|
-
metadata["model_settings"] =
|
|
978
|
+
metadata["model_settings"] = bt_safe_deep_copy(model_settings)
|
|
823
979
|
return metadata
|
|
824
980
|
|
|
825
981
|
|
|
826
|
-
def _parse_model_string(model: Any) -> tuple[
|
|
982
|
+
def _parse_model_string(model: Any) -> tuple[str | None, str | None]:
|
|
827
983
|
"""Parse model string to extract provider and model name.
|
|
828
984
|
|
|
829
985
|
Pydantic AI uses format: "provider:model-name" (e.g., "openai:gpt-4o")
|
|
@@ -840,9 +996,9 @@ def _parse_model_string(model: Any) -> tuple[Optional[str], Optional[str]]:
|
|
|
840
996
|
return model_str, None
|
|
841
997
|
|
|
842
998
|
|
|
843
|
-
def _extract_usage_metrics(result: Any, start_time: float, end_time: float) ->
|
|
999
|
+
def _extract_usage_metrics(result: Any, start_time: float, end_time: float) -> dict[str, float] | None:
|
|
844
1000
|
"""Extract usage metrics from agent run result."""
|
|
845
|
-
metrics:
|
|
1001
|
+
metrics: dict[str, float] = {}
|
|
846
1002
|
|
|
847
1003
|
metrics["start"] = start_time
|
|
848
1004
|
metrics["end"] = end_time
|
|
@@ -903,10 +1059,10 @@ def _extract_usage_metrics(result: Any, start_time: float, end_time: float) -> O
|
|
|
903
1059
|
|
|
904
1060
|
|
|
905
1061
|
def _extract_stream_usage_metrics(
|
|
906
|
-
stream_result: Any, start_time: float, end_time: float, first_token_time:
|
|
907
|
-
) ->
|
|
1062
|
+
stream_result: Any, start_time: float, end_time: float, first_token_time: float | None
|
|
1063
|
+
) -> dict[str, float] | None:
|
|
908
1064
|
"""Extract usage metrics from stream result."""
|
|
909
|
-
metrics:
|
|
1065
|
+
metrics: dict[str, float] = {}
|
|
910
1066
|
|
|
911
1067
|
metrics["start"] = start_time
|
|
912
1068
|
metrics["end"] = end_time
|
|
@@ -942,10 +1098,10 @@ def _extract_stream_usage_metrics(
|
|
|
942
1098
|
|
|
943
1099
|
|
|
944
1100
|
def _extract_response_metrics(
|
|
945
|
-
response: Any, start_time: float, end_time: float, first_token_time:
|
|
946
|
-
) ->
|
|
1101
|
+
response: Any, start_time: float, end_time: float, first_token_time: float | None = None
|
|
1102
|
+
) -> dict[str, float] | None:
|
|
947
1103
|
"""Extract metrics from model response."""
|
|
948
|
-
metrics:
|
|
1104
|
+
metrics: dict[str, float] = {}
|
|
949
1105
|
|
|
950
1106
|
metrics["start"] = start_time
|
|
951
1107
|
metrics["end"] = end_time
|
|
@@ -985,24 +1141,6 @@ def _is_patched(obj: Any) -> bool:
|
|
|
985
1141
|
return getattr(obj, "_braintrust_patched", False)
|
|
986
1142
|
|
|
987
1143
|
|
|
988
|
-
def _try_dict(obj: Any) -> Union[Iterable[Any], Dict[str, Any]]:
|
|
989
|
-
"""Try to convert object to dict, handling Pydantic models and circular references."""
|
|
990
|
-
if hasattr(obj, "model_dump"):
|
|
991
|
-
try:
|
|
992
|
-
obj = obj.model_dump(exclude_none=True)
|
|
993
|
-
except ValueError as e:
|
|
994
|
-
if "Circular reference" in str(e):
|
|
995
|
-
return {}
|
|
996
|
-
raise
|
|
997
|
-
|
|
998
|
-
if isinstance(obj, dict):
|
|
999
|
-
return {k: _try_dict(v) for k, v in obj.items()}
|
|
1000
|
-
elif isinstance(obj, (list, tuple)):
|
|
1001
|
-
return [_try_dict(item) for item in obj]
|
|
1002
|
-
|
|
1003
|
-
return obj
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
1144
|
def _serialize_type(obj: Any) -> Any:
|
|
1007
1145
|
"""Serialize a type/class for logging, handling Pydantic models and other types.
|
|
1008
1146
|
|
|
@@ -1046,35 +1184,10 @@ def _serialize_type(obj: Any) -> Any:
|
|
|
1046
1184
|
return obj.__name__
|
|
1047
1185
|
|
|
1048
1186
|
# Try standard serialization
|
|
1049
|
-
return
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
G = TypeVar("G", bound=AsyncGenerator[Any, None])
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
class aclosing(AbstractAsyncContextManager[G]):
|
|
1056
|
-
"""Context manager for closing async generators."""
|
|
1057
|
-
|
|
1058
|
-
def __init__(self, async_generator: G):
|
|
1059
|
-
self.async_generator = async_generator
|
|
1060
|
-
|
|
1061
|
-
async def __aenter__(self):
|
|
1062
|
-
return self.async_generator
|
|
1063
|
-
|
|
1064
|
-
async def __aexit__(self, *exc_info: Any):
|
|
1065
|
-
try:
|
|
1066
|
-
await self.async_generator.aclose()
|
|
1067
|
-
except ValueError as e:
|
|
1068
|
-
if "was created in a different Context" not in str(e):
|
|
1069
|
-
raise
|
|
1070
|
-
else:
|
|
1071
|
-
logger.debug(
|
|
1072
|
-
f"Suppressed ContextVar error during async cleanup: {e}. "
|
|
1073
|
-
"This is expected when async generators yield across context boundaries."
|
|
1074
|
-
)
|
|
1187
|
+
return bt_safe_deep_copy(obj)
|
|
1075
1188
|
|
|
1076
1189
|
|
|
1077
|
-
def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[
|
|
1190
|
+
def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
1078
1191
|
"""Build input data and metadata for agent wrappers.
|
|
1079
1192
|
|
|
1080
1193
|
Returns:
|
|
@@ -1096,9 +1209,9 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu
|
|
|
1096
1209
|
input_data[key] = _serialize_type(value) if value is not None else None
|
|
1097
1210
|
elif key == "model_settings":
|
|
1098
1211
|
# model_settings passed to run() goes in INPUT (it's a run() parameter)
|
|
1099
|
-
input_data[key] =
|
|
1212
|
+
input_data[key] = bt_safe_deep_copy(value) if value is not None else None
|
|
1100
1213
|
else:
|
|
1101
|
-
input_data[key] =
|
|
1214
|
+
input_data[key] = bt_safe_deep_copy(value) if value is not None else None
|
|
1102
1215
|
|
|
1103
1216
|
if "model" in kwargs:
|
|
1104
1217
|
model_name, provider = _parse_model_string(kwargs["model"])
|
|
@@ -1154,7 +1267,9 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu
|
|
|
1154
1267
|
if hasattr(tool_obj, "description") and tool_obj.description:
|
|
1155
1268
|
tool_dict["description"] = tool_obj.description
|
|
1156
1269
|
# Extract JSON schema for parameters
|
|
1157
|
-
if hasattr(tool_obj, "function_schema") and hasattr(
|
|
1270
|
+
if hasattr(tool_obj, "function_schema") and hasattr(
|
|
1271
|
+
tool_obj.function_schema, "json_schema"
|
|
1272
|
+
):
|
|
1158
1273
|
tool_dict["parameters"] = tool_obj.function_schema.json_schema
|
|
1159
1274
|
tools_list.append(tool_dict)
|
|
1160
1275
|
ts_info["tools"] = tools_list
|
|
@@ -1177,7 +1292,7 @@ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tu
|
|
|
1177
1292
|
return input_data, metadata
|
|
1178
1293
|
|
|
1179
1294
|
|
|
1180
|
-
def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[
|
|
1295
|
+
def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
1181
1296
|
"""Build input data and metadata for direct model request wrappers.
|
|
1182
1297
|
|
|
1183
1298
|
Returns:
|
|
@@ -1195,7 +1310,7 @@ def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[Dict
|
|
|
1195
1310
|
|
|
1196
1311
|
for key, value in kwargs.items():
|
|
1197
1312
|
if key not in ["model", "messages"]:
|
|
1198
|
-
input_data[key] =
|
|
1313
|
+
input_data[key] = bt_safe_deep_copy(value) if value is not None else None
|
|
1199
1314
|
|
|
1200
1315
|
model_name, provider = _parse_model_string(model)
|
|
1201
1316
|
metadata = _build_model_metadata(model_name, provider)
|