braintrust 0.4.3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. braintrust/__init__.py +3 -0
  2. braintrust/_generated_types.py +106 -6
  3. braintrust/auto.py +179 -0
  4. braintrust/conftest.py +23 -4
  5. braintrust/framework.py +113 -3
  6. braintrust/functions/invoke.py +3 -1
  7. braintrust/functions/test_invoke.py +61 -0
  8. braintrust/generated_types.py +7 -1
  9. braintrust/logger.py +127 -45
  10. braintrust/oai.py +51 -0
  11. braintrust/span_cache.py +337 -0
  12. braintrust/span_identifier_v3.py +21 -0
  13. braintrust/test_bt_json.py +0 -5
  14. braintrust/test_framework.py +37 -0
  15. braintrust/test_http.py +444 -0
  16. braintrust/test_logger.py +295 -5
  17. braintrust/test_span_cache.py +344 -0
  18. braintrust/test_trace.py +267 -0
  19. braintrust/test_util.py +58 -1
  20. braintrust/trace.py +385 -0
  21. braintrust/util.py +20 -0
  22. braintrust/version.py +2 -2
  23. braintrust/wrappers/agno/__init__.py +2 -3
  24. braintrust/wrappers/anthropic.py +64 -0
  25. braintrust/wrappers/claude_agent_sdk/__init__.py +2 -3
  26. braintrust/wrappers/claude_agent_sdk/_wrapper.py +48 -6
  27. braintrust/wrappers/claude_agent_sdk/test_wrapper.py +115 -0
  28. braintrust/wrappers/dspy.py +52 -1
  29. braintrust/wrappers/google_genai/__init__.py +9 -6
  30. braintrust/wrappers/litellm.py +6 -43
  31. braintrust/wrappers/pydantic_ai.py +2 -3
  32. braintrust/wrappers/test_agno.py +9 -0
  33. braintrust/wrappers/test_anthropic.py +156 -0
  34. braintrust/wrappers/test_dspy.py +117 -0
  35. braintrust/wrappers/test_google_genai.py +9 -0
  36. braintrust/wrappers/test_litellm.py +57 -55
  37. braintrust/wrappers/test_openai.py +253 -1
  38. braintrust/wrappers/test_pydantic_ai_integration.py +9 -0
  39. braintrust/wrappers/test_utils.py +79 -0
  40. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/METADATA +1 -1
  41. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/RECORD +44 -37
  42. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/WHEEL +1 -1
  43. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/entry_points.txt +0 -0
  44. {braintrust-0.4.3.dist-info → braintrust-0.5.2.dist-info}/top_level.txt +0 -0
