pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +774 -0
- pydantic_ai/agent.py +183 -555
- pydantic_ai/models/__init__.py +43 -37
- pydantic_ai/models/anthropic.py +69 -66
- pydantic_ai/models/cohere.py +56 -68
- pydantic_ai/models/function.py +58 -60
- pydantic_ai/models/gemini.py +139 -100
- pydantic_ai/models/groq.py +79 -72
- pydantic_ai/models/mistral.py +72 -71
- pydantic_ai/models/openai.py +96 -71
- pydantic_ai/models/test.py +81 -93
- pydantic_ai/models/vertexai.py +38 -44
- pydantic_ai/settings.py +5 -0
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/METADATA +3 -4
- pydantic_ai_slim-0.0.23.dist-info/RECORD +30 -0
- pydantic_ai_slim-0.0.21.dist-info/RECORD +0 -29
- {pydantic_ai_slim-0.0.21.dist-info → pydantic_ai_slim-0.0.23.dist-info}/WHEEL +0 -0
pydantic_ai/models/__init__.py
CHANGED
|
@@ -52,11 +52,15 @@ KnownModelName = Literal[
|
|
|
52
52
|
'google-gla:gemini-1.5-flash-8b',
|
|
53
53
|
'google-gla:gemini-1.5-pro',
|
|
54
54
|
'google-gla:gemini-2.0-flash-exp',
|
|
55
|
+
'google-gla:gemini-2.0-flash-thinking-exp-01-21',
|
|
56
|
+
'google-gla:gemini-exp-1206',
|
|
55
57
|
'google-vertex:gemini-1.0-pro',
|
|
56
58
|
'google-vertex:gemini-1.5-flash',
|
|
57
59
|
'google-vertex:gemini-1.5-flash-8b',
|
|
58
60
|
'google-vertex:gemini-1.5-pro',
|
|
59
61
|
'google-vertex:gemini-2.0-flash-exp',
|
|
62
|
+
'google-vertex:gemini-2.0-flash-thinking-exp-01-21',
|
|
63
|
+
'google-vertex:gemini-exp-1206',
|
|
60
64
|
'gpt-3.5-turbo',
|
|
61
65
|
'gpt-3.5-turbo-0125',
|
|
62
66
|
'gpt-3.5-turbo-0301',
|
|
@@ -108,6 +112,8 @@ KnownModelName = Literal[
|
|
|
108
112
|
'o1-mini-2024-09-12',
|
|
109
113
|
'o1-preview',
|
|
110
114
|
'o1-preview-2024-09-12',
|
|
115
|
+
'o3-mini',
|
|
116
|
+
'o3-mini-2025-01-31',
|
|
111
117
|
'openai:chatgpt-4o-latest',
|
|
112
118
|
'openai:gpt-3.5-turbo',
|
|
113
119
|
'openai:gpt-3.5-turbo-0125',
|
|
@@ -145,6 +151,8 @@ KnownModelName = Literal[
|
|
|
145
151
|
'openai:o1-mini-2024-09-12',
|
|
146
152
|
'openai:o1-preview',
|
|
147
153
|
'openai:o1-preview-2024-09-12',
|
|
154
|
+
'openai:o3-mini',
|
|
155
|
+
'openai:o3-mini-2025-01-31',
|
|
148
156
|
'test',
|
|
149
157
|
]
|
|
150
158
|
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
|
|
@@ -153,49 +161,37 @@ KnownModelName = Literal[
|
|
|
153
161
|
"""
|
|
154
162
|
|
|
155
163
|
|
|
156
|
-
|
|
157
|
-
|
|
164
|
+
@dataclass
|
|
165
|
+
class ModelRequestParameters:
|
|
166
|
+
"""Configuration for an agent's request to a model, specifically related to tools and result handling."""
|
|
158
167
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
*,
|
|
163
|
-
function_tools: list[ToolDefinition],
|
|
164
|
-
allow_text_result: bool,
|
|
165
|
-
result_tools: list[ToolDefinition],
|
|
166
|
-
) -> AgentModel:
|
|
167
|
-
"""Create an agent model, this is called for each step of an agent run.
|
|
168
|
-
|
|
169
|
-
This is async in case slow/async config checks need to be performed that can't be done in `__init__`.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
function_tools: The tools available to the agent.
|
|
173
|
-
allow_text_result: Whether a plain text final response/result is permitted.
|
|
174
|
-
result_tools: Tool definitions for the final result tool(s), if any.
|
|
175
|
-
|
|
176
|
-
Returns:
|
|
177
|
-
An agent model.
|
|
178
|
-
"""
|
|
179
|
-
raise NotImplementedError()
|
|
168
|
+
function_tools: list[ToolDefinition]
|
|
169
|
+
allow_text_result: bool
|
|
170
|
+
result_tools: list[ToolDefinition]
|
|
180
171
|
|
|
181
|
-
@abstractmethod
|
|
182
|
-
def name(self) -> str:
|
|
183
|
-
raise NotImplementedError()
|
|
184
172
|
|
|
173
|
+
class Model(ABC):
|
|
174
|
+
"""Abstract class for a model."""
|
|
185
175
|
|
|
186
|
-
|
|
187
|
-
|
|
176
|
+
_model_name: str
|
|
177
|
+
_system: str | None
|
|
188
178
|
|
|
189
179
|
@abstractmethod
|
|
190
180
|
async def request(
|
|
191
|
-
self,
|
|
181
|
+
self,
|
|
182
|
+
messages: list[ModelMessage],
|
|
183
|
+
model_settings: ModelSettings | None,
|
|
184
|
+
model_request_parameters: ModelRequestParameters,
|
|
192
185
|
) -> tuple[ModelResponse, Usage]:
|
|
193
186
|
"""Make a request to the model."""
|
|
194
187
|
raise NotImplementedError()
|
|
195
188
|
|
|
196
189
|
@asynccontextmanager
|
|
197
190
|
async def request_stream(
|
|
198
|
-
self,
|
|
191
|
+
self,
|
|
192
|
+
messages: list[ModelMessage],
|
|
193
|
+
model_settings: ModelSettings | None,
|
|
194
|
+
model_request_parameters: ModelRequestParameters,
|
|
199
195
|
) -> AsyncIterator[StreamedResponse]:
|
|
200
196
|
"""Make a request to the model and return a streaming response."""
|
|
201
197
|
# This method is not required, but you need to implement it if you want to support streamed responses
|
|
@@ -204,6 +200,16 @@ class AgentModel(ABC):
|
|
|
204
200
|
# noinspection PyUnreachableCode
|
|
205
201
|
yield # pragma: no cover
|
|
206
202
|
|
|
203
|
+
@property
|
|
204
|
+
def model_name(self) -> str:
|
|
205
|
+
"""The model name."""
|
|
206
|
+
return self._model_name
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def system(self) -> str | None:
|
|
210
|
+
"""The system / model provider, ex: openai."""
|
|
211
|
+
return self._system
|
|
212
|
+
|
|
207
213
|
|
|
208
214
|
@dataclass
|
|
209
215
|
class StreamedResponse(ABC):
|
|
@@ -266,7 +272,7 @@ def check_allow_model_requests() -> None:
|
|
|
266
272
|
"""Check if model requests are allowed.
|
|
267
273
|
|
|
268
274
|
If you're defining your own models that have costs or latency associated with their use, you should call this in
|
|
269
|
-
[`Model.
|
|
275
|
+
[`Model.request`][pydantic_ai.models.Model.request] and [`Model.request_stream`][pydantic_ai.models.Model.request_stream].
|
|
270
276
|
|
|
271
277
|
Raises:
|
|
272
278
|
RuntimeError: If model requests are not allowed.
|
|
@@ -307,33 +313,33 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
307
313
|
from .openai import OpenAIModel
|
|
308
314
|
|
|
309
315
|
return OpenAIModel(model[7:])
|
|
310
|
-
elif model.startswith(('gpt', 'o1')):
|
|
316
|
+
elif model.startswith(('gpt', 'o1', 'o3')):
|
|
311
317
|
from .openai import OpenAIModel
|
|
312
318
|
|
|
313
319
|
return OpenAIModel(model)
|
|
314
320
|
elif model.startswith('google-gla'):
|
|
315
321
|
from .gemini import GeminiModel
|
|
316
322
|
|
|
317
|
-
return GeminiModel(model[11:])
|
|
323
|
+
return GeminiModel(model[11:])
|
|
318
324
|
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
319
325
|
elif model.startswith('gemini'):
|
|
320
326
|
from .gemini import GeminiModel
|
|
321
327
|
|
|
322
328
|
# noinspection PyTypeChecker
|
|
323
|
-
return GeminiModel(model)
|
|
329
|
+
return GeminiModel(model)
|
|
324
330
|
elif model.startswith('groq:'):
|
|
325
331
|
from .groq import GroqModel
|
|
326
332
|
|
|
327
|
-
return GroqModel(model[5:])
|
|
333
|
+
return GroqModel(model[5:])
|
|
328
334
|
elif model.startswith('google-vertex'):
|
|
329
335
|
from .vertexai import VertexAIModel
|
|
330
336
|
|
|
331
|
-
return VertexAIModel(model[14:])
|
|
337
|
+
return VertexAIModel(model[14:])
|
|
332
338
|
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
333
339
|
elif model.startswith('vertexai:'):
|
|
334
340
|
from .vertexai import VertexAIModel
|
|
335
341
|
|
|
336
|
-
return VertexAIModel(model[9:])
|
|
342
|
+
return VertexAIModel(model[9:])
|
|
337
343
|
elif model.startswith('mistral:'):
|
|
338
344
|
from .mistral import MistralModel
|
|
339
345
|
|
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -28,8 +28,8 @@ from ..messages import (
|
|
|
28
28
|
from ..settings import ModelSettings
|
|
29
29
|
from ..tools import ToolDefinition
|
|
30
30
|
from . import (
|
|
31
|
-
AgentModel,
|
|
32
31
|
Model,
|
|
32
|
+
ModelRequestParameters,
|
|
33
33
|
StreamedResponse,
|
|
34
34
|
cached_async_http_client,
|
|
35
35
|
check_allow_model_requests,
|
|
@@ -68,14 +68,14 @@ LatestAnthropicModelNames = Literal[
|
|
|
68
68
|
'claude-3-5-sonnet-latest',
|
|
69
69
|
'claude-3-opus-latest',
|
|
70
70
|
]
|
|
71
|
-
"""Latest
|
|
71
|
+
"""Latest Anthropic models."""
|
|
72
72
|
|
|
73
73
|
AnthropicModelName = Union[str, LatestAnthropicModelNames]
|
|
74
74
|
"""Possible Anthropic model names.
|
|
75
75
|
|
|
76
76
|
Since Anthropic supports a variety of date-stamped models, we explicitly list the latest models but
|
|
77
77
|
allow any name in the type hints.
|
|
78
|
-
|
|
78
|
+
See [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/models) for a full list.
|
|
79
79
|
"""
|
|
80
80
|
|
|
81
81
|
|
|
@@ -101,9 +101,11 @@ class AnthropicModel(Model):
|
|
|
101
101
|
We anticipate adding support for streaming responses in a near-term future release.
|
|
102
102
|
"""
|
|
103
103
|
|
|
104
|
-
model_name: AnthropicModelName
|
|
105
104
|
client: AsyncAnthropic = field(repr=False)
|
|
106
105
|
|
|
106
|
+
_model_name: AnthropicModelName = field(repr=False)
|
|
107
|
+
_system: str | None = field(default='anthropic', repr=False)
|
|
108
|
+
|
|
107
109
|
def __init__(
|
|
108
110
|
self,
|
|
109
111
|
model_name: AnthropicModelName,
|
|
@@ -124,7 +126,7 @@ class AnthropicModel(Model):
|
|
|
124
126
|
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
125
127
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
126
128
|
"""
|
|
127
|
-
self.
|
|
129
|
+
self._model_name = model_name
|
|
128
130
|
if anthropic_client is not None:
|
|
129
131
|
assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
|
|
130
132
|
assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
|
|
@@ -134,81 +136,67 @@ class AnthropicModel(Model):
|
|
|
134
136
|
else:
|
|
135
137
|
self.client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
|
|
136
138
|
|
|
137
|
-
async def
|
|
139
|
+
async def request(
|
|
138
140
|
self,
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
) -> AgentModel:
|
|
141
|
+
messages: list[ModelMessage],
|
|
142
|
+
model_settings: ModelSettings | None,
|
|
143
|
+
model_request_parameters: ModelRequestParameters,
|
|
144
|
+
) -> tuple[ModelResponse, usage.Usage]:
|
|
144
145
|
check_allow_model_requests()
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
148
|
-
return AnthropicAgentModel(
|
|
149
|
-
self.client,
|
|
150
|
-
self.model_name,
|
|
151
|
-
allow_text_result,
|
|
152
|
-
tools,
|
|
146
|
+
response = await self._messages_create(
|
|
147
|
+
messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
153
148
|
)
|
|
154
|
-
|
|
155
|
-
def name(self) -> str:
|
|
156
|
-
return f'anthropic:{self.model_name}'
|
|
157
|
-
|
|
158
|
-
@staticmethod
|
|
159
|
-
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
160
|
-
return {
|
|
161
|
-
'name': f.name,
|
|
162
|
-
'description': f.description,
|
|
163
|
-
'input_schema': f.parameters_json_schema,
|
|
164
|
-
}
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
@dataclass
|
|
168
|
-
class AnthropicAgentModel(AgentModel):
|
|
169
|
-
"""Implementation of `AgentModel` for Anthropic models."""
|
|
170
|
-
|
|
171
|
-
client: AsyncAnthropic
|
|
172
|
-
model_name: AnthropicModelName
|
|
173
|
-
allow_text_result: bool
|
|
174
|
-
tools: list[ToolParam]
|
|
175
|
-
|
|
176
|
-
async def request(
|
|
177
|
-
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
178
|
-
) -> tuple[ModelResponse, usage.Usage]:
|
|
179
|
-
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
|
|
180
149
|
return self._process_response(response), _map_usage(response)
|
|
181
150
|
|
|
182
151
|
@asynccontextmanager
|
|
183
152
|
async def request_stream(
|
|
184
|
-
self,
|
|
153
|
+
self,
|
|
154
|
+
messages: list[ModelMessage],
|
|
155
|
+
model_settings: ModelSettings | None,
|
|
156
|
+
model_request_parameters: ModelRequestParameters,
|
|
185
157
|
) -> AsyncIterator[StreamedResponse]:
|
|
186
|
-
|
|
158
|
+
check_allow_model_requests()
|
|
159
|
+
response = await self._messages_create(
|
|
160
|
+
messages, True, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters
|
|
161
|
+
)
|
|
187
162
|
async with response:
|
|
188
163
|
yield await self._process_streamed_response(response)
|
|
189
164
|
|
|
190
165
|
@overload
|
|
191
166
|
async def _messages_create(
|
|
192
|
-
self,
|
|
167
|
+
self,
|
|
168
|
+
messages: list[ModelMessage],
|
|
169
|
+
stream: Literal[True],
|
|
170
|
+
model_settings: AnthropicModelSettings,
|
|
171
|
+
model_request_parameters: ModelRequestParameters,
|
|
193
172
|
) -> AsyncStream[RawMessageStreamEvent]:
|
|
194
173
|
pass
|
|
195
174
|
|
|
196
175
|
@overload
|
|
197
176
|
async def _messages_create(
|
|
198
|
-
self,
|
|
177
|
+
self,
|
|
178
|
+
messages: list[ModelMessage],
|
|
179
|
+
stream: Literal[False],
|
|
180
|
+
model_settings: AnthropicModelSettings,
|
|
181
|
+
model_request_parameters: ModelRequestParameters,
|
|
199
182
|
) -> AnthropicMessage:
|
|
200
183
|
pass
|
|
201
184
|
|
|
202
185
|
async def _messages_create(
|
|
203
|
-
self,
|
|
186
|
+
self,
|
|
187
|
+
messages: list[ModelMessage],
|
|
188
|
+
stream: bool,
|
|
189
|
+
model_settings: AnthropicModelSettings,
|
|
190
|
+
model_request_parameters: ModelRequestParameters,
|
|
204
191
|
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
205
192
|
# standalone function to make it easier to override
|
|
193
|
+
tools = self._get_tools(model_request_parameters)
|
|
206
194
|
tool_choice: ToolChoiceParam | None
|
|
207
195
|
|
|
208
|
-
if not
|
|
196
|
+
if not tools:
|
|
209
197
|
tool_choice = None
|
|
210
198
|
else:
|
|
211
|
-
if not
|
|
199
|
+
if not model_request_parameters.allow_text_result:
|
|
212
200
|
tool_choice = {'type': 'any'}
|
|
213
201
|
else:
|
|
214
202
|
tool_choice = {'type': 'auto'}
|
|
@@ -222,8 +210,8 @@ class AnthropicAgentModel(AgentModel):
|
|
|
222
210
|
max_tokens=model_settings.get('max_tokens', 1024),
|
|
223
211
|
system=system_prompt or NOT_GIVEN,
|
|
224
212
|
messages=anthropic_messages,
|
|
225
|
-
model=self.
|
|
226
|
-
tools=
|
|
213
|
+
model=self._model_name,
|
|
214
|
+
tools=tools or NOT_GIVEN,
|
|
227
215
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
228
216
|
stream=stream,
|
|
229
217
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
@@ -248,7 +236,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
248
236
|
)
|
|
249
237
|
)
|
|
250
238
|
|
|
251
|
-
return ModelResponse(items, model_name=self.
|
|
239
|
+
return ModelResponse(items, model_name=self._model_name)
|
|
252
240
|
|
|
253
241
|
async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
|
|
254
242
|
peekable_response = _utils.PeekableAsyncStream(response)
|
|
@@ -258,10 +246,17 @@ class AnthropicAgentModel(AgentModel):
|
|
|
258
246
|
|
|
259
247
|
# Since Anthropic doesn't provide a timestamp in the message, we'll use the current time
|
|
260
248
|
timestamp = datetime.now(tz=timezone.utc)
|
|
261
|
-
return AnthropicStreamedResponse(
|
|
249
|
+
return AnthropicStreamedResponse(
|
|
250
|
+
_model_name=self._model_name, _response=peekable_response, _timestamp=timestamp
|
|
251
|
+
)
|
|
262
252
|
|
|
263
|
-
|
|
264
|
-
|
|
253
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolParam]:
|
|
254
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
255
|
+
if model_request_parameters.result_tools:
|
|
256
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
257
|
+
return tools
|
|
258
|
+
|
|
259
|
+
def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
|
|
265
260
|
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
|
|
266
261
|
system_prompt: str = ''
|
|
267
262
|
anthropic_messages: list[MessageParam] = []
|
|
@@ -310,20 +305,28 @@ class AnthropicAgentModel(AgentModel):
|
|
|
310
305
|
content.append(TextBlockParam(text=item.content, type='text'))
|
|
311
306
|
else:
|
|
312
307
|
assert isinstance(item, ToolCallPart)
|
|
313
|
-
content.append(_map_tool_call(item))
|
|
308
|
+
content.append(self._map_tool_call(item))
|
|
314
309
|
anthropic_messages.append(MessageParam(role='assistant', content=content))
|
|
315
310
|
else:
|
|
316
311
|
assert_never(m)
|
|
317
312
|
return system_prompt, anthropic_messages
|
|
318
313
|
|
|
314
|
+
@staticmethod
|
|
315
|
+
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
316
|
+
return ToolUseBlockParam(
|
|
317
|
+
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
318
|
+
type='tool_use',
|
|
319
|
+
name=t.tool_name,
|
|
320
|
+
input=t.args_as_dict(),
|
|
321
|
+
)
|
|
319
322
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
323
|
+
@staticmethod
|
|
324
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
|
|
325
|
+
return {
|
|
326
|
+
'name': f.name,
|
|
327
|
+
'description': f.description,
|
|
328
|
+
'input_schema': f.parameters_json_schema,
|
|
329
|
+
}
|
|
327
330
|
|
|
328
331
|
|
|
329
332
|
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -26,8 +26,8 @@ from ..messages import (
|
|
|
26
26
|
from ..settings import ModelSettings
|
|
27
27
|
from ..tools import ToolDefinition
|
|
28
28
|
from . import (
|
|
29
|
-
AgentModel,
|
|
30
29
|
Model,
|
|
30
|
+
ModelRequestParameters,
|
|
31
31
|
check_allow_model_requests,
|
|
32
32
|
)
|
|
33
33
|
|
|
@@ -52,7 +52,7 @@ except ImportError as _import_error:
|
|
|
52
52
|
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
|
53
53
|
) from _import_error
|
|
54
54
|
|
|
55
|
-
|
|
55
|
+
LatestCohereModelNames = Literal[
|
|
56
56
|
'c4ai-aya-expanse-32b',
|
|
57
57
|
'c4ai-aya-expanse-8b',
|
|
58
58
|
'command',
|
|
@@ -67,9 +67,15 @@ NamedCohereModels = Literal[
|
|
|
67
67
|
'command-r-plus-08-2024',
|
|
68
68
|
'command-r7b-12-2024',
|
|
69
69
|
]
|
|
70
|
-
"""Latest
|
|
70
|
+
"""Latest Cohere models."""
|
|
71
71
|
|
|
72
|
-
CohereModelName = Union[
|
|
72
|
+
CohereModelName = Union[str, LatestCohereModelNames]
|
|
73
|
+
"""Possible Cohere model names.
|
|
74
|
+
|
|
75
|
+
Since Cohere supports a variety of date-stamped models, we explicitly list the latest models but
|
|
76
|
+
allow any name in the type hints.
|
|
77
|
+
See [Cohere's docs](https://docs.cohere.com/v2/docs/models) for a list of all available models.
|
|
78
|
+
"""
|
|
73
79
|
|
|
74
80
|
|
|
75
81
|
class CohereModelSettings(ModelSettings):
|
|
@@ -88,9 +94,11 @@ class CohereModel(Model):
|
|
|
88
94
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
89
95
|
"""
|
|
90
96
|
|
|
91
|
-
model_name: CohereModelName
|
|
92
97
|
client: AsyncClientV2 = field(repr=False)
|
|
93
98
|
|
|
99
|
+
_model_name: CohereModelName = field(repr=False)
|
|
100
|
+
_system: str | None = field(default='cohere', repr=False)
|
|
101
|
+
|
|
94
102
|
def __init__(
|
|
95
103
|
self,
|
|
96
104
|
model_name: CohereModelName,
|
|
@@ -110,7 +118,7 @@ class CohereModel(Model):
|
|
|
110
118
|
`api_key` and `http_client` must be `None`.
|
|
111
119
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
112
120
|
"""
|
|
113
|
-
self.
|
|
121
|
+
self._model_name: CohereModelName = model_name
|
|
114
122
|
if cohere_client is not None:
|
|
115
123
|
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
|
|
116
124
|
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
@@ -118,64 +126,28 @@ class CohereModel(Model):
|
|
|
118
126
|
else:
|
|
119
127
|
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
|
|
120
128
|
|
|
121
|
-
async def agent_model(
|
|
122
|
-
self,
|
|
123
|
-
*,
|
|
124
|
-
function_tools: list[ToolDefinition],
|
|
125
|
-
allow_text_result: bool,
|
|
126
|
-
result_tools: list[ToolDefinition],
|
|
127
|
-
) -> AgentModel:
|
|
128
|
-
check_allow_model_requests()
|
|
129
|
-
tools = [self._map_tool_definition(r) for r in function_tools]
|
|
130
|
-
if result_tools:
|
|
131
|
-
tools += [self._map_tool_definition(r) for r in result_tools]
|
|
132
|
-
return CohereAgentModel(
|
|
133
|
-
self.client,
|
|
134
|
-
self.model_name,
|
|
135
|
-
allow_text_result,
|
|
136
|
-
tools,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
def name(self) -> str:
|
|
140
|
-
return f'cohere:{self.model_name}'
|
|
141
|
-
|
|
142
|
-
@staticmethod
|
|
143
|
-
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
|
|
144
|
-
return ToolV2(
|
|
145
|
-
type='function',
|
|
146
|
-
function=ToolV2Function(
|
|
147
|
-
name=f.name,
|
|
148
|
-
description=f.description,
|
|
149
|
-
parameters=f.parameters_json_schema,
|
|
150
|
-
),
|
|
151
|
-
)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
@dataclass
|
|
155
|
-
class CohereAgentModel(AgentModel):
|
|
156
|
-
"""Implementation of `AgentModel` for Cohere models."""
|
|
157
|
-
|
|
158
|
-
client: AsyncClientV2
|
|
159
|
-
model_name: CohereModelName
|
|
160
|
-
allow_text_result: bool
|
|
161
|
-
tools: list[ToolV2]
|
|
162
|
-
|
|
163
129
|
async def request(
|
|
164
|
-
self,
|
|
130
|
+
self,
|
|
131
|
+
messages: list[ModelMessage],
|
|
132
|
+
model_settings: ModelSettings | None,
|
|
133
|
+
model_request_parameters: ModelRequestParameters,
|
|
165
134
|
) -> tuple[ModelResponse, result.Usage]:
|
|
166
|
-
|
|
135
|
+
check_allow_model_requests()
|
|
136
|
+
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}), model_request_parameters)
|
|
167
137
|
return self._process_response(response), _map_usage(response)
|
|
168
138
|
|
|
169
139
|
async def _chat(
|
|
170
140
|
self,
|
|
171
141
|
messages: list[ModelMessage],
|
|
172
142
|
model_settings: CohereModelSettings,
|
|
143
|
+
model_request_parameters: ModelRequestParameters,
|
|
173
144
|
) -> ChatResponse:
|
|
145
|
+
tools = self._get_tools(model_request_parameters)
|
|
174
146
|
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
175
147
|
return await self.client.chat(
|
|
176
|
-
model=self.
|
|
148
|
+
model=self._model_name,
|
|
177
149
|
messages=cohere_messages,
|
|
178
|
-
tools=
|
|
150
|
+
tools=tools or OMIT,
|
|
179
151
|
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
180
152
|
temperature=model_settings.get('temperature', OMIT),
|
|
181
153
|
p=model_settings.get('top_p', OMIT),
|
|
@@ -201,13 +173,12 @@ class CohereAgentModel(AgentModel):
|
|
|
201
173
|
tool_call_id=c.id,
|
|
202
174
|
)
|
|
203
175
|
)
|
|
204
|
-
return ModelResponse(parts=parts, model_name=self.
|
|
176
|
+
return ModelResponse(parts=parts, model_name=self._model_name)
|
|
205
177
|
|
|
206
|
-
|
|
207
|
-
def _map_message(cls, message: ModelMessage) -> Iterable[ChatMessageV2]:
|
|
178
|
+
def _map_message(self, message: ModelMessage) -> Iterable[ChatMessageV2]:
|
|
208
179
|
"""Just maps a `pydantic_ai.Message` to a `cohere.ChatMessageV2`."""
|
|
209
180
|
if isinstance(message, ModelRequest):
|
|
210
|
-
yield from
|
|
181
|
+
yield from self._map_user_message(message)
|
|
211
182
|
elif isinstance(message, ModelResponse):
|
|
212
183
|
texts: list[str] = []
|
|
213
184
|
tool_calls: list[ToolCallV2] = []
|
|
@@ -215,7 +186,7 @@ class CohereAgentModel(AgentModel):
|
|
|
215
186
|
if isinstance(item, TextPart):
|
|
216
187
|
texts.append(item.content)
|
|
217
188
|
elif isinstance(item, ToolCallPart):
|
|
218
|
-
tool_calls.append(_map_tool_call(item))
|
|
189
|
+
tool_calls.append(self._map_tool_call(item))
|
|
219
190
|
else:
|
|
220
191
|
assert_never(item)
|
|
221
192
|
message_param = AssistantChatMessageV2(role='assistant')
|
|
@@ -227,6 +198,34 @@ class CohereAgentModel(AgentModel):
|
|
|
227
198
|
else:
|
|
228
199
|
assert_never(message)
|
|
229
200
|
|
|
201
|
+
def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolV2]:
|
|
202
|
+
tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
|
|
203
|
+
if model_request_parameters.result_tools:
|
|
204
|
+
tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
|
|
205
|
+
return tools
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
209
|
+
return ToolCallV2(
|
|
210
|
+
id=_guard_tool_call_id(t=t, model_source='Cohere'),
|
|
211
|
+
type='function',
|
|
212
|
+
function=ToolCallV2Function(
|
|
213
|
+
name=t.tool_name,
|
|
214
|
+
arguments=t.args_as_json_str(),
|
|
215
|
+
),
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
@staticmethod
|
|
219
|
+
def _map_tool_definition(f: ToolDefinition) -> ToolV2:
|
|
220
|
+
return ToolV2(
|
|
221
|
+
type='function',
|
|
222
|
+
function=ToolV2Function(
|
|
223
|
+
name=f.name,
|
|
224
|
+
description=f.description,
|
|
225
|
+
parameters=f.parameters_json_schema,
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
|
|
230
229
|
@classmethod
|
|
231
230
|
def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
|
|
232
231
|
for part in message.parts:
|
|
@@ -253,17 +252,6 @@ class CohereAgentModel(AgentModel):
|
|
|
253
252
|
assert_never(part)
|
|
254
253
|
|
|
255
254
|
|
|
256
|
-
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
|
|
257
|
-
return ToolCallV2(
|
|
258
|
-
id=_guard_tool_call_id(t=t, model_source='Cohere'),
|
|
259
|
-
type='function',
|
|
260
|
-
function=ToolCallV2Function(
|
|
261
|
-
name=t.tool_name,
|
|
262
|
-
arguments=t.args_as_json_str(),
|
|
263
|
-
),
|
|
264
|
-
)
|
|
265
|
-
|
|
266
|
-
|
|
267
255
|
def _map_usage(response: ChatResponse) -> result.Usage:
|
|
268
256
|
usage = response.usage
|
|
269
257
|
if usage is None:
|