pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/messages.py CHANGED
@@ -6,14 +6,13 @@ from typing import Annotated, Any, Literal, Union
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
- from pydantic import TypeAdapter
9
+ from typing_extensions import Self
10
10
 
11
- from . import _pydantic
12
11
  from ._utils import now_utc as _now_utc
13
12
 
14
13
 
15
14
  @dataclass
16
- class SystemPrompt:
15
+ class SystemPromptPart:
17
16
  """A system prompt, generally written by the application developer.
18
17
 
19
18
  This gives the model context and guidance on how to respond.
@@ -21,12 +20,13 @@ class SystemPrompt:
21
20
 
22
21
  content: str
23
22
  """The content of the prompt."""
24
- role: Literal['system'] = 'system'
25
- """Message type identifier, this type is available on all message as a discriminator."""
23
+
24
+ part_kind: Literal['system-prompt'] = 'system-prompt'
25
+ """Part type identifier, this is available on all parts as a discriminator."""
26
26
 
27
27
 
28
28
  @dataclass
29
- class UserPrompt:
29
+ class UserPromptPart:
30
30
  """A user prompt, generally written by the end user.
31
31
 
32
32
  Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
@@ -35,29 +35,35 @@ class UserPrompt:
35
35
 
36
36
  content: str
37
37
  """The content of the prompt."""
38
+
38
39
  timestamp: datetime = field(default_factory=_now_utc)
39
40
  """The timestamp of the prompt."""
40
- role: Literal['user'] = 'user'
41
- """Message type identifier, this type is available on all message as a discriminator."""
41
+
42
+ part_kind: Literal['user-prompt'] = 'user-prompt'
43
+ """Part type identifier, this is available on all parts as a discriminator."""
42
44
 
43
45
 
44
- tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
46
+ tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
45
47
 
46
48
 
47
49
  @dataclass
48
- class ToolReturn:
50
+ class ToolReturnPart:
49
51
  """A tool return message, this encodes the result of running a tool."""
50
52
 
51
53
  tool_name: str
52
54
  """The name of the "tool" was called."""
55
+
53
56
  content: Any
54
57
  """The return value."""
55
- tool_id: str | None = None
56
- """Optional tool identifier, this is used by some models including OpenAI."""
58
+
59
+ tool_call_id: str | None = None
60
+ """Optional tool call identifier, this is used by some models including OpenAI."""
61
+
57
62
  timestamp: datetime = field(default_factory=_now_utc)
58
63
  """The timestamp, when the tool returned."""
59
- role: Literal['tool-return'] = 'tool-return'
60
- """Message type identifier, this type is available on all message as a discriminator."""
64
+
65
+ part_kind: Literal['tool-return'] = 'tool-return'
66
+ """Part type identifier, this is available on all parts as a discriminator."""
61
67
 
62
68
  def model_response_str(self) -> str:
63
69
  if isinstance(self.content, str):
@@ -73,11 +79,11 @@ class ToolReturn:
73
79
  return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
74
80
 
75
81
 
76
- ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
82
+ error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
77
83
 
78
84
 
79
85
  @dataclass
80
- class RetryPrompt:
86
+ class RetryPromptPart:
81
87
  """A message back to a model asking it to try again.
82
88
 
83
89
  This can be sent for a number of reasons:
@@ -98,37 +104,54 @@ class RetryPrompt:
98
104
  If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of
99
105
  error details.
100
106
  """
107
+
101
108
  tool_name: str | None = None
102
109
  """The name of the tool that was called, if any."""
103
- tool_id: str | None = None
104
- """The tool identifier, if any."""
110
+
111
+ tool_call_id: str | None = None
112
+ """Optional tool call identifier, this is used by some models including OpenAI."""
113
+
105
114
  timestamp: datetime = field(default_factory=_now_utc)
106
115
  """The timestamp, when the retry was triggered."""
107
- role: Literal['retry-prompt'] = 'retry-prompt'
108
- """Message type identifier, this type is available on all message as a discriminator."""
116
+
117
+ part_kind: Literal['retry-prompt'] = 'retry-prompt'
118
+ """Part type identifier, this is available on all parts as a discriminator."""
109
119
 
110
120
  def model_response(self) -> str:
