anthropic 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. anthropic/__init__.py +2 -0
  2. anthropic/_compat.py +6 -0
  3. anthropic/_models.py +13 -1
  4. anthropic/_utils/_transform.py +1 -1
  5. anthropic/_version.py +1 -1
  6. anthropic/lib/_parse/_response.py +44 -0
  7. anthropic/lib/_parse/_transform.py +167 -0
  8. anthropic/lib/streaming/__init__.py +14 -4
  9. anthropic/lib/streaming/_beta_messages.py +82 -43
  10. anthropic/lib/streaming/_beta_types.py +21 -13
  11. anthropic/lib/tools/_beta_runner.py +102 -101
  12. anthropic/resources/beta/messages/batches.py +12 -12
  13. anthropic/resources/beta/messages/messages.py +365 -29
  14. anthropic/resources/messages/batches.py +12 -12
  15. anthropic/resources/messages/messages.py +14 -8
  16. anthropic/types/beta/__init__.py +1 -0
  17. anthropic/types/beta/beta_code_execution_tool_20250522_param.py +2 -0
  18. anthropic/types/beta/beta_code_execution_tool_20250825_param.py +2 -0
  19. anthropic/types/beta/beta_json_output_format_param.py +15 -0
  20. anthropic/types/beta/beta_memory_tool_20250818_param.py +2 -0
  21. anthropic/types/beta/beta_tool_bash_20241022_param.py +2 -0
  22. anthropic/types/beta/beta_tool_bash_20250124_param.py +2 -0
  23. anthropic/types/beta/beta_tool_computer_use_20241022_param.py +2 -0
  24. anthropic/types/beta/beta_tool_computer_use_20250124_param.py +2 -0
  25. anthropic/types/beta/beta_tool_param.py +2 -0
  26. anthropic/types/beta/beta_tool_text_editor_20241022_param.py +2 -0
  27. anthropic/types/beta/beta_tool_text_editor_20250124_param.py +2 -0
  28. anthropic/types/beta/beta_tool_text_editor_20250429_param.py +2 -0
  29. anthropic/types/beta/beta_tool_text_editor_20250728_param.py +2 -0
  30. anthropic/types/beta/beta_web_fetch_tool_20250910_param.py +2 -0
  31. anthropic/types/beta/beta_web_search_tool_20250305_param.py +2 -0
  32. anthropic/types/beta/message_count_tokens_params.py +4 -0
  33. anthropic/types/beta/message_create_params.py +24 -2
  34. anthropic/types/beta/messages/batch_create_params.py +8 -2
  35. anthropic/types/beta/parsed_beta_message.py +68 -0
  36. anthropic/types/messages/batch_create_params.py +0 -1
  37. {anthropic-0.72.1.dist-info → anthropic-0.73.0.dist-info}/METADATA +1 -1
  38. {anthropic-0.72.1.dist-info → anthropic-0.73.0.dist-info}/RECORD +40 -36
  39. {anthropic-0.72.1.dist-info → anthropic-0.73.0.dist-info}/WHEEL +0 -0
  40. {anthropic-0.72.1.dist-info → anthropic-0.73.0.dist-info}/licenses/LICENSE +0 -0
anthropic/__init__.py CHANGED
@@ -44,6 +44,7 @@ from ._exceptions import (
44
44
  )
45
45
  from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
46
46
  from ._utils._logs import setup_logging as _setup_logging
47
+ from .lib._parse._transform import transform_schema
47
48
 
48
49
  __all__ = [
49
50
  "types",
@@ -91,6 +92,7 @@ __all__ = [
91
92
  "AI_PROMPT",
92
93
  "beta_tool",
93
94
  "beta_async_tool",
95
+ "transform_schema",
94
96
  ]
95
97
 
96
98
  if not _t.TYPE_CHECKING:
anthropic/_compat.py CHANGED
@@ -131,6 +131,12 @@ def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
131
131
  return model.model_dump_json(indent=indent)
132
132
 
133
133
 
134
+ def model_parse_json(model: type[_ModelT], data: str | bytes) -> _ModelT:
135
+ if PYDANTIC_V1:
136
+ return model.parse_raw(data) # pyright: ignore[reportDeprecated]
137
+ return model.model_validate_json(data)
138
+
139
+
134
140
  def model_dump(
135
141
  model: pydantic.BaseModel,
136
142
  *,
anthropic/_models.py CHANGED
@@ -774,7 +774,7 @@ else:
774
774
 
775
775
 
776
776
  if not PYDANTIC_V1:
777
- from pydantic import TypeAdapter as _TypeAdapter
777
+ from pydantic import TypeAdapter as _TypeAdapter, computed_field as computed_field
778
778
 
779
779
  _CachedTypeAdapter = cast("TypeAdapter[object]", lru_cache(maxsize=None)(_TypeAdapter))
780
780
 
@@ -811,6 +811,18 @@ elif not TYPE_CHECKING: # TODO: condition is weird
811
811
  def TypeAdapter(*_args: Any, **_kwargs: Any) -> Any:
812
812
  raise RuntimeError("attempted to use TypeAdapter in pydantic v1")
813
813
 
814
+ def computed_field(func: Any | None = None, /, **__: Any) -> Any:
815
+ def _exc_func(*_: Any, **__: Any) -> Any:
816
+ raise RuntimeError("attempted to use computed_field in pydantic v1")
817
+
818
+ def _dec(*_: Any, **__: Any) -> Any:
819
+ return _exc_func
820
+
821
+ if func is not None:
822
+ return _dec(func)
823
+ else:
824
+ return _dec
825
+
814
826
 
815
827
  class FinalRequestOptionsInput(TypedDict, total=False):
816
828
  method: Required[str]
@@ -218,7 +218,7 @@ def _transform_recursive(
218
218
  return data
219
219
 
220
220
  if isinstance(data, pydantic.BaseModel):
221
- return model_dump(data, exclude_unset=True, mode="json")
221
+ return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))
222
222
 
223
223
  annotated_type = _get_annotated_type(annotation)
224
224
  if annotated_type is None:
anthropic/_version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
2
 
3
3
  __title__ = "anthropic"
4
- __version__ = "0.72.1" # x-release-please-version
4
+ __version__ = "0.73.0" # x-release-please-version
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from typing_extensions import TypeVar
4
+
5
+ from ..._types import NotGiven
6
+ from ..._models import TypeAdapter, construct_type_unchecked
7
+ from ..._utils._utils import is_given
8
+ from ...types.beta.beta_message import BetaMessage
9
+ from ...types.beta.parsed_beta_message import ParsedBetaMessage, ParsedBetaTextBlock, ParsedBetaContentBlock
10
+
11
+ ResponseFormatT = TypeVar("ResponseFormatT", default=None)
12
+
13
+
14
+ def parse_text(text: str, output_format: ResponseFormatT | NotGiven) -> ResponseFormatT | None:
15
+ if is_given(output_format):
16
+ adapted_type: TypeAdapter[ResponseFormatT] = TypeAdapter(output_format)
17
+ return adapted_type.validate_json(text)
18
+ return None
19
+
20
+
21
+ def parse_response(
22
+ *,
23
+ output_format: ResponseFormatT | NotGiven,
24
+ response: BetaMessage,
25
+ ) -> ParsedBetaMessage[ResponseFormatT]:
26
+ content_list: list[ParsedBetaContentBlock[ResponseFormatT]] = []
27
+ for content in response.content:
28
+ if content.type == "text":
29
+ content_list.append(
30
+ construct_type_unchecked(
31
+ type_=ParsedBetaTextBlock[ResponseFormatT],
32
+ value={**content.to_dict(), "parsed_output": parse_text(content.text, output_format)},
33
+ )
34
+ )
35
+ else:
36
+ content_list.append(content) # type: ignore
37
+
38
+ return construct_type_unchecked(
39
+ type_=ParsedBetaMessage[ResponseFormatT],
40
+ value={
41
+ **response.to_dict(),
42
+ "content": content_list,
43
+ },
44
+ )
@@ -0,0 +1,167 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import Any, Literal, Optional, cast
5
+ from typing_extensions import assert_never
6
+
7
+ import pydantic
8
+
9
+ from ..._utils import is_list
10
+
11
+ SupportedTypes = Literal[
12
+ "object",
13
+ "array",
14
+ "string",
15
+ "integer",
16
+ "number",
17
+ "boolean",
18
+ "null",
19
+ ]
20
+
21
+ SupportedStringFormats = {
22
+ "date-time",
23
+ "time",
24
+ "date",
25
+ "duration",
26
+ "email",
27
+ "hostname",
28
+ "uri",
29
+ "ipv4",
30
+ "ipv6",
31
+ "uuid",
32
+ }
33
+
34
+
35
+ def get_transformed_string(
36
+ schema: dict[str, Any],
37
+ ) -> dict[str, Any]:
38
+ """Transforms a JSON schema of type string to ensure it conforms to the API's expectations.
39
+
40
+ Specifically, it ensures that if the schema is of type "string" and does not already
41
+ specify a "format", it sets the format to "text".
42
+
43
+ Args:
44
+ schema: The original JSON schema.
45
+
46
+ Returns:
47
+ The transformed JSON schema.
48
+ """
49
+ if schema.get("type") == "string" and "format" not in schema:
50
+ schema["format"] = "text"
51
+ return schema
52
+
53
+
54
+ def transform_schema(
55
+ json_schema: type[pydantic.BaseModel] | dict[str, Any],
56
+ ) -> dict[str, Any]:
57
+ """
58
+ Transforms a JSON schema to ensure it conforms to the API's expectations.
59
+
60
+ Args:
61
+ json_schema (Dict[str, Any]): The original JSON schema.
62
+
63
+ Returns:
64
+ The transformed JSON schema.
65
+
66
+ Examples:
67
+ >>> transform_schema(
68
+ ... {
69
+ ... "type": "integer",
70
+ ... "minimum": 1,
71
+ ... "maximum": 10,
72
+ ... "description": "A number",
73
+ ... }
74
+ ... )
75
+ {'type': 'integer', 'description': 'A number\n\n{minimum: 1, maximum: 10}'}
76
+ """
77
+ if inspect.isclass(json_schema) and issubclass(json_schema, pydantic.BaseModel): # pyright: ignore[reportUnnecessaryIsInstance]
78
+ json_schema = json_schema.model_json_schema()
79
+
80
+ strict_schema: dict[str, Any] = {}
81
+ json_schema = {**json_schema}
82
+
83
+ ref = json_schema.pop("$ref", None)
84
+ if ref is not None:
85
+ strict_schema["$ref"] = ref
86
+ return strict_schema
87
+
88
+ defs = json_schema.pop("$defs", None)
89
+ if defs is not None:
90
+ strict_defs: dict[str, Any] = {}
91
+ strict_schema["$defs"] = strict_defs
92
+
93
+ for name, schema in defs.items():
94
+ strict_defs[name] = transform_schema(schema)
95
+
96
+ type_: Optional[SupportedTypes] = json_schema.pop("type", None)
97
+ any_of = json_schema.pop("anyOf", None)
98
+ one_of = json_schema.pop("oneOf", None)
99
+ all_of = json_schema.pop("allOf", None)
100
+
101
+ if is_list(any_of):
102
+ strict_schema["anyOf"] = [transform_schema(cast("dict[str, Any]", variant)) for variant in any_of]
103
+ elif is_list(one_of):
104
+ strict_schema["anyOf"] = [transform_schema(cast("dict[str, Any]", variant)) for variant in one_of]
105
+ elif is_list(all_of):
106
+ strict_schema["allOf"] = [transform_schema(cast("dict[str, Any]", variant)) for variant in all_of]
107
+ else:
108
+ if type_ is None:
109
+ raise ValueError("Schema must have a 'type', 'anyOf', 'oneOf', or 'allOf' field.")
110
+
111
+ strict_schema["type"] = type_
112
+
113
+ description = json_schema.pop("description", None)
114
+ if description is not None:
115
+ strict_schema["description"] = description
116
+
117
+ title = json_schema.pop("title", None)
118
+ if title is not None:
119
+ strict_schema["title"] = title
120
+
121
+ if type_ == "object":
122
+ strict_schema["properties"] = {
123
+ key: transform_schema(prop_schema) for key, prop_schema in json_schema.pop("properties", {}).items()
124
+ }
125
+ json_schema.pop("additionalProperties", None)
126
+ strict_schema["additionalProperties"] = False
127
+
128
+ required = json_schema.pop("required", None)
129
+ if required is not None:
130
+ strict_schema["required"] = required
131
+
132
+ elif type_ == "string":
133
+ format = json_schema.pop("format", None)
134
+ if format and format in SupportedStringFormats:
135
+ strict_schema["format"] = format
136
+ elif format:
137
+ # add it back so its treated as an extra property and appended to the description
138
+ json_schema["format"] = format
139
+ elif type_ == "array":
140
+ items = json_schema.pop("items", None)
141
+ if items is not None:
142
+ strict_schema["items"] = transform_schema(items)
143
+
144
+ min_items = json_schema.pop("minItems", None)
145
+ if min_items is not None and min_items == 0 or min_items == 1:
146
+ strict_schema["minItems"] = min_items
147
+ elif min_items is not None:
148
+ # add it back so its treated as an extra property and appended to the description
149
+ json_schema["minItems"] = min_items
150
+
151
+ elif type_ == "boolean" or type_ == "integer" or type_ == "number" or type_ == "null" or type_ is None:
152
+ pass
153
+ else:
154
+ assert_never(type_)
155
+
156
+ # if there are any propes leftover then they aren't supported, so we add them to the description
157
+ # so that the model *might* follow them.
158
+ if json_schema:
159
+ description = strict_schema.get("description")
160
+ strict_schema["description"] = (
161
+ (description + "\n\n" if description is not None else "")
162
+ + "{"
163
+ + ", ".join(f"{key}: {value}" for key, value in json_schema.items())
164
+ + "}"
165
+ )
166
+
167
+ return strict_schema
@@ -1,3 +1,5 @@
1
+ from typing_extensions import TypeAlias
2
+
1
3
  from ._types import (
2
4
  TextEvent as TextEvent,
3
5
  InputJsonEvent as InputJsonEvent,
@@ -12,12 +14,20 @@ from ._messages import (
12
14
  AsyncMessageStreamManager as AsyncMessageStreamManager,
13
15
  )
14
16
  from ._beta_types import (
15
- BetaTextEvent as BetaTextEvent,
16
17
  BetaInputJsonEvent as BetaInputJsonEvent,
17
- BetaMessageStopEvent as BetaMessageStopEvent,
18
- BetaMessageStreamEvent as BetaMessageStreamEvent,
19
- BetaContentBlockStopEvent as BetaContentBlockStopEvent,
18
+ ParsedBetaTextEvent as ParsedBetaTextEvent,
19
+ ParsedBetaMessageStopEvent as ParsedBetaMessageStopEvent,
20
+ ParsedBetaMessageStreamEvent as ParsedBetaMessageStreamEvent,
21
+ ParsedBetaContentBlockStopEvent as ParsedBetaContentBlockStopEvent,
20
22
  )
23
+
24
+ # For backwards compatibility
25
+ BetaTextEvent: TypeAlias = ParsedBetaTextEvent
26
+ BetaMessageStopEvent: TypeAlias = ParsedBetaMessageStopEvent[object]
27
+ BetaMessageStreamEvent: TypeAlias = ParsedBetaMessageStreamEvent
28
+ BetaContentBlockStopEvent: TypeAlias = ParsedBetaContentBlockStopEvent[object]
29
+
30
+
21
31
  from ._beta_messages import (
22
32
  BetaMessageStream as BetaMessageStream,
23
33
  BetaAsyncMessageStream as BetaAsyncMessageStream,
@@ -1,7 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import builtins
3
4
  from types import TracebackType
4
- from typing import TYPE_CHECKING, Any, Type, Callable, cast
5
+ from typing import TYPE_CHECKING, Any, Type, Generic, Callable, cast
5
6
  from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never
6
7
 
7
8
  import httpx
@@ -11,23 +12,27 @@ from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
11
12
  from anthropic.types.beta.beta_mcp_tool_use_block import BetaMCPToolUseBlock
12
13
  from anthropic.types.beta.beta_server_tool_use_block import BetaServerToolUseBlock
13
14
 
15
+ from ..._types import NOT_GIVEN, NotGiven
14
16
  from ..._utils import consume_sync_iterator, consume_async_iterator
15
17
  from ..._models import build, construct_type, construct_type_unchecked
16
18
  from ._beta_types import (
17
- BetaTextEvent,
18
19
  BetaCitationEvent,
19
20
  BetaThinkingEvent,
20
21
  BetaInputJsonEvent,
21
22
  BetaSignatureEvent,
22
- BetaMessageStopEvent,
23
- BetaMessageStreamEvent,
24
- BetaContentBlockStopEvent,
23
+ ParsedBetaTextEvent,
24
+ ParsedBetaMessageStopEvent,
25
+ ParsedBetaMessageStreamEvent,
26
+ ParsedBetaContentBlockStopEvent,
25
27
  )
26
28
  from ..._streaming import Stream, AsyncStream
27
- from ...types.beta import BetaMessage, BetaContentBlock, BetaRawMessageStreamEvent
29
+ from ...types.beta import BetaRawMessageStreamEvent
30
+ from ..._utils._utils import is_given
31
+ from .._parse._response import ResponseFormatT, parse_text
32
+ from ...types.beta.parsed_beta_message import ParsedBetaMessage, ParsedBetaContentBlock
28
33
 
29
34
 
30
- class BetaMessageStream:
35
+ class BetaMessageStream(Generic[ResponseFormatT]):
31
36
  text_stream: Iterator[str]
32
37
  """Iterator over just the text deltas in the stream.
33
38
 
@@ -38,11 +43,16 @@ class BetaMessageStream:
38
43
  ```
39
44
  """
40
45
 
41
- def __init__(self, raw_stream: Stream[BetaRawMessageStreamEvent]) -> None:
46
+ def __init__(
47
+ self,
48
+ raw_stream: Stream[BetaRawMessageStreamEvent],
49
+ output_format: ResponseFormatT | NotGiven,
50
+ ) -> None:
42
51
  self._raw_stream = raw_stream
43
52
  self.text_stream = self.__stream_text__()
44
53
  self._iterator = self.__stream__()
45
- self.__final_message_snapshot: BetaMessage | None = None
54
+ self.__final_message_snapshot: ParsedBetaMessage[ResponseFormatT] | None = None
55
+ self.__output_format = output_format
46
56
 
47
57
  @property
48
58
  def response(self) -> httpx.Response:
@@ -52,10 +62,10 @@ class BetaMessageStream:
52
62
  def request_id(self) -> str | None:
53
63
  return self.response.headers.get("request-id") # type: ignore[no-any-return]
54
64
 
55
- def __next__(self) -> BetaMessageStreamEvent:
65
+ def __next__(self) -> ParsedBetaMessageStreamEvent[ResponseFormatT]:
56
66
  return self._iterator.__next__()
57
67
 
58
- def __iter__(self) -> Iterator[BetaMessageStreamEvent]:
68
+ def __iter__(self) -> Iterator[ParsedBetaMessageStreamEvent[ResponseFormatT]]:
59
69
  for item in self._iterator:
60
70
  yield item
61
71
 
@@ -78,7 +88,7 @@ class BetaMessageStream:
78
88
  """
79
89
  self._raw_stream.close()
80
90
 
81
- def get_final_message(self) -> BetaMessage:
91
+ def get_final_message(self) -> ParsedBetaMessage[ResponseFormatT]:
82
92
  """Waits until the stream has been read to completion and returns
83
93
  the accumulated `Message` object.
84
94
  """
@@ -113,16 +123,17 @@ class BetaMessageStream:
113
123
 
114
124
  # properties
115
125
  @property
116
- def current_message_snapshot(self) -> BetaMessage:
126
+ def current_message_snapshot(self) -> ParsedBetaMessage[ResponseFormatT]:
117
127
  assert self.__final_message_snapshot is not None
118
128
  return self.__final_message_snapshot
119
129
 
120
- def __stream__(self) -> Iterator[BetaMessageStreamEvent]:
130
+ def __stream__(self) -> Iterator[ParsedBetaMessageStreamEvent[ResponseFormatT]]:
121
131
  for sse_event in self._raw_stream:
122
132
  self.__final_message_snapshot = accumulate_event(
123
133
  event=sse_event,
124
134
  current_snapshot=self.__final_message_snapshot,
125
135
  request_headers=self.response.request.headers,
136
+ output_format=self.__output_format,
126
137
  )
127
138
 
128
139
  events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
@@ -135,7 +146,7 @@ class BetaMessageStream:
135
146
  yield chunk.delta.text
136
147
 
137
148
 
138
- class BetaMessageStreamManager:
149
+ class BetaMessageStreamManager(Generic[ResponseFormatT]):
139
150
  """Wrapper over MessageStream that is returned by `.stream()`.
140
151
 
141
152
  ```py
@@ -148,13 +159,16 @@ class BetaMessageStreamManager:
148
159
  def __init__(
149
160
  self,
150
161
  api_request: Callable[[], Stream[BetaRawMessageStreamEvent]],
162
+ *,
163
+ output_format: ResponseFormatT | NotGiven,
151
164
  ) -> None:
152
- self.__stream: BetaMessageStream | None = None
165
+ self.__stream: BetaMessageStream[ResponseFormatT] | None = None
153
166
  self.__api_request = api_request
167
+ self.__output_format = output_format
154
168
 
155
- def __enter__(self) -> BetaMessageStream:
169
+ def __enter__(self) -> BetaMessageStream[ResponseFormatT]:
156
170
  raw_stream = self.__api_request()
157
- self.__stream = BetaMessageStream(raw_stream)
171
+ self.__stream = BetaMessageStream(raw_stream, output_format=self.__output_format)
158
172
  return self.__stream
159
173
 
160
174
  def __exit__(
@@ -167,7 +181,7 @@ class BetaMessageStreamManager:
167
181
  self.__stream.close()
168
182
 
169
183
 
170
- class BetaAsyncMessageStream:
184
+ class BetaAsyncMessageStream(Generic[ResponseFormatT]):
171
185
  text_stream: AsyncIterator[str]
172
186
  """Async iterator over just the text deltas in the stream.
173
187
 
@@ -178,11 +192,16 @@ class BetaAsyncMessageStream:
178
192
  ```
179
193
  """
180
194
 
181
- def __init__(self, raw_stream: AsyncStream[BetaRawMessageStreamEvent]) -> None:
195
+ def __init__(
196
+ self,
197
+ raw_stream: AsyncStream[BetaRawMessageStreamEvent],
198
+ output_format: ResponseFormatT | NotGiven,
199
+ ) -> None:
182
200
  self._raw_stream = raw_stream
183
201
  self.text_stream = self.__stream_text__()
184
202
  self._iterator = self.__stream__()
185
- self.__final_message_snapshot: BetaMessage | None = None
203
+ self.__final_message_snapshot: ParsedBetaMessage[ResponseFormatT] | None = None
204
+ self.__output_format = output_format
186
205
 
187
206
  @property
188
207
  def response(self) -> httpx.Response:
@@ -192,10 +211,10 @@ class BetaAsyncMessageStream:
192
211
  def request_id(self) -> str | None:
193
212
  return self.response.headers.get("request-id") # type: ignore[no-any-return]
194
213
 
195
- async def __anext__(self) -> BetaMessageStreamEvent:
214
+ async def __anext__(self) -> ParsedBetaMessageStreamEvent[ResponseFormatT]:
196
215
  return await self._iterator.__anext__()
197
216
 
198
- async def __aiter__(self) -> AsyncIterator[BetaMessageStreamEvent]:
217
+ async def __aiter__(self) -> AsyncIterator[ParsedBetaMessageStreamEvent[ResponseFormatT]]:
199
218
  async for item in self._iterator:
200
219
  yield item
201
220
 
@@ -218,7 +237,7 @@ class BetaAsyncMessageStream:
218
237
  """
219
238
  await self._raw_stream.close()
220
239
 
221
- async def get_final_message(self) -> BetaMessage:
240
+ async def get_final_message(self) -> ParsedBetaMessage[ResponseFormatT]:
222
241
  """Waits until the stream has been read to completion and returns
223
242
  the accumulated `Message` object.
224
243
  """
@@ -253,16 +272,17 @@ class BetaAsyncMessageStream:
253
272
 
254
273
  # properties
255
274
  @property
256
- def current_message_snapshot(self) -> BetaMessage:
275
+ def current_message_snapshot(self) -> ParsedBetaMessage[ResponseFormatT]:
257
276
  assert self.__final_message_snapshot is not None
258
277
  return self.__final_message_snapshot
259
278
 
260
- async def __stream__(self) -> AsyncIterator[BetaMessageStreamEvent]:
279
+ async def __stream__(self) -> AsyncIterator[ParsedBetaMessageStreamEvent[ResponseFormatT]]:
261
280
  async for sse_event in self._raw_stream:
262
281
  self.__final_message_snapshot = accumulate_event(
263
282
  event=sse_event,
264
283
  current_snapshot=self.__final_message_snapshot,
265
284
  request_headers=self.response.request.headers,
285
+ output_format=self.__output_format,
266
286
  )
267
287
 
268
288
  events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot)
@@ -275,7 +295,7 @@ class BetaAsyncMessageStream:
275
295
  yield chunk.delta.text
276
296
 
277
297
 
278
- class BetaAsyncMessageStreamManager:
298
+ class BetaAsyncMessageStreamManager(Generic[ResponseFormatT]):
279
299
  """Wrapper over BetaAsyncMessageStream that is returned by `.stream()`
280
300
  so that an async context manager can be used without `await`ing the
281
301
  original client call.
@@ -290,13 +310,16 @@ class BetaAsyncMessageStreamManager:
290
310
  def __init__(
291
311
  self,
292
312
  api_request: Awaitable[AsyncStream[BetaRawMessageStreamEvent]],
313
+ *,
314
+ output_format: ResponseFormatT | NotGiven = NOT_GIVEN,
293
315
  ) -> None:
294
- self.__stream: BetaAsyncMessageStream | None = None
316
+ self.__stream: BetaAsyncMessageStream[ResponseFormatT] | None = None
295
317
  self.__api_request = api_request
318
+ self.__output_format = output_format
296
319
 
297
- async def __aenter__(self) -> BetaAsyncMessageStream:
320
+ async def __aenter__(self) -> BetaAsyncMessageStream[ResponseFormatT]:
298
321
  raw_stream = await self.__api_request
299
- self.__stream = BetaAsyncMessageStream(raw_stream)
322
+ self.__stream = BetaAsyncMessageStream(raw_stream, output_format=self.__output_format)
300
323
  return self.__stream
301
324
 
302
325
  async def __aexit__(
@@ -312,16 +335,18 @@ class BetaAsyncMessageStreamManager:
312
335
  def build_events(
313
336
  *,
314
337
  event: BetaRawMessageStreamEvent,
315
- message_snapshot: BetaMessage,
316
- ) -> list[BetaMessageStreamEvent]:
317
- events_to_fire: list[BetaMessageStreamEvent] = []
338
+ message_snapshot: ParsedBetaMessage[ResponseFormatT],
339
+ ) -> list[ParsedBetaMessageStreamEvent[ResponseFormatT]]:
340
+ events_to_fire: list[ParsedBetaMessageStreamEvent[ResponseFormatT]] = []
318
341
 
319
342
  if event.type == "message_start":
320
343
  events_to_fire.append(event)
321
344
  elif event.type == "message_delta":
322
345
  events_to_fire.append(event)
323
346
  elif event.type == "message_stop":
324
- events_to_fire.append(build(BetaMessageStopEvent, type="message_stop", message=message_snapshot))
347
+ events_to_fire.append(
348
+ build(ParsedBetaMessageStopEvent[ResponseFormatT], type="message_stop", message=message_snapshot)
349
+ )
325
350
  elif event.type == "content_block_start":
326
351
  events_to_fire.append(event)
327
352
  elif event.type == "content_block_delta":
@@ -332,7 +357,7 @@ def build_events(
332
357
  if content_block.type == "text":
333
358
  events_to_fire.append(
334
359
  build(
335
- BetaTextEvent,
360
+ ParsedBetaTextEvent,
336
361
  type="text",
337
362
  text=event.delta.text,
338
363
  snapshot=content_block.text,
@@ -385,9 +410,14 @@ def build_events(
385
410
  elif event.type == "content_block_stop":
386
411
  content_block = message_snapshot.content[event.index]
387
412
 
388
- events_to_fire.append(
389
- build(BetaContentBlockStopEvent, type="content_block_stop", index=event.index, content_block=content_block),
413
+ event_to_fire = build(
414
+ ParsedBetaContentBlockStopEvent,
415
+ type="content_block_stop",
416
+ index=event.index,
417
+ content_block=content_block,
390
418
  )
419
+
420
+ events_to_fire.append(event_to_fire)
391
421
  else:
392
422
  # we only want exhaustive checking for linters, not at runtime
393
423
  if TYPE_CHECKING: # type: ignore[unreachable]
@@ -408,9 +438,10 @@ TRACKS_TOOL_INPUT = (
408
438
  def accumulate_event(
409
439
  *,
410
440
  event: BetaRawMessageStreamEvent,
411
- current_snapshot: BetaMessage | None,
441
+ current_snapshot: ParsedBetaMessage[ResponseFormatT] | None,
412
442
  request_headers: httpx.Headers,
413
- ) -> BetaMessage:
443
+ output_format: ResponseFormatT | NotGiven = NOT_GIVEN,
444
+ ) -> ParsedBetaMessage[ResponseFormatT]:
414
445
  if not isinstance(cast(Any, event), BaseModel):
415
446
  event = cast( # pyright: ignore[reportUnnecessaryCast]
416
447
  BetaRawMessageStreamEvent,
@@ -420,11 +451,15 @@ def accumulate_event(
420
451
  ),
421
452
  )
422
453
  if not isinstance(cast(Any, event), BaseModel):
423
- raise TypeError(f"Unexpected event runtime type, after deserialising twice - {event} - {type(event)}")
454
+ raise TypeError(
455
+ f"Unexpected event runtime type, after deserialising twice - {event} - {builtins.type(event)}"
456
+ )
424
457
 
425
458
  if current_snapshot is None:
426
459
  if event.type == "message_start":
427
- return BetaMessage.construct(**cast(Any, event.message.to_dict()))
460
+ return cast(
461
+ ParsedBetaMessage[ResponseFormatT], ParsedBetaMessage.construct(**cast(Any, event.message.to_dict()))
462
+ )
428
463
 
429
464
  raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"')
430
465
 
@@ -432,8 +467,8 @@ def accumulate_event(
432
467
  # TODO: check index
433
468
  current_snapshot.content.append(
434
469
  cast(
435
- BetaContentBlock,
436
- construct_type(type_=BetaContentBlock, value=event.content_block.model_dump()),
470
+ Any, # Pydantic does not support generic unions at runtime
471
+ construct_type(type_=ParsedBetaContentBlock, value=event.content_block.model_dump()),
437
472
  ),
438
473
  )
439
474
  elif event.type == "content_block_delta":
@@ -481,6 +516,10 @@ def accumulate_event(
481
516
  # we only want exhaustive checking for linters, not at runtime
482
517
  if TYPE_CHECKING: # type: ignore[unreachable]
483
518
  assert_never(event.delta)
519
+ elif event.type == "content_block_stop":
520
+ content_block = current_snapshot.content[event.index]
521
+ if content_block.type == "text" and is_given(output_format):
522
+ content_block.parsed_output = parse_text(content_block.text, output_format)
484
523
  elif event.type == "message_delta":
485
524
  current_snapshot.container = event.delta.container
486
525
  current_snapshot.stop_reason = event.delta.stop_reason