pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -37,7 +37,8 @@ def format_as_xml(
37
37
  none_str: String to use for `None` values.
38
38
  indent: Indentation string to use for pretty printing.
39
39
 
40
- Returns: XML representation of the object.
40
+ Returns:
41
+ XML representation of the object.
41
42
 
42
43
  Example:
43
44
  ```python {title="format_as_xml_example.py" lint="skip"}
pydantic_ai/messages.py CHANGED
@@ -1,14 +1,15 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- from dataclasses import dataclass, field
3
+ from dataclasses import dataclass, field, replace
4
4
  from datetime import datetime
5
- from typing import Annotated, Any, Literal, Union, cast
5
+ from typing import Annotated, Any, Literal, Union, cast, overload
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
9
  from typing_extensions import Self, assert_never
10
10
 
11
11
  from ._utils import now_utc as _now_utc
12
+ from .exceptions import UnexpectedModelBehavior
12
13
 
13
14
 
14
15
  @dataclass
@@ -21,6 +22,12 @@ class SystemPromptPart:
21
22
  content: str
22
23
  """The content of the prompt."""
23
24
 
25
+ dynamic_ref: str | None = None
26
+ """The ref of the dynamic system prompt function that generated this part.
27
+
28
+ Only set if system prompt is dynamic, see [`system_prompt`][pydantic_ai.Agent.system_prompt] for more information.
29
+ """
30
+
24
31
  part_kind: Literal['system-prompt'] = 'system-prompt'
25
32
  """Part type identifier, this is available on all parts as a discriminator."""
26
33
 
@@ -66,12 +73,14 @@ class ToolReturnPart:
66
73
  """Part type identifier, this is available on all parts as a discriminator."""
67
74
 
68
75
  def model_response_str(self) -> str:
76
+ """Return a string representation of the content for the model."""
69
77
  if isinstance(self.content, str):
70
78
  return self.content
71
79
  else:
72
80
  return tool_return_ta.dump_json(self.content).decode()
73
81
 
74
82
  def model_response_object(self) -> dict[str, Any]:
83
+ """Return a dictionary representation of the content, wrapping non-dict types appropriately."""
75
84
  # gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
76
85
  if isinstance(self.content, dict):
77
86
  return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
@@ -118,6 +127,7 @@ class RetryPromptPart:
118
127
  """Part type identifier, this is available on all parts as a discriminator."""
119
128
 
120
129
  def model_response(self) -> str:
130
+ """Return a string message describing why the retry is requested."""
121
131
  if isinstance(self.content, str):
122
132
  description = self.content
123
133
  else:
@@ -153,6 +163,10 @@ class TextPart:
153
163
  part_kind: Literal['text'] = 'text'
154
164
  """Part type identifier, this is available on all parts as a discriminator."""
155
165
 
166
+ def has_content(self) -> bool:
167
+ """Return `True` if the text content is non-empty."""
168
+ return bool(self.content)
169
+
156
170
 
157
171
  @dataclass
158
172
  class ArgsJson:
@@ -191,7 +205,7 @@ class ToolCallPart:
191
205
 
192
206
  @classmethod
193
207
  def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
194
- """Create a `ToolCallPart` from raw arguments."""
208
+ """Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
195
209
  if isinstance(args, str):
196
210
  return cls(tool_name, ArgsJson(args), tool_call_id)
197
211
  elif isinstance(args, dict):
@@ -220,6 +234,7 @@ class ToolCallPart:
220
234
  return pydantic_core.to_json(self.args.args_dict).decode()
221
235
 
222
236
  def has_content(self) -> bool:
237
+ """Return `True` if the arguments contain any data."""
223
238
  if isinstance(self.args, ArgsDict):
224
239
  return any(self.args.args_dict.values())
225
240
  else:
@@ -248,17 +263,217 @@ class ModelResponse:
248
263
 
249
264
  @classmethod
250
265
  def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
251
- return cls([TextPart(content)], timestamp=timestamp or _now_utc())
266
+ """Create a `ModelResponse` containing a single `TextPart`."""
267
+ return cls([TextPart(content=content)], timestamp=timestamp or _now_utc())
252
268
 
253
269
  @classmethod
254
270
  def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
271
+ """Create a `ModelResponse` containing a single `ToolCallPart`."""
255
272
  return cls([tool_call])
256
273
 
257
274
 
258
- ModelMessage = Union[ModelRequest, ModelResponse]
259
- """Any message send to or returned by a model."""
275
+ ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
276
+ """Any message sent to or returned by a model."""
260
277
 
261
- ModelMessagesTypeAdapter = pydantic.TypeAdapter(
262
- list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
263
- )
278
+ ModelMessagesTypeAdapter = pydantic.TypeAdapter(list[ModelMessage], config=pydantic.ConfigDict(defer_build=True))
264
279
  """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
280
+
281
+
282
+ @dataclass
283
+ class TextPartDelta:
284
+ """A partial update (delta) for a `TextPart` to append new text content."""
285
+
286
+ content_delta: str
287
+ """The incremental text content to add to the existing `TextPart` content."""
288
+
289
+ part_delta_kind: Literal['text'] = 'text'
290
+ """Part delta type identifier, used as a discriminator."""
291
+
292
+ def apply(self, part: ModelResponsePart) -> TextPart:
293
+ """Apply this text delta to an existing `TextPart`.
294
+
295
+ Args:
296
+ part: The existing model response part, which must be a `TextPart`.
297
+
298
+ Returns:
299
+ A new `TextPart` with updated text content.
300
+
301
+ Raises:
302
+ ValueError: If `part` is not a `TextPart`.
303
+ """
304
+ if not isinstance(part, TextPart):
305
+ raise ValueError('Cannot apply TextPartDeltas to non-TextParts')
306
+ return replace(part, content=part.content + self.content_delta)
307
+
308
+
309
+ @dataclass
310
+ class ToolCallPartDelta:
311
+ """A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
312
+
313
+ tool_name_delta: str | None = None
314
+ """Incremental text to add to the existing tool name, if any."""
315
+
316
+ args_delta: str | dict[str, Any] | None = None
317
+ """Incremental data to add to the tool arguments.
318
+
319
+ If this is a string, it will be appended to existing JSON arguments.
320
+ If this is a dict, it will be merged with existing dict arguments.
321
+ """
322
+
323
+ tool_call_id: str | None = None
324
+ """Optional tool call identifier, this is used by some models including OpenAI.
325
+
326
+ Note this is never treated as a delta — it can replace None, but otherwise if a
327
+ non-matching value is provided an error will be raised."""
328
+
329
+ part_delta_kind: Literal['tool_call'] = 'tool_call'
330
+ """Part delta type identifier, used as a discriminator."""
331
+
332
+ def as_part(self) -> ToolCallPart | None:
333
+ """Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
334
+
335
+ Returns:
336
+ A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
337
+ """
338
+ if self.tool_name_delta is None or self.args_delta is None:
339
+ return None
340
+
341
+ return ToolCallPart.from_raw_args(
342
+ self.tool_name_delta,
343
+ self.args_delta,
344
+ self.tool_call_id,
345
+ )
346
+
347
+ @overload
348
+ def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
349
+
350
+ @overload
351
+ def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta: ...
352
+
353
+ def apply(self, part: ModelResponsePart | ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
354
+ """Apply this delta to a part or delta, returning a new part or delta with the changes applied.
355
+
356
+ Args:
357
+ part: The existing model response part or delta to update.
358
+
359
+ Returns:
360
+ Either a new `ToolCallPart` or an updated `ToolCallPartDelta`.
361
+
362
+ Raises:
363
+ ValueError: If `part` is neither a `ToolCallPart` nor a `ToolCallPartDelta`.
364
+ UnexpectedModelBehavior: If applying JSON deltas to dict arguments or vice versa.
365
+ """
366
+ if isinstance(part, ToolCallPart):
367
+ return self._apply_to_part(part)
368
+
369
+ if isinstance(part, ToolCallPartDelta):
370
+ return self._apply_to_delta(part)
371
+
372
+ raise ValueError(f'Can only apply ToolCallPartDeltas to ToolCallParts or ToolCallPartDeltas, not {part}')
373
+
374
+ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPartDelta:
375
+ """Internal helper to apply this delta to another delta."""
376
+ if self.tool_name_delta:
377
+ # Append incremental text to the existing tool_name_delta
378
+ updated_tool_name_delta = (delta.tool_name_delta or '') + self.tool_name_delta
379
+ delta = replace(delta, tool_name_delta=updated_tool_name_delta)
380
+
381
+ if isinstance(self.args_delta, str):
382
+ if isinstance(delta.args_delta, dict):
383
+ raise UnexpectedModelBehavior(
384
+ f'Cannot apply JSON deltas to non-JSON tool arguments ({delta=}, {self=})'
385
+ )
386
+ updated_args_delta = (delta.args_delta or '') + self.args_delta
387
+ delta = replace(delta, args_delta=updated_args_delta)
388
+ elif isinstance(self.args_delta, dict):
389
+ if isinstance(delta.args_delta, str):
390
+ raise UnexpectedModelBehavior(
391
+ f'Cannot apply dict deltas to non-dict tool arguments ({delta=}, {self=})'
392
+ )
393
+ updated_args_delta = {**(delta.args_delta or {}), **self.args_delta}
394
+ delta = replace(delta, args_delta=updated_args_delta)
395
+
396
+ if self.tool_call_id:
397
+ # Set the tool_call_id if it wasn't present, otherwise error if it has changed
398
+ if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
399
+ raise UnexpectedModelBehavior(
400
+ f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
401
+ )
402
+ delta = replace(delta, tool_call_id=self.tool_call_id)
403
+
404
+ # If we now have enough data to create a full ToolCallPart, do so
405
+ if delta.tool_name_delta is not None and delta.args_delta is not None:
406
+ return ToolCallPart.from_raw_args(
407
+ delta.tool_name_delta,
408
+ delta.args_delta,
409
+ delta.tool_call_id,
410
+ )
411
+
412
+ return delta
413
+
414
+ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
415
+ """Internal helper to apply this delta directly to a `ToolCallPart`."""
416
+ if self.tool_name_delta:
417
+ # Append incremental text to the existing tool_name
418
+ tool_name = part.tool_name + self.tool_name_delta
419
+ part = replace(part, tool_name=tool_name)
420
+
421
+ if isinstance(self.args_delta, str):
422
+ if not isinstance(part.args, ArgsJson):
423
+ raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
424
+ updated_json = part.args.args_json + self.args_delta
425
+ part = replace(part, args=ArgsJson(updated_json))
426
+ elif isinstance(self.args_delta, dict):
427
+ if not isinstance(part.args, ArgsDict):
428
+ raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
429
+ updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
430
+ part = replace(part, args=ArgsDict(updated_dict))
431
+
432
+ if self.tool_call_id:
433
+ # Replace the tool_call_id entirely if given
434
+ if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
435
+ raise UnexpectedModelBehavior(
436
+ f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
437
+ )
438
+ part = replace(part, tool_call_id=self.tool_call_id)
439
+ return part
440
+
441
+
442
+ ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
443
+ """A partial update (delta) for any model response part."""
444
+
445
+
446
+ @dataclass
447
+ class PartStartEvent:
448
+ """An event indicating that a new part has started.
449
+
450
+ If multiple `PartStartEvent`s are received with the same index,
451
+ the new one should fully replace the old one.
452
+ """
453
+
454
+ index: int
455
+ """The index of the part within the overall response parts list."""
456
+
457
+ part: ModelResponsePart
458
+ """The newly started `ModelResponsePart`."""
459
+
460
+ event_kind: Literal['part_start'] = 'part_start'
461
+ """Event type identifier, used as a discriminator."""
462
+
463
+
464
+ @dataclass
465
+ class PartDeltaEvent:
466
+ """An event indicating a delta update for an existing part."""
467
+
468
+ index: int
469
+ """The index of the part within the overall response parts list."""
470
+
471
+ delta: ModelResponsePartDelta
472
+ """The delta to apply to the specified part."""
473
+
474
+ event_kind: Literal['part_delta'] = 'part_delta'
475
+ """Event type identifier, used as a discriminator."""
476
+
477
+
478
+ ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
479
+ """An event in the model response stream, either starting a new part or applying a delta to an existing one."""
@@ -7,20 +7,22 @@ specific LLM being used.
7
7
  from __future__ import annotations as _annotations
8
8
 
9
9
  from abc import ABC, abstractmethod
10
- from collections.abc import AsyncIterator, Iterable, Iterator
10
+ from collections.abc import AsyncIterator, Iterator
11
11
  from contextlib import asynccontextmanager, contextmanager
12
+ from dataclasses import dataclass, field
12
13
  from datetime import datetime
13
14
  from functools import cache
14
- from typing import TYPE_CHECKING, Literal, Union
15
+ from typing import TYPE_CHECKING, Literal
15
16
 
16
17
  import httpx
17
18
 
19
+ from .._parts_manager import ModelResponsePartsManager
18
20
  from ..exceptions import UserError
19
- from ..messages import ModelMessage, ModelResponse
21
+ from ..messages import ModelMessage, ModelResponse, ModelResponseStreamEvent
20
22
  from ..settings import ModelSettings
23
+ from ..usage import Usage
21
24
 
22
25
  if TYPE_CHECKING:
23
- from ..result import Usage
24
26
  from ..tools import ToolDefinition
25
27
 
26
28
 
@@ -48,13 +50,12 @@ KnownModelName = Literal[
48
50
  'groq:mixtral-8x7b-32768',
49
51
  'groq:gemma2-9b-it',
50
52
  'groq:gemma-7b-it',
51
- 'gemini-1.5-flash',
52
- 'gemini-1.5-pro',
53
- 'gemini-2.0-flash-exp',
54
- 'vertexai:gemini-1.5-flash',
55
- 'vertexai:gemini-1.5-pro',
56
- # since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
57
- # don't start with "mistral", we add the "mistral:" prefix to all to be explicit
53
+ 'google-gla:gemini-1.5-flash',
54
+ 'google-gla:gemini-1.5-pro',
55
+ 'google-gla:gemini-2.0-flash-exp',
56
+ 'google-vertex:gemini-1.5-flash',
57
+ 'google-vertex:gemini-1.5-pro',
58
+ 'google-vertex:gemini-2.0-flash-exp',
58
59
  'mistral:mistral-small-latest',
59
60
  'mistral:mistral-large-latest',
60
61
  'mistral:codestral-latest',
@@ -71,14 +72,15 @@ KnownModelName = Literal[
71
72
  'ollama:mistral-nemo',
72
73
  'ollama:mixtral',
73
74
  'ollama:phi3',
75
+ 'ollama:phi4',
74
76
  'ollama:qwq',
75
77
  'ollama:qwen',
76
78
  'ollama:qwen2',
77
79
  'ollama:qwen2.5',
78
80
  'ollama:starcoder2',
79
- 'claude-3-5-haiku-latest',
80
- 'claude-3-5-sonnet-latest',
81
- 'claude-3-opus-latest',
81
+ 'anthropic:claude-3-5-haiku-latest',
82
+ 'anthropic:claude-3-5-sonnet-latest',
83
+ 'anthropic:claude-3-opus-latest',
82
84
  'test',
83
85
  ]
84
86
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -130,88 +132,47 @@ class AgentModel(ABC):
130
132
  @asynccontextmanager
131
133
  async def request_stream(
132
134
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
133
- ) -> AsyncIterator[EitherStreamedResponse]:
135
+ ) -> AsyncIterator[StreamedResponse]:
134
136
  """Make a request to the model and return a streaming response."""
137
+ # This method is not required, but you need to implement it if you want to support streamed responses
135
138
  raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
136
139
  # yield is required to make this a generator for type checking
137
140
  # noinspection PyUnreachableCode
138
141
  yield # pragma: no cover
139
142
 
140
143
 
141
- class StreamTextResponse(ABC):
142
- """Streamed response from an LLM when returning text."""
143
-
144
- def __aiter__(self) -> AsyncIterator[None]:
145
- """Stream the response as an async iterable, building up the text as it goes.
146
-
147
- This is an async iterator that yields `None` to avoid doing the work of validating the input and
148
- extracting the text field when it will often be thrown away.
149
- """
150
- return self
151
-
152
- @abstractmethod
153
- async def __anext__(self) -> None:
154
- """Process the next chunk of the response, see above for why this returns `None`."""
155
- raise NotImplementedError()
156
-
157
- @abstractmethod
158
- def get(self, *, final: bool = False) -> Iterable[str]:
159
- """Returns an iterable of text since the last call to `get()` — e.g. the text delta.
160
-
161
- Args:
162
- final: If True, this is the final call, after iteration is complete, the response should be fully validated
163
- and all text extracted.
164
- """
165
- raise NotImplementedError()
166
-
167
- @abstractmethod
168
- def usage(self) -> Usage:
169
- """Return the usage of the request.
170
-
171
- NOTE: this won't return the full usage until the stream is finished.
172
- """
173
- raise NotImplementedError()
174
-
175
- @abstractmethod
176
- def timestamp(self) -> datetime:
177
- """Get the timestamp of the response."""
178
- raise NotImplementedError()
179
-
180
-
181
- class StreamStructuredResponse(ABC):
144
+ @dataclass
145
+ class StreamedResponse(ABC):
182
146
  """Streamed response from an LLM when calling a tool."""
183
147
 
184
- def __aiter__(self) -> AsyncIterator[None]:
185
- """Stream the response as an async iterable, building up the tool call as it goes.
186
-
187
- This is an async iterator that yields `None` to avoid doing the work of building the final tool call when
188
- it will often be thrown away.
189
- """
190
- return self
148
+ _usage: Usage = field(default_factory=Usage, init=False)
149
+ _parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
150
+ _event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
191
151
 
192
- @abstractmethod
193
- async def __anext__(self) -> None:
194
- """Process the next chunk of the response, see above for why this returns `None`."""
195
- raise NotImplementedError()
152
+ def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
153
+ """Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
154
+ if self._event_iterator is None:
155
+ self._event_iterator = self._get_event_iterator()
156
+ return self._event_iterator
196
157
 
197
158
  @abstractmethod
198
- def get(self, *, final: bool = False) -> ModelResponse:
199
- """Get the `ModelResponse` at this point.
200
-
201
- The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
159
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
160
+ """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
202
161
 
203
- Args:
204
- final: If True, this is the final call, after iteration is complete, the response should be fully validated.
162
+ This method should be implemented by subclasses to translate the vendor-specific stream of events into
163
+ pydantic_ai-format events.
205
164
  """
206
165
  raise NotImplementedError()
166
+ # noinspection PyUnreachableCode
167
+ yield
207
168
 
208
- @abstractmethod
209
- def usage(self) -> Usage:
210
- """Get the usage of the request.
169
+ def get(self) -> ModelResponse:
170
+ """Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
171
+ return ModelResponse(parts=self._parts_manager.get_parts(), timestamp=self.timestamp())
211
172
 
212
- NOTE: this won't return the full usage until the stream is finished.
213
- """
214
- raise NotImplementedError()
173
+ def usage(self) -> Usage:
174
+ """Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
175
+ return self._usage
215
176
 
216
177
  @abstractmethod
217
178
  def timestamp(self) -> datetime:
@@ -219,9 +180,6 @@ class StreamStructuredResponse(ABC):
219
180
  raise NotImplementedError()
220
181
 
221
182
 
222
- EitherStreamedResponse = Union[StreamTextResponse, StreamStructuredResponse]
223
-
224
-
225
183
  ALLOW_MODEL_REQUESTS = True
226
184
  """Whether to allow requests to models.
227
185
 
@@ -274,6 +232,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
274
232
  from .openai import OpenAIModel
275
233
 
276
234
  return OpenAIModel(model[7:])
235
+ elif model.startswith(('gpt', 'o1')):
236
+ from .openai import OpenAIModel
237
+
238
+ return OpenAIModel(model)
239
+ elif model.startswith('google-gla'):
240
+ from .gemini import GeminiModel
241
+
242
+ return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
243
+ # backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
277
244
  elif model.startswith('gemini'):
278
245
  from .gemini import GeminiModel
279
246
 
@@ -283,6 +250,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
283
250
  from .groq import GroqModel
284
251
 
285
252
  return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
253
+ elif model.startswith('google-vertex'):
254
+ from .vertexai import VertexAIModel
255
+
256
+ return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
257
+ # backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
286
258
  elif model.startswith('vertexai:'):
287
259
  from .vertexai import VertexAIModel
288
260
 
@@ -295,6 +267,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
295
267
  from .ollama import OllamaModel
296
268
 
297
269
  return OllamaModel(model[7:])
270
+ elif model.startswith('anthropic'):
271
+ from .anthropic import AnthropicModel
272
+
273
+ return AnthropicModel(model[10:])
274
+ # backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
298
275
  elif model.startswith('claude'):
299
276
  from .anthropic import AnthropicModel
300
277
 
@@ -8,7 +8,7 @@ from typing import Any, Literal, Union, cast, overload
8
8
  from httpx import AsyncClient as AsyncHTTPClient
9
9
  from typing_extensions import assert_never
10
10
 
11
- from .. import result
11
+ from .. import usage
12
12
  from .._utils import guard_tool_call_id as _guard_tool_call_id
13
13
  from ..messages import (
14
14
  ArgsDict,
@@ -27,8 +27,8 @@ from ..settings import ModelSettings
27
27
  from ..tools import ToolDefinition
28
28
  from . import (
29
29
  AgentModel,
30
- EitherStreamedResponse,
31
30
  Model,
31
+ StreamedResponse,
32
32
  cached_async_http_client,
33
33
  check_allow_model_requests,
34
34
  )
@@ -136,7 +136,7 @@ class AnthropicModel(Model):
136
136
  )
137
137
 
138
138
  def name(self) -> str:
139
- return self.model_name
139
+ return f'anthropic:{self.model_name}'
140
140
 
141
141
  @staticmethod
142
142
  def _map_tool_definition(f: ToolDefinition) -> ToolParam:
@@ -158,14 +158,14 @@ class AnthropicAgentModel(AgentModel):
158
158
 
159
159
  async def request(
160
160
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
- ) -> tuple[ModelResponse, result.Usage]:
161
+ ) -> tuple[ModelResponse, usage.Usage]:
162
162
  response = await self._messages_create(messages, False, model_settings)
163
163
  return self._process_response(response), _map_usage(response)
164
164
 
165
165
  @asynccontextmanager
166
166
  async def request_stream(
167
167
  self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
- ) -> AsyncIterator[EitherStreamedResponse]:
168
+ ) -> AsyncIterator[StreamedResponse]:
169
169
  response = await self._messages_create(messages, True, model_settings)
170
170
  async with response:
171
171
  yield await self._process_streamed_response(response)
@@ -216,28 +216,28 @@ class AnthropicAgentModel(AgentModel):
216
216
  items: list[ModelResponsePart] = []
217
217
  for item in response.content:
218
218
  if isinstance(item, TextBlock):
219
- items.append(TextPart(item.text))
219
+ items.append(TextPart(content=item.text))
220
220
  else:
221
221
  assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
222
  items.append(
223
223
  ToolCallPart.from_raw_args(
224
- item.name,
225
- cast(dict[str, Any], item.input),
226
- item.id,
224
+ tool_name=item.name,
225
+ args=cast(dict[str, Any], item.input),
226
+ tool_call_id=item.id,
227
227
  )
228
228
  )
229
229
 
230
230
  return ModelResponse(items)
231
231
 
232
232
  @staticmethod
233
- async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
233
+ async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
234
234
  """TODO: Process a streamed response, and prepare a streaming response to return."""
235
235
  # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
236
  # Streamed responses will be supported in a future release.
237
237
 
238
238
  raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
239
 
240
- # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
240
+ # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamedResponse
241
241
  # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
242
  # RawMessageStartEvent
243
243
  # RawMessageDeltaEvent
@@ -315,30 +315,30 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
315
315
  )
316
316
 
317
317
 
318
- def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> result.Usage:
318
+ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
319
319
  if isinstance(message, AnthropicMessage):
320
- usage = message.usage
320
+ response_usage = message.usage
321
321
  else:
322
322
  if isinstance(message, RawMessageStartEvent):
323
- usage = message.message.usage
323
+ response_usage = message.message.usage
324
324
  elif isinstance(message, RawMessageDeltaEvent):
325
- usage = message.usage
325
+ response_usage = message.usage
326
326
  else:
327
327
  # No usage information provided in:
328
328
  # - RawMessageStopEvent
329
329
  # - RawContentBlockStartEvent
330
330
  # - RawContentBlockDeltaEvent
331
331
  # - RawContentBlockStopEvent
332
- usage = None
332
+ response_usage = None
333
333
 
334
- if usage is None:
335
- return result.Usage()
334
+ if response_usage is None:
335
+ return usage.Usage()
336
336
 
337
- request_tokens = getattr(usage, 'input_tokens', None)
337
+ request_tokens = getattr(response_usage, 'input_tokens', None)
338
338
 
339
- return result.Usage(
339
+ return usage.Usage(
340
340
  # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
341
341
  request_tokens=request_tokens,
342
- response_tokens=usage.output_tokens,
343
- total_tokens=(request_tokens or 0) + usage.output_tokens,
342
+ response_tokens=response_usage.output_tokens,
343
+ total_tokens=(request_tokens or 0) + response_usage.output_tokens,
344
344
  )