111
121
  if isinstance(self.content, str):
112
122
  description = self.content
113
123
  else:
114
- json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
124
+ json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
115
125
  description = f'{len(self.content)} validation errors: {json_errors.decode()}'
116
126
  return f'{description}\n\nFix the errors and try again.'
117
127
 
118
128
 
129
+ ModelRequestPart = Annotated[
130
+ Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
131
+ ]
132
+ """A message part sent by PydanticAI to a model."""
133
+
134
+
119
135
  @dataclass
120
- class ModelTextResponse:
136
+ class ModelRequest:
137
+ """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model."""
138
+
139
+ parts: list[ModelRequestPart]
140
+ """The parts of the user message."""
141
+
142
+ kind: Literal['request'] = 'request'
143
+ """Message type identifier, this is available on all parts as a discriminator."""
144
+
145
+
146
+ @dataclass
147
+ class TextPart:
121
148
  """A plain text response from a model."""
122
149
 
123
150
  content: str
124
151
  """The text content of the response."""
125
- timestamp: datetime = field(default_factory=_now_utc)
126
- """The timestamp of the response.
127
152
 
128
- If the model provides a timestamp in the response (as OpenAI does) that will be used.
129
- """
130
- role: Literal['model-text-response'] = 'model-text-response'
131
- """Message type identifier, this type is available on all message as a discriminator."""
153
+ part_kind: Literal['text'] = 'text'
154
+ """Part type identifier, this is available on all parts as a discriminator."""
132
155
 
133
156
 
134
157
  @dataclass
@@ -148,26 +171,31 @@ class ArgsDict:
148
171
 
149
172
 
150
173
  @dataclass
151
- class ToolCall:
152
- """Either a tool call from the agent."""
174
+ class ToolCallPart:
175
+ """A tool call from a model."""
153
176
 
154
177
  tool_name: str
155
178
  """The name of the tool to call."""
179
+
156
180
  args: ArgsJson | ArgsDict
157
181
  """The arguments to pass to the tool.
158
182
 
159
183
  Either as JSON or a Python dictionary depending on how data was returned.
160
184
  """
161
- tool_id: str | None = None
162
- """Optional tool identifier, this is used by some models including OpenAI."""
185
+
186
+ tool_call_id: str | None = None
187
+ """Optional tool call identifier, this is used by some models including OpenAI."""
188
+
189
+ part_kind: Literal['tool-call'] = 'tool-call'
190
+ """Part type identifier, this is available on all parts as a discriminator."""
163
191
 
164
192
  @classmethod
165
- def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) -> ToolCall:
166
- return cls(tool_name, ArgsJson(args_json), tool_id)
193
+ def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> Self:
194
+ return cls(tool_name, ArgsJson(args_json), tool_call_id)
167
195
 
168
196
  @classmethod
169
- def from_dict(cls, tool_name: str, args_dict: dict[str, Any]) -> ToolCall:
170
- return cls(tool_name, ArgsDict(args_dict))
197
+ def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> Self:
198
+ return cls(tool_name, ArgsDict(args_dict), tool_call_id)
171
199
 
172
200
  def has_content(self) -> bool:
173
201
  if isinstance(self.args, ArgsDict):
@@ -176,28 +204,39 @@ class ToolCall:
176
204
  return bool(self.args.args_json)
177
205
 
178
206
 
207
+ ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
208
+ """A message part returned by a model."""
209
+
210
+
179
211
  @dataclass
180
- class ModelStructuredResponse:
181
- """A structured response from a model.
212
+ class ModelResponse:
213
+ """A response from a model, e.g. a message from the model to the PydanticAI app."""
182
214
 
183
- This is used either to call a tool or to return a structured response from an agent run.
184
- """
215
+ parts: list[ModelResponsePart]
216
+ """The parts of the model message."""
185
217
 
186
- calls: list[ToolCall]
187
- """The tool calls being made."""
188
218
  timestamp: datetime = field(default_factory=_now_utc)
189
219
  """The timestamp of the response.
190
220
 
191
221
  If the model provides a timestamp in the response (as OpenAI does) that will be used.
192
222
  """
