braintrust 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrust/_generated_types.py +328 -126
- braintrust/cli/install/api.py +1 -1
- braintrust/conftest.py +24 -0
- braintrust/devserver/test_server_integration.py +0 -11
- braintrust/framework.py +98 -1
- braintrust/functions/invoke.py +4 -9
- braintrust/functions/test_invoke.py +61 -0
- braintrust/generated_types.py +13 -7
- braintrust/logger.py +107 -66
- braintrust/prompt_cache/test_disk_cache.py +3 -3
- braintrust/span_cache.py +337 -0
- braintrust/span_identifier_v3.py +21 -0
- braintrust/span_types.py +3 -0
- braintrust/test_bt_json.py +23 -19
- braintrust/test_logger.py +116 -0
- braintrust/test_span_cache.py +344 -0
- braintrust/test_trace.py +267 -0
- braintrust/trace.py +385 -0
- braintrust/version.py +2 -2
- braintrust/wrappers/claude_agent_sdk/_wrapper.py +48 -6
- braintrust/wrappers/claude_agent_sdk/test_wrapper.py +106 -0
- braintrust/wrappers/langsmith_wrapper.py +517 -0
- braintrust/wrappers/test_agno.py +0 -12
- braintrust/wrappers/test_anthropic.py +1 -11
- braintrust/wrappers/test_dspy.py +0 -11
- braintrust/wrappers/test_google_genai.py +6 -1
- braintrust/wrappers/test_langsmith_wrapper.py +338 -0
- braintrust/wrappers/test_litellm.py +0 -10
- braintrust/wrappers/test_oai_attachments.py +0 -10
- braintrust/wrappers/test_openai.py +3 -12
- braintrust/wrappers/test_openrouter.py +0 -9
- braintrust/wrappers/test_pydantic_ai_integration.py +0 -11
- braintrust/wrappers/test_pydantic_ai_wrap_openai.py +2 -0
- {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/METADATA +1 -1
- {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/RECORD +38 -31
- {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/WHEEL +1 -1
- {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/entry_points.txt +0 -0
- {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/top_level.txt +0 -0
braintrust/trace.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trace objects for accessing spans in evaluations.
|
|
3
|
+
|
|
4
|
+
This module provides the LocalTrace class which allows scorers to access
|
|
5
|
+
spans from the current evaluation task without making server round-trips.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
from typing import Any, Awaitable, Callable, Optional, Protocol
|
|
10
|
+
|
|
11
|
+
from braintrust.logger import BraintrustState, ObjectFetcher
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SpanData:
|
|
15
|
+
"""Span data returned by get_spans()."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
input: Optional[Any] = None,
|
|
20
|
+
output: Optional[Any] = None,
|
|
21
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
22
|
+
span_id: Optional[str] = None,
|
|
23
|
+
span_parents: Optional[list[str]] = None,
|
|
24
|
+
span_attributes: Optional[dict[str, Any]] = None,
|
|
25
|
+
**kwargs: Any,
|
|
26
|
+
):
|
|
27
|
+
self.input = input
|
|
28
|
+
self.output = output
|
|
29
|
+
self.metadata = metadata
|
|
30
|
+
self.span_id = span_id
|
|
31
|
+
self.span_parents = span_parents
|
|
32
|
+
self.span_attributes = span_attributes
|
|
33
|
+
# Store any additional fields
|
|
34
|
+
for key, value in kwargs.items():
|
|
35
|
+
setattr(self, key, value)
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def from_dict(cls, data: dict[str, Any]) -> "SpanData":
|
|
39
|
+
"""Create SpanData from a dictionary."""
|
|
40
|
+
return cls(**data)
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> dict[str, Any]:
|
|
43
|
+
"""Convert to dictionary."""
|
|
44
|
+
result = {}
|
|
45
|
+
for key, value in self.__dict__.items():
|
|
46
|
+
if value is not None:
|
|
47
|
+
result[key] = value
|
|
48
|
+
return result
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SpanFetcher(ObjectFetcher[dict[str, Any]]):
|
|
52
|
+
"""
|
|
53
|
+
Fetcher for spans by root_span_id, using the ObjectFetcher pattern.
|
|
54
|
+
Handles pagination automatically via cursor-based iteration.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
|
|
60
|
+
object_id: str,
|
|
61
|
+
root_span_id: str,
|
|
62
|
+
state: BraintrustState,
|
|
63
|
+
span_type_filter: Optional[list[str]] = None,
|
|
64
|
+
):
|
|
65
|
+
# Build the filter expression for root_span_id and optionally span_attributes.type
|
|
66
|
+
filter_expr = self._build_filter(root_span_id, span_type_filter)
|
|
67
|
+
|
|
68
|
+
super().__init__(
|
|
69
|
+
object_type=object_type,
|
|
70
|
+
_internal_btql={"filter": filter_expr},
|
|
71
|
+
)
|
|
72
|
+
self._object_id = object_id
|
|
73
|
+
self._state = state
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = None) -> dict[str, Any]:
|
|
77
|
+
"""Build BTQL filter expression."""
|
|
78
|
+
children = [
|
|
79
|
+
# Base filter: root_span_id = 'value'
|
|
80
|
+
{
|
|
81
|
+
"op": "eq",
|
|
82
|
+
"left": {"op": "ident", "name": ["root_span_id"]},
|
|
83
|
+
"right": {"op": "literal", "value": root_span_id},
|
|
84
|
+
},
|
|
85
|
+
# Exclude span_attributes.purpose = 'scorer'
|
|
86
|
+
{
|
|
87
|
+
"op": "or",
|
|
88
|
+
"children": [
|
|
89
|
+
{
|
|
90
|
+
"op": "isnull",
|
|
91
|
+
"expr": {"op": "ident", "name": ["span_attributes", "purpose"]},
|
|
92
|
+
},
|
|
93
|
+
{
|
|
94
|
+
"op": "ne",
|
|
95
|
+
"left": {"op": "ident", "name": ["span_attributes", "purpose"]},
|
|
96
|
+
"right": {"op": "literal", "value": "scorer"},
|
|
97
|
+
},
|
|
98
|
+
],
|
|
99
|
+
},
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
# If span type filter specified, add it
|
|
103
|
+
if span_type_filter and len(span_type_filter) > 0:
|
|
104
|
+
children.append(
|
|
105
|
+
{
|
|
106
|
+
"op": "in",
|
|
107
|
+
"left": {"op": "ident", "name": ["span_attributes", "type"]},
|
|
108
|
+
"right": {"op": "literal", "value": span_type_filter},
|
|
109
|
+
}
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return {"op": "and", "children": children}
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def id(self) -> str:
|
|
116
|
+
return self._object_id
|
|
117
|
+
|
|
118
|
+
def _get_state(self) -> BraintrustState:
|
|
119
|
+
return self._state
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
SpanFetchFn = Callable[[Optional[list[str]]], Awaitable[list[SpanData]]]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class CachedSpanFetcher:
|
|
126
|
+
"""
|
|
127
|
+
Cached span fetcher that handles fetching and caching spans by type.
|
|
128
|
+
|
|
129
|
+
Caching strategy:
|
|
130
|
+
- Cache spans by span type (dict[spanType, list[SpanData]])
|
|
131
|
+
- Track if all spans have been fetched (all_fetched flag)
|
|
132
|
+
- When filtering by spanType, only fetch types not already in cache
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
object_type: Optional[str] = None, # Literal["experiment", "project_logs", "playground_logs"]
|
|
138
|
+
object_id: Optional[str] = None,
|
|
139
|
+
root_span_id: Optional[str] = None,
|
|
140
|
+
get_state: Optional[Callable[[], Awaitable[BraintrustState]]] = None,
|
|
141
|
+
fetch_fn: Optional[SpanFetchFn] = None,
|
|
142
|
+
):
|
|
143
|
+
self._span_cache: dict[str, list[SpanData]] = {}
|
|
144
|
+
self._all_fetched = False
|
|
145
|
+
|
|
146
|
+
if fetch_fn is not None:
|
|
147
|
+
# Direct fetch function injection (for testing)
|
|
148
|
+
self._fetch_fn = fetch_fn
|
|
149
|
+
else:
|
|
150
|
+
# Standard constructor with SpanFetcher
|
|
151
|
+
if object_type is None or object_id is None or root_span_id is None or get_state is None:
|
|
152
|
+
raise ValueError("Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state")
|
|
153
|
+
|
|
154
|
+
async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]:
|
|
155
|
+
state = await get_state()
|
|
156
|
+
fetcher = SpanFetcher(
|
|
157
|
+
object_type=object_type,
|
|
158
|
+
object_id=object_id,
|
|
159
|
+
root_span_id=root_span_id,
|
|
160
|
+
state=state,
|
|
161
|
+
span_type_filter=span_type,
|
|
162
|
+
)
|
|
163
|
+
rows = list(fetcher.fetch())
|
|
164
|
+
# Filter out scorer spans
|
|
165
|
+
filtered = [
|
|
166
|
+
row
|
|
167
|
+
for row in rows
|
|
168
|
+
if not (
|
|
169
|
+
isinstance(row.get("span_attributes"), dict)
|
|
170
|
+
and row.get("span_attributes", {}).get("purpose") == "scorer"
|
|
171
|
+
)
|
|
172
|
+
]
|
|
173
|
+
return [
|
|
174
|
+
SpanData(
|
|
175
|
+
input=row.get("input"),
|
|
176
|
+
output=row.get("output"),
|
|
177
|
+
metadata=row.get("metadata"),
|
|
178
|
+
span_id=row.get("span_id"),
|
|
179
|
+
span_parents=row.get("span_parents"),
|
|
180
|
+
span_attributes=row.get("span_attributes"),
|
|
181
|
+
id=row.get("id"),
|
|
182
|
+
_xact_id=row.get("_xact_id"),
|
|
183
|
+
_pagination_key=row.get("_pagination_key"),
|
|
184
|
+
root_span_id=row.get("root_span_id"),
|
|
185
|
+
)
|
|
186
|
+
for row in filtered
|
|
187
|
+
]
|
|
188
|
+
|
|
189
|
+
self._fetch_fn = _fetch_fn
|
|
190
|
+
|
|
191
|
+
async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
|
|
192
|
+
"""
|
|
193
|
+
Get spans, using cache when possible.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
span_type: Optional list of span types to filter by
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
List of matching spans
|
|
200
|
+
"""
|
|
201
|
+
# If we've fetched all spans, just filter from cache
|
|
202
|
+
if self._all_fetched:
|
|
203
|
+
return self._get_from_cache(span_type)
|
|
204
|
+
|
|
205
|
+
# If no filter requested, fetch everything
|
|
206
|
+
if not span_type or len(span_type) == 0:
|
|
207
|
+
await self._fetch_spans(None)
|
|
208
|
+
self._all_fetched = True
|
|
209
|
+
return self._get_from_cache(None)
|
|
210
|
+
|
|
211
|
+
# Find which spanTypes we don't have in cache yet
|
|
212
|
+
missing_types = [t for t in span_type if t not in self._span_cache]
|
|
213
|
+
|
|
214
|
+
# If all requested types are cached, return from cache
|
|
215
|
+
if not missing_types:
|
|
216
|
+
return self._get_from_cache(span_type)
|
|
217
|
+
|
|
218
|
+
# Fetch only the missing types
|
|
219
|
+
await self._fetch_spans(missing_types)
|
|
220
|
+
return self._get_from_cache(span_type)
|
|
221
|
+
|
|
222
|
+
async def _fetch_spans(self, span_type: Optional[list[str]]) -> None:
|
|
223
|
+
"""Fetch spans from the server."""
|
|
224
|
+
spans = await self._fetch_fn(span_type)
|
|
225
|
+
|
|
226
|
+
for span in spans:
|
|
227
|
+
span_attrs = span.span_attributes or {}
|
|
228
|
+
span_type_str = span_attrs.get("type", "")
|
|
229
|
+
if span_type_str not in self._span_cache:
|
|
230
|
+
self._span_cache[span_type_str] = []
|
|
231
|
+
self._span_cache[span_type_str].append(span)
|
|
232
|
+
|
|
233
|
+
def _get_from_cache(self, span_type: Optional[list[str]]) -> list[SpanData]:
|
|
234
|
+
"""Get spans from cache, optionally filtering by type."""
|
|
235
|
+
if not span_type or len(span_type) == 0:
|
|
236
|
+
# Return all spans
|
|
237
|
+
result = []
|
|
238
|
+
for spans in self._span_cache.values():
|
|
239
|
+
result.extend(spans)
|
|
240
|
+
return result
|
|
241
|
+
|
|
242
|
+
# Return only requested types
|
|
243
|
+
result = []
|
|
244
|
+
for type_str in span_type:
|
|
245
|
+
if type_str in self._span_cache:
|
|
246
|
+
result.extend(self._span_cache[type_str])
|
|
247
|
+
return result
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class Trace(Protocol):
|
|
251
|
+
"""
|
|
252
|
+
Interface for trace objects that can be used by scorers.
|
|
253
|
+
Both the SDK's LocalTrace class and the API wrapper's WrapperTrace implement this.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def get_configuration(self) -> dict[str, str]:
|
|
257
|
+
"""Get the trace configuration (object_type, object_id, root_span_id)."""
|
|
258
|
+
...
|
|
259
|
+
|
|
260
|
+
async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
|
|
261
|
+
"""
|
|
262
|
+
Fetch all spans for this root span.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
span_type: Optional list of span types to filter by
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
List of matching spans
|
|
269
|
+
"""
|
|
270
|
+
...
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class LocalTrace(dict):
|
|
274
|
+
"""
|
|
275
|
+
SDK implementation of Trace that uses local span cache and falls back to BTQL.
|
|
276
|
+
Carries identifying information about the evaluation so scorers can perform
|
|
277
|
+
richer logging or side effects.
|
|
278
|
+
|
|
279
|
+
Inherits from dict so that it serializes to {"trace_ref": {...}} when passed
|
|
280
|
+
to json.dumps(). This allows LocalTrace to be transparently serialized when
|
|
281
|
+
passed through invoke() or other JSON-serializing code paths.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
|
|
287
|
+
object_id: str,
|
|
288
|
+
root_span_id: str,
|
|
289
|
+
ensure_spans_flushed: Optional[Callable[[], Awaitable[None]]],
|
|
290
|
+
state: BraintrustState,
|
|
291
|
+
):
|
|
292
|
+
# Initialize dict with trace_ref for JSON serialization
|
|
293
|
+
super().__init__({
|
|
294
|
+
"trace_ref": {
|
|
295
|
+
"object_type": object_type,
|
|
296
|
+
"object_id": object_id,
|
|
297
|
+
"root_span_id": root_span_id,
|
|
298
|
+
}
|
|
299
|
+
})
|
|
300
|
+
|
|
301
|
+
self._object_type = object_type
|
|
302
|
+
self._object_id = object_id
|
|
303
|
+
self._root_span_id = root_span_id
|
|
304
|
+
self._ensure_spans_flushed = ensure_spans_flushed
|
|
305
|
+
self._state = state
|
|
306
|
+
self._spans_flushed = False
|
|
307
|
+
self._spans_flush_promise: Optional[asyncio.Task[None]] = None
|
|
308
|
+
|
|
309
|
+
async def get_state() -> BraintrustState:
|
|
310
|
+
await self._ensure_spans_ready()
|
|
311
|
+
# Ensure state is logged in
|
|
312
|
+
await asyncio.get_event_loop().run_in_executor(None, lambda: state.login())
|
|
313
|
+
return state
|
|
314
|
+
|
|
315
|
+
self._cached_fetcher = CachedSpanFetcher(
|
|
316
|
+
object_type=object_type,
|
|
317
|
+
object_id=object_id,
|
|
318
|
+
root_span_id=root_span_id,
|
|
319
|
+
get_state=get_state,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def get_configuration(self) -> dict[str, str]:
|
|
323
|
+
"""Get the trace configuration."""
|
|
324
|
+
return {
|
|
325
|
+
"object_type": self._object_type,
|
|
326
|
+
"object_id": self._object_id,
|
|
327
|
+
"root_span_id": self._root_span_id,
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
|
|
331
|
+
"""
|
|
332
|
+
Fetch all rows for this root span from its parent object (experiment or project logs).
|
|
333
|
+
First checks the local span cache for recently logged spans, then falls
|
|
334
|
+
back to CachedSpanFetcher which handles BTQL fetching and caching.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
span_type: Optional list of span types to filter by
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
List of matching spans
|
|
341
|
+
"""
|
|
342
|
+
# Try local span cache first (for recently logged spans not yet flushed)
|
|
343
|
+
cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id)
|
|
344
|
+
if cached_spans and len(cached_spans) > 0:
|
|
345
|
+
# Filter by purpose
|
|
346
|
+
spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"]
|
|
347
|
+
|
|
348
|
+
# Filter by span type if requested
|
|
349
|
+
if span_type and len(span_type) > 0:
|
|
350
|
+
spans = [span for span in spans if (span.span_attributes or {}).get("type", "") in span_type]
|
|
351
|
+
|
|
352
|
+
# Convert to SpanData
|
|
353
|
+
return [
|
|
354
|
+
SpanData(
|
|
355
|
+
input=span.input,
|
|
356
|
+
output=span.output,
|
|
357
|
+
metadata=span.metadata,
|
|
358
|
+
span_id=span.span_id,
|
|
359
|
+
span_parents=span.span_parents,
|
|
360
|
+
span_attributes=span.span_attributes,
|
|
361
|
+
)
|
|
362
|
+
for span in spans
|
|
363
|
+
]
|
|
364
|
+
|
|
365
|
+
# Fall back to CachedSpanFetcher for BTQL fetching with caching
|
|
366
|
+
return await self._cached_fetcher.get_spans(span_type)
|
|
367
|
+
|
|
368
|
+
async def _ensure_spans_ready(self) -> None:
|
|
369
|
+
"""Ensure spans are flushed before fetching."""
|
|
370
|
+
if self._spans_flushed or not self._ensure_spans_flushed:
|
|
371
|
+
return
|
|
372
|
+
|
|
373
|
+
if self._spans_flush_promise is None:
|
|
374
|
+
|
|
375
|
+
async def flush_and_mark():
|
|
376
|
+
try:
|
|
377
|
+
await self._ensure_spans_flushed()
|
|
378
|
+
self._spans_flushed = True
|
|
379
|
+
except Exception as err:
|
|
380
|
+
self._spans_flush_promise = None
|
|
381
|
+
raise err
|
|
382
|
+
|
|
383
|
+
self._spans_flush_promise = asyncio.create_task(flush_and_mark())
|
|
384
|
+
|
|
385
|
+
await self._spans_flush_promise
|
braintrust/version.py
CHANGED
|
@@ -2,7 +2,7 @@ import dataclasses
|
|
|
2
2
|
import logging
|
|
3
3
|
import threading
|
|
4
4
|
import time
|
|
5
|
-
from collections.abc import AsyncGenerator, Callable
|
|
5
|
+
from collections.abc import AsyncGenerator, AsyncIterable, Callable
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
8
|
from braintrust.logger import start_span
|
|
@@ -191,17 +191,38 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
191
191
|
self.__client = client
|
|
192
192
|
self.__last_prompt: str | None = None
|
|
193
193
|
self.__query_start_time: float | None = None
|
|
194
|
+
self.__captured_messages: list[dict[str, Any]] | None = None
|
|
194
195
|
|
|
195
196
|
async def query(self, *args: Any, **kwargs: Any) -> Any:
|
|
196
197
|
"""Wrap query to capture the prompt and start time for tracing."""
|
|
197
198
|
# Capture the time when query is called (when LLM call starts)
|
|
198
199
|
self.__query_start_time = time.time()
|
|
200
|
+
self.__captured_messages = None
|
|
199
201
|
|
|
200
202
|
# Capture the prompt for use in receive_response
|
|
201
|
-
if args
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
203
|
+
prompt = args[0] if args else kwargs.get("prompt")
|
|
204
|
+
|
|
205
|
+
if prompt is not None:
|
|
206
|
+
if isinstance(prompt, str):
|
|
207
|
+
self.__last_prompt = prompt
|
|
208
|
+
elif isinstance(prompt, AsyncIterable):
|
|
209
|
+
# AsyncIterable[dict] - wrap it to capture messages as they're yielded
|
|
210
|
+
captured: list[dict[str, Any]] = []
|
|
211
|
+
self.__captured_messages = captured
|
|
212
|
+
self.__last_prompt = None # Will be set after messages are captured
|
|
213
|
+
|
|
214
|
+
async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
|
|
215
|
+
async for msg in prompt:
|
|
216
|
+
captured.append(msg)
|
|
217
|
+
yield msg
|
|
218
|
+
|
|
219
|
+
# Replace the prompt with our capturing wrapper
|
|
220
|
+
if args:
|
|
221
|
+
args = (capturing_wrapper(),) + args[1:]
|
|
222
|
+
else:
|
|
223
|
+
kwargs["prompt"] = capturing_wrapper()
|
|
224
|
+
else:
|
|
225
|
+
self.__last_prompt = str(prompt)
|
|
205
226
|
|
|
206
227
|
return await self.__client.query(*args, **kwargs)
|
|
207
228
|
|
|
@@ -215,11 +236,16 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
215
236
|
"""
|
|
216
237
|
generator = self.__client.receive_response()
|
|
217
238
|
|
|
239
|
+
# Determine the initial input - may be updated later if using async generator
|
|
240
|
+
initial_input = self.__last_prompt if self.__last_prompt else None
|
|
241
|
+
|
|
218
242
|
with start_span(
|
|
219
243
|
name="Claude Agent",
|
|
220
244
|
span_attributes={"type": SpanTypeAttribute.TASK},
|
|
221
|
-
input=
|
|
245
|
+
input=initial_input,
|
|
222
246
|
) as span:
|
|
247
|
+
# If we're capturing async messages, we'll update input after they're consumed
|
|
248
|
+
input_needs_update = self.__captured_messages is not None
|
|
223
249
|
# Store the parent span export in thread-local storage for tool handlers
|
|
224
250
|
_thread_local.parent_span_export = span.export()
|
|
225
251
|
|
|
@@ -228,6 +254,13 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
|
|
|
228
254
|
|
|
229
255
|
try:
|
|
230
256
|
async for message in generator:
|
|
257
|
+
# Update input from captured async messages (once, after they're consumed)
|
|
258
|
+
if input_needs_update and self.__captured_messages:
|
|
259
|
+
captured_input = _format_captured_messages(self.__captured_messages)
|
|
260
|
+
if captured_input:
|
|
261
|
+
span.log(input=captured_input)
|
|
262
|
+
input_needs_update = False
|
|
263
|
+
|
|
231
264
|
message_type = type(message).__name__
|
|
232
265
|
|
|
233
266
|
if message_type == "AssistantMessage":
|
|
@@ -390,3 +423,12 @@ def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) ->
|
|
|
390
423
|
return [{"content": prompt, "role": "user"}] + conversation_history
|
|
391
424
|
|
|
392
425
|
return conversation_history if conversation_history else None
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _format_captured_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
429
|
+
"""Formats captured async generator messages into structured input.
|
|
430
|
+
|
|
431
|
+
Returns the messages as-is to preserve structure for tracing.
|
|
432
|
+
Empty list returns empty list.
|
|
433
|
+
"""
|
|
434
|
+
return messages if messages else []
|
|
@@ -177,3 +177,109 @@ async def test_calculator_with_multiple_operations(memory_logger):
|
|
|
177
177
|
if span["span_id"] != root_span_id:
|
|
178
178
|
assert span["root_span_id"] == root_span_id
|
|
179
179
|
assert root_span_id in span["span_parents"]
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _make_message(content: str) -> dict:
|
|
183
|
+
"""Create a streaming format message dict."""
|
|
184
|
+
return {"type": "user", "message": {"role": "user", "content": content}}
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _assert_structured_input(task_span: dict, expected_contents: list[str]) -> None:
|
|
188
|
+
"""Assert that task span input is a structured list with expected content."""
|
|
189
|
+
inp = task_span.get("input")
|
|
190
|
+
assert isinstance(inp, list), f"Expected list input, got {type(inp).__name__}: {inp}"
|
|
191
|
+
assert [x["message"]["content"] for x in inp] == expected_contents
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class CustomAsyncIterable:
|
|
195
|
+
"""Custom AsyncIterable class (not a generator) for testing."""
|
|
196
|
+
|
|
197
|
+
def __init__(self, messages: list[dict]):
|
|
198
|
+
self._messages = messages
|
|
199
|
+
|
|
200
|
+
def __aiter__(self):
|
|
201
|
+
return CustomAsyncIterator(self._messages)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class CustomAsyncIterator:
|
|
205
|
+
"""Iterator for CustomAsyncIterable."""
|
|
206
|
+
|
|
207
|
+
def __init__(self, messages: list[dict]):
|
|
208
|
+
self._messages = messages
|
|
209
|
+
self._index = 0
|
|
210
|
+
|
|
211
|
+
async def __anext__(self):
|
|
212
|
+
if self._index >= len(self._messages):
|
|
213
|
+
raise StopAsyncIteration
|
|
214
|
+
msg = self._messages[self._index]
|
|
215
|
+
self._index += 1
|
|
216
|
+
return msg
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
|
|
220
|
+
@pytest.mark.asyncio
|
|
221
|
+
@pytest.mark.parametrize(
|
|
222
|
+
"input_factory,expected_contents",
|
|
223
|
+
[
|
|
224
|
+
pytest.param(
|
|
225
|
+
lambda: (msg async for msg in _single_message_generator()),
|
|
226
|
+
["What is 2 + 2?"],
|
|
227
|
+
id="asyncgen_single",
|
|
228
|
+
),
|
|
229
|
+
pytest.param(
|
|
230
|
+
lambda: (msg async for msg in _multi_message_generator()),
|
|
231
|
+
["Part 1", "Part 2"],
|
|
232
|
+
id="asyncgen_multi",
|
|
233
|
+
),
|
|
234
|
+
pytest.param(
|
|
235
|
+
lambda: CustomAsyncIterable([_make_message("Custom 1"), _make_message("Custom 2")]),
|
|
236
|
+
["Custom 1", "Custom 2"],
|
|
237
|
+
id="custom_async_iterable",
|
|
238
|
+
),
|
|
239
|
+
],
|
|
240
|
+
)
|
|
241
|
+
async def test_query_async_iterable(memory_logger, input_factory, expected_contents):
|
|
242
|
+
"""Test that async iterable inputs are captured as structured lists.
|
|
243
|
+
|
|
244
|
+
Verifies that passing AsyncIterable[dict] to query() results in the span
|
|
245
|
+
input showing the structured message list, not a flattened string or repr.
|
|
246
|
+
"""
|
|
247
|
+
assert not memory_logger.pop()
|
|
248
|
+
|
|
249
|
+
original_client = claude_agent_sdk.ClaudeSDKClient
|
|
250
|
+
claude_agent_sdk.ClaudeSDKClient = _create_client_wrapper_class(original_client)
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
options = claude_agent_sdk.ClaudeAgentOptions(model=TEST_MODEL)
|
|
254
|
+
|
|
255
|
+
async with claude_agent_sdk.ClaudeSDKClient(options=options) as client:
|
|
256
|
+
await client.query(input_factory())
|
|
257
|
+
async for message in client.receive_response():
|
|
258
|
+
if type(message).__name__ == "ResultMessage":
|
|
259
|
+
break
|
|
260
|
+
|
|
261
|
+
spans = memory_logger.pop()
|
|
262
|
+
|
|
263
|
+
task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK]
|
|
264
|
+
assert len(task_spans) >= 1, f"Should have at least one task span, got {len(task_spans)}"
|
|
265
|
+
|
|
266
|
+
task_span = next(
|
|
267
|
+
(s for s in task_spans if s["span_attributes"]["name"] == "Claude Agent"),
|
|
268
|
+
task_spans[0],
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
_assert_structured_input(task_span, expected_contents)
|
|
272
|
+
|
|
273
|
+
finally:
|
|
274
|
+
claude_agent_sdk.ClaudeSDKClient = original_client
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
async def _single_message_generator():
|
|
278
|
+
"""Generator yielding a single message."""
|
|
279
|
+
yield _make_message("What is 2 + 2?")
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
async def _multi_message_generator():
|
|
283
|
+
"""Generator yielding multiple messages."""
|
|
284
|
+
yield _make_message("Part 1")
|
|
285
|
+
yield _make_message("Part 2")
|