pydantic-ai-slim 0.0.19__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_pydantic.py +1 -0
- pydantic_ai/_result.py +29 -28
- pydantic_ai/_system_prompt.py +4 -4
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +137 -113
- pydantic_ai/messages.py +24 -56
- pydantic_ai/models/__init__.py +122 -51
- pydantic_ai/models/anthropic.py +109 -38
- pydantic_ai/models/cohere.py +290 -0
- pydantic_ai/models/function.py +12 -8
- pydantic_ai/models/gemini.py +29 -15
- pydantic_ai/models/groq.py +27 -23
- pydantic_ai/models/mistral.py +34 -29
- pydantic_ai/models/openai.py +45 -23
- pydantic_ai/models/test.py +47 -24
- pydantic_ai/models/vertexai.py +2 -1
- pydantic_ai/result.py +45 -26
- pydantic_ai/settings.py +58 -1
- pydantic_ai/tools.py +29 -26
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/METADATA +6 -4
- pydantic_ai_slim-0.0.21.dist-info/RECORD +29 -0
- pydantic_ai/models/ollama.py +0 -120
- pydantic_ai_slim-0.0.19.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.19.dist-info → pydantic_ai_slim-0.0.21.dist-info}/WHEEL +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -6,7 +6,6 @@ from typing import Annotated, Any, Literal, Union, cast, overload
|
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
|
-
from typing_extensions import Self, assert_never
|
|
10
9
|
|
|
11
10
|
from ._utils import now_utc as _now_utc
|
|
12
11
|
from .exceptions import UnexpectedModelBehavior
|
|
@@ -168,22 +167,6 @@ class TextPart:
|
|
|
168
167
|
return bool(self.content)
|
|
169
168
|
|
|
170
169
|
|
|
171
|
-
@dataclass
|
|
172
|
-
class ArgsJson:
|
|
173
|
-
"""Tool arguments as a JSON string."""
|
|
174
|
-
|
|
175
|
-
args_json: str
|
|
176
|
-
"""A JSON string of arguments."""
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@dataclass
|
|
180
|
-
class ArgsDict:
|
|
181
|
-
"""Tool arguments as a Python dictionary."""
|
|
182
|
-
|
|
183
|
-
args_dict: dict[str, Any]
|
|
184
|
-
"""A python dictionary of arguments."""
|
|
185
|
-
|
|
186
|
-
|
|
187
170
|
@dataclass
|
|
188
171
|
class ToolCallPart:
|
|
189
172
|
"""A tool call from a model."""
|
|
@@ -191,10 +174,10 @@ class ToolCallPart:
|
|
|
191
174
|
tool_name: str
|
|
192
175
|
"""The name of the tool to call."""
|
|
193
176
|
|
|
194
|
-
args:
|
|
177
|
+
args: str | dict[str, Any]
|
|
195
178
|
"""The arguments to pass to the tool.
|
|
196
179
|
|
|
197
|
-
|
|
180
|
+
This is stored either as a JSON string or a Python dictionary depending on how data was received.
|
|
198
181
|
"""
|
|
199
182
|
|
|
200
183
|
tool_call_id: str | None = None
|
|
@@ -203,24 +186,14 @@ class ToolCallPart:
|
|
|
203
186
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
204
187
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
205
188
|
|
|
206
|
-
@classmethod
|
|
207
|
-
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
208
|
-
"""Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
|
|
209
|
-
if isinstance(args, str):
|
|
210
|
-
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
211
|
-
elif isinstance(args, dict):
|
|
212
|
-
return cls(tool_name, ArgsDict(args), tool_call_id)
|
|
213
|
-
else:
|
|
214
|
-
assert_never(args)
|
|
215
|
-
|
|
216
189
|
def args_as_dict(self) -> dict[str, Any]:
|
|
217
190
|
"""Return the arguments as a Python dictionary.
|
|
218
191
|
|
|
219
192
|
This is just for convenience with models that require dicts as input.
|
|
220
193
|
"""
|
|
221
|
-
if isinstance(self.args,
|
|
222
|
-
return self.args
|
|
223
|
-
args = pydantic_core.from_json(self.args
|
|
194
|
+
if isinstance(self.args, dict):
|
|
195
|
+
return self.args
|
|
196
|
+
args = pydantic_core.from_json(self.args)
|
|
224
197
|
assert isinstance(args, dict), 'args should be a dict'
|
|
225
198
|
return cast(dict[str, Any], args)
|
|
226
199
|
|
|
@@ -229,16 +202,18 @@ class ToolCallPart:
|
|
|
229
202
|
|
|
230
203
|
This is just for convenience with models that require JSON strings as input.
|
|
231
204
|
"""
|
|
232
|
-
if isinstance(self.args,
|
|
233
|
-
return self.args
|
|
234
|
-
return pydantic_core.to_json(self.args
|
|
205
|
+
if isinstance(self.args, str):
|
|
206
|
+
return self.args
|
|
207
|
+
return pydantic_core.to_json(self.args).decode()
|
|
235
208
|
|
|
236
209
|
def has_content(self) -> bool:
|
|
237
210
|
"""Return `True` if the arguments contain any data."""
|
|
238
|
-
if isinstance(self.args,
|
|
239
|
-
return
|
|
211
|
+
if isinstance(self.args, dict):
|
|
212
|
+
# TODO: This should probably return True if you have the value False, or 0, etc.
|
|
213
|
+
# It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
|
|
214
|
+
return any(self.args.values())
|
|
240
215
|
else:
|
|
241
|
-
return bool(self.args
|
|
216
|
+
return bool(self.args)
|
|
242
217
|
|
|
243
218
|
|
|
244
219
|
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
|
|
@@ -252,6 +227,9 @@ class ModelResponse:
|
|
|
252
227
|
parts: list[ModelResponsePart]
|
|
253
228
|
"""The parts of the model message."""
|
|
254
229
|
|
|
230
|
+
model_name: str | None = None
|
|
231
|
+
"""The name of the model that generated the response."""
|
|
232
|
+
|
|
255
233
|
timestamp: datetime = field(default_factory=_now_utc)
|
|
256
234
|
"""The timestamp of the response.
|
|
257
235
|
|
|
@@ -261,16 +239,6 @@ class ModelResponse:
|
|
|
261
239
|
kind: Literal['response'] = 'response'
|
|
262
240
|
"""Message type identifier, this is available on all parts as a discriminator."""
|
|
263
241
|
|
|
264
|
-
@classmethod
|
|
265
|
-
def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
|
|
266
|
-
"""Create a `ModelResponse` containing a single `TextPart`."""
|
|
267
|
-
return cls([TextPart(content=content)], timestamp=timestamp or _now_utc())
|
|
268
|
-
|
|
269
|
-
@classmethod
|
|
270
|
-
def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
|
|
271
|
-
"""Create a `ModelResponse` containing a single `ToolCallPart`."""
|
|
272
|
-
return cls([tool_call])
|
|
273
|
-
|
|
274
242
|
|
|
275
243
|
ModelMessage = Annotated[Union[ModelRequest, ModelResponse], pydantic.Discriminator('kind')]
|
|
276
244
|
"""Any message sent to or returned by a model."""
|
|
@@ -338,7 +306,7 @@ class ToolCallPartDelta:
|
|
|
338
306
|
if self.tool_name_delta is None or self.args_delta is None:
|
|
339
307
|
return None
|
|
340
308
|
|
|
341
|
-
return ToolCallPart
|
|
309
|
+
return ToolCallPart(
|
|
342
310
|
self.tool_name_delta,
|
|
343
311
|
self.args_delta,
|
|
344
312
|
self.tool_call_id,
|
|
@@ -403,7 +371,7 @@ class ToolCallPartDelta:
|
|
|
403
371
|
|
|
404
372
|
# If we now have enough data to create a full ToolCallPart, do so
|
|
405
373
|
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
406
|
-
return ToolCallPart
|
|
374
|
+
return ToolCallPart(
|
|
407
375
|
delta.tool_name_delta,
|
|
408
376
|
delta.args_delta,
|
|
409
377
|
delta.tool_call_id,
|
|
@@ -419,15 +387,15 @@ class ToolCallPartDelta:
|
|
|
419
387
|
part = replace(part, tool_name=tool_name)
|
|
420
388
|
|
|
421
389
|
if isinstance(self.args_delta, str):
|
|
422
|
-
if not isinstance(part.args,
|
|
390
|
+
if not isinstance(part.args, str):
|
|
423
391
|
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
|
|
424
|
-
updated_json = part.args
|
|
425
|
-
part = replace(part, args=
|
|
392
|
+
updated_json = part.args + self.args_delta
|
|
393
|
+
part = replace(part, args=updated_json)
|
|
426
394
|
elif isinstance(self.args_delta, dict):
|
|
427
|
-
if not isinstance(part.args,
|
|
395
|
+
if not isinstance(part.args, dict):
|
|
428
396
|
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
|
|
429
|
-
updated_dict = {**(part.args
|
|
430
|
-
part = replace(part, args=
|
|
397
|
+
updated_dict = {**(part.args or {}), **self.args_delta}
|
|
398
|
+
part = replace(part, args=updated_dict)
|
|
431
399
|
|
|
432
400
|
if self.tool_call_id:
|
|
433
401
|
# Replace the tool_call_id entirely if given
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -12,9 +12,10 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
12
12
|
from dataclasses import dataclass, field
|
|
13
13
|
from datetime import datetime
|
|
14
14
|
from functools import cache
|
|
15
|
-
from typing import TYPE_CHECKING
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
import httpx
|
|
18
|
+
from typing_extensions import Literal
|
|
18
19
|
|
|
19
20
|
from .._parts_manager import ModelResponsePartsManager
|
|
20
21
|
from ..exceptions import UserError
|
|
@@ -27,60 +28,123 @@ if TYPE_CHECKING:
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
KnownModelName = Literal[
|
|
30
|
-
'
|
|
31
|
-
'
|
|
32
|
-
'
|
|
33
|
-
'
|
|
34
|
-
'
|
|
35
|
-
'
|
|
36
|
-
'
|
|
37
|
-
'
|
|
38
|
-
'
|
|
39
|
-
'
|
|
40
|
-
'
|
|
41
|
-
'
|
|
42
|
-
'
|
|
31
|
+
'anthropic:claude-3-5-haiku-latest',
|
|
32
|
+
'anthropic:claude-3-5-sonnet-latest',
|
|
33
|
+
'anthropic:claude-3-opus-latest',
|
|
34
|
+
'claude-3-5-haiku-latest',
|
|
35
|
+
'claude-3-5-sonnet-latest',
|
|
36
|
+
'claude-3-opus-latest',
|
|
37
|
+
'cohere:c4ai-aya-expanse-32b',
|
|
38
|
+
'cohere:c4ai-aya-expanse-8b',
|
|
39
|
+
'cohere:command',
|
|
40
|
+
'cohere:command-light',
|
|
41
|
+
'cohere:command-light-nightly',
|
|
42
|
+
'cohere:command-nightly',
|
|
43
|
+
'cohere:command-r',
|
|
44
|
+
'cohere:command-r-03-2024',
|
|
45
|
+
'cohere:command-r-08-2024',
|
|
46
|
+
'cohere:command-r-plus',
|
|
47
|
+
'cohere:command-r-plus-04-2024',
|
|
48
|
+
'cohere:command-r-plus-08-2024',
|
|
49
|
+
'cohere:command-r7b-12-2024',
|
|
50
|
+
'google-gla:gemini-1.0-pro',
|
|
51
|
+
'google-gla:gemini-1.5-flash',
|
|
52
|
+
'google-gla:gemini-1.5-flash-8b',
|
|
53
|
+
'google-gla:gemini-1.5-pro',
|
|
54
|
+
'google-gla:gemini-2.0-flash-exp',
|
|
55
|
+
'google-vertex:gemini-1.0-pro',
|
|
56
|
+
'google-vertex:gemini-1.5-flash',
|
|
57
|
+
'google-vertex:gemini-1.5-flash-8b',
|
|
58
|
+
'google-vertex:gemini-1.5-pro',
|
|
59
|
+
'google-vertex:gemini-2.0-flash-exp',
|
|
60
|
+
'gpt-3.5-turbo',
|
|
61
|
+
'gpt-3.5-turbo-0125',
|
|
62
|
+
'gpt-3.5-turbo-0301',
|
|
63
|
+
'gpt-3.5-turbo-0613',
|
|
64
|
+
'gpt-3.5-turbo-1106',
|
|
65
|
+
'gpt-3.5-turbo-16k',
|
|
66
|
+
'gpt-3.5-turbo-16k-0613',
|
|
67
|
+
'gpt-4',
|
|
68
|
+
'gpt-4-0125-preview',
|
|
69
|
+
'gpt-4-0314',
|
|
70
|
+
'gpt-4-0613',
|
|
71
|
+
'gpt-4-1106-preview',
|
|
72
|
+
'gpt-4-32k',
|
|
73
|
+
'gpt-4-32k-0314',
|
|
74
|
+
'gpt-4-32k-0613',
|
|
75
|
+
'gpt-4-turbo',
|
|
76
|
+
'gpt-4-turbo-2024-04-09',
|
|
77
|
+
'gpt-4-turbo-preview',
|
|
78
|
+
'gpt-4-vision-preview',
|
|
79
|
+
'gpt-4o',
|
|
80
|
+
'gpt-4o-2024-05-13',
|
|
81
|
+
'gpt-4o-2024-08-06',
|
|
82
|
+
'gpt-4o-2024-11-20',
|
|
83
|
+
'gpt-4o-audio-preview',
|
|
84
|
+
'gpt-4o-audio-preview-2024-10-01',
|
|
85
|
+
'gpt-4o-audio-preview-2024-12-17',
|
|
86
|
+
'gpt-4o-mini',
|
|
87
|
+
'gpt-4o-mini-2024-07-18',
|
|
88
|
+
'gpt-4o-mini-audio-preview',
|
|
89
|
+
'gpt-4o-mini-audio-preview-2024-12-17',
|
|
90
|
+
'groq:gemma2-9b-it',
|
|
43
91
|
'groq:llama-3.1-8b-instant',
|
|
92
|
+
'groq:llama-3.2-11b-vision-preview',
|
|
44
93
|
'groq:llama-3.2-1b-preview',
|
|
45
94
|
'groq:llama-3.2-3b-preview',
|
|
46
|
-
'groq:llama-3.2-11b-vision-preview',
|
|
47
95
|
'groq:llama-3.2-90b-vision-preview',
|
|
96
|
+
'groq:llama-3.3-70b-specdec',
|
|
97
|
+
'groq:llama-3.3-70b-versatile',
|
|
48
98
|
'groq:llama3-70b-8192',
|
|
49
99
|
'groq:llama3-8b-8192',
|
|
50
100
|
'groq:mixtral-8x7b-32768',
|
|
51
|
-
'groq:gemma2-9b-it',
|
|
52
|
-
'groq:gemma-7b-it',
|
|
53
|
-
'google-gla:gemini-1.5-flash',
|
|
54
|
-
'google-gla:gemini-1.5-pro',
|
|
55
|
-
'google-gla:gemini-2.0-flash-exp',
|
|
56
|
-
'google-vertex:gemini-1.5-flash',
|
|
57
|
-
'google-vertex:gemini-1.5-pro',
|
|
58
|
-
'google-vertex:gemini-2.0-flash-exp',
|
|
59
|
-
'mistral:mistral-small-latest',
|
|
60
|
-
'mistral:mistral-large-latest',
|
|
61
101
|
'mistral:codestral-latest',
|
|
102
|
+
'mistral:mistral-large-latest',
|
|
62
103
|
'mistral:mistral-moderation-latest',
|
|
63
|
-
'
|
|
64
|
-
'
|
|
65
|
-
'
|
|
66
|
-
'
|
|
67
|
-
'
|
|
68
|
-
'
|
|
69
|
-
'
|
|
70
|
-
'
|
|
71
|
-
'
|
|
72
|
-
'
|
|
73
|
-
'
|
|
74
|
-
'
|
|
75
|
-
'
|
|
76
|
-
'
|
|
77
|
-
'
|
|
78
|
-
'
|
|
79
|
-
'
|
|
80
|
-
'
|
|
81
|
-
'
|
|
82
|
-
'
|
|
83
|
-
'
|
|
104
|
+
'mistral:mistral-small-latest',
|
|
105
|
+
'o1',
|
|
106
|
+
'o1-2024-12-17',
|
|
107
|
+
'o1-mini',
|
|
108
|
+
'o1-mini-2024-09-12',
|
|
109
|
+
'o1-preview',
|
|
110
|
+
'o1-preview-2024-09-12',
|
|
111
|
+
'openai:chatgpt-4o-latest',
|
|
112
|
+
'openai:gpt-3.5-turbo',
|
|
113
|
+
'openai:gpt-3.5-turbo-0125',
|
|
114
|
+
'openai:gpt-3.5-turbo-0301',
|
|
115
|
+
'openai:gpt-3.5-turbo-0613',
|
|
116
|
+
'openai:gpt-3.5-turbo-1106',
|
|
117
|
+
'openai:gpt-3.5-turbo-16k',
|
|
118
|
+
'openai:gpt-3.5-turbo-16k-0613',
|
|
119
|
+
'openai:gpt-4',
|
|
120
|
+
'openai:gpt-4-0125-preview',
|
|
121
|
+
'openai:gpt-4-0314',
|
|
122
|
+
'openai:gpt-4-0613',
|
|
123
|
+
'openai:gpt-4-1106-preview',
|
|
124
|
+
'openai:gpt-4-32k',
|
|
125
|
+
'openai:gpt-4-32k-0314',
|
|
126
|
+
'openai:gpt-4-32k-0613',
|
|
127
|
+
'openai:gpt-4-turbo',
|
|
128
|
+
'openai:gpt-4-turbo-2024-04-09',
|
|
129
|
+
'openai:gpt-4-turbo-preview',
|
|
130
|
+
'openai:gpt-4-vision-preview',
|
|
131
|
+
'openai:gpt-4o',
|
|
132
|
+
'openai:gpt-4o-2024-05-13',
|
|
133
|
+
'openai:gpt-4o-2024-08-06',
|
|
134
|
+
'openai:gpt-4o-2024-11-20',
|
|
135
|
+
'openai:gpt-4o-audio-preview',
|
|
136
|
+
'openai:gpt-4o-audio-preview-2024-10-01',
|
|
137
|
+
'openai:gpt-4o-audio-preview-2024-12-17',
|
|
138
|
+
'openai:gpt-4o-mini',
|
|
139
|
+
'openai:gpt-4o-mini-2024-07-18',
|
|
140
|
+
'openai:gpt-4o-mini-audio-preview',
|
|
141
|
+
'openai:gpt-4o-mini-audio-preview-2024-12-17',
|
|
142
|
+
'openai:o1',
|
|
143
|
+
'openai:o1-2024-12-17',
|
|
144
|
+
'openai:o1-mini',
|
|
145
|
+
'openai:o1-mini-2024-09-12',
|
|
146
|
+
'openai:o1-preview',
|
|
147
|
+
'openai:o1-preview-2024-09-12',
|
|
84
148
|
'test',
|
|
85
149
|
]
|
|
86
150
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -145,6 +209,7 @@ class AgentModel(ABC):
|
|
|
145
209
|
class StreamedResponse(ABC):
|
|
146
210
|
"""Streamed response from an LLM when calling a tool."""
|
|
147
211
|
|
|
212
|
+
_model_name: str
|
|
148
213
|
_usage: Usage = field(default_factory=Usage, init=False)
|
|
149
214
|
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
|
|
150
215
|
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
|
|
@@ -168,7 +233,13 @@ class StreamedResponse(ABC):
|
|
|
168
233
|
|
|
169
234
|
def get(self) -> ModelResponse:
|
|
170
235
|
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
|
|
171
|
-
return ModelResponse(
|
|
236
|
+
return ModelResponse(
|
|
237
|
+
parts=self._parts_manager.get_parts(), model_name=self._model_name, timestamp=self.timestamp()
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def model_name(self) -> str:
|
|
241
|
+
"""Get the model name of the response."""
|
|
242
|
+
return self._model_name
|
|
172
243
|
|
|
173
244
|
def usage(self) -> Usage:
|
|
174
245
|
"""Get the usage of the response so far. This will not be the final usage until the stream is exhausted."""
|
|
@@ -228,6 +299,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
228
299
|
from .test import TestModel
|
|
229
300
|
|
|
230
301
|
return TestModel()
|
|
302
|
+
elif model.startswith('cohere:'):
|
|
303
|
+
from .cohere import CohereModel
|
|
304
|
+
|
|
305
|
+
return CohereModel(model[7:])
|
|
231
306
|
elif model.startswith('openai:'):
|
|
232
307
|
from .openai import OpenAIModel
|
|
233
308
|
|
|
@@ -263,10 +338,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
263
338
|
from .mistral import MistralModel
|
|
264
339
|
|
|
265
340
|
return MistralModel(model[8:])
|
|
266
|
-
elif model.startswith('ollama:'):
|
|
267
|
-
from .ollama import OllamaModel
|
|
268
|
-
|
|
269
|
-
return OllamaModel(model[7:])
|
|
270
341
|
elif model.startswith('anthropic'):
|
|
271
342
|
from .anthropic import AnthropicModel
|
|
272
343
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -1,21 +1,23 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
-
from collections.abc import AsyncIterator
|
|
3
|
+
from collections.abc import AsyncIterable, AsyncIterator
|
|
4
4
|
from contextlib import asynccontextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from json import JSONDecodeError, loads as json_loads
|
|
6
8
|
from typing import Any, Literal, Union, cast, overload
|
|
7
9
|
|
|
8
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
11
|
from typing_extensions import assert_never
|
|
10
12
|
|
|
11
|
-
from .. import usage
|
|
13
|
+
from .. import UnexpectedModelBehavior, _utils, usage
|
|
12
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
13
15
|
from ..messages import (
|
|
14
|
-
ArgsDict,
|
|
15
16
|
ModelMessage,
|
|
16
17
|
ModelRequest,
|
|
17
18
|
ModelResponse,
|
|
18
19
|
ModelResponsePart,
|
|
20
|
+
ModelResponseStreamEvent,
|
|
19
21
|
RetryPromptPart,
|
|
20
22
|
SystemPromptPart,
|
|
21
23
|
TextPart,
|
|
@@ -38,11 +40,17 @@ try:
|
|
|
38
40
|
from anthropic.types import (
|
|
39
41
|
Message as AnthropicMessage,
|
|
40
42
|
MessageParam,
|
|
43
|
+
MetadataParam,
|
|
44
|
+
RawContentBlockDeltaEvent,
|
|
45
|
+
RawContentBlockStartEvent,
|
|
46
|
+
RawContentBlockStopEvent,
|
|
41
47
|
RawMessageDeltaEvent,
|
|
42
48
|
RawMessageStartEvent,
|
|
49
|
+
RawMessageStopEvent,
|
|
43
50
|
RawMessageStreamEvent,
|
|
44
51
|
TextBlock,
|
|
45
52
|
TextBlockParam,
|
|
53
|
+
TextDelta,
|
|
46
54
|
ToolChoiceParam,
|
|
47
55
|
ToolParam,
|
|
48
56
|
ToolResultBlockParam,
|
|
@@ -71,6 +79,15 @@ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/model
|
|
|
71
79
|
"""
|
|
72
80
|
|
|
73
81
|
|
|
82
|
+
class AnthropicModelSettings(ModelSettings):
|
|
83
|
+
"""Settings used for an Anthropic model request."""
|
|
84
|
+
|
|
85
|
+
anthropic_metadata: MetadataParam
|
|
86
|
+
"""An object describing metadata about the request.
|
|
87
|
+
|
|
88
|
+
Contains `user_id`, an external identifier for the user who is associated with the request."""
|
|
89
|
+
|
|
90
|
+
|
|
74
91
|
@dataclass(init=False)
|
|
75
92
|
class AnthropicModel(Model):
|
|
76
93
|
"""A model that uses the Anthropic API.
|
|
@@ -152,50 +169,54 @@ class AnthropicAgentModel(AgentModel):
|
|
|
152
169
|
"""Implementation of `AgentModel` for Anthropic models."""
|
|
153
170
|
|
|
154
171
|
client: AsyncAnthropic
|
|
155
|
-
model_name:
|
|
172
|
+
model_name: AnthropicModelName
|
|
156
173
|
allow_text_result: bool
|
|
157
174
|
tools: list[ToolParam]
|
|
158
175
|
|
|
159
176
|
async def request(
|
|
160
177
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
161
178
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
162
|
-
response = await self._messages_create(messages, False, model_settings)
|
|
179
|
+
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
|
|
163
180
|
return self._process_response(response), _map_usage(response)
|
|
164
181
|
|
|
165
182
|
@asynccontextmanager
|
|
166
183
|
async def request_stream(
|
|
167
184
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
185
|
) -> AsyncIterator[StreamedResponse]:
|
|
169
|
-
response = await self._messages_create(messages, True, model_settings)
|
|
186
|
+
response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
|
|
170
187
|
async with response:
|
|
171
188
|
yield await self._process_streamed_response(response)
|
|
172
189
|
|
|
173
190
|
@overload
|
|
174
191
|
async def _messages_create(
|
|
175
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
192
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
|
|
176
193
|
) -> AsyncStream[RawMessageStreamEvent]:
|
|
177
194
|
pass
|
|
178
195
|
|
|
179
196
|
@overload
|
|
180
197
|
async def _messages_create(
|
|
181
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
198
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
|
|
182
199
|
) -> AnthropicMessage:
|
|
183
200
|
pass
|
|
184
201
|
|
|
185
202
|
async def _messages_create(
|
|
186
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
203
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
|
|
187
204
|
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
188
205
|
# standalone function to make it easier to override
|
|
206
|
+
tool_choice: ToolChoiceParam | None
|
|
207
|
+
|
|
189
208
|
if not self.tools:
|
|
190
|
-
tool_choice
|
|
191
|
-
elif not self.allow_text_result:
|
|
192
|
-
tool_choice = {'type': 'any'}
|
|
209
|
+
tool_choice = None
|
|
193
210
|
else:
|
|
194
|
-
|
|
211
|
+
if not self.allow_text_result:
|
|
212
|
+
tool_choice = {'type': 'any'}
|
|
213
|
+
else:
|
|
214
|
+
tool_choice = {'type': 'auto'}
|
|
195
215
|
|
|
196
|
-
|
|
216
|
+
if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None:
|
|
217
|
+
tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls
|
|
197
218
|
|
|
198
|
-
|
|
219
|
+
system_prompt, anthropic_messages = self._map_message(messages)
|
|
199
220
|
|
|
200
221
|
return await self.client.messages.create(
|
|
201
222
|
max_tokens=model_settings.get('max_tokens', 1024),
|
|
@@ -208,10 +229,10 @@ class AnthropicAgentModel(AgentModel):
|
|
|
208
229
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
209
230
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
210
231
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
232
|
+
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
|
|
211
233
|
)
|
|
212
234
|
|
|
213
|
-
|
|
214
|
-
def _process_response(response: AnthropicMessage) -> ModelResponse:
|
|
235
|
+
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
|
|
215
236
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
216
237
|
items: list[ModelResponsePart] = []
|
|
217
238
|
for item in response.content:
|
|
@@ -220,33 +241,24 @@ class AnthropicAgentModel(AgentModel):
|
|
|
220
241
|
else:
|
|
221
242
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
222
243
|
items.append(
|
|
223
|
-
ToolCallPart
|
|
244
|
+
ToolCallPart(
|
|
224
245
|
tool_name=item.name,
|
|
225
246
|
args=cast(dict[str, Any], item.input),
|
|
226
247
|
tool_call_id=item.id,
|
|
227
248
|
)
|
|
228
249
|
)
|
|
229
250
|
|
|
230
|
-
return ModelResponse(items)
|
|
251
|
+
return ModelResponse(items, model_name=self.model_name)
|
|
231
252
|
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
# depending on the type of chunk we get, but we need to establish how we handle (and when we get) the following:
|
|
242
|
-
# RawMessageStartEvent
|
|
243
|
-
# RawMessageDeltaEvent
|
|
244
|
-
# RawMessageStopEvent
|
|
245
|
-
# RawContentBlockStartEvent
|
|
246
|
-
# RawContentBlockDeltaEvent
|
|
247
|
-
# RawContentBlockDeltaEvent
|
|
248
|
-
#
|
|
249
|
-
# We might refactor streaming internally before we implement this...
|
|
253
|
+
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
254
|
+
peekable_response = _utils.PeekableAsyncStream(response)
|
|
255
|
+
first_chunk = await peekable_response.peek()
|
|
256
|
+
if isinstance(first_chunk, _utils.Unset):
|
|
257
|
+
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls')
|
|
258
|
+
|
|
259
|
+
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
|
|
260
|
+
timestamp = datetime.now(tz=timezone.utc)
|
|
261
|
+
return AnthropicStreamedResponse(_model_name=self.model_name, _response=peekable_response, _timestamp=timestamp)
|
|
250
262
|
|
|
251
263
|
@staticmethod
|
|
252
264
|
def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
@@ -306,7 +318,6 @@ class AnthropicAgentModel(AgentModel):
|
|
|
306
318
|
|
|
307
319
|
|
|
308
320
|
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
309
|
-
assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
|
|
310
321
|
return ToolUseBlockParam(
|
|
311
322
|
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
312
323
|
type='tool_use',
|
|
@@ -342,3 +353,63 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
|
|
|
342
353
|
response_tokens=response_usage.output_tokens,
|
|
343
354
|
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
|
|
344
355
|
)
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
@dataclass
|
|
359
|
+
class AnthropicStreamedResponse(StreamedResponse):
|
|
360
|
+
"""Implementation of `StreamedResponse` for Anthropic models."""
|
|
361
|
+
|
|
362
|
+
_response: AsyncIterable[RawMessageStreamEvent]
|
|
363
|
+
_timestamp: datetime
|
|
364
|
+
|
|
365
|
+
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
|
|
366
|
+
current_block: TextBlock | ToolUseBlock | None = None
|
|
367
|
+
current_json: str = ''
|
|
368
|
+
|
|
369
|
+
async for event in self._response:
|
|
370
|
+
self._usage += _map_usage(event)
|
|
371
|
+
|
|
372
|
+
if isinstance(event, RawContentBlockStartEvent):
|
|
373
|
+
current_block = event.content_block
|
|
374
|
+
if isinstance(current_block, TextBlock) and current_block.text:
|
|
375
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text)
|
|
376
|
+
elif isinstance(current_block, ToolUseBlock):
|
|
377
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
378
|
+
vendor_part_id=current_block.id,
|
|
379
|
+
tool_name=current_block.name,
|
|
380
|
+
args=cast(dict[str, Any], current_block.input),
|
|
381
|
+
tool_call_id=current_block.id,
|
|
382
|
+
)
|
|
383
|
+
if maybe_event is not None:
|
|
384
|
+
yield maybe_event
|
|
385
|
+
|
|
386
|
+
elif isinstance(event, RawContentBlockDeltaEvent):
|
|
387
|
+
if isinstance(event.delta, TextDelta):
|
|
388
|
+
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text)
|
|
389
|
+
elif (
|
|
390
|
+
current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
|
|
391
|
+
):
|
|
392
|
+
# Try to parse the JSON immediately, otherwise cache the value for later. This handles
|
|
393
|
+
# cases where the JSON is not currently valid but will be valid once we stream more tokens.
|
|
394
|
+
try:
|
|
395
|
+
parsed_args = json_loads(current_json + event.delta.partial_json)
|
|
396
|
+
current_json = ''
|
|
397
|
+
except JSONDecodeError:
|
|
398
|
+
current_json += event.delta.partial_json
|
|
399
|
+
continue
|
|
400
|
+
|
|
401
|
+
# For tool calls, we need to handle partial JSON updates
|
|
402
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
403
|
+
vendor_part_id=current_block.id,
|
|
404
|
+
tool_name='',
|
|
405
|
+
args=parsed_args,
|
|
406
|
+
tool_call_id=current_block.id,
|
|
407
|
+
)
|
|
408
|
+
if maybe_event is not None:
|
|
409
|
+
yield maybe_event
|
|
410
|
+
|
|
411
|
+
elif isinstance(event, (RawContentBlockStopEvent, RawMessageStopEvent)):
|
|
412
|
+
current_block = None
|
|
413
|
+
|
|
414
|
+
def timestamp(self) -> datetime:
|
|
415
|
+
return self._timestamp
|