braintrust/trace.py ADDED
@@ -0,0 +1,385 @@
1
+ """
2
+ Trace objects for accessing spans in evaluations.
3
+
4
+ This module provides the LocalTrace class which allows scorers to access
5
+ spans from the current evaluation task without making server round-trips.
6
+ """
7
+
8
+ import asyncio
9
+ from typing import Any, Awaitable, Callable, Optional, Protocol
10
+
11
+ from braintrust.logger import BraintrustState, ObjectFetcher
12
+
13
+
14
+ class SpanData:
15
+ """Span data returned by get_spans()."""
16
+
17
+ def __init__(
18
+ self,
19
+ input: Optional[Any] = None,
20
+ output: Optional[Any] = None,
21
+ metadata: Optional[dict[str, Any]] = None,
22
+ span_id: Optional[str] = None,
23
+ span_parents: Optional[list[str]] = None,
24
+ span_attributes: Optional[dict[str, Any]] = None,
25
+ **kwargs: Any,
26
+ ):
27
+ self.input = input
28
+ self.output = output
29
+ self.metadata = metadata
30
+ self.span_id = span_id
31
+ self.span_parents = span_parents
32
+ self.span_attributes = span_attributes
33
+ # Store any additional fields
34
+ for key, value in kwargs.items():
35
+ setattr(self, key, value)
36
+
37
+ @classmethod
38
+ def from_dict(cls, data: dict[str, Any]) -> "SpanData":
39
+ """Create SpanData from a dictionary."""
40
+ return cls(**data)
41
+
42
+ def to_dict(self) -> dict[str, Any]:
43
+ """Convert to dictionary."""
44
+ result = {}
45
+ for key, value in self.__dict__.items():
46
+ if value is not None:
47
+ result[key] = value
48
+ return result
49
+
50
+
51
+ class SpanFetcher(ObjectFetcher[dict[str, Any]]):
52
+ """
53
+ Fetcher for spans by root_span_id, using the ObjectFetcher pattern.
54
+ Handles pagination automatically via cursor-based iteration.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
60
+ object_id: str,
61
+ root_span_id: str,
62
+ state: BraintrustState,
63
+ span_type_filter: Optional[list[str]] = None,
64
+ ):
65
+ # Build the filter expression for root_span_id and optionally span_attributes.type
66
+ filter_expr = self._build_filter(root_span_id, span_type_filter)
67
+
68
+ super().__init__(
69
+ object_type=object_type,
70
+ _internal_btql={"filter": filter_expr},
71
+ )
72
+ self._object_id = object_id
73
+ self._state = state
74
+
75
+ @staticmethod
76
+ def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = None) -> dict[str, Any]:
77
+ """Build BTQL filter expression."""
78
+ children = [
79
+ # Base filter: root_span_id = 'value'
80
+ {
81
+ "op": "eq",
82
+ "left": {"op": "ident", "name": ["root_span_id"]},
83
+ "right": {"op": "literal", "value": root_span_id},
84
+ },
85
+ # Exclude span_attributes.purpose = 'scorer'
86
+ {
87
+ "op": "or",
88
+ "children": [
89
+ {
90
+ "op": "isnull",
91
+ "expr": {"op": "ident", "name": ["span_attributes", "purpose"]},
92
+ },
93
+ {
94
+ "op": "ne",
95
+ "left": {"op": "ident", "name": ["span_attributes", "purpose"]},
96
+ "right": {"op": "literal", "value": "scorer"},
97
+ },
98
+ ],
99
+ },
100
+ ]
101
+
102
+ # If span type filter specified, add it
103
+ if span_type_filter and len(span_type_filter) > 0:
104
+ children.append(
105
+ {
106
+ "op": "in",
107
+ "left": {"op": "ident", "name": ["span_attributes", "type"]},
108
+ "right": {"op": "literal", "value": span_type_filter},
109
+ }
110
+ )
111
+
112
+ return {"op": "and", "children": children}
113
+
114
+ @property
115
+ def id(self) -> str:
116
+ return self._object_id
117
+
118
+ def _get_state(self) -> BraintrustState:
119
+ return self._state
120
+
121
+
122
+ SpanFetchFn = Callable[[Optional[list[str]]], Awaitable[list[SpanData]]]
123
+
124
+
125
+ class CachedSpanFetcher:
126
+ """
127
+ Cached span fetcher that handles fetching and caching spans by type.
128
+
129
+ Caching strategy:
130
+ - Cache spans by span type (dict[spanType, list[SpanData]])
131
+ - Track if all spans have been fetched (all_fetched flag)
132
+ - When filtering by spanType, only fetch types not already in cache
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ object_type: Optional[str] = None, # Literal["experiment", "project_logs", "playground_logs"]
138
+ object_id: Optional[str] = None,
139
+ root_span_id: Optional[str] = None,
140
+ get_state: Optional[Callable[[], Awaitable[BraintrustState]]] = None,
141
+ fetch_fn: Optional[SpanFetchFn] = None,
142
+ ):
143
+ self._span_cache: dict[str, list[SpanData]] = {}
144
+ self._all_fetched = False
145
+
146
+ if fetch_fn is not None:
147
+ # Direct fetch function injection (for testing)
148
+ self._fetch_fn = fetch_fn
149
+ else:
150
+ # Standard constructor with SpanFetcher
151
+ if object_type is None or object_id is None or root_span_id is None or get_state is None:
152
+ raise ValueError("Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state")
153
+
154
+ async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]:
155
+ state = await get_state()
156
+ fetcher = SpanFetcher(
157
+ object_type=object_type,
158
+ object_id=object_id,
159
+ root_span_id=root_span_id,
160
+ state=state,
161
+ span_type_filter=span_type,
162
+ )
163
+ rows = list(fetcher.fetch())
164
+ # Filter out scorer spans
165
+ filtered = [
166
+ row
167
+ for row in rows
168
+ if not (
169
+ isinstance(row.get("span_attributes"), dict)
170
+ and row.get("span_attributes", {}).get("purpose") == "scorer"
171
+ )
172
+ ]
173
+ return [
174
+ SpanData(
175
+ input=row.get("input"),
176
+ output=row.get("output"),
177
+ metadata=row.get("metadata"),
178
+ span_id=row.get("span_id"),
179
+ span_parents=row.get("span_parents"),
180
+ span_attributes=row.get("span_attributes"),
181
+ id=row.get("id"),
182
+ _xact_id=row.get("_xact_id"),
183
+ _pagination_key=row.get("_pagination_key"),
184
+ root_span_id=row.get("root_span_id"),
185
+ )
186
+ for row in filtered
187
+ ]
188
+
189
+ self._fetch_fn = _fetch_fn
190
+
191
+ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
192
+ """
193
+ Get spans, using cache when possible.
194
+
195
+ Args:
196
+ span_type: Optional list of span types to filter by
197
+
198
+ Returns:
199
+ List of matching spans
200
+ """
201
+ # If we've fetched all spans, just filter from cache
202
+ if self._all_fetched:
203
+ return self._get_from_cache(span_type)
204
+
205
+ # If no filter requested, fetch everything
206
+ if not span_type or len(span_type) == 0:
207
+ await self._fetch_spans(None)
208
+ self._all_fetched = True
209
+ return self._get_from_cache(None)
210
+
211
+ # Find which spanTypes we don't have in cache yet
212
+ missing_types = [t for t in span_type if t not in self._span_cache]
213
+
214
+ # If all requested types are cached, return from cache
215
+ if not missing_types:
216
+ return self._get_from_cache(span_type)
217
+
218
+ # Fetch only the missing types
219
+ await self._fetch_spans(missing_types)
220
+ return self._get_from_cache(span_type)
221
+
222
+ async def _fetch_spans(self, span_type: Optional[list[str]]) -> None:
223
+ """Fetch spans from the server."""
224
+ spans = await self._fetch_fn(span_type)
225
+
226
+ for span in spans:
227
+ span_attrs = span.span_attributes or {}
228
+ span_type_str = span_attrs.get("type", "")
229
+ if span_type_str not in self._span_cache:
230
+ self._span_cache[span_type_str] = []
231
+ self._span_cache[span_type_str].append(span)
232
+
233
+ def _get_from_cache(self, span_type: Optional[list[str]]) -> list[SpanData]:
234
+ """Get spans from cache, optionally filtering by type."""
235
+ if not span_type or len(span_type) == 0:
236
+ # Return all spans
237
+ result = []
238
+ for spans in self._span_cache.values():
239
+ result.extend(spans)
240
+ return result
241
+
242
+ # Return only requested types
243
+ result = []
244
+ for type_str in span_type:
245
+ if type_str in self._span_cache:
246
+ result.extend(self._span_cache[type_str])
247
+ return result
248
+
249
+
250
+ class Trace(Protocol):
251
+ """
252
+ Interface for trace objects that can be used by scorers.
253
+ Both the SDK's LocalTrace class and the API wrapper's WrapperTrace implement this.
254
+ """
255
+
256
+ def get_configuration(self) -> dict[str, str]:
257
+ """Get the trace configuration (object_type, object_id, root_span_id)."""
258
+ ...
259
+
260
+ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
261
+ """
262
+ Fetch all spans for this root span.
263
+
264
+ Args:
265
+ span_type: Optional list of span types to filter by
266
+
267
+ Returns:
268
+ List of matching spans
269
+ """
270
+ ...
271
+
272
+
273
+ class LocalTrace(dict):
274
+ """
275
+ SDK implementation of Trace that uses local span cache and falls back to BTQL.
276
+ Carries identifying information about the evaluation so scorers can perform
277
+ richer logging or side effects.
278
+
279
+ Inherits from dict so that it serializes to {"trace_ref": {...}} when passed
280
+ to json.dumps(). This allows LocalTrace to be transparently serialized when
281
+ passed through invoke() or other JSON-serializing code paths.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ object_type: str, # Literal["experiment", "project_logs", "playground_logs"]
287
+ object_id: str,
288
+ root_span_id: str,
289
+ ensure_spans_flushed: Optional[Callable[[], Awaitable[None]]],
290
+ state: BraintrustState,
291
+ ):
292
+ # Initialize dict with trace_ref for JSON serialization
293
+ super().__init__({
294
+ "trace_ref": {
295
+ "object_type": object_type,
296
+ "object_id": object_id,
297
+ "root_span_id": root_span_id,
298
+ }
299
+ })
300
+
301
+ self._object_type = object_type
302
+ self._object_id = object_id
303
+ self._root_span_id = root_span_id
304
+ self._ensure_spans_flushed = ensure_spans_flushed
305
+ self._state = state
306
+ self._spans_flushed = False
307
+ self._spans_flush_promise: Optional[asyncio.Task[None]] = None
308
+
309
+ async def get_state() -> BraintrustState:
310
+ await self._ensure_spans_ready()
311
+ # Ensure state is logged in
312
+ await asyncio.get_event_loop().run_in_executor(None, lambda: state.login())
313
+ return state
314
+
315
+ self._cached_fetcher = CachedSpanFetcher(
316
+ object_type=object_type,
317
+ object_id=object_id,
318
+ root_span_id=root_span_id,
319
+ get_state=get_state,
320
+ )
321
+
322
+ def get_configuration(self) -> dict[str, str]:
323
+ """Get the trace configuration."""
324
+ return {
325
+ "object_type": self._object_type,
326
+ "object_id": self._object_id,
327
+ "root_span_id": self._root_span_id,
328
+ }
329
+
330
+ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]:
331
+ """
332
+ Fetch all rows for this root span from its parent object (experiment or project logs).
333
+ First checks the local span cache for recently logged spans, then falls
334
+ back to CachedSpanFetcher which handles BTQL fetching and caching.
335
+
336
+ Args:
337
+ span_type: Optional list of span types to filter by
338
+
339
+ Returns:
340
+ List of matching spans
341
+ """
342
+ # Try local span cache first (for recently logged spans not yet flushed)
343
+ cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id)
344
+ if cached_spans and len(cached_spans) > 0:
345
+ # Filter by purpose
346
+ spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"]
347
+
348
+ # Filter by span type if requested
349
+ if span_type and len(span_type) > 0:
350
+ spans = [span for span in spans if (span.span_attributes or {}).get("type", "") in span_type]
351
+
352
+ # Convert to SpanData
353
+ return [
354
+ SpanData(
355
+ input=span.input,
356
+ output=span.output,
357
+ metadata=span.metadata,
358
+ span_id=span.span_id,
359
+ span_parents=span.span_parents,
360
+ span_attributes=span.span_attributes,
361
+ )
362
+ for span in spans
363
+ ]
364
+
365
+ # Fall back to CachedSpanFetcher for BTQL fetching with caching
366
+ return await self._cached_fetcher.get_spans(span_type)
367
+
368
+ async def _ensure_spans_ready(self) -> None:
369
+ """Ensure spans are flushed before fetching."""
370
+ if self._spans_flushed or not self._ensure_spans_flushed:
371
+ return
372
+
373
+ if self._spans_flush_promise is None:
374
+
375
+ async def flush_and_mark():
376
+ try:
377
+ await self._ensure_spans_flushed()
378
+ self._spans_flushed = True
379
+ except Exception as err:
380
+ self._spans_flush_promise = None
381
+ raise err
382
+
383
+ self._spans_flush_promise = asyncio.create_task(flush_and_mark())
384
+
385
+ await self._spans_flush_promise
braintrust/util.py CHANGED
@@ -1,5 +1,7 @@
1
1
  import inspect
2
2
  import json
3
+ import math
4
+ import os
3
5
  import sys
4
6
  import threading
5
7
  import urllib.parse
@@ -9,6 +11,24 @@ from typing import Any, Generic, Literal, TypedDict, TypeVar, Union
9
11
 
10
12
  from requests import HTTPError, Response
11
13
 
14
+
15
+ def parse_env_var_float(name: str, default: float) -> float:
16
+ """Parse a float from an environment variable, returning default if invalid.
17
+
18
+ Returns the default value if the env var is missing, empty, not a valid
19
+ float, NaN, or infinity.
20
+ """
21
+ value = os.environ.get(name)
22
+ if value is None:
23
+ return default
24
+ try:
25
+ result = float(value)
26
+ if math.isnan(result) or math.isinf(result):
27
+ return default
28
+ return result
29
+ except (ValueError, TypeError):
30
+ return default
31
+
12
32
  GLOBAL_PROJECT = "Global"
13
33
  BT_IS_ASYNC_ATTRIBUTE = "_BT_IS_ASYNC"
14
34
 
braintrust/version.py CHANGED
@@ -1,4 +1,4 @@
1
- VERSION = "0.4.3"
1
+ VERSION = "0.5.2"
2
2
 
3
3
  # this will be templated during the build
4
- GIT_COMMIT = "d734e8ffc272ee65fe0588df00fd9390614ccd2e"
4
+ GIT_COMMIT = "25868bc58450dad2058b6499ce3bb9400330fbd1"
@@ -62,7 +62,6 @@ def setup_agno(
62
62
  models.base.Model = wrap_model(models.base.Model) # pyright: ignore[reportUnknownMemberType]
63
63
  tools.function.FunctionCall = wrap_function_call(tools.function.FunctionCall) # pyright: ignore[reportUnknownMemberType]
64
64
  return True
65
- except ImportError as e:
66
- logger.error(f"Failed to import Agno: {e}")
67
- logger.error("Agno is not installed. Please install it with: pip install agno")
65
+ except ImportError:
66
+ # Not installed - this is expected when using auto_instrument()
68
67
  return False
@@ -5,6 +5,7 @@ from contextlib import contextmanager
5
5
 
6
6
  from braintrust.logger import NOOP_SPAN, log_exc_info_to_span, start_span
7
7
  from braintrust.wrappers._anthropic_utils import Wrapper, extract_anthropic_usage, finalize_anthropic_tokens
8
+ from wrapt import wrap_function_wrapper
8
9
 
9
10
  log = logging.getLogger(__name__)
10
11
 
@@ -358,3 +359,66 @@ def wrap_anthropic(client):
358
359
 
359
360
  def wrap_anthropic_client(client):
360
361
  return wrap_anthropic(client)
362
+
363
+
364
+ def _apply_anthropic_wrapper(client):
365
+ """Apply tracing wrapper to an Anthropic client instance in-place."""
366
+ wrapped = wrap_anthropic(client)
367
+ client.messages = wrapped.messages
368
+ if hasattr(wrapped, "beta"):
369
+ client.beta = wrapped.beta
370
+
371
+
372
+ def _apply_async_anthropic_wrapper(client):
373
+ """Apply tracing wrapper to an AsyncAnthropic client instance in-place."""
374
+ wrapped = wrap_anthropic(client)
375
+ client.messages = wrapped.messages
376
+ if hasattr(wrapped, "beta"):
377
+ client.beta = wrapped.beta
378
+
379
+
380
+ def _anthropic_init_wrapper(wrapped, instance, args, kwargs):
381
+ """Wrapper for Anthropic.__init__ that applies tracing after initialization."""
382
+ wrapped(*args, **kwargs)
383
+ _apply_anthropic_wrapper(instance)
384
+
385
+
386
+ def _async_anthropic_init_wrapper(wrapped, instance, args, kwargs):
387
+ """Wrapper for AsyncAnthropic.__init__ that applies tracing after initialization."""
388
+ wrapped(*args, **kwargs)
389
+ _apply_async_anthropic_wrapper(instance)
390
+
391
+
392
+ def patch_anthropic() -> bool:
393
+ """
394
+ Patch Anthropic to add Braintrust tracing globally.
395
+
396
+ After calling this, all new Anthropic() and AsyncAnthropic() clients
397
+ will automatically have tracing enabled.
398
+
399
+ Returns:
400
+ True if Anthropic was patched (or already patched), False if Anthropic is not installed.
401
+
402
+ Example:
403
+ ```python
404
+ import braintrust
405
+ braintrust.patch_anthropic()
406
+
407
+ import anthropic
408
+ client = anthropic.Anthropic()
409
+ # All calls are now traced!
410
+ ```
411
+ """
412
+ try:
413
+ import anthropic
414
+
415
+ if getattr(anthropic, "__braintrust_wrapped__", False):
416
+ return True # Already patched
417
+
418
+ wrap_function_wrapper("anthropic", "Anthropic.__init__", _anthropic_init_wrapper)
419
+ wrap_function_wrapper("anthropic", "AsyncAnthropic.__init__", _async_anthropic_init_wrapper)
420
+ anthropic.__braintrust_wrapped__ = True
421
+ return True
422
+
423
+ except ImportError:
424
+ return False
@@ -105,7 +105,6 @@ def setup_claude_agent_sdk(
105
105
  setattr(module, "tool", wrapped_tool_fn)
106
106
 
107
107
  return True
108
- except ImportError as e:
109
- logger.error(f"Failed to import Claude Agent SDK: {e}")
110
- logger.error("claude-agent-sdk is not installed. Please install it with: pip install claude-agent-sdk")
108
+ except ImportError:
109
+ # Not installed - this is expected when using auto_instrument()
111
110
  return False
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import logging
3
3
  import threading
4
4
  import time
5
- from collections.abc import AsyncGenerator, Callable
5
+ from collections.abc import AsyncGenerator, AsyncIterable, Callable
6
6
  from typing import Any
7
7
 
8
8
  from braintrust.logger import start_span
@@ -191,17 +191,38 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
191
191
  self.__client = client
192
192
  self.__last_prompt: str | None = None
193
193
  self.__query_start_time: float | None = None
194
+ self.__captured_messages: list[dict[str, Any]] | None = None
194
195
 
195
196
  async def query(self, *args: Any, **kwargs: Any) -> Any:
196
197
  """Wrap query to capture the prompt and start time for tracing."""
197
198
  # Capture the time when query is called (when LLM call starts)
198
199
  self.__query_start_time = time.time()
200
+ self.__captured_messages = None
199
201
 
200
202
  # Capture the prompt for use in receive_response
201
- if args:
202
- self.__last_prompt = str(args[0])
203
- elif "prompt" in kwargs:
204
- self.__last_prompt = str(kwargs["prompt"])
203
+ prompt = args[0] if args else kwargs.get("prompt")
204
+
205
+ if prompt is not None:
206
+ if isinstance(prompt, str):
207
+ self.__last_prompt = prompt
208
+ elif isinstance(prompt, AsyncIterable):
209
+ # AsyncIterable[dict] - wrap it to capture messages as they're yielded
210
+ captured: list[dict[str, Any]] = []
211
+ self.__captured_messages = captured
212
+ self.__last_prompt = None # Will be set after messages are captured
213
+
214
+ async def capturing_wrapper() -> AsyncGenerator[dict[str, Any], None]:
215
+ async for msg in prompt:
216
+ captured.append(msg)
217
+ yield msg
218
+
219
+ # Replace the prompt with our capturing wrapper
220
+ if args:
221
+ args = (capturing_wrapper(),) + args[1:]
222
+ else:
223
+ kwargs["prompt"] = capturing_wrapper()
224
+ else:
225
+ self.__last_prompt = str(prompt)
205
226
 
206
227
  return await self.__client.query(*args, **kwargs)
207
228
 
@@ -215,11 +236,16 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
215
236
  """
