pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -6,7 +6,6 @@ from typing import Annotated, Any, Literal, Union, cast, overload
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
- from typing_extensions import Self, assert_never
10
9
 
11
10
  from ._utils import now_utc as _now_utc
12
11
  from .exceptions import UnexpectedModelBehavior
@@ -168,22 +167,6 @@ class TextPart:
168
167
  return bool(self.content)
169
168
 
170
169
 
171
- @dataclass
172
- class ArgsJson:
173
- """Tool arguments as a JSON string."""
174
-
175
- args_json: str
176
- """A JSON string of arguments."""
177
-
178
-
179
- @dataclass
180
- class ArgsDict:
181
- """Tool arguments as a Python dictionary."""
182
-
183
- args_dict: dict[str, Any]
184
- """A python dictionary of arguments."""
185
-
186
-
187
170
  @dataclass
188
171
  class ToolCallPart:
189
172
  """A tool call from a model."""
@@ -191,10 +174,10 @@ class ToolCallPart:
191
174
  tool_name: str
192
175
  """The name of the tool to call."""
193
176
 
194
- args: ArgsJson | ArgsDict
177
+ args: str | dict[str, Any]
195
178
  """The arguments to pass to the tool.
196
179
 
197
- Either as JSON or a Python dictionary depending on how data was returned.
180
+ This is stored either as a JSON string or a Python dictionary depending on how data was received.
198
181
  """
199
182
 
200
183
  tool_call_id: str | None = None
@@ -203,24 +186,14 @@ class ToolCallPart:
203
186
  part_kind: Literal['tool-call'] = 'tool-call'
204
187
  """Part type identifier, this is available on all parts as a discriminator."""
205
188
 
206
- @classmethod
207
- def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
208
- """Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
209
- if isinstance(args, str):
210
- return cls(tool_name, ArgsJson(args), tool_call_id)
211
- elif isinstance(args, dict):
212
- return cls(tool_name, ArgsDict(args), tool_call_id)
213
- else:
214
- assert_never(args)
215
-
216
189
  def args_as_dict(self) -> dict[str, Any]:
217
190
  """Return the arguments as a Python dictionary.
218
191
 
219
192
  This is just for convenience with models that require dicts as input.
220
193
  """
221
- if isinstance(self.args, ArgsDict):
222
- return self.args.args_dict
223
- args = pydantic_core.from_json(self.args.args_json)
194
+ if isinstance(self.args, dict):
195
+ return self.args
196
+ args = pydantic_core.from_json(self.args)
224
197
  assert isinstance(args, dict), 'args should be a dict'
225
198
  return cast(dict[str, Any], args)
226
199
 
@@ -229,16 +202,18 @@ class ToolCallPart:
229
202
 
230
203
  This is just for convenience with models that require JSON strings as input.
231
204
  """
232
- if isinstance(self.args, ArgsJson):
233
- return self.args.args_json
234
- return pydantic_core.to_json(self.args.args_dict).decode()
205
+ if isinstance(self.args, str):
206
+ return self.args
207
+ return pydantic_core.to_json(self.args).decode()
235
208
 
236
209
  def has_content(self) -> bool:
237
210
  """Return `True` if the arguments contain any data."""
238
- if isinstance(self.args, ArgsDict):
239
- return any(self.args.args_dict.values())
211
+ if isinstance(self.args, dict):
212
+ # TODO: This should probably return True if you have the value False, or 0, etc.
213
+ # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
214
+ return any(self.args.values())
240
215
  else:
241
- return bool(self.args.args_json)
216
+ return bool(self.args)
242
217
 
243
218
 
244
219
  ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
@@ -252,6 +227,9 @@ class ModelResponse:
252
227
  parts: list[ModelResponsePart]
253
228
  """The parts of the model message."""
254
229
 
230
+ model_name: str | None = None
231
+ """The name of the model that generated the response."""
232
+
255
233
  timestamp: datetime = field(default_factory=_now_utc)
