pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_pydantic.py +7 -25
- pydantic_ai/_result.py +34 -16
- pydantic_ai/_system_prompt.py +1 -1
- pydantic_ai/_utils.py +9 -2
- pydantic_ai/agent.py +333 -148
- pydantic_ai/messages.py +87 -48
- pydantic_ai/models/__init__.py +30 -6
- pydantic_ai/models/anthropic.py +344 -0
- pydantic_ai/models/function.py +59 -31
- pydantic_ai/models/gemini.py +150 -108
- pydantic_ai/models/groq.py +94 -74
- pydantic_ai/models/mistral.py +680 -0
- pydantic_ai/models/ollama.py +1 -1
- pydantic_ai/models/openai.py +102 -76
- pydantic_ai/models/test.py +62 -51
- pydantic_ai/models/vertexai.py +7 -3
- pydantic_ai/result.py +35 -37
- pydantic_ai/settings.py +72 -0
- pydantic_ai/tools.py +28 -18
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/METADATA +8 -3
- pydantic_ai_slim-0.0.13.dist-info/RECORD +26 -0
- {pydantic_ai_slim-0.0.12.dist-info → pydantic_ai_slim-0.0.13.dist-info}/WHEEL +1 -1
- pydantic_ai_slim-0.0.12.dist-info/RECORD +0 -23
pydantic_ai/messages.py
CHANGED
|
@@ -6,14 +6,13 @@ from typing import Annotated, Any, Literal, Union
|
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
|
-
from
|
|
9
|
+
from typing_extensions import Self
|
|
10
10
|
|
|
11
|
-
from . import _pydantic
|
|
12
11
|
from ._utils import now_utc as _now_utc
|
|
13
12
|
|
|
14
13
|
|
|
15
14
|
@dataclass
|
|
16
|
-
class
|
|
15
|
+
class SystemPromptPart:
|
|
17
16
|
"""A system prompt, generally written by the application developer.
|
|
18
17
|
|
|
19
18
|
This gives the model context and guidance on how to respond.
|
|
@@ -21,12 +20,13 @@ class SystemPrompt:
|
|
|
21
20
|
|
|
22
21
|
content: str
|
|
23
22
|
"""The content of the prompt."""
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
|
|
24
|
+
part_kind: Literal['system-prompt'] = 'system-prompt'
|
|
25
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
@dataclass
|
|
29
|
-
class
|
|
29
|
+
class UserPromptPart:
|
|
30
30
|
"""A user prompt, generally written by the end user.
|
|
31
31
|
|
|
32
32
|
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
|
|
@@ -35,29 +35,35 @@ class UserPrompt:
|
|
|
35
35
|
|
|
36
36
|
content: str
|
|
37
37
|
"""The content of the prompt."""
|
|
38
|
+
|
|
38
39
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
39
40
|
"""The timestamp of the prompt."""
|
|
40
|
-
|
|
41
|
-
|
|
41
|
+
|
|
42
|
+
part_kind: Literal['user-prompt'] = 'user-prompt'
|
|
43
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
42
44
|
|
|
43
45
|
|
|
44
|
-
tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
|
|
46
|
+
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
|
|
45
47
|
|
|
46
48
|
|
|
47
49
|
@dataclass
|
|
48
|
-
class
|
|
50
|
+
class ToolReturnPart:
|
|
49
51
|
"""A tool return message, this encodes the result of running a tool."""
|
|
50
52
|
|
|
51
53
|
tool_name: str
|
|
52
54
|
"""The name of the "tool" was called."""
|
|
55
|
+
|
|
53
56
|
content: Any
|
|
54
57
|
"""The return value."""
|
|
55
|
-
|
|
56
|
-
|
|
58
|
+
|
|
59
|
+
tool_call_id: str | None = None
|
|
60
|
+
"""Optional tool call identifier, this is used by some models including OpenAI."""
|
|
61
|
+
|
|
57
62
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
58
63
|
"""The timestamp, when the tool returned."""
|
|
59
|
-
|
|
60
|
-
|
|
64
|
+
|
|
65
|
+
part_kind: Literal['tool-return'] = 'tool-return'
|
|
66
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
61
67
|
|
|
62
68
|
def model_response_str(self) -> str:
|
|
63
69
|
if isinstance(self.content, str):
|
|
@@ -73,11 +79,11 @@ class ToolReturn:
|
|
|
73
79
|
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
|
|
74
80
|
|
|
75
81
|
|
|
76
|
-
|
|
82
|
+
error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
|
|
77
83
|
|
|
78
84
|
|
|
79
85
|
@dataclass
|
|
80
|
-
class
|
|
86
|
+
class RetryPromptPart:
|
|
81
87
|
"""A message back to a model asking it to try again.
|
|
82
88
|
|
|
83
89
|
This can be sent for a number of reasons:
|
|
@@ -98,37 +104,54 @@ class RetryPrompt:
|
|
|
98
104
|
If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of
|
|
99
105
|
error details.
|
|
100
106
|
"""
|
|
107
|
+
|
|
101
108
|
tool_name: str | None = None
|
|
102
109
|
"""The name of the tool that was called, if any."""
|
|
103
|
-
|
|
104
|
-
|
|
110
|
+
|
|
111
|
+
tool_call_id: str | None = None
|
|
112
|
+
"""Optional tool call identifier, this is used by some models including OpenAI."""
|
|
113
|
+
|
|
105
114
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
106
115
|
"""The timestamp, when the retry was triggered."""
|
|
107
|
-
|
|
108
|
-
|
|
116
|
+
|
|
117
|
+
part_kind: Literal['retry-prompt'] = 'retry-prompt'
|
|
118
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
109
119
|
|
|
110
120
|
def model_response(self) -> str:
|
|
111
121
|
if isinstance(self.content, str):
|
|
112
122
|
description = self.content
|
|
113
123
|
else:
|
|
114
|
-
json_errors =
|
|
124
|
+
json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
|
|
115
125
|
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
|
|
116
126
|
return f'{description}\n\nFix the errors and try again.'
|
|
117
127
|
|
|
118
128
|
|
|
129
|
+
ModelRequestPart = Annotated[
|
|
130
|
+
Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
|
|
131
|
+
]
|
|
132
|
+
"""A message part sent by PydanticAI to a model."""
|
|
133
|
+
|
|
134
|
+
|
|
119
135
|
@dataclass
|
|
120
|
-
class
|
|
136
|
+
class ModelRequest:
|
|
137
|
+
"""A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model."""
|
|
138
|
+
|
|
139
|
+
parts: list[ModelRequestPart]
|
|
140
|
+
"""The parts of the user message."""
|
|
141
|
+
|
|
142
|
+
kind: Literal['request'] = 'request'
|
|
143
|
+
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
@dataclass
|
|
147
|
+
class TextPart:
|
|
121
148
|
"""A plain text response from a model."""
|
|
122
149
|
|
|
123
150
|
content: str
|
|
124
151
|
"""The text content of the response."""
|
|
125
|
-
timestamp: datetime = field(default_factory=_now_utc)
|
|
126
|
-
"""The timestamp of the response.
|
|
127
152
|
|
|
128
|
-
|
|
129
|
-
"""
|
|
130
|
-
role: Literal['model-text-response'] = 'model-text-response'
|
|
131
|
-
"""Message type identifier, this type is available on all message as a discriminator."""
|
|
153
|
+
part_kind: Literal['text'] = 'text'
|
|
154
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
132
155
|
|
|
133
156
|
|
|
134
157
|
@dataclass
|
|
@@ -148,26 +171,31 @@ class ArgsDict:
|
|
|
148
171
|
|
|
149
172
|
|
|
150
173
|
@dataclass
|
|
151
|
-
class
|
|
152
|
-
"""
|
|
174
|
+
class ToolCallPart:
|
|
175
|
+
"""A tool call from a model."""
|
|
153
176
|
|
|
154
177
|
tool_name: str
|
|
155
178
|
"""The name of the tool to call."""
|
|
179
|
+
|
|
156
180
|
args: ArgsJson | ArgsDict
|
|
157
181
|
"""The arguments to pass to the tool.
|
|
158
182
|
|
|
159
183
|
Either as JSON or a Python dictionary depending on how data was returned.
|
|
160
184
|
"""
|
|
161
|
-
|
|
162
|
-
|
|
185
|
+
|
|
186
|
+
tool_call_id: str | None = None
|
|
187
|
+
"""Optional tool call identifier, this is used by some models including OpenAI."""
|
|
188
|
+
|
|
189
|
+
part_kind: Literal['tool-call'] = 'tool-call'
|
|
190
|
+
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
163
191
|
|
|
164
192
|
@classmethod
|
|
165
|
-
def from_json(cls, tool_name: str, args_json: str,
|
|
166
|
-
return cls(tool_name, ArgsJson(args_json),
|
|
193
|
+
def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> Self:
|
|
194
|
+
return cls(tool_name, ArgsJson(args_json), tool_call_id)
|
|
167
195
|
|
|
168
196
|
@classmethod
|
|
169
|
-
def from_dict(cls, tool_name: str, args_dict: dict[str, Any]) ->
|
|
170
|
-
return cls(tool_name, ArgsDict(args_dict))
|
|
197
|
+
def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
198
|
+
return cls(tool_name, ArgsDict(args_dict), tool_call_id)
|
|
171
199
|
|
|
172
200
|
def has_content(self) -> bool:
|
|
173
201
|
if isinstance(self.args, ArgsDict):
|
|
@@ -176,28 +204,39 @@ class ToolCall:
|
|
|
176
204
|
return bool(self.args.args_json)
|
|
177
205
|
|
|
178
206
|
|
|
207
|
+
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
|
|
208
|
+
"""A message part returned by a model."""
|
|
209
|
+
|
|
210
|
+
|
|
179
211
|
@dataclass
|
|
180
|
-
class
|
|
181
|
-
"""A
|
|
212
|
+
class ModelResponse:
|
|
213
|
+
"""A response from a model, e.g. a message from the model to the PydanticAI app."""
|
|
182
214
|
|
|
183
|
-
|
|
184
|
-
"""
|
|
215
|
+
parts: list[ModelResponsePart]
|
|
216
|
+
"""The parts of the model message."""
|
|
185
217
|
|
|
186
|
-
calls: list[ToolCall]
|
|
187
|
-
"""The tool calls being made."""
|
|
188
218
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
189
219
|
"""The timestamp of the response.
|
|
190
220
|
|
|
191
221
|
If the model provides a timestamp in the response (as OpenAI does) that will be used.
|
|
192
222
|
"""
|
|
193
|
-
|
|
194
|
-
|
|
223
|
+
|
|
224
|
+
kind: Literal['response'] = 'response'
|
|
225
|
+
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
226
|
+
|
|
227
|
+
@classmethod
|
|
228
|
+
def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
|
|
229
|
+
return cls([TextPart(content)], timestamp=timestamp or _now_utc())
|
|
230
|
+
|
|
231
|
+
@classmethod
|
|
232
|
+
def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
|
|
233
|
+
return cls([tool_call])
|
|
195
234
|
|
|
196
235
|
|
|
197
|
-
|
|
198
|
-
"""Any response from a model."""
|
|
199
|
-
Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelTextResponse, ModelStructuredResponse]
|
|
236
|
+
ModelMessage = Union[ModelRequest, ModelResponse]
|
|
200
237
|
"""Any message send to or returned by a model."""
|
|
201
238
|
|
|
202
|
-
|
|
239
|
+
ModelMessagesTypeAdapter = pydantic.TypeAdapter(
|
|
240
|
+
list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
|
|
241
|
+
)
|
|
203
242
|
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -16,7 +16,8 @@ from typing import TYPE_CHECKING, Literal, Union
|
|
|
16
16
|
import httpx
|
|
17
17
|
|
|
18
18
|
from ..exceptions import UserError
|
|
19
|
-
from ..messages import
|
|
19
|
+
from ..messages import ModelMessage, ModelResponse
|
|
20
|
+
from ..settings import ModelSettings
|
|
20
21
|
|
|
21
22
|
if TYPE_CHECKING:
|
|
22
23
|
from ..result import Cost
|
|
@@ -31,6 +32,7 @@ KnownModelName = Literal[
|
|
|
31
32
|
'openai:o1-preview',
|
|
32
33
|
'openai:o1-mini',
|
|
33
34
|
'openai:gpt-3.5-turbo',
|
|
35
|
+
'groq:llama-3.3-70b-versatile',
|
|
34
36
|
'groq:llama-3.1-70b-versatile',
|
|
35
37
|
'groq:llama3-groq-70b-8192-tool-use-preview',
|
|
36
38
|
'groq:llama3-groq-8b-8192-tool-use-preview',
|
|
@@ -47,8 +49,15 @@ KnownModelName = Literal[
|
|
|
47
49
|
'groq:gemma-7b-it',
|
|
48
50
|
'gemini-1.5-flash',
|
|
49
51
|
'gemini-1.5-pro',
|
|
52
|
+
'gemini-2.0-flash-exp',
|
|
50
53
|
'vertexai:gemini-1.5-flash',
|
|
51
54
|
'vertexai:gemini-1.5-pro',
|
|
55
|
+
# since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
|
|
56
|
+
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
|
|
57
|
+
'mistral:mistral-small-latest',
|
|
58
|
+
'mistral:mistral-large-latest',
|
|
59
|
+
'mistral:codestral-latest',
|
|
60
|
+
'mistral:mistral-moderation-latest',
|
|
52
61
|
'ollama:codellama',
|
|
53
62
|
'ollama:gemma',
|
|
54
63
|
'ollama:gemma2',
|
|
@@ -66,6 +75,9 @@ KnownModelName = Literal[
|
|
|
66
75
|
'ollama:qwen2',
|
|
67
76
|
'ollama:qwen2.5',
|
|
68
77
|
'ollama:starcoder2',
|
|
78
|
+
'claude-3-5-haiku-latest',
|
|
79
|
+
'claude-3-5-sonnet-latest',
|
|
80
|
+
'claude-3-opus-latest',
|
|
69
81
|
'test',
|
|
70
82
|
]
|
|
71
83
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -108,12 +120,16 @@ class AgentModel(ABC):
|
|
|
108
120
|
"""Model configured for each step of an Agent run."""
|
|
109
121
|
|
|
110
122
|
@abstractmethod
|
|
111
|
-
async def request(
|
|
123
|
+
async def request(
|
|
124
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
125
|
+
) -> tuple[ModelResponse, Cost]:
|
|
112
126
|
"""Make a request to the model."""
|
|
113
127
|
raise NotImplementedError()
|
|
114
128
|
|
|
115
129
|
@asynccontextmanager
|
|
116
|
-
async def request_stream(
|
|
130
|
+
async def request_stream(
|
|
131
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
132
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
117
133
|
"""Make a request to the model and return a streaming response."""
|
|
118
134
|
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
|
|
119
135
|
# yield is required to make this a generator for type checking
|
|
@@ -178,10 +194,10 @@ class StreamStructuredResponse(ABC):
|
|
|
178
194
|
raise NotImplementedError()
|
|
179
195
|
|
|
180
196
|
@abstractmethod
|
|
181
|
-
def get(self, *, final: bool = False) ->
|
|
182
|
-
"""Get the `
|
|
197
|
+
def get(self, *, final: bool = False) -> ModelResponse:
|
|
198
|
+
"""Get the `ModelResponse` at this point.
|
|
183
199
|
|
|
184
|
-
The `
|
|
200
|
+
The `ModelResponse` may or may not be complete, depending on whether the stream is finished.
|
|
185
201
|
|
|
186
202
|
Args:
|
|
187
203
|
final: If True, this is the final call, after iteration is complete, the response should be fully validated.
|
|
@@ -270,10 +286,18 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
270
286
|
from .vertexai import VertexAIModel
|
|
271
287
|
|
|
272
288
|
return VertexAIModel(model[9:]) # pyright: ignore[reportArgumentType]
|
|
289
|
+
elif model.startswith('mistral:'):
|
|
290
|
+
from .mistral import MistralModel
|
|
291
|
+
|
|
292
|
+
return MistralModel(model[8:])
|
|
273
293
|
elif model.startswith('ollama:'):
|
|
274
294
|
from .ollama import OllamaModel
|
|
275
295
|
|
|
276
296
|
return OllamaModel(model[7:])
|
|
297
|
+
elif model.startswith('claude'):
|
|
298
|
+
from .anthropic import AnthropicModel
|
|
299
|
+
|
|
300
|
+
return AnthropicModel(model)
|
|
277
301
|
else:
|
|
278
302
|
raise UserError(f'Unknown model: {model}')
|
|
279
303
|
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Literal, Union, cast, overload
|
|
7
|
+
|
|
8
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
|
+
from typing_extensions import assert_never
|
|
10
|
+
|
|
11
|
+
from .. import result
|
|
12
|
+
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
|
+
from ..messages import (
|
|
14
|
+
ArgsDict,
|
|
15
|
+
ModelMessage,
|
|
16
|
+
ModelRequest,
|
|
17
|
+
ModelResponse,
|
|
18
|
+
ModelResponsePart,
|
|
19
|
+
RetryPromptPart,
|
|
20
|
+
SystemPromptPart,
|
|
21
|
+
TextPart,
|
|
22
|
+
ToolCallPart,
|
|
23
|
+
ToolReturnPart,
|
|
24
|
+
UserPromptPart,
|
|
25
|
+
)
|
|
26
|
+
from ..settings import ModelSettings
|
|
27
|
+
from ..tools import ToolDefinition
|
|
28
|
+
from . import (
|
|
29
|
+
AgentModel,
|
|
30
|
+
EitherStreamedResponse,
|
|
31
|
+
Model,
|
|
32
|
+
cached_async_http_client,
|
|
33
|
+
check_allow_model_requests,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
from anthropic import NOT_GIVEN, AsyncAnthropic, AsyncStream
|
|
38
|
+
from anthropic.types import (
|
|
39
|
+
Message as AnthropicMessage,
|
|
40
|
+
MessageParam,
|
|
41
|
+
RawMessageDeltaEvent,
|
|
42
|
+
RawMessageStartEvent,
|
|
43
|
+
RawMessageStreamEvent,
|
|
44
|
+
TextBlock,
|
|
45
|
+
TextBlockParam,
|
|
46
|
+
ToolChoiceParam,
|
|
47
|
+
ToolParam,
|
|
48
|
+
ToolResultBlockParam,
|
|
49
|
+
ToolUseBlock,
|
|
50
|
+
ToolUseBlockParam,
|
|
51
|
+
)
|
|
52
|
+
except ImportError as _import_error:
|
|
53
|
+
raise ImportError(
|
|
54
|
+
'Please install `anthropic` to use the Anthropic model, '
|
|
55
|
+
"you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
|
|
56
|
+
) from _import_error
|
|
57
|
+
|
|
58
|
+
LatestAnthropicModelNames = Literal[
|
|
59
|
+
'claude-3-5-haiku-latest',
|
|
60
|
+
'claude-3-5-sonnet-latest',
|
|
61
|
+
'claude-3-opus-latest',
|
|
62
|
+
]
|
|
63
|
+
"""Latest named Anthropic models."""
|
|
64
|
+
|
|
65
|
+
AnthropicModelName = Union[str, LatestAnthropicModelNames]
|
|
66
|
+
"""Possible Anthropic model names.
|
|
67
|
+
|
|
68
|
+
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
|
|
69
|
+
allow any name in the type hints.
|
|
70
|
+
Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(init=False)
|
|
75
|
+
class AnthropicModel(Model):
|
|
76
|
+
"""A model that uses the Anthropic API.
|
|
77
|
+
|
|
78
|
+
Internally, this uses the [Anthropic Python client](https://github.com/anthropics/anthropic-sdk-python) to interact with the API.
|
|
79
|
+
|
|
80
|
+
Apart from `__init__`, all methods are private or match those of the base class.
|
|
81
|
+
|
|
82
|
+
!!! note
|
|
83
|
+
The `AnthropicModel` class does not yet support streaming responses.
|
|
84
|
+
We anticipate adding support for streaming responses in a near-term future release.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
model_name: AnthropicModelName
|
|
88
|
+
client: AsyncAnthropic = field(repr=False)
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
model_name: AnthropicModelName,
|
|
93
|
+
*,
|
|
94
|
+
api_key: str | None = None,
|
|
95
|
+
anthropic_client: AsyncAnthropic | None = None,
|
|
96
|
+
http_client: AsyncHTTPClient | None = None,
|
|
97
|
+
):
|
|
98
|
+
"""Initialize an Anthropic model.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
model_name: The name of the Anthropic model to use. List of model names available
|
|
102
|
+
[here](https://docs.anthropic.com/en/docs/about-claude/models).
|
|
103
|
+
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
|
|
104
|
+
will be used if available.
|
|
105
|
+
anthropic_client: An existing
|
|
106
|
+
[`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python?tab=readme-ov-file#async-usage)
|
|
107
|
+
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
108
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
109
|
+
"""
|
|
110
|
+
self.model_name = model_name
|
|
111
|
+
if anthropic_client is not None:
|
|
112
|
+
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
113
|
+
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
114
|
+
self.client = anthropic_client
|
|
115
|
+
elif http_client is not None:
|
|
116
|
+
self.client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
117
|
+
else:
|
|
118
|
+
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
119
|
+
|
|
120
|
+
async def agent_model(
|
|
121
|
+
self,
|
|
122
|
+
*,
|
|
123
|
+
function_tools: list[ToolDefinition],
|
|
124
|
+
allow_text_result: bool,
|
|
125
|
+
result_tools: list[ToolDefinition],
|
|
126
|
+
) -> AgentModel:
|
|
127
|
+
check_allow_model_requests()
|
|
128
|
+
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
129
|
+
if result_tools:
|
|
130
|
+
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
131
|
+
return AnthropicAgentModel(
|
|
132
|
+
self.client,
|
|
133
|
+
self.model_name,
|
|
134
|
+
allow_text_result,
|
|
135
|
+
tools,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def name(self) -> str:
|
|
139
|
+
return self.model_name
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
143
|
+
return {
|
|
144
|
+
'name': f.name,
|
|
145
|
+
'description': f.description,
|
|
146
|
+
'input_schema': f.parameters_json_schema,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass
|
|
151
|
+
class AnthropicAgentModel(AgentModel):
|
|
152
|
+
"""Implementation of `AgentModel` for Anthropic models."""
|
|
153
|
+
|
|
154
|
+
client: AsyncAnthropic
|
|
155
|
+
model_name: str
|
|
156
|
+
allow_text_result: bool
|
|
157
|
+
tools: list[ToolParam]
|
|
158
|
+
|
|
159
|
+
async def request(
|
|
160
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
|
+
) -> tuple[ModelResponse, result.Cost]:
|
|
162
|
+
response = await self._messages_create(messages, False, model_settings)
|
|
163
|
+
return self._process_response(response), _map_cost(response)
|
|
164
|
+
|
|
165
|
+
@asynccontextmanager
|
|
166
|
+
async def request_stream(
|
|
167
|
+
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
|
+
) -> AsyncIterator[EitherStreamedResponse]:
|
|
169
|
+
response = await self._messages_create(messages, True, model_settings)
|
|
170
|
+
async with response:
|
|
171
|
+
yield await self._process_streamed_response(response)
|
|
172
|
+
|
|
173
|
+
@overload
|
|
174
|
+
async def _messages_create(
|
|
175
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
|
|
176
|
+
) -> AsyncStream[RawMessageStreamEvent]:
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
@overload
|
|
180
|
+
async def _messages_create(
|
|
181
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
|
|
182
|
+
) -> AnthropicMessage:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
async def _messages_create(
|
|
186
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
|
|
187
|
+
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
188
|
+
# standalone function to make it easier to override
|
|
189
|
+
if not self.tools:
|
|
190
|
+
tool_choice: ToolChoiceParam | None = None
|
|
191
|
+
elif not self.allow_text_result:
|
|
192
|
+
tool_choice = {'type': 'any'}
|
|
193
|
+
else:
|
|
194
|
+
tool_choice = {'type': 'auto'}
|
|
195
|
+
|
|
196
|
+
system_prompt, anthropic_messages = self._map_message(messages)
|
|
197
|
+
|
|
198
|
+
model_settings = model_settings or {}
|
|
199
|
+
|
|
200
|
+
return await self.client.messages.create(
|
|
201
|
+
max_tokens=model_settings.get('max_tokens', 1024),
|
|
202
|
+
system=system_prompt or NOT_GIVEN,
|
|
203
|
+
messages=anthropic_messages,
|
|
204
|
+
model=self.model_name,
|
|
205
|
+
tools=self.tools or NOT_GIVEN,
|
|
206
|
+
tool_choice=tool_choice or NOT_GIVEN,
|
|
207
|
+
stream=stream,
|
|
208
|
+
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
209
|
+
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
210
|
+
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _process_response(response: AnthropicMessage) -> ModelResponse:
|
|
215
|
+
"""Process a non-streamed response, and prepare a message to return."""
|
|
216
|
+
items: list[ModelResponsePart] = []
|
|
217
|
+
for item in response.content:
|
|
218
|
+
if isinstance(item, TextBlock):
|
|
219
|
+
items.append(TextPart(item.text))
|
|
220
|
+
else:
|
|
221
|
+
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
|
+
items.append(
|
|
223
|
+
ToolCallPart.from_dict(
|
|
224
|
+
item.name,
|
|
225
|
+
cast(dict[str, Any], item.input),
|
|
226
|
+
item.id,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return ModelResponse(items)
|
|
231
|
+
|
|
232
|
+
@staticmethod
|
|
233
|
+
async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent]) -> EitherStreamedResponse:
|
|
234
|
+
"""TODO: Process a streamed response, and prepare a streaming response to return."""
|
|
235
|
+
# We don't yet support streamed responses from Anthropic, so we raise an error here for now.
|
|
236
|
+
# Streamed responses will be supported in a future release.
|
|
237
|
+
|
|
238
|
+
raise RuntimeError('Streamed responses are not yet supported for Anthropic models.')
|
|
239
|
+
|
|
240
|
+
# Should be returning some sort of AnthropicStreamTextResponse or AnthropicStreamStructuredResponse
|
|
241
|
+
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
|
+
# RawMessageStartEvent
|
|
243
|
+
# RawMessageDeltaEvent
|
|
244
|
+
# RawMessageStopEvent
|
|
245
|
+
# RawContentBlockStartEvent
|
|
246
|
+
# RawContentBlockDeltaEvent
|
|
247
|
+
# RawContentBlockDeltaEvent
|
|
248
|
+
#
|
|
249
|
+
# We might refactor streaming internally before we implement this...
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
253
|
+
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
254
|
+
system_prompt: str = ''
|
|
255
|
+
anthropic_messages: list[MessageParam] = []
|
|
256
|
+
for m in messages:
|
|
257
|
+
if isinstance(m, ModelRequest):
|
|
258
|
+
for part in m.parts:
|
|
259
|
+
if isinstance(part, SystemPromptPart):
|
|
260
|
+
system_prompt += part.content
|
|
261
|
+
elif isinstance(part, UserPromptPart):
|
|
262
|
+
anthropic_messages.append(MessageParam(role='user', content=part.content))
|
|
263
|
+
elif isinstance(part, ToolReturnPart):
|
|
264
|
+
anthropic_messages.append(
|
|
265
|
+
MessageParam(
|
|
266
|
+
role='user',
|
|
267
|
+
content=[
|
|
268
|
+
ToolResultBlockParam(
|
|
269
|
+
tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
270
|
+
type='tool_result',
|
|
271
|
+
content=part.model_response_str(),
|
|
272
|
+
is_error=False,
|
|
273
|
+
)
|
|
274
|
+
],
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
elif isinstance(part, RetryPromptPart):
|
|
278
|
+
if part.tool_name is None:
|
|
279
|
+
anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
|
|
280
|
+
else:
|
|
281
|
+
anthropic_messages.append(
|
|
282
|
+
MessageParam(
|
|
283
|
+
role='user',
|
|
284
|
+
content=[
|
|
285
|
+
ToolUseBlockParam(
|
|
286
|
+
id=_guard_tool_call_id(t=part, model_source='Anthropic'),
|
|
287
|
+
input=part.model_response(),
|
|
288
|
+
name=part.tool_name,
|
|
289
|
+
type='tool_use',
|
|
290
|
+
),
|
|
291
|
+
],
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
elif isinstance(m, ModelResponse):
|
|
295
|
+
content: list[TextBlockParam | ToolUseBlockParam] = []
|
|
296
|
+
for item in m.parts:
|
|
297
|
+
if isinstance(item, TextPart):
|
|
298
|
+
content.append(TextBlockParam(text=item.content, type='text'))
|
|
299
|
+
else:
|
|
300
|
+
assert isinstance(item, ToolCallPart)
|
|
301
|
+
content.append(_map_tool_call(item))
|
|
302
|
+
anthropic_messages.append(MessageParam(role='assistant', content=content))
|
|
303
|
+
else:
|
|
304
|
+
assert_never(m)
|
|
305
|
+
return system_prompt, anthropic_messages
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
309
|
+
assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
|
|
310
|
+
return ToolUseBlockParam(
|
|
311
|
+
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
312
|
+
type='tool_use',
|
|
313
|
+
name=t.tool_name,
|
|
314
|
+
input=t.args.args_dict,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _map_cost(message: AnthropicMessage | RawMessageStreamEvent) -> result.Cost:
|
|
319
|
+
if isinstance(message, AnthropicMessage):
|
|
320
|
+
usage = message.usage
|
|
321
|
+
else:
|
|
322
|
+
if isinstance(message, RawMessageStartEvent):
|
|
323
|
+
usage = message.message.usage
|
|
324
|
+
elif isinstance(message, RawMessageDeltaEvent):
|
|
325
|
+
usage = message.usage
|
|
326
|
+
else:
|
|
327
|
+
# No usage information provided in:
|
|
328
|
+
# - RawMessageStopEvent
|
|
329
|
+
# - RawContentBlockStartEvent
|
|
330
|
+
# - RawContentBlockDeltaEvent
|
|
331
|
+
# - RawContentBlockStopEvent
|
|
332
|
+
usage = None
|
|
333
|
+
|
|
334
|
+
if usage is None:
|
|
335
|
+
return result.Cost()
|
|
336
|
+
|
|
337
|
+
request_tokens = getattr(usage, 'input_tokens', None)
|
|
338
|
+
|
|
339
|
+
return result.Cost(
|
|
340
|
+
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
|
|
341
|
+
request_tokens=request_tokens,
|
|
342
|
+
response_tokens=usage.output_tokens,
|
|
343
|
+
total_tokens=(request_tokens or 0) + usage.output_tokens,
|
|
344
|
+
)
|