pydantic-ai-slim 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -6,7 +6,6 @@ from contextlib import asynccontextmanager, contextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from typing import Any, Callable, Literal
8
8
 
9
- import logfire_api
10
9
  from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
11
10
  from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
12
11
  from opentelemetry.util.types import AttributeValue
@@ -59,27 +58,15 @@ class InstrumentedModel(WrapperModel):
59
58
  event_logger_provider: EventLoggerProvider | None = None,
60
59
  event_mode: Literal['attributes', 'logs'] = 'attributes',
61
60
  ):
61
+ from pydantic_ai import __version__
62
+
62
63
  super().__init__(wrapped)
63
64
  tracer_provider = tracer_provider or get_tracer_provider()
64
65
  event_logger_provider = event_logger_provider or get_event_logger_provider()
65
- self.tracer = tracer_provider.get_tracer('pydantic-ai')
66
- self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
66
+ self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
67
+ self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
67
68
  self.event_mode = event_mode
68
69
 
69
- @classmethod
70
- def from_logfire(
71
- cls,
72
- wrapped: Model | KnownModelName,
73
- logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
74
- event_mode: Literal['attributes', 'logs'] = 'attributes',
75
- ) -> InstrumentedModel:
76
- if hasattr(logfire_instance.config, 'get_event_logger_provider'):
77
- event_provider = logfire_instance.config.get_event_logger_provider()
78
- else:
79
- event_provider = None
80
- tracer_provider = logfire_instance.config.get_tracer_provider()
81
- return cls(wrapped, tracer_provider, event_provider, event_mode)
82
-
83
70
  async def request(
84
71
  self,
85
72
  messages: list[ModelMessage],
@@ -199,19 +186,30 @@ class InstrumentedModel(WrapperModel):
199
186
  @staticmethod
200
187
  def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
201
188
  result: list[Event] = []
202
- for message in messages:
189
+ for message_index, message in enumerate(messages):
190
+ message_events: list[Event] = []
203
191
  if isinstance(message, ModelRequest):
204
192
  for part in message.parts:
205
193
  if hasattr(part, 'otel_event'):
206
- result.append(part.otel_event())
194
+ message_events.append(part.otel_event())
207
195
  elif isinstance(message, ModelResponse):
208
- result.extend(message.otel_events())
196
+ message_events = message.otel_events()
197
+ for event in message_events:
198
+ event.attributes = {
199
+ 'gen_ai.message.index': message_index,
200
+ **(event.attributes or {}),
201
+ }
202
+ result.extend(message_events)
209
203
  for event in result:
210
- try:
211
- event.body = ANY_ADAPTER.dump_python(event.body, mode='json')
212
- except Exception:
213
- try:
214
- event.body = str(event.body)
215
- except Exception:
216
- event.body = 'Unable to serialize event body'
204
+ event.body = InstrumentedModel.serialize_any(event.body)
217
205
  return result
206
+
207
+ @staticmethod
208
+ def serialize_any(value: Any) -> str:
209
+ try:
210
+ return ANY_ADAPTER.dump_python(value, mode='json')
211
+ except Exception:
212
+ try:
213
+ return str(value)
214
+ except Exception as e:
215
+ return f'Unable to serialize: {e}'
@@ -9,7 +9,9 @@ from datetime import datetime, timezone
9
9
  from typing import Literal, Union, cast, overload
10
10
 
11
11
  from httpx import AsyncClient as AsyncHTTPClient
12
- from typing_extensions import assert_never
12
+ from typing_extensions import assert_never, deprecated
13
+
14
+ from pydantic_ai.providers import Provider, infer_provider
13
15
 
14
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
15
17
  from .._utils import guard_tool_call_id as _guard_tool_call_id
@@ -98,10 +100,36 @@ class OpenAIModel(Model):
98
100
  _model_name: OpenAIModelName = field(repr=False)
99
101
  _system: str | None = field(repr=False)
100
102
 
103
+ @overload
101
104
  def __init__(
102
105
  self,
103
106
  model_name: OpenAIModelName,
104
107
  *,
108
+ provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] = 'openai',
109
+ system_prompt_role: OpenAISystemPromptRole | None = None,
110
+ system: str | None = 'openai',
111
+ ) -> None: ...
112
+
113
+ @deprecated('Use the `provider` parameter instead of `base_url`, `api_key`, `openai_client` and `http_client`.')
114
+ @overload
115
+ def __init__(
116
+ self,
117
+ model_name: OpenAIModelName,
118
+ *,
119
+ provider: None = None,
120
+ base_url: str | None = None,
121
+ api_key: str | None = None,
122
+ openai_client: AsyncOpenAI | None = None,
123
+ http_client: AsyncHTTPClient | None = None,
124
+ system_prompt_role: OpenAISystemPromptRole | None = None,
125
+ system: str | None = 'openai',
126
+ ) -> None: ...
127
+
128
+ def __init__(
129
+ self,
130
+ model_name: OpenAIModelName,
131
+ *,
132
+ provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] | None = None,
105
133
  base_url: str | None = None,
