pydantic-ai-slim 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +1 -1
- pydantic_ai/_result.py +3 -7
- pydantic_ai/_utils.py +1 -56
- pydantic_ai/agent.py +34 -30
- pydantic_ai/messages.py +21 -46
- pydantic_ai/models/__init__.py +100 -57
- pydantic_ai/models/anthropic.py +17 -10
- pydantic_ai/models/cohere.py +37 -25
- pydantic_ai/models/gemini.py +20 -6
- pydantic_ai/models/groq.py +19 -17
- pydantic_ai/models/mistral.py +22 -23
- pydantic_ai/models/openai.py +19 -11
- pydantic_ai/models/test.py +37 -22
- pydantic_ai/result.py +1 -1
- pydantic_ai/settings.py +41 -1
- pydantic_ai/tools.py +11 -8
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.21.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.21.dist-info/RECORD +29 -0
- pydantic_ai/models/ollama.py +0 -123
- pydantic_ai_slim-0.0.20.dist-info/RECORD +0 -30
- {pydantic_ai_slim-0.0.20.dist-info → pydantic_ai_slim-0.0.21.dist-info}/WHEEL +0 -0
pydantic_ai/models/anthropic.py
CHANGED
|
@@ -13,7 +13,6 @@ from typing_extensions import assert_never
|
|
|
13
13
|
from .. import UnexpectedModelBehavior, _utils, usage
|
|
14
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
15
15
|
from ..messages import (
|
|
16
|
-
ArgsDict,
|
|
17
16
|
ModelMessage,
|
|
18
17
|
ModelRequest,
|
|
19
18
|
ModelResponse,
|
|
@@ -41,6 +40,7 @@ try:
|
|
|
41
40
|
from anthropic.types import (
|
|
42
41
|
Message as AnthropicMessage,
|
|
43
42
|
MessageParam,
|
|
43
|
+
MetadataParam,
|
|
44
44
|
RawContentBlockDeltaEvent,
|
|
45
45
|
RawContentBlockStartEvent,
|
|
46
46
|
RawContentBlockStopEvent,
|
|
@@ -79,6 +79,15 @@ Since [the Anthropic docs](https://docs.anthropic.com/en/docs/about-claude/model
|
|
|
79
79
|
"""
|
|
80
80
|
|
|
81
81
|
|
|
82
|
+
class AnthropicModelSettings(ModelSettings):
|
|
83
|
+
"""Settings used for an Anthropic model request."""
|
|
84
|
+
|
|
85
|
+
anthropic_metadata: MetadataParam
|
|
86
|
+
"""An object describing metadata about the request.
|
|
87
|
+
|
|
88
|
+
Contains `user_id`, an external identifier for the user who is associated with the request."""
|
|
89
|
+
|
|
90
|
+
|
|
82
91
|
@dataclass(init=False)
|
|
83
92
|
class AnthropicModel(Model):
|
|
84
93
|
"""A model that uses the Anthropic API.
|
|
@@ -167,35 +176,33 @@ class AnthropicAgentModel(AgentModel):
|
|
|
167
176
|
async def request(
|
|
168
177
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
169
178
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
170
|
-
response = await self._messages_create(messages, False, model_settings)
|
|
179
|
+
response = await self._messages_create(messages, False, cast(AnthropicModelSettings, model_settings or {}))
|
|
171
180
|
return self._process_response(response), _map_usage(response)
|
|
172
181
|
|
|
173
182
|
@asynccontextmanager
|
|
174
183
|
async def request_stream(
|
|
175
184
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
176
185
|
) -> AsyncIterator[StreamedResponse]:
|
|
177
|
-
response = await self._messages_create(messages, True, model_settings)
|
|
186
|
+
response = await self._messages_create(messages, True, cast(AnthropicModelSettings, model_settings or {}))
|
|
178
187
|
async with response:
|
|
179
188
|
yield await self._process_streamed_response(response)
|
|
180
189
|
|
|
181
190
|
@overload
|
|
182
191
|
async def _messages_create(
|
|
183
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
192
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: AnthropicModelSettings
|
|
184
193
|
) -> AsyncStream[RawMessageStreamEvent]:
|
|
185
194
|
pass
|
|
186
195
|
|
|
187
196
|
@overload
|
|
188
197
|
async def _messages_create(
|
|
189
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
198
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: AnthropicModelSettings
|
|
190
199
|
) -> AnthropicMessage:
|
|
191
200
|
pass
|
|
192
201
|
|
|
193
202
|
async def _messages_create(
|
|
194
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
203
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: AnthropicModelSettings
|
|
195
204
|
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
|
|
196
205
|
# standalone function to make it easier to override
|
|
197
|
-
model_settings = model_settings or {}
|
|
198
|
-
|
|
199
206
|
tool_choice: ToolChoiceParam | None
|
|
200
207
|
|
|
201
208
|
if not self.tools:
|
|
@@ -222,6 +229,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
222
229
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
223
230
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
224
231
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
232
|
+
metadata=model_settings.get('anthropic_metadata', NOT_GIVEN),
|
|
225
233
|
)
|
|
226
234
|
|
|
227
235
|
def _process_response(self, response: AnthropicMessage) -> ModelResponse:
|
|
@@ -233,7 +241,7 @@ class AnthropicAgentModel(AgentModel):
|
|
|
233
241
|
else:
|
|
234
242
|
assert isinstance(item, ToolUseBlock), 'unexpected item type'
|
|
235
243
|
items.append(
|
|
236
|
-
ToolCallPart
|
|
244
|
+
ToolCallPart(
|
|
237
245
|
tool_name=item.name,
|
|
238
246
|
args=cast(dict[str, Any], item.input),
|
|
239
247
|
tool_call_id=item.id,
|
|
@@ -310,7 +318,6 @@ class AnthropicAgentModel(AgentModel):
|
|
|
310
318
|
|
|
311
319
|
|
|
312
320
|
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
|
|
313
|
-
assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
|
|
314
321
|
return ToolUseBlockParam(
|
|
315
322
|
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
|
|
316
323
|
type='tool_use',
|
pydantic_ai/models/cohere.py
CHANGED
|
@@ -3,9 +3,10 @@ from __future__ import annotations as _annotations
|
|
|
3
3
|
from collections.abc import Iterable
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from itertools import chain
|
|
6
|
-
from typing import Literal,
|
|
6
|
+
from typing import Literal, Union, cast
|
|
7
7
|
|
|
8
8
|
from cohere import TextAssistantMessageContentItem
|
|
9
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
9
10
|
from typing_extensions import assert_never
|
|
10
11
|
|
|
11
12
|
from .. import result
|
|
@@ -51,24 +52,30 @@ except ImportError as _import_error:
|
|
|
51
52
|
"you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
|
|
52
53
|
) from _import_error
|
|
53
54
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
'command-r-plus-08-2024',
|
|
69
|
-
'command-r7b-12-2024',
|
|
70
|
-
],
|
|
55
|
+
NamedCohereModels = Literal[
|
|
56
|
+
'c4ai-aya-expanse-32b',
|
|
57
|
+
'c4ai-aya-expanse-8b',
|
|
58
|
+
'command',
|
|
59
|
+
'command-light',
|
|
60
|
+
'command-light-nightly',
|
|
61
|
+
'command-nightly',
|
|
62
|
+
'command-r',
|
|
63
|
+
'command-r-03-2024',
|
|
64
|
+
'command-r-08-2024',
|
|
65
|
+
'command-r-plus',
|
|
66
|
+
'command-r-plus-04-2024',
|
|
67
|
+
'command-r-plus-08-2024',
|
|
68
|
+
'command-r7b-12-2024',
|
|
71
69
|
]
|
|
70
|
+
"""Latest / most popular named Cohere models."""
|
|
71
|
+
|
|
72
|
+
CohereModelName = Union[NamedCohereModels, str]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class CohereModelSettings(ModelSettings):
|
|
76
|
+
"""Settings used for a Cohere model request."""
|
|
77
|
+
|
|
78
|
+
# This class is a placeholder for any future cohere-specific settings
|
|
72
79
|
|
|
73
80
|
|
|
74
81
|
@dataclass(init=False)
|
|
@@ -90,6 +97,7 @@ class CohereModel(Model):
|
|
|
90
97
|
*,
|
|
91
98
|
api_key: str | None = None,
|
|
92
99
|
cohere_client: AsyncClientV2 | None = None,
|
|
100
|
+
http_client: AsyncHTTPClient | None = None,
|
|
93
101
|
):
|
|
94
102
|
"""Initialize an Cohere model.
|
|
95
103
|
|
|
@@ -97,16 +105,18 @@ class CohereModel(Model):
|
|
|
97
105
|
model_name: The name of the Cohere model to use. List of model names
|
|
98
106
|
available [here](https://docs.cohere.com/docs/models#command).
|
|
99
107
|
api_key: The API key to use for authentication, if not provided, the
|
|
100
|
-
`
|
|
108
|
+
`CO_API_KEY` environment variable will be used if available.
|
|
101
109
|
cohere_client: An existing Cohere async client to use. If provided,
|
|
102
|
-
`api_key` must be `None`.
|
|
110
|
+
`api_key` and `http_client` must be `None`.
|
|
111
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
103
112
|
"""
|
|
104
113
|
self.model_name: CohereModelName = model_name
|
|
105
114
|
if cohere_client is not None:
|
|
115
|
+
assert http_client is None, 'Cannot provide both `cohere_client` and `http_client`'
|
|
106
116
|
assert api_key is None, 'Cannot provide both `cohere_client` and `api_key`'
|
|
107
117
|
self.client = cohere_client
|
|
108
118
|
else:
|
|
109
|
-
self.client = AsyncClientV2(api_key=api_key) # type: ignore
|
|
119
|
+
self.client = AsyncClientV2(api_key=api_key, httpx_client=http_client) # type: ignore
|
|
110
120
|
|
|
111
121
|
async def agent_model(
|
|
112
122
|
self,
|
|
@@ -153,16 +163,15 @@ class CohereAgentModel(AgentModel):
|
|
|
153
163
|
async def request(
|
|
154
164
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
155
165
|
) -> tuple[ModelResponse, result.Usage]:
|
|
156
|
-
response = await self._chat(messages, model_settings)
|
|
166
|
+
response = await self._chat(messages, cast(CohereModelSettings, model_settings or {}))
|
|
157
167
|
return self._process_response(response), _map_usage(response)
|
|
158
168
|
|
|
159
169
|
async def _chat(
|
|
160
170
|
self,
|
|
161
171
|
messages: list[ModelMessage],
|
|
162
|
-
model_settings:
|
|
172
|
+
model_settings: CohereModelSettings,
|
|
163
173
|
) -> ChatResponse:
|
|
164
174
|
cohere_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
165
|
-
model_settings = model_settings or {}
|
|
166
175
|
return await self.client.chat(
|
|
167
176
|
model=self.model_name,
|
|
168
177
|
messages=cohere_messages,
|
|
@@ -170,6 +179,9 @@ class CohereAgentModel(AgentModel):
|
|
|
170
179
|
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
171
180
|
temperature=model_settings.get('temperature', OMIT),
|
|
172
181
|
p=model_settings.get('top_p', OMIT),
|
|
182
|
+
seed=model_settings.get('seed', OMIT),
|
|
183
|
+
presence_penalty=model_settings.get('presence_penalty', OMIT),
|
|
184
|
+
frequency_penalty=model_settings.get('frequency_penalty', OMIT),
|
|
173
185
|
)
|
|
174
186
|
|
|
175
187
|
def _process_response(self, response: ChatResponse) -> ModelResponse:
|
|
@@ -183,7 +195,7 @@ class CohereAgentModel(AgentModel):
|
|
|
183
195
|
for c in response.message.tool_calls or []:
|
|
184
196
|
if c.function and c.function.name and c.function.arguments:
|
|
185
197
|
parts.append(
|
|
186
|
-
ToolCallPart
|
|
198
|
+
ToolCallPart(
|
|
187
199
|
tool_name=c.function.name,
|
|
188
200
|
args=c.function.arguments,
|
|
189
201
|
tool_call_id=c.id,
|
pydantic_ai/models/gemini.py
CHANGED
|
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
|
|
|
7
7
|
from copy import deepcopy
|
|
8
8
|
from dataclasses import dataclass, field
|
|
9
9
|
from datetime import datetime
|
|
10
|
-
from typing import Annotated, Any, Literal, Protocol, Union
|
|
10
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
11
11
|
from uuid import uuid4
|
|
12
12
|
|
|
13
13
|
import pydantic
|
|
@@ -48,6 +48,12 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#mo
|
|
|
48
48
|
"""
|
|
49
49
|
|
|
50
50
|
|
|
51
|
+
class GeminiModelSettings(ModelSettings):
|
|
52
|
+
"""Settings used for a Gemini model request."""
|
|
53
|
+
|
|
54
|
+
# This class is a placeholder for any future gemini-specific settings
|
|
55
|
+
|
|
56
|
+
|
|
51
57
|
@dataclass(init=False)
|
|
52
58
|
class GeminiModel(Model):
|
|
53
59
|
"""A model that uses Gemini via `generativelanguage.googleapis.com` API.
|
|
@@ -171,7 +177,9 @@ class GeminiAgentModel(AgentModel):
|
|
|
171
177
|
async def request(
|
|
172
178
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
173
179
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
174
|
-
async with self._make_request(
|
|
180
|
+
async with self._make_request(
|
|
181
|
+
messages, False, cast(GeminiModelSettings, model_settings or {})
|
|
182
|
+
) as http_response:
|
|
175
183
|
response = _gemini_response_ta.validate_json(await http_response.aread())
|
|
176
184
|
return self._process_response(response), _metadata_as_usage(response)
|
|
177
185
|
|
|
@@ -179,12 +187,12 @@ class GeminiAgentModel(AgentModel):
|
|
|
179
187
|
async def request_stream(
|
|
180
188
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
181
189
|
) -> AsyncIterator[StreamedResponse]:
|
|
182
|
-
async with self._make_request(messages, True, model_settings) as http_response:
|
|
190
|
+
async with self._make_request(messages, True, cast(GeminiModelSettings, model_settings or {})) as http_response:
|
|
183
191
|
yield await self._process_streamed_response(http_response)
|
|
184
192
|
|
|
185
193
|
@asynccontextmanager
|
|
186
194
|
async def _make_request(
|
|
187
|
-
self, messages: list[ModelMessage], streamed: bool, model_settings:
|
|
195
|
+
self, messages: list[ModelMessage], streamed: bool, model_settings: GeminiModelSettings
|
|
188
196
|
) -> AsyncIterator[HTTPResponse]:
|
|
189
197
|
sys_prompt_parts, contents = self._message_to_gemini_content(messages)
|
|
190
198
|
|
|
@@ -204,6 +212,10 @@ class GeminiAgentModel(AgentModel):
|
|
|
204
212
|
generation_config['temperature'] = temperature
|
|
205
213
|
if (top_p := model_settings.get('top_p')) is not None:
|
|
206
214
|
generation_config['top_p'] = top_p
|
|
215
|
+
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
|
|
216
|
+
generation_config['presence_penalty'] = presence_penalty
|
|
217
|
+
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
|
|
218
|
+
generation_config['frequency_penalty'] = frequency_penalty
|
|
207
219
|
if generation_config:
|
|
208
220
|
request_data['generation_config'] = generation_config
|
|
209
221
|
|
|
@@ -222,7 +234,7 @@ class GeminiAgentModel(AgentModel):
|
|
|
222
234
|
url,
|
|
223
235
|
content=request_json,
|
|
224
236
|
headers=headers,
|
|
225
|
-
timeout=
|
|
237
|
+
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
|
|
226
238
|
) as r:
|
|
227
239
|
if r.status_code != 200:
|
|
228
240
|
await r.aread()
|
|
@@ -398,6 +410,8 @@ class _GeminiGenerationConfig(TypedDict, total=False):
|
|
|
398
410
|
max_output_tokens: int
|
|
399
411
|
temperature: float
|
|
400
412
|
top_p: float
|
|
413
|
+
presence_penalty: float
|
|
414
|
+
frequency_penalty: float
|
|
401
415
|
|
|
402
416
|
|
|
403
417
|
class _GeminiContent(TypedDict):
|
|
@@ -439,7 +453,7 @@ def _process_response_from_parts(
|
|
|
439
453
|
items.append(TextPart(content=part['text']))
|
|
440
454
|
elif 'function_call' in part:
|
|
441
455
|
items.append(
|
|
442
|
-
ToolCallPart
|
|
456
|
+
ToolCallPart(
|
|
443
457
|
tool_name=part['function_call']['name'],
|
|
444
458
|
args=part['function_call']['args'],
|
|
445
459
|
)
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from itertools import chain
|
|
8
|
-
from typing import Literal, overload
|
|
8
|
+
from typing import Literal, cast, overload
|
|
9
9
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -47,10 +47,7 @@ except ImportError as _import_error:
|
|
|
47
47
|
|
|
48
48
|
GroqModelName = Literal[
|
|
49
49
|
'llama-3.3-70b-versatile',
|
|
50
|
-
'llama-3.
|
|
51
|
-
'llama3-groq-70b-8192-tool-use-preview',
|
|
52
|
-
'llama3-groq-8b-8192-tool-use-preview',
|
|
53
|
-
'llama-3.1-70b-specdec',
|
|
50
|
+
'llama-3.3-70b-specdec',
|
|
54
51
|
'llama-3.1-8b-instant',
|
|
55
52
|
'llama-3.2-1b-preview',
|
|
56
53
|
'llama-3.2-3b-preview',
|
|
@@ -60,7 +57,6 @@ GroqModelName = Literal[
|
|
|
60
57
|
'llama3-8b-8192',
|
|
61
58
|
'mixtral-8x7b-32768',
|
|
62
59
|
'gemma2-9b-it',
|
|
63
|
-
'gemma-7b-it',
|
|
64
60
|
]
|
|
65
61
|
"""Named Groq models.
|
|
66
62
|
|
|
@@ -68,6 +64,12 @@ See [the Groq docs](https://console.groq.com/docs/models) for a full list.
|
|
|
68
64
|
"""
|
|
69
65
|
|
|
70
66
|
|
|
67
|
+
class GroqModelSettings(ModelSettings):
|
|
68
|
+
"""Settings used for a Groq model request."""
|
|
69
|
+
|
|
70
|
+
# This class is a placeholder for any future groq-specific settings
|
|
71
|
+
|
|
72
|
+
|
|
71
73
|
@dataclass(init=False)
|
|
72
74
|
class GroqModel(Model):
|
|
73
75
|
"""A model that uses the Groq API.
|
|
@@ -155,31 +157,31 @@ class GroqAgentModel(AgentModel):
|
|
|
155
157
|
async def request(
|
|
156
158
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
157
159
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
158
|
-
response = await self._completions_create(messages, False, model_settings)
|
|
160
|
+
response = await self._completions_create(messages, False, cast(GroqModelSettings, model_settings or {}))
|
|
159
161
|
return self._process_response(response), _map_usage(response)
|
|
160
162
|
|
|
161
163
|
@asynccontextmanager
|
|
162
164
|
async def request_stream(
|
|
163
165
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
164
166
|
) -> AsyncIterator[StreamedResponse]:
|
|
165
|
-
response = await self._completions_create(messages, True, model_settings)
|
|
167
|
+
response = await self._completions_create(messages, True, cast(GroqModelSettings, model_settings or {}))
|
|
166
168
|
async with response:
|
|
167
169
|
yield await self._process_streamed_response(response)
|
|
168
170
|
|
|
169
171
|
@overload
|
|
170
172
|
async def _completions_create(
|
|
171
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
173
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: GroqModelSettings
|
|
172
174
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
173
175
|
pass
|
|
174
176
|
|
|
175
177
|
@overload
|
|
176
178
|
async def _completions_create(
|
|
177
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
179
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: GroqModelSettings
|
|
178
180
|
) -> chat.ChatCompletion:
|
|
179
181
|
pass
|
|
180
182
|
|
|
181
183
|
async def _completions_create(
|
|
182
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
184
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: GroqModelSettings
|
|
183
185
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
184
186
|
# standalone function to make it easier to override
|
|
185
187
|
if not self.tools:
|
|
@@ -191,13 +193,11 @@ class GroqAgentModel(AgentModel):
|
|
|
191
193
|
|
|
192
194
|
groq_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
193
195
|
|
|
194
|
-
model_settings = model_settings or {}
|
|
195
|
-
|
|
196
196
|
return await self.client.chat.completions.create(
|
|
197
197
|
model=str(self.model_name),
|
|
198
198
|
messages=groq_messages,
|
|
199
199
|
n=1,
|
|
200
|
-
parallel_tool_calls=model_settings.get('parallel_tool_calls',
|
|
200
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
201
201
|
tools=self.tools or NOT_GIVEN,
|
|
202
202
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
203
203
|
stream=stream,
|
|
@@ -205,6 +205,10 @@ class GroqAgentModel(AgentModel):
|
|
|
205
205
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
206
206
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
207
207
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
208
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
209
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
210
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
211
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
208
212
|
)
|
|
209
213
|
|
|
210
214
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
@@ -216,9 +220,7 @@ class GroqAgentModel(AgentModel):
|
|
|
216
220
|
items.append(TextPart(content=choice.message.content))
|
|
217
221
|
if choice.message.tool_calls is not None:
|
|
218
222
|
for c in choice.message.tool_calls:
|
|
219
|
-
items.append(
|
|
220
|
-
ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
|
|
221
|
-
)
|
|
223
|
+
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
|
|
222
224
|
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
223
225
|
|
|
224
226
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -6,7 +6,7 @@ from contextlib import asynccontextmanager
|
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from datetime import datetime, timezone
|
|
8
8
|
from itertools import chain
|
|
9
|
-
from typing import Any, Callable, Literal, Union
|
|
9
|
+
from typing import Any, Callable, Literal, Union, cast
|
|
10
10
|
|
|
11
11
|
import pydantic_core
|
|
12
12
|
from httpx import AsyncClient as AsyncHTTPClient, Timeout
|
|
@@ -15,7 +15,6 @@ from typing_extensions import assert_never
|
|
|
15
15
|
from .. import UnexpectedModelBehavior, _utils
|
|
16
16
|
from .._utils import now_utc as _now_utc
|
|
17
17
|
from ..messages import (
|
|
18
|
-
ArgsJson,
|
|
19
18
|
ModelMessage,
|
|
20
19
|
ModelRequest,
|
|
21
20
|
ModelResponse,
|
|
@@ -85,6 +84,12 @@ Since [the Mistral docs](https://docs.mistral.ai/getting-started/models/models_o
|
|
|
85
84
|
"""
|
|
86
85
|
|
|
87
86
|
|
|
87
|
+
class MistralModelSettings(ModelSettings):
|
|
88
|
+
"""Settings used for a Mistral model request."""
|
|
89
|
+
|
|
90
|
+
# This class is a placeholder for any future mistral-specific settings
|
|
91
|
+
|
|
92
|
+
|
|
88
93
|
@dataclass(init=False)
|
|
89
94
|
class MistralModel(Model):
|
|
90
95
|
"""A model that uses Mistral.
|
|
@@ -159,7 +164,7 @@ class MistralAgentModel(AgentModel):
|
|
|
159
164
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
160
165
|
) -> tuple[ModelResponse, Usage]:
|
|
161
166
|
"""Make a non-streaming request to the model from Pydantic AI call."""
|
|
162
|
-
response = await self._completions_create(messages, model_settings)
|
|
167
|
+
response = await self._completions_create(messages, cast(MistralModelSettings, model_settings or {}))
|
|
163
168
|
return self._process_response(response), _map_usage(response)
|
|
164
169
|
|
|
165
170
|
@asynccontextmanager
|
|
@@ -167,15 +172,14 @@ class MistralAgentModel(AgentModel):
|
|
|
167
172
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
168
173
|
) -> AsyncIterator[StreamedResponse]:
|
|
169
174
|
"""Make a streaming request to the model from Pydantic AI call."""
|
|
170
|
-
response = await self._stream_completions_create(messages, model_settings)
|
|
175
|
+
response = await self._stream_completions_create(messages, cast(MistralModelSettings, model_settings or {}))
|
|
171
176
|
async with response:
|
|
172
177
|
yield await self._process_streamed_response(self.result_tools, response)
|
|
173
178
|
|
|
174
179
|
async def _completions_create(
|
|
175
|
-
self, messages: list[ModelMessage], model_settings:
|
|
180
|
+
self, messages: list[ModelMessage], model_settings: MistralModelSettings
|
|
176
181
|
) -> MistralChatCompletionResponse:
|
|
177
182
|
"""Make a non-streaming request to the model."""
|
|
178
|
-
model_settings = model_settings or {}
|
|
179
183
|
response = await self.client.chat.complete_async(
|
|
180
184
|
model=str(self.model_name),
|
|
181
185
|
messages=list(chain(*(self._map_message(m) for m in messages))),
|
|
@@ -187,6 +191,7 @@ class MistralAgentModel(AgentModel):
|
|
|
187
191
|
temperature=model_settings.get('temperature', UNSET),
|
|
188
192
|
top_p=model_settings.get('top_p', 1),
|
|
189
193
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
194
|
+
random_seed=model_settings.get('seed', UNSET),
|
|
190
195
|
)
|
|
191
196
|
assert response, 'A unexpected empty response from Mistral.'
|
|
192
197
|
return response
|
|
@@ -194,12 +199,11 @@ class MistralAgentModel(AgentModel):
|
|
|
194
199
|
async def _stream_completions_create(
|
|
195
200
|
self,
|
|
196
201
|
messages: list[ModelMessage],
|
|
197
|
-
model_settings:
|
|
202
|
+
model_settings: MistralModelSettings,
|
|
198
203
|
) -> MistralEventStreamAsync[MistralCompletionEvent]:
|
|
199
204
|
"""Create a streaming completion request to the Mistral model."""
|
|
200
205
|
response: MistralEventStreamAsync[MistralCompletionEvent] | None
|
|
201
206
|
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
202
|
-
model_settings = model_settings or {}
|
|
203
207
|
|
|
204
208
|
if self.result_tools and self.function_tools or self.function_tools:
|
|
205
209
|
# Function Calling
|
|
@@ -213,6 +217,8 @@ class MistralAgentModel(AgentModel):
|
|
|
213
217
|
top_p=model_settings.get('top_p', 1),
|
|
214
218
|
max_tokens=model_settings.get('max_tokens', UNSET),
|
|
215
219
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
220
|
+
presence_penalty=model_settings.get('presence_penalty'),
|
|
221
|
+
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
216
222
|
)
|
|
217
223
|
|
|
218
224
|
elif self.result_tools:
|
|
@@ -317,18 +323,11 @@ class MistralAgentModel(AgentModel):
|
|
|
317
323
|
@staticmethod
|
|
318
324
|
def _map_to_mistral_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
319
325
|
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
)
|
|
326
|
-
else:
|
|
327
|
-
return MistralToolCall(
|
|
328
|
-
id=t.tool_call_id,
|
|
329
|
-
type='function',
|
|
330
|
-
function=MistralFunctionCall(name=t.tool_name, arguments=t.args.args_dict),
|
|
331
|
-
)
|
|
326
|
+
return MistralToolCall(
|
|
327
|
+
id=t.tool_call_id,
|
|
328
|
+
type='function',
|
|
329
|
+
function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
|
|
330
|
+
)
|
|
332
331
|
|
|
333
332
|
def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:
|
|
334
333
|
"""Get a message with an example of the expected output format."""
|
|
@@ -511,7 +510,7 @@ class MistralStreamedResponse(StreamedResponse):
|
|
|
511
510
|
continue
|
|
512
511
|
|
|
513
512
|
# The following part_id will be thrown away
|
|
514
|
-
return ToolCallPart
|
|
513
|
+
return ToolCallPart(tool_name=result_tool.name, args=output_json)
|
|
515
514
|
|
|
516
515
|
@staticmethod
|
|
517
516
|
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
|
|
@@ -569,7 +568,7 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
|
|
|
569
568
|
tool_call_id = tool_call.id or None
|
|
570
569
|
func_call = tool_call.function
|
|
571
570
|
|
|
572
|
-
return ToolCallPart
|
|
571
|
+
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
573
572
|
|
|
574
573
|
|
|
575
574
|
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:
|
|
@@ -600,7 +599,7 @@ def _map_content(content: MistralOptionalNullable[MistralContent]) -> str | None
|
|
|
600
599
|
elif isinstance(content, str):
|
|
601
600
|
result = content
|
|
602
601
|
|
|
603
|
-
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and
|
|
602
|
+
# Note: Check len to handle potential mismatch between function calls and responses from the API. (`msg: not the same number of function class and responses`)
|
|
604
603
|
if result and len(result) == 0:
|
|
605
604
|
result = None
|
|
606
605
|
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -5,7 +5,7 @@ from contextlib import asynccontextmanager
|
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from datetime import datetime, timezone
|
|
7
7
|
from itertools import chain
|
|
8
|
-
from typing import Literal, Union, overload
|
|
8
|
+
from typing import Literal, Union, cast, overload
|
|
9
9
|
|
|
10
10
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
11
11
|
from typing_extensions import assert_never
|
|
@@ -48,12 +48,18 @@ except ImportError as _import_error:
|
|
|
48
48
|
OpenAIModelName = Union[ChatModel, str]
|
|
49
49
|
"""
|
|
50
50
|
Using this more broad type for the model name instead of the ChatModel definition
|
|
51
|
-
allows this model to be used more easily with other model types (ie, Ollama)
|
|
51
|
+
allows this model to be used more easily with other model types (ie, Ollama, Deepseek)
|
|
52
52
|
"""
|
|
53
53
|
|
|
54
54
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
55
55
|
|
|
56
56
|
|
|
57
|
+
class OpenAIModelSettings(ModelSettings):
|
|
58
|
+
"""Settings used for an OpenAI model request."""
|
|
59
|
+
|
|
60
|
+
# This class is a placeholder for any future openai-specific settings
|
|
61
|
+
|
|
62
|
+
|
|
57
63
|
@dataclass(init=False)
|
|
58
64
|
class OpenAIModel(Model):
|
|
59
65
|
"""A model that uses the OpenAI API.
|
|
@@ -153,31 +159,31 @@ class OpenAIAgentModel(AgentModel):
|
|
|
153
159
|
async def request(
|
|
154
160
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
155
161
|
) -> tuple[ModelResponse, usage.Usage]:
|
|
156
|
-
response = await self._completions_create(messages, False, model_settings)
|
|
162
|
+
response = await self._completions_create(messages, False, cast(OpenAIModelSettings, model_settings or {}))
|
|
157
163
|
return self._process_response(response), _map_usage(response)
|
|
158
164
|
|
|
159
165
|
@asynccontextmanager
|
|
160
166
|
async def request_stream(
|
|
161
167
|
self, messages: list[ModelMessage], model_settings: ModelSettings | None
|
|
162
168
|
) -> AsyncIterator[StreamedResponse]:
|
|
163
|
-
response = await self._completions_create(messages, True, model_settings)
|
|
169
|
+
response = await self._completions_create(messages, True, cast(OpenAIModelSettings, model_settings or {}))
|
|
164
170
|
async with response:
|
|
165
171
|
yield await self._process_streamed_response(response)
|
|
166
172
|
|
|
167
173
|
@overload
|
|
168
174
|
async def _completions_create(
|
|
169
|
-
self, messages: list[ModelMessage], stream: Literal[True], model_settings:
|
|
175
|
+
self, messages: list[ModelMessage], stream: Literal[True], model_settings: OpenAIModelSettings
|
|
170
176
|
) -> AsyncStream[ChatCompletionChunk]:
|
|
171
177
|
pass
|
|
172
178
|
|
|
173
179
|
@overload
|
|
174
180
|
async def _completions_create(
|
|
175
|
-
self, messages: list[ModelMessage], stream: Literal[False], model_settings:
|
|
181
|
+
self, messages: list[ModelMessage], stream: Literal[False], model_settings: OpenAIModelSettings
|
|
176
182
|
) -> chat.ChatCompletion:
|
|
177
183
|
pass
|
|
178
184
|
|
|
179
185
|
async def _completions_create(
|
|
180
|
-
self, messages: list[ModelMessage], stream: bool, model_settings:
|
|
186
|
+
self, messages: list[ModelMessage], stream: bool, model_settings: OpenAIModelSettings
|
|
181
187
|
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
|
182
188
|
# standalone function to make it easier to override
|
|
183
189
|
if not self.tools:
|
|
@@ -189,13 +195,11 @@ class OpenAIAgentModel(AgentModel):
|
|
|
189
195
|
|
|
190
196
|
openai_messages = list(chain(*(self._map_message(m) for m in messages)))
|
|
191
197
|
|
|
192
|
-
model_settings = model_settings or {}
|
|
193
|
-
|
|
194
198
|
return await self.client.chat.completions.create(
|
|
195
199
|
model=self.model_name,
|
|
196
200
|
messages=openai_messages,
|
|
197
201
|
n=1,
|
|
198
|
-
parallel_tool_calls=model_settings.get('parallel_tool_calls',
|
|
202
|
+
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
199
203
|
tools=self.tools or NOT_GIVEN,
|
|
200
204
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
201
205
|
stream=stream,
|
|
@@ -204,6 +208,10 @@ class OpenAIAgentModel(AgentModel):
|
|
|
204
208
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
205
209
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
206
210
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
211
|
+
seed=model_settings.get('seed', NOT_GIVEN),
|
|
212
|
+
presence_penalty=model_settings.get('presence_penalty', NOT_GIVEN),
|
|
213
|
+
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
214
|
+
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
207
215
|
)
|
|
208
216
|
|
|
209
217
|
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
|
|
@@ -215,7 +223,7 @@ class OpenAIAgentModel(AgentModel):
|
|
|
215
223
|
items.append(TextPart(choice.message.content))
|
|
216
224
|
if choice.message.tool_calls is not None:
|
|
217
225
|
for c in choice.message.tool_calls:
|
|
218
|
-
items.append(ToolCallPart
|
|
226
|
+
items.append(ToolCallPart(c.function.name, c.function.arguments, c.id))
|
|
219
227
|
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
|
|
220
228
|
|
|
221
229
|
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
|