pydantic-ai-slim 0.0.43__py3-none-any.whl → 0.0.45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_cli.py +1 -1
- pydantic_ai/_griffe.py +29 -2
- pydantic_ai/_parts_manager.py +7 -1
- pydantic_ai/_utils.py +12 -6
- pydantic_ai/agent.py +2 -2
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/messages.py +15 -27
- pydantic_ai/models/__init__.py +15 -14
- pydantic_ai/models/anthropic.py +7 -46
- pydantic_ai/models/bedrock.py +7 -11
- pydantic_ai/models/cohere.py +14 -20
- pydantic_ai/models/gemini.py +18 -73
- pydantic_ai/models/groq.py +9 -53
- pydantic_ai/models/instrumented.py +14 -3
- pydantic_ai/models/mistral.py +12 -51
- pydantic_ai/models/openai.py +17 -75
- pydantic_ai/providers/__init__.py +4 -0
- pydantic_ai/providers/anthropic.py +4 -5
- pydantic_ai/providers/azure.py +8 -9
- pydantic_ai/providers/bedrock.py +2 -1
- pydantic_ai/providers/cohere.py +71 -0
- pydantic_ai/providers/deepseek.py +4 -4
- pydantic_ai/providers/google_gla.py +3 -2
- pydantic_ai/providers/google_vertex.py +2 -3
- pydantic_ai/providers/groq.py +4 -5
- pydantic_ai/providers/mistral.py +4 -5
- pydantic_ai/providers/openai.py +5 -8
- {pydantic_ai_slim-0.0.43.dist-info → pydantic_ai_slim-0.0.45.dist-info}/METADATA +3 -3
- pydantic_ai_slim-0.0.45.dist-info/RECORD +50 -0
- pydantic_ai/models/vertexai.py +0 -260
- pydantic_ai_slim-0.0.43.dist-info/RECORD +0 -50
- {pydantic_ai_slim-0.0.43.dist-info → pydantic_ai_slim-0.0.45.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.43.dist-info → pydantic_ai_slim-0.0.45.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/gemini.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import os
|
|
5
4
|
import re
|
|
6
5
|
from collections.abc import AsyncIterator, Sequence
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
7
|
from copy import deepcopy
|
|
9
8
|
from dataclasses import dataclass, field
|
|
10
9
|
from datetime import datetime
|
|
11
|
-
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
10
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
12
11
|
from uuid import uuid4
|
|
13
12
|
|
|
13
|
+
import httpx
|
|
14
14
|
import pydantic
|
|
15
|
-
from httpx import USE_CLIENT_DEFAULT,
|
|
16
|
-
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
15
|
+
from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
|
|
16
|
+
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
17
17
|
|
|
18
18
|
from pydantic_ai.providers import Provider, infer_provider
|
|
19
19
|
|
|
@@ -85,78 +85,36 @@ class GeminiModel(Model):
|
|
|
85
85
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
86
86
|
"""
|
|
87
87
|
|
|
88
|
-
client:
|
|
88
|
+
client: httpx.AsyncClient = field(repr=False)
|
|
89
89
|
|
|
90
90
|
_model_name: GeminiModelName = field(repr=False)
|
|
91
|
-
_provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
91
|
+
_provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
|
|
92
92
|
_auth: AuthProtocol | None = field(repr=False)
|
|
93
93
|
_url: str | None = field(repr=False)
|
|
94
94
|
_system: str = field(default='gemini', repr=False)
|
|
95
95
|
|
|
96
|
-
@overload
|
|
97
96
|
def __init__(
|
|
98
97
|
self,
|
|
99
98
|
model_name: GeminiModelName,
|
|
100
99
|
*,
|
|
101
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
102
|
-
) -> None: ...
|
|
103
|
-
|
|
104
|
-
@deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
|
|
105
|
-
@overload
|
|
106
|
-
def __init__(
|
|
107
|
-
self,
|
|
108
|
-
model_name: GeminiModelName,
|
|
109
|
-
*,
|
|
110
|
-
provider: None = None,
|
|
111
|
-
api_key: str | None = None,
|
|
112
|
-
http_client: AsyncHTTPClient | None = None,
|
|
113
|
-
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
114
|
-
) -> None: ...
|
|
115
|
-
|
|
116
|
-
def __init__(
|
|
117
|
-
self,
|
|
118
|
-
model_name: GeminiModelName,
|
|
119
|
-
*,
|
|
120
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
|
|
121
|
-
api_key: str | None = None,
|
|
122
|
-
http_client: AsyncHTTPClient | None = None,
|
|
123
|
-
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
100
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
|
|
124
101
|
):
|
|
125
102
|
"""Initialize a Gemini model.
|
|
126
103
|
|
|
127
104
|
Args:
|
|
128
105
|
model_name: The name of the model to use.
|
|
129
|
-
provider: The provider to use for the
|
|
130
|
-
|
|
131
|
-
will be
|
|
132
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
133
|
-
url_template: The URL template to use for making requests, you shouldn't need to change this,
|
|
134
|
-
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
|
|
135
|
-
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
106
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
107
|
+
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
|
|
108
|
+
If not provided, a new provider will be created using the other parameters.
|
|
136
109
|
"""
|
|
137
110
|
self._model_name = model_name
|
|
138
111
|
self._provider = provider
|
|
139
112
|
|
|
140
|
-
if provider
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
self._url = str(self.client.base_url)
|
|
146
|
-
else:
|
|
147
|
-
if api_key is None:
|
|
148
|
-
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
149
|
-
api_key = env_api_key
|
|
150
|
-
else:
|
|
151
|
-
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
152
|
-
self.client = http_client or cached_async_http_client()
|
|
153
|
-
self._auth = ApiKeyAuth(api_key)
|
|
154
|
-
self._url = url_template.format(model=model_name)
|
|
155
|
-
|
|
156
|
-
@property
|
|
157
|
-
def auth(self) -> AuthProtocol:
|
|
158
|
-
assert self._auth is not None, 'Auth not initialized'
|
|
159
|
-
return self._auth
|
|
113
|
+
if isinstance(provider, str):
|
|
114
|
+
provider = infer_provider(provider)
|
|
115
|
+
self._system = provider.name
|
|
116
|
+
self.client = provider.client
|
|
117
|
+
self._url = str(self.client.base_url)
|
|
160
118
|
|
|
161
119
|
@property
|
|
162
120
|
def base_url(self) -> str:
|
|
@@ -252,18 +210,10 @@ class GeminiModel(Model):
|
|
|
252
210
|
if generation_config:
|
|
253
211
|
request_data['generation_config'] = generation_config
|
|
254
212
|
|
|
255
|
-
headers = {
|
|
256
|
-
|
|
257
|
-
'User-Agent': get_user_agent(),
|
|
258
|
-
}
|
|
259
|
-
if self._provider is None: # pragma: no cover
|
|
260
|
-
url = self.base_url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
261
|
-
headers.update(await self.auth.headers())
|
|
262
|
-
else:
|
|
263
|
-
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
213
|
+
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
214
|
+
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
264
215
|
|
|
265
216
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
266
|
-
|
|
267
217
|
async with self.client.stream(
|
|
268
218
|
'POST',
|
|
269
219
|
url,
|
|
@@ -603,12 +553,7 @@ def _process_response_from_parts(
|
|
|
603
553
|
if 'text' in part:
|
|
604
554
|
items.append(TextPart(content=part['text']))
|
|
605
555
|
elif 'function_call' in part:
|
|
606
|
-
items.append(
|
|
607
|
-
ToolCallPart(
|
|
608
|
-
tool_name=part['function_call']['name'],
|
|
609
|
-
args=part['function_call']['args'],
|
|
610
|
-
)
|
|
611
|
-
)
|
|
556
|
+
items.append(ToolCallPart(tool_name=part['function_call']['name'], args=part['function_call']['args']))
|
|
612
557
|
elif 'function_response' in part:
|
|
613
558
|
raise UnexpectedModelBehavior(
|
|
614
559
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -8,8 +8,7 @@ from datetime import datetime, timezone
|
|
|
8
8
|
from itertools import chain
|
|
9
9
|
from typing import Literal, Union, cast, overload
|
|
10
10
|
|
|
11
|
-
from
|
|
12
|
-
from typing_extensions import assert_never, deprecated
|
|
11
|
+
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -32,7 +31,7 @@ from ..messages import (
|
|
|
32
31
|
from ..providers import Provider, infer_provider
|
|
33
32
|
from ..settings import ModelSettings
|
|
34
33
|
from ..tools import ToolDefinition
|
|
35
|
-
from . import Model, ModelRequestParameters, StreamedResponse,
|
|
34
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
|
|
36
35
|
|
|
37
36
|
try:
|
|
38
37
|
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
@@ -90,35 +89,7 @@ class GroqModel(Model):
|
|
|
90
89
|
_model_name: GroqModelName = field(repr=False)
|
|
91
90
|
_system: str = field(default='groq', repr=False)
|
|
92
91
|
|
|
93
|
-
|
|
94
|
-
def __init__(
|
|
95
|
-
self,
|
|
96
|
-
model_name: GroqModelName,
|
|
97
|
-
*,
|
|
98
|
-
provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
|
|
99
|
-
) -> None: ...
|
|
100
|
-
|
|
101
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `groq_client`, and `http_client`.')
|
|
102
|
-
@overload
|
|
103
|
-
def __init__(
|
|
104
|
-
self,
|
|
105
|
-
model_name: GroqModelName,
|
|
106
|
-
*,
|
|
107
|
-
provider: None = None,
|
|
108
|
-
api_key: str | None = None,
|
|
109
|
-
groq_client: AsyncGroq | None = None,
|
|
110
|
-
http_client: AsyncHTTPClient | None = None,
|
|
111
|
-
) -> None: ...
|
|
112
|
-
|
|
113
|
-
def __init__(
|
|
114
|
-
self,
|
|
115
|
-
model_name: GroqModelName,
|
|
116
|
-
*,
|
|
117
|
-
provider: Literal['groq'] | Provider[AsyncGroq] | None = None,
|
|
118
|
-
api_key: str | None = None,
|
|
119
|
-
groq_client: AsyncGroq | None = None,
|
|
120
|
-
http_client: AsyncHTTPClient | None = None,
|
|
121
|
-
):
|
|
92
|
+
def __init__(self, model_name: GroqModelName, *, provider: Literal['groq'] | Provider[AsyncGroq] = 'groq'):
|
|
122
93
|
"""Initialize a Groq model.
|
|
123
94
|
|
|
124
95
|
Args:
|
|
@@ -127,27 +98,12 @@ class GroqModel(Model):
|
|
|
127
98
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
128
99
|
'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
|
|
129
100
|
created using the other parameters.
|
|
130
|
-
api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable
|
|
131
|
-
will be used if available.
|
|
132
|
-
groq_client: An existing
|
|
133
|
-
[`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage)
|
|
134
|
-
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
135
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
136
101
|
"""
|
|
137
102
|
self._model_name = model_name
|
|
138
103
|
|
|
139
|
-
if provider
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
self.client = provider.client
|
|
143
|
-
elif groq_client is not None:
|
|
144
|
-
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
|
|
145
|
-
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
|
|
146
|
-
self.client = groq_client
|
|
147
|
-
elif http_client is not None:
|
|
148
|
-
self.client = AsyncGroq(api_key=api_key, http_client=http_client)
|
|
149
|
-
else:
|
|
150
|
-
self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
|
|
104
|
+
if isinstance(provider, str):
|
|
105
|
+
provider = infer_provider(provider)
|
|
106
|
+
self.client = provider.client
|
|
151
107
|
|
|
152
108
|
@property
|
|
153
109
|
def base_url(self) -> str:
|
|
@@ -309,7 +265,7 @@ class GroqModel(Model):
|
|
|
309
265
|
@staticmethod
|
|
310
266
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
311
267
|
return chat.ChatCompletionMessageToolCallParam(
|
|
312
|
-
id=_guard_tool_call_id(t=t
|
|
268
|
+
id=_guard_tool_call_id(t=t),
|
|
313
269
|
type='function',
|
|
314
270
|
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
315
271
|
)
|
|
@@ -335,7 +291,7 @@ class GroqModel(Model):
|
|
|
335
291
|
elif isinstance(part, ToolReturnPart):
|
|
336
292
|
yield chat.ChatCompletionToolMessageParam(
|
|
337
293
|
role='tool',
|
|
338
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
294
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
339
295
|
content=part.model_response_str(),
|
|
340
296
|
)
|
|
341
297
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -344,7 +300,7 @@ class GroqModel(Model):
|
|
|
344
300
|
else:
|
|
345
301
|
yield chat.ChatCompletionToolMessageParam(
|
|
346
302
|
role='tool',
|
|
347
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
303
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
348
304
|
content=part.model_response(),
|
|
349
305
|
)
|
|
350
306
|
|
|
@@ -118,7 +118,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
118
118
|
model_settings: ModelSettings | None,
|
|
119
119
|
model_request_parameters: ModelRequestParameters,
|
|
120
120
|
) -> tuple[ModelResponse, Usage]:
|
|
121
|
-
with self._instrument(messages, model_settings) as finish:
|
|
121
|
+
with self._instrument(messages, model_settings, model_request_parameters) as finish:
|
|
122
122
|
response, usage = await super().request(messages, model_settings, model_request_parameters)
|
|
123
123
|
finish(response, usage)
|
|
124
124
|
return response, usage
|
|
@@ -130,7 +130,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
130
130
|
model_settings: ModelSettings | None,
|
|
131
131
|
model_request_parameters: ModelRequestParameters,
|
|
132
132
|
) -> AsyncIterator[StreamedResponse]:
|
|
133
|
-
with self._instrument(messages, model_settings) as finish:
|
|
133
|
+
with self._instrument(messages, model_settings, model_request_parameters) as finish:
|
|
134
134
|
response_stream: StreamedResponse | None = None
|
|
135
135
|
try:
|
|
136
136
|
async with super().request_stream(
|
|
@@ -146,6 +146,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
146
146
|
self,
|
|
147
147
|
messages: list[ModelMessage],
|
|
148
148
|
model_settings: ModelSettings | None,
|
|
149
|
+
model_request_parameters: ModelRequestParameters,
|
|
149
150
|
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
|
|
150
151
|
operation = 'chat'
|
|
151
152
|
span_name = f'{operation} {self.model_name}'
|
|
@@ -155,6 +156,13 @@ class InstrumentedModel(WrapperModel):
|
|
|
155
156
|
attributes: dict[str, AttributeValue] = {
|
|
156
157
|
'gen_ai.operation.name': operation,
|
|
157
158
|
**self.model_attributes(self.wrapped),
|
|
159
|
+
'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters)),
|
|
160
|
+
'logfire.json_schema': json.dumps(
|
|
161
|
+
{
|
|
162
|
+
'type': 'object',
|
|
163
|
+
'properties': {'model_request_parameters': {'type': 'object'}},
|
|
164
|
+
}
|
|
165
|
+
),
|
|
158
166
|
}
|
|
159
167
|
|
|
160
168
|
if model_settings:
|
|
@@ -207,7 +215,10 @@ class InstrumentedModel(WrapperModel):
|
|
|
207
215
|
'logfire.json_schema': json.dumps(
|
|
208
216
|
{
|
|
209
217
|
'type': 'object',
|
|
210
|
-
'properties': {
|
|
218
|
+
'properties': {
|
|
219
|
+
attr_name: {'type': 'array'},
|
|
220
|
+
'model_request_parameters': {'type': 'object'},
|
|
221
|
+
},
|
|
211
222
|
}
|
|
212
223
|
),
|
|
213
224
|
}
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -1,20 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import os
|
|
5
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
6
|
from dataclasses import dataclass, field
|
|
8
7
|
from datetime import datetime, timezone
|
|
9
8
|
from itertools import chain
|
|
10
|
-
from typing import Any,
|
|
9
|
+
from typing import Any, Literal, Union, cast
|
|
11
10
|
|
|
12
11
|
import pydantic_core
|
|
13
|
-
from httpx import
|
|
14
|
-
from typing_extensions import assert_never
|
|
12
|
+
from httpx import Timeout
|
|
13
|
+
from typing_extensions import assert_never
|
|
15
14
|
|
|
16
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
17
|
-
from .._utils import now_utc as _now_utc
|
|
16
|
+
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
18
17
|
from ..messages import (
|
|
19
18
|
BinaryContent,
|
|
20
19
|
DocumentUrl,
|
|
@@ -39,7 +38,6 @@ from . import (
|
|
|
39
38
|
Model,
|
|
40
39
|
ModelRequestParameters,
|
|
41
40
|
StreamedResponse,
|
|
42
|
-
cached_async_http_client,
|
|
43
41
|
check_allow_model_requests,
|
|
44
42
|
)
|
|
45
43
|
|
|
@@ -113,65 +111,28 @@ class MistralModel(Model):
|
|
|
113
111
|
_model_name: MistralModelName = field(repr=False)
|
|
114
112
|
_system: str = field(default='mistral_ai', repr=False)
|
|
115
113
|
|
|
116
|
-
@overload
|
|
117
114
|
def __init__(
|
|
118
115
|
self,
|
|
119
116
|
model_name: MistralModelName,
|
|
120
117
|
*,
|
|
121
118
|
provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
|
|
122
119
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
123
|
-
) -> None: ...
|
|
124
|
-
|
|
125
|
-
@overload
|
|
126
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `client` and `http_client`.')
|
|
127
|
-
def __init__(
|
|
128
|
-
self,
|
|
129
|
-
model_name: MistralModelName,
|
|
130
|
-
*,
|
|
131
|
-
provider: None = None,
|
|
132
|
-
api_key: str | Callable[[], str | None] | None = None,
|
|
133
|
-
client: Mistral | None = None,
|
|
134
|
-
http_client: AsyncHTTPClient | None = None,
|
|
135
|
-
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
136
|
-
) -> None: ...
|
|
137
|
-
|
|
138
|
-
def __init__(
|
|
139
|
-
self,
|
|
140
|
-
model_name: MistralModelName,
|
|
141
|
-
*,
|
|
142
|
-
provider: Literal['mistral'] | Provider[Mistral] | None = None,
|
|
143
|
-
api_key: str | Callable[[], str | None] | None = None,
|
|
144
|
-
client: Mistral | None = None,
|
|
145
|
-
http_client: AsyncHTTPClient | None = None,
|
|
146
|
-
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
147
120
|
):
|
|
148
121
|
"""Initialize a Mistral model.
|
|
149
122
|
|
|
150
123
|
Args:
|
|
124
|
+
model_name: The name of the model to use.
|
|
151
125
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
152
126
|
'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
|
|
153
127
|
created using the other parameters.
|
|
154
|
-
model_name: The name of the model to use.
|
|
155
|
-
api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
|
|
156
|
-
client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
157
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
158
128
|
json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
|
|
159
129
|
"""
|
|
160
130
|
self._model_name = model_name
|
|
161
131
|
self.json_mode_schema_prompt = json_mode_schema_prompt
|
|
162
132
|
|
|
163
|
-
if provider
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
provider = infer_provider(provider) # pragma: no cover
|
|
167
|
-
self.client = provider.client
|
|
168
|
-
elif client is not None:
|
|
169
|
-
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
|
|
170
|
-
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
|
|
171
|
-
self.client = client
|
|
172
|
-
else:
|
|
173
|
-
api_key = api_key or os.getenv('MISTRAL_API_KEY')
|
|
174
|
-
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
133
|
+
if isinstance(provider, str):
|
|
134
|
+
provider = infer_provider(provider)
|
|
135
|
+
self.client = provider.client
|
|
175
136
|
|
|
176
137
|
@property
|
|
177
138
|
def base_url(self) -> str:
|
|
@@ -380,16 +341,16 @@ class MistralModel(Model):
|
|
|
380
341
|
@staticmethod
|
|
381
342
|
def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
|
|
382
343
|
"""Maps a MistralToolCall to a ToolCall."""
|
|
383
|
-
tool_call_id = tool_call.id or
|
|
344
|
+
tool_call_id = tool_call.id or _generate_tool_call_id()
|
|
384
345
|
func_call = tool_call.function
|
|
385
346
|
|
|
386
347
|
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
387
348
|
|
|
388
349
|
@staticmethod
|
|
389
|
-
def
|
|
350
|
+
def _map_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
390
351
|
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
391
352
|
return MistralToolCall(
|
|
392
|
-
id=t
|
|
353
|
+
id=_utils.guard_tool_call_id(t=t),
|
|
393
354
|
type='function',
|
|
394
355
|
function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
|
|
395
356
|
)
|
|
@@ -502,7 +463,7 @@ class MistralModel(Model):
|
|
|
502
463
|
if isinstance(part, TextPart):
|
|
503
464
|
content_chunks.append(MistralTextChunk(text=part.content))
|
|
504
465
|
elif isinstance(part, ToolCallPart):
|
|
505
|
-
tool_calls.append(cls.
|
|
466
|
+
tool_calls.append(cls._map_tool_call(part))
|
|
506
467
|
else:
|
|
507
468
|
assert_never(part)
|
|
508
469
|
yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import os
|
|
5
4
|
from collections.abc import AsyncIterable, AsyncIterator
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
6
|
from dataclasses import dataclass, field
|
|
8
7
|
from datetime import datetime, timezone
|
|
9
8
|
from typing import Literal, Union, cast, overload
|
|
10
9
|
|
|
11
|
-
from
|
|
12
|
-
from typing_extensions import assert_never, deprecated
|
|
10
|
+
from typing_extensions import assert_never
|
|
13
11
|
|
|
14
12
|
from pydantic_ai.providers import Provider, infer_provider
|
|
15
13
|
|
|
@@ -75,7 +73,7 @@ allows this model to be used more easily with other model types (ie, Ollama, Dee
|
|
|
75
73
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
76
74
|
|
|
77
75
|
|
|
78
|
-
class OpenAIModelSettings(ModelSettings):
|
|
76
|
+
class OpenAIModelSettings(ModelSettings, total=False):
|
|
79
77
|
"""Settings used for an OpenAI model request."""
|
|
80
78
|
|
|
81
79
|
openai_reasoning_effort: chat.ChatCompletionReasoningEffort
|
|
@@ -85,6 +83,12 @@ class OpenAIModelSettings(ModelSettings):
|
|
|
85
83
|
result in faster responses and fewer tokens used on reasoning in a response.
|
|
86
84
|
"""
|
|
87
85
|
|
|
86
|
+
user: str
|
|
87
|
+
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
|
|
88
|
+
|
|
89
|
+
See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
|
|
90
|
+
"""
|
|
91
|
+
|
|
88
92
|
|
|
89
93
|
@dataclass(init=False)
|
|
90
94
|
class OpenAIModel(Model):
|
|
@@ -99,44 +103,14 @@ class OpenAIModel(Model):
|
|
|
99
103
|
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
100
104
|
|
|
101
105
|
_model_name: OpenAIModelName = field(repr=False)
|
|
102
|
-
_system: str = field(repr=False)
|
|
106
|
+
_system: str = field(default='openai', repr=False)
|
|
103
107
|
|
|
104
|
-
@overload
|
|
105
108
|
def __init__(
|
|
106
109
|
self,
|
|
107
110
|
model_name: OpenAIModelName,
|
|
108
111
|
*,
|
|
109
112
|
provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
|
|
110
113
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
111
|
-
system: str = 'openai',
|
|
112
|
-
) -> None: ...
|
|
113
|
-
|
|
114
|
-
@deprecated('Use the `provider` parameter instead of `base_url`, `api_key`, `openai_client` and `http_client`.')
|
|
115
|
-
@overload
|
|
116
|
-
def __init__(
|
|
117
|
-
self,
|
|
118
|
-
model_name: OpenAIModelName,
|
|
119
|
-
*,
|
|
120
|
-
provider: None = None,
|
|
121
|
-
base_url: str | None = None,
|
|
122
|
-
api_key: str | None = None,
|
|
123
|
-
openai_client: AsyncOpenAI | None = None,
|
|
124
|
-
http_client: AsyncHTTPClient | None = None,
|
|
125
|
-
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
126
|
-
system: str = 'openai',
|
|
127
|
-
) -> None: ...
|
|
128
|
-
|
|
129
|
-
def __init__(
|
|
130
|
-
self,
|
|
131
|
-
model_name: OpenAIModelName,
|
|
132
|
-
*,
|
|
133
|
-
provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] | None = None,
|
|
134
|
-
base_url: str | None = None,
|
|
135
|
-
api_key: str | None = None,
|
|
136
|
-
openai_client: AsyncOpenAI | None = None,
|
|
137
|
-
http_client: AsyncHTTPClient | None = None,
|
|
138
|
-
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
139
|
-
system: str = 'openai',
|
|
140
114
|
):
|
|
141
115
|
"""Initialize an OpenAI model.
|
|
142
116
|
|
|
@@ -145,47 +119,14 @@ class OpenAIModel(Model):
|
|
|
145
119
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
146
120
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
147
121
|
provider: The provider to use. Defaults to `'openai'`.
|
|
148
|
-
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
149
|
-
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
150
|
-
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
151
|
-
will be used if available.
|
|
152
|
-
openai_client: An existing
|
|
153
|
-
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
154
|
-
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
155
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
156
122
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
157
123
|
In the future, this may be inferred from the model name.
|
|
158
|
-
system: The model provider used, defaults to `openai`. This is for observability purposes, you must
|
|
159
|
-
customize the `base_url` and `api_key` to use a different provider.
|
|
160
124
|
"""
|
|
161
125
|
self._model_name = model_name
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
provider = infer_provider(provider)
|
|
166
|
-
self.client = provider.client
|
|
167
|
-
else: # pragma: no cover
|
|
168
|
-
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
169
|
-
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
170
|
-
if (
|
|
171
|
-
api_key is None
|
|
172
|
-
and 'OPENAI_API_KEY' not in os.environ
|
|
173
|
-
and base_url is not None
|
|
174
|
-
and openai_client is None
|
|
175
|
-
):
|
|
176
|
-
api_key = 'api-key-not-set'
|
|
177
|
-
|
|
178
|
-
if openai_client is not None:
|
|
179
|
-
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
180
|
-
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
181
|
-
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
182
|
-
self.client = openai_client
|
|
183
|
-
elif http_client is not None:
|
|
184
|
-
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
185
|
-
else:
|
|
186
|
-
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
126
|
+
if isinstance(provider, str):
|
|
127
|
+
provider = infer_provider(provider)
|
|
128
|
+
self.client = provider.client
|
|
187
129
|
self.system_prompt_role = system_prompt_role
|
|
188
|
-
self._system = system
|
|
189
130
|
|
|
190
131
|
@property
|
|
191
132
|
def base_url(self) -> str:
|
|
@@ -279,7 +220,7 @@ class OpenAIModel(Model):
|
|
|
279
220
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
280
221
|
stream=stream,
|
|
281
222
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
282
|
-
|
|
223
|
+
max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
283
224
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
284
225
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
285
226
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
@@ -288,6 +229,7 @@ class OpenAIModel(Model):
|
|
|
288
229
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
289
230
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
290
231
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
232
|
+
user=model_settings.get('user', NOT_GIVEN),
|
|
291
233
|
)
|
|
292
234
|
except APIStatusError as e:
|
|
293
235
|
if (status_code := e.status_code) >= 400:
|
|
@@ -354,7 +296,7 @@ class OpenAIModel(Model):
|
|
|
354
296
|
@staticmethod
|
|
355
297
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
356
298
|
return chat.ChatCompletionMessageToolCallParam(
|
|
357
|
-
id=_guard_tool_call_id(t=t
|
|
299
|
+
id=_guard_tool_call_id(t=t),
|
|
358
300
|
type='function',
|
|
359
301
|
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
360
302
|
)
|
|
@@ -384,7 +326,7 @@ class OpenAIModel(Model):
|
|
|
384
326
|
elif isinstance(part, ToolReturnPart):
|
|
385
327
|
yield chat.ChatCompletionToolMessageParam(
|
|
386
328
|
role='tool',
|
|
387
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
329
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
388
330
|
content=part.model_response_str(),
|
|
389
331
|
)
|
|
390
332
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -393,7 +335,7 @@ class OpenAIModel(Model):
|
|
|
393
335
|
else:
|
|
394
336
|
yield chat.ChatCompletionToolMessageParam(
|
|
395
337
|
role='tool',
|
|
396
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
338
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
397
339
|
content=part.model_response(),
|
|
398
340
|
)
|
|
399
341
|
else:
|
|
@@ -77,5 +77,9 @@ def infer_provider(provider: str) -> Provider[Any]:
|
|
|
77
77
|
from .mistral import MistralProvider
|
|
78
78
|
|
|
79
79
|
return MistralProvider()
|
|
80
|
+
elif provider == 'cohere':
|
|
81
|
+
from .cohere import CohereProvider
|
|
82
|
+
|
|
83
|
+
return CohereProvider()
|
|
80
84
|
else: # pragma: no cover
|
|
81
85
|
raise ValueError(f'Unknown provider: {provider}')
|
|
@@ -5,7 +5,9 @@ from typing import overload
|
|
|
5
5
|
|
|
6
6
|
import httpx
|
|
7
7
|
|
|
8
|
+
from pydantic_ai.exceptions import UserError
|
|
8
9
|
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.providers import Provider
|
|
9
11
|
|
|
10
12
|
try:
|
|
11
13
|
from anthropic import AsyncAnthropic
|
|
@@ -16,9 +18,6 @@ except ImportError as _import_error: # pragma: no cover
|
|
|
16
18
|
) from _import_error
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
from . import Provider
|
|
20
|
-
|
|
21
|
-
|
|
22
21
|
class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
23
22
|
"""Provider for Anthropic API."""
|
|
24
23
|
|
|
@@ -62,8 +61,8 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
62
61
|
self._client = anthropic_client
|
|
63
62
|
else:
|
|
64
63
|
api_key = api_key or os.environ.get('ANTHROPIC_API_KEY')
|
|
65
|
-
if api_key
|
|
66
|
-
raise
|
|
64
|
+
if not api_key:
|
|
65
|
+
raise UserError(
|
|
67
66
|
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
|
|
68
67
|
'to use the Anthropic provider.'
|
|
69
68
|
)
|