256
234
  """The timestamp of the response.
257
235
 
@@ -261,16 +239,6 @@ class ModelResponse:
261
239
  kind: Literal['response'] = 'response'
262
240
  """Message type identifier, this is available on all parts as a discriminator."""
263
241
 
264
- @classmethod
265
- def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
266
- """Create a `ModelResponse` containing a single `TextPart`."""
267
- return cls([TextPart(content=content)], timestamp=timestamp or _now_utc())
268
-
269
- @classmethod
270
- def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
271
- """Create a `ModelResponse` containing a single `ToolCallPart`."""
272
- return cls([tool_call])
273
-
274
242
 
275
243
  ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
276
244
  """Any message sent to or returned by a model."""
@@ -338,7 +306,7 @@ class ToolCallPartDelta:
338
306
  if self.tool_name_delta is None or self.args_delta is None:
339
307
  return None
340
308
 
341
- return ToolCallPart.from_raw_args(
309
+ return ToolCallPart(
342
310
  self.tool_name_delta,
343
311
  self.args_delta,
344
312
  self.tool_call_id,
@@ -403,7 +371,7 @@ class ToolCallPartDelta:
403
371
 
404
372
  # If we now have enough data to create a full ToolCallPart, do so
405
373
  if delta.tool_name_delta is not None and delta.args_delta is not None:
406
- return ToolCallPart.from_raw_args(
374
+ return ToolCallPart(
407
375
  delta.tool_name_delta,
408
376
  delta.args_delta,
409
377
  delta.tool_call_id,
@@ -419,15 +387,15 @@ class ToolCallPartDelta:
419
387
  part = replace(part, tool_name=tool_name)
420
388
 
421
389
  if isinstance(self.args_delta, str):
422
- if not isinstance(part.args, ArgsJson):
390
+ if not isinstance(part.args, str):
423
391
  raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
424
- updated_json = part.args.args_json + self.args_delta
425
- part = replace(part, args=ArgsJson(updated_json))
392
+ updated_json = part.args + self.args_delta
393
+ part = replace(part, args=updated_json)
426
394
  elif isinstance(self.args_delta, dict):
427
- if not isinstance(part.args, ArgsDict):
395
+ if not isinstance(part.args, dict):
428
396
  raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
429
- updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
430
- part = replace(part, args=ArgsDict(updated_dict))
397
+ updated_dict = {**(part.args or {}), **self.args_delta}
398
+ part = replace(part, args=updated_dict)
431
399
 
432
400
  if self.tool_call_id:
433
401
  # Replace the tool_call_id entirely if given
@@ -12,9 +12,10 @@ from contextlib import asynccontextmanager, contextmanager
12
12
  from dataclasses import dataclass, field
13
13
  from datetime import datetime
14
14
  from functools import cache
15
- from typing import TYPE_CHECKING, Literal
15
+ from typing import TYPE_CHECKING
16
16
 
17
17
  import httpx
18
+ from typing_extensions import Literal
18
19
 
19
20
  from .._parts_manager import ModelResponsePartsManager
20
21
  from ..exceptions import UserError
@@ -27,60 +28,123 @@ if TYPE_CHECKING:
27
28
 
28
29
 
29
30
  KnownModelName = Literal[
30
- 'openai:gpt-4o',
31
- 'openai:gpt-4o-mini',
32
- 'openai:gpt-4-turbo',
33
- 'openai:gpt-4',
34
- 'openai:o1-preview',
35
- 'openai:o1-mini',
36
- 'openai:o1',
37
- 'openai:gpt-3.5-turbo',
38
- 'groq:llama-3.3-70b-versatile',
39
- 'groq:llama-3.1-70b-versatile',
40
- 'groq:llama3-groq-70b-8192-tool-use-preview',
41
- 'groq:llama3-groq-8b-8192-tool-use-preview',
42
- 'groq:llama-3.1-70b-specdec',
31
+ 'anthropic:claude-3-5-haiku-latest',
32
+ 'anthropic:claude-3-5-sonnet-latest',
33
+ 'anthropic:claude-3-opus-latest',
34
+ 'claude-3-5-haiku-latest',
35
+ 'claude-3-5-sonnet-latest',
36
+ 'claude-3-opus-latest',
37
+ 'cohere:c4ai-aya-expanse-32b',
38
+ 'cohere:c4ai-aya-expanse-8b',
39
+ 'cohere:command',
40
+ 'cohere:command-light',
41
+ 'cohere:command-light-nightly',
42
+ 'cohere:command-nightly',
43
+ 'cohere:command-r',
44
+ 'cohere:command-r-03-2024',
45
+ 'cohere:command-r-08-2024',
46
+ 'cohere:command-r-plus',
47
+ 'cohere:command-r-plus-04-2024',
48
+ 'cohere:command-r-plus-08-2024',
49
+ 'cohere:command-r7b-12-2024',
50
+ 'google-gla:gemini-1.0-pro',
51
+ 'google-gla:gemini-1.5-flash',
52
+ 'google-gla:gemini-1.5-flash-8b',
53
+ 'google-gla:gemini-1.5-pro',
54
+ 'google-gla:gemini-2.0-flash-exp',
55
+ 'google-vertex:gemini-1.0-pro',
56
+ 'google-vertex:gemini-1.5-flash',
57
+ 'google-vertex:gemini-1.5-flash-8b',
58
+ 'google-vertex:gemini-1.5-pro',
59
+ 'google-vertex:gemini-2.0-flash-exp',
60
+ 'gpt-3.5-turbo',
61
+ 'gpt-3.5-turbo-0125',
62
+ 'gpt-3.5-turbo-0301',
63
+ 'gpt-3.5-turbo-0613',
64
+ 'gpt-3.5-turbo-1106',
65
+ 'gpt-3.5-turbo-16k',
66
+ 'gpt-3.5-turbo-16k-0613',
67
+ 'gpt-4',
68
+ 'gpt-4-0125-preview',
69
+ 'gpt-4-0314',
70
+ 'gpt-4-0613',
71
+ 'gpt-4-1106-preview',
72
+ 'gpt-4-32k',
73
+ 'gpt-4-32k-0314',
74
+ 'gpt-4-32k-0613',
75
+ 'gpt-4-turbo',
76
+ 'gpt-4-turbo-2024-04-09',
77
+ 'gpt-4-turbo-preview',
78
+ 'gpt-4-vision-preview',
79
+ 'gpt-4o',
80
+ 'gpt-4o-2024-05-13',
81
+ 'gpt-4o-2024-08-06',
82
+ 'gpt-4o-2024-11-20',
83
+ 'gpt-4o-audio-preview',
84
+ 'gpt-4o-audio-preview-2024-10-01',
85
+ 'gpt-4o-audio-preview-2024-12-17',
86
+ 'gpt-4o-mini',
87
+ 'gpt-4o-mini-2024-07-18',
88
+ 'gpt-4o-mini-audio-preview',
89
+ 'gpt-4o-mini-audio-preview-2024-12-17',
90
+ 'groq:gemma2-9b-it',
43
91
  'groq:llama-3.1-8b-instant',
92
+ 'groq:llama-3.2-11b-vision-preview',
44
93
  'groq:llama-3.2-1b-preview',
45
94
  'groq:llama-3.2-3b-preview',
46
- 'groq:llama-3.2-11b-vision-preview',
47
95
  'groq:llama-3.2-90b-vision-preview',
96
+ 'groq:llama-3.3-70b-specdec',
97
+ 'groq:llama-3.3-70b-versatile',
48
98
  'groq:llama3-70b-8192',
49
99
  'groq:llama3-8b-8192',
50
100
  'groq:mixtral-8x7b-32768',
51
- 'groq:gemma2-9b-it',
52
- 'groq:gemma-7b-it',
53
- 'google-gla:gemini-1.5-flash',
54
- 'google-gla:gemini-1.5-pro',
55
- 'google-gla:gemini-2.0-flash-exp',
56
- 'google-vertex:gemini-1.5-flash',
57
- 'google-vertex:gemini-1.5-pro',
58
- 'google-vertex:gemini-2.0-flash-exp',
59
- 'mistral:mistral-small-latest',
60
- 'mistral:mistral-large-latest',
61
101
  'mistral:codestral-latest',
102
+ 'mistral:mistral-large-latest',
62
103
  'mistral:mistral-moderation-latest',
63
- 'ollama:codellama',
64
- 'ollama:gemma',
65
- 'ollama:gemma2',
66
- 'ollama:llama3',
67
- 'ollama:llama3.1',
68
- 'ollama:llama3.2',
69
- 'ollama:llama3.2-vision',
70
- 'ollama:llama3.3',
71
- 'ollama:mistral',
72
- 'ollama:mistral-nemo',
73
- 'ollama:mixtral',
74
- 'ollama:phi3',
75
- 'ollama:phi4',
76
- 'ollama:qwq',
77
- 'ollama:qwen',
78
- 'ollama:qwen2',
79
- 'ollama:qwen2.5',
80
- 'ollama:starcoder2',
81
- 'anthropic:claude-3-5-haiku-latest',
82
- 'anthropic:claude-3-5-sonnet-latest',
83
- 'anthropic:claude-3-opus-latest',
104
+ 'mistral:mistral-small-latest',
105
+ 'o1',
106
+ 'o1-2024-12-17',
107
+ 'o1-mini',
108
+ 'o1-mini-2024-09-12',
109
+ 'o1-preview',
110
+ 'o1-preview-2024-09-12',
111
+ 'openai:chatgpt-4o-latest',
112
+ 'openai:gpt-3.5-turbo',
113
+ 'openai:gpt-3.5-turbo-0125',
114
+ 'openai:gpt-3.5-turbo-0301',
115
+ 'openai:gpt-3.5-turbo-0613',
116
+ 'openai:gpt-3.5-turbo-1106',
117
+ 'openai:gpt-3.5-turbo-16k',
118
+ 'openai:gpt-3.5-turbo-16k-0613',
119
+ 'openai:gpt-4',
120
+ 'openai:gpt-4-0125-preview',
121
+ 'openai:gpt-4-0314',
122
+ 'openai:gpt-4-0613',
123
+ 'openai:gpt-4-1106-preview',
124
+ 'openai:gpt-4-32k',
125
+ 'openai:gpt-4-32k-0314',
126
+ 'openai:gpt-4-32k-0613',
127
+ 'openai:gpt-4-turbo',
128
+ 'openai:gpt-4-turbo-2024-04-09',
129
+ 'openai:gpt-4-turbo-preview',
130
+ 'openai:gpt-4-vision-preview',
131
+ 'openai:gpt-4o',
132
+ 'openai:gpt-4o-2024-05-13',
133
+ 'openai:gpt-4o-2024-08-06',
134
+ 'openai:gpt-4o-2024-11-20',
135
+ 'openai:gpt-4o-audio-preview',
136
+ 'openai:gpt-4o-audio-preview-2024-10-01',
137
+ 'openai:gpt-4o-audio-preview-2024-12-17',
138
+ 'openai:gpt-4o-mini',
139
+ 'openai:gpt-4o-mini-2024-07-18',
140
+ 'openai:gpt-4o-mini-audio-preview',
141
+ 'openai:gpt-4o-mini-audio-preview-2024-12-17',
142
+ 'openai:o1',
143
+ 'openai:o1-2024-12-17',
144
+ 'openai:o1-mini',
145
+ 'openai:o1-mini-2024-09-12',
146
+ 'openai:o1-preview',
147
+ 'openai:o1-preview-2024-09-12',
84
148
  'test',
85
149
  ]
86
150
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -145,6 +209,7 @@ class AgentModel(ABC):
145
209
  class StreamedResponse(ABC):
146
210
  """Streamed response from an LLM when calling a tool."""
147
211
 
212
+ _model_name: str
148
213
  _usage: Usage = field(default_factory=Usage, init=False)
149
214
  _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
150
215
  _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
@@ -168,7 +233,13 @@ class StreamedResponse(ABC):
168
233
 
169
234
  def get(self) -> ModelResponse:
170
235
  """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
