pydantic-ai-slim 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +2 -0
- pydantic_ai/_agent_graph.py +31 -6
- pydantic_ai/agent/__init__.py +8 -8
- pydantic_ai/builtin_tools.py +35 -4
- pydantic_ai/exceptions.py +5 -0
- pydantic_ai/mcp.py +1 -22
- pydantic_ai/models/__init__.py +34 -34
- pydantic_ai/models/anthropic.py +2 -2
- pydantic_ai/models/bedrock.py +4 -4
- pydantic_ai/models/cohere.py +0 -7
- pydantic_ai/models/gemini.py +9 -2
- pydantic_ai/models/google.py +5 -7
- pydantic_ai/models/groq.py +4 -4
- pydantic_ai/models/huggingface.py +2 -2
- pydantic_ai/models/openai.py +53 -36
- pydantic_ai/providers/__init__.py +21 -12
- pydantic_ai/providers/bedrock.py +60 -16
- pydantic_ai/providers/gateway.py +60 -72
- pydantic_ai/providers/google.py +61 -23
- pydantic_ai/providers/ovhcloud.py +95 -0
- pydantic_ai/usage.py +13 -2
- {pydantic_ai_slim-1.2.0.dist-info → pydantic_ai_slim-1.3.0.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.2.0.dist-info → pydantic_ai_slim-1.3.0.dist-info}/RECORD +26 -25
- {pydantic_ai_slim-1.2.0.dist-info → pydantic_ai_slim-1.3.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.2.0.dist-info → pydantic_ai_slim-1.3.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.2.0.dist-info → pydantic_ai_slim-1.3.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -285,6 +285,8 @@ class OpenAIChatModel(Model):
|
|
|
285
285
|
'vercel',
|
|
286
286
|
'litellm',
|
|
287
287
|
'nebius',
|
|
288
|
+
'ovhcloud',
|
|
289
|
+
'gateway',
|
|
288
290
|
]
|
|
289
291
|
| Provider[AsyncOpenAI] = 'openai',
|
|
290
292
|
profile: ModelProfileSpec | None = None,
|
|
@@ -314,6 +316,8 @@ class OpenAIChatModel(Model):
|
|
|
314
316
|
'vercel',
|
|
315
317
|
'litellm',
|
|
316
318
|
'nebius',
|
|
319
|
+
'ovhcloud',
|
|
320
|
+
'gateway',
|
|
317
321
|
]
|
|
318
322
|
| Provider[AsyncOpenAI] = 'openai',
|
|
319
323
|
profile: ModelProfileSpec | None = None,
|
|
@@ -342,6 +346,8 @@ class OpenAIChatModel(Model):
|
|
|
342
346
|
'vercel',
|
|
343
347
|
'litellm',
|
|
344
348
|
'nebius',
|
|
349
|
+
'ovhcloud',
|
|
350
|
+
'gateway',
|
|
345
351
|
]
|
|
346
352
|
| Provider[AsyncOpenAI] = 'openai',
|
|
347
353
|
profile: ModelProfileSpec | None = None,
|
|
@@ -363,7 +369,7 @@ class OpenAIChatModel(Model):
|
|
|
363
369
|
self._model_name = model_name
|
|
364
370
|
|
|
365
371
|
if isinstance(provider, str):
|
|
366
|
-
provider = infer_provider(provider)
|
|
372
|
+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
|
|
367
373
|
self._provider = provider
|
|
368
374
|
self.client = provider.client
|
|
369
375
|
|
|
@@ -559,24 +565,7 @@ class OpenAIChatModel(Model):
|
|
|
559
565
|
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
|
|
560
566
|
# If you need this, please file an issue.
|
|
561
567
|
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
# Add logprobs to vendor_details if available
|
|
565
|
-
if choice.logprobs is not None and choice.logprobs.content:
|
|
566
|
-
# Convert logprobs to a serializable format
|
|
567
|
-
vendor_details['logprobs'] = [
|
|
568
|
-
{
|
|
569
|
-
'token': lp.token,
|
|
570
|
-
'bytes': lp.bytes,
|
|
571
|
-
'logprob': lp.logprob,
|
|
572
|
-
'top_logprobs': [
|
|
573
|
-
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
|
|
574
|
-
],
|
|
575
|
-
}
|
|
576
|
-
for lp in choice.logprobs.content
|
|
577
|
-
]
|
|
578
|
-
|
|
579
|
-
if choice.message.content is not None:
|
|
568
|
+
if choice.message.content:
|
|
580
569
|
items.extend(
|
|
581
570
|
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
|
|
582
571
|
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
|
|
@@ -594,6 +583,23 @@ class OpenAIChatModel(Model):
|
|
|
594
583
|
part.tool_call_id = _guard_tool_call_id(part)
|
|
595
584
|
items.append(part)
|
|
596
585
|
|
|
586
|
+
vendor_details: dict[str, Any] = {}
|
|
587
|
+
|
|
588
|
+
# Add logprobs to vendor_details if available
|
|
589
|
+
if choice.logprobs is not None and choice.logprobs.content:
|
|
590
|
+
# Convert logprobs to a serializable format
|
|
591
|
+
vendor_details['logprobs'] = [
|
|
592
|
+
{
|
|
593
|
+
'token': lp.token,
|
|
594
|
+
'bytes': lp.bytes,
|
|
595
|
+
'logprob': lp.logprob,
|
|
596
|
+
'top_logprobs': [
|
|
597
|
+
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
|
|
598
|
+
],
|
|
599
|
+
}
|
|
600
|
+
for lp in choice.logprobs.content
|
|
601
|
+
]
|
|
602
|
+
|
|
597
603
|
raw_finish_reason = choice.finish_reason
|
|
598
604
|
vendor_details['finish_reason'] = raw_finish_reason
|
|
599
605
|
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
@@ -903,7 +909,18 @@ class OpenAIResponsesModel(Model):
|
|
|
903
909
|
self,
|
|
904
910
|
model_name: OpenAIModelName,
|
|
905
911
|
*,
|
|
906
|
-
provider: Literal[
|
|
912
|
+
provider: Literal[
|
|
913
|
+
'openai',
|
|
914
|
+
'deepseek',
|
|
915
|
+
'azure',
|
|
916
|
+
'openrouter',
|
|
917
|
+
'grok',
|
|
918
|
+
'fireworks',
|
|
919
|
+
'together',
|
|
920
|
+
'nebius',
|
|
921
|
+
'ovhcloud',
|
|
922
|
+
'gateway',
|
|
923
|
+
]
|
|
907
924
|
| Provider[AsyncOpenAI] = 'openai',
|
|
908
925
|
profile: ModelProfileSpec | None = None,
|
|
909
926
|
settings: ModelSettings | None = None,
|
|
@@ -919,7 +936,7 @@ class OpenAIResponsesModel(Model):
|
|
|
919
936
|
self._model_name = model_name
|
|
920
937
|
|
|
921
938
|
if isinstance(provider, str):
|
|
922
|
-
provider = infer_provider(provider)
|
|
939
|
+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
|
|
923
940
|
self._provider = provider
|
|
924
941
|
self.client = provider.client
|
|
925
942
|
|
|
@@ -1616,21 +1633,6 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1616
1633
|
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
1617
1634
|
self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
1618
1635
|
|
|
1619
|
-
# Handle the text part of the response
|
|
1620
|
-
content = choice.delta.content
|
|
1621
|
-
if content is not None:
|
|
1622
|
-
maybe_event = self._parts_manager.handle_text_delta(
|
|
1623
|
-
vendor_part_id='content',
|
|
1624
|
-
content=content,
|
|
1625
|
-
thinking_tags=self._model_profile.thinking_tags,
|
|
1626
|
-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
1627
|
-
)
|
|
1628
|
-
if maybe_event is not None: # pragma: no branch
|
|
1629
|
-
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
|
|
1630
|
-
maybe_event.part.id = 'content'
|
|
1631
|
-
maybe_event.part.provider_name = self.provider_name
|
|
1632
|
-
yield maybe_event
|
|
1633
|
-
|
|
1634
1636
|
# The `reasoning_content` field is only present in DeepSeek models.
|
|
1635
1637
|
# https://api-docs.deepseek.com/guides/reasoning_model
|
|
1636
1638
|
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
|
|
@@ -1652,6 +1654,21 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1652
1654
|
provider_name=self.provider_name,
|
|
1653
1655
|
)
|
|
1654
1656
|
|
|
1657
|
+
# Handle the text part of the response
|
|
1658
|
+
content = choice.delta.content
|
|
1659
|
+
if content:
|
|
1660
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
1661
|
+
vendor_part_id='content',
|
|
1662
|
+
content=content,
|
|
1663
|
+
thinking_tags=self._model_profile.thinking_tags,
|
|
1664
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
1665
|
+
)
|
|
1666
|
+
if maybe_event is not None: # pragma: no branch
|
|
1667
|
+
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
|
|
1668
|
+
maybe_event.part.id = 'content'
|
|
1669
|
+
maybe_event.part.provider_name = self.provider_name
|
|
1670
|
+
yield maybe_event
|
|
1671
|
+
|
|
1655
1672
|
for dtc in choice.delta.tool_calls or []:
|
|
1656
1673
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
1657
1674
|
vendor_part_id=dtc.index,
|
|
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
|
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
9
|
from typing import Any, Generic, TypeVar
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from ..profiles import ModelProfile
|
|
12
12
|
|
|
13
13
|
InterfaceClient = TypeVar('InterfaceClient')
|
|
14
14
|
|
|
@@ -53,7 +53,7 @@ class Provider(ABC, Generic[InterfaceClient]):
|
|
|
53
53
|
|
|
54
54
|
def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
55
55
|
"""Infers the provider class from the provider name."""
|
|
56
|
-
if provider
|
|
56
|
+
if provider in ('openai', 'openai-chat', 'openai-responses'):
|
|
57
57
|
from .openai import OpenAIProvider
|
|
58
58
|
|
|
59
59
|
return OpenAIProvider
|
|
@@ -73,15 +73,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
73
73
|
from .azure import AzureProvider
|
|
74
74
|
|
|
75
75
|
return AzureProvider
|
|
76
|
-
elif provider
|
|
77
|
-
from .
|
|
76
|
+
elif provider in ('google-vertex', 'google-gla'):
|
|
77
|
+
from .google import GoogleProvider
|
|
78
78
|
|
|
79
|
-
return
|
|
80
|
-
elif provider == 'google-gla':
|
|
81
|
-
from .google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
|
|
82
|
-
|
|
83
|
-
return GoogleGLAProvider # type: ignore[reportDeprecated]
|
|
84
|
-
# NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
|
|
79
|
+
return GoogleProvider
|
|
85
80
|
elif provider == 'bedrock':
|
|
86
81
|
from .bedrock import BedrockProvider
|
|
87
82
|
|
|
@@ -146,11 +141,25 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
146
141
|
from .nebius import NebiusProvider
|
|
147
142
|
|
|
148
143
|
return NebiusProvider
|
|
144
|
+
elif provider == 'ovhcloud':
|
|
145
|
+
from .ovhcloud import OVHcloudProvider
|
|
146
|
+
|
|
147
|
+
return OVHcloudProvider
|
|
149
148
|
else: # pragma: no cover
|
|
150
149
|
raise ValueError(f'Unknown provider: {provider}')
|
|
151
150
|
|
|
152
151
|
|
|
153
152
|
def infer_provider(provider: str) -> Provider[Any]:
|
|
154
153
|
"""Infer the provider from the provider name."""
|
|
155
|
-
|
|
156
|
-
|
|
154
|
+
if provider.startswith('gateway/'):
|
|
155
|
+
from .gateway import gateway_provider
|
|
156
|
+
|
|
157
|
+
provider = provider.removeprefix('gateway/')
|
|
158
|
+
return gateway_provider(provider)
|
|
159
|
+
elif provider in ('google-vertex', 'google-gla'):
|
|
160
|
+
from .google import GoogleProvider
|
|
161
|
+
|
|
162
|
+
return GoogleProvider(vertexai=provider == 'google-vertex')
|
|
163
|
+
else:
|
|
164
|
+
provider_class = infer_provider_class(provider)
|
|
165
|
+
return provider_class()
|
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import Literal, overload
|
|
7
|
+
from typing import Any, Literal, overload
|
|
8
8
|
|
|
9
9
|
from pydantic_ai import ModelProfile
|
|
10
10
|
from pydantic_ai.exceptions import UserError
|
|
@@ -21,6 +21,8 @@ try:
|
|
|
21
21
|
from botocore.client import BaseClient
|
|
22
22
|
from botocore.config import Config
|
|
23
23
|
from botocore.exceptions import NoRegionError
|
|
24
|
+
from botocore.session import Session
|
|
25
|
+
from botocore.tokens import FrozenAuthToken
|
|
24
26
|
except ImportError as _import_error:
|
|
25
27
|
raise ImportError(
|
|
26
28
|
'Please install the `boto3` package to use the Bedrock provider, '
|
|
@@ -117,10 +119,23 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
117
119
|
def __init__(
|
|
118
120
|
self,
|
|
119
121
|
*,
|
|
122
|
+
api_key: str,
|
|
123
|
+
base_url: str | None = None,
|
|
120
124
|
region_name: str | None = None,
|
|
125
|
+
profile_name: str | None = None,
|
|
126
|
+
aws_read_timeout: float | None = None,
|
|
127
|
+
aws_connect_timeout: float | None = None,
|
|
128
|
+
) -> None: ...
|
|
129
|
+
|
|
130
|
+
@overload
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
*,
|
|
121
134
|
aws_access_key_id: str | None = None,
|
|
122
135
|
aws_secret_access_key: str | None = None,
|
|
123
136
|
aws_session_token: str | None = None,
|
|
137
|
+
base_url: str | None = None,
|
|
138
|
+
region_name: str | None = None,
|
|
124
139
|
profile_name: str | None = None,
|
|
125
140
|
aws_read_timeout: float | None = None,
|
|
126
141
|
aws_connect_timeout: float | None = None,
|
|
@@ -130,11 +145,13 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
130
145
|
self,
|
|
131
146
|
*,
|
|
132
147
|
bedrock_client: BaseClient | None = None,
|
|
133
|
-
region_name: str | None = None,
|
|
134
148
|
aws_access_key_id: str | None = None,
|
|
135
149
|
aws_secret_access_key: str | None = None,
|
|
136
150
|
aws_session_token: str | None = None,
|
|
151
|
+
base_url: str | None = None,
|
|
152
|
+
region_name: str | None = None,
|
|
137
153
|
profile_name: str | None = None,
|
|
154
|
+
api_key: str | None = None,
|
|
138
155
|
aws_read_timeout: float | None = None,
|
|
139
156
|
aws_connect_timeout: float | None = None,
|
|
140
157
|
) -> None:
|
|
@@ -142,10 +159,12 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
142
159
|
|
|
143
160
|
Args:
|
|
144
161
|
bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
162
|
+
aws_access_key_id: The AWS access key ID. If not set, the `AWS_ACCESS_KEY_ID` environment variable will be used if available.
|
|
163
|
+
aws_secret_access_key: The AWS secret access key. If not set, the `AWS_SECRET_ACCESS_KEY` environment variable will be used if available.
|
|
164
|
+
aws_session_token: The AWS session token. If not set, the `AWS_SESSION_TOKEN` environment variable will be used if available.
|
|
165
|
+
api_key: The API key for Bedrock client. Can be used instead of `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`. If not set, the `AWS_BEARER_TOKEN_BEDROCK` environment variable will be used if available.
|
|
166
|
+
base_url: The base URL for the Bedrock client.
|
|
167
|
+
region_name: The AWS region name. If not set, the `AWS_DEFAULT_REGION` environment variable will be used if available.
|
|
149
168
|
profile_name: The AWS profile name.
|
|
150
169
|
aws_read_timeout: The read timeout for Bedrock client.
|
|
151
170
|
aws_connect_timeout: The connect timeout for Bedrock client.
|
|
@@ -153,19 +172,44 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
153
172
|
if bedrock_client is not None:
|
|
154
173
|
self._client = bedrock_client
|
|
155
174
|
else:
|
|
175
|
+
read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
|
|
176
|
+
connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
|
|
177
|
+
config: dict[str, Any] = {
|
|
178
|
+
'read_timeout': read_timeout,
|
|
179
|
+
'connect_timeout': connect_timeout,
|
|
180
|
+
}
|
|
156
181
|
try:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
182
|
+
if api_key is not None:
|
|
183
|
+
session = boto3.Session(
|
|
184
|
+
botocore_session=_BearerTokenSession(api_key),
|
|
185
|
+
region_name=region_name,
|
|
186
|
+
profile_name=profile_name,
|
|
187
|
+
)
|
|
188
|
+
config['signature_version'] = 'bearer'
|
|
189
|
+
else:
|
|
190
|
+
session = boto3.Session(
|
|
191
|
+
aws_access_key_id=aws_access_key_id,
|
|
192
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
193
|
+
aws_session_token=aws_session_token,
|
|
194
|
+
region_name=region_name,
|
|
195
|
+
profile_name=profile_name,
|
|
196
|
+
)
|
|
166
197
|
self._client = session.client( # type: ignore[reportUnknownMemberType]
|
|
167
198
|
'bedrock-runtime',
|
|
168
|
-
config=Config(
|
|
199
|
+
config=Config(**config),
|
|
200
|
+
endpoint_url=base_url,
|
|
169
201
|
)
|
|
170
202
|
except NoRegionError as exc: # pragma: no cover
|
|
171
203
|
raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class _BearerTokenSession(Session):
|
|
207
|
+
def __init__(self, token: str):
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.token = token
|
|
210
|
+
|
|
211
|
+
def get_auth_token(self, **_kwargs: Any) -> FrozenAuthToken:
|
|
212
|
+
return FrozenAuthToken(self.token)
|
|
213
|
+
|
|
214
|
+
def get_credentials(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
|
|
215
|
+
return None
|
pydantic_ai/providers/gateway.py
CHANGED
|
@@ -3,14 +3,16 @@
|
|
|
3
3
|
from __future__ import annotations as _annotations
|
|
4
4
|
|
|
5
5
|
import os
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
6
7
|
from typing import TYPE_CHECKING, Any, Literal, overload
|
|
7
8
|
|
|
8
9
|
import httpx
|
|
9
10
|
|
|
10
11
|
from pydantic_ai.exceptions import UserError
|
|
11
|
-
from pydantic_ai.models import
|
|
12
|
+
from pydantic_ai.models import cached_async_http_client
|
|
12
13
|
|
|
13
14
|
if TYPE_CHECKING:
|
|
15
|
+
from botocore.client import BaseClient
|
|
14
16
|
from google.genai import Client as GoogleClient
|
|
15
17
|
from groq import AsyncGroq
|
|
16
18
|
from openai import AsyncOpenAI
|
|
@@ -18,6 +20,8 @@ if TYPE_CHECKING:
|
|
|
18
20
|
from pydantic_ai.models.anthropic import AsyncAnthropicClient
|
|
19
21
|
from pydantic_ai.providers import Provider
|
|
20
22
|
|
|
23
|
+
GATEWAY_BASE_URL = 'https://gateway.pydantic.dev/proxy'
|
|
24
|
+
|
|
21
25
|
|
|
22
26
|
@overload
|
|
23
27
|
def gateway_provider(
|
|
@@ -57,13 +61,34 @@ def gateway_provider(
|
|
|
57
61
|
) -> Provider[AsyncAnthropicClient]: ...
|
|
58
62
|
|
|
59
63
|
|
|
64
|
+
@overload
|
|
65
|
+
def gateway_provider(
|
|
66
|
+
upstream_provider: Literal['bedrock'],
|
|
67
|
+
*,
|
|
68
|
+
api_key: str | None = None,
|
|
69
|
+
base_url: str | None = None,
|
|
70
|
+
) -> Provider[BaseClient]: ...
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@overload
|
|
60
74
|
def gateway_provider(
|
|
61
|
-
upstream_provider:
|
|
75
|
+
upstream_provider: str,
|
|
76
|
+
*,
|
|
77
|
+
api_key: str | None = None,
|
|
78
|
+
base_url: str | None = None,
|
|
79
|
+
) -> Provider[Any]: ...
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def gateway_provider(
|
|
86
|
+
upstream_provider: UpstreamProvider | str,
|
|
62
87
|
*,
|
|
63
88
|
# Every provider
|
|
64
89
|
api_key: str | None = None,
|
|
65
90
|
base_url: str | None = None,
|
|
66
|
-
# OpenAI &
|
|
91
|
+
# OpenAI, Groq & Anthropic
|
|
67
92
|
http_client: httpx.AsyncClient | None = None,
|
|
68
93
|
) -> Provider[Any]:
|
|
69
94
|
"""Create a new Gateway provider.
|
|
@@ -73,25 +98,21 @@ def gateway_provider(
|
|
|
73
98
|
api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
|
|
74
99
|
environment variable will be used if available.
|
|
75
100
|
base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
|
|
76
|
-
environment variable will be used if available. Otherwise, defaults to `
|
|
101
|
+
environment variable will be used if available. Otherwise, defaults to `https://gateway.pydantic.dev/proxy`.
|
|
77
102
|
http_client: The HTTP client to use for the Gateway.
|
|
78
103
|
"""
|
|
79
104
|
api_key = api_key or os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
|
|
80
105
|
if not api_key:
|
|
81
106
|
raise UserError(
|
|
82
|
-
'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(api_key=...)`'
|
|
107
|
+
'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(..., api_key=...)`'
|
|
83
108
|
' to use the Pydantic AI Gateway provider.'
|
|
84
109
|
)
|
|
85
110
|
|
|
86
|
-
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL',
|
|
87
|
-
http_client = http_client or cached_async_http_client(provider=f'gateway
|
|
88
|
-
http_client.event_hooks = {'request': [_request_hook]}
|
|
89
|
-
|
|
90
|
-
if upstream_provider in ('openai', 'openai-chat'):
|
|
91
|
-
from .openai import OpenAIProvider
|
|
111
|
+
base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
|
|
112
|
+
http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
|
|
113
|
+
http_client.event_hooks = {'request': [_request_hook(api_key)]}
|
|
92
114
|
|
|
93
|
-
|
|
94
|
-
elif upstream_provider == 'openai-responses':
|
|
115
|
+
if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
|
|
95
116
|
from .openai import OpenAIProvider
|
|
96
117
|
|
|
97
118
|
return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
|
|
@@ -111,79 +132,46 @@ def gateway_provider(
|
|
|
111
132
|
http_client=http_client,
|
|
112
133
|
)
|
|
113
134
|
)
|
|
114
|
-
elif upstream_provider == '
|
|
115
|
-
from
|
|
135
|
+
elif upstream_provider == 'bedrock':
|
|
136
|
+
from .bedrock import BedrockProvider
|
|
116
137
|
|
|
138
|
+
return BedrockProvider(
|
|
139
|
+
api_key=api_key,
|
|
140
|
+
base_url=_merge_url_path(base_url, 'bedrock'),
|
|
141
|
+
region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
|
|
142
|
+
)
|
|
143
|
+
elif upstream_provider == 'google-vertex':
|
|
117
144
|
from .google import GoogleProvider
|
|
118
145
|
|
|
119
146
|
return GoogleProvider(
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
'base_url': _merge_url_path(base_url, 'google-vertex'),
|
|
125
|
-
'headers': {'User-Agent': get_user_agent(), 'Authorization': api_key},
|
|
126
|
-
# TODO(Marcelo): Until https://github.com/googleapis/python-genai/issues/1357 is solved.
|
|
127
|
-
'async_client_args': {
|
|
128
|
-
'transport': httpx.AsyncHTTPTransport(),
|
|
129
|
-
'event_hooks': {'request': [_request_hook]},
|
|
130
|
-
},
|
|
131
|
-
},
|
|
132
|
-
)
|
|
147
|
+
vertexai=True,
|
|
148
|
+
api_key=api_key,
|
|
149
|
+
base_url=_merge_url_path(base_url, 'google-vertex'),
|
|
150
|
+
http_client=http_client,
|
|
133
151
|
)
|
|
134
|
-
else:
|
|
135
|
-
raise UserError(f'Unknown provider: {upstream_provider}')
|
|
152
|
+
else:
|
|
153
|
+
raise UserError(f'Unknown upstream provider: {upstream_provider}')
|
|
136
154
|
|
|
137
155
|
|
|
138
|
-
def
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
|
-
Args:
|
|
142
|
-
model_name: The name of the model to infer. Must be in the format "provider/model_name".
|
|
156
|
+
def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
|
|
157
|
+
"""Request hook for the gateway provider.
|
|
143
158
|
|
|
144
|
-
|
|
145
|
-
The model class that will be used to make requests to the gateway.
|
|
159
|
+
It adds the `"traceparent"` and `"Authorization"` headers to the request.
|
|
146
160
|
"""
|
|
147
|
-
try:
|
|
148
|
-
upstream_provider, model_name = model_name.split('/', 1)
|
|
149
|
-
except ValueError:
|
|
150
|
-
raise UserError(f'The model name "{model_name}" is not in the format "provider/model_name".')
|
|
151
161
|
|
|
152
|
-
|
|
153
|
-
from
|
|
162
|
+
async def _hook(request: httpx.Request) -> httpx.Request:
|
|
163
|
+
from opentelemetry.propagate import inject
|
|
154
164
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
return OpenAIResponsesModel(model_name, provider=gateway_provider('openai'))
|
|
160
|
-
elif upstream_provider == 'groq':
|
|
161
|
-
from pydantic_ai.models.groq import GroqModel
|
|
165
|
+
headers: dict[str, Any] = {}
|
|
166
|
+
inject(headers)
|
|
167
|
+
request.headers.update(headers)
|
|
162
168
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
from pydantic_ai.models.anthropic import AnthropicModel
|
|
166
|
-
|
|
167
|
-
return AnthropicModel(model_name, provider=gateway_provider('anthropic'))
|
|
168
|
-
elif upstream_provider == 'google-vertex':
|
|
169
|
-
from pydantic_ai.models.google import GoogleModel
|
|
170
|
-
|
|
171
|
-
return GoogleModel(model_name, provider=gateway_provider('google-vertex'))
|
|
172
|
-
raise UserError(f'Unknown upstream provider: {upstream_provider}')
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
async def _request_hook(request: httpx.Request) -> httpx.Request:
|
|
176
|
-
"""Request hook for the gateway provider.
|
|
177
|
-
|
|
178
|
-
It adds the `"traceparent"` header to the request.
|
|
179
|
-
"""
|
|
180
|
-
from opentelemetry.propagate import inject
|
|
169
|
+
if 'Authorization' not in request.headers:
|
|
170
|
+
request.headers['Authorization'] = f'Bearer {api_key}'
|
|
181
171
|
|
|
182
|
-
|
|
183
|
-
inject(headers)
|
|
184
|
-
request.headers.update(headers)
|
|
172
|
+
return request
|
|
185
173
|
|
|
186
|
-
return
|
|
174
|
+
return _hook
|
|
187
175
|
|
|
188
176
|
|
|
189
177
|
def _merge_url_path(base_url: str, path: str) -> str:
|