braintrust 0.3.14__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1203 @@
1
+ import logging
2
+ import sys
3
+ import time
4
+ from contextlib import AbstractAsyncContextManager
5
+ from typing import Any, AsyncGenerator, Dict, Iterable, Optional, TypeVar, Union
6
+
7
+ from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
8
+ from braintrust.span_types import SpanTypeAttribute
9
+ from wrapt import wrap_function_wrapper
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ __all__ = ["setup_pydantic_ai"]
14
+
15
+
16
+ def setup_pydantic_ai(
17
+ api_key: Optional[str] = None,
18
+ project_id: Optional[str] = None,
19
+ project_name: Optional[str] = None,
20
+ ) -> bool:
21
+ """
22
+ Setup Braintrust integration with Pydantic AI. Will automatically patch Pydantic AI Agents and direct API functions for automatic tracing.
23
+
24
+ Args:
25
+ api_key (Optional[str]): Braintrust API key.
26
+ project_id (Optional[str]): Braintrust project ID.
27
+ project_name (Optional[str]): Braintrust project name.
28
+
29
+ Returns:
30
+ bool: True if setup was successful, False otherwise.
31
+ """
32
+ span = current_span()
33
+ if span == NOOP_SPAN:
34
+ init_logger(project=project_name, api_key=api_key, project_id=project_id)
35
+
36
+ try:
37
+ import pydantic_ai.direct as direct_module
38
+ from pydantic_ai import Agent
39
+
40
+ Agent = wrap_agent(Agent)
41
+
42
+ wrap_function_wrapper(direct_module, "model_request", _create_direct_model_request_wrapper())
43
+ wrap_function_wrapper(direct_module, "model_request_sync", _create_direct_model_request_sync_wrapper())
44
+ wrap_function_wrapper(direct_module, "model_request_stream", _create_direct_model_request_stream_wrapper())
45
+ wrap_function_wrapper(
46
+ direct_module, "model_request_stream_sync", _create_direct_model_request_stream_sync_wrapper()
47
+ )
48
+
49
+ wrap_model_classes()
50
+
51
+ return True
52
+ except ImportError as e:
53
+ logger.error(f"Failed to import Pydantic AI: {e}")
54
+ logger.error("Pydantic AI is not installed. Please install it with: pip install pydantic-ai-slim")
55
+ return False
56
+
57
+
58
+ def wrap_agent(Agent: Any) -> Any:
59
+ if _is_patched(Agent):
60
+ return Agent
61
+
62
+ def _ensure_model_wrapped(instance: Any):
63
+ """Ensure the agent's model class is wrapped (lazy wrapping)."""
64
+ if hasattr(instance, "_model"):
65
+ model_class = type(instance._model)
66
+ _wrap_concrete_model_class(model_class)
67
+
68
+ async def agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
69
+ _ensure_model_wrapped(instance)
70
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
71
+
72
+ with start_span(
73
+ name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
74
+ type=SpanTypeAttribute.LLM,
75
+ input=input_data if input_data else None,
76
+ metadata=_try_dict(metadata),
77
+ ) as agent_span:
78
+ start_time = time.time()
79
+ result = await wrapped(*args, **kwargs)
80
+ end_time = time.time()
81
+
82
+ output = _serialize_result_output(result)
83
+ metrics = _extract_usage_metrics(result, start_time, end_time)
84
+
85
+ agent_span.log(output=output, metrics=metrics)
86
+ return result
87
+
88
+ wrap_function_wrapper(Agent, "run", agent_run_wrapper)
89
+
90
+ def agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
91
+ _ensure_model_wrapped(instance)
92
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
93
+
94
+ with start_span(
95
+ name=f"agent_run_sync [{instance.name}]"
96
+ if hasattr(instance, "name") and instance.name
97
+ else "agent_run_sync",
98
+ type=SpanTypeAttribute.LLM,
99
+ input=input_data if input_data else None,
100
+ metadata=_try_dict(metadata),
101
+ ) as agent_span:
102
+ start_time = time.time()
103
+ result = wrapped(*args, **kwargs)
104
+ end_time = time.time()
105
+
106
+ output = _serialize_result_output(result)
107
+ metrics = _extract_usage_metrics(result, start_time, end_time)
108
+
109
+ agent_span.log(output=output, metrics=metrics)
110
+ return result
111
+
112
+ wrap_function_wrapper(Agent, "run_sync", agent_run_sync_wrapper)
113
+
114
+ def agent_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
115
+ _ensure_model_wrapped(instance)
116
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
117
+ agent_name = instance.name if hasattr(instance, "name") else None
118
+ span_name = f"agent_run_stream [{agent_name}]" if agent_name else "agent_run_stream"
119
+
120
+ return _AgentStreamWrapper(
121
+ wrapped(*args, **kwargs),
122
+ span_name,
123
+ input_data,
124
+ metadata,
125
+ )
126
+
127
+ wrap_function_wrapper(Agent, "run_stream", agent_run_stream_wrapper)
128
+
129
+ def agent_run_stream_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
130
+ _ensure_model_wrapped(instance)
131
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
132
+ agent_name = instance.name if hasattr(instance, "name") else None
133
+ span_name = f"agent_run_stream_sync [{agent_name}]" if agent_name else "agent_run_stream_sync"
134
+
135
+ # Create span context BEFORE calling wrapped function so internal spans nest under it
136
+ span_cm = start_span(
137
+ name=span_name,
138
+ type=SpanTypeAttribute.LLM,
139
+ input=input_data if input_data else None,
140
+ metadata=_try_dict(metadata),
141
+ )
142
+ span = span_cm.__enter__()
143
+ start_time = time.time()
144
+
145
+ try:
146
+ # Call the original function within the span context
147
+ stream_result = wrapped(*args, **kwargs)
148
+ return _AgentStreamResultSyncProxy(
149
+ stream_result,
150
+ span,
151
+ span_cm,
152
+ start_time,
153
+ )
154
+ except Exception:
155
+ # Clean up span on error
156
+ span_cm.__exit__(*sys.exc_info())
157
+ raise
158
+
159
+ wrap_function_wrapper(Agent, "run_stream_sync", agent_run_stream_sync_wrapper)
160
+
161
+ async def agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
162
+ _ensure_model_wrapped(instance)
163
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
164
+
165
+ agent_name = instance.name if hasattr(instance, "name") else None
166
+ span_name = f"agent_run_stream_events [{agent_name}]" if agent_name else "agent_run_stream_events"
167
+
168
+ with start_span(
169
+ name=span_name,
170
+ type=SpanTypeAttribute.LLM,
171
+ input=input_data if input_data else None,
172
+ metadata=_try_dict(metadata),
173
+ ) as agent_span:
174
+ start_time = time.time()
175
+ event_count = 0
176
+ final_result = None
177
+
178
+ async for event in wrapped(*args, **kwargs):
179
+ event_count += 1
180
+ if hasattr(event, "output"):
181
+ final_result = event
182
+ yield event
183
+
184
+ end_time = time.time()
185
+
186
+ output = None
187
+ metrics = {
188
+ "start": start_time,
189
+ "end": end_time,
190
+ "duration": end_time - start_time,
191
+ "event_count": event_count,
192
+ }
193
+
194
+ if final_result:
195
+ output = _serialize_result_output(final_result)
196
+ usage_metrics = _extract_usage_metrics(final_result, start_time, end_time)
197
+ metrics.update(usage_metrics)
198
+
199
+ agent_span.log(output=output, metrics=metrics)
200
+
201
+ wrap_function_wrapper(Agent, "run_stream_events", agent_run_stream_events_wrapper)
202
+
203
+ Agent._braintrust_patched = True
204
+
205
+ return Agent
206
+
207
+
208
+ def _create_direct_model_request_wrapper():
209
+ """Create wrapper for direct.model_request()."""
210
+
211
+ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
212
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
213
+
214
+ with start_span(
215
+ name="model_request",
216
+ type=SpanTypeAttribute.LLM,
217
+ input=input_data,
218
+ metadata=_try_dict(metadata),
219
+ ) as span:
220
+ start_time = time.time()
221
+ result = await wrapped(*args, **kwargs)
222
+ end_time = time.time()
223
+
224
+ output = _serialize_model_response(result)
225
+ metrics = _extract_response_metrics(result, start_time, end_time)
226
+
227
+ span.log(output=output, metrics=metrics)
228
+ return result
229
+
230
+ return wrapper
231
+
232
+
233
+ def _create_direct_model_request_sync_wrapper():
234
+ """Create wrapper for direct.model_request_sync()."""
235
+
236
+ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
237
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
238
+
239
+ with start_span(
240
+ name="model_request_sync",
241
+ type=SpanTypeAttribute.LLM,
242
+ input=input_data,
243
+ metadata=_try_dict(metadata),
244
+ ) as span:
245
+ start_time = time.time()
246
+ result = wrapped(*args, **kwargs)
247
+ end_time = time.time()
248
+
249
+ output = _serialize_model_response(result)
250
+ metrics = _extract_response_metrics(result, start_time, end_time)
251
+
252
+ span.log(output=output, metrics=metrics)
253
+ return result
254
+
255
+ return wrapper
256
+
257
+
258
+ def _create_direct_model_request_stream_wrapper():
259
+ """Create wrapper for direct.model_request_stream()."""
260
+
261
+ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
262
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
263
+
264
+ return _DirectStreamWrapper(
265
+ wrapped(*args, **kwargs),
266
+ "model_request_stream",
267
+ input_data,
268
+ metadata,
269
+ )
270
+
271
+ return wrapper
272
+
273
+
274
+ def _create_direct_model_request_stream_sync_wrapper():
275
+ """Create wrapper for direct.model_request_stream_sync()."""
276
+
277
+ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
278
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
279
+
280
+ return _DirectStreamWrapperSync(
281
+ wrapped(*args, **kwargs),
282
+ "model_request_stream_sync",
283
+ input_data,
284
+ metadata,
285
+ )
286
+
287
+ return wrapper
288
+
289
+
290
+ def wrap_model_request(original_func: Any) -> Any:
291
+ async def wrapper(*args, **kwargs):
292
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
293
+
294
+ with start_span(
295
+ name="model_request",
296
+ type=SpanTypeAttribute.LLM,
297
+ input=input_data,
298
+ metadata=_try_dict(metadata),
299
+ ) as span:
300
+ start_time = time.time()
301
+ result = await original_func(*args, **kwargs)
302
+ end_time = time.time()
303
+
304
+ output = _serialize_model_response(result)
305
+ metrics = _extract_response_metrics(result, start_time, end_time)
306
+
307
+ span.log(output=output, metrics=metrics)
308
+ return result
309
+
310
+ return wrapper
311
+
312
+
313
+ def wrap_model_request_sync(original_func: Any) -> Any:
314
+ def wrapper(*args, **kwargs):
315
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
316
+
317
+ with start_span(
318
+ name="model_request_sync",
319
+ type=SpanTypeAttribute.LLM,
320
+ input=input_data,
321
+ metadata=_try_dict(metadata),
322
+ ) as span:
323
+ start_time = time.time()
324
+ result = original_func(*args, **kwargs)
325
+ end_time = time.time()
326
+
327
+ output = _serialize_model_response(result)
328
+ metrics = _extract_response_metrics(result, start_time, end_time)
329
+
330
+ span.log(output=output, metrics=metrics)
331
+ return result
332
+
333
+ return wrapper
334
+
335
+
336
+ def wrap_model_request_stream(original_func: Any) -> Any:
337
+ def wrapper(*args, **kwargs):
338
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
339
+
340
+ return _DirectStreamWrapper(
341
+ original_func(*args, **kwargs),
342
+ "model_request_stream",
343
+ input_data,
344
+ metadata,
345
+ )
346
+
347
+ return wrapper
348
+
349
+
350
+ def wrap_model_request_stream_sync(original_func: Any) -> Any:
351
+ def wrapper(*args, **kwargs):
352
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
353
+
354
+ return _DirectStreamWrapperSync(
355
+ original_func(*args, **kwargs),
356
+ "model_request_stream_sync",
357
+ input_data,
358
+ metadata,
359
+ )
360
+
361
+ return wrapper
362
+
363
+
364
+ def wrap_model_classes():
365
+ """Wrap Model classes to capture internal model requests made by agents."""
366
+ try:
367
+ from pydantic_ai.models import Model
368
+
369
+ def wrap_all_subclasses(base_class):
370
+ """Recursively wrap all subclasses of a base class."""
371
+ for subclass in base_class.__subclasses__():
372
+ if not getattr(subclass, "__abstractmethods__", None):
373
+ try:
374
+ _wrap_concrete_model_class(subclass)
375
+ except Exception as e:
376
+ logger.debug(f"Could not wrap {subclass.__name__}: {e}")
377
+
378
+ wrap_all_subclasses(subclass)
379
+
380
+ wrap_all_subclasses(Model)
381
+
382
+ except Exception as e:
383
+ logger.warning(f"Failed to wrap Model classes: {e}")
384
+
385
+
386
+ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any):
387
+ """Build input data and metadata for model class request wrappers.
388
+
389
+ Returns:
390
+ Tuple of (model_name, display_name, input_data, metadata)
391
+ """
392
+ model_name, provider = _extract_model_info_from_model_instance(instance)
393
+ display_name = model_name or str(instance)
394
+
395
+ messages = args[0] if len(args) > 0 else kwargs.get("messages")
396
+ model_settings = args[1] if len(args) > 1 else kwargs.get("model_settings")
397
+
398
+ serialized_messages = _serialize_messages(messages)
399
+
400
+ input_data = {"messages": serialized_messages}
401
+ if model_settings is not None:
402
+ input_data["model_settings"] = _try_dict(model_settings)
403
+
404
+ metadata = _build_model_metadata(model_name, provider, model_settings=None)
405
+
406
+ return model_name, display_name, input_data, metadata
407
+
408
+
409
+ def _wrap_concrete_model_class(model_class: Any):
410
+ """Wrap a concrete model class to trace its request methods."""
411
+ if _is_patched(model_class):
412
+ return
413
+
414
+ async def model_request_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
415
+ model_name, display_name, input_data, metadata = _build_model_class_input_and_metadata(instance, args, kwargs)
416
+
417
+ with start_span(
418
+ name=f"chat {display_name}",
419
+ type=SpanTypeAttribute.LLM,
420
+ input=input_data,
421
+ metadata=_try_dict(metadata),
422
+ ) as span:
423
+ start_time = time.time()
424
+ result = await wrapped(*args, **kwargs)
425
+ end_time = time.time()
426
+
427
+ output = _serialize_model_response(result)
428
+ metrics = _extract_response_metrics(result, start_time, end_time)
429
+
430
+ span.log(output=output, metrics=metrics)
431
+ return result
432
+
433
+ def model_request_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
434
+ model_name, display_name, input_data, metadata = _build_model_class_input_and_metadata(instance, args, kwargs)
435
+
436
+ return _DirectStreamWrapper(
437
+ wrapped(*args, **kwargs),
438
+ f"chat {display_name}",
439
+ input_data,
440
+ metadata,
441
+ )
442
+
443
+ wrap_function_wrapper(model_class, "request", model_request_wrapper)
444
+ wrap_function_wrapper(model_class, "request_stream", model_request_stream_wrapper)
445
+ model_class._braintrust_patched = True
446
+
447
+
448
+ class _AgentStreamWrapper(AbstractAsyncContextManager):
449
+ """Wrapper for agent.run_stream() that adds tracing while passing through the stream result."""
450
+
451
+ def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
452
+ self.stream_cm = stream_cm
453
+ self.span_name = span_name
454
+ self.input_data = input_data
455
+ self.metadata = metadata
456
+ self.span_cm = None
457
+ self.start_time = None
458
+ self.stream_result = None
459
+
460
+ async def __aenter__(self):
461
+ # Use context manager properly so span stays current
462
+ # DON'T pass start_time here - we'll set it via metrics in __aexit__
463
+ self.span_cm = start_span(
464
+ name=self.span_name,
465
+ type=SpanTypeAttribute.LLM,
466
+ input=self.input_data if self.input_data else None,
467
+ metadata=_try_dict(self.metadata),
468
+ )
469
+ span = self.span_cm.__enter__()
470
+
471
+ # Capture start time right before entering the stream (API call initiation)
472
+ self.start_time = time.time()
473
+ self.stream_result = await self.stream_cm.__aenter__()
474
+ return self.stream_result # Return actual stream result object
475
+
476
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
477
+ try:
478
+ await self.stream_cm.__aexit__(exc_type, exc_val, exc_tb)
479
+ finally:
480
+ if self.span_cm and self.start_time and self.stream_result:
481
+ end_time = time.time()
482
+
483
+ output = _serialize_stream_output(self.stream_result)
484
+ metrics = _extract_stream_usage_metrics(self.stream_result, self.start_time, end_time, None)
485
+ self.span_cm.log(output=output, metrics=metrics)
486
+
487
+ # Always clean up span context
488
+ if self.span_cm:
489
+ self.span_cm.__exit__(None, None, None)
490
+
491
+ return False
492
+
493
+
494
+ class _DirectStreamWrapper(AbstractAsyncContextManager):
495
+ """Wrapper for model_request_stream() that adds tracing while passing through the stream."""
496
+
497
+ def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
498
+ self.stream_cm = stream_cm
499
+ self.span_name = span_name
500
+ self.input_data = input_data
501
+ self.metadata = metadata
502
+ self.span_cm = None
503
+ self.start_time = None
504
+ self.stream = None
505
+
506
+ async def __aenter__(self):
507
+ # Use context manager properly so span stays current
508
+ # DON'T pass start_time here - we'll set it via metrics in __aexit__
509
+ self.span_cm = start_span(
510
+ name=self.span_name,
511
+ type=SpanTypeAttribute.LLM,
512
+ input=self.input_data if self.input_data else None,
513
+ metadata=_try_dict(self.metadata),
514
+ )
515
+ span = self.span_cm.__enter__()
516
+
517
+ # Capture start time right before entering the stream (API call initiation)
518
+ self.start_time = time.time()
519
+ self.stream = await self.stream_cm.__aenter__()
520
+ return self.stream # Return actual stream object
521
+
522
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
523
+ try:
524
+ await self.stream_cm.__aexit__(exc_type, exc_val, exc_tb)
525
+ finally:
526
+ if self.span_cm and self.start_time and self.stream:
527
+ end_time = time.time()
528
+
529
+ try:
530
+ final_response = self.stream.get()
531
+ output = _serialize_model_response(final_response)
532
+ metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
533
+ self.span_cm.log(output=output, metrics=metrics)
534
+ except Exception as e:
535
+ logger.debug(f"Failed to extract stream output/metrics: {e}")
536
+
537
+ # Always clean up span context
538
+ if self.span_cm:
539
+ self.span_cm.__exit__(None, None, None)
540
+
541
+ return False
542
+
543
+
544
+ class _AgentStreamResultSyncProxy:
545
+ """Proxy for agent.run_stream_sync() result that adds tracing while delegating to actual stream result."""
546
+
547
+ def __init__(self, stream_result: Any, span: Any, span_cm: Any, start_time: float):
548
+ self._stream_result = stream_result
549
+ self._span = span
550
+ self._span_cm = span_cm
551
+ self._start_time = start_time
552
+ self._logged = False
553
+ self._finalize_on_del = True
554
+
555
+ def __getattr__(self, name: str):
556
+ """Delegate all attribute access to the wrapped stream result."""
557
+ attr = getattr(self._stream_result, name)
558
+
559
+ # Wrap any method that returns an iterator to auto-finalize when exhausted
560
+ if callable(attr) and name in ('stream_text', 'stream_output', '__iter__'):
561
+ def wrapped_method(*args, **kwargs):
562
+ try:
563
+ iterator = attr(*args, **kwargs)
564
+ # If it's an iterator, wrap it
565
+ if hasattr(iterator, '__iter__') or hasattr(iterator, '__next__'):
566
+ try:
567
+ yield from iterator
568
+ finally:
569
+ self._finalize()
570
+ self._finalize_on_del = False # Don't finalize again in __del__
571
+ else:
572
+ return iterator
573
+ except Exception:
574
+ self._finalize()
575
+ self._finalize_on_del = False
576
+ raise
577
+ return wrapped_method
578
+
579
+ return attr
580
+
581
+ def _finalize(self):
582
+ """Log metrics and close span."""
583
+ if self._span and not self._logged and self._stream_result:
584
+ try:
585
+ end_time = time.time()
586
+ output = _serialize_stream_output(self._stream_result)
587
+ metrics = _extract_stream_usage_metrics(self._stream_result, self._start_time, end_time, None)
588
+ self._span.log(output=output, metrics=metrics)
589
+ self._logged = True
590
+ finally:
591
+ try:
592
+ self._span_cm.__exit__(None, None, None)
593
+ except Exception:
594
+ pass
595
+
596
+ def __del__(self):
597
+ """Ensure span is closed when proxy is destroyed."""
598
+ if self._finalize_on_del:
599
+ self._finalize()
600
+
601
+
602
+ class _DirectStreamWrapperSync:
603
+ """Wrapper for model_request_stream_sync() that adds tracing while passing through the stream."""
604
+
605
+ def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
606
+ self.stream_cm = stream_cm
607
+ self.span_name = span_name
608
+ self.input_data = input_data
609
+ self.metadata = metadata
610
+ self.span_cm = None
611
+ self.start_time = None
612
+ self.stream = None
613
+
614
+ def __enter__(self):
615
+ # Use context manager properly so span stays current
616
+ # DON'T pass start_time here - we'll set it via metrics in __exit__
617
+ self.span_cm = start_span(
618
+ name=self.span_name,
619
+ type=SpanTypeAttribute.LLM,
620
+ input=self.input_data if self.input_data else None,
621
+ metadata=_try_dict(self.metadata),
622
+ )
623
+ span = self.span_cm.__enter__()
624
+
625
+ # Capture start time right before entering the stream (API call initiation)
626
+ self.start_time = time.time()
627
+ self.stream = self.stream_cm.__enter__()
628
+ return self.stream # Return actual stream object
629
+
630
+ def __exit__(self, exc_type, exc_val, exc_tb):
631
+ try:
632
+ self.stream_cm.__exit__(exc_type, exc_val, exc_tb)
633
+ finally:
634
+ if self.span_cm and self.start_time and self.stream:
635
+ end_time = time.time()
636
+
637
+ try:
638
+ final_response = self.stream.get()
639
+ output = _serialize_model_response(final_response)
640
+ metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
641
+ self.span_cm.log(output=output, metrics=metrics)
642
+ except Exception as e:
643
+ logger.debug(f"Failed to extract stream output/metrics: {e}")
644
+
645
+ # Always clean up span context
646
+ if self.span_cm:
647
+ self.span_cm.__exit__(None, None, None)
648
+
649
+ return False
650
+
651
+
652
+ def _serialize_user_prompt(user_prompt: Any) -> Any:
653
+ """Serialize user prompt, handling BinaryContent and other types."""
654
+ if user_prompt is None:
655
+ return None
656
+
657
+ if isinstance(user_prompt, str):
658
+ return user_prompt
659
+
660
+ if isinstance(user_prompt, list):
661
+ return [_serialize_content_part(part) for part in user_prompt]
662
+
663
+ return _serialize_content_part(user_prompt)
664
+
665
+
666
+ def _serialize_content_part(part: Any) -> Any:
667
+ """Serialize a content part, handling BinaryContent specially."""
668
+ if part is None:
669
+ return None
670
+
671
+ if hasattr(part, "data") and hasattr(part, "media_type") and hasattr(part, "kind"):
672
+ if part.kind == "binary":
673
+ data = part.data
674
+ media_type = part.media_type
675
+
676
+ extension = media_type.split("/")[1] if "/" in media_type else "bin"
677
+ filename = f"file.{extension}"
678
+
679
+ attachment = Attachment(data=data, filename=filename, content_type=media_type)
680
+ return {"type": "binary", "attachment": attachment, "media_type": media_type}
681
+
682
+ if isinstance(part, str):
683
+ return part
684
+
685
+ return _try_dict(part)
686
+
687
+
688
+ def _serialize_messages(messages: Any) -> Any:
689
+ """Serialize messages list."""
690
+ if not messages:
691
+ return []
692
+
693
+ result = []
694
+ for msg in messages:
695
+ serialized_msg = _try_dict(msg)
696
+
697
+ if isinstance(serialized_msg, dict) and "parts" in serialized_msg:
698
+ serialized_msg["parts"] = [_serialize_content_part(p) for p in msg.parts]
699
+
700
+ result.append(serialized_msg)
701
+
702
+ return result
703
+
704
+
705
+ def _serialize_result_output(result: Any) -> Any:
706
+ """Serialize agent run result output."""
707
+ if not result:
708
+ return None
709
+
710
+ output_dict = {}
711
+
712
+ if hasattr(result, "output"):
713
+ output_dict["output"] = _try_dict(result.output)
714
+
715
+ if hasattr(result, "response"):
716
+ output_dict["response"] = _serialize_model_response(result.response)
717
+
718
+ return output_dict if output_dict else _try_dict(result)
719
+
720
+
721
+ def _serialize_stream_output(stream_result: Any) -> Any:
722
+ """Serialize stream result output."""
723
+ if not stream_result:
724
+ return None
725
+
726
+ output_dict = {}
727
+
728
+ if hasattr(stream_result, "response"):
729
+ output_dict["response"] = _serialize_model_response(stream_result.response)
730
+
731
+ return output_dict if output_dict else None
732
+
733
+
734
+ def _serialize_model_response(response: Any) -> Any:
735
+ """Serialize a model response."""
736
+ if not response:
737
+ return None
738
+
739
+ response_dict = _try_dict(response)
740
+
741
+ if isinstance(response_dict, dict) and "parts" in response_dict:
742
+ if hasattr(response, "parts"):
743
+ response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
744
+
745
+ return response_dict
746
+
747
+
748
+ def _extract_model_info_from_model_instance(model: Any) -> tuple[Optional[str], Optional[str]]:
749
+ """Extract model name and provider from a model instance.
750
+
751
+ Args:
752
+ model: A Pydantic AI model instance (OpenAIChatModel, AnthropicModel, etc.)
753
+
754
+ Returns:
755
+ Tuple of (model_name, provider)
756
+ """
757
+ if not model:
758
+ return None, None
759
+
760
+ if isinstance(model, str):
761
+ return _parse_model_string(model)
762
+
763
+ if hasattr(model, "model_name"):
764
+ model_name = model.model_name
765
+ class_name = type(model).__name__
766
+ provider = None
767
+ if "OpenAI" in class_name:
768
+ provider = "openai"
769
+ elif "Anthropic" in class_name:
770
+ provider = "anthropic"
771
+ elif "Gemini" in class_name:
772
+ provider = "gemini"
773
+ elif "Groq" in class_name:
774
+ provider = "groq"
775
+ elif "Mistral" in class_name:
776
+ provider = "mistral"
777
+ elif "VertexAI" in class_name:
778
+ provider = "vertexai"
779
+
780
+ return model_name, provider
781
+
782
+ if hasattr(model, "name"):
783
+ return _parse_model_string(model.name)
784
+
785
+ return None, None
786
+
787
+
788
+ def _extract_model_info(agent: Any) -> tuple[Optional[str], Optional[str]]:
789
+ """Extract model name and provider from agent.
790
+
791
+ Args:
792
+ agent: A Pydantic AI Agent instance
793
+
794
+ Returns:
795
+ Tuple of (model_name, provider)
796
+ """
797
+ if not hasattr(agent, "model"):
798
+ return None, None
799
+
800
+ return _extract_model_info_from_model_instance(agent.model)
801
+
802
+
803
+ def _build_model_metadata(
804
+ model_name: Optional[str], provider: Optional[str], model_settings: Any = None
805
+ ) -> Dict[str, Any]:
806
+ """Build metadata dictionary with model info.
807
+
808
+ Args:
809
+ model_name: The model name (e.g., "gpt-4o")
810
+ provider: The provider (e.g., "openai")
811
+ model_settings: Optional model settings to include
812
+
813
+ Returns:
814
+ Dictionary of metadata
815
+ """
816
+ metadata = {}
817
+ if model_name:
818
+ metadata["model"] = model_name
819
+ if provider:
820
+ metadata["provider"] = provider
821
+ if model_settings:
822
+ metadata["model_settings"] = _try_dict(model_settings)
823
+ return metadata
824
+
825
+
826
+ def _parse_model_string(model: Any) -> tuple[Optional[str], Optional[str]]:
827
+ """Parse model string to extract provider and model name.
828
+
829
+ Pydantic AI uses format: "provider:model-name" (e.g., "openai:gpt-4o")
830
+ """
831
+ if not model:
832
+ return None, None
833
+
834
+ model_str = str(model)
835
+
836
+ if ":" in model_str:
837
+ parts = model_str.split(":", 1)
838
+ return parts[1], parts[0] # (model_name, provider)
839
+
840
+ return model_str, None
841
+
842
+
843
+ def _extract_usage_metrics(result: Any, start_time: float, end_time: float) -> Optional[Dict[str, float]]:
844
+ """Extract usage metrics from agent run result."""
845
+ metrics: Dict[str, float] = {}
846
+
847
+ metrics["start"] = start_time
848
+ metrics["end"] = end_time
849
+ metrics["duration"] = end_time - start_time
850
+
851
+ usage = None
852
+ if hasattr(result, "response"):
853
+ try:
854
+ response = result.response
855
+ if hasattr(response, "usage"):
856
+ usage = response.usage
857
+ except (AttributeError, ValueError):
858
+ pass
859
+
860
+ if usage is None and hasattr(result, "usage"):
861
+ usage = result.usage
862
+
863
+ if usage is None:
864
+ return metrics
865
+
866
+ if hasattr(usage, "input_tokens"):
867
+ input_tokens = usage.input_tokens
868
+ if input_tokens is not None:
869
+ metrics["prompt_tokens"] = float(input_tokens)
870
+
871
+ if hasattr(usage, "output_tokens"):
872
+ output_tokens = usage.output_tokens
873
+ if output_tokens is not None:
874
+ metrics["completion_tokens"] = float(output_tokens)
875
+
876
+ if hasattr(usage, "total_tokens"):
877
+ total_tokens = usage.total_tokens
878
+ if total_tokens is not None:
879
+ metrics["tokens"] = float(total_tokens)
880
+
881
+ if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
882
+ metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
883
+
884
+ if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
885
+ metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
886
+
887
+ if hasattr(usage, "input_audio_tokens") and usage.input_audio_tokens is not None:
888
+ metrics["prompt_audio_tokens"] = float(usage.input_audio_tokens)
889
+
890
+ if hasattr(usage, "output_audio_tokens") and usage.output_audio_tokens is not None:
891
+ metrics["completion_audio_tokens"] = float(usage.output_audio_tokens)
892
+
893
+ if hasattr(usage, "details") and isinstance(usage.details, dict):
894
+ details = usage.details
895
+
896
+ if "reasoning_tokens" in details:
897
+ metrics["completion_reasoning_tokens"] = float(details["reasoning_tokens"])
898
+
899
+ if "cached_tokens" in details:
900
+ metrics["prompt_cached_tokens"] = float(details["cached_tokens"])
901
+
902
+ return metrics if metrics else None
903
+
904
+
905
+ def _extract_stream_usage_metrics(
906
+ stream_result: Any, start_time: float, end_time: float, first_token_time: Optional[float]
907
+ ) -> Optional[Dict[str, float]]:
908
+ """Extract usage metrics from stream result."""
909
+ metrics: Dict[str, float] = {}
910
+
911
+ metrics["start"] = start_time
912
+ metrics["end"] = end_time
913
+ metrics["duration"] = end_time - start_time
914
+
915
+ if first_token_time:
916
+ metrics["time_to_first_token"] = first_token_time - start_time
917
+
918
+ if hasattr(stream_result, "usage"):
919
+ usage_func = stream_result.usage
920
+ if callable(usage_func):
921
+ usage = usage_func()
922
+ else:
923
+ usage = usage_func
924
+
925
+ if usage:
926
+ if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
927
+ metrics["prompt_tokens"] = float(usage.input_tokens)
928
+
929
+ if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
930
+ metrics["completion_tokens"] = float(usage.output_tokens)
931
+
932
+ if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
933
+ metrics["tokens"] = float(usage.total_tokens)
934
+
935
+ if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
936
+ metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
937
+
938
+ if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
939
+ metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
940
+
941
+ return metrics if metrics else None
942
+
943
+
944
+ def _extract_response_metrics(
945
+ response: Any, start_time: float, end_time: float, first_token_time: Optional[float] = None
946
+ ) -> Optional[Dict[str, float]]:
947
+ """Extract metrics from model response."""
948
+ metrics: Dict[str, float] = {}
949
+
950
+ metrics["start"] = start_time
951
+ metrics["end"] = end_time
952
+ metrics["duration"] = end_time - start_time
953
+
954
+ if first_token_time:
955
+ metrics["time_to_first_token"] = first_token_time - start_time
956
+
957
+ if hasattr(response, "usage") and response.usage:
958
+ usage = response.usage
959
+
960
+ if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
961
+ metrics["prompt_tokens"] = float(usage.input_tokens)
962
+
963
+ if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
964
+ metrics["completion_tokens"] = float(usage.output_tokens)
965
+
966
+ if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
967
+ metrics["tokens"] = float(usage.total_tokens)
968
+
969
+ if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
970
+ metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
971
+
972
+ if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
973
+ metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
974
+
975
+ # Extract reasoning tokens for reasoning models (o1/o3)
976
+ if hasattr(usage, "details") and usage.details is not None:
977
+ if hasattr(usage.details, "reasoning_tokens") and usage.details.reasoning_tokens is not None:
978
+ metrics["completion_reasoning_tokens"] = float(usage.details.reasoning_tokens)
979
+
980
+ return metrics if metrics else None
981
+
982
+
983
+ def _is_patched(obj: Any) -> bool:
984
+ """Check if object is already patched."""
985
+ return getattr(obj, "_braintrust_patched", False)
986
+
987
+
988
+ def _try_dict(obj: Any) -> Union[Iterable[Any], Dict[str, Any]]:
989
+ """Try to convert object to dict, handling Pydantic models and circular references."""
990
+ if hasattr(obj, "model_dump"):
991
+ try:
992
+ obj = obj.model_dump(exclude_none=True)
993
+ except ValueError as e:
994
+ if "Circular reference" in str(e):
995
+ return {}
996
+ raise
997
+
998
+ if isinstance(obj, dict):
999
+ return {k: _try_dict(v) for k, v in obj.items()}
1000
+ elif isinstance(obj, (list, tuple)):
1001
+ return [_try_dict(item) for item in obj]
1002
+
1003
+ return obj
1004
+
1005
+
1006
+ def _serialize_type(obj: Any) -> Any:
1007
+ """Serialize a type/class for logging, handling Pydantic models and other types.
1008
+
1009
+ This is useful for output_type, toolsets, and similar type parameters.
1010
+ Returns full JSON schema for Pydantic models so engineers can see exactly
1011
+ what structured output schema was used.
1012
+ """
1013
+ import inspect
1014
+
1015
+ # For sequences of types (like Union types or list of models)
1016
+ if isinstance(obj, (list, tuple)):
1017
+ return [_serialize_type(item) for item in obj]
1018
+
1019
+ # Handle Pydantic AI's output wrappers (ToolOutput, NativeOutput, PromptedOutput, TextOutput)
1020
+ if hasattr(obj, "output"):
1021
+ # These are wrapper classes with an 'output' field containing the actual type
1022
+ wrapper_info = {"wrapper": type(obj).__name__}
1023
+ if hasattr(obj, "name") and obj.name:
1024
+ wrapper_info["name"] = obj.name
1025
+ if hasattr(obj, "description") and obj.description:
1026
+ wrapper_info["description"] = obj.description
1027
+ wrapper_info["output"] = _serialize_type(obj.output)
1028
+ return wrapper_info
1029
+
1030
+ # If it's a Pydantic model class, return its full JSON schema
1031
+ if inspect.isclass(obj):
1032
+ try:
1033
+ from pydantic import BaseModel
1034
+
1035
+ if issubclass(obj, BaseModel):
1036
+ # Return the full JSON schema - includes all field info, descriptions, constraints, etc.
1037
+ return obj.model_json_schema()
1038
+ except (ImportError, AttributeError, TypeError):
1039
+ pass
1040
+
1041
+ # Not a Pydantic model, return class name
1042
+ return obj.__name__
1043
+
1044
+ # If it has a __name__ attribute (like functions), use that
1045
+ if hasattr(obj, "__name__"):
1046
+ return obj.__name__
1047
+
1048
+ # Try standard serialization
1049
+ return _try_dict(obj)
1050
+
1051
+
1052
+ G = TypeVar("G", bound=AsyncGenerator[Any, None])
1053
+
1054
+
1055
+ class aclosing(AbstractAsyncContextManager[G]):
1056
+ """Context manager for closing async generators."""
1057
+
1058
+ def __init__(self, async_generator: G):
1059
+ self.async_generator = async_generator
1060
+
1061
+ async def __aenter__(self):
1062
+ return self.async_generator
1063
+
1064
+ async def __aexit__(self, *exc_info: Any):
1065
+ try:
1066
+ await self.async_generator.aclose()
1067
+ except ValueError as e:
1068
+ if "was created in a different Context" not in str(e):
1069
+ raise
1070
+ else:
1071
+ logger.debug(
1072
+ f"Suppressed ContextVar error during async cleanup: {e}. "
1073
+ "This is expected when async generators yield across context boundaries."
1074
+ )
1075
+
1076
+
1077
+ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[Dict[str, Any], Dict[str, Any]]:
1078
+ """Build input data and metadata for agent wrappers.
1079
+
1080
+ Returns:
1081
+ Tuple of (input_data, metadata)
1082
+ """
1083
+ input_data = {}
1084
+
1085
+ user_prompt = args[0] if len(args) > 0 else kwargs.get("user_prompt")
1086
+ if user_prompt is not None:
1087
+ input_data["user_prompt"] = _serialize_user_prompt(user_prompt)
1088
+
1089
+ for key, value in kwargs.items():
1090
+ if key == "deps":
1091
+ continue
1092
+ elif key == "message_history":
1093
+ input_data[key] = _serialize_messages(value) if value is not None else None
1094
+ elif key in ("output_type", "toolsets"):
1095
+ # These often contain types/classes, use special serialization
1096
+ input_data[key] = _serialize_type(value) if value is not None else None
1097
+ elif key == "model_settings":
1098
+ # model_settings passed to run() goes in INPUT (it's a run() parameter)
1099
+ input_data[key] = _try_dict(value) if value is not None else None
1100
+ else:
1101
+ input_data[key] = _try_dict(value) if value is not None else None
1102
+
1103
+ if "model" in kwargs:
1104
+ model_name, provider = _parse_model_string(kwargs["model"])
1105
+ else:
1106
+ model_name, provider = _extract_model_info(instance)
1107
+
1108
+ # Extract agent-level configuration for metadata
1109
+ # Only add to metadata if NOT explicitly passed in kwargs (those go in input)
1110
+ agent_model_settings = None
1111
+ if "model_settings" not in kwargs and hasattr(instance, "model_settings") and instance.model_settings is not None:
1112
+ agent_model_settings = instance.model_settings
1113
+
1114
+ metadata = _build_model_metadata(model_name, provider, agent_model_settings)
1115
+
1116
+ # Extract additional agent configuration (only if not passed as kwargs)
1117
+ if "name" not in kwargs and hasattr(instance, "name") and instance.name is not None:
1118
+ metadata["agent_name"] = instance.name
1119
+
1120
+ if "end_strategy" not in kwargs and hasattr(instance, "end_strategy") and instance.end_strategy is not None:
1121
+ metadata["end_strategy"] = str(instance.end_strategy)
1122
+
1123
+ # Extract output_type if set on agent and not passed as kwarg
1124
+ # output_type can be a Pydantic model, str, or other types that get converted to JSON schema
1125
+ if "output_type" not in kwargs and hasattr(instance, "output_type") and instance.output_type is not None:
1126
+ try:
1127
+ metadata["output_type"] = _serialize_type(instance.output_type)
1128
+ except Exception as e:
1129
+ logger.debug(f"Failed to extract output_type from agent: {e}")
1130
+
1131
+ # Extract toolsets if set on agent and not passed as kwarg
1132
+ # Toolsets go in INPUT (not metadata) because agent.run() accepts toolsets parameter
1133
+ if "toolsets" not in kwargs and hasattr(instance, "toolsets"):
1134
+ try:
1135
+ toolsets = instance.toolsets
1136
+ if toolsets:
1137
+ # Convert toolsets to a list with FULL tool schemas for input
1138
+ serialized_toolsets = []
1139
+ for ts in toolsets:
1140
+ ts_info = {
1141
+ "id": getattr(ts, "id", str(type(ts).__name__)),
1142
+ "label": getattr(ts, "label", None),
1143
+ }
1144
+ # Add full tool schemas (not just names) since toolsets can be passed to agent.run()
1145
+ if hasattr(ts, "tools") and ts.tools:
1146
+ tools_list = []
1147
+ tools_dict = ts.tools
1148
+ # tools is a dict mapping tool name -> Tool object
1149
+ for tool_name, tool_obj in tools_dict.items():
1150
+ tool_dict = {
1151
+ "name": tool_name,
1152
+ }
1153
+ # Extract description
1154
+ if hasattr(tool_obj, "description") and tool_obj.description:
1155
+ tool_dict["description"] = tool_obj.description
1156
+ # Extract JSON schema for parameters
1157
+ if hasattr(tool_obj, "function_schema") and hasattr(tool_obj.function_schema, "json_schema"):
1158
+ tool_dict["parameters"] = tool_obj.function_schema.json_schema
1159
+ tools_list.append(tool_dict)
1160
+ ts_info["tools"] = tools_list
1161
+ serialized_toolsets.append(ts_info)
1162
+ input_data["toolsets"] = serialized_toolsets
1163
+ except Exception as e:
1164
+ logger.debug(f"Failed to extract toolsets from agent: {e}")
1165
+
1166
+ # Extract system_prompt from agent if not passed as kwarg
1167
+ # Note: system_prompt goes in input (not metadata) because it's semantically part of the LLM input
1168
+ # Pydantic AI doesn't expose a public API for this, so we access the private _system_prompts
1169
+ # attribute. This is wrapped in try/except to gracefully handle if the internal structure changes.
1170
+ if "system_prompt" not in kwargs:
1171
+ try:
1172
+ if hasattr(instance, "_system_prompts") and instance._system_prompts:
1173
+ input_data["system_prompt"] = "\n\n".join(instance._system_prompts)
1174
+ except Exception as e:
1175
+ logger.debug(f"Failed to extract system_prompt from agent: {e}")
1176
+
1177
+ return input_data, metadata
1178
+
1179
+
1180
+ def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[Dict[str, Any], Dict[str, Any]]:
1181
+ """Build input data and metadata for direct model request wrappers.
1182
+
1183
+ Returns:
1184
+ Tuple of (input_data, metadata)
1185
+ """
1186
+ input_data = {}
1187
+
1188
+ model = args[0] if len(args) > 0 else kwargs.get("model")
1189
+ if model is not None:
1190
+ input_data["model"] = str(model)
1191
+
1192
+ messages = args[1] if len(args) > 1 else kwargs.get("messages", [])
1193
+ if messages:
1194
+ input_data["messages"] = _serialize_messages(messages)
1195
+
1196
+ for key, value in kwargs.items():
1197
+ if key not in ["model", "messages"]:
1198
+ input_data[key] = _try_dict(value) if value is not None else None
1199
+
1200
+ model_name, provider = _parse_model_string(model)
1201
+ metadata = _build_model_metadata(model_name, provider)
1202
+
1203
+ return input_data, metadata