216
237
  generator = self.__client.receive_response()
217
238
 
239
+ # Determine the initial input - may be updated later if using async generator
240
+ initial_input = self.__last_prompt if self.__last_prompt else None
241
+
218
242
  with start_span(
219
243
  name="Claude Agent",
220
244
  span_attributes={"type": SpanTypeAttribute.TASK},
221
- input=self.__last_prompt if self.__last_prompt else None,
245
+ input=initial_input,
222
246
  ) as span:
247
+ # If we're capturing async messages, we'll update input after they're consumed
248
+ input_needs_update = self.__captured_messages is not None
223
249
  # Store the parent span export in thread-local storage for tool handlers
224
250
  _thread_local.parent_span_export = span.export()
225
251
 
@@ -228,6 +254,13 @@ def _create_client_wrapper_class(original_client_class: Any) -> Any:
228
254
 
229
255
  try:
230
256
  async for message in generator:
257
+ # Update input from captured async messages (once, after they're consumed)
258
+ if input_needs_update and self.__captured_messages:
259
+ captured_input = _format_captured_messages(self.__captured_messages)
260
+ if captured_input:
261
+ span.log(input=captured_input)
262
+ input_needs_update = False
263
+
231
264
  message_type = type(message).__name__
232
265
 
233
266
  if message_type == "AssistantMessage":
@@ -390,3 +423,12 @@ def _build_llm_input(prompt: Any, conversation_history: list[dict[str, Any]]) ->
390
423
  return [{"content": prompt, "role": "user"}] + conversation_history
391
424
 
392
425
  return conversation_history if conversation_history else None
426
+
427
+
428
+ def _format_captured_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
429
+ """Formats captured async generator messages into structured input.
430
+
431
+ Returns the messages as-is to preserve structure for tracing.
432
+ Empty list returns empty list.
433
+ """
434
+ return messages if messages else []