braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. braintrust/__init__.py +4 -0
  2. braintrust/_generated_types.py +1200 -611
  3. braintrust/audit.py +2 -2
  4. braintrust/cli/eval.py +6 -7
  5. braintrust/cli/push.py +11 -11
  6. braintrust/conftest.py +1 -0
  7. braintrust/context.py +12 -17
  8. braintrust/contrib/temporal/__init__.py +16 -27
  9. braintrust/contrib/temporal/test_temporal.py +8 -3
  10. braintrust/devserver/auth.py +8 -8
  11. braintrust/devserver/cache.py +3 -4
  12. braintrust/devserver/cors.py +8 -7
  13. braintrust/devserver/dataset.py +3 -5
  14. braintrust/devserver/eval_hooks.py +7 -6
  15. braintrust/devserver/schemas.py +22 -19
  16. braintrust/devserver/server.py +19 -12
  17. braintrust/devserver/test_cached_login.py +4 -4
  18. braintrust/framework.py +128 -140
  19. braintrust/framework2.py +88 -87
  20. braintrust/functions/invoke.py +93 -53
  21. braintrust/functions/stream.py +3 -2
  22. braintrust/generated_types.py +17 -1
  23. braintrust/git_fields.py +11 -11
  24. braintrust/gitutil.py +2 -3
  25. braintrust/graph_util.py +10 -10
  26. braintrust/id_gen.py +2 -2
  27. braintrust/logger.py +346 -357
  28. braintrust/merge_row_batch.py +10 -9
  29. braintrust/oai.py +107 -24
  30. braintrust/otel/__init__.py +49 -49
  31. braintrust/otel/context.py +16 -30
  32. braintrust/otel/test_distributed_tracing.py +14 -11
  33. braintrust/otel/test_otel_bt_integration.py +32 -31
  34. braintrust/parameters.py +8 -8
  35. braintrust/prompt.py +14 -14
  36. braintrust/prompt_cache/disk_cache.py +5 -4
  37. braintrust/prompt_cache/lru_cache.py +3 -2
  38. braintrust/prompt_cache/prompt_cache.py +13 -14
  39. braintrust/queue.py +4 -4
  40. braintrust/score.py +4 -4
  41. braintrust/serializable_data_class.py +4 -4
  42. braintrust/span_identifier_v1.py +1 -2
  43. braintrust/span_identifier_v2.py +3 -4
  44. braintrust/span_identifier_v3.py +23 -20
  45. braintrust/span_identifier_v4.py +34 -25
  46. braintrust/test_framework.py +16 -6
  47. braintrust/test_helpers.py +5 -5
  48. braintrust/test_id_gen.py +2 -3
  49. braintrust/test_otel.py +61 -53
  50. braintrust/test_queue.py +0 -1
  51. braintrust/test_score.py +1 -3
  52. braintrust/test_span_components.py +29 -44
  53. braintrust/util.py +9 -8
  54. braintrust/version.py +2 -2
  55. braintrust/wrappers/_anthropic_utils.py +4 -4
  56. braintrust/wrappers/agno/__init__.py +3 -4
  57. braintrust/wrappers/agno/agent.py +1 -2
  58. braintrust/wrappers/agno/function_call.py +1 -2
  59. braintrust/wrappers/agno/model.py +1 -2
  60. braintrust/wrappers/agno/team.py +1 -2
  61. braintrust/wrappers/agno/utils.py +12 -12
  62. braintrust/wrappers/anthropic.py +7 -8
  63. braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
  64. braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
  65. braintrust/wrappers/dspy.py +15 -17
  66. braintrust/wrappers/google_genai/__init__.py +16 -16
  67. braintrust/wrappers/langchain.py +22 -24
  68. braintrust/wrappers/litellm.py +4 -3
  69. braintrust/wrappers/openai.py +15 -15
  70. braintrust/wrappers/pydantic_ai.py +1204 -0
  71. braintrust/wrappers/test_agno.py +0 -1
  72. braintrust/wrappers/test_dspy.py +0 -1
  73. braintrust/wrappers/test_google_genai.py +2 -3
  74. braintrust/wrappers/test_litellm.py +0 -1
  75. braintrust/wrappers/test_oai_attachments.py +322 -0
  76. braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
  77. braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
  78. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
  79. braintrust-0.4.0.dist-info/RECORD +120 -0
  80. braintrust-0.3.14.dist-info/RECORD +0 -117
  81. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
  82. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
  83. {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1204 @@
1
+ import logging
2
+ import sys
3
+ import time
4
+ from collections.abc import AsyncGenerator, Iterable
5
+ from contextlib import AbstractAsyncContextManager
6
+ from typing import Any, TypeVar
7
+
8
+ from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
9
+ from braintrust.span_types import SpanTypeAttribute
10
+ from wrapt import wrap_function_wrapper
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ __all__ = ["setup_pydantic_ai"]
15
+
16
+
17
+ def setup_pydantic_ai(
18
+ api_key: str | None = None,
19
+ project_id: str | None = None,
20
+ project_name: str | None = None,
21
+ ) -> bool:
22
+ """
23
+ Setup Braintrust integration with Pydantic AI. Will automatically patch Pydantic AI Agents and direct API functions for automatic tracing.
24
+
25
+ Args:
26
+ api_key (Optional[str]): Braintrust API key.
27
+ project_id (Optional[str]): Braintrust project ID.
28
+ project_name (Optional[str]): Braintrust project name.
29
+
30
+ Returns:
31
+ bool: True if setup was successful, False otherwise.
32
+ """
33
+ span = current_span()
34
+ if span == NOOP_SPAN:
35
+ init_logger(project=project_name, api_key=api_key, project_id=project_id)
36
+
37
+ try:
38
+ import pydantic_ai.direct as direct_module
39
+ from pydantic_ai import Agent
40
+
41
+ Agent = wrap_agent(Agent)
42
+
43
+ wrap_function_wrapper(direct_module, "model_request", _create_direct_model_request_wrapper())
44
+ wrap_function_wrapper(direct_module, "model_request_sync", _create_direct_model_request_sync_wrapper())
45
+ wrap_function_wrapper(direct_module, "model_request_stream", _create_direct_model_request_stream_wrapper())
46
+ wrap_function_wrapper(
47
+ direct_module, "model_request_stream_sync", _create_direct_model_request_stream_sync_wrapper()
48
+ )
49
+
50
+ wrap_model_classes()
51
+
52
+ return True
53
+ except ImportError as e:
54
+ logger.error(f"Failed to import Pydantic AI: {e}")
55
+ logger.error("Pydantic AI is not installed. Please install it with: pip install pydantic-ai-slim")
56
+ return False
57
+
58
+
59
+ def wrap_agent(Agent: Any) -> Any:
60
+ if _is_patched(Agent):
61
+ return Agent
62
+
63
+ def _ensure_model_wrapped(instance: Any):
64
+ """Ensure the agent's model class is wrapped (lazy wrapping)."""
65
+ if hasattr(instance, "_model"):
66
+ model_class = type(instance._model)
67
+ _wrap_concrete_model_class(model_class)
68
+
69
+ async def agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
70
+ _ensure_model_wrapped(instance)
71
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
72
+
73
+ with start_span(
74
+ name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
75
+ type=SpanTypeAttribute.LLM,
76
+ input=input_data if input_data else None,
77
+ metadata=_try_dict(metadata),
78
+ ) as agent_span:
79
+ start_time = time.time()
80
+ result = await wrapped(*args, **kwargs)
81
+ end_time = time.time()
82
+
83
+ output = _serialize_result_output(result)
84
+ metrics = _extract_usage_metrics(result, start_time, end_time)
85
+
86
+ agent_span.log(output=output, metrics=metrics)
87
+ return result
88
+
89
+ wrap_function_wrapper(Agent, "run", agent_run_wrapper)
90
+
91
+ def agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
92
+ _ensure_model_wrapped(instance)
93
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
94
+
95
+ with start_span(
96
+ name=f"agent_run_sync [{instance.name}]"
97
+ if hasattr(instance, "name") and instance.name
98
+ else "agent_run_sync",
99
+ type=SpanTypeAttribute.LLM,
100
+ input=input_data if input_data else None,
101
+ metadata=_try_dict(metadata),
102
+ ) as agent_span:
103
+ start_time = time.time()
104
+ result = wrapped(*args, **kwargs)
105
+ end_time = time.time()
106
+
107
+ output = _serialize_result_output(result)
108
+ metrics = _extract_usage_metrics(result, start_time, end_time)
109
+
110
+ agent_span.log(output=output, metrics=metrics)
111
+ return result
112
+
113
+ wrap_function_wrapper(Agent, "run_sync", agent_run_sync_wrapper)
114
+
115
+ def agent_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
116
+ _ensure_model_wrapped(instance)
117
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
118
+ agent_name = instance.name if hasattr(instance, "name") else None
119
+ span_name = f"agent_run_stream [{agent_name}]" if agent_name else "agent_run_stream"
120
+
121
+ return _AgentStreamWrapper(
122
+ wrapped(*args, **kwargs),
123
+ span_name,
124
+ input_data,
125
+ metadata,
126
+ )
127
+
128
+ wrap_function_wrapper(Agent, "run_stream", agent_run_stream_wrapper)
129
+
130
+ def agent_run_stream_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
131
+ _ensure_model_wrapped(instance)
132
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
133
+ agent_name = instance.name if hasattr(instance, "name") else None
134
+ span_name = f"agent_run_stream_sync [{agent_name}]" if agent_name else "agent_run_stream_sync"
135
+
136
+ # Create span context BEFORE calling wrapped function so internal spans nest under it
137
+ span_cm = start_span(
138
+ name=span_name,
139
+ type=SpanTypeAttribute.LLM,
140
+ input=input_data if input_data else None,
141
+ metadata=_try_dict(metadata),
142
+ )
143
+ span = span_cm.__enter__()
144
+ start_time = time.time()
145
+
146
+ try:
147
+ # Call the original function within the span context
148
+ stream_result = wrapped(*args, **kwargs)
149
+ return _AgentStreamResultSyncProxy(
150
+ stream_result,
151
+ span,
152
+ span_cm,
153
+ start_time,
154
+ )
155
+ except Exception:
156
+ # Clean up span on error
157
+ span_cm.__exit__(*sys.exc_info())
158
+ raise
159
+
160
+ wrap_function_wrapper(Agent, "run_stream_sync", agent_run_stream_sync_wrapper)
161
+
162
+ async def agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
163
+ _ensure_model_wrapped(instance)
164
+ input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
165
+
166
+ agent_name = instance.name if hasattr(instance, "name") else None
167
+ span_name = f"agent_run_stream_events [{agent_name}]" if agent_name else "agent_run_stream_events"
168
+
169
+ with start_span(
170
+ name=span_name,
171
+ type=SpanTypeAttribute.LLM,
172
+ input=input_data if input_data else None,
173
+ metadata=_try_dict(metadata),
174
+ ) as agent_span:
175
+ start_time = time.time()
176
+ event_count = 0
177
+ final_result = None
178
+
179
+ async for event in wrapped(*args, **kwargs):
180
+ event_count += 1
181
+ if hasattr(event, "output"):
182
+ final_result = event
183
+ yield event
184
+
185
+ end_time = time.time()
186
+
187
+ output = None
188
+ metrics = {
189
+ "start": start_time,
190
+ "end": end_time,
191
+ "duration": end_time - start_time,
192
+ "event_count": event_count,
193
+ }
194
+
195
+ if final_result:
196
+ output = _serialize_result_output(final_result)
197
+ usage_metrics = _extract_usage_metrics(final_result, start_time, end_time)
198
+ metrics.update(usage_metrics)
199
+
200
+ agent_span.log(output=output, metrics=metrics)
201
+
202
+ wrap_function_wrapper(Agent, "run_stream_events", agent_run_stream_events_wrapper)
203
+
204
+ Agent._braintrust_patched = True
205
+
206
+ return Agent
207
+
208
+
209
+ def _create_direct_model_request_wrapper():
210
+ """Create wrapper for direct.model_request()."""
211
+
212
+ async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
213
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
214
+
215
+ with start_span(
216
+ name="model_request",
217
+ type=SpanTypeAttribute.LLM,
218
+ input=input_data,
219
+ metadata=_try_dict(metadata),
220
+ ) as span:
221
+ start_time = time.time()
222
+ result = await wrapped(*args, **kwargs)
223
+ end_time = time.time()
224
+
225
+ output = _serialize_model_response(result)
226
+ metrics = _extract_response_metrics(result, start_time, end_time)
227
+
228
+ span.log(output=output, metrics=metrics)
229
+ return result
230
+
231
+ return wrapper
232
+
233
+
234
+ def _create_direct_model_request_sync_wrapper():
235
+ """Create wrapper for direct.model_request_sync()."""
236
+
237
+ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
238
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
239
+
240
+ with start_span(
241
+ name="model_request_sync",
242
+ type=SpanTypeAttribute.LLM,
243
+ input=input_data,
244
+ metadata=_try_dict(metadata),
245
+ ) as span:
246
+ start_time = time.time()
247
+ result = wrapped(*args, **kwargs)
248
+ end_time = time.time()
249
+
250
+ output = _serialize_model_response(result)
251
+ metrics = _extract_response_metrics(result, start_time, end_time)
252
+
253
+ span.log(output=output, metrics=metrics)
254
+ return result
255
+
256
+ return wrapper
257
+
258
+
259
+ def _create_direct_model_request_stream_wrapper():
260
+ """Create wrapper for direct.model_request_stream()."""
261
+
262
+ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
263
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
264
+
265
+ return _DirectStreamWrapper(
266
+ wrapped(*args, **kwargs),
267
+ "model_request_stream",
268
+ input_data,
269
+ metadata,
270
+ )
271
+
272
+ return wrapper
273
+
274
+
275
+ def _create_direct_model_request_stream_sync_wrapper():
276
+ """Create wrapper for direct.model_request_stream_sync()."""
277
+
278
+ def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
279
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
280
+
281
+ return _DirectStreamWrapperSync(
282
+ wrapped(*args, **kwargs),
283
+ "model_request_stream_sync",
284
+ input_data,
285
+ metadata,
286
+ )
287
+
288
+ return wrapper
289
+
290
+
291
+ def wrap_model_request(original_func: Any) -> Any:
292
+ async def wrapper(*args, **kwargs):
293
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
294
+
295
+ with start_span(
296
+ name="model_request",
297
+ type=SpanTypeAttribute.LLM,
298
+ input=input_data,
299
+ metadata=_try_dict(metadata),
300
+ ) as span:
301
+ start_time = time.time()
302
+ result = await original_func(*args, **kwargs)
303
+ end_time = time.time()
304
+
305
+ output = _serialize_model_response(result)
306
+ metrics = _extract_response_metrics(result, start_time, end_time)
307
+
308
+ span.log(output=output, metrics=metrics)
309
+ return result
310
+
311
+ return wrapper
312
+
313
+
314
+ def wrap_model_request_sync(original_func: Any) -> Any:
315
+ def wrapper(*args, **kwargs):
316
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
317
+
318
+ with start_span(
319
+ name="model_request_sync",
320
+ type=SpanTypeAttribute.LLM,
321
+ input=input_data,
322
+ metadata=_try_dict(metadata),
323
+ ) as span:
324
+ start_time = time.time()
325
+ result = original_func(*args, **kwargs)
326
+ end_time = time.time()
327
+
328
+ output = _serialize_model_response(result)
329
+ metrics = _extract_response_metrics(result, start_time, end_time)
330
+
331
+ span.log(output=output, metrics=metrics)
332
+ return result
333
+
334
+ return wrapper
335
+
336
+
337
+ def wrap_model_request_stream(original_func: Any) -> Any:
338
+ def wrapper(*args, **kwargs):
339
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
340
+
341
+ return _DirectStreamWrapper(
342
+ original_func(*args, **kwargs),
343
+ "model_request_stream",
344
+ input_data,
345
+ metadata,
346
+ )
347
+
348
+ return wrapper
349
+
350
+
351
+ def wrap_model_request_stream_sync(original_func: Any) -> Any:
352
+ def wrapper(*args, **kwargs):
353
+ input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
354
+
355
+ return _DirectStreamWrapperSync(
356
+ original_func(*args, **kwargs),
357
+ "model_request_stream_sync",
358
+ input_data,
359
+ metadata,
360
+ )
361
+
362
+ return wrapper
363
+
364
+
365
+ def wrap_model_classes():
366
+ """Wrap Model classes to capture internal model requests made by agents."""
367
+ try:
368
+ from pydantic_ai.models import Model
369
+
370
+ def wrap_all_subclasses(base_class):
371
+ """Recursively wrap all subclasses of a base class."""
372
+ for subclass in base_class.__subclasses__():
373
+ if not getattr(subclass, "__abstractmethods__", None):
374
+ try:
375
+ _wrap_concrete_model_class(subclass)
376
+ except Exception as e:
377
+ logger.debug(f"Could not wrap {subclass.__name__}: {e}")
378
+
379
+ wrap_all_subclasses(subclass)
380
+
381
+ wrap_all_subclasses(Model)
382
+
383
+ except Exception as e:
384
+ logger.warning(f"Failed to wrap Model classes: {e}")
385
+
386
+
387
+ def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any):
388
+ """Build input data and metadata for model class request wrappers.
389
+
390
+ Returns:
391
+ Tuple of (model_name, display_name, input_data, metadata)
392
+ """
393
+ model_name, provider = _extract_model_info_from_model_instance(instance)
394
+ display_name = model_name or str(instance)
395
+
396
+ messages = args[0] if len(args) > 0 else kwargs.get("messages")
397
+ model_settings = args[1] if len(args) > 1 else kwargs.get("model_settings")
398
+
399
+ serialized_messages = _serialize_messages(messages)
400
+
401
+ input_data = {"messages": serialized_messages}
402
+ if model_settings is not None:
403
+ input_data["model_settings"] = _try_dict(model_settings)
404
+
405
+ metadata = _build_model_metadata(model_name, provider, model_settings=None)
406
+
407
+ return model_name, display_name, input_data, metadata
408
+
409
+
410
+ def _wrap_concrete_model_class(model_class: Any):
411
+ """Wrap a concrete model class to trace its request methods."""
412
+ if _is_patched(model_class):
413
+ return
414
+
415
+ async def model_request_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
416
+ model_name, display_name, input_data, metadata = _build_model_class_input_and_metadata(instance, args, kwargs)
417
+
418
+ with start_span(
419
+ name=f"chat {display_name}",
420
+ type=SpanTypeAttribute.LLM,
421
+ input=input_data,
422
+ metadata=_try_dict(metadata),
423
+ ) as span:
424
+ start_time = time.time()
425
+ result = await wrapped(*args, **kwargs)
426
+ end_time = time.time()
427
+
428
+ output = _serialize_model_response(result)
429
+ metrics = _extract_response_metrics(result, start_time, end_time)
430
+
431
+ span.log(output=output, metrics=metrics)
432
+ return result
433
+
434
+ def model_request_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
435
+ model_name, display_name, input_data, metadata = _build_model_class_input_and_metadata(instance, args, kwargs)
436
+
437
+ return _DirectStreamWrapper(
438
+ wrapped(*args, **kwargs),
439
+ f"chat {display_name}",
440
+ input_data,
441
+ metadata,
442
+ )
443
+
444
+ wrap_function_wrapper(model_class, "request", model_request_wrapper)
445
+ wrap_function_wrapper(model_class, "request_stream", model_request_stream_wrapper)
446
+ model_class._braintrust_patched = True
447
+
448
+
449
+ class _AgentStreamWrapper(AbstractAsyncContextManager):
450
+ """Wrapper for agent.run_stream() that adds tracing while passing through the stream result."""
451
+
452
+ def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
453
+ self.stream_cm = stream_cm
454
+ self.span_name = span_name
455
+ self.input_data = input_data
456
+ self.metadata = metadata
457
+ self.span_cm = None
458
+ self.start_time = None
459
+ self.stream_result = None
460
+
461
+ async def __aenter__(self):
462
+ # Use context manager properly so span stays current
463
+ # DON'T pass start_time here - we'll set it via metrics in __aexit__
464
+ self.span_cm = start_span(
465
+ name=self.span_name,
466
+ type=SpanTypeAttribute.LLM,
467
+ input=self.input_data if self.input_data else None,
468
+ metadata=_try_dict(self.metadata),
469
+ )
470
+ span = self.span_cm.__enter__()
471
+
472
+ # Capture start time right before entering the stream (API call initiation)
473
+ self.start_time = time.time()
474
+ self.stream_result = await self.stream_cm.__aenter__()
475
+ return self.stream_result # Return actual stream result object
476
+
477
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
478
+ try:
479
+ await self.stream_cm.__aexit__(exc_type, exc_val, exc_tb)
480
+ finally:
481
+ if self.span_cm and self.start_time and self.stream_result:
482
+ end_time = time.time()
483
+
484
+ output = _serialize_stream_output(self.stream_result)
485
+ metrics = _extract_stream_usage_metrics(self.stream_result, self.start_time, end_time, None)
486
+ self.span_cm.log(output=output, metrics=metrics)
487
+
488
+ # Always clean up span context
489
+ if self.span_cm:
490
+ self.span_cm.__exit__(None, None, None)
491
+
492
+ return False
493
+
494
+
495
+ class _DirectStreamWrapper(AbstractAsyncContextManager):
496
+ """Wrapper for model_request_stream() that adds tracing while passing through the stream."""
497
+
498
+ def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
499
+ self.stream_cm = stream_cm
500
+ self.span_name = span_name
501
+ self.input_data = input_data
502
+ self.metadata = metadata
503
+ self.span_cm = None
504
+ self.start_time = None
505
+ self.stream = None
506
+
507
+ async def __aenter__(self):
508
+ # Use context manager properly so span stays current
509
+ # DON'T pass start_time here - we'll set it via metrics in __aexit__
510
+ self.span_cm = start_span(
511
+ name=self.span_name,
512
+ type=SpanTypeAttribute.LLM,
513
+ input=self.input_data if self.input_data else None,
514
+ metadata=_try_dict(self.metadata),
515
+ )
516
+ span = self.span_cm.__enter__()
517
+
518
+ # Capture start time right before entering the stream (API call initiation)
519
+ self.start_time = time.time()
520
+ self.stream = await self.stream_cm.__aenter__()
521
+ return self.stream # Return actual stream object
522
+
523
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
524
+ try:
525
+ await self.stream_cm.__aexit__(exc_type, exc_val, exc_tb)
526
+ finally:
527
+ if self.span_cm and self.start_time and self.stream:
528
+ end_time = time.time()
529
+
530
+ try:
531
+ final_response = self.stream.get()
532
+ output = _serialize_model_response(final_response)
533
+ metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
534
+ self.span_cm.log(output=output, metrics=metrics)
535
+ except Exception as e:
536
+ logger.debug(f"Failed to extract stream output/metrics: {e}")
537
+
538
+ # Always clean up span context
539
+ if self.span_cm:
540
+ self.span_cm.__exit__(None, None, None)
541
+
542
+ return False
543
+
544
+
545
+ class _AgentStreamResultSyncProxy:
546
+ """Proxy for agent.run_stream_sync() result that adds tracing while delegating to actual stream result."""
547
+
548
+ def __init__(self, stream_result: Any, span: Any, span_cm: Any, start_time: float):
549
+ self._stream_result = stream_result
550
+ self._span = span
551
+ self._span_cm = span_cm
552
+ self._start_time = start_time
553
+ self._logged = False
554
+ self._finalize_on_del = True
555
+
556
+ def __getattr__(self, name: str):
557
+ """Delegate all attribute access to the wrapped stream result."""
558
+ attr = getattr(self._stream_result, name)
559
+
560
+ # Wrap any method that returns an iterator to auto-finalize when exhausted
561
+ if callable(attr) and name in ('stream_text', 'stream_output', '__iter__'):
562
+ def wrapped_method(*args, **kwargs):
563
+ try:
564
+ iterator = attr(*args, **kwargs)
565
+ # If it's an iterator, wrap it
566
+ if hasattr(iterator, '__iter__') or hasattr(iterator, '__next__'):
567
+ try:
568
+ yield from iterator
569
+ finally:
570
+ self._finalize()
571
+ self._finalize_on_del = False # Don't finalize again in __del__
572
+ else:
573
+ return iterator
574
+ except Exception:
575
+ self._finalize()
576
+ self._finalize_on_del = False
577
+ raise
578
+ return wrapped_method
579
+
580
+ return attr
581
+
582
+ def _finalize(self):
583
+ """Log metrics and close span."""
584
+ if self._span and not self._logged and self._stream_result:
585
+ try:
586
+ end_time = time.time()
587
+ output = _serialize_stream_output(self._stream_result)
588
+ metrics = _extract_stream_usage_metrics(self._stream_result, self._start_time, end_time, None)
589
+ self._span.log(output=output, metrics=metrics)
590
+ self._logged = True
591
+ finally:
592
+ try:
593
+ self._span_cm.__exit__(None, None, None)
594
+ except Exception:
595
+ pass
596
+
597
+ def __del__(self):
598
+ """Ensure span is closed when proxy is destroyed."""
599
+ if self._finalize_on_del:
600
+ self._finalize()
601
+
602
+
603
+ class _DirectStreamWrapperSync:
604
+ """Wrapper for model_request_stream_sync() that adds tracing while passing through the stream."""
605
+
606
+ def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
607
+ self.stream_cm = stream_cm
608
+ self.span_name = span_name
609
+ self.input_data = input_data
610
+ self.metadata = metadata
611
+ self.span_cm = None
612
+ self.start_time = None
613
+ self.stream = None
614
+
615
+ def __enter__(self):
616
+ # Use context manager properly so span stays current
617
+ # DON'T pass start_time here - we'll set it via metrics in __exit__
618
+ self.span_cm = start_span(
619
+ name=self.span_name,
620
+ type=SpanTypeAttribute.LLM,
621
+ input=self.input_data if self.input_data else None,
622
+ metadata=_try_dict(self.metadata),
623
+ )
624
+ span = self.span_cm.__enter__()
625
+
626
+ # Capture start time right before entering the stream (API call initiation)
627
+ self.start_time = time.time()
628
+ self.stream = self.stream_cm.__enter__()
629
+ return self.stream # Return actual stream object
630
+
631
+ def __exit__(self, exc_type, exc_val, exc_tb):
632
+ try:
633
+ self.stream_cm.__exit__(exc_type, exc_val, exc_tb)
634
+ finally:
635
+ if self.span_cm and self.start_time and self.stream:
636
+ end_time = time.time()
637
+
638
+ try:
639
+ final_response = self.stream.get()
640
+ output = _serialize_model_response(final_response)
641
+ metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
642
+ self.span_cm.log(output=output, metrics=metrics)
643
+ except Exception as e:
644
+ logger.debug(f"Failed to extract stream output/metrics: {e}")
645
+
646
+ # Always clean up span context
647
+ if self.span_cm:
648
+ self.span_cm.__exit__(None, None, None)
649
+
650
+ return False
651
+
652
+
653
+ def _serialize_user_prompt(user_prompt: Any) -> Any:
654
+ """Serialize user prompt, handling BinaryContent and other types."""
655
+ if user_prompt is None:
656
+ return None
657
+
658
+ if isinstance(user_prompt, str):
659
+ return user_prompt
660
+
661
+ if isinstance(user_prompt, list):
662
+ return [_serialize_content_part(part) for part in user_prompt]
663
+
664
+ return _serialize_content_part(user_prompt)
665
+
666
+
667
+ def _serialize_content_part(part: Any) -> Any:
668
+ """Serialize a content part, handling BinaryContent specially."""
669
+ if part is None:
670
+ return None
671
+
672
+ if hasattr(part, "data") and hasattr(part, "media_type") and hasattr(part, "kind"):
673
+ if part.kind == "binary":
674
+ data = part.data
675
+ media_type = part.media_type
676
+
677
+ extension = media_type.split("/")[1] if "/" in media_type else "bin"
678
+ filename = f"file.{extension}"
679
+
680
+ attachment = Attachment(data=data, filename=filename, content_type=media_type)
681
+ return {"type": "binary", "attachment": attachment, "media_type": media_type}
682
+
683
+ if isinstance(part, str):
684
+ return part
685
+
686
+ return _try_dict(part)
687
+
688
+
689
+ def _serialize_messages(messages: Any) -> Any:
690
+ """Serialize messages list."""
691
+ if not messages:
692
+ return []
693
+
694
+ result = []
695
+ for msg in messages:
696
+ serialized_msg = _try_dict(msg)
697
+
698
+ if isinstance(serialized_msg, dict) and "parts" in serialized_msg:
699
+ serialized_msg["parts"] = [_serialize_content_part(p) for p in msg.parts]
700
+
701
+ result.append(serialized_msg)
702
+
703
+ return result
704
+
705
+
706
+ def _serialize_result_output(result: Any) -> Any:
707
+ """Serialize agent run result output."""
708
+ if not result:
709
+ return None
710
+
711
+ output_dict = {}
712
+
713
+ if hasattr(result, "output"):
714
+ output_dict["output"] = _try_dict(result.output)
715
+
716
+ if hasattr(result, "response"):
717
+ output_dict["response"] = _serialize_model_response(result.response)
718
+
719
+ return output_dict if output_dict else _try_dict(result)
720
+
721
+
722
+ def _serialize_stream_output(stream_result: Any) -> Any:
723
+ """Serialize stream result output."""
724
+ if not stream_result:
725
+ return None
726
+
727
+ output_dict = {}
728
+
729
+ if hasattr(stream_result, "response"):
730
+ output_dict["response"] = _serialize_model_response(stream_result.response)
731
+
732
+ return output_dict if output_dict else None
733
+
734
+
735
+ def _serialize_model_response(response: Any) -> Any:
736
+ """Serialize a model response."""
737
+ if not response:
738
+ return None
739
+
740
+ response_dict = _try_dict(response)
741
+
742
+ if isinstance(response_dict, dict) and "parts" in response_dict:
743
+ if hasattr(response, "parts"):
744
+ response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
745
+
746
+ return response_dict
747
+
748
+
749
+ def _extract_model_info_from_model_instance(model: Any) -> tuple[str | None, str | None]:
750
+ """Extract model name and provider from a model instance.
751
+
752
+ Args:
753
+ model: A Pydantic AI model instance (OpenAIChatModel, AnthropicModel, etc.)
754
+
755
+ Returns:
756
+ Tuple of (model_name, provider)
757
+ """
758
+ if not model:
759
+ return None, None
760
+
761
+ if isinstance(model, str):
762
+ return _parse_model_string(model)
763
+
764
+ if hasattr(model, "model_name"):
765
+ model_name = model.model_name
766
+ class_name = type(model).__name__
767
+ provider = None
768
+ if "OpenAI" in class_name:
769
+ provider = "openai"
770
+ elif "Anthropic" in class_name:
771
+ provider = "anthropic"
772
+ elif "Gemini" in class_name:
773
+ provider = "gemini"
774
+ elif "Groq" in class_name:
775
+ provider = "groq"
776
+ elif "Mistral" in class_name:
777
+ provider = "mistral"
778
+ elif "VertexAI" in class_name:
779
+ provider = "vertexai"
780
+
781
+ return model_name, provider
782
+
783
+ if hasattr(model, "name"):
784
+ return _parse_model_string(model.name)
785
+
786
+ return None, None
787
+
788
+
789
+ def _extract_model_info(agent: Any) -> tuple[str | None, str | None]:
790
+ """Extract model name and provider from agent.
791
+
792
+ Args:
793
+ agent: A Pydantic AI Agent instance
794
+
795
+ Returns:
796
+ Tuple of (model_name, provider)
797
+ """
798
+ if not hasattr(agent, "model"):
799
+ return None, None
800
+
801
+ return _extract_model_info_from_model_instance(agent.model)
802
+
803
+
804
+ def _build_model_metadata(
805
+ model_name: str | None, provider: str | None, model_settings: Any = None
806
+ ) -> dict[str, Any]:
807
+ """Build metadata dictionary with model info.
808
+
809
+ Args:
810
+ model_name: The model name (e.g., "gpt-4o")
811
+ provider: The provider (e.g., "openai")
812
+ model_settings: Optional model settings to include
813
+
814
+ Returns:
815
+ Dictionary of metadata
816
+ """
817
+ metadata = {}
818
+ if model_name:
819
+ metadata["model"] = model_name
820
+ if provider:
821
+ metadata["provider"] = provider
822
+ if model_settings:
823
+ metadata["model_settings"] = _try_dict(model_settings)
824
+ return metadata
825
+
826
+
827
+ def _parse_model_string(model: Any) -> tuple[str | None, str | None]:
828
+ """Parse model string to extract provider and model name.
829
+
830
+ Pydantic AI uses format: "provider:model-name" (e.g., "openai:gpt-4o")
831
+ """
832
+ if not model:
833
+ return None, None
834
+
835
+ model_str = str(model)
836
+
837
+ if ":" in model_str:
838
+ parts = model_str.split(":", 1)
839
+ return parts[1], parts[0] # (model_name, provider)
840
+
841
+ return model_str, None
842
+
843
+
844
+ def _extract_usage_metrics(result: Any, start_time: float, end_time: float) -> dict[str, float] | None:
845
+ """Extract usage metrics from agent run result."""
846
+ metrics: dict[str, float] = {}
847
+
848
+ metrics["start"] = start_time
849
+ metrics["end"] = end_time
850
+ metrics["duration"] = end_time - start_time
851
+
852
+ usage = None
853
+ if hasattr(result, "response"):
854
+ try:
855
+ response = result.response
856
+ if hasattr(response, "usage"):
857
+ usage = response.usage
858
+ except (AttributeError, ValueError):
859
+ pass
860
+
861
+ if usage is None and hasattr(result, "usage"):
862
+ usage = result.usage
863
+
864
+ if usage is None:
865
+ return metrics
866
+
867
+ if hasattr(usage, "input_tokens"):
868
+ input_tokens = usage.input_tokens
869
+ if input_tokens is not None:
870
+ metrics["prompt_tokens"] = float(input_tokens)
871
+
872
+ if hasattr(usage, "output_tokens"):
873
+ output_tokens = usage.output_tokens
874
+ if output_tokens is not None:
875
+ metrics["completion_tokens"] = float(output_tokens)
876
+
877
+ if hasattr(usage, "total_tokens"):
878
+ total_tokens = usage.total_tokens
879
+ if total_tokens is not None:
880
+ metrics["tokens"] = float(total_tokens)
881
+
882
+ if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
883
+ metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
884
+
885
+ if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
886
+ metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
887
+
888
+ if hasattr(usage, "input_audio_tokens") and usage.input_audio_tokens is not None:
889
+ metrics["prompt_audio_tokens"] = float(usage.input_audio_tokens)
890
+
891
+ if hasattr(usage, "output_audio_tokens") and usage.output_audio_tokens is not None:
892
+ metrics["completion_audio_tokens"] = float(usage.output_audio_tokens)
893
+
894
+ if hasattr(usage, "details") and isinstance(usage.details, dict):
895
+ details = usage.details
896
+
897
+ if "reasoning_tokens" in details:
898
+ metrics["completion_reasoning_tokens"] = float(details["reasoning_tokens"])
899
+
900
+ if "cached_tokens" in details:
901
+ metrics["prompt_cached_tokens"] = float(details["cached_tokens"])
902
+
903
+ return metrics if metrics else None
904
+
905
+
906
+ def _extract_stream_usage_metrics(
907
+ stream_result: Any, start_time: float, end_time: float, first_token_time: float | None
908
+ ) -> dict[str, float] | None:
909
+ """Extract usage metrics from stream result."""
910
+ metrics: dict[str, float] = {}
911
+
912
+ metrics["start"] = start_time
913
+ metrics["end"] = end_time
914
+ metrics["duration"] = end_time - start_time
915
+
916
+ if first_token_time:
917
+ metrics["time_to_first_token"] = first_token_time - start_time
918
+
919
+ if hasattr(stream_result, "usage"):
920
+ usage_func = stream_result.usage
921
+ if callable(usage_func):
922
+ usage = usage_func()
923
+ else:
924
+ usage = usage_func
925
+
926
+ if usage:
927
+ if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
928
+ metrics["prompt_tokens"] = float(usage.input_tokens)
929
+
930
+ if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
931
+ metrics["completion_tokens"] = float(usage.output_tokens)
932
+
933
+ if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
934
+ metrics["tokens"] = float(usage.total_tokens)
935
+
936
+ if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
937
+ metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
938
+
939
+ if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
940
+ metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
941
+
942
+ return metrics if metrics else None
943
+
944
+
945
+ def _extract_response_metrics(
946
+ response: Any, start_time: float, end_time: float, first_token_time: float | None = None
947
+ ) -> dict[str, float] | None:
948
+ """Extract metrics from model response."""
949
+ metrics: dict[str, float] = {}
950
+
951
+ metrics["start"] = start_time
952
+ metrics["end"] = end_time
953
+ metrics["duration"] = end_time - start_time
954
+
955
+ if first_token_time:
956
+ metrics["time_to_first_token"] = first_token_time - start_time
957
+
958
+ if hasattr(response, "usage") and response.usage:
959
+ usage = response.usage
960
+
961
+ if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
962
+ metrics["prompt_tokens"] = float(usage.input_tokens)
963
+
964
+ if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
965
+ metrics["completion_tokens"] = float(usage.output_tokens)
966
+
967
+ if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
968
+ metrics["tokens"] = float(usage.total_tokens)
969
+
970
+ if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
971
+ metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
972
+
973
+ if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
974
+ metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
975
+
976
+ # Extract reasoning tokens for reasoning models (o1/o3)
977
+ if hasattr(usage, "details") and usage.details is not None:
978
+ if hasattr(usage.details, "reasoning_tokens") and usage.details.reasoning_tokens is not None:
979
+ metrics["completion_reasoning_tokens"] = float(usage.details.reasoning_tokens)
980
+
981
+ return metrics if metrics else None
982
+
983
+
984
+ def _is_patched(obj: Any) -> bool:
985
+ """Check if object is already patched."""
986
+ return getattr(obj, "_braintrust_patched", False)
987
+
988
+
989
+ def _try_dict(obj: Any) -> Iterable[Any] | dict[str, Any]:
990
+ """Try to convert object to dict, handling Pydantic models and circular references."""
991
+ if hasattr(obj, "model_dump"):
992
+ try:
993
+ obj = obj.model_dump(exclude_none=True)
994
+ except ValueError as e:
995
+ if "Circular reference" in str(e):
996
+ return {}
997
+ raise
998
+
999
+ if isinstance(obj, dict):
1000
+ return {k: _try_dict(v) for k, v in obj.items()}
1001
+ elif isinstance(obj, (list, tuple)):
1002
+ return [_try_dict(item) for item in obj]
1003
+
1004
+ return obj
1005
+
1006
+
1007
+ def _serialize_type(obj: Any) -> Any:
1008
+ """Serialize a type/class for logging, handling Pydantic models and other types.
1009
+
1010
+ This is useful for output_type, toolsets, and similar type parameters.
1011
+ Returns full JSON schema for Pydantic models so engineers can see exactly
1012
+ what structured output schema was used.
1013
+ """
1014
+ import inspect
1015
+
1016
+ # For sequences of types (like Union types or list of models)
1017
+ if isinstance(obj, (list, tuple)):
1018
+ return [_serialize_type(item) for item in obj]
1019
+
1020
+ # Handle Pydantic AI's output wrappers (ToolOutput, NativeOutput, PromptedOutput, TextOutput)
1021
+ if hasattr(obj, "output"):
1022
+ # These are wrapper classes with an 'output' field containing the actual type
1023
+ wrapper_info = {"wrapper": type(obj).__name__}
1024
+ if hasattr(obj, "name") and obj.name:
1025
+ wrapper_info["name"] = obj.name
1026
+ if hasattr(obj, "description") and obj.description:
1027
+ wrapper_info["description"] = obj.description
1028
+ wrapper_info["output"] = _serialize_type(obj.output)
1029
+ return wrapper_info
1030
+
1031
+ # If it's a Pydantic model class, return its full JSON schema
1032
+ if inspect.isclass(obj):
1033
+ try:
1034
+ from pydantic import BaseModel
1035
+
1036
+ if issubclass(obj, BaseModel):
1037
+ # Return the full JSON schema - includes all field info, descriptions, constraints, etc.
1038
+ return obj.model_json_schema()
1039
+ except (ImportError, AttributeError, TypeError):
1040
+ pass
1041
+
1042
+ # Not a Pydantic model, return class name
1043
+ return obj.__name__
1044
+
1045
+ # If it has a __name__ attribute (like functions), use that
1046
+ if hasattr(obj, "__name__"):
1047
+ return obj.__name__
1048
+
1049
+ # Try standard serialization
1050
+ return _try_dict(obj)
1051
+
1052
+
1053
+ G = TypeVar("G", bound=AsyncGenerator[Any, None])
1054
+
1055
+
1056
+ class aclosing(AbstractAsyncContextManager[G]):
1057
+ """Context manager for closing async generators."""
1058
+
1059
+ def __init__(self, async_generator: G):
1060
+ self.async_generator = async_generator
1061
+
1062
+ async def __aenter__(self):
1063
+ return self.async_generator
1064
+
1065
+ async def __aexit__(self, *exc_info: Any):
1066
+ try:
1067
+ await self.async_generator.aclose()
1068
+ except ValueError as e:
1069
+ if "was created in a different Context" not in str(e):
1070
+ raise
1071
+ else:
1072
+ logger.debug(
1073
+ f"Suppressed ContextVar error during async cleanup: {e}. "
1074
+ "This is expected when async generators yield across context boundaries."
1075
+ )
1076
+
1077
+
1078
+ def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[dict[str, Any], dict[str, Any]]:
1079
+ """Build input data and metadata for agent wrappers.
1080
+
1081
+ Returns:
1082
+ Tuple of (input_data, metadata)
1083
+ """
1084
+ input_data = {}
1085
+
1086
+ user_prompt = args[0] if len(args) > 0 else kwargs.get("user_prompt")
1087
+ if user_prompt is not None:
1088
+ input_data["user_prompt"] = _serialize_user_prompt(user_prompt)
1089
+
1090
+ for key, value in kwargs.items():
1091
+ if key == "deps":
1092
+ continue
1093
+ elif key == "message_history":
1094
+ input_data[key] = _serialize_messages(value) if value is not None else None
1095
+ elif key in ("output_type", "toolsets"):
1096
+ # These often contain types/classes, use special serialization
1097
+ input_data[key] = _serialize_type(value) if value is not None else None
1098
+ elif key == "model_settings":
1099
+ # model_settings passed to run() goes in INPUT (it's a run() parameter)
1100
+ input_data[key] = _try_dict(value) if value is not None else None
1101
+ else:
1102
+ input_data[key] = _try_dict(value) if value is not None else None
1103
+
1104
+ if "model" in kwargs:
1105
+ model_name, provider = _parse_model_string(kwargs["model"])
1106
+ else:
1107
+ model_name, provider = _extract_model_info(instance)
1108
+
1109
+ # Extract agent-level configuration for metadata
1110
+ # Only add to metadata if NOT explicitly passed in kwargs (those go in input)
1111
+ agent_model_settings = None
1112
+ if "model_settings" not in kwargs and hasattr(instance, "model_settings") and instance.model_settings is not None:
1113
+ agent_model_settings = instance.model_settings
1114
+
1115
+ metadata = _build_model_metadata(model_name, provider, agent_model_settings)
1116
+
1117
+ # Extract additional agent configuration (only if not passed as kwargs)
1118
+ if "name" not in kwargs and hasattr(instance, "name") and instance.name is not None:
1119
+ metadata["agent_name"] = instance.name
1120
+
1121
+ if "end_strategy" not in kwargs and hasattr(instance, "end_strategy") and instance.end_strategy is not None:
1122
+ metadata["end_strategy"] = str(instance.end_strategy)
1123
+
1124
+ # Extract output_type if set on agent and not passed as kwarg
1125
+ # output_type can be a Pydantic model, str, or other types that get converted to JSON schema
1126
+ if "output_type" not in kwargs and hasattr(instance, "output_type") and instance.output_type is not None:
1127
+ try:
1128
+ metadata["output_type"] = _serialize_type(instance.output_type)
1129
+ except Exception as e:
1130
+ logger.debug(f"Failed to extract output_type from agent: {e}")
1131
+
1132
+ # Extract toolsets if set on agent and not passed as kwarg
1133
+ # Toolsets go in INPUT (not metadata) because agent.run() accepts toolsets parameter
1134
+ if "toolsets" not in kwargs and hasattr(instance, "toolsets"):
1135
+ try:
1136
+ toolsets = instance.toolsets
1137
+ if toolsets:
1138
+ # Convert toolsets to a list with FULL tool schemas for input
1139
+ serialized_toolsets = []
1140
+ for ts in toolsets:
1141
+ ts_info = {
1142
+ "id": getattr(ts, "id", str(type(ts).__name__)),
1143
+ "label": getattr(ts, "label", None),
1144
+ }
1145
+ # Add full tool schemas (not just names) since toolsets can be passed to agent.run()
1146
+ if hasattr(ts, "tools") and ts.tools:
1147
+ tools_list = []
1148
+ tools_dict = ts.tools
1149
+ # tools is a dict mapping tool name -> Tool object
1150
+ for tool_name, tool_obj in tools_dict.items():
1151
+ tool_dict = {
1152
+ "name": tool_name,
1153
+ }
1154
+ # Extract description
1155
+ if hasattr(tool_obj, "description") and tool_obj.description:
1156
+ tool_dict["description"] = tool_obj.description
1157
+ # Extract JSON schema for parameters
1158
+ if hasattr(tool_obj, "function_schema") and hasattr(tool_obj.function_schema, "json_schema"):
1159
+ tool_dict["parameters"] = tool_obj.function_schema.json_schema
1160
+ tools_list.append(tool_dict)
1161
+ ts_info["tools"] = tools_list
1162
+ serialized_toolsets.append(ts_info)
1163
+ input_data["toolsets"] = serialized_toolsets
1164
+ except Exception as e:
1165
+ logger.debug(f"Failed to extract toolsets from agent: {e}")
1166
+
1167
+ # Extract system_prompt from agent if not passed as kwarg
1168
+ # Note: system_prompt goes in input (not metadata) because it's semantically part of the LLM input
1169
+ # Pydantic AI doesn't expose a public API for this, so we access the private _system_prompts
1170
+ # attribute. This is wrapped in try/except to gracefully handle if the internal structure changes.
1171
+ if "system_prompt" not in kwargs:
1172
+ try:
1173
+ if hasattr(instance, "_system_prompts") and instance._system_prompts:
1174
+ input_data["system_prompt"] = "\n\n".join(instance._system_prompts)
1175
+ except Exception as e:
1176
+ logger.debug(f"Failed to extract system_prompt from agent: {e}")
1177
+
1178
+ return input_data, metadata
1179
+
1180
+
1181
+ def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[dict[str, Any], dict[str, Any]]:
1182
+ """Build input data and metadata for direct model request wrappers.
1183
+
1184
+ Returns:
1185
+ Tuple of (input_data, metadata)
1186
+ """
1187
+ input_data = {}
1188
+
1189
+ model = args[0] if len(args) > 0 else kwargs.get("model")
1190
+ if model is not None:
1191
+ input_data["model"] = str(model)
1192
+
1193
+ messages = args[1] if len(args) > 1 else kwargs.get("messages", [])
1194
+ if messages:
1195
+ input_data["messages"] = _serialize_messages(messages)
1196
+
1197
+ for key, value in kwargs.items():
1198
+ if key not in ["model", "messages"]:
1199
+ input_data[key] = _try_dict(value) if value is not None else None
1200
+
1201
+ model_name, provider = _parse_model_string(model)
1202
+ metadata = _build_model_metadata(model_name, provider)
1203
+
1204
+ return input_data, metadata