pydantic-ai-slim 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -285,6 +285,8 @@ class OpenAIChatModel(Model):
285
285
  'vercel',
286
286
  'litellm',
287
287
  'nebius',
288
+ 'ovhcloud',
289
+ 'gateway',
288
290
  ]
289
291
  | Provider[AsyncOpenAI] = 'openai',
290
292
  profile: ModelProfileSpec | None = None,
@@ -314,6 +316,8 @@ class OpenAIChatModel(Model):
314
316
  'vercel',
315
317
  'litellm',
316
318
  'nebius',
319
+ 'ovhcloud',
320
+ 'gateway',
317
321
  ]
318
322
  | Provider[AsyncOpenAI] = 'openai',
319
323
  profile: ModelProfileSpec | None = None,
@@ -342,6 +346,8 @@ class OpenAIChatModel(Model):
342
346
  'vercel',
343
347
  'litellm',
344
348
  'nebius',
349
+ 'ovhcloud',
350
+ 'gateway',
345
351
  ]
346
352
  | Provider[AsyncOpenAI] = 'openai',
347
353
  profile: ModelProfileSpec | None = None,
@@ -363,7 +369,7 @@ class OpenAIChatModel(Model):
363
369
  self._model_name = model_name
364
370
 
365
371
  if isinstance(provider, str):
366
- provider = infer_provider(provider)
372
+ provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
367
373
  self._provider = provider
368
374
  self.client = provider.client
369
375
 
@@ -559,24 +565,7 @@ class OpenAIChatModel(Model):
559
565
  # - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
560
566
  # If you need this, please file an issue.
561
567
 
