pydantic-ai-slim 0.0.22__py3-none-any.whl → 0.0.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +12 -8
- pydantic_ai/agent.py +5 -5
- pydantic_ai/models/__init__.py +52 -45
- pydantic_ai/models/anthropic.py +87 -66
- pydantic_ai/models/cohere.py +65 -67
- pydantic_ai/models/function.py +76 -60
- pydantic_ai/models/gemini.py +153 -99
- pydantic_ai/models/groq.py +97 -72
- pydantic_ai/models/mistral.py +90 -71
- pydantic_ai/models/openai.py +110 -71
- pydantic_ai/models/test.py +99 -94
- pydantic_ai/models/vertexai.py +48 -44
- pydantic_ai/result.py +2 -2
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.24.dist-info}/METADATA +3 -3
- pydantic_ai_slim-0.0.24.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.22.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.22.dist-info → pydantic_ai_slim-0.0.24.dist-info}/WHEEL +0 -0
pydantic_ai/models/cohere.py
CHANGED
|
@@ -26,8 +26,8 @@ from ..messages import (
|
|
|
26
26
|
from ..settings import ModelSettings
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
29
|
-
AgentModel,
|
|
30
29
|
Model,
|
|
30
|
+
ModelRequestParameters,
|
|
31
31
|
check_allow_model_requests,
|
|
32
32
|
)
|
|
33
33
|
|
|
@@ -52,7 +52,7 @@ except ImportError as _import_error:
|
|
|
52
52
|
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
|
53
53
|
) from _import_error
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
LatestCohereModelNames = Literal[
|
|
56
56
|
'c4ai-aya-expanse-32b',
|
|
57
57
|
'c4ai-aya-expanse-8b',
|
|
58
58
|
'command',
|
|
@@ -67,9 +67,15 @@ NamedCohereModels = Literal[
|
|
|
67
67
|
'command-r-plus-08-2024',
|
|
68
68
|
'command-r7b-12-2024',
|
|
69
69
|
]
|
|
70
|
-
"""Latest
|
|
70
|
+
"""Latest Cohere models."""
|
|
71
71
|
|
|
72
|
-
CohereModelName = Union[
|
|
72
|
+
CohereModelName = Union[str, LatestCohereModelNames]
|
|
73
|
+
"""Possible Cohere model names.
|
|
74
|
+
|
|
75
|
+
Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
|
|
76
|
+
allow any name in the type hints.
|
|
77
|
+
See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
|
|
78
|
+
"""
|
|
73
79
|
|
|
74
80
|
|
|
75
81
|
class CohereModelSettings(ModelSettings):
|
|
@@ -88,9 +94,11 @@ class CohereModel(Model):
|
|
|
88
94
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
89
95
|
"""
|
|
90
96
|
|
|
91
|
-
model_name: CohereModelName
|
|
92
97
|
client: AsyncClientV2 = field(repr=False)
|
|
93
98
|
|
|
99
|
+
_model_name: CohereModelName = field(repr=False)
|
|
100
|
+
_system: str | None = field(default='cohere', repr=False)
|
|
101
|
+
|
|
94
102
|
def __init__(
|
|
95
103
|
self,
|
|
96
104
|
model_name: CohereModelName,
|
|
@@ -110,7 +118,7 @@ class CohereModel(Model):
|
|
|
110
118
|
`api_key` and `http_client` must be `None`.
|
|
111
119
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
112
120
|
"""
|
|
113
|
-
self.
|
|
121
|
+
self._model_name: CohereModelName = model_name
|
|
114
122
|
if cohere_client is not None:
|
|
115
123
|
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
|
|
116
124
|
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
@@ -118,64 +126,38 @@ class CohereModel(Model):
|
|
|
118
126
|
else:
|
|
119
127
|
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
|
|
120
128
|
|
|
121
|
-
async def
|
|
129
|
+
async def request(
|
|
122
130
|
self,
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
) -> AgentModel:
|
|
131
|
+
messages: list[ModelMessage],
|
|
132
|
+
model_settings: ModelSettings | None,
|
|
133
|
+
model_request_parameters: ModelRequestParameters,
|
|
134
|
+
) -> tuple[ModelResponse, result.Usage]:
|
|
128
135
|
check_allow_model_requests()
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
132
|
-
return CohereAgentModel(
|
|
133
|
-
self.client,
|
|
134
|
-
self.model_name,
|
|
135
|
-
allow_text_result,
|
|
136
|
-
tools,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
def name(self) -> str:
|
|
140
|
-
return f'cohere:{self.model_name}'
|
|
141
|
-
|
|
142
|
-
@staticmethod
|
|
143
|
-
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
|
|
144
|
-
return ToolV2(
|
|
145
|
-
type='function',
|
|
146
|
-
function=ToolV2Function(
|
|
147
|
-
name=f.name,
|
|
148
|
-
description=f.description,
|
|
149
|
-
parameters=f.parameters_json_schema,
|
|
150
|
-
),
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
@dataclass
|
|
155
|
-
class CohereAgentModel(AgentModel):
|
|
156
|
-
"""Implementation of `AgentModel` for Cohere models."""
|
|
136
|
+
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
|
|
137
|
+
return self._process_response(response), _map_usage(response)
|
|
157
138
|
|
|
158
|
-
|
|
159
|
-
model_name
|
|
160
|
-
|
|
161
|
-
|
|
139
|
+
@property
|
|
140
|
+
def model_name(self) -> CohereModelName:
|
|
141
|
+
"""The model name."""
|
|
142
|
+
return self._model_name
|
|
162
143
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
return self._process_response(response), _map_usage(response)
|
|
144
|
+
@property
|
|
145
|
+
def system(self) -> str | None:
|
|
146
|
+
"""The system / model provider."""
|
|
147
|
+
return self._system
|
|
168
148
|
|
|
169
149
|
async def _chat(
|
|
170
150
|
self,
|
|
171
151
|
messages: list[ModelMessage],
|
|
172
152
|
model_settings: CohereModelSettings,
|
|
153
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
154
|
) -> ChatResponse:
|
|
155
|
+
tools = self._get_tools(model_request_parameters)
|
|
174
156
|
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
175
157
|
return await self.client.chat(
|
|
176
|
-
model=self.
|
|
158
|
+
model=self._model_name,
|
|
177
159
|
messages=cohere_messages,
|
|
178
|
-
tools=
|
|
160
|
+
tools=tools or OMIT,
|
|
179
161
|
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
180
162
|
temperature=model_settings.get('temperature', OMIT),
|
|
181
163
|
p=model_settings.get('top_p', OMIT),
|
|
@@ -201,13 +183,12 @@ class CohereAgentModel(AgentModel):
|
|
|
201
183
|
tool_call_id=c.id,
|
|
202
184
|
)
|
|
203
185
|
)
|
|
204
|
-
return ModelResponse(parts=parts, model_name=self.
|
|
186
|
+
return ModelResponse(parts=parts, model_name=self._model_name)
|
|
205
187
|
|
|
206
|
-
|
|
207
|
-
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
|
|
188
|
+
def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
|
|
208
189
|
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
209
190
|
if isinstance(message, ModelRequest):
|
|
210
|
-
yield from
|
|
191
|
+
yield from self._map_user_message(message)
|
|
211
192
|
elif isinstance(message, ModelResponse):
|
|
212
193
|
texts: list[str] = []
|
|
213
194
|
tool_calls: list[ToolCallV2] = []
|
|
@@ -215,7 +196,7 @@ class CohereAgentModel(AgentModel):
|
|
|
215
196
|
if isinstance(item, TextPart):
|
|
216
197
|
texts.append(item.content)
|
|
217
198
|
elif isinstance(item, ToolCallPart):
|
|
218
|
-
tool_calls.append(_map_tool_call(item))
|
|
199
|
+
tool_calls.append(self._map_tool_call(item))
|
|
219
200
|
else:
|
|
220
201
|
assert_never(item)
|
|
221
202
|
message_param = AssistantChatMessageV2(role='assistant')
|
|
@@ -227,6 +208,34 @@ class CohereAgentModel(AgentModel):
|
|
|
227
208
|
else:
|
|
228
209
|
assert_never(message)
|
|
229
210
|
|
|
211
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
|
|
212
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
213
|
+
if model_request_parameters.result_tools:
|
|
214
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
215
|
+
return tools
|
|
216
|
+
|
|
217
|
+
@staticmethod
|
|
218
|
+
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
219
|
+
return ToolCallV2(
|
|
220
|
+
id=_guard_tool_call_id(t=t, model_source='Cohere'),
|
|
221
|
+
type='function',
|
|
222
|
+
function=ToolCallV2Function(
|
|
223
|
+
name=t.tool_name,
|
|
224
|
+
arguments=t.args_as_json_str(),
|
|
225
|
+
),
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@staticmethod
|
|
229
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
|
|
230
|
+
return ToolV2(
|
|
231
|
+
type='function',
|
|
232
|
+
function=ToolV2Function(
|
|
233
|
+
name=f.name,
|
|
234
|
+
description=f.description,
|
|
235
|
+
parameters=f.parameters_json_schema,
|
|
236
|
+
),
|
|
237
|
+
)
|
|
238
|
+
|
|
230
239
|
@classmethod
|
|
231
240
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
|
|
232
241
|
for part in message.parts:
|
|
@@ -253,17 +262,6 @@ class CohereAgentModel(AgentModel):
|
|
|
253
262
|
assert_never(part)
|
|
254
263
|
|
|
255
264
|
|
|
256
|
-
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
257
|
-
return ToolCallV2(
|
|
258
|
-
id=_guard_tool_call_id(t=t, model_source='Cohere'),
|
|
259
|
-
type='function',
|
|
260
|
-
function=ToolCallV2Function(
|
|
261
|
-
name=t.tool_name,
|
|
262
|
-
arguments=t.args_as_json_str(),
|
|
263
|
-
),
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
|
|
267
265
|
def _map_usage(response: ChatResponse) -> result.Usage:
|
|
268
266
|
usage = response.usage
|
|
269
267
|
if usage is None:
|
pydantic_ai/models/function.py
CHANGED
|
@@ -4,7 +4,7 @@ import inspect
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import AsyncIterator, Awaitable, Iterable
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from itertools import chain
|
|
10
10
|
from typing import Callable, Union
|
|
@@ -27,7 +27,7 @@ from ..messages import (
|
|
|
27
27
|
)
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
|
-
from . import
|
|
30
|
+
from . import Model, ModelRequestParameters, StreamedResponse
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
@dataclass(init=False)
|
|
@@ -40,6 +40,9 @@ class FunctionModel(Model):
|
|
|
40
40
|
function: FunctionDef | None = None
|
|
41
41
|
stream_function: StreamFunctionDef | None = None
|
|
42
42
|
|
|
43
|
+
_model_name: str = field(repr=False)
|
|
44
|
+
_system: str | None = field(default=None, repr=False)
|
|
45
|
+
|
|
43
46
|
@overload
|
|
44
47
|
def __init__(self, function: FunctionDef) -> None: ...
|
|
45
48
|
|
|
@@ -63,23 +66,70 @@ class FunctionModel(Model):
|
|
|
63
66
|
self.function = function
|
|
64
67
|
self.stream_function = stream_function
|
|
65
68
|
|
|
66
|
-
|
|
69
|
+
function_name = self.function.__name__ if self.function is not None else ''
|
|
70
|
+
stream_function_name = self.stream_function.__name__ if self.stream_function is not None else ''
|
|
71
|
+
self._model_name = f'function:{function_name}:{stream_function_name}'
|
|
72
|
+
|
|
73
|
+
async def request(
|
|
67
74
|
self,
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
75
|
+
messages: list[ModelMessage],
|
|
76
|
+
model_settings: ModelSettings | None,
|
|
77
|
+
model_request_parameters: ModelRequestParameters,
|
|
78
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
79
|
+
agent_info = AgentInfo(
|
|
80
|
+
model_request_parameters.function_tools,
|
|
81
|
+
model_request_parameters.allow_text_result,
|
|
82
|
+
model_request_parameters.result_tools,
|
|
83
|
+
model_settings,
|
|
77
84
|
)
|
|
78
85
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
86
|
+
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
87
|
+
|
|
88
|
+
if inspect.iscoroutinefunction(self.function):
|
|
89
|
+
response = await self.function(messages, agent_info)
|
|
90
|
+
else:
|
|
91
|
+
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
92
|
+
assert isinstance(response_, ModelResponse), response_
|
|
93
|
+
response = response_
|
|
94
|
+
response.model_name = f'function:{self.function.__name__}'
|
|
95
|
+
# TODO is `messages` right here? Should it just be new messages?
|
|
96
|
+
return response, _estimate_usage(chain(messages, [response]))
|
|
97
|
+
|
|
98
|
+
@asynccontextmanager
|
|
99
|
+
async def request_stream(
|
|
100
|
+
self,
|
|
101
|
+
messages: list[ModelMessage],
|
|
102
|
+
model_settings: ModelSettings | None,
|
|
103
|
+
model_request_parameters: ModelRequestParameters,
|
|
104
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
105
|
+
agent_info = AgentInfo(
|
|
106
|
+
model_request_parameters.function_tools,
|
|
107
|
+
model_request_parameters.allow_text_result,
|
|
108
|
+
model_request_parameters.result_tools,
|
|
109
|
+
model_settings,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
assert (
|
|
113
|
+
self.stream_function is not None
|
|
114
|
+
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
115
|
+
|
|
116
|
+
response_stream = PeekableAsyncStream(self.stream_function(messages, agent_info))
|
|
117
|
+
|
|
118
|
+
first = await response_stream.peek()
|
|
119
|
+
if isinstance(first, _utils.Unset):
|
|
120
|
+
raise ValueError('Stream function must return at least one item')
|
|
121
|
+
|
|
122
|
+
yield FunctionStreamedResponse(_model_name=f'function:{self.stream_function.__name__}', _iter=response_stream)
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def model_name(self) -> str:
|
|
126
|
+
"""The model name."""
|
|
127
|
+
return self._model_name
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def system(self) -> str | None:
|
|
131
|
+
"""The system / model provider."""
|
|
132
|
+
return self._system
|
|
83
133
|
|
|
84
134
|
|
|
85
135
|
@dataclass(frozen=True)
|
|
@@ -119,9 +169,11 @@ class DeltaToolCall:
|
|
|
119
169
|
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
|
|
120
170
|
"""A mapping of tool call IDs to incremental changes."""
|
|
121
171
|
|
|
172
|
+
# TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...]
|
|
122
173
|
FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
|
|
123
174
|
"""A function used to generate a non-streamed response."""
|
|
124
175
|
|
|
176
|
+
# TODO: Change signature as indicated above
|
|
125
177
|
StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
|
|
126
178
|
"""A function used to generate a streamed response.
|
|
127
179
|
|
|
@@ -132,54 +184,11 @@ E.g. you need to yield all text or all `DeltaToolCalls`, not mix them.
|
|
|
132
184
|
"""
|
|
133
185
|
|
|
134
186
|
|
|
135
|
-
@dataclass
|
|
136
|
-
class FunctionAgentModel(AgentModel):
|
|
137
|
-
"""Implementation of `AgentModel` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
138
|
-
|
|
139
|
-
function: FunctionDef | None
|
|
140
|
-
stream_function: StreamFunctionDef | None
|
|
141
|
-
agent_info: AgentInfo
|
|
142
|
-
|
|
143
|
-
async def request(
|
|
144
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
145
|
-
) -> tuple[ModelResponse, usage.Usage]:
|
|
146
|
-
agent_info = replace(self.agent_info, model_settings=model_settings)
|
|
147
|
-
|
|
148
|
-
assert self.function is not None, 'FunctionModel must receive a `function` to support non-streamed requests'
|
|
149
|
-
model_name = f'function:{self.function.__name__}'
|
|
150
|
-
|
|
151
|
-
if inspect.iscoroutinefunction(self.function):
|
|
152
|
-
response = await self.function(messages, agent_info)
|
|
153
|
-
else:
|
|
154
|
-
response_ = await _utils.run_in_executor(self.function, messages, agent_info)
|
|
155
|
-
assert isinstance(response_, ModelResponse), response_
|
|
156
|
-
response = response_
|
|
157
|
-
response.model_name = model_name
|
|
158
|
-
# TODO is `messages` right here? Should it just be new messages?
|
|
159
|
-
return response, _estimate_usage(chain(messages, [response]))
|
|
160
|
-
|
|
161
|
-
@asynccontextmanager
|
|
162
|
-
async def request_stream(
|
|
163
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
164
|
-
) -> AsyncIterator[StreamedResponse]:
|
|
165
|
-
assert (
|
|
166
|
-
self.stream_function is not None
|
|
167
|
-
), 'FunctionModel must receive a `stream_function` to support streamed requests'
|
|
168
|
-
model_name = f'function:{self.stream_function.__name__}'
|
|
169
|
-
|
|
170
|
-
response_stream = PeekableAsyncStream(self.stream_function(messages, self.agent_info))
|
|
171
|
-
|
|
172
|
-
first = await response_stream.peek()
|
|
173
|
-
if isinstance(first, _utils.Unset):
|
|
174
|
-
raise ValueError('Stream function must return at least one item')
|
|
175
|
-
|
|
176
|
-
yield FunctionStreamedResponse(_model_name=model_name, _iter=response_stream)
|
|
177
|
-
|
|
178
|
-
|
|
179
187
|
@dataclass
|
|
180
188
|
class FunctionStreamedResponse(StreamedResponse):
|
|
181
189
|
"""Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel]."""
|
|
182
190
|
|
|
191
|
+
_model_name: str
|
|
183
192
|
_iter: AsyncIterator[str | DeltaToolCalls]
|
|
184
193
|
_timestamp: datetime = field(default_factory=_utils.now_utc)
|
|
185
194
|
|
|
@@ -207,7 +216,14 @@ class FunctionStreamedResponse(StreamedResponse):
|
|
|
207
216
|
if maybe_event is not None:
|
|
208
217
|
yield maybe_event
|
|
209
218
|
|
|
219
|
+
@property
|
|
220
|
+
def model_name(self) -> str:
|
|
221
|
+
"""Get the model name of the response."""
|
|
222
|
+
return self._model_name
|
|
223
|
+
|
|
224
|
+
@property
|
|
210
225
|
def timestamp(self) -> datetime:
|
|
226
|
+
"""Get the timestamp of the response."""
|
|
211
227
|
return self._timestamp
|
|
212
228
|
|
|
213
229
|
|