pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai/__init__.py +6 -0
- pydantic_ai/_agent_graph.py +67 -20
- pydantic_ai/_cli.py +2 -2
- pydantic_ai/_output.py +20 -12
- pydantic_ai/_run_context.py +6 -2
- pydantic_ai/_utils.py +26 -8
- pydantic_ai/ag_ui.py +50 -696
- pydantic_ai/agent/__init__.py +13 -25
- pydantic_ai/agent/abstract.py +146 -9
- pydantic_ai/builtin_tools.py +106 -4
- pydantic_ai/direct.py +16 -4
- pydantic_ai/durable_exec/dbos/_agent.py +3 -0
- pydantic_ai/durable_exec/prefect/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/__init__.py +11 -0
- pydantic_ai/durable_exec/temporal/_agent.py +3 -0
- pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
- pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
- pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
- pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
- pydantic_ai/exceptions.py +6 -1
- pydantic_ai/mcp.py +1 -22
- pydantic_ai/messages.py +46 -8
- pydantic_ai/models/__init__.py +87 -38
- pydantic_ai/models/anthropic.py +132 -11
- pydantic_ai/models/bedrock.py +4 -4
- pydantic_ai/models/cohere.py +0 -7
- pydantic_ai/models/gemini.py +9 -2
- pydantic_ai/models/google.py +26 -23
- pydantic_ai/models/groq.py +13 -5
- pydantic_ai/models/huggingface.py +2 -2
- pydantic_ai/models/openai.py +251 -52
- pydantic_ai/models/outlines.py +563 -0
- pydantic_ai/models/test.py +6 -3
- pydantic_ai/profiles/openai.py +7 -0
- pydantic_ai/providers/__init__.py +25 -12
- pydantic_ai/providers/anthropic.py +2 -2
- pydantic_ai/providers/bedrock.py +60 -16
- pydantic_ai/providers/gateway.py +60 -72
- pydantic_ai/providers/google.py +91 -24
- pydantic_ai/providers/openrouter.py +3 -0
- pydantic_ai/providers/outlines.py +40 -0
- pydantic_ai/providers/ovhcloud.py +95 -0
- pydantic_ai/result.py +173 -8
- pydantic_ai/run.py +40 -24
- pydantic_ai/settings.py +8 -0
- pydantic_ai/tools.py +10 -6
- pydantic_ai/toolsets/fastmcp.py +215 -0
- pydantic_ai/ui/__init__.py +16 -0
- pydantic_ai/ui/_adapter.py +386 -0
- pydantic_ai/ui/_event_stream.py +591 -0
- pydantic_ai/ui/_messages_builder.py +28 -0
- pydantic_ai/ui/ag_ui/__init__.py +9 -0
- pydantic_ai/ui/ag_ui/_adapter.py +187 -0
- pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
- pydantic_ai/ui/ag_ui/app.py +148 -0
- pydantic_ai/ui/vercel_ai/__init__.py +16 -0
- pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
- pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
- pydantic_ai/ui/vercel_ai/_utils.py +16 -0
- pydantic_ai/ui/vercel_ai/request_types.py +275 -0
- pydantic_ai/ui/vercel_ai/response_types.py +230 -0
- pydantic_ai/usage.py +13 -2
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import Literal, overload
|
|
7
|
+
from typing import Any, Literal, overload
|
|
8
8
|
|
|
9
9
|
from pydantic_ai import ModelProfile
|
|
10
10
|
from pydantic_ai.exceptions import UserError
|
|
@@ -21,6 +21,8 @@ try:
|
|
|
21
21
|
from botocore.client import BaseClient
|
|
22
22
|
from botocore.config import Config
|
|
23
23
|
from botocore.exceptions import NoRegionError
|
|
24
|
+
from botocore.session import Session
|
|
25
|
+
from botocore.tokens import FrozenAuthToken
|
|
24
26
|
except ImportError as _import_error:
|
|
25
27
|
raise ImportError(
|
|
26
28
|
'Please install the `boto3` package to use the Bedrock provider, '
|
|
@@ -117,10 +119,23 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
117
119
|
def __init__(
|
|
118
120
|
self,
|
|
119
121
|
*,
|
|
122
|
+
api_key: str,
|
|
123
|
+
base_url: str | None = None,
|
|
120
124
|
region_name: str | None = None,
|
|
125
|
+
profile_name: str | None = None,
|
|
126
|
+
aws_read_timeout: float | None = None,
|
|
127
|
+
aws_connect_timeout: float | None = None,
|
|
128
|
+
) -> None: ...
|
|
129
|
+
|
|
130
|
+
@overload
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
*,
|
|
121
134
|
aws_access_key_id: str | None = None,
|
|
122
135
|
aws_secret_access_key: str | None = None,
|
|
123
136
|
aws_session_token: str | None = None,
|
|
137
|
+
base_url: str | None = None,
|
|
138
|
+
region_name: str | None = None,
|
|
124
139
|
profile_name: str | None = None,
|
|
125
140
|
aws_read_timeout: float | None = None,
|
|
126
141
|
aws_connect_timeout: float | None = None,
|
|
@@ -130,11 +145,13 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
130
145
|
self,
|
|
131
146
|
*,
|
|
132
147
|
bedrock_client: BaseClient | None = None,
|
|
133
|
-
region_name: str | None = None,
|
|
134
148
|
aws_access_key_id: str | None = None,
|
|
135
149
|
aws_secret_access_key: str | None = None,
|
|
136
150
|
aws_session_token: str | None = None,
|
|
151
|
+
base_url: str | None = None,
|
|
152
|
+
region_name: str | None = None,
|
|
137
153
|
profile_name: str | None = None,
|
|
154
|
+
api_key: str | None = None,
|
|
138
155
|
aws_read_timeout: float | None = None,
|
|
139
156
|
aws_connect_timeout: float | None = None,
|
|
140
157
|
) -> None:
|
|
@@ -142,10 +159,12 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
142
159
|
|
|
143
160
|
Args:
|
|
144
161
|
bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
162
|
+
aws_access_key_id: The AWS access key ID. If not set, the `AWS_ACCESS_KEY_ID` environment variable will be used if available.
|
|
163
|
+
aws_secret_access_key: The AWS secret access key. If not set, the `AWS_SECRET_ACCESS_KEY` environment variable will be used if available.
|
|
164
|
+
aws_session_token: The AWS session token. If not set, the `AWS_SESSION_TOKEN` environment variable will be used if available.
|
|
165
|
+
api_key: The API key for Bedrock client. Can be used instead of `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`. If not set, the `AWS_BEARER_TOKEN_BEDROCK` environment variable will be used if available.
|
|
166
|
+
base_url: The base URL for the Bedrock client.
|
|
167
|
+
region_name: The AWS region name. If not set, the `AWS_DEFAULT_REGION` environment variable will be used if available.
|
|
149
168
|
profile_name: The AWS profile name.
|
|
150
169
|
aws_read_timeout: The read timeout for Bedrock client.
|
|
151
170
|
aws_connect_timeout: The connect timeout for Bedrock client.
|
|
@@ -153,19 +172,44 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
153
172
|
if bedrock_client is not None:
|
|
154
173
|
self._client = bedrock_client
|
|
155
174
|
else:
|
|
175
|
+
read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
|
|
176
|
+
connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
|
|
177
|
+
config: dict[str, Any] = {
|
|
178
|
+
'read_timeout': read_timeout,
|
|
179
|
+
'connect_timeout': connect_timeout,
|
|
180
|
+
}
|
|
156
181
|
try:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
182
|
+
if api_key is not None:
|
|
183
|
+
session = boto3.Session(
|
|
184
|
+
botocore_session=_BearerTokenSession(api_key),
|
|
185
|
+
region_name=region_name,
|
|
186
|
+
profile_name=profile_name,
|
|
187
|
+
)
|
|
188
|
+
config['signature_version'] = 'bearer'
|
|
189
|
+
else:
|
|
190
|
+
session = boto3.Session(
|
|
191
|
+
aws_access_key_id=aws_access_key_id,
|
|
192
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
193
|
+
aws_session_token=aws_session_token,
|
|
194
|
+
region_name=region_name,
|
|
195
|
+
profile_name=profile_name,
|
|
196
|
+
)
|
|
166
197
|
self._client = session.client( # type: ignore[reportUnknownMemberType]
|
|
167
198
|
'bedrock-runtime',
|
|
168
|
-
config=Config(
|
|
199
|
+
config=Config(**config),
|
|
200
|
+
endpoint_url=base_url,
|
|
169
201
|
)
|
|
170
202
|
except NoRegionError as exc: # pragma: no cover
|
|
171
203
|
raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class _BearerTokenSession(Session):
|
|
207
|
+
def __init__(self, token: str):
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.token = token
|
|
210
|
+
|
|
211
|
+
def get_auth_token(self, **_kwargs: Any) -> FrozenAuthToken:
|
|
212
|
+
return FrozenAuthToken(self.token)
|
|
213
|
+
|
|
214
|
+
def get_credentials(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
|
|
215
|
+
return None
|
pydantic_ai/providers/gateway.py
CHANGED
|
@@ -3,14 +3,16 @@
|
|
|
3
3
|
from __future__ import annotations as _annotations
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
6
7
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
|
7
8
|
|
|
8
9
|
import httpx
|
|
9
10
|
|
|
10
11
|
from pydantic_ai.exceptions import UserError
|
|
11
|
-
from pydantic_ai.models import
|
|
12
|
+
from pydantic_ai.models import cached_async_http_client
|
|
12
13
|
|
|
13
14
|
if TYPE_CHECKING:
|
|
15
|
+
from botocore.client import BaseClient
|
|
14
16
|
from google.genai import Client as GoogleClient
|
|
15
17
|
from groq import AsyncGroq
|
|
16
18
|
from openai import AsyncOpenAI
|
|
@@ -18,6 +20,8 @@ if TYPE_CHECKING:
|
|
|
18
20
|
from pydantic_ai.models.anthropic import AsyncAnthropicClient
|
|
19
21
|
from pydantic_ai.providers import Provider
|
|
20
22
|
|
|
23
|
+
GATEWAY_BASE_URL = 'https://gateway.pydantic.dev/proxy'
|
|
24
|
+
|
|
21
25
|
|
|
22
26
|
@overload
|
|
23
27
|
def gateway_provider(
|
|
@@ -57,13 +61,34 @@ def gateway_provider(
|
|
|
57
61
|
) -> Provider[AsyncAnthropicClient]: ...
|
|
58
62
|
|
|
59
63
|
|
|
64
|
+
@overload
|
|
65
|
+
def gateway_provider(
|
|
66
|
+
upstream_provider: Literal['bedrock'],
|
|
67
|
+
*,
|
|
68
|
+
api_key: str | None = None,
|
|
69
|
+
base_url: str | None = None,
|
|
70
|
+
) -> Provider[BaseClient]: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@overload
|
|
60
74
|
def gateway_provider(
|
|
61
|
-
upstream_provider:
|
|
75
|
+
upstream_provider: str,
|
|
76
|
+
*,
|
|
77
|
+
api_key: str | None = None,
|
|
78
|
+
base_url: str | None = None,
|
|
79
|
+
) -> Provider[Any]: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def gateway_provider(
|
|
86
|
+
upstream_provider: UpstreamProvider | str,
|
|
62
87
|
*,
|
|
63
88
|
# Every provider
|
|
64
89
|
api_key: str | None = None,
|
|
65
90
|
base_url: str | None = None,
|
|
66
|
-
# OpenAI &
|
|
91
|
+
# OpenAI, Groq & Anthropic
|
|
67
92
|
http_client: httpx.AsyncClient | None = None,
|
|
68
93
|
) -> Provider[Any]:
|
|
69
94
|
"""Create a new Gateway provider.
|
|
@@ -73,25 +98,21 @@ def gateway_provider(
|
|
|
73
98
|
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
|
|
74
99
|
environment variable will be used if available.
|
|
75
100
|
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
|
|
76
|
-
environment variable will be used if available. Otherwise, defaults to `
|
|
101
|
+
environment variable will be used if available. Otherwise, defaults to `https://gateway.pydantic.dev/proxy`.
|
|
77
102
|
http_client: The HTTP client to use for the Gateway.
|
|
78
103
|
"""
|
|
79
104
|
api_key = api_key or os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
|
|
80
105
|
if not api_key:
|
|
81
106
|
raise UserError(
|
|
82
|
-
'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(api_key=...)`'
|
|
107
|
+
'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(..., api_key=...)`'
|
|
83
108
|
' to use the Pydantic AI Gateway provider.'
|
|
84
109
|
)
|
|
85
110
|
|
|
86
|
-
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL',
|
|
87
|
-
http_client = http_client or cached_async_http_client(provider=f'gateway
|
|
88
|
-
http_client.event_hooks = {'request': [_request_hook]}
|
|
89
|
-
|
|
90
|
-
if upstream_provider in ('openai', 'openai-chat'):
|
|
91
|
-
from .openai import OpenAIProvider
|
|
111
|
+
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
|
|
112
|
+
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
|
|
113
|
+
http_client.event_hooks = {'request': [_request_hook(api_key)]}
|
|
92
114
|
|
|
93
|
-
|
|
94
|
-
elif upstream_provider == 'openai-responses':
|
|
115
|
+
if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
|
|
95
116
|
from .openai import OpenAIProvider
|
|
96
117
|
|
|
97
118
|
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
|
|
@@ -111,79 +132,46 @@ def gateway_provider(
|
|
|
111
132
|
http_client=http_client,
|
|
112
133
|
)
|
|
113
134
|
)
|
|
114
|
-
elif upstream_provider == '
|
|
115
|
-
from
|
|
135
|
+
elif upstream_provider == 'bedrock':
|
|
136
|
+
from .bedrock import BedrockProvider
|
|
116
137
|
|
|
138
|
+
return BedrockProvider(
|
|
139
|
+
api_key=api_key,
|
|
140
|
+
base_url=_merge_url_path(base_url, 'bedrock'),
|
|
141
|
+
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
|
|
142
|
+
)
|
|
143
|
+
elif upstream_provider == 'google-vertex':
|
|
117
144
|
from .google import GoogleProvider
|
|
118
145
|
|
|
119
146
|
return GoogleProvider(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
'base_url': _merge_url_path(base_url, 'google-vertex'),
|
|
125
|
-
'headers': {'User-Agent': get_user_agent(), 'Authorization': api_key},
|
|
126
|
-
# TODO(Marcelo): Until https://github.com/googleapis/python-genai/issues/1357 is solved.
|
|
127
|
-
'async_client_args': {
|
|
128
|
-
'transport': httpx.AsyncHTTPTransport(),
|
|
129
|
-
'event_hooks': {'request': [_request_hook]},
|
|
130
|
-
},
|
|
131
|
-
},
|
|
132
|
-
)
|
|
147
|
+
vertexai=True,
|
|
148
|
+
api_key=api_key,
|
|
149
|
+
base_url=_merge_url_path(base_url, 'google-vertex'),
|
|
150
|
+
http_client=http_client,
|
|
133
151
|
)
|
|
134
|
-
else:
|
|
135
|
-
raise UserError(f'Unknown provider: {upstream_provider}')
|
|
152
|
+
else:
|
|
153
|
+
raise UserError(f'Unknown upstream provider: {upstream_provider}')
|
|
136
154
|
|
|
137
155
|
|
|
138
|
-
def
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
|
-
Args:
|
|
142
|
-
model_name: The name of the model to infer. Must be in the format "provider/model_name".
|
|
156
|
+
def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
|
|
157
|
+
"""Request hook for the gateway provider.
|
|
143
158
|
|
|
144
|
-
|
|
145
|
-
The model class that will be used to make requests to the gateway.
|
|
159
|
+
It adds the `"traceparent"` and `"Authorization"` headers to the request.
|
|
146
160
|
"""
|
|
147
|
-
try:
|
|
148
|
-
upstream_provider, model_name = model_name.split('/', 1)
|
|
149
|
-
except ValueError:
|
|
150
|
-
raise UserError(f'The model name "{model_name}" is not in the format "provider/model_name".')
|
|
151
161
|
|
|
152
|
-
|
|
153
|
-
from
|
|
162
|
+
async def _hook(request: httpx.Request) -> httpx.Request:
|
|
163
|
+
from opentelemetry.propagate import inject
|
|
154
164
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
return OpenAIResponsesModel(model_name, provider=gateway_provider('openai'))
|
|
160
|
-
elif upstream_provider == 'groq':
|
|
161
|
-
from pydantic_ai.models.groq import GroqModel
|
|
165
|
+
headers: dict[str, Any] = {}
|
|
166
|
+
inject(headers)
|
|
167
|
+
request.headers.update(headers)
|
|
162
168
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
from pydantic_ai.models.anthropic import AnthropicModel
|
|
166
|
-
|
|
167
|
-
return AnthropicModel(model_name, provider=gateway_provider('anthropic'))
|
|
168
|
-
elif upstream_provider == 'google-vertex':
|
|
169
|
-
from pydantic_ai.models.google import GoogleModel
|
|
170
|
-
|
|
171
|
-
return GoogleModel(model_name, provider=gateway_provider('google-vertex'))
|
|
172
|
-
raise UserError(f'Unknown upstream provider: {upstream_provider}')
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
async def _request_hook(request: httpx.Request) -> httpx.Request:
|
|
176
|
-
"""Request hook for the gateway provider.
|
|
177
|
-
|
|
178
|
-
It adds the `"traceparent"` header to the request.
|
|
179
|
-
"""
|
|
180
|
-
from opentelemetry.propagate import inject
|
|
169
|
+
if 'Authorization' not in request.headers:
|
|
170
|
+
request.headers['Authorization'] = f'Bearer {api_key}'
|
|
181
171
|
|
|
182
|
-
|
|
183
|
-
inject(headers)
|
|
184
|
-
request.headers.update(headers)
|
|
172
|
+
return request
|
|
185
173
|
|
|
186
|
-
return
|
|
174
|
+
return _hook
|
|
187
175
|
|
|
188
176
|
|
|
189
177
|
def _merge_url_path(base_url: str, path: str) -> str:
|
pydantic_ai/providers/google.py
CHANGED
|
@@ -7,14 +7,15 @@ import httpx
|
|
|
7
7
|
|
|
8
8
|
from pydantic_ai import ModelProfile
|
|
9
9
|
from pydantic_ai.exceptions import UserError
|
|
10
|
-
from pydantic_ai.models import get_user_agent
|
|
10
|
+
from pydantic_ai.models import cached_async_http_client, get_user_agent
|
|
11
11
|
from pydantic_ai.profiles.google import google_model_profile
|
|
12
12
|
from pydantic_ai.providers import Provider
|
|
13
13
|
|
|
14
14
|
try:
|
|
15
15
|
from google.auth.credentials import Credentials
|
|
16
|
-
from google.genai import
|
|
17
|
-
from google.genai.
|
|
16
|
+
from google.genai._api_client import BaseApiClient
|
|
17
|
+
from google.genai.client import Client, DebugConfig
|
|
18
|
+
from google.genai.types import HttpOptions
|
|
18
19
|
except ImportError as _import_error:
|
|
19
20
|
raise ImportError(
|
|
20
21
|
'Please install the `google-genai` package to use the Google provider, '
|
|
@@ -41,7 +42,9 @@ class GoogleProvider(Provider[Client]):
|
|
|
41
42
|
return google_model_profile(model_name)
|
|
42
43
|
|
|
43
44
|
@overload
|
|
44
|
-
def __init__(
|
|
45
|
+
def __init__(
|
|
46
|
+
self, *, api_key: str, http_client: httpx.AsyncClient | None = None, base_url: str | None = None
|
|
47
|
+
) -> None: ...
|
|
45
48
|
|
|
46
49
|
@overload
|
|
47
50
|
def __init__(
|
|
@@ -49,14 +52,23 @@ class GoogleProvider(Provider[Client]):
|
|
|
49
52
|
*,
|
|
50
53
|
credentials: Credentials | None = None,
|
|
51
54
|
project: str | None = None,
|
|
52
|
-
location: VertexAILocation | Literal['global'] | None = None,
|
|
55
|
+
location: VertexAILocation | Literal['global'] | str | None = None,
|
|
56
|
+
http_client: httpx.AsyncClient | None = None,
|
|
57
|
+
base_url: str | None = None,
|
|
53
58
|
) -> None: ...
|
|
54
59
|
|
|
55
60
|
@overload
|
|
56
61
|
def __init__(self, *, client: Client) -> None: ...
|
|
57
62
|
|
|
58
63
|
@overload
|
|
59
|
-
def __init__(
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
*,
|
|
67
|
+
vertexai: bool = False,
|
|
68
|
+
api_key: str | None = None,
|
|
69
|
+
http_client: httpx.AsyncClient | None = None,
|
|
70
|
+
base_url: str | None = None,
|
|
71
|
+
) -> None: ...
|
|
60
72
|
|
|
61
73
|
def __init__(
|
|
62
74
|
self,
|
|
@@ -64,16 +76,17 @@ class GoogleProvider(Provider[Client]):
|
|
|
64
76
|
api_key: str | None = None,
|
|
65
77
|
credentials: Credentials | None = None,
|
|
66
78
|
project: str | None = None,
|
|
67
|
-
location: VertexAILocation | Literal['global'] | None = None,
|
|
68
|
-
client: Client | None = None,
|
|
79
|
+
location: VertexAILocation | Literal['global'] | str | None = None,
|
|
69
80
|
vertexai: bool | None = None,
|
|
81
|
+
client: Client | None = None,
|
|
82
|
+
http_client: httpx.AsyncClient | None = None,
|
|
83
|
+
base_url: str | None = None,
|
|
70
84
|
) -> None:
|
|
71
85
|
"""Create a new Google provider.
|
|
72
86
|
|
|
73
87
|
Args:
|
|
74
88
|
api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
|
|
75
89
|
use for authentication. It can also be set via the `GOOGLE_API_KEY` environment variable.
|
|
76
|
-
Applies to the Gemini Developer API only.
|
|
77
90
|
credentials: The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be
|
|
78
91
|
obtained from environment variables and default credentials. For more information, see Set up
|
|
79
92
|
Application Default Credentials. Applies to the Vertex AI API only.
|
|
@@ -81,43 +94,60 @@ class GoogleProvider(Provider[Client]):
|
|
|
81
94
|
(for example, GOOGLE_CLOUD_PROJECT). Applies to the Vertex AI API only.
|
|
82
95
|
location: The location to send API requests to (for example, us-central1). Can be obtained from environment variables.
|
|
83
96
|
Applies to the Vertex AI API only.
|
|
84
|
-
client: A pre-initialized client to use.
|
|
85
97
|
vertexai: Force the use of the Vertex AI API. If `False`, the Google Generative Language API will be used.
|
|
86
|
-
Defaults to `False
|
|
98
|
+
Defaults to `False` unless `location`, `project`, or `credentials` are provided.
|
|
99
|
+
client: A pre-initialized client to use.
|
|
100
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
101
|
+
base_url: The base URL for the Google API.
|
|
87
102
|
"""
|
|
88
103
|
if client is None:
|
|
89
104
|
# NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
|
|
90
105
|
api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
|
|
91
106
|
|
|
107
|
+
vertex_ai_args_used = bool(location or project or credentials)
|
|
92
108
|
if vertexai is None:
|
|
93
|
-
vertexai =
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
'
|
|
97
|
-
|
|
98
|
-
|
|
109
|
+
vertexai = vertex_ai_args_used
|
|
110
|
+
|
|
111
|
+
http_client = http_client or cached_async_http_client(
|
|
112
|
+
provider='google-vertex' if vertexai else 'google-gla'
|
|
113
|
+
)
|
|
114
|
+
http_options = HttpOptions(
|
|
115
|
+
base_url=base_url,
|
|
116
|
+
headers={'User-Agent': get_user_agent()},
|
|
117
|
+
httpx_async_client=http_client,
|
|
118
|
+
# TODO: Remove once https://github.com/googleapis/python-genai/issues/1565 is solved.
|
|
119
|
+
async_client_args={'transport': httpx.AsyncHTTPTransport()},
|
|
120
|
+
)
|
|
99
121
|
if not vertexai:
|
|
100
122
|
if api_key is None:
|
|
101
|
-
raise UserError(
|
|
123
|
+
raise UserError(
|
|
102
124
|
'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`'
|
|
103
125
|
'to use the Google Generative Language API.'
|
|
104
126
|
)
|
|
105
|
-
self._client =
|
|
127
|
+
self._client = _SafelyClosingClient(vertexai=False, api_key=api_key, http_options=http_options)
|
|
106
128
|
else:
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
129
|
+
if vertex_ai_args_used:
|
|
130
|
+
api_key = None
|
|
131
|
+
|
|
132
|
+
if api_key is None:
|
|
133
|
+
project = project or os.getenv('GOOGLE_CLOUD_PROJECT')
|
|
110
134
|
# From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149:
|
|
111
135
|
# Currently `us-central1` supports the most models by far of any region including `global`, but not
|
|
112
136
|
# all of them. `us-central1` has all google models but is missing some Anthropic partner models,
|
|
113
137
|
# which use `us-east5` instead. `global` has fewer models but higher availability.
|
|
114
138
|
# For more details, check: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
|
|
115
|
-
location=location or os.getenv('GOOGLE_CLOUD_LOCATION') or 'us-central1'
|
|
139
|
+
location = location or os.getenv('GOOGLE_CLOUD_LOCATION') or 'us-central1'
|
|
140
|
+
|
|
141
|
+
self._client = _SafelyClosingClient(
|
|
142
|
+
vertexai=True,
|
|
143
|
+
api_key=api_key,
|
|
144
|
+
project=project,
|
|
145
|
+
location=location,
|
|
116
146
|
credentials=credentials,
|
|
117
147
|
http_options=http_options,
|
|
118
148
|
)
|
|
119
149
|
else:
|
|
120
|
-
self._client = client
|
|
150
|
+
self._client = client # pragma: no cover
|
|
121
151
|
|
|
122
152
|
|
|
123
153
|
VertexAILocation = Literal[
|
|
@@ -154,3 +184,40 @@ VertexAILocation = Literal[
|
|
|
154
184
|
"""Regions available for Vertex AI.
|
|
155
185
|
More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
|
|
156
186
|
"""
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class _SafelyClosingClient(Client):
|
|
190
|
+
@staticmethod
|
|
191
|
+
def _get_api_client(
|
|
192
|
+
vertexai: bool | None = None,
|
|
193
|
+
api_key: str | None = None,
|
|
194
|
+
credentials: Credentials | None = None,
|
|
195
|
+
project: str | None = None,
|
|
196
|
+
location: str | None = None,
|
|
197
|
+
debug_config: DebugConfig | None = None,
|
|
198
|
+
http_options: HttpOptions | None = None,
|
|
199
|
+
) -> BaseApiClient:
|
|
200
|
+
return _NonClosingApiClient(
|
|
201
|
+
vertexai=vertexai,
|
|
202
|
+
api_key=api_key,
|
|
203
|
+
credentials=credentials,
|
|
204
|
+
project=project,
|
|
205
|
+
location=location,
|
|
206
|
+
http_options=http_options,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
def close(self) -> None:
|
|
210
|
+
# This is called from `Client.__del__`, even if `Client.__init__` raised an error before `self._api_client` is set, which would raise an `AttributeError` here.
|
|
211
|
+
# TODO: Remove once https://github.com/googleapis/python-genai/issues/1567 is solved.
|
|
212
|
+
try:
|
|
213
|
+
super().close()
|
|
214
|
+
except AttributeError:
|
|
215
|
+
pass
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class _NonClosingApiClient(BaseApiClient):
|
|
219
|
+
async def aclose(self) -> None:
|
|
220
|
+
# The original implementation also calls `await self._async_httpx_client.aclose()`, but we don't want to close our `cached_async_http_client` or the one the user passed in.
|
|
221
|
+
# TODO: Remove once https://github.com/googleapis/python-genai/issues/1566 is solved.
|
|
222
|
+
if self._aiohttp_session:
|
|
223
|
+
await self._aiohttp_session.close() # pragma: no cover
|
|
@@ -81,6 +81,9 @@ class OpenRouterProvider(Provider[AsyncOpenAI]):
|
|
|
81
81
|
@overload
|
|
82
82
|
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
|
|
83
83
|
|
|
84
|
+
@overload
|
|
85
|
+
def __init__(self, *, http_client: httpx.AsyncClient) -> None: ...
|
|
86
|
+
|
|
84
87
|
@overload
|
|
85
88
|
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
86
89
|
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic_ai.profiles import ModelProfile
|
|
6
|
+
from pydantic_ai.providers import Provider
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OutlinesProvider(Provider[Any]):
|
|
10
|
+
"""Provider for Outlines API."""
|
|
11
|
+
|
|
12
|
+
@property
|
|
13
|
+
def name(self) -> str:
|
|
14
|
+
"""The provider name."""
|
|
15
|
+
return 'outlines'
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def base_url(self) -> str:
|
|
19
|
+
"""The base URL for the provider API."""
|
|
20
|
+
raise NotImplementedError(
|
|
21
|
+
'The Outlines provider does not have a set base URL as it functions '
|
|
22
|
+
+ 'with a set of different underlying models.'
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def client(self) -> Any:
|
|
27
|
+
"""The client for the provider."""
|
|
28
|
+
raise NotImplementedError(
|
|
29
|
+
'The Outlines provider does not have a set client as it functions '
|
|
30
|
+
+ 'with a set of different underlying models.'
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
34
|
+
"""The model profile for the named model, if available."""
|
|
35
|
+
return ModelProfile(
|
|
36
|
+
supports_tools=False,
|
|
37
|
+
supports_json_schema_output=True,
|
|
38
|
+
supports_json_object_output=True,
|
|
39
|
+
default_structured_output_mode='native',
|
|
40
|
+
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from pydantic_ai import ModelProfile
|
|
9
|
+
from pydantic_ai.exceptions import UserError
|
|
10
|
+
from pydantic_ai.models import cached_async_http_client
|
|
11
|
+
from pydantic_ai.profiles.deepseek import deepseek_model_profile
|
|
12
|
+
from pydantic_ai.profiles.harmony import harmony_model_profile
|
|
13
|
+
from pydantic_ai.profiles.meta import meta_model_profile
|
|
14
|
+
from pydantic_ai.profiles.mistral import mistral_model_profile
|
|
15
|
+
from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
|
|
16
|
+
from pydantic_ai.profiles.qwen import qwen_model_profile
|
|
17
|
+
from pydantic_ai.providers import Provider
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
from openai import AsyncOpenAI
|
|
21
|
+
except ImportError as _import_error: # pragma: no cover
|
|
22
|
+
raise ImportError(
|
|
23
|
+
'Please install the `openai` package to use OVHcloud AI Endpoints provider.'
|
|
24
|
+
'You can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
|
|
25
|
+
) from _import_error
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class OVHcloudProvider(Provider[AsyncOpenAI]):
|
|
29
|
+
"""Provider for OVHcloud AI Endpoints."""
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def name(self) -> str:
|
|
33
|
+
return 'ovhcloud'
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def base_url(self) -> str:
|
|
37
|
+
return 'https://oai.endpoints.kepler.ai.cloud.ovh.net/v1'
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def client(self) -> AsyncOpenAI:
|
|
41
|
+
return self._client
|
|
42
|
+
|
|
43
|
+
def model_profile(self, model_name: str) -> ModelProfile | None:
|
|
44
|
+
model_name = model_name.lower()
|
|
45
|
+
|
|
46
|
+
prefix_to_profile = {
|
|
47
|
+
'llama': meta_model_profile,
|
|
48
|
+
'meta-': meta_model_profile,
|
|
49
|
+
'deepseek': deepseek_model_profile,
|
|
50
|
+
'mistral': mistral_model_profile,
|
|
51
|
+
'gpt': harmony_model_profile,
|
|
52
|
+
'qwen': qwen_model_profile,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
profile = None
|
|
56
|
+
for prefix, profile_func in prefix_to_profile.items():
|
|
57
|
+
if model_name.startswith(prefix):
|
|
58
|
+
profile = profile_func(model_name)
|
|
59
|
+
|
|
60
|
+
# As the OVHcloud AI Endpoints API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer.
|
|
61
|
+
return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
|
|
62
|
+
|
|
63
|
+
@overload
|
|
64
|
+
def __init__(self) -> None: ...
|
|
65
|
+
|
|
66
|
+
@overload
|
|
67
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
68
|
+
|
|
69
|
+
@overload
|
|
70
|
+
def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
|
|
71
|
+
|
|
72
|
+
@overload
|
|
73
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
*,
|
|
78
|
+
api_key: str | None = None,
|
|
79
|
+
openai_client: AsyncOpenAI | None = None,
|
|
80
|
+
http_client: httpx.AsyncClient | None = None,
|
|
81
|
+
) -> None:
|
|
82
|
+
api_key = api_key or os.getenv('OVHCLOUD_API_KEY')
|
|
83
|
+
if not api_key and openai_client is None:
|
|
84
|
+
raise UserError(
|
|
85
|
+
'Set the `OVHCLOUD_API_KEY` environment variable or pass it via '
|
|
86
|
+
'`OVHcloudProvider(api_key=...)` to use OVHcloud AI Endpoints provider.'
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if openai_client is not None:
|
|
90
|
+
self._client = openai_client
|
|
91
|
+
elif http_client is not None:
|
|
92
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
93
|
+
else:
|
|
94
|
+
http_client = cached_async_http_client(provider='ovhcloud')
|
|
95
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|