106
134
  api_key: str | None = None,
107
135
  openai_client: AsyncOpenAI | None = None,
@@ -115,6 +143,7 @@ class OpenAIModel(Model):
115
143
  model_name: The name of the OpenAI model to use. List of model names available
116
144
  [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
117
145
  (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
146
+ provider: The provider to use. Defaults to `'openai'`.
118
147
  base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
119
148
  will be used if available. Otherwise, defaults to OpenAI's base url.
120
149
  api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
@@ -129,20 +158,32 @@ class OpenAIModel(Model):
129
158
  customize the `base_url` and `api_key` to use a different provider.
130
159
  """
131
160
  self._model_name = model_name
132
- # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
133
- # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
134
- if api_key is None and 'OPENAI_API_KEY' not in os.environ and base_url is not None and openai_client is None:
135
- api_key = 'api-key-not-set'
136
-
137
- if openai_client is not None:
138
- assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
139
- assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
140
- assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
141
- self.client = openai_client
142
- elif http_client is not None:
143
- self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
144
- else:
145
- self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
161
+
162
+ if provider is not None:
163
+ if isinstance(provider, str):
164
+ self.client = infer_provider(provider).client
165
+ else:
166
+ self.client = provider.client
167
+ else: # pragma: no cover
168
+ # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
169
+ # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
170
+ if (
171
+ api_key is None
172
+ and 'OPENAI_API_KEY' not in os.environ
173
+ and base_url is not None
174
+ and openai_client is None
175
+ ):
176
+ api_key = 'api-key-not-set'
177
+
178
+ if openai_client is not None:
179
+ assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
180
+ assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
181
+ assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
182
+ self.client = openai_client
183
+ elif http_client is not None:
184
+ self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
185
+ else:
186
+ self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
146
187
  self.system_prompt_role = system_prompt_role
147
188
  self._system = system
148
189
 
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
+ import warnings
3
4
  from collections.abc import AsyncIterator
4
5
  from contextlib import asynccontextmanager
5
6
  from dataclasses import dataclass, field
@@ -8,6 +9,7 @@ from pathlib import Path
8
9
  from typing import Literal
9
10
 
10
11
  from httpx import AsyncClient as AsyncHTTPClient
12
+ from typing_extensions import deprecated
11
13
 
12
14
  from .. import usage
13
15
  from .._utils import run_in_executor
@@ -55,6 +57,7 @@ The template is used thus:
55
57
  """
56
58
 
57
59
 
60
+ @deprecated('Please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.')
58
61
  @dataclass(init=False)
59
62
  class VertexAIModel(GeminiModel):
60
63
  """A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API."""
@@ -103,11 +106,16 @@ class VertexAIModel(GeminiModel):
103
106
  self.project_id = project_id
104
107
  self.region = region
105
108
  self.model_publisher = model_publisher
106
- self.http_client = http_client or cached_async_http_client()
109
+ self.client = http_client or cached_async_http_client()
107
110
  self.url_template = url_template
108
111
 
109
112
  self._auth = None
110
113
  self._url = None
114
+ warnings.warn(
115
+ 'VertexAIModel is deprecated, please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.',
116
+ DeprecationWarning,
117
+ )
118
+ self._provider = None
111
119
 
112
120
  async def ainit(self) -> None:
113
121
  """Initialize the model, setting the URL and auth.
@@ -0,0 +1,64 @@
1
+ """Providers for the API clients.
2
+
3
+ The providers are in charge of providing an authenticated client to the API.
4
+ """
5
+
6
+ from __future__ import annotations as _annotations
7
+
8
+ from abc import ABC, abstractmethod
9
+ from typing import Any, Generic, TypeVar
10
+
11
+ InterfaceClient = TypeVar('InterfaceClient')
12
+
13
+
14
+ class Provider(ABC, Generic[InterfaceClient]):
15
+ """Abstract class for a provider.
16
+
17
+ The provider is in charge of providing an authenticated client to the API.
18
+
19
+ Each provider only supports a specific interface. A interface can be supported by multiple providers.
20
+
21
+ For example, the OpenAIModel interface can be supported by the OpenAIProvider and the DeepSeekProvider.
22
+ """
23
+
24
+ _client: InterfaceClient
25
+
26
+ @property
27
+ @abstractmethod
28
+ def name(self) -> str:
29
+ """The provider name."""
30
+ raise NotImplementedError()
31
+
32
+ @property
33
+ @abstractmethod
34
+ def base_url(self) -> str:
35
+ """The base URL for the provider API."""
36
+ raise NotImplementedError()
37
+
38
+ @property
39
+ @abstractmethod
40
+ def client(self) -> InterfaceClient:
41
+ """The client for the provider."""
42
+ raise NotImplementedError()
43
+
44
+
45
+ def infer_provider(provider: str) -> Provider[Any]:
46
+ """Infer the provider from the provider name."""
47
+ if provider == 'openai':
48
+ from .openai import OpenAIProvider
49
+
50
+ return OpenAIProvider()
51
+ elif provider == 'deepseek':
52
+ from .deepseek import DeepSeekProvider
53
+
54
+ return DeepSeekProvider()
55
+ elif provider == 'google-vertex':
56
+ from .google_vertex import GoogleVertexProvider
57
+
58
+ return GoogleVertexProvider()
59
+ elif provider == 'google-gla':
60
+ from .google_gla import GoogleGLAProvider
61
+
62
+ return GoogleGLAProvider()
63
+ else: # pragma: no cover
64
+ raise ValueError(f'Unknown provider: {provider}')
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+ from openai import AsyncOpenAI
8
+
9
+ from pydantic_ai.models import cached_async_http_client
10
+
11
+ try:
12
+ from openai import AsyncOpenAI
13
+ except ImportError as _import_error: # pragma: no cover
14
+ raise ImportError(
15
+ 'Please install `openai` to use the DeepSeek provider, '
16
+ "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
17
+ ) from _import_error
18
+
19
+ from . import Provider
20
+
21
+
22
+ class DeepSeekProvider(Provider[AsyncOpenAI]):
23
+ """Provider for DeepSeek API."""
24
+
25
+ @property
26
+ def name(self) -> str:
27
+ return 'deepseek'
28
+
29
+ @property
30
+ def base_url(self) -> str:
31
+ return 'https://api.deepseek.com'
32
+
33
+ @property
34
+ def client(self) -> AsyncOpenAI:
35
+ return self._client
36
+
37
+ @overload
38
+ def __init__(self) -> None: ...
39
+
40
+ @overload
41
+ def __init__(self, *, api_key: str) -> None: ...
42
+
43
+ @overload
44
+ def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
45
+
46
+ @overload
47
+ def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
48
+
49
+ def __init__(
50
+ self,
51
+ *,
52
+ api_key: str | None = None,
53
+ openai_client: AsyncOpenAI | None = None,
54
+ http_client: AsyncHTTPClient | None = None,
55
+ ) -> None:
56
+ api_key = api_key or os.getenv('DEEPSEEK_API_KEY')
57
+ if api_key is None and openai_client is None:
58
+ raise ValueError(
59
+ 'Set the `DEEPSEEK_API_KEY` environment variable or pass it via `DeepSeekProvider(api_key=...)`'
60
+ 'to use the DeepSeek provider.'
61
+ )
62
+
63
+ if openai_client is not None:
64
+ self._client = openai_client
65
+ elif http_client is not None:
66
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
67
+ else:
68
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+
5
+ import httpx
6
+
7
+ from pydantic_ai.models import cached_async_http_client
8
+ from pydantic_ai.providers import Provider
9
+
10
+
11
+ class GoogleGLAProvider(Provider[httpx.AsyncClient]):
12
+ """Provider for Google Generative Language AI API."""
13
+
14
+ @property
15
+ def name(self):
16
+ return 'google-gla'
17
+
18
+ @property
19
+ def base_url(self) -> str:
20
+ return 'https://generativelanguage.googleapis.com/v1beta/models/'
21
+
22
+ @property
23
+ def client(self) -> httpx.AsyncClient:
24
+ return self._client
25
+
26
+ def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None:
27
+ """Create a new Google GLA provider.
28
+
29
+ Args:
30
+ api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
31
+ will be used if available.
32
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
33
+ """
34
+ api_key = api_key or os.environ.get('GEMINI_API_KEY')
35
+ if api_key is None:
36
+ raise ValueError(
37
+ 'Set the `GEMINI_API_KEY` environment variable or pass it via `GoogleGLAProvider(api_key=...)`'
38
+ 'to use the Google GLA provider.'
39
+ )
40
+
41
+ self._client = http_client or cached_async_http_client()
42
+ self._client.base_url = self.base_url
43
+ # https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
44
+ self._client.headers['X-Goog-Api-Key'] = api_key
@@ -0,0 +1,200 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import functools
4
+ from collections.abc import AsyncGenerator
5
+ from datetime import datetime, timedelta
6
+ from pathlib import Path
7
+ from typing import Literal
8
+
9
+ import anyio.to_thread
10
+ import httpx
11
+
12
+ from pydantic_ai.exceptions import UserError
13
+
14
+ from ..models import cached_async_http_client
15
+ from . import Provider
16
+
17
+ try:
18
+ import google.auth
19
+ from google.auth.credentials import Credentials as BaseCredentials
20
+ from google.auth.transport.requests import Request
21
+ from google.oauth2.service_account import Credentials as ServiceAccountCredentials
22
+ except ImportError as _import_error:
23
+ raise ImportError(
24
+ 'Please install `google-auth` to use the Google Vertex AI provider, '
25
+ "you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
26
+ ) from _import_error
27
+
28
+
29
+ __all__ = ('GoogleVertexProvider',)
30
+
31
+ # default expiry is 3600 seconds
32
+ MAX_TOKEN_AGE = timedelta(seconds=3000)
33
+
34
+
35
+ class GoogleVertexProvider(Provider[httpx.AsyncClient]):
36
+ """Provider for Vertex AI API."""
37
+
38
+ @property
39
+ def name(self) -> str:
40
+ return 'google-vertex'
41
+
42
+ @property
43
+ def base_url(self) -> str:
44
+ return (
45
+ f'https://{self.region}-aiplatform.googleapis.com/v1'
46
+ f'/projects/{self.project_id}'
47
+ f'/locations/{self.region}'
48
+ f'/publishers/{self.model_publisher}/models/'
49
+ )
50
+
51
+ @property
52
+ def client(self) -> httpx.AsyncClient:
53
+ return self._client
54
+
55
+ def __init__(
56
+ self,
57
+ service_account_file: Path | str | None = None,
58
+ project_id: str | None = None,
59
+ region: VertexAiRegion = 'us-central1',
60
+ model_publisher: str = 'google',
61
+ http_client: httpx.AsyncClient | None = None,
62
+ ) -> None:
63
+ """Create a new Vertex AI provider.
64
+
65
+ Args:
66
+ service_account_file: Path to a service account file.
67
+ If not provided, the default environment credentials will be used.
68
+ project_id: The project ID to use, if not provided it will be taken from the credentials.
69
+ region: The region to make requests to.
70
+ model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
71
+ and from trial and error it seems non-google models don't work with the `generateContent` and
72
+ `streamGenerateContent` functions, hence only `google` is currently supported.
73
+ Please create an issue or PR if you know how to use other publishers.
74
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
75
+ """
76
+ self._client = http_client or cached_async_http_client()
77
+ self.service_account_file = service_account_file
78
+ self.project_id = project_id
79
+ self.region = region
80
+ self.model_publisher = model_publisher
81
+
82
+ self._client.auth = _VertexAIAuth(service_account_file, project_id, region)
83
+ self._client.base_url = self.base_url
84
+
85
+
86
+ class _VertexAIAuth(httpx.Auth):
87
+ """Auth class for Vertex AI API."""
88
+
89
+ credentials: BaseCredentials | ServiceAccountCredentials | None
90
+
91
+ def __init__(
92
+ self,
93
+ service_account_file: Path | str | None = None,
94
+ project_id: str | None = None,
95
+ region: VertexAiRegion = 'us-central1',
96
+ ) -> None:
97
+ self.service_account_file = service_account_file
98
+ self.project_id = project_id
99
+ self.region = region
100
+
101
+ self.credentials = None
102
+ self.token_created: datetime | None = None
103
+
104
+ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
105
+ if self.credentials is None:
106
+ self.credentials = await self._get_credentials()
107
+ if self.credentials.token is None or self._token_expired(): # type: ignore[reportUnknownMemberType]
108
+ await anyio.to_thread.run_sync(self._refresh_token)
109
+ self.token_created = datetime.now()
110
+ request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
111
+
112
+ # NOTE: This workaround is in place because we might get the project_id from the credentials.
113
+ request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
114
+ yield request
115
+
116
+ async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
117
+ if self.service_account_file is not None:
118
+ creds = await _creds_from_file(self.service_account_file)
119
+ assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType]
120
+ creds_project_id: str | None = creds.project_id
121
+ creds_source = 'service account file'
122
+ else:
123
+ creds, creds_project_id = await _async_google_auth()
124
+ creds_source = '`google.auth.default()`'
125
+
126
+ if self.project_id is None:
127
+ if creds_project_id is None:
128
+ raise UserError(f'No project_id provided and none found in {creds_source}')
129
+ self.project_id = creds_project_id
130
+ return creds
131
+
132
+ def _token_expired(self) -> bool:
133
+ if self.token_created is None:
134
+ return True
135
+ else:
136
+ return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
137
+
138
+ def _refresh_token(self) -> str: # pragma: no cover
139
+ assert self.credentials is not None
140
+ self.credentials.refresh(Request()) # type: ignore[reportUnknownMemberType]
141
+ assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
142
+ return self.credentials.token
143
+
144
+
145
+ async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
146
+ return await anyio.to_thread.run_sync(google.auth.default, ['https://www.googleapis.com/auth/cloud-platform']) # type: ignore
147
+
148
+
149
+ async def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
150
+ service_account_credentials_from_file = functools.partial(
151
+ ServiceAccountCredentials.from_service_account_file, # type: ignore[reportUnknownMemberType]
152
+ scopes=['https://www.googleapis.com/auth/cloud-platform'],
153
+ )
154
+ return await anyio.to_thread.run_sync(service_account_credentials_from_file, str(service_account_file))
155
+
156
+
157
+ VertexAiRegion = Literal[
158
+ 'us-central1',
159
+ 'us-east1',
160
+ 'us-east4',
161
+ 'us-south1',
162
+ 'us-west1',
163
+ 'us-west2',
164
+ 'us-west3',
165
+ 'us-west4',
166
+ 'us-east5',
167
+ 'europe-central2',
168
+ 'europe-north1',
169
+ 'europe-southwest1',
170
+ 'europe-west1',
171
+ 'europe-west2',
172
+ 'europe-west3',
173
+ 'europe-west4',
174
+ 'europe-west6',
175
+ 'europe-west8',
176
+ 'europe-west9',
177
+ 'europe-west12',
178
+ 'africa-south1',
179
+ 'asia-east1',
180
+ 'asia-east2',
181
+ 'asia-northeast1',
182
+ 'asia-northeast2',
183
+ 'asia-northeast3',
184
+ 'asia-south1',
185
+ 'asia-southeast1',
186
+ 'asia-southeast2',
187
+ 'australia-southeast1',
188
+ 'australia-southeast2',
189
+ 'me-central1',
190
+ 'me-central2',
191
+ 'me-west1',
192
+ 'northamerica-northeast1',
193
+ 'northamerica-northeast2',
194
+ 'southamerica-east1',
195
+ 'southamerica-west1',
196
+ ]
197
+ """Regions available for Vertex AI.
198
+
199
+ More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
200
+ """
@@ -0,0 +1,72 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import TypeVar
5
+
6
+ import httpx
7
+
8
+ from pydantic_ai.models import cached_async_http_client
9
+
10
+ try:
11
+ from openai import AsyncOpenAI
12
+ except ImportError as _import_error: # pragma: no cover
13
+ raise ImportError(
14
+ 'Please install `openai` to use the OpenAI provider, '
15
+ "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
16
+ ) from _import_error
17
+
18
+
19
+ from . import Provider
20
+
21
+ InterfaceClient = TypeVar('InterfaceClient')
22
+
23
+
24
+ class OpenAIProvider(Provider[AsyncOpenAI]):
25
+ """Provider for OpenAI API."""
26
+
27
+ @property
28
+ def name(self) -> str:
29
+ return 'openai' # pragma: no cover
30
+
31
+ @property
32
+ def base_url(self) -> str:
33
+ return self._base_url
34
+
35
+ @property
36
+ def client(self) -> AsyncOpenAI:
37
+ return self._client
38
+
39
+ def __init__(
40
+ self,
41
+ base_url: str | None = None,
42
+ api_key: str | None = None,
43
+ openai_client: AsyncOpenAI | None = None,
44
+ http_client: httpx.AsyncClient | None = None,
45
+ ) -> None:
46
+ """Create a new OpenAI provider.
47
+
48
+ Args:
49
+ base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
50
+ will be used if available. Otherwise, defaults to OpenAI's base url.
51
+ api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
52
+ will be used if available.
53
+ openai_client: An existing
54
+ [`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
55
+ client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
56
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
57
+ """
58
+ self._base_url = base_url or 'https://api.openai.com/v1'
59
+ # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
60
+ # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
61
+ if api_key is None and 'OPENAI_API_KEY' not in os.environ and openai_client is None:
62
+ api_key = 'api-key-not-set'
63
+
64
+ if openai_client is not None:
65
+ assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
66
+ assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
67
+ assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
68
+ self._client = openai_client
69
+ elif http_client is not None:
70
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
71
+ else:
72
+ self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())