braintrust 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +4 -0
- braintrust/_generated_types.py +1200 -611
- braintrust/audit.py +2 -2
- braintrust/cli/eval.py +6 -7
- braintrust/cli/push.py +11 -11
- braintrust/conftest.py +1 -0
- braintrust/context.py +12 -17
- braintrust/contrib/temporal/__init__.py +16 -27
- braintrust/contrib/temporal/test_temporal.py +8 -3
- braintrust/devserver/auth.py +8 -8
- braintrust/devserver/cache.py +3 -4
- braintrust/devserver/cors.py +8 -7
- braintrust/devserver/dataset.py +3 -5
- braintrust/devserver/eval_hooks.py +7 -6
- braintrust/devserver/schemas.py +22 -19
- braintrust/devserver/server.py +19 -12
- braintrust/devserver/test_cached_login.py +4 -4
- braintrust/framework.py +128 -140
- braintrust/framework2.py +88 -87
- braintrust/functions/invoke.py +93 -53
- braintrust/functions/stream.py +3 -2
- braintrust/generated_types.py +17 -1
- braintrust/git_fields.py +11 -11
- braintrust/gitutil.py +2 -3
- braintrust/graph_util.py +10 -10
- braintrust/id_gen.py +2 -2
- braintrust/logger.py +346 -357
- braintrust/merge_row_batch.py +10 -9
- braintrust/oai.py +107 -24
- braintrust/otel/__init__.py +49 -49
- braintrust/otel/context.py +16 -30
- braintrust/otel/test_distributed_tracing.py +14 -11
- braintrust/otel/test_otel_bt_integration.py +32 -31
- braintrust/parameters.py +8 -8
- braintrust/prompt.py +14 -14
- braintrust/prompt_cache/disk_cache.py +5 -4
- braintrust/prompt_cache/lru_cache.py +3 -2
- braintrust/prompt_cache/prompt_cache.py +13 -14
- braintrust/queue.py +4 -4
- braintrust/score.py +4 -4
- braintrust/serializable_data_class.py +4 -4
- braintrust/span_identifier_v1.py +1 -2
- braintrust/span_identifier_v2.py +3 -4
- braintrust/span_identifier_v3.py +23 -20
- braintrust/span_identifier_v4.py +34 -25
- braintrust/test_framework.py +16 -6
- braintrust/test_helpers.py +5 -5
- braintrust/test_id_gen.py +2 -3
- braintrust/test_otel.py +61 -53
- braintrust/test_queue.py +0 -1
- braintrust/test_score.py +1 -3
- braintrust/test_span_components.py +29 -44
- braintrust/util.py +9 -8
- braintrust/version.py +2 -2
- braintrust/wrappers/_anthropic_utils.py +4 -4
- braintrust/wrappers/agno/__init__.py +3 -4
- braintrust/wrappers/agno/agent.py +1 -2
- braintrust/wrappers/agno/function_call.py +1 -2
- braintrust/wrappers/agno/model.py +1 -2
- braintrust/wrappers/agno/team.py +1 -2
- braintrust/wrappers/agno/utils.py +12 -12
- braintrust/wrappers/anthropic.py +7 -8
- braintrust/wrappers/claude_agent_sdk/__init__.py +3 -4
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +29 -27
- braintrust/wrappers/dspy.py +15 -17
- braintrust/wrappers/google_genai/__init__.py +16 -16
- braintrust/wrappers/langchain.py +22 -24
- braintrust/wrappers/litellm.py +4 -3
- braintrust/wrappers/openai.py +15 -15
- braintrust/wrappers/pydantic_ai.py +1204 -0
- braintrust/wrappers/test_agno.py +0 -1
- braintrust/wrappers/test_dspy.py +0 -1
- braintrust/wrappers/test_google_genai.py +2 -3
- braintrust/wrappers/test_litellm.py +0 -1
- braintrust/wrappers/test_oai_attachments.py +322 -0
- braintrust/wrappers/test_pydantic_ai_integration.py +1788 -0
- braintrust/wrappers/{test_pydantic_ai.py → test_pydantic_ai_wrap_openai.py} +1 -2
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/METADATA +3 -2
- braintrust-0.4.0.dist-info/RECORD +120 -0
- braintrust-0.3.14.dist-info/RECORD +0 -117
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/WHEEL +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/entry_points.txt +0 -0
- {braintrust-0.3.14.dist-info → braintrust-0.4.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1204 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import AsyncGenerator, Iterable
|
|
5
|
+
from contextlib import AbstractAsyncContextManager
|
|
6
|
+
from typing import Any, TypeVar
|
|
7
|
+
|
|
8
|
+
from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span
|
|
9
|
+
from braintrust.span_types import SpanTypeAttribute
|
|
10
|
+
from wrapt import wrap_function_wrapper
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
__all__ = ["setup_pydantic_ai"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def setup_pydantic_ai(
|
|
18
|
+
api_key: str | None = None,
|
|
19
|
+
project_id: str | None = None,
|
|
20
|
+
project_name: str | None = None,
|
|
21
|
+
) -> bool:
|
|
22
|
+
"""
|
|
23
|
+
Setup Braintrust integration with Pydantic AI. Will automatically patch Pydantic AI Agents and direct API functions for automatic tracing.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
api_key (Optional[str]): Braintrust API key.
|
|
27
|
+
project_id (Optional[str]): Braintrust project ID.
|
|
28
|
+
project_name (Optional[str]): Braintrust project name.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
bool: True if setup was successful, False otherwise.
|
|
32
|
+
"""
|
|
33
|
+
span = current_span()
|
|
34
|
+
if span == NOOP_SPAN:
|
|
35
|
+
init_logger(project=project_name, api_key=api_key, project_id=project_id)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import pydantic_ai.direct as direct_module
|
|
39
|
+
from pydantic_ai import Agent
|
|
40
|
+
|
|
41
|
+
Agent = wrap_agent(Agent)
|
|
42
|
+
|
|
43
|
+
wrap_function_wrapper(direct_module, "model_request", _create_direct_model_request_wrapper())
|
|
44
|
+
wrap_function_wrapper(direct_module, "model_request_sync", _create_direct_model_request_sync_wrapper())
|
|
45
|
+
wrap_function_wrapper(direct_module, "model_request_stream", _create_direct_model_request_stream_wrapper())
|
|
46
|
+
wrap_function_wrapper(
|
|
47
|
+
direct_module, "model_request_stream_sync", _create_direct_model_request_stream_sync_wrapper()
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
wrap_model_classes()
|
|
51
|
+
|
|
52
|
+
return True
|
|
53
|
+
except ImportError as e:
|
|
54
|
+
logger.error(f"Failed to import Pydantic AI: {e}")
|
|
55
|
+
logger.error("Pydantic AI is not installed. Please install it with: pip install pydantic-ai-slim")
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def wrap_agent(Agent: Any) -> Any:
|
|
60
|
+
if _is_patched(Agent):
|
|
61
|
+
return Agent
|
|
62
|
+
|
|
63
|
+
def _ensure_model_wrapped(instance: Any):
|
|
64
|
+
"""Ensure the agent's model class is wrapped (lazy wrapping)."""
|
|
65
|
+
if hasattr(instance, "_model"):
|
|
66
|
+
model_class = type(instance._model)
|
|
67
|
+
_wrap_concrete_model_class(model_class)
|
|
68
|
+
|
|
69
|
+
async def agent_run_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
70
|
+
_ensure_model_wrapped(instance)
|
|
71
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
|
|
72
|
+
|
|
73
|
+
with start_span(
|
|
74
|
+
name=f"agent_run [{instance.name}]" if hasattr(instance, "name") and instance.name else "agent_run",
|
|
75
|
+
type=SpanTypeAttribute.LLM,
|
|
76
|
+
input=input_data if input_data else None,
|
|
77
|
+
metadata=_try_dict(metadata),
|
|
78
|
+
) as agent_span:
|
|
79
|
+
start_time = time.time()
|
|
80
|
+
result = await wrapped(*args, **kwargs)
|
|
81
|
+
end_time = time.time()
|
|
82
|
+
|
|
83
|
+
output = _serialize_result_output(result)
|
|
84
|
+
metrics = _extract_usage_metrics(result, start_time, end_time)
|
|
85
|
+
|
|
86
|
+
agent_span.log(output=output, metrics=metrics)
|
|
87
|
+
return result
|
|
88
|
+
|
|
89
|
+
wrap_function_wrapper(Agent, "run", agent_run_wrapper)
|
|
90
|
+
|
|
91
|
+
def agent_run_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
92
|
+
_ensure_model_wrapped(instance)
|
|
93
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
|
|
94
|
+
|
|
95
|
+
with start_span(
|
|
96
|
+
name=f"agent_run_sync [{instance.name}]"
|
|
97
|
+
if hasattr(instance, "name") and instance.name
|
|
98
|
+
else "agent_run_sync",
|
|
99
|
+
type=SpanTypeAttribute.LLM,
|
|
100
|
+
input=input_data if input_data else None,
|
|
101
|
+
metadata=_try_dict(metadata),
|
|
102
|
+
) as agent_span:
|
|
103
|
+
start_time = time.time()
|
|
104
|
+
result = wrapped(*args, **kwargs)
|
|
105
|
+
end_time = time.time()
|
|
106
|
+
|
|
107
|
+
output = _serialize_result_output(result)
|
|
108
|
+
metrics = _extract_usage_metrics(result, start_time, end_time)
|
|
109
|
+
|
|
110
|
+
agent_span.log(output=output, metrics=metrics)
|
|
111
|
+
return result
|
|
112
|
+
|
|
113
|
+
wrap_function_wrapper(Agent, "run_sync", agent_run_sync_wrapper)
|
|
114
|
+
|
|
115
|
+
def agent_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
116
|
+
_ensure_model_wrapped(instance)
|
|
117
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
|
|
118
|
+
agent_name = instance.name if hasattr(instance, "name") else None
|
|
119
|
+
span_name = f"agent_run_stream [{agent_name}]" if agent_name else "agent_run_stream"
|
|
120
|
+
|
|
121
|
+
return _AgentStreamWrapper(
|
|
122
|
+
wrapped(*args, **kwargs),
|
|
123
|
+
span_name,
|
|
124
|
+
input_data,
|
|
125
|
+
metadata,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
wrap_function_wrapper(Agent, "run_stream", agent_run_stream_wrapper)
|
|
129
|
+
|
|
130
|
+
def agent_run_stream_sync_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
131
|
+
_ensure_model_wrapped(instance)
|
|
132
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
|
|
133
|
+
agent_name = instance.name if hasattr(instance, "name") else None
|
|
134
|
+
span_name = f"agent_run_stream_sync [{agent_name}]" if agent_name else "agent_run_stream_sync"
|
|
135
|
+
|
|
136
|
+
# Create span context BEFORE calling wrapped function so internal spans nest under it
|
|
137
|
+
span_cm = start_span(
|
|
138
|
+
name=span_name,
|
|
139
|
+
type=SpanTypeAttribute.LLM,
|
|
140
|
+
input=input_data if input_data else None,
|
|
141
|
+
metadata=_try_dict(metadata),
|
|
142
|
+
)
|
|
143
|
+
span = span_cm.__enter__()
|
|
144
|
+
start_time = time.time()
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
# Call the original function within the span context
|
|
148
|
+
stream_result = wrapped(*args, **kwargs)
|
|
149
|
+
return _AgentStreamResultSyncProxy(
|
|
150
|
+
stream_result,
|
|
151
|
+
span,
|
|
152
|
+
span_cm,
|
|
153
|
+
start_time,
|
|
154
|
+
)
|
|
155
|
+
except Exception:
|
|
156
|
+
# Clean up span on error
|
|
157
|
+
span_cm.__exit__(*sys.exc_info())
|
|
158
|
+
raise
|
|
159
|
+
|
|
160
|
+
wrap_function_wrapper(Agent, "run_stream_sync", agent_run_stream_sync_wrapper)
|
|
161
|
+
|
|
162
|
+
async def agent_run_stream_events_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
163
|
+
_ensure_model_wrapped(instance)
|
|
164
|
+
input_data, metadata = _build_agent_input_and_metadata(args, kwargs, instance)
|
|
165
|
+
|
|
166
|
+
agent_name = instance.name if hasattr(instance, "name") else None
|
|
167
|
+
span_name = f"agent_run_stream_events [{agent_name}]" if agent_name else "agent_run_stream_events"
|
|
168
|
+
|
|
169
|
+
with start_span(
|
|
170
|
+
name=span_name,
|
|
171
|
+
type=SpanTypeAttribute.LLM,
|
|
172
|
+
input=input_data if input_data else None,
|
|
173
|
+
metadata=_try_dict(metadata),
|
|
174
|
+
) as agent_span:
|
|
175
|
+
start_time = time.time()
|
|
176
|
+
event_count = 0
|
|
177
|
+
final_result = None
|
|
178
|
+
|
|
179
|
+
async for event in wrapped(*args, **kwargs):
|
|
180
|
+
event_count += 1
|
|
181
|
+
if hasattr(event, "output"):
|
|
182
|
+
final_result = event
|
|
183
|
+
yield event
|
|
184
|
+
|
|
185
|
+
end_time = time.time()
|
|
186
|
+
|
|
187
|
+
output = None
|
|
188
|
+
metrics = {
|
|
189
|
+
"start": start_time,
|
|
190
|
+
"end": end_time,
|
|
191
|
+
"duration": end_time - start_time,
|
|
192
|
+
"event_count": event_count,
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
if final_result:
|
|
196
|
+
output = _serialize_result_output(final_result)
|
|
197
|
+
usage_metrics = _extract_usage_metrics(final_result, start_time, end_time)
|
|
198
|
+
metrics.update(usage_metrics)
|
|
199
|
+
|
|
200
|
+
agent_span.log(output=output, metrics=metrics)
|
|
201
|
+
|
|
202
|
+
wrap_function_wrapper(Agent, "run_stream_events", agent_run_stream_events_wrapper)
|
|
203
|
+
|
|
204
|
+
Agent._braintrust_patched = True
|
|
205
|
+
|
|
206
|
+
return Agent
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _create_direct_model_request_wrapper():
|
|
210
|
+
"""Create wrapper for direct.model_request()."""
|
|
211
|
+
|
|
212
|
+
async def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
213
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
214
|
+
|
|
215
|
+
with start_span(
|
|
216
|
+
name="model_request",
|
|
217
|
+
type=SpanTypeAttribute.LLM,
|
|
218
|
+
input=input_data,
|
|
219
|
+
metadata=_try_dict(metadata),
|
|
220
|
+
) as span:
|
|
221
|
+
start_time = time.time()
|
|
222
|
+
result = await wrapped(*args, **kwargs)
|
|
223
|
+
end_time = time.time()
|
|
224
|
+
|
|
225
|
+
output = _serialize_model_response(result)
|
|
226
|
+
metrics = _extract_response_metrics(result, start_time, end_time)
|
|
227
|
+
|
|
228
|
+
span.log(output=output, metrics=metrics)
|
|
229
|
+
return result
|
|
230
|
+
|
|
231
|
+
return wrapper
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _create_direct_model_request_sync_wrapper():
|
|
235
|
+
"""Create wrapper for direct.model_request_sync()."""
|
|
236
|
+
|
|
237
|
+
def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
238
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
239
|
+
|
|
240
|
+
with start_span(
|
|
241
|
+
name="model_request_sync",
|
|
242
|
+
type=SpanTypeAttribute.LLM,
|
|
243
|
+
input=input_data,
|
|
244
|
+
metadata=_try_dict(metadata),
|
|
245
|
+
) as span:
|
|
246
|
+
start_time = time.time()
|
|
247
|
+
result = wrapped(*args, **kwargs)
|
|
248
|
+
end_time = time.time()
|
|
249
|
+
|
|
250
|
+
output = _serialize_model_response(result)
|
|
251
|
+
metrics = _extract_response_metrics(result, start_time, end_time)
|
|
252
|
+
|
|
253
|
+
span.log(output=output, metrics=metrics)
|
|
254
|
+
return result
|
|
255
|
+
|
|
256
|
+
return wrapper
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _create_direct_model_request_stream_wrapper():
|
|
260
|
+
"""Create wrapper for direct.model_request_stream()."""
|
|
261
|
+
|
|
262
|
+
def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
263
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
264
|
+
|
|
265
|
+
return _DirectStreamWrapper(
|
|
266
|
+
wrapped(*args, **kwargs),
|
|
267
|
+
"model_request_stream",
|
|
268
|
+
input_data,
|
|
269
|
+
metadata,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return wrapper
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _create_direct_model_request_stream_sync_wrapper():
|
|
276
|
+
"""Create wrapper for direct.model_request_stream_sync()."""
|
|
277
|
+
|
|
278
|
+
def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
279
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
280
|
+
|
|
281
|
+
return _DirectStreamWrapperSync(
|
|
282
|
+
wrapped(*args, **kwargs),
|
|
283
|
+
"model_request_stream_sync",
|
|
284
|
+
input_data,
|
|
285
|
+
metadata,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return wrapper
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def wrap_model_request(original_func: Any) -> Any:
|
|
292
|
+
async def wrapper(*args, **kwargs):
|
|
293
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
294
|
+
|
|
295
|
+
with start_span(
|
|
296
|
+
name="model_request",
|
|
297
|
+
type=SpanTypeAttribute.LLM,
|
|
298
|
+
input=input_data,
|
|
299
|
+
metadata=_try_dict(metadata),
|
|
300
|
+
) as span:
|
|
301
|
+
start_time = time.time()
|
|
302
|
+
result = await original_func(*args, **kwargs)
|
|
303
|
+
end_time = time.time()
|
|
304
|
+
|
|
305
|
+
output = _serialize_model_response(result)
|
|
306
|
+
metrics = _extract_response_metrics(result, start_time, end_time)
|
|
307
|
+
|
|
308
|
+
span.log(output=output, metrics=metrics)
|
|
309
|
+
return result
|
|
310
|
+
|
|
311
|
+
return wrapper
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def wrap_model_request_sync(original_func: Any) -> Any:
|
|
315
|
+
def wrapper(*args, **kwargs):
|
|
316
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
317
|
+
|
|
318
|
+
with start_span(
|
|
319
|
+
name="model_request_sync",
|
|
320
|
+
type=SpanTypeAttribute.LLM,
|
|
321
|
+
input=input_data,
|
|
322
|
+
metadata=_try_dict(metadata),
|
|
323
|
+
) as span:
|
|
324
|
+
start_time = time.time()
|
|
325
|
+
result = original_func(*args, **kwargs)
|
|
326
|
+
end_time = time.time()
|
|
327
|
+
|
|
328
|
+
output = _serialize_model_response(result)
|
|
329
|
+
metrics = _extract_response_metrics(result, start_time, end_time)
|
|
330
|
+
|
|
331
|
+
span.log(output=output, metrics=metrics)
|
|
332
|
+
return result
|
|
333
|
+
|
|
334
|
+
return wrapper
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def wrap_model_request_stream(original_func: Any) -> Any:
|
|
338
|
+
def wrapper(*args, **kwargs):
|
|
339
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
340
|
+
|
|
341
|
+
return _DirectStreamWrapper(
|
|
342
|
+
original_func(*args, **kwargs),
|
|
343
|
+
"model_request_stream",
|
|
344
|
+
input_data,
|
|
345
|
+
metadata,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return wrapper
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def wrap_model_request_stream_sync(original_func: Any) -> Any:
|
|
352
|
+
def wrapper(*args, **kwargs):
|
|
353
|
+
input_data, metadata = _build_direct_model_input_and_metadata(args, kwargs)
|
|
354
|
+
|
|
355
|
+
return _DirectStreamWrapperSync(
|
|
356
|
+
original_func(*args, **kwargs),
|
|
357
|
+
"model_request_stream_sync",
|
|
358
|
+
input_data,
|
|
359
|
+
metadata,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
return wrapper
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def wrap_model_classes():
|
|
366
|
+
"""Wrap Model classes to capture internal model requests made by agents."""
|
|
367
|
+
try:
|
|
368
|
+
from pydantic_ai.models import Model
|
|
369
|
+
|
|
370
|
+
def wrap_all_subclasses(base_class):
|
|
371
|
+
"""Recursively wrap all subclasses of a base class."""
|
|
372
|
+
for subclass in base_class.__subclasses__():
|
|
373
|
+
if not getattr(subclass, "__abstractmethods__", None):
|
|
374
|
+
try:
|
|
375
|
+
_wrap_concrete_model_class(subclass)
|
|
376
|
+
except Exception as e:
|
|
377
|
+
logger.debug(f"Could not wrap {subclass.__name__}: {e}")
|
|
378
|
+
|
|
379
|
+
wrap_all_subclasses(subclass)
|
|
380
|
+
|
|
381
|
+
wrap_all_subclasses(Model)
|
|
382
|
+
|
|
383
|
+
except Exception as e:
|
|
384
|
+
logger.warning(f"Failed to wrap Model classes: {e}")
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def _build_model_class_input_and_metadata(instance: Any, args: Any, kwargs: Any):
|
|
388
|
+
"""Build input data and metadata for model class request wrappers.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Tuple of (model_name, display_name, input_data, metadata)
|
|
392
|
+
"""
|
|
393
|
+
model_name, provider = _extract_model_info_from_model_instance(instance)
|
|
394
|
+
display_name = model_name or str(instance)
|
|
395
|
+
|
|
396
|
+
messages = args[0] if len(args) > 0 else kwargs.get("messages")
|
|
397
|
+
model_settings = args[1] if len(args) > 1 else kwargs.get("model_settings")
|
|
398
|
+
|
|
399
|
+
serialized_messages = _serialize_messages(messages)
|
|
400
|
+
|
|
401
|
+
input_data = {"messages": serialized_messages}
|
|
402
|
+
if model_settings is not None:
|
|
403
|
+
input_data["model_settings"] = _try_dict(model_settings)
|
|
404
|
+
|
|
405
|
+
metadata = _build_model_metadata(model_name, provider, model_settings=None)
|
|
406
|
+
|
|
407
|
+
return model_name, display_name, input_data, metadata
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def _wrap_concrete_model_class(model_class: Any):
|
|
411
|
+
"""Wrap a concrete model class to trace its request methods."""
|
|
412
|
+
if _is_patched(model_class):
|
|
413
|
+
return
|
|
414
|
+
|
|
415
|
+
async def model_request_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
416
|
+
model_name, display_name, input_data, metadata = _build_model_class_input_and_metadata(instance, args, kwargs)
|
|
417
|
+
|
|
418
|
+
with start_span(
|
|
419
|
+
name=f"chat {display_name}",
|
|
420
|
+
type=SpanTypeAttribute.LLM,
|
|
421
|
+
input=input_data,
|
|
422
|
+
metadata=_try_dict(metadata),
|
|
423
|
+
) as span:
|
|
424
|
+
start_time = time.time()
|
|
425
|
+
result = await wrapped(*args, **kwargs)
|
|
426
|
+
end_time = time.time()
|
|
427
|
+
|
|
428
|
+
output = _serialize_model_response(result)
|
|
429
|
+
metrics = _extract_response_metrics(result, start_time, end_time)
|
|
430
|
+
|
|
431
|
+
span.log(output=output, metrics=metrics)
|
|
432
|
+
return result
|
|
433
|
+
|
|
434
|
+
def model_request_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any):
|
|
435
|
+
model_name, display_name, input_data, metadata = _build_model_class_input_and_metadata(instance, args, kwargs)
|
|
436
|
+
|
|
437
|
+
return _DirectStreamWrapper(
|
|
438
|
+
wrapped(*args, **kwargs),
|
|
439
|
+
f"chat {display_name}",
|
|
440
|
+
input_data,
|
|
441
|
+
metadata,
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
wrap_function_wrapper(model_class, "request", model_request_wrapper)
|
|
445
|
+
wrap_function_wrapper(model_class, "request_stream", model_request_stream_wrapper)
|
|
446
|
+
model_class._braintrust_patched = True
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
class _AgentStreamWrapper(AbstractAsyncContextManager):
|
|
450
|
+
"""Wrapper for agent.run_stream() that adds tracing while passing through the stream result."""
|
|
451
|
+
|
|
452
|
+
def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
|
|
453
|
+
self.stream_cm = stream_cm
|
|
454
|
+
self.span_name = span_name
|
|
455
|
+
self.input_data = input_data
|
|
456
|
+
self.metadata = metadata
|
|
457
|
+
self.span_cm = None
|
|
458
|
+
self.start_time = None
|
|
459
|
+
self.stream_result = None
|
|
460
|
+
|
|
461
|
+
async def __aenter__(self):
|
|
462
|
+
# Use context manager properly so span stays current
|
|
463
|
+
# DON'T pass start_time here - we'll set it via metrics in __aexit__
|
|
464
|
+
self.span_cm = start_span(
|
|
465
|
+
name=self.span_name,
|
|
466
|
+
type=SpanTypeAttribute.LLM,
|
|
467
|
+
input=self.input_data if self.input_data else None,
|
|
468
|
+
metadata=_try_dict(self.metadata),
|
|
469
|
+
)
|
|
470
|
+
span = self.span_cm.__enter__()
|
|
471
|
+
|
|
472
|
+
# Capture start time right before entering the stream (API call initiation)
|
|
473
|
+
self.start_time = time.time()
|
|
474
|
+
self.stream_result = await self.stream_cm.__aenter__()
|
|
475
|
+
return self.stream_result # Return actual stream result object
|
|
476
|
+
|
|
477
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
478
|
+
try:
|
|
479
|
+
await self.stream_cm.__aexit__(exc_type, exc_val, exc_tb)
|
|
480
|
+
finally:
|
|
481
|
+
if self.span_cm and self.start_time and self.stream_result:
|
|
482
|
+
end_time = time.time()
|
|
483
|
+
|
|
484
|
+
output = _serialize_stream_output(self.stream_result)
|
|
485
|
+
metrics = _extract_stream_usage_metrics(self.stream_result, self.start_time, end_time, None)
|
|
486
|
+
self.span_cm.log(output=output, metrics=metrics)
|
|
487
|
+
|
|
488
|
+
# Always clean up span context
|
|
489
|
+
if self.span_cm:
|
|
490
|
+
self.span_cm.__exit__(None, None, None)
|
|
491
|
+
|
|
492
|
+
return False
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
class _DirectStreamWrapper(AbstractAsyncContextManager):
|
|
496
|
+
"""Wrapper for model_request_stream() that adds tracing while passing through the stream."""
|
|
497
|
+
|
|
498
|
+
def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
|
|
499
|
+
self.stream_cm = stream_cm
|
|
500
|
+
self.span_name = span_name
|
|
501
|
+
self.input_data = input_data
|
|
502
|
+
self.metadata = metadata
|
|
503
|
+
self.span_cm = None
|
|
504
|
+
self.start_time = None
|
|
505
|
+
self.stream = None
|
|
506
|
+
|
|
507
|
+
async def __aenter__(self):
|
|
508
|
+
# Use context manager properly so span stays current
|
|
509
|
+
# DON'T pass start_time here - we'll set it via metrics in __aexit__
|
|
510
|
+
self.span_cm = start_span(
|
|
511
|
+
name=self.span_name,
|
|
512
|
+
type=SpanTypeAttribute.LLM,
|
|
513
|
+
input=self.input_data if self.input_data else None,
|
|
514
|
+
metadata=_try_dict(self.metadata),
|
|
515
|
+
)
|
|
516
|
+
span = self.span_cm.__enter__()
|
|
517
|
+
|
|
518
|
+
# Capture start time right before entering the stream (API call initiation)
|
|
519
|
+
self.start_time = time.time()
|
|
520
|
+
self.stream = await self.stream_cm.__aenter__()
|
|
521
|
+
return self.stream # Return actual stream object
|
|
522
|
+
|
|
523
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
524
|
+
try:
|
|
525
|
+
await self.stream_cm.__aexit__(exc_type, exc_val, exc_tb)
|
|
526
|
+
finally:
|
|
527
|
+
if self.span_cm and self.start_time and self.stream:
|
|
528
|
+
end_time = time.time()
|
|
529
|
+
|
|
530
|
+
try:
|
|
531
|
+
final_response = self.stream.get()
|
|
532
|
+
output = _serialize_model_response(final_response)
|
|
533
|
+
metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
|
|
534
|
+
self.span_cm.log(output=output, metrics=metrics)
|
|
535
|
+
except Exception as e:
|
|
536
|
+
logger.debug(f"Failed to extract stream output/metrics: {e}")
|
|
537
|
+
|
|
538
|
+
# Always clean up span context
|
|
539
|
+
if self.span_cm:
|
|
540
|
+
self.span_cm.__exit__(None, None, None)
|
|
541
|
+
|
|
542
|
+
return False
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
class _AgentStreamResultSyncProxy:
|
|
546
|
+
"""Proxy for agent.run_stream_sync() result that adds tracing while delegating to actual stream result."""
|
|
547
|
+
|
|
548
|
+
def __init__(self, stream_result: Any, span: Any, span_cm: Any, start_time: float):
|
|
549
|
+
self._stream_result = stream_result
|
|
550
|
+
self._span = span
|
|
551
|
+
self._span_cm = span_cm
|
|
552
|
+
self._start_time = start_time
|
|
553
|
+
self._logged = False
|
|
554
|
+
self._finalize_on_del = True
|
|
555
|
+
|
|
556
|
+
def __getattr__(self, name: str):
|
|
557
|
+
"""Delegate all attribute access to the wrapped stream result."""
|
|
558
|
+
attr = getattr(self._stream_result, name)
|
|
559
|
+
|
|
560
|
+
# Wrap any method that returns an iterator to auto-finalize when exhausted
|
|
561
|
+
if callable(attr) and name in ('stream_text', 'stream_output', '__iter__'):
|
|
562
|
+
def wrapped_method(*args, **kwargs):
|
|
563
|
+
try:
|
|
564
|
+
iterator = attr(*args, **kwargs)
|
|
565
|
+
# If it's an iterator, wrap it
|
|
566
|
+
if hasattr(iterator, '__iter__') or hasattr(iterator, '__next__'):
|
|
567
|
+
try:
|
|
568
|
+
yield from iterator
|
|
569
|
+
finally:
|
|
570
|
+
self._finalize()
|
|
571
|
+
self._finalize_on_del = False # Don't finalize again in __del__
|
|
572
|
+
else:
|
|
573
|
+
return iterator
|
|
574
|
+
except Exception:
|
|
575
|
+
self._finalize()
|
|
576
|
+
self._finalize_on_del = False
|
|
577
|
+
raise
|
|
578
|
+
return wrapped_method
|
|
579
|
+
|
|
580
|
+
return attr
|
|
581
|
+
|
|
582
|
+
def _finalize(self):
|
|
583
|
+
"""Log metrics and close span."""
|
|
584
|
+
if self._span and not self._logged and self._stream_result:
|
|
585
|
+
try:
|
|
586
|
+
end_time = time.time()
|
|
587
|
+
output = _serialize_stream_output(self._stream_result)
|
|
588
|
+
metrics = _extract_stream_usage_metrics(self._stream_result, self._start_time, end_time, None)
|
|
589
|
+
self._span.log(output=output, metrics=metrics)
|
|
590
|
+
self._logged = True
|
|
591
|
+
finally:
|
|
592
|
+
try:
|
|
593
|
+
self._span_cm.__exit__(None, None, None)
|
|
594
|
+
except Exception:
|
|
595
|
+
pass
|
|
596
|
+
|
|
597
|
+
def __del__(self):
|
|
598
|
+
"""Ensure span is closed when proxy is destroyed."""
|
|
599
|
+
if self._finalize_on_del:
|
|
600
|
+
self._finalize()
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
class _DirectStreamWrapperSync:
|
|
604
|
+
"""Wrapper for model_request_stream_sync() that adds tracing while passing through the stream."""
|
|
605
|
+
|
|
606
|
+
def __init__(self, stream_cm: Any, span_name: str, input_data: Any, metadata: Any):
|
|
607
|
+
self.stream_cm = stream_cm
|
|
608
|
+
self.span_name = span_name
|
|
609
|
+
self.input_data = input_data
|
|
610
|
+
self.metadata = metadata
|
|
611
|
+
self.span_cm = None
|
|
612
|
+
self.start_time = None
|
|
613
|
+
self.stream = None
|
|
614
|
+
|
|
615
|
+
def __enter__(self):
|
|
616
|
+
# Use context manager properly so span stays current
|
|
617
|
+
# DON'T pass start_time here - we'll set it via metrics in __exit__
|
|
618
|
+
self.span_cm = start_span(
|
|
619
|
+
name=self.span_name,
|
|
620
|
+
type=SpanTypeAttribute.LLM,
|
|
621
|
+
input=self.input_data if self.input_data else None,
|
|
622
|
+
metadata=_try_dict(self.metadata),
|
|
623
|
+
)
|
|
624
|
+
span = self.span_cm.__enter__()
|
|
625
|
+
|
|
626
|
+
# Capture start time right before entering the stream (API call initiation)
|
|
627
|
+
self.start_time = time.time()
|
|
628
|
+
self.stream = self.stream_cm.__enter__()
|
|
629
|
+
return self.stream # Return actual stream object
|
|
630
|
+
|
|
631
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
632
|
+
try:
|
|
633
|
+
self.stream_cm.__exit__(exc_type, exc_val, exc_tb)
|
|
634
|
+
finally:
|
|
635
|
+
if self.span_cm and self.start_time and self.stream:
|
|
636
|
+
end_time = time.time()
|
|
637
|
+
|
|
638
|
+
try:
|
|
639
|
+
final_response = self.stream.get()
|
|
640
|
+
output = _serialize_model_response(final_response)
|
|
641
|
+
metrics = _extract_response_metrics(final_response, self.start_time, end_time, None)
|
|
642
|
+
self.span_cm.log(output=output, metrics=metrics)
|
|
643
|
+
except Exception as e:
|
|
644
|
+
logger.debug(f"Failed to extract stream output/metrics: {e}")
|
|
645
|
+
|
|
646
|
+
# Always clean up span context
|
|
647
|
+
if self.span_cm:
|
|
648
|
+
self.span_cm.__exit__(None, None, None)
|
|
649
|
+
|
|
650
|
+
return False
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def _serialize_user_prompt(user_prompt: Any) -> Any:
|
|
654
|
+
"""Serialize user prompt, handling BinaryContent and other types."""
|
|
655
|
+
if user_prompt is None:
|
|
656
|
+
return None
|
|
657
|
+
|
|
658
|
+
if isinstance(user_prompt, str):
|
|
659
|
+
return user_prompt
|
|
660
|
+
|
|
661
|
+
if isinstance(user_prompt, list):
|
|
662
|
+
return [_serialize_content_part(part) for part in user_prompt]
|
|
663
|
+
|
|
664
|
+
return _serialize_content_part(user_prompt)
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def _serialize_content_part(part: Any) -> Any:
|
|
668
|
+
"""Serialize a content part, handling BinaryContent specially."""
|
|
669
|
+
if part is None:
|
|
670
|
+
return None
|
|
671
|
+
|
|
672
|
+
if hasattr(part, "data") and hasattr(part, "media_type") and hasattr(part, "kind"):
|
|
673
|
+
if part.kind == "binary":
|
|
674
|
+
data = part.data
|
|
675
|
+
media_type = part.media_type
|
|
676
|
+
|
|
677
|
+
extension = media_type.split("/")[1] if "/" in media_type else "bin"
|
|
678
|
+
filename = f"file.{extension}"
|
|
679
|
+
|
|
680
|
+
attachment = Attachment(data=data, filename=filename, content_type=media_type)
|
|
681
|
+
return {"type": "binary", "attachment": attachment, "media_type": media_type}
|
|
682
|
+
|
|
683
|
+
if isinstance(part, str):
|
|
684
|
+
return part
|
|
685
|
+
|
|
686
|
+
return _try_dict(part)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def _serialize_messages(messages: Any) -> Any:
|
|
690
|
+
"""Serialize messages list."""
|
|
691
|
+
if not messages:
|
|
692
|
+
return []
|
|
693
|
+
|
|
694
|
+
result = []
|
|
695
|
+
for msg in messages:
|
|
696
|
+
serialized_msg = _try_dict(msg)
|
|
697
|
+
|
|
698
|
+
if isinstance(serialized_msg, dict) and "parts" in serialized_msg:
|
|
699
|
+
serialized_msg["parts"] = [_serialize_content_part(p) for p in msg.parts]
|
|
700
|
+
|
|
701
|
+
result.append(serialized_msg)
|
|
702
|
+
|
|
703
|
+
return result
|
|
704
|
+
|
|
705
|
+
|
|
706
|
+
def _serialize_result_output(result: Any) -> Any:
|
|
707
|
+
"""Serialize agent run result output."""
|
|
708
|
+
if not result:
|
|
709
|
+
return None
|
|
710
|
+
|
|
711
|
+
output_dict = {}
|
|
712
|
+
|
|
713
|
+
if hasattr(result, "output"):
|
|
714
|
+
output_dict["output"] = _try_dict(result.output)
|
|
715
|
+
|
|
716
|
+
if hasattr(result, "response"):
|
|
717
|
+
output_dict["response"] = _serialize_model_response(result.response)
|
|
718
|
+
|
|
719
|
+
return output_dict if output_dict else _try_dict(result)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _serialize_stream_output(stream_result: Any) -> Any:
|
|
723
|
+
"""Serialize stream result output."""
|
|
724
|
+
if not stream_result:
|
|
725
|
+
return None
|
|
726
|
+
|
|
727
|
+
output_dict = {}
|
|
728
|
+
|
|
729
|
+
if hasattr(stream_result, "response"):
|
|
730
|
+
output_dict["response"] = _serialize_model_response(stream_result.response)
|
|
731
|
+
|
|
732
|
+
return output_dict if output_dict else None
|
|
733
|
+
|
|
734
|
+
|
|
735
|
+
def _serialize_model_response(response: Any) -> Any:
|
|
736
|
+
"""Serialize a model response."""
|
|
737
|
+
if not response:
|
|
738
|
+
return None
|
|
739
|
+
|
|
740
|
+
response_dict = _try_dict(response)
|
|
741
|
+
|
|
742
|
+
if isinstance(response_dict, dict) and "parts" in response_dict:
|
|
743
|
+
if hasattr(response, "parts"):
|
|
744
|
+
response_dict["parts"] = [_serialize_content_part(p) for p in response.parts]
|
|
745
|
+
|
|
746
|
+
return response_dict
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def _extract_model_info_from_model_instance(model: Any) -> tuple[str | None, str | None]:
|
|
750
|
+
"""Extract model name and provider from a model instance.
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
model: A Pydantic AI model instance (OpenAIChatModel, AnthropicModel, etc.)
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
Tuple of (model_name, provider)
|
|
757
|
+
"""
|
|
758
|
+
if not model:
|
|
759
|
+
return None, None
|
|
760
|
+
|
|
761
|
+
if isinstance(model, str):
|
|
762
|
+
return _parse_model_string(model)
|
|
763
|
+
|
|
764
|
+
if hasattr(model, "model_name"):
|
|
765
|
+
model_name = model.model_name
|
|
766
|
+
class_name = type(model).__name__
|
|
767
|
+
provider = None
|
|
768
|
+
if "OpenAI" in class_name:
|
|
769
|
+
provider = "openai"
|
|
770
|
+
elif "Anthropic" in class_name:
|
|
771
|
+
provider = "anthropic"
|
|
772
|
+
elif "Gemini" in class_name:
|
|
773
|
+
provider = "gemini"
|
|
774
|
+
elif "Groq" in class_name:
|
|
775
|
+
provider = "groq"
|
|
776
|
+
elif "Mistral" in class_name:
|
|
777
|
+
provider = "mistral"
|
|
778
|
+
elif "VertexAI" in class_name:
|
|
779
|
+
provider = "vertexai"
|
|
780
|
+
|
|
781
|
+
return model_name, provider
|
|
782
|
+
|
|
783
|
+
if hasattr(model, "name"):
|
|
784
|
+
return _parse_model_string(model.name)
|
|
785
|
+
|
|
786
|
+
return None, None
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
def _extract_model_info(agent: Any) -> tuple[str | None, str | None]:
|
|
790
|
+
"""Extract model name and provider from agent.
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
agent: A Pydantic AI Agent instance
|
|
794
|
+
|
|
795
|
+
Returns:
|
|
796
|
+
Tuple of (model_name, provider)
|
|
797
|
+
"""
|
|
798
|
+
if not hasattr(agent, "model"):
|
|
799
|
+
return None, None
|
|
800
|
+
|
|
801
|
+
return _extract_model_info_from_model_instance(agent.model)
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def _build_model_metadata(
|
|
805
|
+
model_name: str | None, provider: str | None, model_settings: Any = None
|
|
806
|
+
) -> dict[str, Any]:
|
|
807
|
+
"""Build metadata dictionary with model info.
|
|
808
|
+
|
|
809
|
+
Args:
|
|
810
|
+
model_name: The model name (e.g., "gpt-4o")
|
|
811
|
+
provider: The provider (e.g., "openai")
|
|
812
|
+
model_settings: Optional model settings to include
|
|
813
|
+
|
|
814
|
+
Returns:
|
|
815
|
+
Dictionary of metadata
|
|
816
|
+
"""
|
|
817
|
+
metadata = {}
|
|
818
|
+
if model_name:
|
|
819
|
+
metadata["model"] = model_name
|
|
820
|
+
if provider:
|
|
821
|
+
metadata["provider"] = provider
|
|
822
|
+
if model_settings:
|
|
823
|
+
metadata["model_settings"] = _try_dict(model_settings)
|
|
824
|
+
return metadata
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
def _parse_model_string(model: Any) -> tuple[str | None, str | None]:
|
|
828
|
+
"""Parse model string to extract provider and model name.
|
|
829
|
+
|
|
830
|
+
Pydantic AI uses format: "provider:model-name" (e.g., "openai:gpt-4o")
|
|
831
|
+
"""
|
|
832
|
+
if not model:
|
|
833
|
+
return None, None
|
|
834
|
+
|
|
835
|
+
model_str = str(model)
|
|
836
|
+
|
|
837
|
+
if ":" in model_str:
|
|
838
|
+
parts = model_str.split(":", 1)
|
|
839
|
+
return parts[1], parts[0] # (model_name, provider)
|
|
840
|
+
|
|
841
|
+
return model_str, None
|
|
842
|
+
|
|
843
|
+
|
|
844
|
+
def _extract_usage_metrics(result: Any, start_time: float, end_time: float) -> dict[str, float] | None:
|
|
845
|
+
"""Extract usage metrics from agent run result."""
|
|
846
|
+
metrics: dict[str, float] = {}
|
|
847
|
+
|
|
848
|
+
metrics["start"] = start_time
|
|
849
|
+
metrics["end"] = end_time
|
|
850
|
+
metrics["duration"] = end_time - start_time
|
|
851
|
+
|
|
852
|
+
usage = None
|
|
853
|
+
if hasattr(result, "response"):
|
|
854
|
+
try:
|
|
855
|
+
response = result.response
|
|
856
|
+
if hasattr(response, "usage"):
|
|
857
|
+
usage = response.usage
|
|
858
|
+
except (AttributeError, ValueError):
|
|
859
|
+
pass
|
|
860
|
+
|
|
861
|
+
if usage is None and hasattr(result, "usage"):
|
|
862
|
+
usage = result.usage
|
|
863
|
+
|
|
864
|
+
if usage is None:
|
|
865
|
+
return metrics
|
|
866
|
+
|
|
867
|
+
if hasattr(usage, "input_tokens"):
|
|
868
|
+
input_tokens = usage.input_tokens
|
|
869
|
+
if input_tokens is not None:
|
|
870
|
+
metrics["prompt_tokens"] = float(input_tokens)
|
|
871
|
+
|
|
872
|
+
if hasattr(usage, "output_tokens"):
|
|
873
|
+
output_tokens = usage.output_tokens
|
|
874
|
+
if output_tokens is not None:
|
|
875
|
+
metrics["completion_tokens"] = float(output_tokens)
|
|
876
|
+
|
|
877
|
+
if hasattr(usage, "total_tokens"):
|
|
878
|
+
total_tokens = usage.total_tokens
|
|
879
|
+
if total_tokens is not None:
|
|
880
|
+
metrics["tokens"] = float(total_tokens)
|
|
881
|
+
|
|
882
|
+
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
|
|
883
|
+
metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
|
|
884
|
+
|
|
885
|
+
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
|
|
886
|
+
metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
|
|
887
|
+
|
|
888
|
+
if hasattr(usage, "input_audio_tokens") and usage.input_audio_tokens is not None:
|
|
889
|
+
metrics["prompt_audio_tokens"] = float(usage.input_audio_tokens)
|
|
890
|
+
|
|
891
|
+
if hasattr(usage, "output_audio_tokens") and usage.output_audio_tokens is not None:
|
|
892
|
+
metrics["completion_audio_tokens"] = float(usage.output_audio_tokens)
|
|
893
|
+
|
|
894
|
+
if hasattr(usage, "details") and isinstance(usage.details, dict):
|
|
895
|
+
details = usage.details
|
|
896
|
+
|
|
897
|
+
if "reasoning_tokens" in details:
|
|
898
|
+
metrics["completion_reasoning_tokens"] = float(details["reasoning_tokens"])
|
|
899
|
+
|
|
900
|
+
if "cached_tokens" in details:
|
|
901
|
+
metrics["prompt_cached_tokens"] = float(details["cached_tokens"])
|
|
902
|
+
|
|
903
|
+
return metrics if metrics else None
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def _extract_stream_usage_metrics(
|
|
907
|
+
stream_result: Any, start_time: float, end_time: float, first_token_time: float | None
|
|
908
|
+
) -> dict[str, float] | None:
|
|
909
|
+
"""Extract usage metrics from stream result."""
|
|
910
|
+
metrics: dict[str, float] = {}
|
|
911
|
+
|
|
912
|
+
metrics["start"] = start_time
|
|
913
|
+
metrics["end"] = end_time
|
|
914
|
+
metrics["duration"] = end_time - start_time
|
|
915
|
+
|
|
916
|
+
if first_token_time:
|
|
917
|
+
metrics["time_to_first_token"] = first_token_time - start_time
|
|
918
|
+
|
|
919
|
+
if hasattr(stream_result, "usage"):
|
|
920
|
+
usage_func = stream_result.usage
|
|
921
|
+
if callable(usage_func):
|
|
922
|
+
usage = usage_func()
|
|
923
|
+
else:
|
|
924
|
+
usage = usage_func
|
|
925
|
+
|
|
926
|
+
if usage:
|
|
927
|
+
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
|
|
928
|
+
metrics["prompt_tokens"] = float(usage.input_tokens)
|
|
929
|
+
|
|
930
|
+
if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
|
|
931
|
+
metrics["completion_tokens"] = float(usage.output_tokens)
|
|
932
|
+
|
|
933
|
+
if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
|
|
934
|
+
metrics["tokens"] = float(usage.total_tokens)
|
|
935
|
+
|
|
936
|
+
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
|
|
937
|
+
metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
|
|
938
|
+
|
|
939
|
+
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
|
|
940
|
+
metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
|
|
941
|
+
|
|
942
|
+
return metrics if metrics else None
|
|
943
|
+
|
|
944
|
+
|
|
945
|
+
def _extract_response_metrics(
|
|
946
|
+
response: Any, start_time: float, end_time: float, first_token_time: float | None = None
|
|
947
|
+
) -> dict[str, float] | None:
|
|
948
|
+
"""Extract metrics from model response."""
|
|
949
|
+
metrics: dict[str, float] = {}
|
|
950
|
+
|
|
951
|
+
metrics["start"] = start_time
|
|
952
|
+
metrics["end"] = end_time
|
|
953
|
+
metrics["duration"] = end_time - start_time
|
|
954
|
+
|
|
955
|
+
if first_token_time:
|
|
956
|
+
metrics["time_to_first_token"] = first_token_time - start_time
|
|
957
|
+
|
|
958
|
+
if hasattr(response, "usage") and response.usage:
|
|
959
|
+
usage = response.usage
|
|
960
|
+
|
|
961
|
+
if hasattr(usage, "input_tokens") and usage.input_tokens is not None:
|
|
962
|
+
metrics["prompt_tokens"] = float(usage.input_tokens)
|
|
963
|
+
|
|
964
|
+
if hasattr(usage, "output_tokens") and usage.output_tokens is not None:
|
|
965
|
+
metrics["completion_tokens"] = float(usage.output_tokens)
|
|
966
|
+
|
|
967
|
+
if hasattr(usage, "total_tokens") and usage.total_tokens is not None:
|
|
968
|
+
metrics["tokens"] = float(usage.total_tokens)
|
|
969
|
+
|
|
970
|
+
if hasattr(usage, "cache_read_tokens") and usage.cache_read_tokens is not None:
|
|
971
|
+
metrics["prompt_cached_tokens"] = float(usage.cache_read_tokens)
|
|
972
|
+
|
|
973
|
+
if hasattr(usage, "cache_write_tokens") and usage.cache_write_tokens is not None:
|
|
974
|
+
metrics["prompt_cache_creation_tokens"] = float(usage.cache_write_tokens)
|
|
975
|
+
|
|
976
|
+
# Extract reasoning tokens for reasoning models (o1/o3)
|
|
977
|
+
if hasattr(usage, "details") and usage.details is not None:
|
|
978
|
+
if hasattr(usage.details, "reasoning_tokens") and usage.details.reasoning_tokens is not None:
|
|
979
|
+
metrics["completion_reasoning_tokens"] = float(usage.details.reasoning_tokens)
|
|
980
|
+
|
|
981
|
+
return metrics if metrics else None
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def _is_patched(obj: Any) -> bool:
|
|
985
|
+
"""Check if object is already patched."""
|
|
986
|
+
return getattr(obj, "_braintrust_patched", False)
|
|
987
|
+
|
|
988
|
+
|
|
989
|
+
def _try_dict(obj: Any) -> Iterable[Any] | dict[str, Any]:
|
|
990
|
+
"""Try to convert object to dict, handling Pydantic models and circular references."""
|
|
991
|
+
if hasattr(obj, "model_dump"):
|
|
992
|
+
try:
|
|
993
|
+
obj = obj.model_dump(exclude_none=True)
|
|
994
|
+
except ValueError as e:
|
|
995
|
+
if "Circular reference" in str(e):
|
|
996
|
+
return {}
|
|
997
|
+
raise
|
|
998
|
+
|
|
999
|
+
if isinstance(obj, dict):
|
|
1000
|
+
return {k: _try_dict(v) for k, v in obj.items()}
|
|
1001
|
+
elif isinstance(obj, (list, tuple)):
|
|
1002
|
+
return [_try_dict(item) for item in obj]
|
|
1003
|
+
|
|
1004
|
+
return obj
|
|
1005
|
+
|
|
1006
|
+
|
|
1007
|
+
def _serialize_type(obj: Any) -> Any:
|
|
1008
|
+
"""Serialize a type/class for logging, handling Pydantic models and other types.
|
|
1009
|
+
|
|
1010
|
+
This is useful for output_type, toolsets, and similar type parameters.
|
|
1011
|
+
Returns full JSON schema for Pydantic models so engineers can see exactly
|
|
1012
|
+
what structured output schema was used.
|
|
1013
|
+
"""
|
|
1014
|
+
import inspect
|
|
1015
|
+
|
|
1016
|
+
# For sequences of types (like Union types or list of models)
|
|
1017
|
+
if isinstance(obj, (list, tuple)):
|
|
1018
|
+
return [_serialize_type(item) for item in obj]
|
|
1019
|
+
|
|
1020
|
+
# Handle Pydantic AI's output wrappers (ToolOutput, NativeOutput, PromptedOutput, TextOutput)
|
|
1021
|
+
if hasattr(obj, "output"):
|
|
1022
|
+
# These are wrapper classes with an 'output' field containing the actual type
|
|
1023
|
+
wrapper_info = {"wrapper": type(obj).__name__}
|
|
1024
|
+
if hasattr(obj, "name") and obj.name:
|
|
1025
|
+
wrapper_info["name"] = obj.name
|
|
1026
|
+
if hasattr(obj, "description") and obj.description:
|
|
1027
|
+
wrapper_info["description"] = obj.description
|
|
1028
|
+
wrapper_info["output"] = _serialize_type(obj.output)
|
|
1029
|
+
return wrapper_info
|
|
1030
|
+
|
|
1031
|
+
# If it's a Pydantic model class, return its full JSON schema
|
|
1032
|
+
if inspect.isclass(obj):
|
|
1033
|
+
try:
|
|
1034
|
+
from pydantic import BaseModel
|
|
1035
|
+
|
|
1036
|
+
if issubclass(obj, BaseModel):
|
|
1037
|
+
# Return the full JSON schema - includes all field info, descriptions, constraints, etc.
|
|
1038
|
+
return obj.model_json_schema()
|
|
1039
|
+
except (ImportError, AttributeError, TypeError):
|
|
1040
|
+
pass
|
|
1041
|
+
|
|
1042
|
+
# Not a Pydantic model, return class name
|
|
1043
|
+
return obj.__name__
|
|
1044
|
+
|
|
1045
|
+
# If it has a __name__ attribute (like functions), use that
|
|
1046
|
+
if hasattr(obj, "__name__"):
|
|
1047
|
+
return obj.__name__
|
|
1048
|
+
|
|
1049
|
+
# Try standard serialization
|
|
1050
|
+
return _try_dict(obj)
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
G = TypeVar("G", bound=AsyncGenerator[Any, None])
|
|
1054
|
+
|
|
1055
|
+
|
|
1056
|
+
class aclosing(AbstractAsyncContextManager[G]):
|
|
1057
|
+
"""Context manager for closing async generators."""
|
|
1058
|
+
|
|
1059
|
+
def __init__(self, async_generator: G):
|
|
1060
|
+
self.async_generator = async_generator
|
|
1061
|
+
|
|
1062
|
+
async def __aenter__(self):
|
|
1063
|
+
return self.async_generator
|
|
1064
|
+
|
|
1065
|
+
async def __aexit__(self, *exc_info: Any):
|
|
1066
|
+
try:
|
|
1067
|
+
await self.async_generator.aclose()
|
|
1068
|
+
except ValueError as e:
|
|
1069
|
+
if "was created in a different Context" not in str(e):
|
|
1070
|
+
raise
|
|
1071
|
+
else:
|
|
1072
|
+
logger.debug(
|
|
1073
|
+
f"Suppressed ContextVar error during async cleanup: {e}. "
|
|
1074
|
+
"This is expected when async generators yield across context boundaries."
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
|
|
1078
|
+
def _build_agent_input_and_metadata(args: Any, kwargs: Any, instance: Any) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
1079
|
+
"""Build input data and metadata for agent wrappers.
|
|
1080
|
+
|
|
1081
|
+
Returns:
|
|
1082
|
+
Tuple of (input_data, metadata)
|
|
1083
|
+
"""
|
|
1084
|
+
input_data = {}
|
|
1085
|
+
|
|
1086
|
+
user_prompt = args[0] if len(args) > 0 else kwargs.get("user_prompt")
|
|
1087
|
+
if user_prompt is not None:
|
|
1088
|
+
input_data["user_prompt"] = _serialize_user_prompt(user_prompt)
|
|
1089
|
+
|
|
1090
|
+
for key, value in kwargs.items():
|
|
1091
|
+
if key == "deps":
|
|
1092
|
+
continue
|
|
1093
|
+
elif key == "message_history":
|
|
1094
|
+
input_data[key] = _serialize_messages(value) if value is not None else None
|
|
1095
|
+
elif key in ("output_type", "toolsets"):
|
|
1096
|
+
# These often contain types/classes, use special serialization
|
|
1097
|
+
input_data[key] = _serialize_type(value) if value is not None else None
|
|
1098
|
+
elif key == "model_settings":
|
|
1099
|
+
# model_settings passed to run() goes in INPUT (it's a run() parameter)
|
|
1100
|
+
input_data[key] = _try_dict(value) if value is not None else None
|
|
1101
|
+
else:
|
|
1102
|
+
input_data[key] = _try_dict(value) if value is not None else None
|
|
1103
|
+
|
|
1104
|
+
if "model" in kwargs:
|
|
1105
|
+
model_name, provider = _parse_model_string(kwargs["model"])
|
|
1106
|
+
else:
|
|
1107
|
+
model_name, provider = _extract_model_info(instance)
|
|
1108
|
+
|
|
1109
|
+
# Extract agent-level configuration for metadata
|
|
1110
|
+
# Only add to metadata if NOT explicitly passed in kwargs (those go in input)
|
|
1111
|
+
agent_model_settings = None
|
|
1112
|
+
if "model_settings" not in kwargs and hasattr(instance, "model_settings") and instance.model_settings is not None:
|
|
1113
|
+
agent_model_settings = instance.model_settings
|
|
1114
|
+
|
|
1115
|
+
metadata = _build_model_metadata(model_name, provider, agent_model_settings)
|
|
1116
|
+
|
|
1117
|
+
# Extract additional agent configuration (only if not passed as kwargs)
|
|
1118
|
+
if "name" not in kwargs and hasattr(instance, "name") and instance.name is not None:
|
|
1119
|
+
metadata["agent_name"] = instance.name
|
|
1120
|
+
|
|
1121
|
+
if "end_strategy" not in kwargs and hasattr(instance, "end_strategy") and instance.end_strategy is not None:
|
|
1122
|
+
metadata["end_strategy"] = str(instance.end_strategy)
|
|
1123
|
+
|
|
1124
|
+
# Extract output_type if set on agent and not passed as kwarg
|
|
1125
|
+
# output_type can be a Pydantic model, str, or other types that get converted to JSON schema
|
|
1126
|
+
if "output_type" not in kwargs and hasattr(instance, "output_type") and instance.output_type is not None:
|
|
1127
|
+
try:
|
|
1128
|
+
metadata["output_type"] = _serialize_type(instance.output_type)
|
|
1129
|
+
except Exception as e:
|
|
1130
|
+
logger.debug(f"Failed to extract output_type from agent: {e}")
|
|
1131
|
+
|
|
1132
|
+
# Extract toolsets if set on agent and not passed as kwarg
|
|
1133
|
+
# Toolsets go in INPUT (not metadata) because agent.run() accepts toolsets parameter
|
|
1134
|
+
if "toolsets" not in kwargs and hasattr(instance, "toolsets"):
|
|
1135
|
+
try:
|
|
1136
|
+
toolsets = instance.toolsets
|
|
1137
|
+
if toolsets:
|
|
1138
|
+
# Convert toolsets to a list with FULL tool schemas for input
|
|
1139
|
+
serialized_toolsets = []
|
|
1140
|
+
for ts in toolsets:
|
|
1141
|
+
ts_info = {
|
|
1142
|
+
"id": getattr(ts, "id", str(type(ts).__name__)),
|
|
1143
|
+
"label": getattr(ts, "label", None),
|
|
1144
|
+
}
|
|
1145
|
+
# Add full tool schemas (not just names) since toolsets can be passed to agent.run()
|
|
1146
|
+
if hasattr(ts, "tools") and ts.tools:
|
|
1147
|
+
tools_list = []
|
|
1148
|
+
tools_dict = ts.tools
|
|
1149
|
+
# tools is a dict mapping tool name -> Tool object
|
|
1150
|
+
for tool_name, tool_obj in tools_dict.items():
|
|
1151
|
+
tool_dict = {
|
|
1152
|
+
"name": tool_name,
|
|
1153
|
+
}
|
|
1154
|
+
# Extract description
|
|
1155
|
+
if hasattr(tool_obj, "description") and tool_obj.description:
|
|
1156
|
+
tool_dict["description"] = tool_obj.description
|
|
1157
|
+
# Extract JSON schema for parameters
|
|
1158
|
+
if hasattr(tool_obj, "function_schema") and hasattr(tool_obj.function_schema, "json_schema"):
|
|
1159
|
+
tool_dict["parameters"] = tool_obj.function_schema.json_schema
|
|
1160
|
+
tools_list.append(tool_dict)
|
|
1161
|
+
ts_info["tools"] = tools_list
|
|
1162
|
+
serialized_toolsets.append(ts_info)
|
|
1163
|
+
input_data["toolsets"] = serialized_toolsets
|
|
1164
|
+
except Exception as e:
|
|
1165
|
+
logger.debug(f"Failed to extract toolsets from agent: {e}")
|
|
1166
|
+
|
|
1167
|
+
# Extract system_prompt from agent if not passed as kwarg
|
|
1168
|
+
# Note: system_prompt goes in input (not metadata) because it's semantically part of the LLM input
|
|
1169
|
+
# Pydantic AI doesn't expose a public API for this, so we access the private _system_prompts
|
|
1170
|
+
# attribute. This is wrapped in try/except to gracefully handle if the internal structure changes.
|
|
1171
|
+
if "system_prompt" not in kwargs:
|
|
1172
|
+
try:
|
|
1173
|
+
if hasattr(instance, "_system_prompts") and instance._system_prompts:
|
|
1174
|
+
input_data["system_prompt"] = "\n\n".join(instance._system_prompts)
|
|
1175
|
+
except Exception as e:
|
|
1176
|
+
logger.debug(f"Failed to extract system_prompt from agent: {e}")
|
|
1177
|
+
|
|
1178
|
+
return input_data, metadata
|
|
1179
|
+
|
|
1180
|
+
|
|
1181
|
+
def _build_direct_model_input_and_metadata(args: Any, kwargs: Any) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
1182
|
+
"""Build input data and metadata for direct model request wrappers.
|
|
1183
|
+
|
|
1184
|
+
Returns:
|
|
1185
|
+
Tuple of (input_data, metadata)
|
|
1186
|
+
"""
|
|
1187
|
+
input_data = {}
|
|
1188
|
+
|
|
1189
|
+
model = args[0] if len(args) > 0 else kwargs.get("model")
|
|
1190
|
+
if model is not None:
|
|
1191
|
+
input_data["model"] = str(model)
|
|
1192
|
+
|
|
1193
|
+
messages = args[1] if len(args) > 1 else kwargs.get("messages", [])
|
|
1194
|
+
if messages:
|
|
1195
|
+
input_data["messages"] = _serialize_messages(messages)
|
|
1196
|
+
|
|
1197
|
+
for key, value in kwargs.items():
|
|
1198
|
+
if key not in ["model", "messages"]:
|
|
1199
|
+
input_data[key] = _try_dict(value) if value is not None else None
|
|
1200
|
+
|
|
1201
|
+
model_name, provider = _parse_model_string(model)
|
|
1202
|
+
metadata = _build_model_metadata(model_name, provider)
|
|
1203
|
+
|
|
1204
|
+
return input_data, metadata
|