pydantic-ai-slim 0.0.32__py3-none-any.whl → 0.0.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/_cli.py +225 -0
- pydantic_ai/_pydantic.py +4 -4
- pydantic_ai/_result.py +7 -18
- pydantic_ai/agent.py +29 -9
- pydantic_ai/messages.py +11 -2
- pydantic_ai/models/__init__.py +36 -36
- pydantic_ai/models/gemini.py +51 -14
- pydantic_ai/models/instrumented.py +43 -9
- pydantic_ai/models/openai.py +56 -15
- pydantic_ai/models/vertexai.py +9 -1
- pydantic_ai/providers/__init__.py +64 -0
- pydantic_ai/providers/deepseek.py +68 -0
- pydantic_ai/providers/google_gla.py +44 -0
- pydantic_ai/providers/google_vertex.py +200 -0
- pydantic_ai/providers/openai.py +72 -0
- {pydantic_ai_slim-0.0.32.dist-info → pydantic_ai_slim-0.0.34.dist-info}/METADATA +7 -2
- {pydantic_ai_slim-0.0.32.dist-info → pydantic_ai_slim-0.0.34.dist-info}/RECORD +19 -12
- pydantic_ai_slim-0.0.34.dist-info/entry_points.txt +2 -0
- {pydantic_ai_slim-0.0.32.dist-info → pydantic_ai_slim-0.0.34.dist-info}/WHEEL +0 -0
|
@@ -43,9 +43,16 @@ MODEL_SETTING_ATTRIBUTES: tuple[
|
|
|
43
43
|
ANY_ADAPTER = TypeAdapter[Any](Any)
|
|
44
44
|
|
|
45
45
|
|
|
46
|
-
@dataclass
|
|
47
|
-
class
|
|
48
|
-
"""
|
|
46
|
+
@dataclass(init=False)
|
|
47
|
+
class InstrumentationSettings:
|
|
48
|
+
"""Options for instrumenting models and agents with OpenTelemetry.
|
|
49
|
+
|
|
50
|
+
Used in:
|
|
51
|
+
|
|
52
|
+
- `Agent(instrument=...)`
|
|
53
|
+
- [`Agent.instrument_all()`][pydantic_ai.agent.Agent.instrument_all]
|
|
54
|
+
- `InstrumentedModel`
|
|
55
|
+
"""
|
|
49
56
|
|
|
50
57
|
tracer: Tracer = field(repr=False)
|
|
51
58
|
event_logger: EventLogger = field(repr=False)
|
|
@@ -53,20 +60,47 @@ class InstrumentedModel(WrapperModel):
|
|
|
53
60
|
|
|
54
61
|
def __init__(
|
|
55
62
|
self,
|
|
56
|
-
|
|
63
|
+
*,
|
|
64
|
+
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
57
65
|
tracer_provider: TracerProvider | None = None,
|
|
58
66
|
event_logger_provider: EventLoggerProvider | None = None,
|
|
59
|
-
event_mode: Literal['attributes', 'logs'] = 'attributes',
|
|
60
67
|
):
|
|
68
|
+
"""Create instrumentation options.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
event_mode: The mode for emitting events. If `'attributes'`, events are attached to the span as attributes.
|
|
72
|
+
If `'logs'`, events are emitted as OpenTelemetry log-based events.
|
|
73
|
+
tracer_provider: The OpenTelemetry tracer provider to use.
|
|
74
|
+
If not provided, the global tracer provider is used.
|
|
75
|
+
Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
|
|
76
|
+
event_logger_provider: The OpenTelemetry event logger provider to use.
|
|
77
|
+
If not provided, the global event logger provider is used.
|
|
78
|
+
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
|
|
79
|
+
This is only used if `event_mode='logs'`.
|
|
80
|
+
"""
|
|
61
81
|
from pydantic_ai import __version__
|
|
62
82
|
|
|
63
|
-
super().__init__(wrapped)
|
|
64
83
|
tracer_provider = tracer_provider or get_tracer_provider()
|
|
65
84
|
event_logger_provider = event_logger_provider or get_event_logger_provider()
|
|
66
85
|
self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
|
|
67
86
|
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
|
|
68
87
|
self.event_mode = event_mode
|
|
69
88
|
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class InstrumentedModel(WrapperModel):
|
|
92
|
+
"""Model which is instrumented with OpenTelemetry."""
|
|
93
|
+
|
|
94
|
+
options: InstrumentationSettings
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
wrapped: Model | KnownModelName,
|
|
99
|
+
options: InstrumentationSettings | None = None,
|
|
100
|
+
) -> None:
|
|
101
|
+
super().__init__(wrapped)
|
|
102
|
+
self.options = options or InstrumentationSettings()
|
|
103
|
+
|
|
70
104
|
async def request(
|
|
71
105
|
self,
|
|
72
106
|
messages: list[ModelMessage],
|
|
@@ -123,7 +157,7 @@ class InstrumentedModel(WrapperModel):
|
|
|
123
157
|
if isinstance(value := model_settings.get(key), (float, int)):
|
|
124
158
|
attributes[f'gen_ai.request.{key}'] = value
|
|
125
159
|
|
|
126
|
-
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
160
|
+
with self.options.tracer.start_as_current_span(span_name, attributes=attributes) as span:
|
|
127
161
|
|
|
128
162
|
def finish(response: ModelResponse, usage: Usage):
|
|
129
163
|
if not span.is_recording():
|
|
@@ -156,9 +190,9 @@ class InstrumentedModel(WrapperModel):
|
|
|
156
190
|
def _emit_events(self, system: str, span: Span, events: list[Event]) -> None:
|
|
157
191
|
for event in events:
|
|
158
192
|
event.attributes = {'gen_ai.system': system, **(event.attributes or {})}
|
|
159
|
-
if self.event_mode == 'logs':
|
|
193
|
+
if self.options.event_mode == 'logs':
|
|
160
194
|
for event in events:
|
|
161
|
-
self.event_logger.emit(event)
|
|
195
|
+
self.options.event_logger.emit(event)
|
|
162
196
|
else:
|
|
163
197
|
attr_name = 'events'
|
|
164
198
|
span.set_attributes(
|
pydantic_ai/models/openai.py
CHANGED
|
@@ -9,7 +9,9 @@ from datetime import datetime, timezone
|
|
|
9
9
|
from typing import Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
|
-
from typing_extensions import assert_never
|
|
12
|
+
from typing_extensions import assert_never, deprecated
|
|
13
|
+
|
|
14
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
13
15
|
|
|
14
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
17
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -98,10 +100,36 @@ class OpenAIModel(Model):
|
|
|
98
100
|
_model_name: OpenAIModelName = field(repr=False)
|
|
99
101
|
_system: str | None = field(repr=False)
|
|
100
102
|
|
|
103
|
+
@overload
|
|
101
104
|
def __init__(
|
|
102
105
|
self,
|
|
103
106
|
model_name: OpenAIModelName,
|
|
104
107
|
*,
|
|
108
|
+
provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] = 'openai',
|
|
109
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
110
|
+
system: str | None = 'openai',
|
|
111
|
+
) -> None: ...
|
|
112
|
+
|
|
113
|
+
@deprecated('Use the `provider` parameter instead of `base_url`, `api_key`, `openai_client` and `http_client`.')
|
|
114
|
+
@overload
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: OpenAIModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: None = None,
|
|
120
|
+
base_url: str | None = None,
|
|
121
|
+
api_key: str | None = None,
|
|
122
|
+
openai_client: AsyncOpenAI | None = None,
|
|
123
|
+
http_client: AsyncHTTPClient | None = None,
|
|
124
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
125
|
+
system: str | None = 'openai',
|
|
126
|
+
) -> None: ...
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
model_name: OpenAIModelName,
|
|
131
|
+
*,
|
|
132
|
+
provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] | None = None,
|
|
105
133
|
base_url: str | None = None,
|
|
106
134
|
api_key: str | None = None,
|
|
107
135
|
openai_client: AsyncOpenAI | None = None,
|
|
@@ -115,6 +143,7 @@ class OpenAIModel(Model):
|
|
|
115
143
|
model_name: The name of the OpenAI model to use. List of model names available
|
|
116
144
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
117
145
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
146
|
+
provider: The provider to use. Defaults to `'openai'`.
|
|
118
147
|
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
119
148
|
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
120
149
|
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
@@ -129,20 +158,32 @@ class OpenAIModel(Model):
|
|
|
129
158
|
customize the `base_url` and `api_key` to use a different provider.
|
|
130
159
|
"""
|
|
131
160
|
self._model_name = model_name
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
161
|
+
|
|
162
|
+
if provider is not None:
|
|
163
|
+
if isinstance(provider, str):
|
|
164
|
+
self.client = infer_provider(provider).client
|
|
165
|
+
else:
|
|
166
|
+
self.client = provider.client
|
|
167
|
+
else: # pragma: no cover
|
|
168
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
169
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
170
|
+
if (
|
|
171
|
+
api_key is None
|
|
172
|
+
and 'OPENAI_API_KEY' not in os.environ
|
|
173
|
+
and base_url is not None
|
|
174
|
+
and openai_client is None
|
|
175
|
+
):
|
|
176
|
+
api_key = 'api-key-not-set'
|
|
177
|
+
|
|
178
|
+
if openai_client is not None:
|
|
179
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
180
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
181
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
182
|
+
self.client = openai_client
|
|
183
|
+
elif http_client is not None:
|
|
184
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
185
|
+
else:
|
|
186
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
146
187
|
self.system_prompt_role = system_prompt_role
|
|
147
188
|
self._system = system
|
|
148
189
|
|
pydantic_ai/models/vertexai.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from collections.abc import AsyncIterator
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -8,6 +9,7 @@ from pathlib import Path
|
|
|
8
9
|
from typing import Literal
|
|
9
10
|
|
|
10
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
|
+
from typing_extensions import deprecated
|
|
11
13
|
|
|
12
14
|
from .. import usage
|
|
13
15
|
from .._utils import run_in_executor
|
|
@@ -55,6 +57,7 @@ The template is used thus:
|
|
|
55
57
|
"""
|
|
56
58
|
|
|
57
59
|
|
|
60
|
+
@deprecated('Please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.')
|
|
58
61
|
@dataclass(init=False)
|
|
59
62
|
class VertexAIModel(GeminiModel):
|
|
60
63
|
"""A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API."""
|
|
@@ -103,11 +106,16 @@ class VertexAIModel(GeminiModel):
|
|
|
103
106
|
self.project_id = project_id
|
|
104
107
|
self.region = region
|
|
105
108
|
self.model_publisher = model_publisher
|
|
106
|
-
self.
|
|
109
|
+
self.client = http_client or cached_async_http_client()
|
|
107
110
|
self.url_template = url_template
|
|
108
111
|
|
|
109
112
|
self._auth = None
|
|
110
113
|
self._url = None
|
|
114
|
+
warnings.warn(
|
|
115
|
+
'VertexAIModel is deprecated, please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.',
|
|
116
|
+
DeprecationWarning,
|
|
117
|
+
)
|
|
118
|
+
self._provider = None
|
|
111
119
|
|
|
112
120
|
async def ainit(self) -> None:
|
|
113
121
|
"""Initialize the model, setting the URL and auth.
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Providers for the API clients.
|
|
2
|
+
|
|
3
|
+
The providers are in charge of providing an authenticated client to the API.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations as _annotations
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Generic, TypeVar
|
|
10
|
+
|
|
11
|
+
InterfaceClient = TypeVar('InterfaceClient')
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Provider(ABC, Generic[InterfaceClient]):
|
|
15
|
+
"""Abstract class for a provider.
|
|
16
|
+
|
|
17
|
+
The provider is in charge of providing an authenticated client to the API.
|
|
18
|
+
|
|
19
|
+
Each provider only supports a specific interface. A interface can be supported by multiple providers.
|
|
20
|
+
|
|
21
|
+
For example, the OpenAIModel interface can be supported by the OpenAIProvider and the DeepSeekProvider.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
_client: InterfaceClient
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
"""The provider name."""
|
|
30
|
+
raise NotImplementedError()
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def base_url(self) -> str:
|
|
35
|
+
"""The base URL for the provider API."""
|
|
36
|
+
raise NotImplementedError()
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def client(self) -> InterfaceClient:
|
|
41
|
+
"""The client for the provider."""
|
|
42
|
+
raise NotImplementedError()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def infer_provider(provider: str) -> Provider[Any]:
|
|
46
|
+
"""Infer the provider from the provider name."""
|
|
47
|
+
if provider == 'openai':
|
|
48
|
+
from .openai import OpenAIProvider
|
|
49
|
+
|
|
50
|
+
return OpenAIProvider()
|
|
51
|
+
elif provider == 'deepseek':
|
|
52
|
+
from .deepseek import DeepSeekProvider
|
|
53
|
+
|
|
54
|
+
return DeepSeekProvider()
|
|
55
|
+
elif provider == 'google-vertex':
|
|
56
|
+
from .google_vertex import GoogleVertexProvider
|
|
57
|
+
|
|
58
|
+
return GoogleVertexProvider()
|
|
59
|
+
elif provider == 'google-gla':
|
|
60
|
+
from .google_gla import GoogleGLAProvider
|
|
61
|
+
|
|
62
|
+
return GoogleGLAProvider()
|
|
63
|
+
else: # pragma: no cover
|
|
64
|
+
raise ValueError(f'Unknown provider: {provider}')
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from openai import AsyncOpenAI
|
|
13
|
+
except ImportError as _import_error: # pragma: no cover
|
|
14
|
+
raise ImportError(
|
|
15
|
+
'Please install `openai` to use the DeepSeek provider, '
|
|
16
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
17
|
+
) from _import_error
|
|
18
|
+
|
|
19
|
+
from . import Provider
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DeepSeekProvider(Provider[AsyncOpenAI]):
|
|
23
|
+
"""Provider for DeepSeek API."""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def name(self) -> str:
|
|
27
|
+
return 'deepseek'
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def base_url(self) -> str:
|
|
31
|
+
return 'https://api.deepseek.com'
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def client(self) -> AsyncOpenAI:
|
|
35
|
+
return self._client
|
|
36
|
+
|
|
37
|
+
@overload
|
|
38
|
+
def __init__(self) -> None: ...
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
|
|
45
|
+
|
|
46
|
+
@overload
|
|
47
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
api_key: str | None = None,
|
|
53
|
+
openai_client: AsyncOpenAI | None = None,
|
|
54
|
+
http_client: AsyncHTTPClient | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
api_key = api_key or os.getenv('DEEPSEEK_API_KEY')
|
|
57
|
+
if api_key is None and openai_client is None:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
'Set the `DEEPSEEK_API_KEY` environment variable or pass it via `DeepSeekProvider(api_key=...)`'
|
|
60
|
+
'to use the DeepSeek provider.'
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if openai_client is not None:
|
|
64
|
+
self._client = openai_client
|
|
65
|
+
elif http_client is not None:
|
|
66
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
67
|
+
else:
|
|
68
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.models import cached_async_http_client
|
|
8
|
+
from pydantic_ai.providers import Provider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleGLAProvider(Provider[httpx.AsyncClient]):
|
|
12
|
+
"""Provider for Google Generative Language AI API."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def name(self):
|
|
16
|
+
return 'google-gla'
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def base_url(self) -> str:
|
|
20
|
+
return 'https://generativelanguage.googleapis.com/v1beta/models/'
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def client(self) -> httpx.AsyncClient:
|
|
24
|
+
return self._client
|
|
25
|
+
|
|
26
|
+
def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None:
|
|
27
|
+
"""Create a new Google GLA provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
|
|
31
|
+
will be used if available.
|
|
32
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
33
|
+
"""
|
|
34
|
+
api_key = api_key or os.environ.get('GEMINI_API_KEY')
|
|
35
|
+
if api_key is None:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
'Set the `GEMINI_API_KEY` environment variable or pass it via `GoogleGLAProvider(api_key=...)`'
|
|
38
|
+
'to use the Google GLA provider.'
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self._client = http_client or cached_async_http_client()
|
|
42
|
+
self._client.base_url = self.base_url
|
|
43
|
+
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
|
|
44
|
+
self._client.headers['X-Goog-Api-Key'] = api_key
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
import anyio.to_thread
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from pydantic_ai.exceptions import UserError
|
|
13
|
+
|
|
14
|
+
from ..models import cached_async_http_client
|
|
15
|
+
from . import Provider
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import google.auth
|
|
19
|
+
from google.auth.credentials import Credentials as BaseCredentials
|
|
20
|
+
from google.auth.transport.requests import Request
|
|
21
|
+
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
|
22
|
+
except ImportError as _import_error:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Please install `google-auth` to use the Google Vertex AI provider, '
|
|
25
|
+
"you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
|
|
26
|
+
) from _import_error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ('GoogleVertexProvider',)
|
|
30
|
+
|
|
31
|
+
# default expiry is 3600 seconds
|
|
32
|
+
MAX_TOKEN_AGE = timedelta(seconds=3000)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
36
|
+
"""Provider for Vertex AI API."""
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def name(self) -> str:
|
|
40
|
+
return 'google-vertex'
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def base_url(self) -> str:
|
|
44
|
+
return (
|
|
45
|
+
f'https://{self.region}-aiplatform.googleapis.com/v1'
|
|
46
|
+
f'/projects/{self.project_id}'
|
|
47
|
+
f'/locations/{self.region}'
|
|
48
|
+
f'/publishers/{self.model_publisher}/models/'
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def client(self) -> httpx.AsyncClient:
|
|
53
|
+
return self._client
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
service_account_file: Path | str | None = None,
|
|
58
|
+
project_id: str | None = None,
|
|
59
|
+
region: VertexAiRegion = 'us-central1',
|
|
60
|
+
model_publisher: str = 'google',
|
|
61
|
+
http_client: httpx.AsyncClient | None = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Create a new Vertex AI provider.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
service_account_file: Path to a service account file.
|
|
67
|
+
If not provided, the default environment credentials will be used.
|
|
68
|
+
project_id: The project ID to use, if not provided it will be taken from the credentials.
|
|
69
|
+
region: The region to make requests to.
|
|
70
|
+
model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
|
|
71
|
+
and from trial and error it seems non-google models don't work with the `generateContent` and
|
|
72
|
+
`streamGenerateContent` functions, hence only `google` is currently supported.
|
|
73
|
+
Please create an issue or PR if you know how to use other publishers.
|
|
74
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
75
|
+
"""
|
|
76
|
+
self._client = http_client or cached_async_http_client()
|
|
77
|
+
self.service_account_file = service_account_file
|
|
78
|
+
self.project_id = project_id
|
|
79
|
+
self.region = region
|
|
80
|
+
self.model_publisher = model_publisher
|
|
81
|
+
|
|
82
|
+
self._client.auth = _VertexAIAuth(service_account_file, project_id, region)
|
|
83
|
+
self._client.base_url = self.base_url
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class _VertexAIAuth(httpx.Auth):
|
|
87
|
+
"""Auth class for Vertex AI API."""
|
|
88
|
+
|
|
89
|
+
credentials: BaseCredentials | ServiceAccountCredentials | None
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
service_account_file: Path | str | None = None,
|
|
94
|
+
project_id: str | None = None,
|
|
95
|
+
region: VertexAiRegion = 'us-central1',
|
|
96
|
+
) -> None:
|
|
97
|
+
self.service_account_file = service_account_file
|
|
98
|
+
self.project_id = project_id
|
|
99
|
+
self.region = region
|
|
100
|
+
|
|
101
|
+
self.credentials = None
|
|
102
|
+
self.token_created: datetime | None = None
|
|
103
|
+
|
|
104
|
+
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
105
|
+
if self.credentials is None:
|
|
106
|
+
self.credentials = await self._get_credentials()
|
|
107
|
+
if self.credentials.token is None or self._token_expired(): # type: ignore[reportUnknownMemberType]
|
|
108
|
+
await anyio.to_thread.run_sync(self._refresh_token)
|
|
109
|
+
self.token_created = datetime.now()
|
|
110
|
+
request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
111
|
+
|
|
112
|
+
# NOTE: This workaround is in place because we might get the project_id from the credentials.
|
|
113
|
+
request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
|
|
114
|
+
yield request
|
|
115
|
+
|
|
116
|
+
async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
|
|
117
|
+
if self.service_account_file is not None:
|
|
118
|
+
creds = await _creds_from_file(self.service_account_file)
|
|
119
|
+
assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType]
|
|
120
|
+
creds_project_id: str | None = creds.project_id
|
|
121
|
+
creds_source = 'service account file'
|
|
122
|
+
else:
|
|
123
|
+
creds, creds_project_id = await _async_google_auth()
|
|
124
|
+
creds_source = '`google.auth.default()`'
|
|
125
|
+
|
|
126
|
+
if self.project_id is None:
|
|
127
|
+
if creds_project_id is None:
|
|
128
|
+
raise UserError(f'No project_id provided and none found in {creds_source}')
|
|
129
|
+
self.project_id = creds_project_id
|
|
130
|
+
return creds
|
|
131
|
+
|
|
132
|
+
def _token_expired(self) -> bool:
|
|
133
|
+
if self.token_created is None:
|
|
134
|
+
return True
|
|
135
|
+
else:
|
|
136
|
+
return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
|
|
137
|
+
|
|
138
|
+
def _refresh_token(self) -> str: # pragma: no cover
|
|
139
|
+
assert self.credentials is not None
|
|
140
|
+
self.credentials.refresh(Request()) # type: ignore[reportUnknownMemberType]
|
|
141
|
+
assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
142
|
+
return self.credentials.token
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
|
|
146
|
+
return await anyio.to_thread.run_sync(google.auth.default, ['https://www.googleapis.com/auth/cloud-platform']) # type: ignore
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
|
|
150
|
+
service_account_credentials_from_file = functools.partial(
|
|
151
|
+
ServiceAccountCredentials.from_service_account_file, # type: ignore[reportUnknownMemberType]
|
|
152
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
|
153
|
+
)
|
|
154
|
+
return await anyio.to_thread.run_sync(service_account_credentials_from_file, str(service_account_file))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
VertexAiRegion = Literal[
|
|
158
|
+
'us-central1',
|
|
159
|
+
'us-east1',
|
|
160
|
+
'us-east4',
|
|
161
|
+
'us-south1',
|
|
162
|
+
'us-west1',
|
|
163
|
+
'us-west2',
|
|
164
|
+
'us-west3',
|
|
165
|
+
'us-west4',
|
|
166
|
+
'us-east5',
|
|
167
|
+
'europe-central2',
|
|
168
|
+
'europe-north1',
|
|
169
|
+
'europe-southwest1',
|
|
170
|
+
'europe-west1',
|
|
171
|
+
'europe-west2',
|
|
172
|
+
'europe-west3',
|
|
173
|
+
'europe-west4',
|
|
174
|
+
'europe-west6',
|
|
175
|
+
'europe-west8',
|
|
176
|
+
'europe-west9',
|
|
177
|
+
'europe-west12',
|
|
178
|
+
'africa-south1',
|
|
179
|
+
'asia-east1',
|
|
180
|
+
'asia-east2',
|
|
181
|
+
'asia-northeast1',
|
|
182
|
+
'asia-northeast2',
|
|
183
|
+
'asia-northeast3',
|
|
184
|
+
'asia-south1',
|
|
185
|
+
'asia-southeast1',
|
|
186
|
+
'asia-southeast2',
|
|
187
|
+
'australia-southeast1',
|
|
188
|
+
'australia-southeast2',
|
|
189
|
+
'me-central1',
|
|
190
|
+
'me-central2',
|
|
191
|
+
'me-west1',
|
|
192
|
+
'northamerica-northeast1',
|
|
193
|
+
'northamerica-northeast2',
|
|
194
|
+
'southamerica-east1',
|
|
195
|
+
'southamerica-west1',
|
|
196
|
+
]
|
|
197
|
+
"""Regions available for Vertex AI.
|
|
198
|
+
|
|
199
|
+
More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
|
|
200
|
+
"""
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.models import cached_async_http_client
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from openai import AsyncOpenAI
|
|
12
|
+
except ImportError as _import_error: # pragma: no cover
|
|
13
|
+
raise ImportError(
|
|
14
|
+
'Please install `openai` to use the OpenAI provider, '
|
|
15
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
16
|
+
) from _import_error
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from . import Provider
|
|
20
|
+
|
|
21
|
+
InterfaceClient = TypeVar('InterfaceClient')
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenAIProvider(Provider[AsyncOpenAI]):
|
|
25
|
+
"""Provider for OpenAI API."""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return 'openai' # pragma: no cover
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def base_url(self) -> str:
|
|
33
|
+
return self._base_url
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def client(self) -> AsyncOpenAI:
|
|
37
|
+
return self._client
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
base_url: str | None = None,
|
|
42
|
+
api_key: str | None = None,
|
|
43
|
+
openai_client: AsyncOpenAI | None = None,
|
|
44
|
+
http_client: httpx.AsyncClient | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Create a new OpenAI provider.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
50
|
+
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
51
|
+
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
52
|
+
will be used if available.
|
|
53
|
+
openai_client: An existing
|
|
54
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
55
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
56
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
57
|
+
"""
|
|
58
|
+
self._base_url = base_url or 'https://api.openai.com/v1'
|
|
59
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
60
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
61
|
+
if api_key is None and 'OPENAI_API_KEY' not in os.environ and openai_client is None:
|
|
62
|
+
api_key = 'api-key-not-set'
|
|
63
|
+
|
|
64
|
+
if openai_client is not None:
|
|
65
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
66
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
67
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
68
|
+
self._client = openai_client
|
|
69
|
+
elif http_client is not None:
|
|
70
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
71
|
+
else:
|
|
72
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
|