pydantic-ai-slim 0.0.32__tar.gz → 0.0.33__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/PKG-INFO +3 -2
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/_pydantic.py +4 -4
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/_result.py +7 -18
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/__init__.py +36 -36
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/gemini.py +51 -14
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/openai.py +56 -15
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/vertexai.py +9 -1
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/__init__.py +64 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/deepseek.py +68 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/google_gla.py +44 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/google_vertex.py +200 -0
- pydantic_ai_slim-0.0.33/pydantic_ai/providers/openai.py +72 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pyproject.toml +3 -2
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/.gitignore +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/README.md +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/agent.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/anthropic.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/settings.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.0.32 → pydantic_ai_slim-0.0.33}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,8 +29,9 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.33
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
|
+
Requires-Dist: typing-inspection>=0.4.0
|
|
34
35
|
Provides-Extra: anthropic
|
|
35
36
|
Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
|
|
36
37
|
Provides-Extra: cohere
|
|
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
|
|
|
6
6
|
from __future__ import annotations as _annotations
|
|
7
7
|
|
|
8
8
|
from inspect import Parameter, signature
|
|
9
|
-
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
|
|
10
10
|
|
|
11
11
|
from pydantic import ConfigDict
|
|
12
12
|
from pydantic._internal import _decorators, _generate_schema, _typing_extra
|
|
@@ -15,6 +15,7 @@ from pydantic.fields import FieldInfo
|
|
|
15
15
|
from pydantic.json_schema import GenerateJsonSchema
|
|
16
16
|
from pydantic.plugin._schema_validator import create_schema_validator
|
|
17
17
|
from pydantic_core import SchemaValidator, core_schema
|
|
18
|
+
from typing_extensions import get_origin
|
|
18
19
|
|
|
19
20
|
from ._griffe import doc_descriptions
|
|
20
21
|
from ._utils import check_object_json_schema, is_model_like
|
|
@@ -223,8 +224,7 @@ def _build_schema(
|
|
|
223
224
|
|
|
224
225
|
|
|
225
226
|
def _is_call_ctx(annotation: Any) -> bool:
|
|
227
|
+
"""Return whether the annotation is the `RunContext` class, parameterized or not."""
|
|
226
228
|
from .tools import RunContext
|
|
227
229
|
|
|
228
|
-
return annotation is RunContext or (
|
|
229
|
-
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
|
|
230
|
-
)
|
|
230
|
+
return annotation is RunContext or get_origin(annotation) is RunContext
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
import sys
|
|
5
|
-
import types
|
|
6
4
|
from collections.abc import Awaitable, Iterable, Iterator
|
|
7
5
|
from dataclasses import dataclass, field
|
|
8
|
-
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
6
|
+
from typing import Any, Callable, Generic, Literal, Union, cast
|
|
9
7
|
|
|
10
8
|
from pydantic import TypeAdapter, ValidationError
|
|
11
|
-
from typing_extensions import
|
|
9
|
+
from typing_extensions import TypedDict, TypeVar, get_args, get_origin
|
|
10
|
+
from typing_inspection import typing_objects
|
|
11
|
+
from typing_inspection.introspection import is_union_origin
|
|
12
12
|
|
|
13
13
|
from . import _utils, messages as _messages
|
|
14
14
|
from .exceptions import ModelRetry
|
|
@@ -248,23 +248,12 @@ def extract_str_from_union(response_type: Any) -> _utils.Option[Any]:
|
|
|
248
248
|
|
|
249
249
|
|
|
250
250
|
def get_union_args(tp: Any) -> tuple[Any, ...]:
|
|
251
|
-
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty
|
|
252
|
-
if
|
|
251
|
+
"""Extract the arguments of a Union type if `response_type` is a union, otherwise return an empty tuple."""
|
|
252
|
+
if typing_objects.is_typealiastype(tp):
|
|
253
253
|
tp = tp.__value__
|
|
254
254
|
|
|
255
255
|
origin = get_origin(tp)
|
|
256
|
-
if
|
|
256
|
+
if is_union_origin(origin):
|
|
257
257
|
return get_args(tp)
|
|
258
258
|
else:
|
|
259
259
|
return ()
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
if sys.version_info < (3, 10):
|
|
263
|
-
|
|
264
|
-
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
265
|
-
return tp is Union
|
|
266
|
-
|
|
267
|
-
else:
|
|
268
|
-
|
|
269
|
-
def origin_is_union(tp: type[Any] | None) -> bool:
|
|
270
|
-
return tp is Union or tp is types.UnionType
|
|
@@ -49,6 +49,8 @@ KnownModelName = Literal[
|
|
|
49
49
|
'cohere:command-r-plus-04-2024',
|
|
50
50
|
'cohere:command-r-plus-08-2024',
|
|
51
51
|
'cohere:command-r7b-12-2024',
|
|
52
|
+
'deepseek:deepseek-chat',
|
|
53
|
+
'deepseek:deepseek-reasoner',
|
|
52
54
|
'google-gla:gemini-1.0-pro',
|
|
53
55
|
'google-gla:gemini-1.5-flash',
|
|
54
56
|
'google-gla:gemini-1.5-flash-8b',
|
|
@@ -320,54 +322,52 @@ def infer_model(model: Model | KnownModelName) -> Model:
|
|
|
320
322
|
from .test import TestModel
|
|
321
323
|
|
|
322
324
|
return TestModel()
|
|
323
|
-
elif model.startswith('cohere:'):
|
|
324
|
-
from .cohere import CohereModel
|
|
325
325
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
326
|
+
try:
|
|
327
|
+
provider, model_name = model.split(':')
|
|
328
|
+
except ValueError:
|
|
329
|
+
model_name = model
|
|
330
|
+
# TODO(Marcelo): We should deprecate this way.
|
|
331
|
+
if model_name.startswith(('gpt', 'o1', 'o3')):
|
|
332
|
+
provider = 'openai'
|
|
333
|
+
elif model_name.startswith('claude'):
|
|
334
|
+
provider = 'anthropic'
|
|
335
|
+
elif model_name.startswith('gemini'):
|
|
336
|
+
provider = 'google-gla'
|
|
337
|
+
else:
|
|
338
|
+
raise UserError(f'Unknown model: {model}')
|
|
339
|
+
|
|
340
|
+
if provider == 'vertexai':
|
|
341
|
+
provider = 'google-vertex'
|
|
342
|
+
|
|
343
|
+
if provider == 'cohere':
|
|
344
|
+
from .cohere import CohereModel
|
|
329
345
|
|
|
330
|
-
|
|
331
|
-
|
|
346
|
+
# TODO(Marcelo): Missing provider API.
|
|
347
|
+
return CohereModel(model_name)
|
|
348
|
+
elif provider in ('deepseek', 'openai'):
|
|
332
349
|
from .openai import OpenAIModel
|
|
333
350
|
|
|
334
|
-
return OpenAIModel(
|
|
335
|
-
elif
|
|
336
|
-
from .gemini import GeminiModel
|
|
337
|
-
|
|
338
|
-
return GeminiModel(model[11:])
|
|
339
|
-
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
|
|
340
|
-
elif model.startswith('gemini'):
|
|
351
|
+
return OpenAIModel(model_name, provider=provider)
|
|
352
|
+
elif provider in ('google-gla', 'google-vertex'):
|
|
341
353
|
from .gemini import GeminiModel
|
|
342
354
|
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
elif model.startswith('groq:'):
|
|
355
|
+
return GeminiModel(model_name, provider=provider)
|
|
356
|
+
elif provider == 'groq':
|
|
346
357
|
from .groq import GroqModel
|
|
347
358
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
return VertexAIModel(model[14:])
|
|
353
|
-
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
|
|
354
|
-
elif model.startswith('vertexai:'):
|
|
355
|
-
from .vertexai import VertexAIModel
|
|
356
|
-
|
|
357
|
-
return VertexAIModel(model[9:])
|
|
358
|
-
elif model.startswith('mistral:'):
|
|
359
|
+
# TODO(Marcelo): Missing provider API.
|
|
360
|
+
return GroqModel(model_name)
|
|
361
|
+
elif provider == 'mistral':
|
|
359
362
|
from .mistral import MistralModel
|
|
360
363
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
return AnthropicModel(model[10:])
|
|
366
|
-
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
|
|
367
|
-
elif model.startswith('claude'):
|
|
364
|
+
# TODO(Marcelo): Missing provider API.
|
|
365
|
+
return MistralModel(model_name)
|
|
366
|
+
elif provider == 'anthropic':
|
|
368
367
|
from .anthropic import AnthropicModel
|
|
369
368
|
|
|
370
|
-
|
|
369
|
+
# TODO(Marcelo): Missing provider API.
|
|
370
|
+
return AnthropicModel(model_name)
|
|
371
371
|
else:
|
|
372
372
|
raise UserError(f'Unknown model: {model}')
|
|
373
373
|
|
|
@@ -8,12 +8,14 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from copy import deepcopy
|
|
9
9
|
from dataclasses import dataclass, field
|
|
10
10
|
from datetime import datetime
|
|
11
|
-
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
11
|
+
from typing import Annotated, Any, Literal, Protocol, Union, cast, overload
|
|
12
12
|
from uuid import uuid4
|
|
13
13
|
|
|
14
14
|
import pydantic
|
|
15
15
|
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
|
|
16
|
-
from typing_extensions import NotRequired, TypedDict, assert_never
|
|
16
|
+
from typing_extensions import NotRequired, TypedDict, assert_never, deprecated
|
|
17
|
+
|
|
18
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
17
19
|
|
|
18
20
|
from .. import ModelHTTPError, UnexpectedModelBehavior, UserError, _utils, usage
|
|
19
21
|
from ..messages import (
|
|
@@ -82,17 +84,39 @@ class GeminiModel(Model):
|
|
|
82
84
|
Apart from `__init__`, all methods are private or match those of the base class.
|
|
83
85
|
"""
|
|
84
86
|
|
|
85
|
-
|
|
87
|
+
client: AsyncHTTPClient = field(repr=False)
|
|
86
88
|
|
|
87
89
|
_model_name: GeminiModelName = field(repr=False)
|
|
90
|
+
_provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = field(repr=False)
|
|
88
91
|
_auth: AuthProtocol | None = field(repr=False)
|
|
89
92
|
_url: str | None = field(repr=False)
|
|
90
93
|
_system: str | None = field(default='google-gla', repr=False)
|
|
91
94
|
|
|
95
|
+
@overload
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
model_name: GeminiModelName,
|
|
99
|
+
*,
|
|
100
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] = 'google-gla',
|
|
101
|
+
) -> None: ...
|
|
102
|
+
|
|
103
|
+
@deprecated('Use the `provider` argument instead of the `api_key`, `http_client`, and `url_template` arguments.')
|
|
104
|
+
@overload
|
|
92
105
|
def __init__(
|
|
93
106
|
self,
|
|
94
107
|
model_name: GeminiModelName,
|
|
95
108
|
*,
|
|
109
|
+
provider: None = None,
|
|
110
|
+
api_key: str | None = None,
|
|
111
|
+
http_client: AsyncHTTPClient | None = None,
|
|
112
|
+
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
113
|
+
) -> None: ...
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: GeminiModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: Literal['google-gla', 'google-vertex'] | Provider[AsyncHTTPClient] | None = None,
|
|
96
120
|
api_key: str | None = None,
|
|
97
121
|
http_client: AsyncHTTPClient | None = None,
|
|
98
122
|
url_template: str = 'https://generativelanguage.googleapis.com/v1beta/models/{model}:',
|
|
@@ -101,6 +125,7 @@ class GeminiModel(Model):
|
|
|
101
125
|
|
|
102
126
|
Args:
|
|
103
127
|
model_name: The name of the model to use.
|
|
128
|
+
provider: The provider to use for the model.
|
|
104
129
|
api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
|
|
105
130
|
will be used if available.
|
|
106
131
|
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
@@ -109,14 +134,24 @@ class GeminiModel(Model):
|
|
|
109
134
|
`model` is substituted with the model name, and `function` is added to the end of the URL.
|
|
110
135
|
"""
|
|
111
136
|
self._model_name = model_name
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
137
|
+
self._provider = provider
|
|
138
|
+
|
|
139
|
+
if provider is not None:
|
|
140
|
+
if isinstance(provider, str):
|
|
141
|
+
self._system = provider
|
|
142
|
+
self.client = infer_provider(provider).client
|
|
115
143
|
else:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
144
|
+
self._system = provider.name
|
|
145
|
+
self.client = provider.client
|
|
146
|
+
else:
|
|
147
|
+
if api_key is None:
|
|
148
|
+
if env_api_key := os.getenv('GEMINI_API_KEY'):
|
|
149
|
+
api_key = env_api_key
|
|
150
|
+
else:
|
|
151
|
+
raise UserError('API key must be provided or set in the GEMINI_API_KEY environment variable')
|
|
152
|
+
self.client = http_client or cached_async_http_client()
|
|
153
|
+
self._auth = ApiKeyAuth(api_key)
|
|
154
|
+
self._url = url_template.format(model=model_name)
|
|
120
155
|
|
|
121
156
|
@property
|
|
122
157
|
def auth(self) -> AuthProtocol:
|
|
@@ -217,17 +252,19 @@ class GeminiModel(Model):
|
|
|
217
252
|
if generation_config:
|
|
218
253
|
request_data['generation_config'] = generation_config
|
|
219
254
|
|
|
220
|
-
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
221
|
-
|
|
222
255
|
headers = {
|
|
223
256
|
'Content-Type': 'application/json',
|
|
224
257
|
'User-Agent': get_user_agent(),
|
|
225
|
-
**await self.auth.headers(),
|
|
226
258
|
}
|
|
259
|
+
if self._provider is None: # pragma: no cover
|
|
260
|
+
url = self.url + ('streamGenerateContent' if streamed else 'generateContent')
|
|
261
|
+
headers.update(await self.auth.headers())
|
|
262
|
+
else:
|
|
263
|
+
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
|
|
227
264
|
|
|
228
265
|
request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)
|
|
229
266
|
|
|
230
|
-
async with self.
|
|
267
|
+
async with self.client.stream(
|
|
231
268
|
'POST',
|
|
232
269
|
url,
|
|
233
270
|
content=request_json,
|
|
@@ -9,7 +9,9 @@ from datetime import datetime, timezone
|
|
|
9
9
|
from typing import Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
|
-
from typing_extensions import assert_never
|
|
12
|
+
from typing_extensions import assert_never, deprecated
|
|
13
|
+
|
|
14
|
+
from pydantic_ai.providers import Provider, infer_provider
|
|
13
15
|
|
|
14
16
|
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
|
|
15
17
|
from .._utils import guard_tool_call_id as _guard_tool_call_id
|
|
@@ -98,10 +100,36 @@ class OpenAIModel(Model):
|
|
|
98
100
|
_model_name: OpenAIModelName = field(repr=False)
|
|
99
101
|
_system: str | None = field(repr=False)
|
|
100
102
|
|
|
103
|
+
@overload
|
|
101
104
|
def __init__(
|
|
102
105
|
self,
|
|
103
106
|
model_name: OpenAIModelName,
|
|
104
107
|
*,
|
|
108
|
+
provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] = 'openai',
|
|
109
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
110
|
+
system: str | None = 'openai',
|
|
111
|
+
) -> None: ...
|
|
112
|
+
|
|
113
|
+
@deprecated('Use the `provider` parameter instead of `base_url`, `api_key`, `openai_client` and `http_client`.')
|
|
114
|
+
@overload
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model_name: OpenAIModelName,
|
|
118
|
+
*,
|
|
119
|
+
provider: None = None,
|
|
120
|
+
base_url: str | None = None,
|
|
121
|
+
api_key: str | None = None,
|
|
122
|
+
openai_client: AsyncOpenAI | None = None,
|
|
123
|
+
http_client: AsyncHTTPClient | None = None,
|
|
124
|
+
system_prompt_role: OpenAISystemPromptRole | None = None,
|
|
125
|
+
system: str | None = 'openai',
|
|
126
|
+
) -> None: ...
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
model_name: OpenAIModelName,
|
|
131
|
+
*,
|
|
132
|
+
provider: Literal['openai', 'deepseek'] | Provider[AsyncOpenAI] | None = None,
|
|
105
133
|
base_url: str | None = None,
|
|
106
134
|
api_key: str | None = None,
|
|
107
135
|
openai_client: AsyncOpenAI | None = None,
|
|
@@ -115,6 +143,7 @@ class OpenAIModel(Model):
|
|
|
115
143
|
model_name: The name of the OpenAI model to use. List of model names available
|
|
116
144
|
[here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
|
|
117
145
|
(Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
|
|
146
|
+
provider: The provider to use. Defaults to `'openai'`.
|
|
118
147
|
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
119
148
|
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
120
149
|
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
@@ -129,20 +158,32 @@ class OpenAIModel(Model):
|
|
|
129
158
|
customize the `base_url` and `api_key` to use a different provider.
|
|
130
159
|
"""
|
|
131
160
|
self._model_name = model_name
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
161
|
+
|
|
162
|
+
if provider is not None:
|
|
163
|
+
if isinstance(provider, str):
|
|
164
|
+
self.client = infer_provider(provider).client
|
|
165
|
+
else:
|
|
166
|
+
self.client = provider.client
|
|
167
|
+
else: # pragma: no cover
|
|
168
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
169
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
170
|
+
if (
|
|
171
|
+
api_key is None
|
|
172
|
+
and 'OPENAI_API_KEY' not in os.environ
|
|
173
|
+
and base_url is not None
|
|
174
|
+
and openai_client is None
|
|
175
|
+
):
|
|
176
|
+
api_key = 'api-key-not-set'
|
|
177
|
+
|
|
178
|
+
if openai_client is not None:
|
|
179
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
180
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
181
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
182
|
+
self.client = openai_client
|
|
183
|
+
elif http_client is not None:
|
|
184
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
|
|
185
|
+
else:
|
|
186
|
+
self.client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
146
187
|
self.system_prompt_role = system_prompt_role
|
|
147
188
|
self._system = system
|
|
148
189
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
|
+
import warnings
|
|
3
4
|
from collections.abc import AsyncIterator
|
|
4
5
|
from contextlib import asynccontextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
@@ -8,6 +9,7 @@ from pathlib import Path
|
|
|
8
9
|
from typing import Literal
|
|
9
10
|
|
|
10
11
|
from httpx import AsyncClient as AsyncHTTPClient
|
|
12
|
+
from typing_extensions import deprecated
|
|
11
13
|
|
|
12
14
|
from .. import usage
|
|
13
15
|
from .._utils import run_in_executor
|
|
@@ -55,6 +57,7 @@ The template is used thus:
|
|
|
55
57
|
"""
|
|
56
58
|
|
|
57
59
|
|
|
60
|
+
@deprecated('Please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.')
|
|
58
61
|
@dataclass(init=False)
|
|
59
62
|
class VertexAIModel(GeminiModel):
|
|
60
63
|
"""A model that uses Gemini via the `*-aiplatform.googleapis.com` VertexAI API."""
|
|
@@ -103,11 +106,16 @@ class VertexAIModel(GeminiModel):
|
|
|
103
106
|
self.project_id = project_id
|
|
104
107
|
self.region = region
|
|
105
108
|
self.model_publisher = model_publisher
|
|
106
|
-
self.
|
|
109
|
+
self.client = http_client or cached_async_http_client()
|
|
107
110
|
self.url_template = url_template
|
|
108
111
|
|
|
109
112
|
self._auth = None
|
|
110
113
|
self._url = None
|
|
114
|
+
warnings.warn(
|
|
115
|
+
'VertexAIModel is deprecated, please use `GeminiModel(provider=GoogleVertexProvider(...))` instead.',
|
|
116
|
+
DeprecationWarning,
|
|
117
|
+
)
|
|
118
|
+
self._provider = None
|
|
111
119
|
|
|
112
120
|
async def ainit(self) -> None:
|
|
113
121
|
"""Initialize the model, setting the URL and auth.
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Providers for the API clients.
|
|
2
|
+
|
|
3
|
+
The providers are in charge of providing an authenticated client to the API.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations as _annotations
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from typing import Any, Generic, TypeVar
|
|
10
|
+
|
|
11
|
+
InterfaceClient = TypeVar('InterfaceClient')
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Provider(ABC, Generic[InterfaceClient]):
|
|
15
|
+
"""Abstract class for a provider.
|
|
16
|
+
|
|
17
|
+
The provider is in charge of providing an authenticated client to the API.
|
|
18
|
+
|
|
19
|
+
Each provider only supports a specific interface. A interface can be supported by multiple providers.
|
|
20
|
+
|
|
21
|
+
For example, the OpenAIModel interface can be supported by the OpenAIProvider and the DeepSeekProvider.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
_client: InterfaceClient
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
"""The provider name."""
|
|
30
|
+
raise NotImplementedError()
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def base_url(self) -> str:
|
|
35
|
+
"""The base URL for the provider API."""
|
|
36
|
+
raise NotImplementedError()
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def client(self) -> InterfaceClient:
|
|
41
|
+
"""The client for the provider."""
|
|
42
|
+
raise NotImplementedError()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def infer_provider(provider: str) -> Provider[Any]:
|
|
46
|
+
"""Infer the provider from the provider name."""
|
|
47
|
+
if provider == 'openai':
|
|
48
|
+
from .openai import OpenAIProvider
|
|
49
|
+
|
|
50
|
+
return OpenAIProvider()
|
|
51
|
+
elif provider == 'deepseek':
|
|
52
|
+
from .deepseek import DeepSeekProvider
|
|
53
|
+
|
|
54
|
+
return DeepSeekProvider()
|
|
55
|
+
elif provider == 'google-vertex':
|
|
56
|
+
from .google_vertex import GoogleVertexProvider
|
|
57
|
+
|
|
58
|
+
return GoogleVertexProvider()
|
|
59
|
+
elif provider == 'google-gla':
|
|
60
|
+
from .google_gla import GoogleGLAProvider
|
|
61
|
+
|
|
62
|
+
return GoogleGLAProvider()
|
|
63
|
+
else: # pragma: no cover
|
|
64
|
+
raise ValueError(f'Unknown provider: {provider}')
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import overload
|
|
5
|
+
|
|
6
|
+
from httpx import AsyncClient as AsyncHTTPClient
|
|
7
|
+
from openai import AsyncOpenAI
|
|
8
|
+
|
|
9
|
+
from pydantic_ai.models import cached_async_http_client
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from openai import AsyncOpenAI
|
|
13
|
+
except ImportError as _import_error: # pragma: no cover
|
|
14
|
+
raise ImportError(
|
|
15
|
+
'Please install `openai` to use the DeepSeek provider, '
|
|
16
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
17
|
+
) from _import_error
|
|
18
|
+
|
|
19
|
+
from . import Provider
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DeepSeekProvider(Provider[AsyncOpenAI]):
|
|
23
|
+
"""Provider for DeepSeek API."""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def name(self) -> str:
|
|
27
|
+
return 'deepseek'
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def base_url(self) -> str:
|
|
31
|
+
return 'https://api.deepseek.com'
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def client(self) -> AsyncOpenAI:
|
|
35
|
+
return self._client
|
|
36
|
+
|
|
37
|
+
@overload
|
|
38
|
+
def __init__(self) -> None: ...
|
|
39
|
+
|
|
40
|
+
@overload
|
|
41
|
+
def __init__(self, *, api_key: str) -> None: ...
|
|
42
|
+
|
|
43
|
+
@overload
|
|
44
|
+
def __init__(self, *, api_key: str, http_client: AsyncHTTPClient) -> None: ...
|
|
45
|
+
|
|
46
|
+
@overload
|
|
47
|
+
def __init__(self, *, openai_client: AsyncOpenAI | None = None) -> None: ...
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
*,
|
|
52
|
+
api_key: str | None = None,
|
|
53
|
+
openai_client: AsyncOpenAI | None = None,
|
|
54
|
+
http_client: AsyncHTTPClient | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
api_key = api_key or os.getenv('DEEPSEEK_API_KEY')
|
|
57
|
+
if api_key is None and openai_client is None:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
'Set the `DEEPSEEK_API_KEY` environment variable or pass it via `DeepSeekProvider(api_key=...)`'
|
|
60
|
+
'to use the DeepSeek provider.'
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
if openai_client is not None:
|
|
64
|
+
self._client = openai_client
|
|
65
|
+
elif http_client is not None:
|
|
66
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
67
|
+
else:
|
|
68
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.models import cached_async_http_client
|
|
8
|
+
from pydantic_ai.providers import Provider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class GoogleGLAProvider(Provider[httpx.AsyncClient]):
|
|
12
|
+
"""Provider for Google Generative Language AI API."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def name(self):
|
|
16
|
+
return 'google-gla'
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def base_url(self) -> str:
|
|
20
|
+
return 'https://generativelanguage.googleapis.com/v1beta/models/'
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def client(self) -> httpx.AsyncClient:
|
|
24
|
+
return self._client
|
|
25
|
+
|
|
26
|
+
def __init__(self, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None:
|
|
27
|
+
"""Create a new Google GLA provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
api_key: The API key to use for authentication, if not provided, the `GEMINI_API_KEY` environment variable
|
|
31
|
+
will be used if available.
|
|
32
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
33
|
+
"""
|
|
34
|
+
api_key = api_key or os.environ.get('GEMINI_API_KEY')
|
|
35
|
+
if api_key is None:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
'Set the `GEMINI_API_KEY` environment variable or pass it via `GoogleGLAProvider(api_key=...)`'
|
|
38
|
+
'to use the Google GLA provider.'
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self._client = http_client or cached_async_http_client()
|
|
42
|
+
self._client.base_url = self.base_url
|
|
43
|
+
# https://cloud.google.com/docs/authentication/api-keys-use#using-with-rest
|
|
44
|
+
self._client.headers['X-Goog-Api-Key'] = api_key
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
import anyio.to_thread
|
|
10
|
+
import httpx
|
|
11
|
+
|
|
12
|
+
from pydantic_ai.exceptions import UserError
|
|
13
|
+
|
|
14
|
+
from ..models import cached_async_http_client
|
|
15
|
+
from . import Provider
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
import google.auth
|
|
19
|
+
from google.auth.credentials import Credentials as BaseCredentials
|
|
20
|
+
from google.auth.transport.requests import Request
|
|
21
|
+
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
|
22
|
+
except ImportError as _import_error:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
'Please install `google-auth` to use the Google Vertex AI provider, '
|
|
25
|
+
"you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
|
|
26
|
+
) from _import_error
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ('GoogleVertexProvider',)
|
|
30
|
+
|
|
31
|
+
# default expiry is 3600 seconds
|
|
32
|
+
MAX_TOKEN_AGE = timedelta(seconds=3000)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class GoogleVertexProvider(Provider[httpx.AsyncClient]):
|
|
36
|
+
"""Provider for Vertex AI API."""
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def name(self) -> str:
|
|
40
|
+
return 'google-vertex'
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def base_url(self) -> str:
|
|
44
|
+
return (
|
|
45
|
+
f'https://{self.region}-aiplatform.googleapis.com/v1'
|
|
46
|
+
f'/projects/{self.project_id}'
|
|
47
|
+
f'/locations/{self.region}'
|
|
48
|
+
f'/publishers/{self.model_publisher}/models/'
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def client(self) -> httpx.AsyncClient:
|
|
53
|
+
return self._client
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
service_account_file: Path | str | None = None,
|
|
58
|
+
project_id: str | None = None,
|
|
59
|
+
region: VertexAiRegion = 'us-central1',
|
|
60
|
+
model_publisher: str = 'google',
|
|
61
|
+
http_client: httpx.AsyncClient | None = None,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Create a new Vertex AI provider.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
service_account_file: Path to a service account file.
|
|
67
|
+
If not provided, the default environment credentials will be used.
|
|
68
|
+
project_id: The project ID to use, if not provided it will be taken from the credentials.
|
|
69
|
+
region: The region to make requests to.
|
|
70
|
+
model_publisher: The model publisher to use, I couldn't find a good list of available publishers,
|
|
71
|
+
and from trial and error it seems non-google models don't work with the `generateContent` and
|
|
72
|
+
`streamGenerateContent` functions, hence only `google` is currently supported.
|
|
73
|
+
Please create an issue or PR if you know how to use other publishers.
|
|
74
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
75
|
+
"""
|
|
76
|
+
self._client = http_client or cached_async_http_client()
|
|
77
|
+
self.service_account_file = service_account_file
|
|
78
|
+
self.project_id = project_id
|
|
79
|
+
self.region = region
|
|
80
|
+
self.model_publisher = model_publisher
|
|
81
|
+
|
|
82
|
+
self._client.auth = _VertexAIAuth(service_account_file, project_id, region)
|
|
83
|
+
self._client.base_url = self.base_url
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class _VertexAIAuth(httpx.Auth):
|
|
87
|
+
"""Auth class for Vertex AI API."""
|
|
88
|
+
|
|
89
|
+
credentials: BaseCredentials | ServiceAccountCredentials | None
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
service_account_file: Path | str | None = None,
|
|
94
|
+
project_id: str | None = None,
|
|
95
|
+
region: VertexAiRegion = 'us-central1',
|
|
96
|
+
) -> None:
|
|
97
|
+
self.service_account_file = service_account_file
|
|
98
|
+
self.project_id = project_id
|
|
99
|
+
self.region = region
|
|
100
|
+
|
|
101
|
+
self.credentials = None
|
|
102
|
+
self.token_created: datetime | None = None
|
|
103
|
+
|
|
104
|
+
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
|
105
|
+
if self.credentials is None:
|
|
106
|
+
self.credentials = await self._get_credentials()
|
|
107
|
+
if self.credentials.token is None or self._token_expired(): # type: ignore[reportUnknownMemberType]
|
|
108
|
+
await anyio.to_thread.run_sync(self._refresh_token)
|
|
109
|
+
self.token_created = datetime.now()
|
|
110
|
+
request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
111
|
+
|
|
112
|
+
# NOTE: This workaround is in place because we might get the project_id from the credentials.
|
|
113
|
+
request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
|
|
114
|
+
yield request
|
|
115
|
+
|
|
116
|
+
async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
|
|
117
|
+
if self.service_account_file is not None:
|
|
118
|
+
creds = await _creds_from_file(self.service_account_file)
|
|
119
|
+
assert creds.project_id is None or isinstance(creds.project_id, str) # type: ignore[reportUnknownMemberType]
|
|
120
|
+
creds_project_id: str | None = creds.project_id
|
|
121
|
+
creds_source = 'service account file'
|
|
122
|
+
else:
|
|
123
|
+
creds, creds_project_id = await _async_google_auth()
|
|
124
|
+
creds_source = '`google.auth.default()`'
|
|
125
|
+
|
|
126
|
+
if self.project_id is None:
|
|
127
|
+
if creds_project_id is None:
|
|
128
|
+
raise UserError(f'No project_id provided and none found in {creds_source}')
|
|
129
|
+
self.project_id = creds_project_id
|
|
130
|
+
return creds
|
|
131
|
+
|
|
132
|
+
def _token_expired(self) -> bool:
|
|
133
|
+
if self.token_created is None:
|
|
134
|
+
return True
|
|
135
|
+
else:
|
|
136
|
+
return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
|
|
137
|
+
|
|
138
|
+
def _refresh_token(self) -> str: # pragma: no cover
|
|
139
|
+
assert self.credentials is not None
|
|
140
|
+
self.credentials.refresh(Request()) # type: ignore[reportUnknownMemberType]
|
|
141
|
+
assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
|
|
142
|
+
return self.credentials.token
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def _async_google_auth() -> tuple[BaseCredentials, str | None]:
|
|
146
|
+
return await anyio.to_thread.run_sync(google.auth.default, ['https://www.googleapis.com/auth/cloud-platform']) # type: ignore
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
async def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCredentials:
|
|
150
|
+
service_account_credentials_from_file = functools.partial(
|
|
151
|
+
ServiceAccountCredentials.from_service_account_file, # type: ignore[reportUnknownMemberType]
|
|
152
|
+
scopes=['https://www.googleapis.com/auth/cloud-platform'],
|
|
153
|
+
)
|
|
154
|
+
return await anyio.to_thread.run_sync(service_account_credentials_from_file, str(service_account_file))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
VertexAiRegion = Literal[
|
|
158
|
+
'us-central1',
|
|
159
|
+
'us-east1',
|
|
160
|
+
'us-east4',
|
|
161
|
+
'us-south1',
|
|
162
|
+
'us-west1',
|
|
163
|
+
'us-west2',
|
|
164
|
+
'us-west3',
|
|
165
|
+
'us-west4',
|
|
166
|
+
'us-east5',
|
|
167
|
+
'europe-central2',
|
|
168
|
+
'europe-north1',
|
|
169
|
+
'europe-southwest1',
|
|
170
|
+
'europe-west1',
|
|
171
|
+
'europe-west2',
|
|
172
|
+
'europe-west3',
|
|
173
|
+
'europe-west4',
|
|
174
|
+
'europe-west6',
|
|
175
|
+
'europe-west8',
|
|
176
|
+
'europe-west9',
|
|
177
|
+
'europe-west12',
|
|
178
|
+
'africa-south1',
|
|
179
|
+
'asia-east1',
|
|
180
|
+
'asia-east2',
|
|
181
|
+
'asia-northeast1',
|
|
182
|
+
'asia-northeast2',
|
|
183
|
+
'asia-northeast3',
|
|
184
|
+
'asia-south1',
|
|
185
|
+
'asia-southeast1',
|
|
186
|
+
'asia-southeast2',
|
|
187
|
+
'australia-southeast1',
|
|
188
|
+
'australia-southeast2',
|
|
189
|
+
'me-central1',
|
|
190
|
+
'me-central2',
|
|
191
|
+
'me-west1',
|
|
192
|
+
'northamerica-northeast1',
|
|
193
|
+
'northamerica-northeast2',
|
|
194
|
+
'southamerica-east1',
|
|
195
|
+
'southamerica-west1',
|
|
196
|
+
]
|
|
197
|
+
"""Regions available for Vertex AI.
|
|
198
|
+
|
|
199
|
+
More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
|
|
200
|
+
"""
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TypeVar
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from pydantic_ai.models import cached_async_http_client
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from openai import AsyncOpenAI
|
|
12
|
+
except ImportError as _import_error: # pragma: no cover
|
|
13
|
+
raise ImportError(
|
|
14
|
+
'Please install `openai` to use the OpenAI provider, '
|
|
15
|
+
"you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
|
|
16
|
+
) from _import_error
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
from . import Provider
|
|
20
|
+
|
|
21
|
+
InterfaceClient = TypeVar('InterfaceClient')
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OpenAIProvider(Provider[AsyncOpenAI]):
|
|
25
|
+
"""Provider for OpenAI API."""
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return 'openai' # pragma: no cover
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def base_url(self) -> str:
|
|
33
|
+
return self._base_url
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def client(self) -> AsyncOpenAI:
|
|
37
|
+
return self._client
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
base_url: str | None = None,
|
|
42
|
+
api_key: str | None = None,
|
|
43
|
+
openai_client: AsyncOpenAI | None = None,
|
|
44
|
+
http_client: httpx.AsyncClient | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Create a new OpenAI provider.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
base_url: The base url for the OpenAI requests. If not provided, the `OPENAI_BASE_URL` environment variable
|
|
50
|
+
will be used if available. Otherwise, defaults to OpenAI's base url.
|
|
51
|
+
api_key: The API key to use for authentication, if not provided, the `OPENAI_API_KEY` environment variable
|
|
52
|
+
will be used if available.
|
|
53
|
+
openai_client: An existing
|
|
54
|
+
[`AsyncOpenAI`](https://github.com/openai/openai-python?tab=readme-ov-file#async-usage)
|
|
55
|
+
client to use. If provided, `base_url`, `api_key`, and `http_client` must be `None`.
|
|
56
|
+
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
|
|
57
|
+
"""
|
|
58
|
+
self._base_url = base_url or 'https://api.openai.com/v1'
|
|
59
|
+
# This is a workaround for the OpenAI client requiring an API key, whilst locally served,
|
|
60
|
+
# openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
|
|
61
|
+
if api_key is None and 'OPENAI_API_KEY' not in os.environ and openai_client is None:
|
|
62
|
+
api_key = 'api-key-not-set'
|
|
63
|
+
|
|
64
|
+
if openai_client is not None:
|
|
65
|
+
assert base_url is None, 'Cannot provide both `openai_client` and `base_url`'
|
|
66
|
+
assert http_client is None, 'Cannot provide both `openai_client` and `http_client`'
|
|
67
|
+
assert api_key is None, 'Cannot provide both `openai_client` and `api_key`'
|
|
68
|
+
self._client = openai_client
|
|
69
|
+
elif http_client is not None:
|
|
70
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=http_client)
|
|
71
|
+
else:
|
|
72
|
+
self._client = AsyncOpenAI(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "pydantic-ai-slim"
|
|
7
|
-
version = "0.0.
|
|
7
|
+
version = "0.0.33"
|
|
8
8
|
description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
|
|
9
9
|
authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
|
|
10
10
|
license = "MIT"
|
|
@@ -36,9 +36,10 @@ dependencies = [
|
|
|
36
36
|
"griffe>=1.3.2",
|
|
37
37
|
"httpx>=0.27",
|
|
38
38
|
"pydantic>=2.10",
|
|
39
|
-
"pydantic-graph==0.0.
|
|
39
|
+
"pydantic-graph==0.0.33",
|
|
40
40
|
"exceptiongroup; python_version < '3.11'",
|
|
41
41
|
"opentelemetry-api>=1.28.0",
|
|
42
|
+
"typing-inspection>=0.4.0",
|
|
42
43
|
]
|
|
43
44
|
|
|
44
45
|
[project.optional-dependencies]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|