193
- role: Literal['model-structured-response'] = 'model-structured-response'
194
- """Message type identifier, this type is available on all message as a discriminator."""
223
+
224
+ kind: Literal['response'] = 'response'
225
+ """Message type identifier, this is available on all parts as a discriminator."""
226
+
227
+ @classmethod
228
+ def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
229
+ return cls([TextPart(content)], timestamp=timestamp or _now_utc())
230
+
231
+ @classmethod
232
+ def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
233
+ return cls([tool_call])
195
234
 
196
235
 
197
- ModelAnyResponse = Union[ModelTextResponse, ModelStructuredResponse]
198
- """Any response from a model."""
199
- Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelTextResponse, ModelStructuredResponse]
236
+ ModelMessage = Union[ModelRequest, ModelResponse]
200
237
  """Any message send to or returned by a model."""
201
238
 
202
- MessagesTypeAdapter = _pydantic.LazyTypeAdapter(list[Annotated[Message, pydantic.Field(discriminator='role')]])
239
+ ModelMessagesTypeAdapter = pydantic.TypeAdapter(
240
+ list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
241
+ )
203
242
  """Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
@@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Literal, Union
16
16
  import httpx
17
17
 
18
18
  from ..exceptions import UserError
19
- from ..messages import Message, ModelAnyResponse, ModelStructuredResponse
19
+ from ..messages import ModelMessage, ModelResponse
20
+ from ..settings import ModelSettings
20
21
 
21
22
  if TYPE_CHECKING:
22
23
  from ..result import Cost
@@ -31,6 +32,7 @@ KnownModelName = Literal[
31
32
  'openai:o1-preview',
32
33
  'openai:o1-mini',
33
34
  'openai:gpt-3.5-turbo',
35
+ 'groq:llama-3.3-70b-versatile',
34
36
  'groq:llama-3.1-70b-versatile',
35
37
  'groq:llama3-groq-70b-8192-tool-use-preview',
36
38
  'groq:llama3-groq-8b-8192-tool-use-preview',
@@ -47,8 +49,15 @@ KnownModelName = Literal[
47
49
  'groq:gemma-7b-it',
48
50
  'gemini-1.5-flash',
49
51
  'gemini-1.5-pro',
52
+ 'gemini-2.0-flash-exp',
50
53
  'vertexai:gemini-1.5-flash',
51
54
  'vertexai:gemini-1.5-pro',
55
+ # since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
56
+ # don't start with "mistral", we add the "mistral:" prefix to all to be explicit
57
+ 'mistral:mistral-small-latest',
58
+ 'mistral:mistral-large-latest',
59
+ 'mistral:codestral-latest',
60
+ 'mistral:mistral-moderation-latest',
52
61
  'ollama:codellama',
53
62
  'ollama:gemma',
54
63
  'ollama:gemma2',
@@ -66,6 +75,9 @@ KnownModelName = Literal[
66
75
  'ollama:qwen2',
67
76
  'ollama:qwen2.5',
68
77
  'ollama:starcoder2',
78
+ 'claude-3-5-haiku-latest',
79
+ 'claude-3-5-sonnet-latest',
80
+ 'claude-3-opus-latest',
69
81
  'test',
70
82
  ]
71
83
  """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -108,12 +120,16 @@ class AgentModel(ABC):
108
120
  """Model configured for each step of an Agent run."""
109
121
 
110
122
  @abstractmethod
111
- async def request(self, messages: list[Message]) -> tuple[ModelAnyResponse, Cost]:
123
+ async def request(
124
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
125
+ ) -> tuple[ModelResponse, Cost]:
112
126
  """Make a request to the model."""
113
127
  raise NotImplementedError()
114
128
 
115
129
  @asynccontextmanager
116
- async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]:
130
+ async def request_stream(
131
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
132
+ ) -> AsyncIterator[EitherStreamedResponse]:
117
133
  """Make a request to the model and return a streaming response."""
118
134
  raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
119
135
  # yield is required to make this a generator for type checking
@@ -178,10 +194,10 @@ class StreamStructuredResponse(ABC):
178
194
  raise NotImplementedError()
179
195
 
180
196
  @abstractmethod
181
- def get(self, *, final: bool = False) -> ModelStructuredResponse:
182
- """Get the `ModelStructuredResponse` at this point.
197
+ def get(self, *, final: bool = False) -> ModelResponse:
198
+ """Get the `ModelResponse` at this point.
183
199
 
