braintrust 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/__init__.py +3 -0
- braintrust/_generated_types.py +106 -6
- braintrust/auto.py +179 -0
- braintrust/conftest.py +23 -4
- braintrust/framework.py +113 -3
- braintrust/functions/invoke.py +3 -1
- braintrust/functions/test_invoke.py +61 -0
- braintrust/generated_types.py +7 -1
- braintrust/logger.py +127 -45
- braintrust/oai.py +51 -0
- braintrust/span_cache.py +337 -0
- braintrust/span_identifier_v3.py +21 -0
- braintrust/test_bt_json.py +0 -5
- braintrust/test_framework.py +37 -0
- braintrust/test_http.py +444 -0
- braintrust/test_logger.py +295 -5
- braintrust/test_span_cache.py +344 -0
- braintrust/test_trace.py +267 -0
- braintrust/test_util.py +58 -1
- braintrust/trace.py +385 -0
- braintrust/util.py +20 -0
- braintrust/version.py +2 -2
- braintrust/wrappers/agno/__init__.py +2 -3
- braintrust/wrappers/anthropic.py +64 -0
- braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +48 -6
- braintrust/wrappers/claude_agent_sdk/test_wrapper.py +115 -0
- braintrust/wrappers/dspy.py +52 -1
- braintrust/wrappers/google_genai/__init__.py +9 -6
- braintrust/wrappers/litellm.py +6 -43
- braintrust/wrappers/pydantic_ai.py +2 -3
- braintrust/wrappers/test_agno.py +9 -0
- braintrust/wrappers/test_anthropic.py +156 -0
- braintrust/wrappers/test_dspy.py +117 -0
- braintrust/wrappers/test_google_genai.py +9 -0
- braintrust/wrappers/test_litellm.py +57 -55
- braintrust/wrappers/test_openai.py +253 -1
- braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
- braintrust/wrappers/test_utils.py +79 -0
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/METADATA +1 -1
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/RECORD +44 -37
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/WHEEL +1 -1
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/entry_points.txt +0 -0
- {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/top_level.txt +0 -0
braintrust/trace.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trace objects for accessing spans in evaluations.
|
|
3
|
+
|
|
4
|
+
This module provides the LocalTrace class which allows scorers to access
|
|
5
|
+
spans from the current evaluation task without making server round-trips.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
from typing import Any, Awaitable, Callable, Optional, Protocol
|
|
10
|
+
|
|
11
|
+
from braintrust.logger import BraintrustState, ObjectFetcher
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SpanData:
|
|
15
|
+
"""Span data returned by get_spans()."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
input: Optional[Any] = None,
|
|
20
|
+
output: Optional[Any] = None,
|
|
21
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
22
|
+
span_id: Optional[str] = None,
|
|
23
|
+
span_parents: Optional[list[str]] = None,
|
|
24
|
+
span_attributes: Optional[dict[str, Any]] = None,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
):
|
|
27
|
+
self.input = input
|
|
28
|
+
self.output = output
|
|
29
|
+
self.metadata = metadata
|
|
30
|
+
self.span_id = span_id
|
|
31
|
+
self.span_parents = span_parents
|
|
32
|
+
self.span_attributes = span_attributes
|
|
33
|
+
# Store any additional fields
|
|
34
|
+
for key, value in kwargs.items():
|
|
35
|
+
setattr(self, key, value)
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_dict(cls, data: dict[str, Any]) -> "SpanData":
|
|
39
|
+
"""Create SpanData from a dictionary."""
|
|
40
|
+
return cls(**data)
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> dict[str, Any]:
|
|
43
|
+
"""Convert to dictionary."""
|
|
44
|
+
result = {}
|
|
45
|
+
for key, value in self.__dict__.items():
|
|
46
|
+
if value is not None:
|
|
47
|
+
result[key] = value
|
|
48
|
+
return result
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SpanFetcher(ObjectFetcher[dict[str, Any]]):
|
|
52
|
+
"""
|
|
53
|
+
Fetcher for spans by root_span_id, using the ObjectFetcher pattern.
|
|
54
|
+
Handles pagination automatically via cursor-based iteration.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
|
|
60
|
+
object_id: str,
|
|
61
|
+
root_span_id: str,
|
|
62
|
+
state: BraintrustState,
|
|
63
|
+
span_type_filter: Optional[list[str]] = None,
|
|
64
|
+
):
|
|
65
|
+
# Build the filter expression for root_span_id and optionally span_attributes.type
|
|
66
|
+
filter_expr = self._build_filter(root_span_id, span_type_filter)
|
|
67
|
+
|
|
68
|
+
super().__init__(
|
|
69
|
+
object_type=object_type,
|
|
70
|
+
_internal_btql={"filter": filter_expr},
|
|
71
|
+
)
|
|
72
|
+
self._object_id = object_id
|
|
73
|
+
self._state = state
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = None) -> dict[str, Any]:
|
|
77
|
+
"""Build BTQL filter expression."""
|
|
78
|
+
children = [
|
|
79
|
+
# Base filter: root_span_id = 'value'
|
|
80
|
+
{
|
|
81
|
+
"op": "eq",
|
|
82
|
+
"left": {"op": "ident", "name": ["root_span_id"]},
|
|
83
|
+
"right": {"op": "literal", "value": root_span_id},
|
|
84
|
+
},
|
|
85
|
+
# Exclude span_attributes.purpose = 'scorer'
|
|
86
|
+
{
|
|
87
|
+
"op": "or",
|
|
88
|
+
"children": [
|
|
89
|
+
{
|
|
90
|
+
"op": "isnull",
|
|
91
|
+
"expr": {"op": "ident", "name": ["span_attributes", "purpose"]},
|
|
92
|
+
},
|
|
93
|
+
{
|
|
94
|
+
"op": "ne",
|
|
95
|
+
"left": {"op": "ident", "name": ["span_attributes", "purpose"]},
|
|
96
|
+
"right": {"op": "literal", "value": "scorer"},
|
|
97
|
+
},
|
|
98
|
+
],
|
|
99
|
+
},
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
# If span type filter specified, add it
|
|
103
|
+
if span_type_filter and len(span_type_filter) > 0:
|
|
104
|
+
children.append(
|
|
105
|
+
{
|
|
106
|
+
"op": "in",
|
|
107
|
+
"left": {"op": "ident", "name": ["span_attributes", "type"]},
|
|
108
|
+
"right": {"op": "literal", "value": span_type_filter},
|
|
109
|
+
}
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return {"op": "and", "children": children}
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def id(self) -> str:
|
|
116
|
+
return self._object_id
|
|
117
|
+
|
|
118
|
+
def _get_state(self) -> BraintrustState:
|
|
119
|
+
return self._state
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
SpanFetchFn = Callable[[Optional[list[str]]], Awaitable[list[SpanData]]]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class CachedSpanFetcher:
|
|
126
|
+
"""
|
|
127
|
+
Cached span fetcher that handles fetching and caching spans by type.
|
|
128
|
+
|
|
129
|
+
Caching strategy:
|
|
130
|
+
- Cache spans by span type (dict[spanType, list[SpanData]])
|
|
131
|
+
- Track if all spans have been fetched (all_fetched flag)
|
|
132
|
+
- When filtering by spanType, only fetch types not already in cache
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
object_type: Optional[str] = None, # Literal["experiment", "project_logs", "playground_logs"]
|
|
138
|
+
object_id: Optional[str] = None,
|
|
139
|
+
root_span_id: Optional[str] = None,
|
|
140
|
+
get_state: Optional[Callable[[], Awaitable[BraintrustState]]] = None,
|
|
141
|
+
fetch_fn: Optional[SpanFetchFn] = None,
|
|
142
|
+
):
|
|
143
|
+
self._span_cache: dict[str, list[SpanData]] = {}
|
|
144
|
+
self._all_fetched = False
|
|
145
|
+
|
|
146
|
+
if fetch_fn is not None:
|
|
147
|
+
# Direct fetch function injection (for testing)
|
|
148
|
+
self._fetch_fn = fetch_fn
|
|
149
|
+
else:
|
|
150
|
+
# Standard constructor with SpanFetcher
|
|
151
|
+
if object_type is None or object_id is None or root_span_id is None or get_state is None:
|
|
152
|
+
raise ValueError("Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state")
|
|
153
|
+
|
|
154
|
+
async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]:
|
|
155
|
+
state = await get_state()
|
|
156
|
+
fetcher = SpanFetcher(
|
|
157
|
+
object_type=object_type,
|
|
158
|
+
object_id=object_id,
|
|
159
|
+
root_span_id=root_span_id,
|
|
160
|
+
state=state,
|
|
161
|
+
span_type_filter=span_type,
|
|
162
|
+
)
|
|
163
|
+
rows = list(fetcher.fetch())
|
|
164
|
+
# Filter out scorer spans
|
|
165
|
+
filtered = [
|
|
166
|
+
row
|
|
167
|
+
for row in rows
|
|
168
|
+
if not (
|
|
169
|
+
isinstance(row.get("span_attributes"), dict)
|
|
170
|
+
and row.get("span_attributes", {}).get("purpose") == "scorer"
|
|
171
|
+
)
|
|
172
|
+
]
|
|
173
|
+
return [
|
|
174
|
+
SpanData(
|
|
175
|
+
input=row.get("input"),
|
|
176
|
+
output=row.get("output"),
|
|
177
|
+
metadata=row.get("metadata"),
|
|
178
|
+
span_id=row.get("span_id"),
|
|
179
|
+
span_parents=row.get("span_parents"),
|
|
180
|
+
span_attributes=row.get("span_attributes"),
|
|
181
|
+
id=row.get("id"),
|
|
182
|
+
_xact_id=row.get("_xact_id"),
|
|
183
|
+
_pagination_key=row.get("_pagination_key"),
|
|
184
|
+
root_span_id=row.get("root_span_id"),
|
|
185
|
+
)
|
|
186
|
+
for row in filtered
|
|
187
|
+
]
|
|
188
|
+
|
|
189
|
+
self._fetch_fn = _fetch_fn
|
|
190
|
+
|
|
191
|
+
async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
|
|
192
|
+
"""
|
|
193
|
+
Get spans, using cache when possible.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
span_type: Optional list of span types to filter by
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
List of matching spans
|
|
200
|
+
"""
|
|
201
|
+
# If we've fetched all spans, just filter from cache
|
|
202
|
+
if self._all_fetched:
|
|
203
|
+
return self._get_from_cache(span_type)
|
|
204
|
+
|
|
205
|
+
# If no filter requested, fetch everything
|
|
206
|
+
if not span_type or len(span_type) == 0:
|
|
207
|
+
await self._fetch_spans(None)
|
|
208
|
+
self._all_fetched = True
|
|
209
|
+
return self._get_from_cache(None)
|
|
210
|
+
|
|
211
|
+
# Find which spanTypes we don't have in cache yet
|
|
212
|
+
missing_types = [t for t in span_type if t not in self._span_cache]
|
|
213
|
+
|
|
214
|
+
# If all requested types are cached, return from cache
|
|
215
|
+
if not missing_types:
|
|
216
|
+
return self._get_from_cache(span_type)
|
|
217
|
+
|
|
218
|
+
# Fetch only the missing types
|
|
219
|
+
await self._fetch_spans(missing_types)
|
|
220
|
+
return self._get_from_cache(span_type)
|
|
221
|
+
|
|
222
|
+
async def _fetch_spans(self, span_type: Optional[list[str]]) -> None:
|
|
223
|
+
"""Fetch spans from the server."""
|
|
224
|
+
spans = await self._fetch_fn(span_type)
|
|
225
|
+
|
|
226
|
+
for span in spans:
|
|
227
|
+
span_attrs = span.span_attributes or {}
|
|
228
|
+
span_type_str = span_attrs.get("type", "")
|
|
229
|
+
if span_type_str not in self._span_cache:
|
|
230
|
+
self._span_cache[span_type_str] = []
|
|
231
|
+
self._span_cache[span_type_str].append(span)
|
|
232
|
+
|
|
233
|
+
def _get_from_cache(self, span_type: Optional[list[str]]) -> list[SpanData]:
|
|
234
|
+
"""Get spans from cache, optionally filtering by type."""
|
|
235
|
+
if not span_type or len(span_type) == 0:
|
|
236
|
+
# Return all spans
|
|
237
|
+
result = []
|
|
238
|
+
for spans in self._span_cache.values():
|
|
239
|
+
result.extend(spans)
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
# Return only requested types
|
|
243
|
+
result = []
|
|
244
|
+
for type_str in span_type:
|
|
245
|
+
if type_str in self._span_cache:
|
|
246
|
+
result.extend(self._span_cache[type_str])
|
|
247
|
+
return result
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class Trace(Protocol):
|
|
251
|
+
"""
|
|
252
|
+
Interface for trace objects that can be used by scorers.
|
|
253
|
+
Both the SDK's LocalTrace class and the API wrapper's WrapperTrace implement this.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def get_configuration(self) -> dict[str, str]:
|
|
257
|
+
"""Get the trace configuration (object_type, object_id, root_span_id)."""
|
|
258
|
+
...
|
|
259
|
+
|
|
260
|
+
async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
|
|
261
|
+
"""
|
|
262
|
+
Fetch all spans for this root span.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
span_type: Optional list of span types to filter by
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
List of matching spans
|
|
269
|
+
"""
|
|
270
|
+
...
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class LocalTrace(dict):
|
|
274
|
+
"""
|
|
275
|
+
SDK implementation of Trace that uses local span cache and falls back to BTQL.
|
|
276
|
+
Carries identifying information about the evaluation so scorers can perform
|
|
277
|
+
richer logging or side effects.
|
|
278
|
+
|
|
279
|
+
Inherits from dict so that it serializes to {"trace_ref": {...}} when passed
|
|
280
|
+
to json.dumps(). This allows LocalTrace to be transparently serialized when
|
|
281
|
+
passed through invoke() or other JSON-serializing code paths.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
|
|
287
|
+
object_id: str,
|
|
288
|
+
root_span_id: str,
|
|
289
|
+
ensure_spans_flushed: Optional[Callable[[], Awaitable[None]]],
|
|
290
|
+
state: BraintrustState,
|
|
291
|
+
):
|
|
292
|
+
# Initialize dict with trace_ref for JSON serialization
|
|
293
|
+
super().__init__({
|
|
294
|
+
"trace_ref": {
|
|
295
|
+
"object_type": object_type,
|
|
296
|
+
"object_id": object_id,
|
|
297
|
+
"root_span_id": root_span_id,
|
|
298
|
+
}
|
|
299
|
+
})
|
|
300
|
+
|
|
301
|
+
self._object_type = object_type
|
|
302
|
+
self._object_id = object_id
|
|
303
|
+
self._root_span_id = root_span_id
|
|
304
|
+
self._ensure_spans_flushed = ensure_spans_flushed
|
|
305
|
+
self._state = state
|
|
306
|
+
self._spans_flushed = False
|
|
307
|
+
self._spans_flush_promise: Optional[asyncio.Task[None]] = None
|
|
308
|
+
|
|
309
|
+
async def get_state() -> BraintrustState:
|
|
310
|
+
await self._ensure_spans_ready()
|
|
311
|
+
# Ensure state is logged in
|
|
312
|
+
await asyncio.get_event_loop().run_in_executor(None, lambda: state.login())
|
|
313
|
+
return state
|
|
314
|
+
|
|
315
|
+
self._cached_fetcher = CachedSpanFetcher(
|
|
316
|
+
object_type=object_type,
|
|
317
|
+
object_id=object_id,
|
|
318
|
+
root_span_id=root_span_id,
|
|
319
|
+
get_state=get_state,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def get_configuration(self) -> dict[str, str]:
|
|
323
|
+
"""Get the trace configuration."""
|
|
324
|
+
return {
|
|
325
|
+
"object_type": self._object_type,
|
|
326
|
+
"object_id": self._object_id,
|
|
327
|
+
"root_span_id": self._root_span_id,
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
|
|
331
|
+
"""
|
|
332
|
+
Fetch all rows for this root span from its parent object (experiment or project logs).
|
|
333
|
+
First checks the local span cache for recently logged spans, then falls
|
|
334
|
+
back to CachedSpanFetcher which handles BTQL fetching and caching.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
span_type: Optional list of span types to filter by
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
List of matching spans
|
|
341
|
+
"""
|
|
342
|
+
# Try local span cache first (for recently logged spans not yet flushed)
|
|
343
|
+
cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id)
|
|
344
|
+
if cached_spans and len(cached_spans) > 0:
|
|
345
|
+
# Filter by purpose
|
|
346
|
+
spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"]
|
|
347
|
+
|
|
348
|
+
# Filter by span type if requested
|
|
349
|
+
if span_type and len(span_type) > 0:
|
|
350
|
+
spans = [span for span in spans if (span.span_attributes or {}).get("type", "") in span_type]
|
|
351
|
+
|
|
352
|
+
# Convert to SpanData
|
|
353
|
+
return [
|
|
354
|
+
SpanData(
|
|
355
|
+
input=span.input,
|
|
356
|
+
output=span.output,
|
|
357
|
+
metadata=span.metadata,
|
|
358
|
+
span_id=span.span_id,
|
|
359
|
+
span_parents=span.span_parents,
|
|
360
|
+
span_attributes=span.span_attributes,
|
|
361
|
+
)
|
|
362
|
+
for span in spans
|
|
363
|
+
]
|
|
364
|
+
|
|
365
|
+
# Fall back to CachedSpanFetcher for BTQL fetching with caching
|
|
366
|
+
return await self._cached_fetcher.get_spans(span_type)
|
|
367
|
+
|
|
368
|
+
async def _ensure_spans_ready(self) -> None:
|
|
369
|
+
"""Ensure spans are flushed before fetching."""
|
|
370
|
+
if self._spans_flushed or not self._ensure_spans_flushed:
|
|
371
|
+
return
|
|
372
|
+
|
|
373
|
+
if self._spans_flush_promise is None:
|
|
374
|
+
|
|
375
|
+
async def flush_and_mark():
|
|
376
|
+
try:
|
|
377
|
+
await self._ensure_spans_flushed()
|
|
378
|
+
self._spans_flushed = True
|
|
379
|
+
except Exception as err:
|
|
380
|
+
self._spans_flush_promise = None
|
|
381
|
+
raise err
|
|
382
|
+
|
|
383
|
+
self._spans_flush_promise = asyncio.create_task(flush_and_mark())
|
|
384
|
+
|
|
385
|
+
await self._spans_flush_promise
|
braintrust/util.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import json
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
3
5
|
import sys
|
|
4
6
|
import threading
|
|
5
7
|
import urllib.parse
|
|
@@ -9,6 +11,24 @@ from typing import Any, Generic, Literal, TypedDict, TypeVar, Union
|
|
|
9
11
|
|
|
10
12
|
from requests import HTTPError, Response
|
|
11
13
|
|
|
14
|
+
|
|
15
|
+
def parse_env_var_float(name: str, default: float) -> float:
|
|
16
|
+
"""Parse a float from an environment variable, returning default if invalid.
|
|
17
|
+
|
|
18
|
+
Returns the default value if the env var is missing, empty, not a valid
|
|
19
|
+
float, NaN, or infinity.
|
|
20
|
+
"""
|
|
21
|
+
value = os.environ.get(name)
|
|
22
|
+
if value is None:
|
|
23
|
+
return default
|
|
24
|
+
try:
|
|
25
|
+
result = float(value)
|
|
26
|
+
if math.isnan(result) or math.isinf(result):
|
|
27
|
+
return default
|
|
28
|
+
return result
|
|
29
|
+
except (ValueError, TypeError):
|
|
30
|
+
return default
|
|
31
|
+
|
|
12
32
|
GLOBAL_PROJECT = "Global"
|
|
13
33
|
BT_IS_ASYNC_ATTRIBUTE = "_BT_IS_ASYNC"
|
|
14
34
|
|
braintrust/version.py
CHANGED
|
@@ -62,7 +62,6 @@ def setup_agno(
|
|
|
62
62
|
models.base.Model = wrap_model(models.base.Model) # pyright: ignore[reportUnknownMemberType]
|
|
63
63
|
tools.function.FunctionCall = wrap_function_call(tools.function.FunctionCall) # pyright: ignore[reportUnknownMemberType]
|
|
64
64
|
return True
|
|
65
|
-
except ImportError
|
|
66
|
-
|
|
67
|
-
logger.error("Agno is not installed. Please install it with: pip install agno")
|
|
65
|
+
except ImportError:
|
|
66
|
+
# Not installed - this is expected when using auto_instrument()
|
|
68
67
|
return False
|
braintrust/wrappers/anthropic.py
CHANGED
|
@@ -5,6 +5,7 @@ from contextlib import contextmanager
|
|
|
5
5
|
|
|
6
6
|
from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span
|
|
7
7
|
from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
|
|
8
|
+
from wrapt import wrap_function_wrapper
|
|
8
9
|
|
|
9
10
|
log = logging.getLogger(__name__)
|
|
10
11
|
|
|
@@ -358,3 +359,66 @@ def wrap_anthropic(client):
|
|
|
358
359
|
|
|
359
360
|
def wrap_anthropic_client(client):
|
|
360
361
|
return wrap_anthropic(client)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _apply_anthropic_wrapper(client):
|
|
365
|
+
"""Apply tracing wrapper to an Anthropic client instance in-place."""
|
|
366
|
+
wrapped = wrap_anthropic(client)
|
|
367
|
+
client.messages = wrapped.messages
|
|
368
|
+
if hasattr(wrapped, "beta"):
|
|
369
|
+
client.beta = wrapped.beta
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _apply_async_anthropic_wrapper(client):
|
|
373
|
+
"""Apply tracing wrapper to an AsyncAnthropic client instance in-place."""
|
|
374
|
+
wrapped = wrap_anthropic(client)
|
|
375
|
+
client.messages = wrapped.messages
|
|
376
|
+
if hasattr(wrapped, "beta"):
|
|
377
|
+
client.beta = wrapped.beta
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _anthropic_init_wrapper(wrapped, instance, args, kwargs):
|
|
381
|
+
"""Wrapper for Anthropic.__init__ that applies tracing after initialization."""
|
|
382
|
+
wrapped(*args, **kwargs)
|
|
383
|
+
_apply_anthropic_wrapper(instance)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def _async_anthropic_init_wrapper(wrapped, instance, args, kwargs):
|
|
387
|
+
"""Wrapper for AsyncAnthropic.__init__ that applies tracing after initialization."""
|
|
388
|
+
wrapped(*args, **kwargs)
|
|
389
|
+
_apply_async_anthropic_wrapper(instance)
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def patch_anthropic() -> bool:
|
|
393
|
+
"""
|
|
394
|
+
Patch Anthropic to add Braintrust tracing globally.
|
|
395
|
+
|
|
396
|
+
After calling this, all new Anthropic() and AsyncAnthropic() clients
|
|
397
|
+
will automatically have tracing enabled.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
True if Anthropic was patched (or already patched), False if Anthropic is not installed.
|
|
401
|
+
|
|
402
|
+
Example:
|
|
403
|
+
```python
|
|
404
|
+
import braintrust
|
|
405
|
+
braintrust.patch_anthropic()
|
|
406
|
+
|
|
407
|
+
import anthropic
|
|
408
|
+
client = anthropic.Anthropic()
|
|
409
|
+
# All calls are now traced!
|
|
410
|
+
```
|
|
411
|
+
"""
|
|
412
|
+
try:
|
|
413
|
+
import anthropic
|
|
414
|
+
|
|
415
|
+
if getattr(anthropic, "__braintrust_wrapped__", False):
|
|
416
|
+
return True # Already patched
|
|
417
|
+
|
|
418
|
+
wrap_function_wrapper("anthropic", "Anthropic.__init__", _anthropic_init_wrapper)
|
|
419
|
+
wrap_function_wrapper("anthropic", "AsyncAnthropic.__init__", _async_anthropic_init_wrapper)
|
|
420
|
+
anthropic.__braintrust_wrapped__ = True
|
|
421
|
+
return True
|
|
422
|
+
|
|
423
|
+
except ImportError:
|
|
424
|
+
return False
|
|
@@ -105,7 +105,6 @@ def setup_claude_agent_sdk(
|
|
|
105
105
|
setattr(module, "tool", wrapped_tool_fn)
|
|
106
106
|
|
|
107
107
|
return True
|
|
108
|
-
except ImportError
|
|
109
|
-
|
|
110
|
-
logger.error("claude-agent-sdk is not installed. Please install it with: pip install claude-agent-sdk")
|
|
108
|
+
except ImportError:
|
|
109
|
+
# Not installed - this is expected when using auto_instrument()
|
|
111
110
|
return False
|
|
@@ -2,7 +2,7 @@ import dataclasses
|
|
|
2
2
|
import logging
|
|
3
3
|
import threading
|
|
4
4
|
import time
|
|
5
|
-
from collections.abc import AsyncGenerator, Callable
|
|
5
|
+
from collections.abc import AsyncGenerator, AsyncIterable, Callable
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
from braintrust.logger import start_span
|
|
@@ -191,17 +191,38 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
191
191
|
self.__client = client
|
|
192
192
|
self.__last_prompt: str | None = None
|
|
193
193
|
self.__query_start_time: float | None = None
|
|
194
|
+
self.__captured_messages: list[dict[str, Any]] | None = None
|
|
194
195
|
|
|
195
196
|
async def query(self, *args: Any, **kwargs: Any) -> Any:
|
|
196
197
|
"""Wrap query to capture the prompt and start time for tracing."""
|
|
197
198
|
# Capture the time when query is called (when LLM call starts)
|
|
198
199
|
self.__query_start_time = time.time()
|
|
200
|
+
self.__captured_messages = None
|
|
199
201
|
|
|
200
202
|
# Capture the prompt for use in receive_response
|
|
201
|
-
if args
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
203
|
+
prompt = args[0] if args else kwargs.get("prompt")
|
|
204
|
+
|
|
205
|
+
if prompt is not None:
|
|
206
|
+
if isinstance(prompt, str):
|
|
207
|
+
self.__last_prompt = prompt
|
|
208
|
+
elif isinstance(prompt, AsyncIterable):
|
|
209
|
+
# AsyncIterable[dict] - wrap it to capture messages as they're yielded
|
|
210
|
+
captured: list[dict[str, Any]] = []
|
|
211
|
+
self.__captured_messages = captured
|
|
212
|
+
self.__last_prompt = None # Will be set after messages are captured
|
|
213
|
+
|
|
214
|
+
async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
|
|
215
|
+
async for msg in prompt:
|
|
216
|
+
captured.append(msg)
|
|
217
|
+
yield msg
|
|
218
|
+
|
|
219
|
+
# Replace the prompt with our capturing wrapper
|
|
220
|
+
if args:
|
|
221
|
+
args = (capturing_wrapper(),) + args[1:]
|
|
222
|
+
else:
|
|
223
|
+
kwargs["prompt"] = capturing_wrapper()
|
|
224
|
+
else:
|
|
225
|
+
self.__last_prompt = str(prompt)
|
|
205
226
|
|
|
206
227
|
return await self.__client.query(*args, **kwargs)
|
|
207
228
|
|
|
@@ -215,11 +236,16 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
215
236
|
"""
|
|
216
237
|
generator = self.__client.receive_response()
|
|
217
238
|
|
|
239
|
+
# Determine the initial input - may be updated later if using async generator
|
|
240
|
+
initial_input = self.__last_prompt if self.__last_prompt else None
|
|
241
|
+
|
|
218
242
|
with start_span(
|
|
219
243
|
name="Claude Agent",
|
|
220
244
|
span_attributes={"type": SpanTypeAttribute.TASK},
|
|
221
|
-
input=
|
|
245
|
+
input=initial_input,
|
|
222
246
|
) as span:
|
|
247
|
+
# If we're capturing async messages, we'll update input after they're consumed
|
|
248
|
+
input_needs_update = self.__captured_messages is not None
|
|
223
249
|
# Store the parent span export in thread-local storage for tool handlers
|
|
224
250
|
_thread_local.parent_span_export = span.export()
|
|
225
251
|
|
|
@@ -228,6 +254,13 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
228
254
|
|
|
229
255
|
try:
|
|
230
256
|
async for message in generator:
|
|
257
|
+
# Update input from captured async messages (once, after they're consumed)
|
|
258
|
+
if input_needs_update and self.__captured_messages:
|
|
259
|
+
captured_input = _format_captured_messages(self.__captured_messages)
|
|
260
|
+
if captured_input:
|
|
261
|
+
span.log(input=captured_input)
|
|
262
|
+
input_needs_update = False
|
|
263
|
+
|
|
231
264
|
message_type = type(message).__name__
|
|
232
265
|
|
|
233
266
|
if message_type == "AssistantMessage":
|
|
@@ -390,3 +423,12 @@ def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) ->
|
|
|
390
423
|
return [{"content": prompt, "role": "user"}] + conversation_history
|
|
391
424
|
|
|
392
425
|
return conversation_history if conversation_history else None
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _format_captured_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
429
|
+
"""Formats captured async generator messages into structured input.
|
|
430
|
+
|
|
431
|
+
Returns the messages as-is to preserve structure for tracing.
|
|
432
|
+
Empty list returns empty list.
|
|
433
|
+
"""
|
|
434
|
+
return messages if messages else []
|