171
- return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp())
236
+ return ModelResponse(
237
+ parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
238
+ )
239
+
240
+ def model_name(self) -> str:
241
+ """Get the model name of the response."""
242
+ return self._model_name
172
243
 
173
244
  def usage(self) -> Usage:
174
245
  """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
@@ -228,6 +299,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
228
299
  from .test import TestModel
229
300
 
230
301
  return TestModel()
302
+ elif model.startswith('cohere:'):
303
+ from .cohere import CohereModel
304
+
305
+ return CohereModel(model[7:])
231
306
  elif model.startswith('openai:'):
232
307
  from .openai import OpenAIModel
233
308
 
@@ -263,10 +338,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
263
338
  from .mistral import MistralModel
264
339
 
265
340
  return MistralModel(model[8:])
266
- elif model.startswith('ollama:'):
267
- from .ollama import OllamaModel
268
-
269
- return OllamaModel(model[7:])
270
341
  elif model.startswith('anthropic'):
271
342
  from .anthropic import AnthropicModel
272
343
 
@@ -1,21 +1,23 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from collections.abc import AsyncIterator
3
+ from collections.abc import AsyncIterable, AsyncIterator
4
4
  from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass, field
6
+ from datetime import datetime, timezone
7
+ from json import JSONDecodeError, loads as json_loads
6
8
  from typing import Any, Literal, Union, cast, overload
7
9
 
8
10
  from httpx import AsyncClient as AsyncHTTPClient
9
11
  from typing_extensions import assert_never
10
12
 
11
- from .. import usage
13
+ from .. import UnexpectedModelBehavior, _utils, usage
12
14
  from .._utils import guard_tool_call_id as _guard_tool_call_id
13
15
  from ..messages import (
14
- ArgsDict,
15
16
  ModelMessage,
16
17
  ModelRequest,
17
18
  ModelResponse,
18
19
  ModelResponsePart,
20
+ ModelResponseStreamEvent,
19
21
  RetryPromptPart,
20
22
  SystemPromptPart,
21
23
  TextPart,
@@ -38,11 +40,17 @@ try:
38
40
  from anthropic.types import (
39
41
  Message as AnthropicMessage,
40
42
  MessageParam,
43
+ MetadataParam,
44
+ RawContentBlockDeltaEvent,
45
+ RawContentBlockStartEvent,
46
+ RawContentBlockStopEvent,
41
47
  RawMessageDeltaEvent,
42
48
  RawMessageStartEvent,
49
+ RawMessageStopEvent,
43
50
  RawMessageStreamEvent,
44
51
  TextBlock,
45
52
  TextBlockParam,
53
+ TextDelta,
46
54
  ToolChoiceParam,
47
55
  ToolParam,
48
56
  ToolResultBlockParam,
@@ -71,6 +79,15 @@ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/model
71
79
  """