562
- vendor_details: dict[str, Any] = {}
563
-
564
- # Add logprobs to vendor_details if available
565
- if choice.logprobs is not None and choice.logprobs.content:
566
- # Convert logprobs to a serializable format
567
- vendor_details['logprobs'] = [
568
- {
569
- 'token': lp.token,
570
- 'bytes': lp.bytes,
571
- 'logprob': lp.logprob,
572
- 'top_logprobs': [
573
- {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
574
- ],
575
- }
576
- for lp in choice.logprobs.content
577
- ]
578
-
579
- if choice.message.content is not None:
568
+ if choice.message.content:
580
569
  items.extend(
581
570
  (replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
582
571
  for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
@@ -594,6 +583,23 @@ class OpenAIChatModel(Model):
594
583
  part.tool_call_id = _guard_tool_call_id(part)
595
584
  items.append(part)
596
585
 
586
+ vendor_details: dict[str, Any] = {}
587
+
588
+ # Add logprobs to vendor_details if available
589
+ if choice.logprobs is not None and choice.logprobs.content:
590
+ # Convert logprobs to a serializable format
591
+ vendor_details['logprobs'] = [
592
+ {
593
+ 'token': lp.token,
594
+ 'bytes': lp.bytes,
595
+ 'logprob': lp.logprob,
596
+ 'top_logprobs': [
597
+ {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
598
+ ],
599
+ }
600
+ for lp in choice.logprobs.content
601
+ ]
602
+
597
603
  raw_finish_reason = choice.finish_reason
598
604
  vendor_details['finish_reason'] = raw_finish_reason
599
605
  finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
@@ -903,7 +909,18 @@ class OpenAIResponsesModel(Model):
903
909
  self,
904
910
  model_name: OpenAIModelName,
905
911
  *,
906
- provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'nebius']
912
+ provider: Literal[
913
+ 'openai',
914
+ 'deepseek',
915
+ 'azure',
916
+ 'openrouter',
917
+ 'grok',
918
+ 'fireworks',
919
+ 'together',
920
+ 'nebius',
921
+ 'ovhcloud',
922
+ 'gateway',
923
+ ]
907
924
  | Provider[AsyncOpenAI] = 'openai',
908
925
  profile: ModelProfileSpec | None = None,
909
926
  settings: ModelSettings | None = None,
@@ -919,7 +936,7 @@ class OpenAIResponsesModel(Model):
919
936
  self._model_name = model_name
920
937
 
921
938
  if isinstance(provider, str):
922
- provider = infer_provider(provider)
939
+ provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
923
940
  self._provider = provider
924
941
  self.client = provider.client
925
942
 
@@ -1616,21 +1633,6 @@ class OpenAIStreamedResponse(StreamedResponse):
1616
1633
  self.provider_details = {'finish_reason': raw_finish_reason}
1617
1634
  self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
1618
1635
 
1619
- # Handle the text part of the response
1620
- content = choice.delta.content
1621
- if content is not None:
1622
- maybe_event = self._parts_manager.handle_text_delta(
1623
- vendor_part_id='content',
1624
- content=content,
1625
- thinking_tags=self._model_profile.thinking_tags,
1626
- ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1627
- )
1628
- if maybe_event is not None: # pragma: no branch
1629
- if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1630
- maybe_event.part.id = 'content'
1631
- maybe_event.part.provider_name = self.provider_name
1632
- yield maybe_event
1633
-
1634
1636
  # The `reasoning_content` field is only present in DeepSeek models.
1635
1637
  # https://api-docs.deepseek.com/guides/reasoning_model
1636
1638
  if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
@@ -1652,6 +1654,21 @@ class OpenAIStreamedResponse(StreamedResponse):
1652
1654
  provider_name=self.provider_name,
1653
1655
  )
1654
1656
 
1657
+ # Handle the text part of the response
1658
+ content = choice.delta.content
1659
+ if content:
1660
+ maybe_event = self._parts_manager.handle_text_delta(
1661
+ vendor_part_id='content',
1662
+ content=content,
1663
+ thinking_tags=self._model_profile.thinking_tags,
1664
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1665
+ )
1666
+ if maybe_event is not None: # pragma: no branch
1667
+ if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1668
+ maybe_event.part.id = 'content'
1669
+ maybe_event.part.provider_name = self.provider_name
1670
+ yield maybe_event
1671
+
1655
1672
  for dtc in choice.delta.tool_calls or []:
1656
1673
  maybe_event = self._parts_manager.handle_tool_call_delta(
1657
1674
  vendor_part_id=dtc.index,
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
8
8
  from abc import ABC, abstractmethod
9
9
  from typing import Any, Generic, TypeVar
10
10
 
11
- from pydantic_ai import ModelProfile
11
+ from ..profiles import ModelProfile
12
12
 
13
13
  InterfaceClient = TypeVar('InterfaceClient')
14
14
 
@@ -53,7 +53,7 @@ class Provider(ABC, Generic[InterfaceClient]):
53
53
 
54
54
  def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
55
55
  """Infers the provider class from the provider name."""
56
- if provider == 'openai':
56
+ if provider in ('openai', 'openai-chat', 'openai-responses'):
57
57
  from .openai import OpenAIProvider
58
58
 
59
59
  return OpenAIProvider
@@ -73,15 +73,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
73
73
  from .azure import AzureProvider
74
74
 
75
75
  return AzureProvider
76
- elif provider == 'google-vertex':
77
- from .google_vertex import GoogleVertexProvider # type: ignore[reportDeprecated]
76
+ elif provider in ('google-vertex', 'google-gla'):
77
+ from .google import GoogleProvider
78
78
 
79
- return GoogleVertexProvider # type: ignore[reportDeprecated]
80
- elif provider == 'google-gla':
81
- from .google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
82
-
83
- return GoogleGLAProvider # type: ignore[reportDeprecated]
84
- # NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
79
+ return GoogleProvider
85
80
  elif provider == 'bedrock':
86
81
  from .bedrock import BedrockProvider
87
82
 
@@ -146,11 +141,25 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
146
141
  from .nebius import NebiusProvider
147
142
 
148
143
  return NebiusProvider
144
+ elif provider == 'ovhcloud':
145
+ from .ovhcloud import OVHcloudProvider
146
+
147
+ return OVHcloudProvider
149
148
  else: # pragma: no cover
150
149
  raise ValueError(f'Unknown provider: {provider}')
151
150
 
152
151
 
153
152
  def infer_provider(provider: str) -> Provider[Any]:
154
153
  """Infer the provider from the provider name."""
155
- provider_class = infer_provider_class(provider)
156
- return provider_class()
154
+ if provider.startswith('gateway/'):
155
+ from .gateway import gateway_provider
156
+
157
+ provider = provider.removeprefix('gateway/')
158
+ return gateway_provider(provider)
159
+ elif provider in ('google-vertex', 'google-gla'):
160
+ from .google import GoogleProvider
161
+
162
+ return GoogleProvider(vertexai=provider == 'google-vertex')
163
+ else:
164
+ provider_class = infer_provider_class(provider)
165
+ return provider_class()
@@ -4,7 +4,7 @@ import os
4
4
  import re
5
5
  from collections.abc import Callable
6
6
  from dataclasses import dataclass
7
- from typing import Literal, overload
7
+ from typing import Any, Literal, overload
8
8
 
9
9
  from pydantic_ai import ModelProfile
10
10
  from pydantic_ai.exceptions import UserError
@@ -21,6 +21,8 @@ try:
21
21
  from botocore.client import BaseClient
22
22
  from botocore.config import Config
23
23
  from botocore.exceptions import NoRegionError
24
+ from botocore.session import Session
25
+ from botocore.tokens import FrozenAuthToken
24
26
  except ImportError as _import_error:
25
27
  raise ImportError(
26
28
  'Please install the `boto3` package to use the Bedrock provider, '
@@ -117,10 +119,23 @@ class BedrockProvider(Provider[BaseClient]):
117
119
  def __init__(
118
120
  self,
119
121
  *,
122
+ api_key: str,
123
+ base_url: str | None = None,
120
124
  region_name: str | None = None,
125
+ profile_name: str | None = None,
126
+ aws_read_timeout: float | None = None,
127
+ aws_connect_timeout: float | None = None,
128
+ ) -> None: ...
129
+
130
+ @overload
131
+ def __init__(
132
+ self,
133
+ *,
121
134
  aws_access_key_id: str | None = None,
122
135
  aws_secret_access_key: str | None = None,
123
136
  aws_session_token: str | None = None,
137
+ base_url: str | None = None,
138
+ region_name: str | None = None,
124
139
  profile_name: str | None = None,
125
140
  aws_read_timeout: float | None = None,
126
141
  aws_connect_timeout: float | None = None,
@@ -130,11 +145,13 @@ class BedrockProvider(Provider[BaseClient]):
130
145
  self,
131
146
  *,
132
147
  bedrock_client: BaseClient | None = None,
133
- region_name: str | None = None,
134
148
  aws_access_key_id: str | None = None,
135
149
  aws_secret_access_key: str | None = None,
136
150
  aws_session_token: str | None = None,
151
+ base_url: str | None = None,
152
+ region_name: str | None = None,
137
153
  profile_name: str | None = None,
154
+ api_key: str | None = None,
138
155
  aws_read_timeout: float | None = None,
139
156
  aws_connect_timeout: float | None = None,
140
157
  ) -> None:
@@ -142,10 +159,12 @@ class BedrockProvider(Provider[BaseClient]):
142
159
 
143
160
  Args:
144
161
  bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
145
- region_name: The AWS region name.
146
- aws_access_key_id: The AWS access key ID.
147
- aws_secret_access_key: The AWS secret access key.
148
- aws_session_token: The AWS session token.
162
+ aws_access_key_id: The AWS access key ID. If not set, the `AWS_ACCESS_KEY_ID` environment variable will be used if available.
163
+ aws_secret_access_key: The AWS secret access key. If not set, the `AWS_SECRET_ACCESS_KEY` environment variable will be used if available.
164
+ aws_session_token: The AWS session token. If not set, the `AWS_SESSION_TOKEN` environment variable will be used if available.
165
+ api_key: The API key for Bedrock client. Can be used instead of `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`. If not set, the `AWS_BEARER_TOKEN_BEDROCK` environment variable will be used if available.
166
+ base_url: The base URL for the Bedrock client.
167
+ region_name: The AWS region name. If not set, the `AWS_DEFAULT_REGION` environment variable will be used if available.
149
168
  profile_name: The AWS profile name.
150
169
  aws_read_timeout: The read timeout for Bedrock client.
151
170
  aws_connect_timeout: The connect timeout for Bedrock client.
@@ -153,19 +172,44 @@ class BedrockProvider(Provider[BaseClient]):
153
172
  if bedrock_client is not None:
154
173
  self._client = bedrock_client
155
174
  else:
175
+ read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
176
+ connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
177
+ config: dict[str, Any] = {
178
+ 'read_timeout': read_timeout,
179
+ 'connect_timeout': connect_timeout,
180
+ }
156
181
  try:
157
- read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
158
- connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
159
- session = boto3.Session(
160
- aws_access_key_id=aws_access_key_id,
161
- aws_secret_access_key=aws_secret_access_key,
162
- aws_session_token=aws_session_token,
163
- region_name=region_name,
164
- profile_name=profile_name,
165
- )
182
+ if api_key is not None:
183
+ session = boto3.Session(
184
+ botocore_session=_BearerTokenSession(api_key),
185
+ region_name=region_name,
186
+ profile_name=profile_name,
187
+ )
188
+ config['signature_version'] = 'bearer'
189
+ else:
190
+ session = boto3.Session(
191
+ aws_access_key_id=aws_access_key_id,
192
+ aws_secret_access_key=aws_secret_access_key,
193
+ aws_session_token=aws_session_token,
194
+ region_name=region_name,
195
+ profile_name=profile_name,
196
+ )
166
197
  self._client = session.client( # type: ignore[reportUnknownMemberType]
167
198
  'bedrock-runtime',
168
- config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout),
199
+ config=Config(**config),
200
+ endpoint_url=base_url,
169
201
  )
170
202
  except NoRegionError as exc: # pragma: no cover
171
203
  raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
204
+
205
+
206
+ class _BearerTokenSession(Session):
207
+ def __init__(self, token: str):
208
+ super().__init__()
209
+ self.token = token
210
+
211
+ def get_auth_token(self, **_kwargs: Any) -> FrozenAuthToken:
212
+ return FrozenAuthToken(self.token)
213
+
214
+ def get_credentials(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
215
+ return None
@@ -3,14 +3,16 @@
3
3
  from __future__ import annotations as _annotations
4
4
 
5
5
  import os
6
+ from collections.abc import Awaitable, Callable
6
7
  from typing import TYPE_CHECKING, Any, Literal, overload
7
8
 
8
9
  import httpx
9
10
 
10
11
  from pydantic_ai.exceptions import UserError
11
- from pydantic_ai.models import Model, cached_async_http_client, get_user_agent
12
+ from pydantic_ai.models import cached_async_http_client
12
13
 
13
14
  if TYPE_CHECKING:
15
+ from botocore.client import BaseClient
14
16
  from google.genai import Client as GoogleClient
15
17
  from groq import AsyncGroq
16
18
  from openai import AsyncOpenAI
@@ -18,6 +20,8 @@ if TYPE_CHECKING:
18
20
  from pydantic_ai.models.anthropic import AsyncAnthropicClient
19
21
  from pydantic_ai.providers import Provider
20
22
 
23
+ GATEWAY_BASE_URL = 'https://gateway.pydantic.dev/proxy'
24
+
21
25
 
22
26
  @overload
23
27
  def gateway_provider(
@@ -57,13 +61,34 @@ def gateway_provider(
57
61
  ) -> Provider[AsyncAnthropicClient]: ...
58
62
 
59
63
 
64
+ @overload
65
+ def gateway_provider(
66
+ upstream_provider: Literal['bedrock'],
67
+ *,
68
+ api_key: str | None = None,
69
+ base_url: str | None = None,
70
+ ) -> Provider[BaseClient]: ...
71
+
72
+
73
+ @overload
60
74
  def gateway_provider(
61
- upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic'] | str,
75
+ upstream_provider: str,
76
+ *,
77
+ api_key: str | None = None,
78
+ base_url: str | None = None,
79
+ ) -> Provider[Any]: ...
80
+
81
+
82
+ UpstreamProvider = Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic', 'bedrock']
83
+
84
+
85
+ def gateway_provider(
86
+ upstream_provider: UpstreamProvider | str,
62
87
  *,
63
88
  # Every provider
64
89
  api_key: str | None = None,
65
90
  base_url: str | None = None,
66
- # OpenAI & Groq
91
+ # OpenAI, Groq & Anthropic
67
92
  http_client: httpx.AsyncClient | None = None,
68
93
  ) -> Provider[Any]:
69
94
  """Create a new Gateway provider.
@@ -73,25 +98,21 @@ def gateway_provider(
73
98
  api_key: The API key to use for authentication. If not provided, the `PYDANTIC_AI_GATEWAY_API_KEY`
74
99
  environment variable will be used if available.
75
100
  base_url: The base URL to use for the Gateway. If not provided, the `PYDANTIC_AI_GATEWAY_BASE_URL`
76
- environment variable will be used if available. Otherwise, defaults to `http://localhost:8787/`.
101
+ environment variable will be used if available. Otherwise, defaults to `https://gateway.pydantic.dev/proxy`.
77
102
  http_client: The HTTP client to use for the Gateway.
78
103
  """
79
104
  api_key = api_key or os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
80
105
  if not api_key:
81
106
  raise UserError(
82
- 'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(api_key=...)`'
107
+ 'Set the `PYDANTIC_AI_GATEWAY_API_KEY` environment variable or pass it via `gateway_provider(..., api_key=...)`'
83
108
  ' to use the Pydantic AI Gateway provider.'
84
109
  )
85
110
 
86
- base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', 'https://gateway.pydantic.dev/proxy')
87
- http_client = http_client or cached_async_http_client(provider=f'gateway-{upstream_provider}')
88
- http_client.event_hooks = {'request': [_request_hook]}
89
-
90
- if upstream_provider in ('openai', 'openai-chat'):
91
- from .openai import OpenAIProvider
111
+ base_url = base_url or os.getenv('PYDANTIC_AI_GATEWAY_BASE_URL', GATEWAY_BASE_URL)
112
+ http_client = http_client or cached_async_http_client(provider=f'gateway/{upstream_provider}')
113
+ http_client.event_hooks = {'request': [_request_hook(api_key)]}
92
114
 
93
- return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
94
- elif upstream_provider == 'openai-responses':
115
+ if upstream_provider in ('openai', 'openai-chat', 'openai-responses'):
95
116
  from .openai import OpenAIProvider
96
117
 
97
118
  return OpenAIProvider(api_key=api_key, base_url=_merge_url_path(base_url, 'openai'), http_client=http_client)
@@ -111,79 +132,46 @@ def gateway_provider(
111
132
  http_client=http_client,
112
133
  )
113
134
  )
114
- elif upstream_provider == 'google-vertex':
115
- from google.genai import Client as GoogleClient
135
+ elif upstream_provider == 'bedrock':
136
+ from .bedrock import BedrockProvider
116
137
 
138
+ return BedrockProvider(
139
+ api_key=api_key,
140
+ base_url=_merge_url_path(base_url, 'bedrock'),
141
+ region_name='pydantic-ai-gateway', # Fake region name to avoid NoRegionError
142
+ )
143
+ elif upstream_provider == 'google-vertex':
117
144
  from .google import GoogleProvider
118
145
 
119
146
  return GoogleProvider(
120
- client=GoogleClient(
121
- vertexai=True,
122
- api_key='unset',
123
- http_options={
124
- 'base_url': _merge_url_path(base_url, 'google-vertex'),
125
- 'headers': {'User-Agent': get_user_agent(), 'Authorization': api_key},
126
- # TODO(Marcelo): Until https://github.com/googleapis/python-genai/issues/1357 is solved.
127
- 'async_client_args': {
128
- 'transport': httpx.AsyncHTTPTransport(),
129
- 'event_hooks': {'request': [_request_hook]},
130
- },
131
- },
132
- )
147
+ vertexai=True,
148
+ api_key=api_key,
149
+ base_url=_merge_url_path(base_url, 'google-vertex'),
150
+ http_client=http_client,
133
151
  )
134
- else: # pragma: no cover
135
- raise UserError(f'Unknown provider: {upstream_provider}')
152
+ else:
153
+ raise UserError(f'Unknown upstream provider: {upstream_provider}')
136
154
 
137
155
 
138
- def infer_model(model_name: str) -> Model:
139
- """Infer the model class that will be used to make requests to the gateway.
140
-
141
- Args:
142
- model_name: The name of the model to infer. Must be in the format "provider/model_name".
156
+ def _request_hook(api_key: str) -> Callable[[httpx.Request], Awaitable[httpx.Request]]:
157
+ """Request hook for the gateway provider.
143
158
 
144
- Returns:
145
- The model class that will be used to make requests to the gateway.
159
+ It adds the `"traceparent"` and `"Authorization"` headers to the request.
146
160
  """
147
- try:
148
- upstream_provider, model_name = model_name.split('/', 1)
149
- except ValueError:
150
- raise UserError(f'The model name "{model_name}" is not in the format "provider/model_name".')
151
161
 
152
- if upstream_provider in ('openai', 'openai-chat'):
153
- from pydantic_ai.models.openai import OpenAIChatModel
162
+ async def _hook(request: httpx.Request) -> httpx.Request:
163
+ from opentelemetry.propagate import inject
154
164
 
155
- return OpenAIChatModel(model_name, provider=gateway_provider('openai'))
156
- elif upstream_provider == 'openai-responses':
157
- from pydantic_ai.models.openai import OpenAIResponsesModel
158
-
159
- return OpenAIResponsesModel(model_name, provider=gateway_provider('openai'))
160
- elif upstream_provider == 'groq':
161
- from pydantic_ai.models.groq import GroqModel
165
+ headers: dict[str, Any] = {}
166
+ inject(headers)
167
+ request.headers.update(headers)
162
168
 
163
- return GroqModel(model_name, provider=gateway_provider('groq'))
164
- elif upstream_provider == 'anthropic':
165
- from pydantic_ai.models.anthropic import AnthropicModel
166
-
167
- return AnthropicModel(model_name, provider=gateway_provider('anthropic'))
168
- elif upstream_provider == 'google-vertex':
169
- from pydantic_ai.models.google import GoogleModel
170
-
171
- return GoogleModel(model_name, provider=gateway_provider('google-vertex'))
172
- raise UserError(f'Unknown upstream provider: {upstream_provider}')
173
-
174
-
175
- async def _request_hook(request: httpx.Request) -> httpx.Request:
176
- """Request hook for the gateway provider.
177
-
178
- It adds the `"traceparent"` header to the request.
179
- """
180
- from opentelemetry.propagate import inject
169
+ if 'Authorization' not in request.headers:
170
+ request.headers['Authorization'] = f'Bearer {api_key}'
181
171
 
182
- headers: dict[str, Any] = {}
183
- inject(headers)
184
- request.headers.update(headers)
172
+ return request
185
173
 
186
- return request
174
+ return _hook
187
175
 
188
176
 
189
177
  def _merge_url_path(base_url: str, path: str) -> str: