braintrust 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. braintrust/_generated_types.py +328 -126
  2. braintrust/cli/install/api.py +1 -1
  3. braintrust/conftest.py +24 -0
  4. braintrust/devserver/test_server_integration.py +0 -11
  5. braintrust/framework.py +98 -1
  6. braintrust/functions/invoke.py +4 -9
  7. braintrust/functions/test_invoke.py +61 -0
  8. braintrust/generated_types.py +13 -7
  9. braintrust/logger.py +107 -66
  10. braintrust/prompt_cache/test_disk_cache.py +3 -3
  11. braintrust/span_cache.py +337 -0
  12. braintrust/span_identifier_v3.py +21 -0
  13. braintrust/span_types.py +3 -0
  14. braintrust/test_bt_json.py +23 -19
  15. braintrust/test_logger.py +116 -0
  16. braintrust/test_span_cache.py +344 -0
  17. braintrust/test_trace.py +267 -0
  18. braintrust/trace.py +385 -0
  19. braintrust/version.py +2 -2
  20. braintrust/wrappers/claude_agent_sdk/_wrapper.py +48 -6
  21. braintrust/wrappers/claude_agent_sdk/test_wrapper.py +106 -0
  22. braintrust/wrappers/langsmith_wrapper.py +517 -0
  23. braintrust/wrappers/test_agno.py +0 -12
  24. braintrust/wrappers/test_anthropic.py +1 -11
  25. braintrust/wrappers/test_dspy.py +0 -11
  26. braintrust/wrappers/test_google_genai.py +6 -1
  27. braintrust/wrappers/test_langsmith_wrapper.py +338 -0
  28. braintrust/wrappers/test_litellm.py +0 -10
  29. braintrust/wrappers/test_oai_attachments.py +0 -10
  30. braintrust/wrappers/test_openai.py +3 -12
  31. braintrust/wrappers/test_openrouter.py +0 -9
  32. braintrust/wrappers/test_pydantic_ai_integration.py +0 -11
  33. braintrust/wrappers/test_pydantic_ai_wrap_openai.py +2 -0
  34. {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/METADATA +1 -1
  35. {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/RECORD +38 -31
  36. {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/WHEEL +1 -1
  37. {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/entry_points.txt +0 -0
  38. {braintrust-0.4.2.dist-info → braintrust-0.5.0.dist-info}/top_level.txt +0 -0
braintrust/trace.py ADDED
@@ -0,0 +1,385 @@
1
+ """
2
+ Trace objects for accessing spans in evaluations.
3
+
4
+ This module provides the LocalTrace class which allows scorers to access
5
+ spans from the current evaluation task without making server round-trips.
6
+ """
7
+
8
+ import asyncio
9
+ from typing import Any, Awaitable, Callable, Optional, Protocol
10
+
11
+ from braintrust.logger import BraintrustState, ObjectFetcher
12
+
13
+
14
+ class SpanData:
15
+ """Span data returned by get_spans()."""
16
+
17
+ def __init__(
18
+ self,
19
+ input: Optional[Any] = None,
20
+ output: Optional[Any] = None,
21
+ metadata: Optional[dict[str, Any]] = None,
22
+ span_id: Optional[str] = None,
23
+ span_parents: Optional[list[str]] = None,
24
+ span_attributes: Optional[dict[str, Any]] = None,
25
+ **kwargs: Any,
26
+ ):
27
+ self.input = input
28
+ self.output = output
29
+ self.metadata = metadata
30
+ self.span_id = span_id
31
+ self.span_parents = span_parents
32
+ self.span_attributes = span_attributes
33
+ # Store any additional fields
34
+ for key, value in kwargs.items():
35
+ setattr(self, key, value)
36
+
37
+ @classmethod
38
+ def from_dict(cls, data: dict[str, Any]) -> "SpanData":
39
+ """Create SpanData from a dictionary."""
40
+ return cls(**data)
41
+
42
+ def to_dict(self) -> dict[str, Any]:
43
+ """Convert to dictionary."""
44
+ result = {}
45
+ for key, value in self.__dict__.items():
46
+ if value is not None:
47
+ result[key] = value
48
+ return result
49
+
50
+
51
+ class SpanFetcher(ObjectFetcher[dict[str, Any]]):
52
+ """
53
+ Fetcher for spans by root_span_id, using the ObjectFetcher pattern.
54
+ Handles pagination automatically via cursor-based iteration.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
60
+ object_id: str,
61
+ root_span_id: str,
62
+ state: BraintrustState,
63
+ span_type_filter: Optional[list[str]] = None,
64
+ ):
65
+ # Build the filter expression for root_span_id and optionally span_attributes.type
66
+ filter_expr = self._build_filter(root_span_id, span_type_filter)
67
+
68
+ super().__init__(
69
+ object_type=object_type,
70
+ _internal_btql={"filter": filter_expr},
71
+ )
72
+ self._object_id = object_id
73
+ self._state = state
74
+
75
+ @staticmethod
76
+ def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = None) -> dict[str, Any]:
77
+ """Build BTQL filter expression."""
78
+ children = [
79
+ # Base filter: root_span_id = 'value'
80
+ {
81
+ "op": "eq",
82
+ "left": {"op": "ident", "name": ["root_span_id"]},
83
+ "right": {"op": "literal", "value": root_span_id},
84
+ },
85
+ # Exclude span_attributes.purpose = 'scorer'
86
+ {
87
+ "op": "or",
88
+ "children": [
89
+ {
90
+ "op": "isnull",
91
+ "expr": {"op": "ident", "name": ["span_attributes", "purpose"]},
92
+ },
93
+ {
94
+ "op": "ne",
95
+ "left": {"op": "ident", "name": ["span_attributes", "purpose"]},
96
+ "right": {"op": "literal", "value": "scorer"},
97
+ },
98
+ ],
99
+ },
100
+ ]
101
+
102
+ # If span type filter specified, add it
103
+ if span_type_filter and len(span_type_filter) > 0:
104
+ children.append(
105
+ {
106
+ "op": "in",
107
+ "left": {"op": "ident", "name": ["span_attributes", "type"]},
108
+ "right": {"op": "literal", "value": span_type_filter},
109
+ }
110
+ )
111
+
112
+ return {"op": "and", "children": children}
113
+
114
+ @property
115
+ def id(self) -> str:
116
+ return self._object_id
117
+
118
+ def _get_state(self) -> BraintrustState:
119
+ return self._state
120
+
121
+
122
+ SpanFetchFn = Callable[[Optional[list[str]]], Awaitable[list[SpanData]]]
123
+
124
+
125
+ class CachedSpanFetcher:
126
+ """
127
+ Cached span fetcher that handles fetching and caching spans by type.
128
+
129
+ Caching strategy:
130
+ - Cache spans by span type (dict[spanType, list[SpanData]])
131
+ - Track if all spans have been fetched (all_fetched flag)
132
+ - When filtering by spanType, only fetch types not already in cache
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ object_type: Optional[str] = None, # Literal["experiment", "project_logs", "playground_logs"]
138
+ object_id: Optional[str] = None,
139
+ root_span_id: Optional[str] = None,
140
+ get_state: Optional[Callable[[], Awaitable[BraintrustState]]] = None,
141
+ fetch_fn: Optional[SpanFetchFn] = None,
142
+ ):
143
+ self._span_cache: dict[str, list[SpanData]] = {}
144
+ self._all_fetched = False
145
+
146
+ if fetch_fn is not None:
147
+ # Direct fetch function injection (for testing)
148
+ self._fetch_fn = fetch_fn
149
+ else:
150
+ # Standard constructor with SpanFetcher
151
+ if object_type is None or object_id is None or root_span_id is None or get_state is None:
152
+ raise ValueError("Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state")
153
+
154
+ async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]:
155
+ state = await get_state()
156
+ fetcher = SpanFetcher(
157
+ object_type=object_type,
158
+ object_id=object_id,
159
+ root_span_id=root_span_id,
160
+ state=state,
161
+ span_type_filter=span_type,
162
+ )
163
+ rows = list(fetcher.fetch())
164
+ # Filter out scorer spans
165
+ filtered = [
166
+ row
167
+ for row in rows
168
+ if not (
169
+ isinstance(row.get("span_attributes"), dict)
170
+ and row.get("span_attributes", {}).get("purpose") == "scorer"
171
+ )
172
+ ]
173
+ return [
174
+ SpanData(
175
+ input=row.get("input"),
176
+ output=row.get("output"),
177
+ metadata=row.get("metadata"),
178
+ span_id=row.get("span_id"),
179
+ span_parents=row.get("span_parents"),
180
+ span_attributes=row.get("span_attributes"),
181
+ id=row.get("id"),
182
+ _xact_id=row.get("_xact_id"),
183
+ _pagination_key=row.get("_pagination_key"),
184
+ root_span_id=row.get("root_span_id"),
185
+ )
186
+ for row in filtered
187
+ ]
188
+
189
+ self._fetch_fn = _fetch_fn
190
+
191
+ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
192
+ """
193
+ Get spans, using cache when possible.
194
+
195
+ Args:
196
+ span_type: Optional list of span types to filter by
197
+
198
+ Returns:
199
+ List of matching spans
200
+ """
201
+ # If we've fetched all spans, just filter from cache
202
+ if self._all_fetched:
203
+ return self._get_from_cache(span_type)
204
+
205
+ # If no filter requested, fetch everything
206
+ if not span_type or len(span_type) == 0:
207
+ await self._fetch_spans(None)
208
+ self._all_fetched = True
209
+ return self._get_from_cache(None)
210
+
211
+ # Find which spanTypes we don't have in cache yet
212
+ missing_types = [t for t in span_type if t not in self._span_cache]
213
+
214
+ # If all requested types are cached, return from cache
215
+ if not missing_types:
216
+ return self._get_from_cache(span_type)
217
+
218
+ # Fetch only the missing types
219
+ await self._fetch_spans(missing_types)
220
+ return self._get_from_cache(span_type)
221
+
222
+ async def _fetch_spans(self, span_type: Optional[list[str]]) -> None:
223
+ """Fetch spans from the server."""
224
+ spans = await self._fetch_fn(span_type)
225
+
226
+ for span in spans:
227
+ span_attrs = span.span_attributes or {}
228
+ span_type_str = span_attrs.get("type", "")
229
+ if span_type_str not in self._span_cache:
230
+ self._span_cache[span_type_str] = []
231
+ self._span_cache[span_type_str].append(span)
232
+
233
+ def _get_from_cache(self, span_type: Optional[list[str]]) -> list[SpanData]:
234
+ """Get spans from cache, optionally filtering by type."""
235
+ if not span_type or len(span_type) == 0:
236
+ # Return all spans
237
+ result = []
238
+ for spans in self._span_cache.values():
239
+ result.extend(spans)
240
+ return result
241
+
242
+ # Return only requested types
243
+ result = []
244
+ for type_str in span_type:
245
+ if type_str in self._span_cache:
246
+ result.extend(self._span_cache[type_str])
247
+ return result
248
+
249
+
250
+ class Trace(Protocol):
251
+ """
252
+ Interface for trace objects that can be used by scorers.
253
+ Both the SDK's LocalTrace class and the API wrapper's WrapperTrace implement this.
254
+ """
255
+
256
+ def get_configuration(self) -> dict[str, str]:
257
+ """Get the trace configuration (object_type, object_id, root_span_id)."""
258
+ ...
259
+
260
+ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
261
+ """
262
+ Fetch all spans for this root span.
263
+
264
+ Args:
265
+ span_type: Optional list of span types to filter by
266
+
267
+ Returns:
268
+ List of matching spans
269
+ """
270
+ ...
271
+
272
+
273
+ class LocalTrace(dict):
274
+ """
275
+ SDK implementation of Trace that uses local span cache and falls back to BTQL.
276
+ Carries identifying information about the evaluation so scorers can perform
277
+ richer logging or side effects.
278
+
279
+ Inherits from dict so that it serializes to {"trace_ref": {...}} when passed
280
+ to json.dumps(). This allows LocalTrace to be transparently serialized when
281
+ passed through invoke() or other JSON-serializing code paths.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
287
+ object_id: str,
288
+ root_span_id: str,
289
+ ensure_spans_flushed: Optional[Callable[[], Awaitable[None]]],
290
+ state: BraintrustState,
291
+ ):
292
+ # Initialize dict with trace_ref for JSON serialization
293
+ super().__init__({
294
+ "trace_ref": {
295
+ "object_type": object_type,
296
+ "object_id": object_id,
297
+ "root_span_id": root_span_id,
298
+ }
299
+ })
300
+
301
+ self._object_type = object_type
302
+ self._object_id = object_id
303
+ self._root_span_id = root_span_id
304
+ self._ensure_spans_flushed = ensure_spans_flushed
305
+ self._state = state
306
+ self._spans_flushed = False
307
+ self._spans_flush_promise: Optional[asyncio.Task[None]] = None
308
+
309
+ async def get_state() -> BraintrustState:
310
+ await self._ensure_spans_ready()
311
+ # Ensure state is logged in
312
+ await asyncio.get_event_loop().run_in_executor(None, lambda: state.login())
313
+ return state
314
+
315
+ self._cached_fetcher = CachedSpanFetcher(
316
+ object_type=object_type,
317
+ object_id=object_id,
318
+ root_span_id=root_span_id,
319
+ get_state=get_state,
320
+ )
321
+
322
+ def get_configuration(self) -> dict[str, str]:
323
+ """Get the trace configuration."""
324
+ return {
325
+ "object_type": self._object_type,
326
+ "object_id": self._object_id,
327
+ "root_span_id": self._root_span_id,
328
+ }
329
+
330
+ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
331
+ """
332
+ Fetch all rows for this root span from its parent object (experiment or project logs).
333
+ First checks the local span cache for recently logged spans, then falls
334
+ back to CachedSpanFetcher which handles BTQL fetching and caching.
335
+
336
+ Args:
337
+ span_type: Optional list of span types to filter by
338
+
339
+ Returns:
340
+ List of matching spans
341
+ """
342
+ # Try local span cache first (for recently logged spans not yet flushed)
343
+ cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id)
344
+ if cached_spans and len(cached_spans) > 0:
345
+ # Filter by purpose
346
+ spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"]
347
+
348
+ # Filter by span type if requested
349
+ if span_type and len(span_type) > 0:
350
+ spans = [span for span in spans if (span.span_attributes or {}).get("type", "") in span_type]
351
+
352
+ # Convert to SpanData
353
+ return [
354
+ SpanData(
355
+ input=span.input,
356
+ output=span.output,
357
+ metadata=span.metadata,
358
+ span_id=span.span_id,
359
+ span_parents=span.span_parents,
360
+ span_attributes=span.span_attributes,
361
+ )
362
+ for span in spans
363
+ ]
364
+
365
+ # Fall back to CachedSpanFetcher for BTQL fetching with caching
366
+ return await self._cached_fetcher.get_spans(span_type)
367
+
368
+ async def _ensure_spans_ready(self) -> None:
369
+ """Ensure spans are flushed before fetching."""
370
+ if self._spans_flushed or not self._ensure_spans_flushed:
371
+ return
372
+
373
+ if self._spans_flush_promise is None:
374
+
375
+ async def flush_and_mark():
376
+ try:
377
+ await self._ensure_spans_flushed()
378
+ self._spans_flushed = True
379
+ except Exception as err:
380
+ self._spans_flush_promise = None
381
+ raise err
382
+
383
+ self._spans_flush_promise = asyncio.create_task(flush_and_mark())
384
+
385
+ await self._spans_flush_promise
braintrust/version.py CHANGED
@@ -1,4 +1,4 @@
1
- VERSION = "0.4.2"
1
+ VERSION = "0.5.0"
2
2
 
3
3
  # this will be templated during the build
4
- GIT_COMMIT = "3ca420e53e77d4665b91ccc7631c95dc97ce566d"
4
+ GIT_COMMIT = "617d9b730b37e96b7d05a099b95f5387944d0951"
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import logging
3
3
  import threading
4
4
  import time
5
- from collections.abc import AsyncGenerator, Callable
5
+ from collections.abc import AsyncGenerator, AsyncIterable, Callable
6
6
  from typing import Any
7
7
 
8
8
  from braintrust.logger import start_span
@@ -191,17 +191,38 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
191
191
  self.__client = client
192
192
  self.__last_prompt: str | None = None
193
193
  self.__query_start_time: float | None = None
194
+ self.__captured_messages: list[dict[str, Any]] | None = None
194
195
 
195
196
  async def query(self, *args: Any, **kwargs: Any) -> Any:
196
197
  """Wrap query to capture the prompt and start time for tracing."""
197
198
  # Capture the time when query is called (when LLM call starts)
198
199
  self.__query_start_time = time.time()
200
+ self.__captured_messages = None
199
201
 
200
202
  # Capture the prompt for use in receive_response
201
- if args:
202
- self.__last_prompt = str(args[0])
203
- elif "prompt" in kwargs:
204
- self.__last_prompt = str(kwargs["prompt"])
203
+ prompt = args[0] if args else kwargs.get("prompt")
204
+
205
+ if prompt is not None:
206
+ if isinstance(prompt, str):
207
+ self.__last_prompt = prompt
208
+ elif isinstance(prompt, AsyncIterable):
209
+ # AsyncIterable[dict] - wrap it to capture messages as they're yielded
210
+ captured: list[dict[str, Any]] = []
211
+ self.__captured_messages = captured
212
+ self.__last_prompt = None # Will be set after messages are captured
213
+
214
+ async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
215
+ async for msg in prompt:
216
+ captured.append(msg)
217
+ yield msg
218
+
219
+ # Replace the prompt with our capturing wrapper
220
+ if args:
221
+ args = (capturing_wrapper(),) + args[1:]
222
+ else:
223
+ kwargs["prompt"] = capturing_wrapper()
224
+ else:
225
+ self.__last_prompt = str(prompt)
205
226
 
206
227
  return await self.__client.query(*args, **kwargs)
207
228
 
@@ -215,11 +236,16 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
215
236
  """
216
237
  generator = self.__client.receive_response()
217
238
 
239
+ # Determine the initial input - may be updated later if using async generator
240
+ initial_input = self.__last_prompt if self.__last_prompt else None
241
+
218
242
  with start_span(
219
243
  name="Claude Agent",
220
244
  span_attributes={"type": SpanTypeAttribute.TASK},
221
- input=self.__last_prompt if self.__last_prompt else None,
245
+ input=initial_input,
222
246
  ) as span:
247
+ # If we're capturing async messages, we'll update input after they're consumed
248
+ input_needs_update = self.__captured_messages is not None
223
249
  # Store the parent span export in thread-local storage for tool handlers
224
250
  _thread_local.parent_span_export = span.export()
225
251
 
@@ -228,6 +254,13 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
228
254
 
229
255
  try:
230
256
  async for message in generator:
257
+ # Update input from captured async messages (once, after they're consumed)
258
+ if input_needs_update and self.__captured_messages:
259
+ captured_input = _format_captured_messages(self.__captured_messages)
260
+ if captured_input:
261
+ span.log(input=captured_input)
262
+ input_needs_update = False
263
+
231
264
  message_type = type(message).__name__
232
265
 
233
266
  if message_type == "AssistantMessage":
@@ -390,3 +423,12 @@ def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) ->
390
423
  return [{"content": prompt, "role": "user"}] + conversation_history
391
424
 
392
425
  return conversation_history if conversation_history else None
426
+
427
+
428
+ def _format_captured_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
429
+ """Formats captured async generator messages into structured input.
430
+
431
+ Returns the messages as-is to preserve structure for tracing.
432
+ Empty list returns empty list.
433
+ """
434
+ return messages if messages else []
@@ -177,3 +177,109 @@ async def test_calculator_with_multiple_operations(memory_logger):
177
177
  if span["span_id"] != root_span_id:
178
178
  assert span["root_span_id"] == root_span_id
179
179
  assert root_span_id in span["span_parents"]
180
+
181
+
182
+ def _make_message(content: str) -> dict:
183
+ """Create a streaming format message dict."""
184
+ return {"type": "user", "message": {"role": "user", "content": content}}
185
+
186
+
187
+ def _assert_structured_input(task_span: dict, expected_contents: list[str]) -> None:
188
+ """Assert that task span input is a structured list with expected content."""
189
+ inp = task_span.get("input")
190
+ assert isinstance(inp, list), f"Expected list input, got {type(inp).__name__}: {inp}"
191
+ assert [x["message"]["content"] for x in inp] == expected_contents
192
+
193
+
194
+ class CustomAsyncIterable:
195
+ """Custom AsyncIterable class (not a generator) for testing."""
196
+
197
+ def __init__(self, messages: list[dict]):
198
+ self._messages = messages
199
+
200
+ def __aiter__(self):
201
+ return CustomAsyncIterator(self._messages)
202
+
203
+
204
+ class CustomAsyncIterator:
205
+ """Iterator for CustomAsyncIterable."""
206
+
207
+ def __init__(self, messages: list[dict]):
208
+ self._messages = messages
209
+ self._index = 0
210
+
211
+ async def __anext__(self):
212
+ if self._index >= len(self._messages):
213
+ raise StopAsyncIteration
214
+ msg = self._messages[self._index]
215
+ self._index += 1
216
+ return msg
217
+
218
+
219
+ @pytest.mark.skipif(not CLAUDE_SDK_AVAILABLE, reason="Claude Agent SDK not installed")
220
+ @pytest.mark.asyncio
221
+ @pytest.mark.parametrize(
222
+ "input_factory,expected_contents",
223
+ [
224
+ pytest.param(
225
+ lambda: (msg async for msg in _single_message_generator()),
226
+ ["What is 2 + 2?"],
227
+ id="asyncgen_single",
228
+ ),
229
+ pytest.param(
230
+ lambda: (msg async for msg in _multi_message_generator()),
231
+ ["Part 1", "Part 2"],
232
+ id="asyncgen_multi",
233
+ ),
234
+ pytest.param(
235
+ lambda: CustomAsyncIterable([_make_message("Custom 1"), _make_message("Custom 2")]),
236
+ ["Custom 1", "Custom 2"],
237
+ id="custom_async_iterable",
238
+ ),
239
+ ],
240
+ )
241
+ async def test_query_async_iterable(memory_logger, input_factory, expected_contents):
242
+ """Test that async iterable inputs are captured as structured lists.
243
+
244
+ Verifies that passing AsyncIterable[dict] to query() results in the span
245
+ input showing the structured message list, not a flattened string or repr.
246
+ """
247
+ assert not memory_logger.pop()
248
+
249
+ original_client = claude_agent_sdk.ClaudeSDKClient
250
+ claude_agent_sdk.ClaudeSDKClient = _create_client_wrapper_class(original_client)
251
+
252
+ try:
253
+ options = claude_agent_sdk.ClaudeAgentOptions(model=TEST_MODEL)
254
+
255
+ async with claude_agent_sdk.ClaudeSDKClient(options=options) as client:
256
+ await client.query(input_factory())
257
+ async for message in client.receive_response():
258
+ if type(message).__name__ == "ResultMessage":
259
+ break
260
+
261
+ spans = memory_logger.pop()
262
+
263
+ task_spans = [s for s in spans if s["span_attributes"]["type"] == SpanTypeAttribute.TASK]
264
+ assert len(task_spans) >= 1, f"Should have at least one task span, got {len(task_spans)}"
265
+
266
+ task_span = next(
267
+ (s for s in task_spans if s["span_attributes"]["name"] == "Claude Agent"),
268
+ task_spans[0],
269
+ )
270
+
271
+ _assert_structured_input(task_span, expected_contents)
272
+
273
+ finally:
274
+ claude_agent_sdk.ClaudeSDKClient = original_client
275
+
276
+
277
+ async def _single_message_generator():
278
+ """Generator yielding a single message."""
279
+ yield _make_message("What is 2 + 2?")
280
+
281
+
282
+ async def _multi_message_generator():
283
+ """Generator yielding multiple messages."""
284
+ yield _make_message("Part 1")
285
+ yield _make_message("Part 2")