pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +770 -0
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_result.py +3 -7
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +192 -560
- pydantic_ai/messages.py +21 -46
- pydantic_ai/models/__init__.py +104 -57
- pydantic_ai/models/anthropic.py +17 -10
- pydantic_ai/models/cohere.py +37 -25
- pydantic_ai/models/gemini.py +27 -7
- pydantic_ai/models/groq.py +19 -17
- pydantic_ai/models/mistral.py +22 -23
- pydantic_ai/models/openai.py +25 -12
- pydantic_ai/models/test.py +37 -22
- pydantic_ai/result.py +1 -1
- pydantic_ai/settings.py +46 -1
- pydantic_ai/tools.py +11 -8
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.22.dist-info}/METADATA +2 -3
- pydantic_ai_slim-0.0.22.dist-info/RECORD +30 -0
- pydantic_ai/models/ollama.py +0 -123
- pydantic_ai_slim-0.0.20.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.22.dist-info}/WHEEL +0 -0
pydantic_ai/messages.py
CHANGED
|
@@ -6,7 +6,6 @@ from typing import Annotated, Any, Literal, Union, cast, overload
|
|
|
6
6
|
|
|
7
7
|
import pydantic
|
|
8
8
|
import pydantic_core
|
|
9
|
-
from typing_extensions import Self, assert_never
|
|
10
9
|
|
|
11
10
|
from ._utils import now_utc as _now_utc
|
|
12
11
|
from .exceptions import UnexpectedModelBehavior
|
|
@@ -168,22 +167,6 @@ class TextPart:
|
|
|
168
167
|
return bool(self.content)
|
|
169
168
|
|
|
170
169
|
|
|
171
|
-
@dataclass
|
|
172
|
-
class ArgsJson:
|
|
173
|
-
"""Tool arguments as a JSON string."""
|
|
174
|
-
|
|
175
|
-
args_json: str
|
|
176
|
-
"""A JSON string of arguments."""
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
@dataclass
|
|
180
|
-
class ArgsDict:
|
|
181
|
-
"""Tool arguments as a Python dictionary."""
|
|
182
|
-
|
|
183
|
-
args_dict: dict[str, Any]
|
|
184
|
-
"""A python dictionary of arguments."""
|
|
185
|
-
|
|
186
|
-
|
|
187
170
|
@dataclass
|
|
188
171
|
class ToolCallPart:
|
|
189
172
|
"""A tool call from a model."""
|
|
@@ -191,10 +174,10 @@ class ToolCallPart:
|
|
|
191
174
|
tool_name: str
|
|
192
175
|
"""The name of the tool to call."""
|
|
193
176
|
|
|
194
|
-
args:
|
|
177
|
+
args: str | dict[str, Any]
|
|
195
178
|
"""The arguments to pass to the tool.
|
|
196
179
|
|
|
197
|
-
|
|
180
|
+
This is stored either as a JSON string or a Python dictionary depending on how data was received.
|
|
198
181
|
"""
|
|
199
182
|
|
|
200
183
|
tool_call_id: str | None = None
|
|
@@ -203,24 +186,14 @@ class ToolCallPart:
|
|
|
203
186
|
part_kind: Literal['tool-call'] = 'tool-call'
|
|
204
187
|
"""Part type identifier, this is available on all parts as a discriminator."""
|
|
205
188
|
|
|
206
|
-
@classmethod
|
|
207
|
-
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
|
|
208
|
-
"""Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
|
|
209
|
-
if isinstance(args, str):
|
|
210
|
-
return cls(tool_name, ArgsJson(args), tool_call_id)
|
|
211
|
-
elif isinstance(args, dict):
|
|
212
|
-
return cls(tool_name, ArgsDict(args), tool_call_id)
|
|
213
|
-
else:
|
|
214
|
-
assert_never(args)
|
|
215
|
-
|
|
216
189
|
def args_as_dict(self) -> dict[str, Any]:
|
|
217
190
|
"""Return the arguments as a Python dictionary.
|
|
218
191
|
|
|
219
192
|
This is just for convenience with models that require dicts as input.
|
|
220
193
|
"""
|
|
221
|
-
if isinstance(self.args,
|
|
222
|
-
return self.args
|
|
223
|
-
args = pydantic_core.from_json(self.args
|
|
194
|
+
if isinstance(self.args, dict):
|
|
195
|
+
return self.args
|
|
196
|
+
args = pydantic_core.from_json(self.args)
|
|
224
197
|
assert isinstance(args, dict), 'args should be a dict'
|
|
225
198
|
return cast(dict[str, Any], args)
|
|
226
199
|
|
|
@@ -229,16 +202,18 @@ class ToolCallPart:
|
|
|
229
202
|
|
|
230
203
|
This is just for convenience with models that require JSON strings as input.
|
|
231
204
|
"""
|
|
232
|
-
if isinstance(self.args,
|
|
233
|
-
return self.args
|
|
234
|
-
return pydantic_core.to_json(self.args
|
|
205
|
+
if isinstance(self.args, str):
|
|
206
|
+
return self.args
|
|
207
|
+
return pydantic_core.to_json(self.args).decode()
|
|
235
208
|
|
|
236
209
|
def has_content(self) -> bool:
|
|
237
210
|
"""Return `True` if the arguments contain any data."""
|
|
238
|
-
if isinstance(self.args,
|
|
239
|
-
return
|
|
211
|
+
if isinstance(self.args, dict):
|
|
212
|
+
# TODO: This should probably return True if you have the value False, or 0, etc.
|
|
213
|
+
# It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
|
|
214
|
+
return any(self.args.values())
|
|
240
215
|
else:
|
|
241
|
-
return bool(self.args
|
|
216
|
+
return bool(self.args)
|
|
242
217
|
|
|
243
218
|
|
|
244
219
|
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
|
|
@@ -331,7 +306,7 @@ class ToolCallPartDelta:
|
|
|
331
306
|
if self.tool_name_delta is None or self.args_delta is None:
|
|
332
307
|
return None
|
|
333
308
|
|
|
334
|
-
return ToolCallPart
|
|
309
|
+
return ToolCallPart(
|
|
335
310
|
self.tool_name_delta,
|
|
336
311
|
self.args_delta,
|
|
337
312
|
self.tool_call_id,
|
|
@@ -396,7 +371,7 @@ class ToolCallPartDelta:
|
|
|
396
371
|
|
|
397
372
|
# If we now have enough data to create a full ToolCallPart, do so
|
|
398
373
|
if delta.tool_name_delta is not None and delta.args_delta is not None:
|
|
399
|
-
return ToolCallPart
|
|
374
|
+
return ToolCallPart(
|
|
400
375
|
delta.tool_name_delta,
|
|
401
376
|
delta.args_delta,
|
|
402
377
|
delta.tool_call_id,
|
|
@@ -412,15 +387,15 @@ class ToolCallPartDelta:
|
|
|
412
387
|
part = replace(part, tool_name=tool_name)
|
|
413
388
|
|
|
414
389
|
if isinstance(self.args_delta, str):
|
|
415
|
-
if not isinstance(part.args,
|
|
390
|
+
if not isinstance(part.args, str):
|
|
416
391
|
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
|
|
417
|
-
updated_json = part.args
|
|
418
|
-
part = replace(part, args=
|
|
392
|
+
updated_json = part.args + self.args_delta
|
|
393
|
+
part = replace(part, args=updated_json)
|
|
419
394
|
elif isinstance(self.args_delta, dict):
|
|
420
|
-
if not isinstance(part.args,
|
|
395
|
+
if not isinstance(part.args, dict):
|
|
421
396
|
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
|
|
422
|
-
updated_dict = {**(part.args
|
|
423
|
-
part = replace(part, args=
|
|
397
|
+
updated_dict = {**(part.args or {}), **self.args_delta}
|
|
398
|
+
part = replace(part, args=updated_dict)
|
|
424
399
|
|
|
425
400
|
if self.tool_call_id:
|
|
426
401
|
# Replace the tool_call_id entirely if given
|
pydantic_ai/models/__init__.py
CHANGED
|
@@ -12,9 +12,10 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
12
12
|
from dataclasses import dataclass, field
|
|
13
13
|
from datetime import datetime
|
|
14
14
|
from functools import cache
|
|
15
|
-
from typing import TYPE_CHECKING
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
16
|
|
|
17
17
|
import httpx
|
|
18
|
+
from typing_extensions import Literal
|
|
18
19
|
|
|
19
20
|
from .._parts_manager import ModelResponsePartsManager
|
|
20
21
|
from ..exceptions import UserError
|
|
@@ -27,58 +28,6 @@ if TYPE_CHECKING:
|
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
KnownModelName = Literal[
|
|
30
|
-
'openai:gpt-4o',
|
|
31
|
-
'openai:gpt-4o-mini',
|
|
32
|
-
'openai:gpt-4-turbo',
|
|
33
|
-
'openai:gpt-4',
|
|
34
|
-
'openai:o1-preview',
|
|
35
|
-
'openai:o1-mini',
|
|
36
|
-
'openai:o1',
|
|
37
|
-
'openai:gpt-3.5-turbo',
|
|
38
|
-
'groq:llama-3.3-70b-versatile',
|
|
39
|
-
'groq:llama-3.1-70b-versatile',
|
|
40
|
-
'groq:llama3-groq-70b-8192-tool-use-preview',
|
|
41
|
-
'groq:llama3-groq-8b-8192-tool-use-preview',
|
|
42
|
-
'groq:llama-3.1-70b-specdec',
|
|
43
|
-
'groq:llama-3.1-8b-instant',
|
|
44
|
-
'groq:llama-3.2-1b-preview',
|
|
45
|
-
'groq:llama-3.2-3b-preview',
|
|
46
|
-
'groq:llama-3.2-11b-vision-preview',
|
|
47
|
-
'groq:llama-3.2-90b-vision-preview',
|
|
48
|
-
'groq:llama3-70b-8192',
|
|
49
|
-
'groq:llama3-8b-8192',
|
|
50
|
-
'groq:mixtral-8x7b-32768',
|
|
51
|
-
'groq:gemma2-9b-it',
|
|
52
|
-
'groq:gemma-7b-it',
|
|
53
|
-
'google-gla:gemini-1.5-flash',
|
|
54
|
-
'google-gla:gemini-1.5-pro',
|
|
55
|
-
'google-gla:gemini-2.0-flash-exp',
|
|
56
|
-
'google-vertex:gemini-1.5-flash',
|
|
57
|
-
'google-vertex:gemini-1.5-pro',
|
|
58
|
-
'google-vertex:gemini-2.0-flash-exp',
|
|
59
|
-
'mistral:mistral-small-latest',
|
|
60
|
-
'mistral:mistral-large-latest',
|
|
61
|
-
'mistral:codestral-latest',
|
|
62
|
-
'mistral:mistral-moderation-latest',
|
|
63
|
-
'ollama:codellama',
|
|
64
|
-
'ollama:deepseek-r1',
|
|
65
|
-
'ollama:gemma',
|
|
66
|
-
'ollama:gemma2',
|
|
67
|
-
'ollama:llama3',
|
|
68
|
-
'ollama:llama3.1',
|
|
69
|
-
'ollama:llama3.2',
|
|
70
|
-
'ollama:llama3.2-vision',
|
|
71
|
-
'ollama:llama3.3',
|
|
72
|
-
'ollama:mistral',
|
|
73
|
-
'ollama:mistral-nemo',
|
|
74
|
-
'ollama:mixtral',
|
|
75
|
-
'ollama:phi3',
|
|
76
|
-
'ollama:phi4',
|
|
77
|
-
'ollama:qwq',
|
|
78
|
-
'ollama:qwen',
|
|
79
|
-
'ollama:qwen2',
|
|
80
|
-
'ollama:qwen2.5',
|
|
81
|
-
'ollama:starcoder2',
|
|
82
31
|
'anthropic:claude-3-5-haiku-latest',
|
|
83
32
|
'anthropic:claude-3-5-sonnet-latest',
|
|
84
33
|
'anthropic:claude-3-opus-latest',
|
|
@@ -98,6 +47,108 @@ KnownModelName = Literal[
|
|
|
98
47
|
'cohere:command-r-plus-04-2024',
|
|
99
48
|
'cohere:command-r-plus-08-2024',
|
|
100
49
|
'cohere:command-r7b-12-2024',
|
|
50
|
+
'google-gla:gemini-1.0-pro',
|
|
51
|
+
'google-gla:gemini-1.5-flash',
|
|
52
|
+
'google-gla:gemini-1.5-flash-8b',
|
|
53
|
+
'google-gla:gemini-1.5-pro',
|
|
54
|
+
'google-gla:gemini-2.0-flash-exp',
|
|
55
|
+
'google-gla:gemini-2.0-flash-thinking-exp-01-21',
|
|
56
|
+
'google-gla:gemini-exp-1206',
|
|
57
|
+
'google-vertex:gemini-1.0-pro',
|
|
58
|
+
'google-vertex:gemini-1.5-flash',
|
|
59
|
+
'google-vertex:gemini-1.5-flash-8b',
|
|
60
|
+
'google-vertex:gemini-1.5-pro',
|
|
61
|
+
'google-vertex:gemini-2.0-flash-exp',
|
|
62
|
+
'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
|
|
63
|
+
'google-vertex:gemini-exp-1206',
|
|
64
|
+
'gpt-3.5-turbo',
|
|
65
|
+
'gpt-3.5-turbo-0125',
|
|
66
|
+
'gpt-3.5-turbo-0301',
|
|
67
|
+
'gpt-3.5-turbo-0613',
|
|
68
|
+
'gpt-3.5-turbo-1106',
|
|
69
|
+
'gpt-3.5-turbo-16k',
|
|
70
|
+
'gpt-3.5-turbo-16k-0613',
|
|
71
|
+
'gpt-4',
|
|
72
|
+
'gpt-4-0125-preview',
|
|
73
|
+
'gpt-4-0314',
|
|
74
|
+
'gpt-4-0613',
|
|
75
|
+
'gpt-4-1106-preview',
|
|
76
|
+
'gpt-4-32k',
|
|
77
|
+
'gpt-4-32k-0314',
|
|
78
|
+
'gpt-4-32k-0613',
|
|
79
|
+
'gpt-4-turbo',
|
|
80
|
+
'gpt-4-turbo-2024-04-09',
|
|
81
|
+
'gpt-4-turbo-preview',
|
|
82
|
+
'gpt-4-vision-preview',
|
|
83
|
+
'gpt-4o',
|
|
84
|
+
'gpt-4o-2024-05-13',
|
|
85
|
+
'gpt-4o-2024-08-06',
|
|
86
|
+
'gpt-4o-2024-11-20',
|
|
87
|
+
'gpt-4o-audio-preview',
|
|
88
|
+
'gpt-4o-audio-preview-2024-10-01',
|
|
89
|
+
'gpt-4o-audio-preview-2024-12-17',
|
|
90
|
+
'gpt-4o-mini',
|
|
91
|
+
'gpt-4o-mini-2024-07-18',
|
|
92
|
+
'gpt-4o-mini-audio-preview',
|
|
93
|
+
'gpt-4o-mini-audio-preview-2024-12-17',
|
|
94
|
+
'groq:gemma2-9b-it',
|
|
95
|
+
'groq:llama-3.1-8b-instant',
|
|
96
|
+
'groq:llama-3.2-11b-vision-preview',
|
|
97
|
+
'groq:llama-3.2-1b-preview',
|
|
98
|
+
'groq:llama-3.2-3b-preview',
|
|
99
|
+
'groq:llama-3.2-90b-vision-preview',
|
|
100
|
+
'groq:llama-3.3-70b-specdec',
|
|
101
|
+
'groq:llama-3.3-70b-versatile',
|
|
102
|
+
'groq:llama3-70b-8192',
|
|
103
|
+
'groq:llama3-8b-8192',
|
|
104
|
+
'groq:mixtral-8x7b-32768',
|
|
105
|
+
'mistral:codestral-latest',
|
|
106
|
+
'mistral:mistral-large-latest',
|
|
107
|
+
'mistral:mistral-moderation-latest',
|
|
108
|
+
'mistral:mistral-small-latest',
|
|
109
|
+
'o1',
|
|
110
|
+
'o1-2024-12-17',
|
|
111
|
+
'o1-mini',
|
|
112
|
+
'o1-mini-2024-09-12',
|
|
113
|
+
'o1-preview',
|
|
114
|
+
'o1-preview-2024-09-12',
|
|
115
|
+
'openai:chatgpt-4o-latest',
|
|
116
|
+
'openai:gpt-3.5-turbo',
|
|
117
|
+
'openai:gpt-3.5-turbo-0125',
|
|
118
|
+
'openai:gpt-3.5-turbo-0301',
|
|
119
|
+
'openai:gpt-3.5-turbo-0613',
|
|
120
|
+
'openai:gpt-3.5-turbo-1106',
|
|
121
|
+
'openai:gpt-3.5-turbo-16k',
|
|
122
|
+
'openai:gpt-3.5-turbo-16k-0613',
|
|
123
|
+
'openai:gpt-4',
|
|
124
|
+
'openai:gpt-4-0125-preview',
|
|
125
|
+
'openai:gpt-4-0314',
|
|
126
|
+
'openai:gpt-4-0613',
|
|
127
|
+
'openai:gpt-4-1106-preview',
|
|
128
|
+
'openai:gpt-4-32k',
|
|
129
|
+
'openai:gpt-4-32k-0314',
|
|
130
|
+
'openai:gpt-4-32k-0613',
|
|
131
|
+
'openai:gpt-4-turbo',
|
|
132
|
+
'openai:gpt-4-turbo-2024-04-09',
|
|
133
|
+
'openai:gpt-4-turbo-preview',
|
|
134
|
+
'openai:gpt-4-vision-preview',
|
|
135
|
+
'openai:gpt-4o',
|
|
136
|
+
'openai:gpt-4o-2024-05-13',
|
|
137
|
+
'openai:gpt-4o-2024-08-06',
|
|
138
|
+
'openai:gpt-4o-2024-11-20',
|
|
139
|
+
'openai:gpt-4o-audio-preview',
|
|
140
|
+
'openai:gpt-4o-audio-preview-2024-10-01',
|
|
141
|
+
'openai:gpt-4o-audio-preview-2024-12-17',
|
|
142
|
+
'openai:gpt-4o-mini',
|
|
143
|
+
'openai:gpt-4o-mini-2024-07-18',
|
|
144
|
+
'openai:gpt-4o-mini-audio-preview',
|
|
145
|
+
'openai:gpt-4o-mini-audio-preview-2024-12-17',
|
|
146
|
+
'openai:o1',
|
|
147
|
+
'openai:o1-2024-12-17',
|
|
148
|
+
'openai:o1-mini',
|
|
149
|
+
'openai:o1-mini-2024-09-12',
|
|
150
|
+
'openai:o1-preview',
|
|
151
|
+
'openai:o1-preview-2024-09-12',
|
|
101
152
|
'test',
|
|
102
153
|
]
|
|
103
154
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -291,10 +342,6 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
291
342
|
from .mistral import MistralModel
|
|
292
343
|
|
|
293
344
|
return MistralModel(model[8:])
|
|
294
|
-
elif model.startswith('ollama:'):
|
|
295
|
-
from .ollama import OllamaModel
|
|
296
|
-
|
|
297
|
-
return OllamaModel(model[7:])
|
|
298
345
|
elif model.startswith('anthropic'):
|
|
299
346
|
from .anthropic import AnthropicModel
|
|
300
347
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
|
-
ArgsDict,
|
|
17
16
|
ModelMessage,
|
|
18
17
|
ModelRequest,
|
|
19
18
|
ModelResponse,
|
|
@@ -41,6 +40,7 @@ try:
|
|
|
41
40
|
from anthropic.types import (
|
|
42
41
|
Message as AnthropicMessage,
|
|
43
42
|
MessageParam,
|
|
43
|
+
MetadataParam,
|
|
44
44
|
RawContentBlockDeltaEvent,
|
|
45
45
|
RawContentBlockStartEvent,
|
|
46
46
|
RawContentBlockStopEvent,
|
|
@@ -79,6 +79,15 @@ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/model
|
|
|
79
79
|
"""
|
|
80
80
|
|
|
81
81
|
|
|
82
|
+
class AnthropicModelSettings(ModelSettings):
|
|
83
|
+
"""Settings used for an Anthropic model request."""
|
|
84
|
+
|
|
85
|
+
anthropic_metadata: MetadataParam
|
|
86
|
+
"""An object describing metadata about the request.
|
|
87
|
+
|
|
88
|
+
Contains `user_id`, an external identifier for the user who is associated with the request."""
|
|
89
|
+
|
|
90
|
+
|
|
82
91
|
@dataclass(init=False)
|
|
83
92
|
class AnthropicModel(Model):
|
|
84
93
|
"""A model that uses the Anthropic API.
|
|
@@ -167,35 +176,33 @@ class AnthropicAgentModel(AgentModel):
|
|
|
167
176
|
async def request(
|
|
168
177
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
169
178
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
170
|
-
response = await self._messages_create(messages, False, model_settings)
|
|
179
|
+
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
|
|
171
180
|
return self._process_response(response), _map_usage(response)
|
|
172
181
|
|
|
173
182
|
@asynccontextmanager
|
|
174
183
|
async def request_stream(
|
|
175
184
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
176
185
|
) -> AsyncIterator[StreamedResponse]:
|
|
177
|
-
response = await self._messages_create(messages, True, model_settings)
|
|
186
|
+
response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
|
|
178
187
|
async with response:
|
|
179
188
|
yield await self._process_streamed_response(response)
|
|
180
189
|
|
|
181
190
|
@overload
|
|
182
191
|
async def _messages_create(
|
|
183
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
192
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
|
|
184
193
|
) -> AsyncStream[RawMessageStreamEvent]:
|
|
185
194
|
pass
|
|
186
195
|
|
|
187
196
|
@overload
|
|
188
197
|
async def _messages_create(
|
|
189
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
198
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
|
|
190
199
|
) -> AnthropicMessage:
|
|
191
200
|
pass
|
|
192
201
|
|
|
193
202
|
async def _messages_create(
|
|
194
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
203
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
|
|
195
204
|
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
196
205
|
# standalone function to make it easier to override
|
|
197
|
-
model_settings = model_settings or {}
|
|
198
|
-
|
|
199
206
|
tool_choice: ToolChoiceParam | None
|
|
200
207
|
|
|
201
208
|
if not self.tools:
|
|
@@ -222,6 +229,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
222
229
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
223
230
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
224
231
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
232
|
+
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
|
|
225
233
|
)
|
|
226
234
|
|
|
227
235
|
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
|
|
@@ -233,7 +241,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
233
241
|
else:
|
|
234
242
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
235
243
|
items.append(
|
|
236
|
-
ToolCallPart
|
|
244
|
+
ToolCallPart(
|
|
237
245
|
tool_name=item.name,
|
|
238
246
|
args=cast(dict[str, Any], item.input),
|
|
239
247
|
tool_call_id=item.id,
|
|
@@ -310,7 +318,6 @@ class AnthropicAgentModel(AgentModel):
|
|
|
310
318
|
|
|
311
319
|
|
|
312
320
|
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
313
|
-
assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
|
|
314
321
|
return ToolUseBlockParam(
|
|
315
322
|
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
316
323
|
type='tool_use',
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from itertools import chain
|
|
6
|
-
from typing import Literal,
|
|
6
|
+
from typing import Literal, Union, cast
|
|
7
7
|
|
|
8
8
|
from cohere import TextAssistantMessageContentItem
|
|
9
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
10
|
from typing_extensions import assert_never
|
|
10
11
|
|
|
11
12
|
from .. import result
|
|
@@ -51,24 +52,30 @@ except ImportError as _import_error:
|
|
|
51
52
|
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
|
52
53
|
) from _import_error
|
|
53
54
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
'command-r-plus-08-2024',
|
|
69
|
-
'command-r7b-12-2024',
|
|
70
|
-
],
|
|
55
|
+
NamedCohereModels = Literal[
|
|
56
|
+
'c4ai-aya-expanse-32b',
|
|
57
|
+
'c4ai-aya-expanse-8b',
|
|
58
|
+
'command',
|
|
59
|
+
'command-light',
|
|
60
|
+
'command-light-nightly',
|
|
61
|
+
'command-nightly',
|
|
62
|
+
'command-r',
|
|
63
|
+
'command-r-03-2024',
|
|
64
|
+
'command-r-08-2024',
|
|
65
|
+
'command-r-plus',
|
|
66
|
+
'command-r-plus-04-2024',
|
|
67
|
+
'command-r-plus-08-2024',
|
|
68
|
+
'command-r7b-12-2024',
|
|
71
69
|
]
|
|
70
|
+
"""Latest / most popular named Cohere models."""
|
|
71
|
+
|
|
72
|
+
CohereModelName = Union[NamedCohereModels, str]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class CohereModelSettings(ModelSettings):
|
|
76
|
+
"""Settings used for a Cohere model request."""
|
|
77
|
+
|
|
78
|
+
# This class is a placeholder for any future cohere-specific settings
|
|
72
79
|
|
|
73
80
|
|
|
74
81
|
@dataclass(init=False)
|
|
@@ -90,6 +97,7 @@ class CohereModel(Model):
|
|
|
90
97
|
*,
|
|
91
98
|
api_key: str | None = None,
|
|
92
99
|
cohere_client: AsyncClientV2 | None = None,
|
|
100
|
+
http_client: AsyncHTTPClient | None = None,
|
|
93
101
|
):
|
|
94
102
|
"""Initialize an Cohere model.
|
|
95
103
|
|
|
@@ -97,16 +105,18 @@ class CohereModel(Model):
|
|
|
97
105
|
model_name: The name of the Cohere model to use. List of model names
|
|
98
106
|
available [here](https://docs.cohere.com/docs/models#command).
|
|
99
107
|
api_key: The API key to use for authentication, if not provided, the
|
|
100
|
-
`
|
|
108
|
+
`CO_API_KEY` environment variable will be used if available.
|
|
101
109
|
cohere_client: An existing Cohere async client to use. If provided,
|
|
102
|
-
`api_key` must be `None`.
|
|
110
|
+
`api_key` and `http_client` must be `None`.
|
|
111
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
103
112
|
"""
|
|
104
113
|
self.model_name: CohereModelName = model_name
|
|
105
114
|
if cohere_client is not None:
|
|
115
|
+
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
|
|
106
116
|
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
107
117
|
self.client = cohere_client
|
|
108
118
|
else:
|
|
109
|
-
self.client = AsyncClientV2(api_key=api_key) # type: ignore
|
|
119
|
+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
|
|
110
120
|
|
|
111
121
|
async def agent_model(
|
|
112
122
|
self,
|
|
@@ -153,16 +163,15 @@ class CohereAgentModel(AgentModel):
|
|
|
153
163
|
async def request(
|
|
154
164
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
155
165
|
) -> tuple[ModelResponse, result.Usage]:
|
|
156
|
-
response = await self._chat(messages, model_settings)
|
|
166
|
+
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
|
|
157
167
|
return self._process_response(response), _map_usage(response)
|
|
158
168
|
|
|
159
169
|
async def _chat(
|
|
160
170
|
self,
|
|
161
171
|
messages: list[ModelMessage],
|
|
162
|
-
model_settings:
|
|
172
|
+
model_settings: CohereModelSettings,
|
|
163
173
|
) -> ChatResponse:
|
|
164
174
|
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
165
|
-
model_settings = model_settings or {}
|
|
166
175
|
return await self.client.chat(
|
|
167
176
|
model=self.model_name,
|
|
168
177
|
messages=cohere_messages,
|
|
@@ -170,6 +179,9 @@ class CohereAgentModel(AgentModel):
|
|
|
170
179
|
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
171
180
|
temperature=model_settings.get('temperature', OMIT),
|
|
172
181
|
p=model_settings.get('top_p', OMIT),
|
|
182
|
+
seed=model_settings.get('seed', OMIT),
|
|
183
|
+
presence_penalty=model_settings.get('presence_penalty', OMIT),
|
|
184
|
+
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
|
|
173
185
|
)
|
|
174
186
|
|
|
175
187
|
def _process_response(self, response: ChatResponse) -> ModelResponse:
|
|
@@ -183,7 +195,7 @@ class CohereAgentModel(AgentModel):
|
|
|
183
195
|
for c in response.message.tool_calls or []:
|
|
184
196
|
if c.function and c.function.name and c.function.arguments:
|
|
185
197
|
parts.append(
|
|
186
|
-
ToolCallPart
|
|
198
|
+
ToolCallPart(
|
|
187
199
|
tool_name=c.function.name,
|
|
188
200
|
args=c.function.arguments,
|
|
189
201
|
tool_call_id=c.id,
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
|
-
from typing import Annotated, Any, Literal, Protocol, Union
|
|
10
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
11
11
|
from uuid import uuid4
|
|
12
12
|
|
|
13
13
|
import pydantic
|
|
@@ -40,7 +40,13 @@ from . import (
|
|
|
40
40
|
)
|
|
41
41
|
|
|
42
42
|
GeminiModelName = Literal[
|
|
43
|
-
'gemini-1.5-flash',
|
|
43
|
+
'gemini-1.5-flash',
|
|
44
|
+
'gemini-1.5-flash-8b',
|
|
45
|
+
'gemini-1.5-pro',
|
|
46
|
+
'gemini-1.0-pro',
|
|
47
|
+
'gemini-2.0-flash-exp',
|
|
48
|
+
'gemini-2.0-flash-thinking-exp-01-21',
|
|
49
|
+
'gemini-exp-1206',
|
|
44
50
|
]
|
|
45
51
|
"""Named Gemini models.
|
|
46
52
|
|
|
@@ -48,6 +54,12 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
|
|
|
48
54
|
"""
|
|
49
55
|
|
|
50
56
|
|
|
57
|
+
class GeminiModelSettings(ModelSettings):
|
|
58
|
+
"""Settings used for a Gemini model request."""
|
|
59
|
+
|
|
60
|
+
# This class is a placeholder for any future gemini-specific settings
|
|
61
|
+
|
|
62
|
+
|
|
51
63
|
@dataclass(init=False)
|
|
52
64
|
class GeminiModel(Model):
|
|
53
65
|
"""A model that uses Gemini via `generativelanguage.googleapis.com` API.
|
|
@@ -171,7 +183,9 @@ class GeminiAgentModel(AgentModel):
|
|
|
171
183
|
async def request(
|
|
172
184
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
173
185
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
174
|
-
async with self._make_request(
|
|
186
|
+
async with self._make_request(
|
|
187
|
+
messages, False, cast(GeminiModelSettings, model_settings or {})
|
|
188
|
+
) as http_response:
|
|
175
189
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
176
190
|
return self._process_response(response), _metadata_as_usage(response)
|
|
177
191
|
|
|
@@ -179,12 +193,12 @@ class GeminiAgentModel(AgentModel):
|
|
|
179
193
|
async def request_stream(
|
|
180
194
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
181
195
|
) -> AsyncIterator[StreamedResponse]:
|
|
182
|
-
async with self._make_request(messages, True, model_settings) as http_response:
|
|
196
|
+
async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
|
|
183
197
|
yield await self._process_streamed_response(http_response)
|
|
184
198
|
|
|
185
199
|
@asynccontextmanager
|
|
186
200
|
async def _make_request(
|
|
187
|
-
self, messages: list[ModelMessage], streamed: bool, model_settings:
|
|
201
|
+
self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
|
|
188
202
|
) -> AsyncIterator[HTTPResponse]:
|
|
189
203
|
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
190
204
|
|
|
@@ -204,6 +218,10 @@ class GeminiAgentModel(AgentModel):
|
|
|
204
218
|
generation_config['temperature'] = temperature
|
|
205
219
|
if (top_p := model_settings.get('top_p')) is not None:
|
|
206
220
|
generation_config['top_p'] = top_p
|
|
221
|
+
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
|
|
222
|
+
generation_config['presence_penalty'] = presence_penalty
|
|
223
|
+
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
|
|
224
|
+
generation_config['frequency_penalty'] = frequency_penalty
|
|
207
225
|
if generation_config:
|
|
208
226
|
request_data['generation_config'] = generation_config
|
|
209
227
|
|
|
@@ -222,7 +240,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
222
240
|
url,
|
|
223
241
|
content=request_json,
|
|
224
242
|
headers=headers,
|
|
225
|
-
timeout=
|
|
243
|
+
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
|
|
226
244
|
) as r:
|
|
227
245
|
if r.status_code != 200:
|
|
228
246
|
await r.aread()
|
|
@@ -398,6 +416,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
|
|
|
398
416
|
max_output_tokens: int
|
|
399
417
|
temperature: float
|
|
400
418
|
top_p: float
|
|
419
|
+
presence_penalty: float
|
|
420
|
+
frequency_penalty: float
|
|
401
421
|
|
|
402
422
|
|
|
403
423
|
class _GeminiContent(TypedDict):
|
|
@@ -439,7 +459,7 @@ def _process_response_from_parts(
|
|
|
439
459
|
items.append(TextPart(content=part['text']))
|
|
440
460
|
elif 'function_call' in part:
|
|
441
461
|
items.append(
|
|
442
|
-
ToolCallPart
|
|
462
|
+
ToolCallPart(
|
|
443
463
|
tool_name=part['function_call']['name'],
|
|
444
464
|
args=part['function_call']['args'],
|
|
445
465
|
)
|