72
80
 
73
81
 
82
+ class AnthropicModelSettings(ModelSettings):
83
+ """Settings used for an Anthropic model request."""
84
+
85
+ anthropic_metadata: MetadataParam
86
+ """An object describing metadata about the request.
87
+
88
+ Contains `user_id`, an external identifier for the user who is associated with the request."""
89
+
90
+
74
91
  @dataclass(init=False)
75
92
  class AnthropicModel(Model):
76
93
  """A model that uses the Anthropic API.
@@ -152,50 +169,54 @@ class AnthropicAgentModel(AgentModel):
152
169
  """Implementation of `AgentModel` for Anthropic models."""
153
170
 
154
171
  client: AsyncAnthropic
155
- model_name: str
172
+ model_name: AnthropicModelName
156
173
  allow_text_result: bool
157
174
  tools: list[ToolParam]
158
175
 
159
176
  async def request(
160
177
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
178
  ) -> tuple[ModelResponse, usage.Usage]:
162
- response = await self._messages_create(messages, False, model_settings)
179
+ response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
163
180
  return self._process_response(response), _map_usage(response)
164
181
 
165
182
  @asynccontextmanager
166
183
  async def request_stream(
167
184
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
185
  ) -> AsyncIterator[StreamedResponse]:
169
- response = await self._messages_create(messages, True, model_settings)
186
+ response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
170
187
  async with response:
171
188
  yield await self._process_streamed_response(response)
172
189
 
173
190
  @overload
174
191
  async def _messages_create(
175
- self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
192
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
176
193
  ) -> AsyncStream[RawMessageStreamEvent]:
177
194
  pass
178
195
 
179
196
  @overload
180
197
  async def _messages_create(
181
- self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
198
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
182
199
  ) -> AnthropicMessage:
183
200
  pass
184
201
 
185
202
  async def _messages_create(
186
- self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
203
+ self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
187
204
  ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
188
205
  # standalone function to make it easier to override
206
+ tool_choice: ToolChoiceParam | None
207
+
189
208
  if not self.tools:
190
- tool_choice: ToolChoiceParam | None = None
191
- elif not self.allow_text_result:
192
- tool_choice = {'type': 'any'}
209
+ tool_choice = None
193
210
  else:
194
- tool_choice = {'type': 'auto'}
211
+ if not self.allow_text_result:
212
+ tool_choice = {'type': 'any'}
213
+ else:
214
+ tool_choice = {'type': 'auto'}
195
215
 
196
- system_prompt, anthropic_messages = self._map_message(messages)
216
+ if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
217
+ tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
197
218
 
198
- model_settings = model_settings or {}
219
+ system_prompt, anthropic_messages = self._map_message(messages)
199
220
 
200
221
  return await self.client.messages.create(
201
222
  max_tokens=model_settings.get('max_tokens', 1024),
@@ -208,10 +229,10 @@ class AnthropicAgentModel(AgentModel):
208
229
  temperature=model_settings.get('temperature', NOT_GIVEN),
209
230
  top_p=model_settings.get('top_p', NOT_GIVEN),
210
231
  timeout=model_settings.get('timeout', NOT_GIVEN),
232
+ metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
211
233
  )
212
234
 
213
- @staticmethod
214
- def _process_response(response: AnthropicMessage) -> ModelResponse:
235
+ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
215
236
  """Process a non-streamed response, and prepare a message to return."""
216
237
  items: list[ModelResponsePart] = []
217
238
  for item in response.content:
@@ -220,33 +241,24 @@ class AnthropicAgentModel(AgentModel):
220
241
  else:
221
242
  assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
243
  items.append(
223
- ToolCallPart.from_raw_args(
244
+ ToolCallPart(
224
245
  tool_name=item.name,
225
246
  args=cast(dict[str, Any], item.input),
226
247
  tool_call_id=item.id,
227
248
  )
228
249
  )
229
250
 
230
- return ModelResponse(items)
251
+ return ModelResponse(items, model_name=self.model_name)
231
252
 
232
- @staticmethod
233
- async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
234
- """TODO: Process a streamed response, and prepare a streaming response to return."""
235
- # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
- # Streamed responses will be supported in a future release.
237
-
238
- raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
-
240
- # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
241
- # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
- # RawMessageStartEvent
243
- # RawMessageDeltaEvent
244
- # RawMessageStopEvent
245
- # RawContentBlockStartEvent
246
- # RawContentBlockDeltaEvent
247
- # RawContentBlockDeltaEvent
248
- #
249
- # We might refactor streaming internally before we implement this...
253
+ async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
254
+ peekable_response = _utils.PeekableAsyncStream(response)
255
+ first_chunk = await peekable_response.peek()
256
+ if isinstance(first_chunk, _utils.Unset):
257
+ raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
258
+
259
+ # Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
260
+ timestamp = datetime.now(tz=timezone.utc)
261
+ return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
250
262
 
251
263
  @staticmethod
252
264
  def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
@@ -306,7 +318,6 @@ class AnthropicAgentModel(AgentModel):
306
318
 
307
319
 
308
320
  def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
309
- assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
310
321
  return ToolUseBlockParam(
311
322
  id=_guard_tool_call_id(t=t, model_source='Anthropic'),
312
323
  type='tool_use',
@@ -342,3 +353,63 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
342
353
  response_tokens=response_usage.output_tokens,
343
354
  total_tokens=(request_tokens or 0) + response_usage.output_tokens,
344
355
  )
356
+
357
+
358
+ @dataclass
359
+ class AnthropicStreamedResponse(StreamedResponse):
360
+ """Implementation of `StreamedResponse` for Anthropic models."""
361
+
362
+ _response: AsyncIterable[RawMessageStreamEvent]
363
+ _timestamp: datetime
364
+
365
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
366
+ current_block: TextBlock | ToolUseBlock | None = None
367
+ current_json: str = ''
368
+
369
+ async for event in self._response:
370
+ self._usage += _map_usage(event)
371
+
372
+ if isinstance(event, RawContentBlockStartEvent):
373
+ current_block = event.content_block
374
+ if isinstance(current_block, TextBlock) and current_block.text:
375
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
376
+ elif isinstance(current_block, ToolUseBlock):
377
+ maybe_event = self._parts_manager.handle_tool_call_delta(
378
+ vendor_part_id=current_block.id,
379
+ tool_name=current_block.name,
380
+ args=cast(dict[str, Any], current_block.input),
381
+ tool_call_id=current_block.id,
382
+ )
383
+ if maybe_event is not None:
384
+ yield maybe_event
385
+
386
+ elif isinstance(event, RawContentBlockDeltaEvent):
387
+ if isinstance(event.delta, TextDelta):
388
+ yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
389
+ elif (
390
+ current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
391
+ ):
392
+ # Try to parse the JSON immediately, otherwise cache the value for later. This handles
393
+ # cases where the JSON is not currently valid but will be valid once we stream more tokens.
394
+ try:
395
+ parsed_args = json_loads(current_json + event.delta.partial_json)
396
+ current_json = ''
397
+ except JSONDecodeError:
398
+ current_json += event.delta.partial_json
399
+ continue
400
+
401
+ # For tool calls, we need to handle partial JSON updates
402
+ maybe_event = self._parts_manager.handle_tool_call_delta(
403
+ vendor_part_id=current_block.id,
404
+ tool_name='',
405
+ args=parsed_args,
406
+ tool_call_id=current_block.id,
407
+ )
408
+ if maybe_event is not None:
409
+ yield maybe_event
410
+
411
+ elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
412
+ current_block = None
413
+
414
+ def timestamp(self) -> datetime:
415
+ return self._timestamp