pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. pydantic_ai/__init__.py +6 -0
  2. pydantic_ai/_agent_graph.py +67 -20
  3. pydantic_ai/_cli.py +2 -2
  4. pydantic_ai/_output.py +20 -12
  5. pydantic_ai/_run_context.py +6 -2
  6. pydantic_ai/_utils.py +26 -8
  7. pydantic_ai/ag_ui.py +50 -696
  8. pydantic_ai/agent/__init__.py +13 -25
  9. pydantic_ai/agent/abstract.py +146 -9
  10. pydantic_ai/builtin_tools.py +106 -4
  11. pydantic_ai/direct.py +16 -4
  12. pydantic_ai/durable_exec/dbos/_agent.py +3 -0
  13. pydantic_ai/durable_exec/prefect/_agent.py +3 -0
  14. pydantic_ai/durable_exec/temporal/__init__.py +11 -0
  15. pydantic_ai/durable_exec/temporal/_agent.py +3 -0
  16. pydantic_ai/durable_exec/temporal/_function_toolset.py +23 -72
  17. pydantic_ai/durable_exec/temporal/_mcp_server.py +30 -30
  18. pydantic_ai/durable_exec/temporal/_run_context.py +7 -2
  19. pydantic_ai/durable_exec/temporal/_toolset.py +67 -3
  20. pydantic_ai/exceptions.py +6 -1
  21. pydantic_ai/mcp.py +1 -22
  22. pydantic_ai/messages.py +46 -8
  23. pydantic_ai/models/__init__.py +87 -38
  24. pydantic_ai/models/anthropic.py +132 -11
  25. pydantic_ai/models/bedrock.py +4 -4
  26. pydantic_ai/models/cohere.py +0 -7
  27. pydantic_ai/models/gemini.py +9 -2
  28. pydantic_ai/models/google.py +26 -23
  29. pydantic_ai/models/groq.py +13 -5
  30. pydantic_ai/models/huggingface.py +2 -2
  31. pydantic_ai/models/openai.py +251 -52
  32. pydantic_ai/models/outlines.py +563 -0
  33. pydantic_ai/models/test.py +6 -3
  34. pydantic_ai/profiles/openai.py +7 -0
  35. pydantic_ai/providers/__init__.py +25 -12
  36. pydantic_ai/providers/anthropic.py +2 -2
  37. pydantic_ai/providers/bedrock.py +60 -16
  38. pydantic_ai/providers/gateway.py +60 -72
  39. pydantic_ai/providers/google.py +91 -24
  40. pydantic_ai/providers/openrouter.py +3 -0
  41. pydantic_ai/providers/outlines.py +40 -0
  42. pydantic_ai/providers/ovhcloud.py +95 -0
  43. pydantic_ai/result.py +173 -8
  44. pydantic_ai/run.py +40 -24
  45. pydantic_ai/settings.py +8 -0
  46. pydantic_ai/tools.py +10 -6
  47. pydantic_ai/toolsets/fastmcp.py +215 -0
  48. pydantic_ai/ui/__init__.py +16 -0
  49. pydantic_ai/ui/_adapter.py +386 -0
  50. pydantic_ai/ui/_event_stream.py +591 -0
  51. pydantic_ai/ui/_messages_builder.py +28 -0
  52. pydantic_ai/ui/ag_ui/__init__.py +9 -0
  53. pydantic_ai/ui/ag_ui/_adapter.py +187 -0
  54. pydantic_ai/ui/ag_ui/_event_stream.py +236 -0
  55. pydantic_ai/ui/ag_ui/app.py +148 -0
  56. pydantic_ai/ui/vercel_ai/__init__.py +16 -0
  57. pydantic_ai/ui/vercel_ai/_adapter.py +199 -0
  58. pydantic_ai/ui/vercel_ai/_event_stream.py +187 -0
  59. pydantic_ai/ui/vercel_ai/_utils.py +16 -0
  60. pydantic_ai/ui/vercel_ai/request_types.py +275 -0
  61. pydantic_ai/ui/vercel_ai/response_types.py +230 -0
  62. pydantic_ai/usage.py +13 -2
  63. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/METADATA +23 -5
  64. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/RECORD +67 -49
  65. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/WHEEL +0 -0
  66. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/entry_points.txt +0 -0
  67. {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.10.0.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,7 @@ import os
4
4
  import re
5
5
  from collections.abc import Callable
6
6
  from dataclasses import dataclass
7
- from typing import Literal, overload
7
+ from typing import Any, Literal, overload
8
8
 
9
9
  from pydantic_ai import ModelProfile
10
10
  from pydantic_ai.exceptions import UserError
@@ -21,6 +21,8 @@ try:
21
21
  from botocore.client import BaseClient
22
22
  from botocore.config import Config
23
23
  from botocore.exceptions import NoRegionError
24
+ from botocore.session import Session
25
+ from botocore.tokens import FrozenAuthToken
24
26
  except ImportError as _import_error:
25
27
  raise ImportError(
26
28
  'Please install the `boto3` package to use the Bedrock provider, '
@@ -117,10 +119,23 @@ class BedrockProvider(Provider[BaseClient]):
117
119
  def __init__(
118
120
  self,
119
121
  *,
122
+ api_key: str,
123
+ base_url: str | None = None,
120
124
  region_name: str | None = None,
125
+ profile_name: str | None = None,
126
+ aws_read_timeout: float | None = None,
127
+ aws_connect_timeout: float | None = None,
128
+ ) -> None: ...
129
+
130
+ @overload
131
+ def __init__(
132
+ self,
133
+ *,
121
134
  aws_access_key_id: str | None = None,
122
135
  aws_secret_access_key: str | None = None,
123
136
  aws_session_token: str | None = None,
137
+ base_url: str | None = None,
138
+ region_name: str | None = None,
124
139
  profile_name: str | None = None,
125
140
  aws_read_timeout: float | None = None,
126
141
  aws_connect_timeout: float | None = None,
@@ -130,11 +145,13 @@ class BedrockProvider(Provider[BaseClient]):
130
145
  self,
131
146
  *,
132
147
  bedrock_client: BaseClient | None = None,
133
- region_name: str | None = None,
134
148
  aws_access_key_id: str | None = None,
135
149
  aws_secret_access_key: str | None = None,
136
150
  aws_session_token: str | None = None,
151
+ base_url: str | None = None,
152
+ region_name: str | None = None,
137
153
  profile_name: str | None = None,
154
+ api_key: str | None = None,
138
155
  aws_read_timeout: float | None = None,
139
156
  aws_connect_timeout: float | None = None,
140
157
  ) -> None:
@@ -142,10 +159,12 @@ class BedrockProvider(Provider[BaseClient]):
142
159
 
143
160
  Args:
144
161
  bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
145
- region_name: The AWS region name.
146
- aws_access_key_id: The AWS access key ID.
147
- aws_secret_access_key: The AWS secret access key.
148
- aws_session_token: The AWS session token.
162
+ aws_access_key_id: The AWS access key ID. If not set, the `AWS_ACCESS_KEY_ID` environment variable will be used if available.
163
+ aws_secret_access_key: The AWS secret access key. If not set, the `AWS_SECRET_ACCESS_KEY` environment variable will be used if available.
164
+ aws_session_token: The AWS session token. If not set, the `AWS_SESSION_TOKEN` environment variable will be used if available.
165
+ api_key: The API key for Bedrock client. Can be used instead of `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`. If not set, the `AWS_BEARER_TOKEN_BEDROCK` environment variable will be used if available.
166
+ base_url: The base URL for the Bedrock client.
167
+ region_name: The AWS region name. If not set, the `AWS_DEFAULT_REGION` environment variable will be used if available.
149
168
  profile_name: The AWS profile name.
150
169
  aws_read_timeout: The read timeout for Bedrock client.
151
170
  aws_connect_timeout: The connect timeout for Bedrock client.
@@ -153,19 +172,44 @@ class BedrockProvider(Provider[BaseClient]):
153
172
  if bedrock_client is not None:
154
173
  self._client = bedrock_client
155
174
  else:
175
+ read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
176
+ connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
177
+ config: dict[str, Any] = {
178
+ 'read_timeout': read_timeout,
179
+ 'connect_timeout': connect_timeout,
180
+ }
156
181
  try:
157
- read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
158
- connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
159
- session = boto3.Session(
160
- aws_access_key_id=aws_access_key_id,
161
- aws_secret_access_key=aws_secret_access_key,
162
- aws_session_token=aws_session_token,
163
- region_name=region_name,
164
- profile_name=profile_name,
165
- )
182
+ if api_key is not None:
183
+ session = boto3.Session(
184
+ botocore_session=_BearerTokenSession(api_key),
185
+ region_name=region_name,
186
+ profile_name=profile_name,
187
+ )
188
+ config['signature_version'] = 'bearer'
189
+ else:
190
+ session = boto3.Session(
191
+ aws_access_key_id=aws_access_key_id,
192
+ aws_secret_access_key=aws_secret_access_key,
193
+ aws_session_token=aws_session_token,
194
+ region_name=region_name,
195
+ profile_name=profile_name,
196
+ )
166
197
  self._client = session.client( # type: ignore[reportUnknownMemberType]
167
198
  'bedrock-runtime',
168
- config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout),
199
+ config=Config(**config),
200
+ endpoint_url=base_url,
169
201
  )
170
202
  except NoRegionError as exc: # pragma: no cover
171
203
  raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
204
+
205
+
206
+ class _BearerTokenSession(Session):
207
+ def __init__(self, token: str):
208
+ super().__init__()
209
+ self.token = token
210
+
211
+ def get_auth_token(self, **_kwargs: Any) -> FrozenAuthToken:
212
+ return FrozenAuthToken(self.token)
213
+
214
+ def get_credentials(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
215
+ return None
@@ -3,14 +3,16 @@
3
3
  from __future__ import annotations as _annotations
4
4
 
5
5
  import os
6
+ from collections.abc import Awaitable, Callable
6
7
  from typing import TYPE_CHECKING, Any, Literal, overload
7
8
 
8
9
  import httpx
9
10
 
10
11
  from pydantic_ai.exceptions import UserError
11
- from pydantic_ai.models import Model, cached_async_http_client, get_user_agent
12
+ from pydantic_ai.models import cached_async_http_client
12
13
 
13
14
  if TYPE_CHECKING:
15
+ from botocore.client import BaseClient
14
16
  from google.genai import Client as GoogleClient
15
17
  from groq import AsyncGroq
16
18
  from openai import AsyncOpenAI
@@ -18,6 +20,8 @@ if TYPE_CHECKING:
18
20
  from pydantic_ai.models.anthropic import AsyncAnthropicClient
19
21
  from pydantic_ai.providers import Provider
20
22
 
23
+ GATEWAY_BASE_URL = 'https://gateway.pydantic.dev/proxy'
24
+
21
25
 
22
26
  @overload
23
27
  def gateway_provider(
@@ -57,13 +61,34 @@ def gateway_provider(
57
61
  ) -> Provider[AsyncAnthropicClient]: ...
58
62
 
59
63
 
64
+ @overload
65
+ def gateway_provider(
66
+ upstream_provider: Literal['bedrock'],
67
+ *,
68
+ api_key: str | None = None,
69
+ base_url: str | None = None,
70
+ ) -> Provider[BaseClient]: ...
71
+
72
+
73
+ @overload
60
74
  def gateway_provider(
61
- upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic'] | str,
75
+ upstream_provider: str,
76
+ *,
77
+ api_key: str | None = None,
78
+ base_url: str | None = None,
79
+ ) -> Provider[Any]: ...
80
+
81
+
82
+ UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
83
+
84
+
85
+ def gateway_provider(
86
+ upstream_provider: UpstreamProvider | str,
62
87
  *,
63
88
  # Every provider
64
89
  api_key: str | None = None,
65
90
  base_url: str | None = None,
66
- # OpenAI & Groq
91
+ # OpenAI, Groq & Anthropic
67
92
  http_client: httpx.AsyncClient | None = None,
68
93
  ) -> Provider[Any]:
69
94
  """Create a new Gateway provider.
@@ -73,25 +98,21 @@ def gateway_provider(
73
98
  api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
74
99
  environment variable will be used if available.
75
100
  base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
76
- environment variable will be used if available. Otherwise, defaults to `http://localhost:8787/`.
101
+ environment variable will be used if available. Otherwise, defaults to `https://gateway.pydantic.dev/proxy`.
77
102
  http_client: The HTTP client to use for the Gateway.
78
103
  """
79
104
  api_key = api_key or os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
80
105
  if not api_key:
81
106
  raise UserError(
82
- 'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(api_key=...)`'
107
+ 'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(..., api_key=...)`'
83
108
  ' to use the Pydantic AI Gateway provider.'
84
109
  )
85
110
 
86
- base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', 'https://gateway.pydantic.dev/proxy')
87
- http_client = http_client or cached_async_http_client(provider=f'gateway-{upstream_provider}')
88
- http_client.event_hooks = {'request': [_request_hook]}
89
-
90
- if upstream_provider in ('openai', 'openai-chat'):
91
- from .openai import OpenAIProvider
111
+ base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
112
+ http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
113
+ http_client.event_hooks = {'request': [_request_hook(api_key)]}
92
114
 
93
- return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
94
- elif upstream_provider == 'openai-responses':
115
+ if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
95
116
  from .openai import OpenAIProvider
96
117
 
97
118
  return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
@@ -111,79 +132,46 @@ def gateway_provider(
111
132
  http_client=http_client,
112
133
  )
113
134
  )
114
- elif upstream_provider == 'google-vertex':
115
- from google.genai import Client as GoogleClient
135
+ elif upstream_provider == 'bedrock':
136
+ from .bedrock import BedrockProvider
116
137
 
138
+ return BedrockProvider(
139
+ api_key=api_key,
140
+ base_url=_merge_url_path(base_url, 'bedrock'),
141
+ region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
142
+ )
143
+ elif upstream_provider == 'google-vertex':
117
144
  from .google import GoogleProvider
118
145
 
119
146
  return GoogleProvider(
120
- client=GoogleClient(
121
- vertexai=True,
122
- api_key='unset',
123
- http_options={
124
- 'base_url': _merge_url_path(base_url, 'google-vertex'),
125
- 'headers': {'User-Agent': get_user_agent(), 'Authorization': api_key},
126
- # TODO(Marcelo): Until https://github.com/googleapis/python-genai/issues/1357 is solved.
127
- 'async_client_args': {
128
- 'transport': httpx.AsyncHTTPTransport(),
129
- 'event_hooks': {'request': [_request_hook]},
130
- },
131
- },
132
- )
147
+ vertexai=True,
148
+ api_key=api_key,
149
+ base_url=_merge_url_path(base_url, 'google-vertex'),
150
+ http_client=http_client,
133
151
  )
134
- else: # pragma: no cover
135
- raise UserError(f'Unknown provider: {upstream_provider}')
152
+ else:
153
+ raise UserError(f'Unknown upstream provider: {upstream_provider}')
136
154
 
137
155
 
138
- def infer_model(model_name: str) -> Model:
139
- """Infer the model class that will be used to make requests to the gateway.
140
-
141
- Args:
142
- model_name: The name of the model to infer. Must be in the format "provider/model_name".
156
+ def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
157
+ """Request hook for the gateway provider.
143
158
 
144
- Returns:
145
- The model class that will be used to make requests to the gateway.
159
+ It adds the `"traceparent"` and `"Authorization"` headers to the request.
146
160
  """
147
- try:
148
- upstream_provider, model_name = model_name.split('/', 1)
149
- except ValueError:
150
- raise UserError(f'The model name "{model_name}" is not in the format "provider/model_name".')
151
161
 
152
- if upstream_provider in ('openai', 'openai-chat'):
153
- from pydantic_ai.models.openai import OpenAIChatModel
162
+ async def _hook(request: httpx.Request) -> httpx.Request:
163
+ from opentelemetry.propagate import inject
154
164
 
155
- return OpenAIChatModel(model_name, provider=gateway_provider('openai'))
156
- elif upstream_provider == 'openai-responses':
157
- from pydantic_ai.models.openai import OpenAIResponsesModel
158
-
159
- return OpenAIResponsesModel(model_name, provider=gateway_provider('openai'))
160
- elif upstream_provider == 'groq':
161
- from pydantic_ai.models.groq import GroqModel
165
+ headers: dict[str, Any] = {}
166
+ inject(headers)
167
+ request.headers.update(headers)
162
168
 
163
- return GroqModel(model_name, provider=gateway_provider('groq'))
164
- elif upstream_provider == 'anthropic':
165
- from pydantic_ai.models.anthropic import AnthropicModel
166
-
167
- return AnthropicModel(model_name, provider=gateway_provider('anthropic'))
168
- elif upstream_provider == 'google-vertex':
169
- from pydantic_ai.models.google import GoogleModel
170
-
171
- return GoogleModel(model_name, provider=gateway_provider('google-vertex'))
172
- raise UserError(f'Unknown upstream provider: {upstream_provider}')
173
-
174
-
175
- async def _request_hook(request: httpx.Request) -> httpx.Request:
176
- """Request hook for the gateway provider.
177
-
178
- It adds the `"traceparent"` header to the request.
179
- """
180
- from opentelemetry.propagate import inject
169
+ if 'Authorization' not in request.headers:
170
+ request.headers['Authorization'] = f'Bearer {api_key}'
181
171
 
182
- headers: dict[str, Any] = {}
183
- inject(headers)
184
- request.headers.update(headers)
172
+ return request
185
173
 
186
- return request
174
+ return _hook
187
175
 
188
176
 
189
177
  def _merge_url_path(base_url: str, path: str) -> str:
@@ -7,14 +7,15 @@ import httpx
7
7
 
8
8
  from pydantic_ai import ModelProfile
9
9
  from pydantic_ai.exceptions import UserError
10
- from pydantic_ai.models import get_user_agent
10
+ from pydantic_ai.models import cached_async_http_client, get_user_agent
11
11
  from pydantic_ai.profiles.google import google_model_profile
12
12
  from pydantic_ai.providers import Provider
13
13
 
14
14
  try:
15
15
  from google.auth.credentials import Credentials
16
- from google.genai import Client
17
- from google.genai.types import HttpOptionsDict
16
+ from google.genai._api_client import BaseApiClient
17
+ from google.genai.client import Client, DebugConfig
18
+ from google.genai.types import HttpOptions
18
19
  except ImportError as _import_error:
19
20
  raise ImportError(
20
21
  'Please install the `google-genai` package to use the Google provider, '
@@ -41,7 +42,9 @@ class GoogleProvider(Provider[Client]):
41
42
  return google_model_profile(model_name)
42
43
 
43
44
  @overload
44
- def __init__(self, *, api_key: str) -> None: ...
45
+ def __init__(
46
+ self, *, api_key: str, http_client: httpx.AsyncClient | None = None, base_url: str | None = None
47
+ ) -> None: ...
45
48
 
46
49
  @overload
47
50
  def __init__(
@@ -49,14 +52,23 @@ class GoogleProvider(Provider[Client]):
49
52
  *,
50
53
  credentials: Credentials | None = None,
51
54
  project: str | None = None,
52
- location: VertexAILocation | Literal['global'] | None = None,
55
+ location: VertexAILocation | Literal['global'] | str | None = None,
56
+ http_client: httpx.AsyncClient | None = None,
57
+ base_url: str | None = None,
53
58
  ) -> None: ...
54
59
 
55
60
  @overload
56
61
  def __init__(self, *, client: Client) -> None: ...
57
62
 
58
63
  @overload
59
- def __init__(self, *, vertexai: bool = False) -> None: ...
64
+ def __init__(
65
+ self,
66
+ *,
67
+ vertexai: bool = False,
68
+ api_key: str | None = None,
69
+ http_client: httpx.AsyncClient | None = None,
70
+ base_url: str | None = None,
71
+ ) -> None: ...
60
72
 
61
73
  def __init__(
62
74
  self,
@@ -64,16 +76,17 @@ class GoogleProvider(Provider[Client]):
64
76
  api_key: str | None = None,
65
77
  credentials: Credentials | None = None,
66
78
  project: str | None = None,
67
- location: VertexAILocation | Literal['global'] | None = None,
68
- client: Client | None = None,
79
+ location: VertexAILocation | Literal['global'] | str | None = None,
69
80
  vertexai: bool | None = None,
81
+ client: Client | None = None,
82
+ http_client: httpx.AsyncClient | None = None,
83
+ base_url: str | None = None,
70
84
  ) -> None:
71
85
  """Create a new Google provider.
72
86
 
73
87
  Args:
74
88
  api_key: The `API key <https://ai.google.dev/gemini-api/docs/api-key>`_ to
75
89
  use for authentication. It can also be set via the `GOOGLE_API_KEY` environment variable.
76
- Applies to the Gemini Developer API only.
77
90
  credentials: The credentials to use for authentication when calling the Vertex AI APIs. Credentials can be
78
91
  obtained from environment variables and default credentials. For more information, see Set up
79
92
  Application Default Credentials. Applies to the Vertex AI API only.
@@ -81,43 +94,60 @@ class GoogleProvider(Provider[Client]):
81
94
  (for example, GOOGLE_CLOUD_PROJECT). Applies to the Vertex AI API only.
82
95
  location: The location to send API requests to (for example, us-central1). Can be obtained from environment variables.
83
96
  Applies to the Vertex AI API only.
84
- client: A pre-initialized client to use.
85
97
  vertexai: Force the use of the Vertex AI API. If `False`, the Google Generative Language API will be used.
86
- Defaults to `False`.
98
+ Defaults to `False` unless `location`, `project`, or `credentials` are provided.
99
+ client: A pre-initialized client to use.
100
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
101
+ base_url: The base URL for the Google API.
87
102
  """
88
103
  if client is None:
89
104
  # NOTE: We are keeping GEMINI_API_KEY for backwards compatibility.
90
105
  api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
91
106
 
107
+ vertex_ai_args_used = bool(location or project or credentials)
92
108
  if vertexai is None:
93
- vertexai = bool(location or project or credentials)
94
-
95
- http_options: HttpOptionsDict = {
96
- 'headers': {'User-Agent': get_user_agent()},
97
- 'async_client_args': {'transport': httpx.AsyncHTTPTransport()},
98
- }
109
+ vertexai = vertex_ai_args_used
110
+
111
+ http_client = http_client or cached_async_http_client(
112
+ provider='google-vertex' if vertexai else 'google-gla'
113
+ )
114
+ http_options = HttpOptions(
115
+ base_url=base_url,
116
+ headers={'User-Agent': get_user_agent()},
117
+ httpx_async_client=http_client,
118
+ # TODO: Remove once https://github.com/googleapis/python-genai/issues/1565 is solved.
119
+ async_client_args={'transport': httpx.AsyncHTTPTransport()},
120
+ )
99
121
  if not vertexai:
100
122
  if api_key is None:
101
- raise UserError( # pragma: no cover
123
+ raise UserError(
102
124
  'Set the `GOOGLE_API_KEY` environment variable or pass it via `GoogleProvider(api_key=...)`'
103
125
  'to use the Google Generative Language API.'
104
126
  )
105
- self._client = Client(vertexai=vertexai, api_key=api_key, http_options=http_options)
127
+ self._client = _SafelyClosingClient(vertexai=False, api_key=api_key, http_options=http_options)
106
128
  else:
107
- self._client = Client(
108
- vertexai=vertexai,
109
- project=project or os.getenv('GOOGLE_CLOUD_PROJECT'),
129
+ if vertex_ai_args_used:
130
+ api_key = None
131
+
132
+ if api_key is None:
133
+ project = project or os.getenv('GOOGLE_CLOUD_PROJECT')
110
134
  # From https://github.com/pydantic/pydantic-ai/pull/2031/files#r2169682149:
111
135
  # Currently `us-central1` supports the most models by far of any region including `global`, but not
112
136
  # all of them. `us-central1` has all google models but is missing some Anthropic partner models,
113
137
  # which use `us-east5` instead. `global` has fewer models but higher availability.
114
138
  # For more details, check: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
115
- location=location or os.getenv('GOOGLE_CLOUD_LOCATION') or 'us-central1',
139
+ location = location or os.getenv('GOOGLE_CLOUD_LOCATION') or 'us-central1'
140
+
141
+ self._client = _SafelyClosingClient(
142
+ vertexai=True,
143
+ api_key=api_key,
144
+ project=project,
145
+ location=location,
116
146
  credentials=credentials,
117
147
  http_options=http_options,
118
148
  )
119
149
  else:
120
- self._client = client
150
+ self._client = client # pragma: no cover
121
151
 
122
152
 
123
153
  VertexAILocation = Literal[
@@ -154,3 +184,40 @@ VertexAILocation = Literal[
154
184
  """Regions available for Vertex AI.
155
185
  More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
156
186
  """
187
+
188
+
189
+ class _SafelyClosingClient(Client):
190
+ @staticmethod
191
+ def _get_api_client(
192
+ vertexai: bool | None = None,
193
+ api_key: str | None = None,
194
+ credentials: Credentials | None = None,
195
+ project: str | None = None,
196
+ location: str | None = None,
197
+ debug_config: DebugConfig | None = None,
198
+ http_options: HttpOptions | None = None,
199
+ ) -> BaseApiClient:
200
+ return _NonClosingApiClient(
201
+ vertexai=vertexai,
202
+ api_key=api_key,
203
+ credentials=credentials,
204
+ project=project,
205
+ location=location,
206
+ http_options=http_options,
207
+ )
208
+
209
+ def close(self) -> None:
210
+ # This is called from `Client.__del__`, even if `Client.__init__` raised an error before `self._api_client` is set, which would raise an `AttributeError` here.
211
+ # TODO: Remove once https://github.com/googleapis/python-genai/issues/1567 is solved.
212
+ try:
213
+ super().close()
214
+ except AttributeError:
215
+ pass
216
+
217
+
218
+ class _NonClosingApiClient(BaseApiClient):
219
+ async def aclose(self) -> None:
220
+ # The original implementation also calls `await self._async_httpx_client.aclose()`, but we don't want to close our `cached_async_http_client` or the one the user passed in.
221
+ # TODO: Remove once https://github.com/googleapis/python-genai/issues/1566 is solved.
222
+ if self._aiohttp_session:
223
+ await self._aiohttp_session.close() # pragma: no cover
@@ -81,6 +81,9 @@ class OpenRouterProvider(Provider[AsyncOpenAI]):
81
81
  @overload
82
82
  def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
83
83
 
84
+ @overload
85
+ def __init__(self, *, http_client: httpx.AsyncClient) -> None: ...
86
+
84
87
  @overload
85
88
  def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
86
89
 
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from typing import Any
4
+
5
+ from pydantic_ai.profiles import ModelProfile
6
+ from pydantic_ai.providers import Provider
7
+
8
+
9
+ class OutlinesProvider(Provider[Any]):
10
+ """Provider for Outlines API."""
11
+
12
+ @property
13
+ def name(self) -> str:
14
+ """The provider name."""
15
+ return 'outlines'
16
+
17
+ @property
18
+ def base_url(self) -> str:
19
+ """The base URL for the provider API."""
20
+ raise NotImplementedError(
21
+ 'The Outlines provider does not have a set base URL as it functions '
22
+ + 'with a set of different underlying models.'
23
+ )
24
+
25
+ @property
26
+ def client(self) -> Any:
27
+ """The client for the provider."""
28
+ raise NotImplementedError(
29
+ 'The Outlines provider does not have a set client as it functions '
30
+ + 'with a set of different underlying models.'
31
+ )
32
+
33
+ def model_profile(self, model_name: str) -> ModelProfile | None:
34
+ """The model profile for the named model, if available."""
35
+ return ModelProfile(
36
+ supports_tools=False,
37
+ supports_json_schema_output=True,
38
+ supports_json_object_output=True,
39
+ default_structured_output_mode='native',
40
+ )
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ import httpx
7
+
8
+ from pydantic_ai import ModelProfile
9
+ from pydantic_ai.exceptions import UserError
10
+ from pydantic_ai.models import cached_async_http_client
11
+ from pydantic_ai.profiles.deepseek import deepseek_model_profile
12
+ from pydantic_ai.profiles.harmony import harmony_model_profile
13
+ from pydantic_ai.profiles.meta import meta_model_profile
14
+ from pydantic_ai.profiles.mistral import mistral_model_profile
15
+ from pydantic_ai.profiles.openai import OpenAIJsonSchemaTransformer, OpenAIModelProfile
16
+ from pydantic_ai.profiles.qwen import qwen_model_profile
17
+ from pydantic_ai.providers import Provider
18
+
19
+ try:
20
+ from openai import AsyncOpenAI
21
+ except ImportError as _import_error: # pragma: no cover
22
+ raise ImportError(
23
+ 'Please install the `openai` package to use OVHcloud AI Endpoints provider.'
24
+ 'You can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
25
+ ) from _import_error
26
+
27
+
28
+ class OVHcloudProvider(Provider[AsyncOpenAI]):
29
+ """Provider for OVHcloud AI Endpoints."""
30
+
31
+ @property
32
+ def name(self) -> str:
33
+ return 'ovhcloud'
34
+
35
+ @property
36
+ def base_url(self) -> str:
37
+ return 'https://oai.endpoints.kepler.ai.cloud.ovh.net/v1'
38
+
39
+ @property
40
+ def client(self) -> AsyncOpenAI:
41
+ return self._client
42
+
43
+ def model_profile(self, model_name: str) -> ModelProfile | None:
44
+ model_name = model_name.lower()
45
+
46
+ prefix_to_profile = {
47
+ 'llama': meta_model_profile,
48
+ 'meta-': meta_model_profile,
49
+ 'deepseek': deepseek_model_profile,
50
+ 'mistral': mistral_model_profile,
51
+ 'gpt': harmony_model_profile,
52
+ 'qwen': qwen_model_profile,
53
+ }
54
+
55
+ profile = None
56
+ for prefix, profile_func in prefix_to_profile.items():
57
+ if model_name.startswith(prefix):
58
+ profile = profile_func(model_name)
59
+
60
+ # As the OVHcloud AI Endpoints API is OpenAI-compatible, let's assume we also need OpenAIJsonSchemaTransformer.
61
+ return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile)
62
+
63
+ @overload
64
+ def __init__(self) -> None: ...
65
+
66
+ @overload
67
+ def __init__(self, *, api_key: str) -> None: ...
68
+
69
+ @overload
70
+ def __init__(self, *, api_key: str, http_client: httpx.AsyncClient) -> None: ...
71
+
72
+ @overload
73
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
74
+
75
+ def __init__(
76
+ self,
77
+ *,
78
+ api_key: str | None = None,
79
+ openai_client: AsyncOpenAI | None = None,
80
+ http_client: httpx.AsyncClient | None = None,
81
+ ) -> None:
82
+ api_key = api_key or os.getenv('OVHCLOUD_API_KEY')
83
+ if not api_key and openai_client is None:
84
+ raise UserError(
85
+ 'Set the `OVHCLOUD_API_KEY` environment variable or pass it via '
86
+ '`OVHcloudProvider(api_key=...)` to use OVHcloud AI Endpoints provider.'
87
+ )
88
+
89
+ if openai_client is not None:
90
+ self._client = openai_client
91
+ elif http_client is not None:
92
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
93
+ else:
94
+ http_client = cached_async_http_client(provider='ovhcloud')
95
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)