pydantic-ai-slim 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_agent_graph.py +39 -38
- pydantic_ai/_pydantic.py +4 -4
- pydantic_ai/_result.py +7 -18
- pydantic_ai/agent.py +24 -21
- pydantic_ai/models/__init__.py +40 -36
- pydantic_ai/models/anthropic.py +3 -1
- pydantic_ai/models/gemini.py +52 -14
- pydantic_ai/models/instrumented.py +25 -27
- pydantic_ai/models/openai.py +56 -15
- pydantic_ai/models/vertexai.py +9 -1
- pydantic_ai/providers/__init__.py +64 -0
- pydantic_ai/providers/deepseek.py +68 -0
- pydantic_ai/providers/google_gla.py +44 -0
- pydantic_ai/providers/google_vertex.py +200 -0
- pydantic_ai/providers/openai.py +72 -0
- pydantic_ai/result.py +19 -27
- {pydantic_ai_slim-0.0.31.dist-info → pydantic_ai_slim-0.0.33.dist-info}/METADATA +4 -4
- {pydantic_ai_slim-0.0.31.dist-info → pydantic_ai_slim-0.0.33.dist-info}/RECORD +19 -14
- {pydantic_ai_slim-0.0.31.dist-info → pydantic_ai_slim-0.0.33.dist-info}/WHEEL +0 -0
|
@@ -6,7 +6,6 @@ from contextlib import asynccontextmanager, contextmanager
|
|
|
6
6
|
from dataclasses import dataclass, field
|
|
7
7
|
from typing import Any, Callable, Literal
|
|
8
8
|
|
|
9
|
-
import logfire_api
|
|
10
9
|
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
|
|
11
10
|
from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
|
|
12
11
|
from opentelemetry.util.types import AttributeValue
|
|
@@ -59,27 +58,15 @@ class InstrumentedModel(WrapperModel):
|
|
|
59
58
|
event_logger_provider: EventLoggerProvider | None = None,
|
|
60
59
|
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
61
60
|
):
|
|
61
|
+
from pydantic_ai import __version__
|
|
62
|
+
|
|
62
63
|
super().__init__(wrapped)
|
|
63
64
|
tracer_provider = tracer_provider or get_tracer_provider()
|
|
64
65
|
event_logger_provider = event_logger_provider or get_event_logger_provider()
|
|
65
|
-
self.tracer = tracer_provider.get_tracer('pydantic-ai')
|
|
66
|
-
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')
|
|
66
|
+
self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
|
|
67
|
+
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
|
|
67
68
|
self.event_mode = event_mode
|
|
68
69
|
|
|
69
|
-
@classmethod
|
|
70
|
-
def from_logfire(
|
|
71
|
-
cls,
|
|
72
|
-
wrapped: Model | KnownModelName,
|
|
73
|
-
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
|
|
74
|
-
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
75
|
-
) -> InstrumentedModel:
|
|
76
|
-
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
|
|
77
|
-
event_provider = logfire_instance.config.get_event_logger_provider()
|
|
78
|
-
else:
|
|
79
|
-
event_provider = None
|
|
80
|
-
tracer_provider = logfire_instance.config.get_tracer_provider()
|
|
81
|
-
return cls(wrapped, tracer_provider, event_provider, event_mode)
|
|
82
|
-
|
|
83
70
|
async def request(
|
|
84
71
|
self,
|
|
85
72
|
messages: list[ModelMessage],
|
|
@@ -199,19 +186,30 @@ class InstrumentedModel(WrapperModel):
|
|
|
199
186
|
@staticmethod
|
|
200
187
|
def messages_to_otel_events(messages: list[ModelMessage]) -> list[Event]:
|
|
201
188
|
result: list[Event] = []
|
|
202
|
-
for message in messages:
|
|
189
|
+
for message_index, message in enumerate(messages):
|
|
190
|
+
message_events: list[Event] = []
|
|
203
191
|
if isinstance(message, ModelRequest):
|
|
204
192
|
for part in message.parts:
|
|
205
193
|
if hasattr(part, 'otel_event'):
|
|
206
|
-
|
|
194
|
+
message_events.append(part.otel_event())
|
|
207
195
|
elif isinstance(message, ModelResponse):
|
|
208
|
-
|
|
196
|
+
message_events = message.otel_events()
|
|
197
|
+
for event in message_events:
|
|
198
|
+
event.attributes = {
|
|
199
|
+
'gen_ai.message.index': message_index,
|
|
200
|
+
**(event.attributes or {}),
|
|
201
|
+
}
|
|
202
|
+
result.extend(message_events)
|
|
209
203
|
for event in result:
|
|
210
|
-
|
|
211
|
-
event.body = ANY_ADAPTER.dump_python(event.body, mode='json')
|
|
212
|
-
except Exception:
|
|
213
|
-
try:
|
|
214
|
-
event.body = str(event.body)
|
|
215
|
-
except Exception:
|
|
216
|
-
event.body = 'Unable to serialize event body'
|
|
204
|
+
event.body = InstrumentedModel.serialize_any(event.body)
|
|
217
205
|
return result
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def serialize_any(value: Any) -> str:
|
|
209
|
+
try:
|
|
210
|
+
return ANY_ADAPTER.dump_python(value, mode='json')
|
|
211
|
+
except Exception:
|
|
212
|
+
try:
|
|
213
|
+
return str(value)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
return f'Unable to serialize: {e}'
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -9,7 +9,9 @@ from datetime import datetime, timezone
|
|
|
9
9
|
from typing import Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
|
-
from typing_extensions import assert_never
|
|
12
|
+
from typing_extensions import assert_never, deprecated
|
|
13
|
+
|
|
14
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
13
15
|
|
|
14
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
17
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -98,10 +100,36 @@ class OpenAIModel(Model):
|
|
|
98
100
|
_model_name: OpenAIModelName = field(repr=False)
|
|
99
101
|
_system: str | None = field(repr=False)
|
|
100
102
|
|
|
103
|
+
@overload
|
|
101
104
|
def __init__(
|
|
102
105
|
self,
|
|
103
106
|
model_name: OpenAIModelName,
|
|
104
107
|
*,
|
|
108
|
+
provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] = 'openai',
|
|
109
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
110
|
+
system: str | None = 'openai',
|
|
111
|
+
) -> None: ...
|
|
112
|
+
|
|
113
|
+
@deprecated('Use the `provider` parameter instead of `base_url`, `api_key`, `openai_client` and `http_client`.')
|
|
114
|
+
@overload
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: OpenAIModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: None = None,
|
|
120
|
+
base_url: str | None = None,
|
|
121
|
+
api_key: str | None = None,
|
|
122
|
+
openai_client: AsyncOpenAI | None = None,
|
|
123
|
+
http_client: AsyncHTTPClient | None = None,
|
|
124
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
125
|
+
system: str | None = 'openai',
|
|
126
|
+
) -> None: ...
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
model_name: OpenAIModelName,
|
|
131
|
+
*,
|
|
132
|
+
provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] | None = None,
|
|
105
133
|
base_url: str | None = None,
|
|
106
134
|
api_key: str | None = None,
|
|
107
135
|
openai_client: AsyncOpenAI | None = None,
|
|
@@ -115,6 +143,7 @@ class OpenAIModel(Model):
|
|
|
115
143
|
model_name: The name of the OpenAI model to use. List of model names available
|
|
116
144
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
117
145
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
146
|
+
provider: The provider to use. Defaults to `'openai'`.
|
|
118
147
|
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
119
148
|
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
120
149
|
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
@@ -129,20 +158,32 @@ class OpenAIModel(Model):
|
|
|
129
158
|
customize the `base_url` and `api_key` to use a different provider.
|
|
130
159
|
"""
|
|
131
160
|
self._model_name = model_name
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
161
|
+
|
|
162
|
+
if provider is not None:
|
|
163
|
+
if isinstance(provider, str):
|
|
164
|
+
self.client = infer_provider(provider).client
|
|
165
|
+
else:
|
|
166
|
+
self.client = provider.client
|
|
167
|
+
else: # pragma: no cover
|
|
168
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
169
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
170
|
+
if (
|
|
171
|
+
api_key is None
|
|
172
|
+
and 'OPENAI_API_KEY' not in os.environ
|
|
173
|
+
and base_url is not None
|
|
174
|
+
and openai_client is None
|
|
175
|
+
):
|
|
176
|
+
api_key = 'api-key-not-set'
|
|
177
|
+
|
|
178
|
+
if openai_client is not None:
|
|
179
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
180
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
181
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
182
|
+
self.client = openai_client
|
|
183
|
+
elif http_client is not None:
|
|
184
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
185
|
+
else:
|
|
186
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
146
187
|
self.system_prompt_role = system_prompt_role
|
|
147
188
|
self._system = system
|
|
148
189
|
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from collections.abc import AsyncIterator
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -8,6 +9,7 @@ from pathlib import Path
|
|
|
8
9
|
from typing import Literal
|
|
9
10
|
|
|
10
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
|
+
from typing_extensions import deprecated
|
|
11
13
|
|
|
12
14
|
from .. import usage
|
|
13
15
|
from .._utils import run_in_executor
|
|
@@ -55,6 +57,7 @@ The template is used thus:
|
|
|
55
57
|
"""
|
|
56
58
|
|
|
57
59
|
|
|
60
|
+
@deprecated('Please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.')
|
|
58
61
|
@dataclass(init=False)
|
|
59
62
|
class VertexAIModel(GeminiModel):
|
|
60
63
|
"""A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API."""
|
|
@@ -103,11 +106,16 @@ class VertexAIModel(GeminiModel):
|
|
|
103
106
|
self.project_id = project_id
|
|
104
107
|
self.region = region
|
|
105
108
|
self.model_publisher = model_publisher
|
|
106
|
-
self.
|
|
109
|
+
self.client = http_client or cached_async_http_client()
|
|
107
110
|
self.url_template = url_template
|
|
108
111
|
|
|
109
112
|
self._auth = None
|
|
110
113
|
self._url = None
|
|
114
|
+
warnings.warn(
|
|
115
|
+
'VertexAIModel is deprecated, please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.',
|
|
116
|
+
DeprecationWarning,
|
|
117
|
+
)
|
|
118
|
+
self._provider = None
|
|
111
119
|
|
|
112
120
|
async def ainit(self) -> None:
|
|
113
121
|
"""Initialize the model, setting the URL and auth.
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Providers for the API clients.
|
|
2
|
+
|
|
3
|
+
The providers are in charge of providing an authenticated client to the API.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations as _annotations
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Generic, TypeVar
|
|
10
|
+
|
|
11
|
+
InterfaceClient = TypeVar('InterfaceClient')
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Provider(ABC, Generic[InterfaceClient]):
|
|
15
|
+
"""Abstract class for a provider.
|
|
16
|
+
|
|
17
|
+
The provider is in charge of providing an authenticated client to the API.
|
|
18
|
+
|
|
19
|
+
Each provider only supports a specific interface. A interface can be supported by multiple providers.
|
|
20
|
+
|
|
21
|
+
For example, the OpenAIModel interface can be supported by the OpenAIProvider and the DeepSeekProvider.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
_client: InterfaceClient
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
"""The provider name."""
|
|
30
|
+
raise NotImplementedError()
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def base_url(self) -> str:
|
|
35
|
+
"""The base URL for the provider API."""
|
|
36
|
+
raise NotImplementedError()
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def client(self) -> InterfaceClient:
|
|
41
|
+
"""The client for the provider."""
|
|
42
|
+
raise NotImplementedError()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def infer_provider(provider: str) -> Provider[Any]:
|
|
46
|
+
"""Infer the provider from the provider name."""
|
|
47
|
+
if provider == 'openai':
|
|
48
|
+
from .openai import OpenAIProvider
|
|
49
|
+
|
|
50
|
+
return OpenAIProvider()
|
|
51
|
+
elif provider == 'deepseek':
|
|
52
|
+
from .deepseek import DeepSeekProvider
|
|
53
|
+
|
|
54
|
+
return DeepSeekProvider()
|
|
55
|
+
elif provider == 'google-vertex':
|
|
56
|
+
from .google_vertex import GoogleVertexProvider
|
|
57
|
+
|
|
58
|
+
return GoogleVertexProvider()
|
|
59
|
+
elif provider == 'google-gla':
|
|
60
|
+
from .google_gla import GoogleGLAProvider
|
|
61
|
+
|
|
62
|
+
return GoogleGLAProvider()
|
|
63
|
+
else: # pragma: no cover
|
|
64
|
+
raise ValueError(f'Unknown provider: {provider}')
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from openai import AsyncOpenAI
|
|
13
|
+
except ImportError as _import_error: # pragma: no cover
|
|
14
|
+
raise ImportError(
|
|
15
|
+
'Please install `openai` to use the DeepSeek provider, '
|
|
16
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
17
|
+
) from _import_error
|
|
18
|
+
|
|
19
|
+
from . import Provider
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DeepSeekProvider(Provider[AsyncOpenAI]):
|
|
23
|
+
"""Provider for DeepSeek API."""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def name(self) -> str:
|
|
27
|
+
return 'deepseek'
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def base_url(self) -> str:
|
|
31
|
+
return 'https://api.deepseek.com'
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def client(self) -> AsyncOpenAI:
|
|
35
|
+
return self._client
|
|
36
|
+
|
|
37
|
+
@overload
|
|
38
|
+
def __init__(self) -> None: ...
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
|
|
45
|
+
|
|
46
|
+
@overload
|
|
47
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
api_key: str | None = None,
|
|
53
|
+
openai_client: AsyncOpenAI | None = None,
|
|
54
|
+
http_client: AsyncHTTPClient | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
api_key = api_key or os.getenv('DEEPSEEK_API_KEY')
|
|
57
|
+
if api_key is None and openai_client is None:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
'Set the `DEEPSEEK_API_KEY` environment variable or pass it via `DeepSeekProvider(api_key=...)`'
|
|
60
|
+
'to use the DeepSeek provider.'
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if openai_client is not None:
|
|
64
|
+
self._client = openai_client
|
|
65
|
+
elif http_client is not None:
|
|
66
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
67
|
+
else:
|
|
68
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.models import cached_async_http_client
|
|
8
|
+
from pydantic_ai.providers import Provider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleGLAProvider(Provider[httpx.AsyncClient]):
|
|
12
|
+
"""Provider for Google Generative Language AI API."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def name(self):
|
|
16
|
+
return 'google-gla'
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def base_url(self) -> str:
|
|
20
|
+
return 'https://generativelanguage.googleapis.com/v1beta/models/'
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def client(self) -> httpx.AsyncClient:
|
|
24
|
+
return self._client
|
|
25
|
+
|
|
26
|
+
def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None:
|
|
27
|
+
"""Create a new Google GLA provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
|
|
31
|
+
will be used if available.
|
|
32
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
33
|
+
"""
|
|
34
|
+
api_key = api_key or os.environ.get('GEMINI_API_KEY')
|
|
35
|
+
if api_key is None:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
'Set the `GEMINI_API_KEY` environment variable or pass it via `GoogleGLAProvider(api_key=...)`'
|
|
38
|
+
'to use the Google GLA provider.'
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self._client = http_client or cached_async_http_client()
|
|
42
|
+
self._client.base_url = self.base_url
|
|
43
|
+
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
|
|
44
|
+
self._client.headers['X-Goog-Api-Key'] = api_key
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
import anyio.to_thread
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from pydantic_ai.exceptions import UserError
|
|
13
|
+
|
|
14
|
+
from ..models import cached_async_http_client
|
|
15
|
+
from . import Provider
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import google.auth
|
|
19
|
+
from google.auth.credentials import Credentials as BaseCredentials
|
|
20
|
+
from google.auth.transport.requests import Request
|
|
21
|
+
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
|
22
|
+
except ImportError as _import_error:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Please install `google-auth` to use the Google Vertex AI provider, '
|
|
25
|
+
"you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
|
|
26
|
+
) from _import_error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ('GoogleVertexProvider',)
|
|
30
|
+
|
|
31
|
+
# default expiry is 3600 seconds
|
|
32
|
+
MAX_TOKEN_AGE = timedelta(seconds=3000)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
36
|
+
"""Provider for Vertex AI API."""
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def name(self) -> str:
|
|
40
|
+
return 'google-vertex'
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def base_url(self) -> str:
|
|
44
|
+
return (
|
|
45
|
+
f'https://{self.region}-aiplatform.googleapis.com/v1'
|
|
46
|
+
f'/projects/{self.project_id}'
|
|
47
|
+
f'/locations/{self.region}'
|
|
48
|
+
f'/publishers/{self.model_publisher}/models/'
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def client(self) -> httpx.AsyncClient:
|
|
53
|
+
return self._client
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
service_account_file: Path | str | None = None,
|
|
58
|
+
project_id: str | None = None,
|
|
59
|
+
region: VertexAiRegion = 'us-central1',
|
|
60
|
+
model_publisher: str = 'google',
|
|
61
|
+
http_client: httpx.AsyncClient | None = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Create a new Vertex AI provider.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
service_account_file: Path to a service account file.
|
|
67
|
+
If not provided, the default environment credentials will be used.
|
|
68
|
+
project_id: The project ID to use, if not provided it will be taken from the credentials.
|
|
69
|
+
region: The region to make requests to.
|
|
70
|
+
model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
|
|
71
|
+
and from trial and error it seems non-google models don't work with the `generateContent` and
|
|
72
|
+
`streamGenerateContent` functions, hence only `google` is currently supported.
|
|
73
|
+
Please create an issue or PR if you know how to use other publishers.
|
|
74
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
75
|
+
"""
|
|
76
|
+
self._client = http_client or cached_async_http_client()
|
|
77
|
+
self.service_account_file = service_account_file
|
|
78
|
+
self.project_id = project_id
|
|
79
|
+
self.region = region
|
|
80
|
+
self.model_publisher = model_publisher
|
|
81
|
+
|
|
82
|
+
self._client.auth = _VertexAIAuth(service_account_file, project_id, region)
|
|
83
|
+
self._client.base_url = self.base_url
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class _VertexAIAuth(httpx.Auth):
|
|
87
|
+
"""Auth class for Vertex AI API."""
|
|
88
|
+
|
|
89
|
+
credentials: BaseCredentials | ServiceAccountCredentials | None
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
service_account_file: Path | str | None = None,
|
|
94
|
+
project_id: str | None = None,
|
|
95
|
+
region: VertexAiRegion = 'us-central1',
|
|
96
|
+
) -> None:
|
|
97
|
+
self.service_account_file = service_account_file
|
|
98
|
+
self.project_id = project_id
|
|
99
|
+
self.region = region
|
|
100
|
+
|
|
101
|
+
self.credentials = None
|
|
102
|
+
self.token_created: datetime | None = None
|
|
103
|
+
|
|
104
|
+
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
105
|
+
if self.credentials is None:
|
|
106
|
+
self.credentials = await self._get_credentials()
|
|
107
|
+
if self.credentials.token is None or self._token_expired(): # type: ignore[reportUnknownMemberType]
|
|
108
|
+
await anyio.to_thread.run_sync(self._refresh_token)
|
|
109
|
+
self.token_created = datetime.now()
|
|
110
|
+
request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
111
|
+
|
|
112
|
+
# NOTE: This workaround is in place because we might get the project_id from the credentials.
|
|
113
|
+
request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
|
|
114
|
+
yield request
|
|
115
|
+
|
|
116
|
+
async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
|
|
117
|
+
if self.service_account_file is not None:
|
|
118
|
+
creds = await _creds_from_file(self.service_account_file)
|
|
119
|
+
assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType]
|
|
120
|
+
creds_project_id: str | None = creds.project_id
|
|
121
|
+
creds_source = 'service account file'
|
|
122
|
+
else:
|
|
123
|
+
creds, creds_project_id = await _async_google_auth()
|
|
124
|
+
creds_source = '`google.auth.default()`'
|
|
125
|
+
|
|
126
|
+
if self.project_id is None:
|
|
127
|
+
if creds_project_id is None:
|
|
128
|
+
raise UserError(f'No project_id provided and none found in {creds_source}')
|
|
129
|
+
self.project_id = creds_project_id
|
|
130
|
+
return creds
|
|
131
|
+
|
|
132
|
+
def _token_expired(self) -> bool:
|
|
133
|
+
if self.token_created is None:
|
|
134
|
+
return True
|
|
135
|
+
else:
|
|
136
|
+
return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
|
|
137
|
+
|
|
138
|
+
def _refresh_token(self) -> str: # pragma: no cover
|
|
139
|
+
assert self.credentials is not None
|
|
140
|
+
self.credentials.refresh(Request()) # type: ignore[reportUnknownMemberType]
|
|
141
|
+
assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
142
|
+
return self.credentials.token
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
|
|
146
|
+
return await anyio.to_thread.run_sync(google.auth.default, ['https://www.googleapis.com/auth/cloud-platform']) # type: ignore
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
|
|
150
|
+
service_account_credentials_from_file = functools.partial(
|
|
151
|
+
ServiceAccountCredentials.from_service_account_file, # type: ignore[reportUnknownMemberType]
|
|
152
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
|
153
|
+
)
|
|
154
|
+
return await anyio.to_thread.run_sync(service_account_credentials_from_file, str(service_account_file))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
VertexAiRegion = Literal[
|
|
158
|
+
'us-central1',
|
|
159
|
+
'us-east1',
|
|
160
|
+
'us-east4',
|
|
161
|
+
'us-south1',
|
|
162
|
+
'us-west1',
|
|
163
|
+
'us-west2',
|
|
164
|
+
'us-west3',
|
|
165
|
+
'us-west4',
|
|
166
|
+
'us-east5',
|
|
167
|
+
'europe-central2',
|
|
168
|
+
'europe-north1',
|
|
169
|
+
'europe-southwest1',
|
|
170
|
+
'europe-west1',
|
|
171
|
+
'europe-west2',
|
|
172
|
+
'europe-west3',
|
|
173
|
+
'europe-west4',
|
|
174
|
+
'europe-west6',
|
|
175
|
+
'europe-west8',
|
|
176
|
+
'europe-west9',
|
|
177
|
+
'europe-west12',
|
|
178
|
+
'africa-south1',
|
|
179
|
+
'asia-east1',
|
|
180
|
+
'asia-east2',
|
|
181
|
+
'asia-northeast1',
|
|
182
|
+
'asia-northeast2',
|
|
183
|
+
'asia-northeast3',
|
|
184
|
+
'asia-south1',
|
|
185
|
+
'asia-southeast1',
|
|
186
|
+
'asia-southeast2',
|
|
187
|
+
'australia-southeast1',
|
|
188
|
+
'australia-southeast2',
|
|
189
|
+
'me-central1',
|
|
190
|
+
'me-central2',
|
|
191
|
+
'me-west1',
|
|
192
|
+
'northamerica-northeast1',
|
|
193
|
+
'northamerica-northeast2',
|
|
194
|
+
'southamerica-east1',
|
|
195
|
+
'southamerica-west1',
|
|
196
|
+
]
|
|
197
|
+
"""Regions available for Vertex AI.
|
|
198
|
+
|
|
199
|
+
More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
|
|
200
|
+
"""
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.models import cached_async_http_client
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from openai import AsyncOpenAI
|
|
12
|
+
except ImportError as _import_error: # pragma: no cover
|
|
13
|
+
raise ImportError(
|
|
14
|
+
'Please install `openai` to use the OpenAI provider, '
|
|
15
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
16
|
+
) from _import_error
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from . import Provider
|
|
20
|
+
|
|
21
|
+
InterfaceClient = TypeVar('InterfaceClient')
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenAIProvider(Provider[AsyncOpenAI]):
|
|
25
|
+
"""Provider for OpenAI API."""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return 'openai' # pragma: no cover
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def base_url(self) -> str:
|
|
33
|
+
return self._base_url
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def client(self) -> AsyncOpenAI:
|
|
37
|
+
return self._client
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
base_url: str | None = None,
|
|
42
|
+
api_key: str | None = None,
|
|
43
|
+
openai_client: AsyncOpenAI | None = None,
|
|
44
|
+
http_client: httpx.AsyncClient | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Create a new OpenAI provider.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
50
|
+
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
51
|
+
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
52
|
+
will be used if available.
|
|
53
|
+
openai_client: An existing
|
|
54
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
55
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
56
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
57
|
+
"""
|
|
58
|
+
self._base_url = base_url or 'https://api.openai.com/v1'
|
|
59
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
60
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
61
|
+
if api_key is None and 'OPENAI_API_KEY' not in os.environ and openai_client is None:
|
|
62
|
+
api_key = 'api-key-not-set'
|
|
63
|
+
|
|
64
|
+
if openai_client is not None:
|
|
65
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
66
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
67
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
68
|
+
self._client = openai_client
|
|
69
|
+
elif http_client is not None:
|
|
70
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
71
|
+
else:
|
|
72
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
|