184
- The `ModelStructuredResponse` may or may not be complete, depending on whether the stream is finished.
200
+ The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
185
201
 
186
202
  Args:
187
203
  final: If True, this is the final call, after iteration is complete, the response should be fully validated.
@@ -270,10 +286,18 @@ def infer_model(model: Model | KnownModelName) -> Model:
270
286
  from .vertexai import VertexAIModel
271
287
 
272
288
  return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
289
+ elif model.startswith('mistral:'):
290
+ from .mistral import MistralModel
291
+
292
+ return MistralModel(model[8:])
273
293
  elif model.startswith('ollama:'):
274
294
  from .ollama import OllamaModel
275
295
 
276
296
  return OllamaModel(model[7:])
297
+ elif model.startswith('claude'):
298
+ from .anthropic import AnthropicModel
299
+
300
+ return AnthropicModel(model)
277
301
  else:
278
302
  raise UserError(f'Unknown model: {model}')
279
303
 
@@ -0,0 +1,344 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Literal, Union, cast, overload
7
+
8
+ from httpx import AsyncClient as AsyncHTTPClient
9
+ from typing_extensions import assert_never
10
+
11
+ from .. import result
12
+ from .._utils import guard_tool_call_id as _guard_tool_call_id
13
+ from ..messages import (
14
+ ArgsDict,
15
+ ModelMessage,
16
+ ModelRequest,
17
+ ModelResponse,
18
+ ModelResponsePart,
19
+ RetryPromptPart,
20
+ SystemPromptPart,
21
+ TextPart,
22
+ ToolCallPart,
23
+ ToolReturnPart,
24
+ UserPromptPart,
25
+ )
26
+ from ..settings import ModelSettings
27
+ from ..tools import ToolDefinition
28
+ from . import (
29
+ AgentModel,
30
+ EitherStreamedResponse,
31
+ Model,
32
+ cached_async_http_client,
33
+ check_allow_model_requests,
34
+ )
35
+
36
+ try:
37
+ from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
38
+ from anthropic.types import (
39
+ Message as AnthropicMessage,
40
+ MessageParam,
41
+ RawMessageDeltaEvent,
42
+ RawMessageStartEvent,
43
+ RawMessageStreamEvent,
44
+ TextBlock,
45
+ TextBlockParam,
46
+ ToolChoiceParam,
47
+ ToolParam,
48
+ ToolResultBlockParam,
49
+ ToolUseBlock,
50
+ ToolUseBlockParam,
51
+ )
52
+ except ImportError as _import_error:
53
+ raise ImportError(
54
+ 'Please install `anthropic` to use the Anthropic model, '
55
+ "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
56
+ ) from _import_error
57
+
58
+ LatestAnthropicModelNames = Literal[
59
+ 'claude-3-5-haiku-latest',
60
+ 'claude-3-5-sonnet-latest',
61
+ 'claude-3-opus-latest',
62
+ ]
63
+ """Latest named Anthropic models."""
64
+
65
+ AnthropicModelName = Union[str, LatestAnthropicModelNames]
66
+ """Possible Anthropic model names.
67
+
68
+ Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
69
+ allow any name in the type hints.
70
+ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
71
+ """
72
+
73
+
74
+ @dataclass(init=False)
75
+ class AnthropicModel(Model):
76
+ """A model that uses the Anthropic API.
77
+
78
+ Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
79
+
80
+ Apart from `__init__`, all methods are private or match those of the base class.
81
+
82
+ !!! note
83
+ The `AnthropicModel` class does not yet support streaming responses.
84
+ We anticipate adding support for streaming responses in a near-term future release.
85
+ """
86
+
87
+ model_name: AnthropicModelName
88
+ client: AsyncAnthropic = field(repr=False)
89
+
90
+ def __init__(
91
+ self,
92
+ model_name: AnthropicModelName,
93
+ *,
94
+ api_key: str | None = None,
95
+ anthropic_client: AsyncAnthropic | None = None,
96
+ http_client: AsyncHTTPClient | None = None,
97
+ ):
98
+ """Initialize an Anthropic model.
99
+
100
+ Args:
101
+ model_name: The name of the Anthropic model to use. List of model names available
102
+ [here](https://docs.anthropic.com/en/docs/about-claude/models).
103
+ api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
104
+ will be used if available.
105
+ anthropic_client: An existing
106
+ [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
107
+ client to use, if provided, `api_key` and `http_client` must be `None`.
108
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
109
+ """
110
+ self.model_name = model_name
111
+ if anthropic_client is not None:
112
+ assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
113
+ assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
114
+ self.client = anthropic_client
115
+ elif http_client is not None:
116
+ self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
117
+ else:
118
+ self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
119
+
120
+ async def agent_model(
121
+ self,
122
+ *,
123
+ function_tools: list[ToolDefinition],
124
+ allow_text_result: bool,
125
+ result_tools: list[ToolDefinition],
126
+ ) -> AgentModel:
127
+ check_allow_model_requests()
128
+ tools = [self._map_tool_definition(r) for r in function_tools]
129
+ if result_tools:
130
+ tools += [self._map_tool_definition(r) for r in result_tools]
131
+ return AnthropicAgentModel(
132
+ self.client,
133
+ self.model_name,
134
+ allow_text_result,
135
+ tools,
136
+ )
137
+
138
+ def name(self) -> str:
139
+ return self.model_name
140
+
141
+ @staticmethod
142
+ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
143
+ return {
144
+ 'name': f.name,
145
+ 'description': f.description,
146
+ 'input_schema': f.parameters_json_schema,
147
+ }
148
+
149
+
150
+ @dataclass
151
+ class AnthropicAgentModel(AgentModel):
152
+ """Implementation of `AgentModel` for Anthropic models."""
153
+
154
+ client: AsyncAnthropic
155
+ model_name: str
156
+ allow_text_result: bool
157
+ tools: list[ToolParam]
158
+
159
+ async def request(
160
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
161
+ ) -> tuple[ModelResponse, result.Cost]:
162
+ response = await self._messages_create(messages, False, model_settings)
163
+ return self._process_response(response), _map_cost(response)
164
+
165
+ @asynccontextmanager
166
+ async def request_stream(
167
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
168
+ ) -> AsyncIterator[EitherStreamedResponse]:
169
+ response = await self._messages_create(messages, True, model_settings)
170
+ async with response:
171
+ yield await self._process_streamed_response(response)
172
+
173
+ @overload
174
+ async def _messages_create(
175
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
176
+ ) -> AsyncStream[RawMessageStreamEvent]:
177
+ pass
178
+
179
+ @overload
180
+ async def _messages_create(
181
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
182
+ ) -> AnthropicMessage:
183
+ pass
184
+
185
+ async def _messages_create(
186
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
187
+ ) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
188
+ # standalone function to make it easier to override
189
+ if not self.tools:
190
+ tool_choice: ToolChoiceParam | None = None
191
+ elif not self.allow_text_result:
192
+ tool_choice = {'type': 'any'}
193
+ else:
194
+ tool_choice = {'type': 'auto'}
195
+
196
+ system_prompt, anthropic_messages = self._map_message(messages)
197
+
198
+ model_settings = model_settings or {}
199
+
200
+ return await self.client.messages.create(
201
+ max_tokens=model_settings.get('max_tokens', 1024),
202
+ system=system_prompt or NOT_GIVEN,
203
+ messages=anthropic_messages,
204
+ model=self.model_name,
205
+ tools=self.tools or NOT_GIVEN,
206
+ tool_choice=tool_choice or NOT_GIVEN,
207
+ stream=stream,
208
+ temperature=model_settings.get('temperature', NOT_GIVEN),
209
+ top_p=model_settings.get('top_p', NOT_GIVEN),
210
+ timeout=model_settings.get('timeout', NOT_GIVEN),
211
+ )
212
+
213
+ @staticmethod
214
+ def _process_response(response: AnthropicMessage) -> ModelResponse:
215
+ """Process a non-streamed response, and prepare a message to return."""
216
+ items: list[ModelResponsePart] = []
217
+ for item in response.content:
218
+ if isinstance(item, TextBlock):
219
+ items.append(TextPart(item.text))
220
+ else:
221
+ assert isinstance(item, ToolUseBlock), 'unexpected item type'
222
+ items.append(
223
+ ToolCallPart.from_dict(
224
+ item.name,
225
+ cast(dict[str, Any], item.input),
226
+ item.id,
227
+ )
228
+ )
229
+
230
+ return ModelResponse(items)
231
+
232
+ @staticmethod
233
+ async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
234
+ """TODO: Process a streamed response, and prepare a streaming response to return."""
235
+ # We don't yet support streamed responses from Anthropic, so we raise an error here for now.
236
+ # Streamed responses will be supported in a future release.
237
+
238
+ raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
239
+
240
+ # Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
241
+ # depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
242
+ # RawMessageStartEvent
243
+ # RawMessageDeltaEvent
244
+ # RawMessageStopEvent
245
+ # RawContentBlockStartEvent
246
+ # RawContentBlockDeltaEvent
247
+ # RawContentBlockDeltaEvent
248
+ #
249
+ # We might refactor streaming internally before we implement this...
250
+
251
+ @staticmethod
252
+ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
253
+ """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
254
+ system_prompt: str = ''
255
+ anthropic_messages: list[MessageParam] = []
256
+ for m in messages:
257
+ if isinstance(m, ModelRequest):
258
+ for part in m.parts:
259
+ if isinstance(part, SystemPromptPart):
260
+ system_prompt += part.content
261
+ elif isinstance(part, UserPromptPart):
262
+ anthropic_messages.append(MessageParam(role='user', content=part.content))
263
+ elif isinstance(part, ToolReturnPart):
264
+ anthropic_messages.append(
265
+ MessageParam(
266
+ role='user',
267
+ content=[
268
+ ToolResultBlockParam(
269
+ tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
270
+ type='tool_result',
271
+ content=part.model_response_str(),
272
+ is_error=False,
273
+ )
274
+ ],
275
+ )
276
+ )
277
+ elif isinstance(part, RetryPromptPart):
278
+ if part.tool_name is None:
279
+ anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
280
+ else:
281
+ anthropic_messages.append(
282
+ MessageParam(
283
+ role='user',
284
+ content=[
285
+ ToolUseBlockParam(
286
+ id=_guard_tool_call_id(t=part, model_source='Anthropic'),
287
+ input=part.model_response(),
288
+ name=part.tool_name,
289
+ type='tool_use',
290
+ ),
291
+ ],
292
+ )
293
+ )
294
+ elif isinstance(m, ModelResponse):
295
+ content: list[TextBlockParam | ToolUseBlockParam] = []
296
+ for item in m.parts:
297
+ if isinstance(item, TextPart):
298
+ content.append(TextBlockParam(text=item.content, type='text'))
299
+ else:
300
+ assert isinstance(item, ToolCallPart)
301
+ content.append(_map_tool_call(item))
302
+ anthropic_messages.append(MessageParam(role='assistant', content=content))
303
+ else:
304
+ assert_never(m)
305
+ return system_prompt, anthropic_messages
306
+
307
+
308
+ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
309
+ assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
310
+ return ToolUseBlockParam(
311
+ id=_guard_tool_call_id(t=t, model_source='Anthropic'),
312
+ type='tool_use',
313
+ name=t.tool_name,
314
+ input=t.args.args_dict,
315
+ )
316
+
317
+
318
+ def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
319
+ if isinstance(message, AnthropicMessage):
320
+ usage = message.usage
321
+ else:
322
+ if isinstance(message, RawMessageStartEvent):
323
+ usage = message.message.usage
324
+ elif isinstance(message, RawMessageDeltaEvent):
325
+ usage = message.usage
326
+ else:
327
+ # No usage information provided in:
328
+ # - RawMessageStopEvent
329
+ # - RawContentBlockStartEvent
330
+ # - RawContentBlockDeltaEvent
331
+ # - RawContentBlockStopEvent
332
+ usage = None
333
+
334
+ if usage is None:
335
+ return result.Cost()
336
+
337
+ request_tokens = getattr(usage, 'input_tokens', None)
338
+
339
+ return result.Cost(
340
+ # Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
341
+ request_tokens=request_tokens,
342
+ response_tokens=usage.output_tokens,
343
+ total_tokens=(request_tokens or 0) + usage.output_tokens,
344
+ )