pydantic-ai-slim 0.0.44__py3-none-any.whl → 0.0.46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_parts_manager.py +7 -1
- pydantic_ai/_utils.py +12 -6
- pydantic_ai/agent.py +2 -2
- pydantic_ai/exceptions.py +2 -2
- pydantic_ai/mcp.py +25 -1
- pydantic_ai/messages.py +15 -27
- pydantic_ai/models/__init__.py +15 -6
- pydantic_ai/models/anthropic.py +7 -46
- pydantic_ai/models/bedrock.py +7 -11
- pydantic_ai/models/cohere.py +10 -50
- pydantic_ai/models/gemini.py +18 -73
- pydantic_ai/models/groq.py +9 -53
- pydantic_ai/models/mistral.py +12 -51
- pydantic_ai/models/openai.py +15 -67
- pydantic_ai/providers/anthropic.py +6 -6
- pydantic_ai/providers/azure.py +9 -10
- pydantic_ai/providers/bedrock.py +2 -1
- pydantic_ai/providers/cohere.py +6 -8
- pydantic_ai/providers/deepseek.py +6 -5
- pydantic_ai/providers/google_gla.py +4 -3
- pydantic_ai/providers/google_vertex.py +3 -4
- pydantic_ai/providers/groq.py +6 -8
- pydantic_ai/providers/mistral.py +6 -6
- pydantic_ai/providers/openai.py +6 -8
- {pydantic_ai_slim-0.0.44.dist-info → pydantic_ai_slim-0.0.46.dist-info}/METADATA +2 -2
- pydantic_ai_slim-0.0.46.dist-info/RECORD +50 -0
- pydantic_ai/models/vertexai.py +0 -260
- pydantic_ai_slim-0.0.44.dist-info/RECORD +0 -51
- {pydantic_ai_slim-0.0.44.dist-info → pydantic_ai_slim-0.0.46.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-0.0.44.dist-info → pydantic_ai_slim-0.0.46.dist-info}/entry_points.txt +0 -0
pydantic_ai/models/gemini.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import os
|
|
5
4
|
import re
|
|
6
5
|
from collections.abc import AsyncIterator, Sequence
|
|
7
6
|
from contextlib import asynccontextmanager
|
|
8
7
|
from copy import deepcopy
|
|
9
8
|
from dataclasses import dataclass, field
|
|
10
9
|
from datetime import datetime
|
|
11
|
-
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
10
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
12
11
|
from uuid import uuid4
|
|
13
12
|
|
|
13
|
+
import httpx
|
|
14
14
|
import pydantic
|
|
15
|
-
from httpx import USE_CLIENT_DEFAULT,
|
|
16
|
-
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
15
|
+
from httpx import USE_CLIENT_DEFAULT, Response as HTTPResponse
|
|
16
|
+
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
17
17
|
|
|
18
18
|
from pydantic_ai.providers import Provider, infer_provider
|
|
19
19
|
|
|
@@ -85,78 +85,36 @@ class GeminiModel(Model):
|
|
|
85
85
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
86
86
|
"""
|
|
87
87
|
|
|
88
|
-
client:
|
|
88
|
+
client: httpx.AsyncClient = field(repr=False)
|
|
89
89
|
|
|
90
90
|
_model_name: GeminiModelName = field(repr=False)
|
|
91
|
-
_provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
91
|
+
_provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] | None = field(repr=False)
|
|
92
92
|
_auth: AuthProtocol | None = field(repr=False)
|
|
93
93
|
_url: str | None = field(repr=False)
|
|
94
94
|
_system: str = field(default='gemini', repr=False)
|
|
95
95
|
|
|
96
|
-
@overload
|
|
97
96
|
def __init__(
|
|
98
97
|
self,
|
|
99
98
|
model_name: GeminiModelName,
|
|
100
99
|
*,
|
|
101
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[
|
|
102
|
-
) -> None: ...
|
|
103
|
-
|
|
104
|
-
@deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
|
|
105
|
-
@overload
|
|
106
|
-
def __init__(
|
|
107
|
-
self,
|
|
108
|
-
model_name: GeminiModelName,
|
|
109
|
-
*,
|
|
110
|
-
provider: None = None,
|
|
111
|
-
api_key: str | None = None,
|
|
112
|
-
http_client: AsyncHTTPClient | None = None,
|
|
113
|
-
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
114
|
-
) -> None: ...
|
|
115
|
-
|
|
116
|
-
def __init__(
|
|
117
|
-
self,
|
|
118
|
-
model_name: GeminiModelName,
|
|
119
|
-
*,
|
|
120
|
-
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
|
|
121
|
-
api_key: str | None = None,
|
|
122
|
-
http_client: AsyncHTTPClient | None = None,
|
|
123
|
-
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
100
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[httpx.AsyncClient] = 'google-gla',
|
|
124
101
|
):
|
|
125
102
|
"""Initialize a Gemini model.
|
|
126
103
|
|
|
127
104
|
Args:
|
|
128
105
|
model_name: The name of the model to use.
|
|
129
|
-
provider: The provider to use for the
|
|
130
|
-
|
|
131
|
-
will be
|
|
132
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
133
|
-
url_template: The URL template to use for making requests, you shouldn't need to change this,
|
|
134
|
-
docs [here](https://ai.google.dev/gemini-api/docs/quickstart?lang=rest#make-first-request),
|
|
135
|
-
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
106
|
+
provider: The provider to use for authentication and API access. Can be either the string
|
|
107
|
+
'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`.
|
|
108
|
+
If not provided, a new provider will be created using the other parameters.
|
|
136
109
|
"""
|
|
137
110
|
self._model_name = model_name
|
|
138
111
|
self._provider = provider
|
|
139
112
|
|
|
140
|
-
if provider
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
self._url = str(self.client.base_url)
|
|
146
|
-
else:
|
|
147
|
-
if api_key is None:
|
|
148
|
-
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
149
|
-
api_key = env_api_key
|
|
150
|
-
else:
|
|
151
|
-
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
152
|
-
self.client = http_client or cached_async_http_client()
|
|
153
|
-
self._auth = ApiKeyAuth(api_key)
|
|
154
|
-
self._url = url_template.format(model=model_name)
|
|
155
|
-
|
|
156
|
-
@property
|
|
157
|
-
def auth(self) -> AuthProtocol:
|
|
158
|
-
assert self._auth is not None, 'Auth not initialized'
|
|
159
|
-
return self._auth
|
|
113
|
+
if isinstance(provider, str):
|
|
114
|
+
provider = infer_provider(provider)
|
|
115
|
+
self._system = provider.name
|
|
116
|
+
self.client = provider.client
|
|
117
|
+
self._url = str(self.client.base_url)
|
|
160
118
|
|
|
161
119
|
@property
|
|
162
120
|
def base_url(self) -> str:
|
|
@@ -252,18 +210,10 @@ class GeminiModel(Model):
|
|
|
252
210
|
if generation_config:
|
|
253
211
|
request_data['generation_config'] = generation_config
|
|
254
212
|
|
|
255
|
-
headers = {
|
|
256
|
-
|
|
257
|
-
'User-Agent': get_user_agent(),
|
|
258
|
-
}
|
|
259
|
-
if self._provider is None: # pragma: no cover
|
|
260
|
-
url = self.base_url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
261
|
-
headers.update(await self.auth.headers())
|
|
262
|
-
else:
|
|
263
|
-
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
213
|
+
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
|
|
214
|
+
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
264
215
|
|
|
265
216
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
266
|
-
|
|
267
217
|
async with self.client.stream(
|
|
268
218
|
'POST',
|
|
269
219
|
url,
|
|
@@ -603,12 +553,7 @@ def _process_response_from_parts(
|
|
|
603
553
|
if 'text' in part:
|
|
604
554
|
items.append(TextPart(content=part['text']))
|
|
605
555
|
elif 'function_call' in part:
|
|
606
|
-
items.append(
|
|
607
|
-
ToolCallPart(
|
|
608
|
-
tool_name=part['function_call']['name'],
|
|
609
|
-
args=part['function_call']['args'],
|
|
610
|
-
)
|
|
611
|
-
)
|
|
556
|
+
items.append(ToolCallPart(tool_name=part['function_call']['name'], args=part['function_call']['args']))
|
|
612
557
|
elif 'function_response' in part:
|
|
613
558
|
raise UnexpectedModelBehavior(
|
|
614
559
|
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
|
pydantic_ai/models/groq.py
CHANGED
|
@@ -8,8 +8,7 @@ from datetime import datetime, timezone
|
|
|
8
8
|
from itertools import chain
|
|
9
9
|
from typing import Literal, Union, cast, overload
|
|
10
10
|
|
|
11
|
-
from
|
|
12
|
-
from typing_extensions import assert_never, deprecated
|
|
11
|
+
from typing_extensions import assert_never
|
|
13
12
|
|
|
14
13
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
14
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -32,7 +31,7 @@ from ..messages import (
|
|
|
32
31
|
from ..providers import Provider, infer_provider
|
|
33
32
|
from ..settings import ModelSettings
|
|
34
33
|
from ..tools import ToolDefinition
|
|
35
|
-
from . import Model, ModelRequestParameters, StreamedResponse,
|
|
34
|
+
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests
|
|
36
35
|
|
|
37
36
|
try:
|
|
38
37
|
from groq import NOT_GIVEN, APIStatusError, AsyncGroq, AsyncStream
|
|
@@ -90,35 +89,7 @@ class GroqModel(Model):
|
|
|
90
89
|
_model_name: GroqModelName = field(repr=False)
|
|
91
90
|
_system: str = field(default='groq', repr=False)
|
|
92
91
|
|
|
93
|
-
|
|
94
|
-
def __init__(
|
|
95
|
-
self,
|
|
96
|
-
model_name: GroqModelName,
|
|
97
|
-
*,
|
|
98
|
-
provider: Literal['groq'] | Provider[AsyncGroq] = 'groq',
|
|
99
|
-
) -> None: ...
|
|
100
|
-
|
|
101
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `groq_client`, and `http_client`.')
|
|
102
|
-
@overload
|
|
103
|
-
def __init__(
|
|
104
|
-
self,
|
|
105
|
-
model_name: GroqModelName,
|
|
106
|
-
*,
|
|
107
|
-
provider: None = None,
|
|
108
|
-
api_key: str | None = None,
|
|
109
|
-
groq_client: AsyncGroq | None = None,
|
|
110
|
-
http_client: AsyncHTTPClient | None = None,
|
|
111
|
-
) -> None: ...
|
|
112
|
-
|
|
113
|
-
def __init__(
|
|
114
|
-
self,
|
|
115
|
-
model_name: GroqModelName,
|
|
116
|
-
*,
|
|
117
|
-
provider: Literal['groq'] | Provider[AsyncGroq] | None = None,
|
|
118
|
-
api_key: str | None = None,
|
|
119
|
-
groq_client: AsyncGroq | None = None,
|
|
120
|
-
http_client: AsyncHTTPClient | None = None,
|
|
121
|
-
):
|
|
92
|
+
def __init__(self, model_name: GroqModelName, *, provider: Literal['groq'] | Provider[AsyncGroq] = 'groq'):
|
|
122
93
|
"""Initialize a Groq model.
|
|
123
94
|
|
|
124
95
|
Args:
|
|
@@ -127,27 +98,12 @@ class GroqModel(Model):
|
|
|
127
98
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
128
99
|
'groq' or an instance of `Provider[AsyncGroq]`. If not provided, a new provider will be
|
|
129
100
|
created using the other parameters.
|
|
130
|
-
api_key: The API key to use for authentication, if not provided, the `GROQ_API_KEY` environment variable
|
|
131
|
-
will be used if available.
|
|
132
|
-
groq_client: An existing
|
|
133
|
-
[`AsyncGroq`](https://github.com/groq/groq-python?tab=readme-ov-file#async-usage)
|
|
134
|
-
client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
135
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
136
101
|
"""
|
|
137
102
|
self._model_name = model_name
|
|
138
103
|
|
|
139
|
-
if provider
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
self.client = provider.client
|
|
143
|
-
elif groq_client is not None:
|
|
144
|
-
assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
|
|
145
|
-
assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
|
|
146
|
-
self.client = groq_client
|
|
147
|
-
elif http_client is not None:
|
|
148
|
-
self.client = AsyncGroq(api_key=api_key, http_client=http_client)
|
|
149
|
-
else:
|
|
150
|
-
self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
|
|
104
|
+
if isinstance(provider, str):
|
|
105
|
+
provider = infer_provider(provider)
|
|
106
|
+
self.client = provider.client
|
|
151
107
|
|
|
152
108
|
@property
|
|
153
109
|
def base_url(self) -> str:
|
|
@@ -309,7 +265,7 @@ class GroqModel(Model):
|
|
|
309
265
|
@staticmethod
|
|
310
266
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
311
267
|
return chat.ChatCompletionMessageToolCallParam(
|
|
312
|
-
id=_guard_tool_call_id(t=t
|
|
268
|
+
id=_guard_tool_call_id(t=t),
|
|
313
269
|
type='function',
|
|
314
270
|
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
315
271
|
)
|
|
@@ -335,7 +291,7 @@ class GroqModel(Model):
|
|
|
335
291
|
elif isinstance(part, ToolReturnPart):
|
|
336
292
|
yield chat.ChatCompletionToolMessageParam(
|
|
337
293
|
role='tool',
|
|
338
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
294
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
339
295
|
content=part.model_response_str(),
|
|
340
296
|
)
|
|
341
297
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -344,7 +300,7 @@ class GroqModel(Model):
|
|
|
344
300
|
else:
|
|
345
301
|
yield chat.ChatCompletionToolMessageParam(
|
|
346
302
|
role='tool',
|
|
347
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
303
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
348
304
|
content=part.model_response(),
|
|
349
305
|
)
|
|
350
306
|
|
pydantic_ai/models/mistral.py
CHANGED
|
@@ -1,20 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import os
|
|
5
4
|
from collections.abc import AsyncIterable, AsyncIterator, Iterable
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
6
|
from dataclasses import dataclass, field
|
|
8
7
|
from datetime import datetime, timezone
|
|
9
8
|
from itertools import chain
|
|
10
|
-
from typing import Any,
|
|
9
|
+
from typing import Any, Literal, Union, cast
|
|
11
10
|
|
|
12
11
|
import pydantic_core
|
|
13
|
-
from httpx import
|
|
14
|
-
from typing_extensions import assert_never
|
|
12
|
+
from httpx import Timeout
|
|
13
|
+
from typing_extensions import assert_never
|
|
15
14
|
|
|
16
15
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
|
|
17
|
-
from .._utils import now_utc as _now_utc
|
|
16
|
+
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
|
|
18
17
|
from ..messages import (
|
|
19
18
|
BinaryContent,
|
|
20
19
|
DocumentUrl,
|
|
@@ -39,7 +38,6 @@ from . import (
|
|
|
39
38
|
Model,
|
|
40
39
|
ModelRequestParameters,
|
|
41
40
|
StreamedResponse,
|
|
42
|
-
cached_async_http_client,
|
|
43
41
|
check_allow_model_requests,
|
|
44
42
|
)
|
|
45
43
|
|
|
@@ -113,65 +111,28 @@ class MistralModel(Model):
|
|
|
113
111
|
_model_name: MistralModelName = field(repr=False)
|
|
114
112
|
_system: str = field(default='mistral_ai', repr=False)
|
|
115
113
|
|
|
116
|
-
@overload
|
|
117
114
|
def __init__(
|
|
118
115
|
self,
|
|
119
116
|
model_name: MistralModelName,
|
|
120
117
|
*,
|
|
121
118
|
provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
|
|
122
119
|
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
123
|
-
) -> None: ...
|
|
124
|
-
|
|
125
|
-
@overload
|
|
126
|
-
@deprecated('Use the `provider` parameter instead of `api_key`, `client` and `http_client`.')
|
|
127
|
-
def __init__(
|
|
128
|
-
self,
|
|
129
|
-
model_name: MistralModelName,
|
|
130
|
-
*,
|
|
131
|
-
provider: None = None,
|
|
132
|
-
api_key: str | Callable[[], str | None] | None = None,
|
|
133
|
-
client: Mistral | None = None,
|
|
134
|
-
http_client: AsyncHTTPClient | None = None,
|
|
135
|
-
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
136
|
-
) -> None: ...
|
|
137
|
-
|
|
138
|
-
def __init__(
|
|
139
|
-
self,
|
|
140
|
-
model_name: MistralModelName,
|
|
141
|
-
*,
|
|
142
|
-
provider: Literal['mistral'] | Provider[Mistral] | None = None,
|
|
143
|
-
api_key: str | Callable[[], str | None] | None = None,
|
|
144
|
-
client: Mistral | None = None,
|
|
145
|
-
http_client: AsyncHTTPClient | None = None,
|
|
146
|
-
json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
|
|
147
120
|
):
|
|
148
121
|
"""Initialize a Mistral model.
|
|
149
122
|
|
|
150
123
|
Args:
|
|
124
|
+
model_name: The name of the model to use.
|
|
151
125
|
provider: The provider to use for authentication and API access. Can be either the string
|
|
152
126
|
'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
|
|
153
127
|
created using the other parameters.
|
|
154
|
-
model_name: The name of the model to use.
|
|
155
|
-
api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
|
|
156
|
-
client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
|
|
157
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
158
128
|
json_mode_schema_prompt: The prompt to show when the model expects a JSON object as input.
|
|
159
129
|
"""
|
|
160
130
|
self._model_name = model_name
|
|
161
131
|
self.json_mode_schema_prompt = json_mode_schema_prompt
|
|
162
132
|
|
|
163
|
-
if provider
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
provider = infer_provider(provider) # pragma: no cover
|
|
167
|
-
self.client = provider.client
|
|
168
|
-
elif client is not None:
|
|
169
|
-
assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
|
|
170
|
-
assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
|
|
171
|
-
self.client = client
|
|
172
|
-
else:
|
|
173
|
-
api_key = api_key or os.getenv('MISTRAL_API_KEY')
|
|
174
|
-
self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
|
|
133
|
+
if isinstance(provider, str):
|
|
134
|
+
provider = infer_provider(provider)
|
|
135
|
+
self.client = provider.client
|
|
175
136
|
|
|
176
137
|
@property
|
|
177
138
|
def base_url(self) -> str:
|
|
@@ -380,16 +341,16 @@ class MistralModel(Model):
|
|
|
380
341
|
@staticmethod
|
|
381
342
|
def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPart:
|
|
382
343
|
"""Maps a MistralToolCall to a ToolCall."""
|
|
383
|
-
tool_call_id = tool_call.id or
|
|
344
|
+
tool_call_id = tool_call.id or _generate_tool_call_id()
|
|
384
345
|
func_call = tool_call.function
|
|
385
346
|
|
|
386
347
|
return ToolCallPart(func_call.name, func_call.arguments, tool_call_id)
|
|
387
348
|
|
|
388
349
|
@staticmethod
|
|
389
|
-
def
|
|
350
|
+
def _map_tool_call(t: ToolCallPart) -> MistralToolCall:
|
|
390
351
|
"""Maps a pydantic-ai ToolCall to a MistralToolCall."""
|
|
391
352
|
return MistralToolCall(
|
|
392
|
-
id=t
|
|
353
|
+
id=_utils.guard_tool_call_id(t=t),
|
|
393
354
|
type='function',
|
|
394
355
|
function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
|
|
395
356
|
)
|
|
@@ -502,7 +463,7 @@ class MistralModel(Model):
|
|
|
502
463
|
if isinstance(part, TextPart):
|
|
503
464
|
content_chunks.append(MistralTextChunk(text=part.content))
|
|
504
465
|
elif isinstance(part, ToolCallPart):
|
|
505
|
-
tool_calls.append(cls.
|
|
466
|
+
tool_calls.append(cls._map_tool_call(part))
|
|
506
467
|
else:
|
|
507
468
|
assert_never(part)
|
|
508
469
|
yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,15 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
-
import os
|
|
5
4
|
from collections.abc import AsyncIterable, AsyncIterator
|
|
6
5
|
from contextlib import asynccontextmanager
|
|
7
6
|
from dataclasses import dataclass, field
|
|
8
7
|
from datetime import datetime, timezone
|
|
9
8
|
from typing import Literal, Union, cast, overload
|
|
10
9
|
|
|
11
|
-
from
|
|
12
|
-
from typing_extensions import assert_never, deprecated
|
|
10
|
+
from typing_extensions import assert_never
|
|
13
11
|
|
|
14
12
|
from pydantic_ai.providers import Provider, infer_provider
|
|
15
13
|
|
|
@@ -75,7 +73,7 @@ allows this model to be used more easily with other model types (ie, Ollama, Dee
|
|
|
75
73
|
OpenAISystemPromptRole = Literal['system', 'developer', 'user']
|
|
76
74
|
|
|
77
75
|
|
|
78
|
-
class OpenAIModelSettings(ModelSettings):
|
|
76
|
+
class OpenAIModelSettings(ModelSettings, total=False):
|
|
79
77
|
"""Settings used for an OpenAI model request."""
|
|
80
78
|
|
|
81
79
|
openai_reasoning_effort: chat.ChatCompletionReasoningEffort
|
|
@@ -85,6 +83,12 @@ class OpenAIModelSettings(ModelSettings):
|
|
|
85
83
|
result in faster responses and fewer tokens used on reasoning in a response.
|
|
86
84
|
"""
|
|
87
85
|
|
|
86
|
+
user: str
|
|
87
|
+
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.
|
|
88
|
+
|
|
89
|
+
See [OpenAI's safety best practices](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids) for more details.
|
|
90
|
+
"""
|
|
91
|
+
|
|
88
92
|
|
|
89
93
|
@dataclass(init=False)
|
|
90
94
|
class OpenAIModel(Model):
|
|
@@ -101,39 +105,12 @@ class OpenAIModel(Model):
|
|
|
101
105
|
_model_name: OpenAIModelName = field(repr=False)
|
|
102
106
|
_system: str = field(default='openai', repr=False)
|
|
103
107
|
|
|
104
|
-
@overload
|
|
105
108
|
def __init__(
|
|
106
109
|
self,
|
|
107
110
|
model_name: OpenAIModelName,
|
|
108
111
|
*,
|
|
109
112
|
provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] = 'openai',
|
|
110
113
|
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
111
|
-
) -> None: ...
|
|
112
|
-
|
|
113
|
-
@deprecated('Use the `provider` parameter instead of `base_url`, `api_key`, `openai_client` and `http_client`.')
|
|
114
|
-
@overload
|
|
115
|
-
def __init__(
|
|
116
|
-
self,
|
|
117
|
-
model_name: OpenAIModelName,
|
|
118
|
-
*,
|
|
119
|
-
provider: None = None,
|
|
120
|
-
base_url: str | None = None,
|
|
121
|
-
api_key: str | None = None,
|
|
122
|
-
openai_client: AsyncOpenAI | None = None,
|
|
123
|
-
http_client: AsyncHTTPClient | None = None,
|
|
124
|
-
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
125
|
-
) -> None: ...
|
|
126
|
-
|
|
127
|
-
def __init__(
|
|
128
|
-
self,
|
|
129
|
-
model_name: OpenAIModelName,
|
|
130
|
-
*,
|
|
131
|
-
provider: Literal['openai', 'deepseek', 'azure'] | Provider[AsyncOpenAI] | None = None,
|
|
132
|
-
base_url: str | None = None,
|
|
133
|
-
api_key: str | None = None,
|
|
134
|
-
openai_client: AsyncOpenAI | None = None,
|
|
135
|
-
http_client: AsyncHTTPClient | None = None,
|
|
136
|
-
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
137
114
|
):
|
|
138
115
|
"""Initialize an OpenAI model.
|
|
139
116
|
|
|
@@ -142,43 +119,13 @@ class OpenAIModel(Model):
|
|
|
142
119
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
143
120
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
144
121
|
provider: The provider to use. Defaults to `'openai'`.
|
|
145
|
-
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
146
|
-
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
147
|
-
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
148
|
-
will be used if available.
|
|
149
|
-
openai_client: An existing
|
|
150
|
-
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
151
|
-
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
152
|
-
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
153
122
|
system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
|
|
154
123
|
In the future, this may be inferred from the model name.
|
|
155
124
|
"""
|
|
156
125
|
self._model_name = model_name
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
provider = infer_provider(provider)
|
|
161
|
-
self.client = provider.client
|
|
162
|
-
else: # pragma: no cover
|
|
163
|
-
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
164
|
-
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
165
|
-
if (
|
|
166
|
-
api_key is None
|
|
167
|
-
and 'OPENAI_API_KEY' not in os.environ
|
|
168
|
-
and base_url is not None
|
|
169
|
-
and openai_client is None
|
|
170
|
-
):
|
|
171
|
-
api_key = 'api-key-not-set'
|
|
172
|
-
|
|
173
|
-
if openai_client is not None:
|
|
174
|
-
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
175
|
-
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
176
|
-
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
177
|
-
self.client = openai_client
|
|
178
|
-
elif http_client is not None:
|
|
179
|
-
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
180
|
-
else:
|
|
181
|
-
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
126
|
+
if isinstance(provider, str):
|
|
127
|
+
provider = infer_provider(provider)
|
|
128
|
+
self.client = provider.client
|
|
182
129
|
self.system_prompt_role = system_prompt_role
|
|
183
130
|
|
|
184
131
|
@property
|
|
@@ -282,6 +229,7 @@ class OpenAIModel(Model):
|
|
|
282
229
|
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
|
|
283
230
|
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
|
|
284
231
|
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
|
|
232
|
+
user=model_settings.get('user', NOT_GIVEN),
|
|
285
233
|
)
|
|
286
234
|
except APIStatusError as e:
|
|
287
235
|
if (status_code := e.status_code) >= 400:
|
|
@@ -348,7 +296,7 @@ class OpenAIModel(Model):
|
|
|
348
296
|
@staticmethod
|
|
349
297
|
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
|
|
350
298
|
return chat.ChatCompletionMessageToolCallParam(
|
|
351
|
-
id=_guard_tool_call_id(t=t
|
|
299
|
+
id=_guard_tool_call_id(t=t),
|
|
352
300
|
type='function',
|
|
353
301
|
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
|
|
354
302
|
)
|
|
@@ -378,7 +326,7 @@ class OpenAIModel(Model):
|
|
|
378
326
|
elif isinstance(part, ToolReturnPart):
|
|
379
327
|
yield chat.ChatCompletionToolMessageParam(
|
|
380
328
|
role='tool',
|
|
381
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
329
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
382
330
|
content=part.model_response_str(),
|
|
383
331
|
)
|
|
384
332
|
elif isinstance(part, RetryPromptPart):
|
|
@@ -387,7 +335,7 @@ class OpenAIModel(Model):
|
|
|
387
335
|
else:
|
|
388
336
|
yield chat.ChatCompletionToolMessageParam(
|
|
389
337
|
role='tool',
|
|
390
|
-
tool_call_id=_guard_tool_call_id(t=part
|
|
338
|
+
tool_call_id=_guard_tool_call_id(t=part),
|
|
391
339
|
content=part.model_response(),
|
|
392
340
|
)
|
|
393
341
|
else:
|
|
@@ -5,7 +5,9 @@ from typing import overload
|
|
|
5
5
|
|
|
6
6
|
import httpx
|
|
7
7
|
|
|
8
|
+
from pydantic_ai.exceptions import UserError
|
|
8
9
|
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
from pydantic_ai.providers import Provider
|
|
9
11
|
|
|
10
12
|
try:
|
|
11
13
|
from anthropic import AsyncAnthropic
|
|
@@ -16,9 +18,6 @@ except ImportError as _import_error: # pragma: no cover
|
|
|
16
18
|
) from _import_error
|
|
17
19
|
|
|
18
20
|
|
|
19
|
-
from . import Provider
|
|
20
|
-
|
|
21
|
-
|
|
22
21
|
class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
23
22
|
"""Provider for Anthropic API."""
|
|
24
23
|
|
|
@@ -62,8 +61,8 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
62
61
|
self._client = anthropic_client
|
|
63
62
|
else:
|
|
64
63
|
api_key = api_key or os.environ.get('ANTHROPIC_API_KEY')
|
|
65
|
-
if api_key
|
|
66
|
-
raise
|
|
64
|
+
if not api_key:
|
|
65
|
+
raise UserError(
|
|
67
66
|
'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
|
|
68
67
|
'to use the Anthropic provider.'
|
|
69
68
|
)
|
|
@@ -71,4 +70,5 @@ class AnthropicProvider(Provider[AsyncAnthropic]):
|
|
|
71
70
|
if http_client is not None:
|
|
72
71
|
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
|
73
72
|
else:
|
|
74
|
-
|
|
73
|
+
http_client = cached_async_http_client(provider='anthropic')
|
|
74
|
+
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
|
pydantic_ai/providers/azure.py
CHANGED
|
@@ -6,7 +6,9 @@ from typing import overload
|
|
|
6
6
|
import httpx
|
|
7
7
|
from openai import AsyncOpenAI
|
|
8
8
|
|
|
9
|
+
from pydantic_ai.exceptions import UserError
|
|
9
10
|
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.providers import Provider
|
|
10
12
|
|
|
11
13
|
try:
|
|
12
14
|
from openai import AsyncAzureOpenAI
|
|
@@ -17,9 +19,6 @@ except ImportError as _import_error: # pragma: no cover
|
|
|
17
19
|
) from _import_error
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
from . import Provider
|
|
21
|
-
|
|
22
|
-
|
|
23
22
|
class AzureProvider(Provider[AsyncOpenAI]):
|
|
24
23
|
"""Provider for Azure OpenAI API.
|
|
25
24
|
|
|
@@ -83,22 +82,22 @@ class AzureProvider(Provider[AsyncOpenAI]):
|
|
|
83
82
|
self._client = openai_client
|
|
84
83
|
else:
|
|
85
84
|
azure_endpoint = azure_endpoint or os.getenv('AZURE_OPENAI_ENDPOINT')
|
|
86
|
-
if azure_endpoint
|
|
87
|
-
raise
|
|
85
|
+
if not azure_endpoint: # pragma: no cover
|
|
86
|
+
raise UserError(
|
|
88
87
|
'Must provide one of the `azure_endpoint` argument or the `AZURE_OPENAI_ENDPOINT` environment variable'
|
|
89
88
|
)
|
|
90
89
|
|
|
91
|
-
if api_key
|
|
92
|
-
raise
|
|
90
|
+
if not api_key and 'OPENAI_API_KEY' not in os.environ: # pragma: no cover
|
|
91
|
+
raise UserError(
|
|
93
92
|
'Must provide one of the `api_key` argument or the `OPENAI_API_KEY` environment variable'
|
|
94
93
|
)
|
|
95
94
|
|
|
96
|
-
if api_version
|
|
97
|
-
raise
|
|
95
|
+
if not api_version and 'OPENAI_API_VERSION' not in os.environ: # pragma: no cover
|
|
96
|
+
raise UserError(
|
|
98
97
|
'Must provide one of the `api_version` argument or the `OPENAI_API_VERSION` environment variable'
|
|
99
98
|
)
|
|
100
99
|
|
|
101
|
-
http_client = http_client or cached_async_http_client()
|
|
100
|
+
http_client = http_client or cached_async_http_client(provider='azure')
|
|
102
101
|
self._client = AsyncAzureOpenAI(
|
|
103
102
|
azure_endpoint=azure_endpoint,
|
|
104
103
|
api_key=api_key,
|
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -2,6 +2,7 @@ from __future__ import annotations as _annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import overload
|
|
4
4
|
|
|
5
|
+
from pydantic_ai.exceptions import UserError
|
|
5
6
|
from pydantic_ai.providers import Provider
|
|
6
7
|
|
|
7
8
|
try:
|
|
@@ -73,4 +74,4 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
73
74
|
region_name=region_name,
|
|
74
75
|
)
|
|
75
76
|
except NoRegionError as exc: # pragma: no cover
|
|
76
|
-
raise
|
|
77
|
+
raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
|