pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
+ import json
4
5
  import warnings
5
6
  from collections.abc import AsyncIterable, AsyncIterator, Sequence
6
7
  from contextlib import asynccontextmanager
@@ -17,7 +18,7 @@ from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
17
18
  from .._run_context import RunContext
18
19
  from .._thinking_part import split_content_into_text_and_thinking
19
20
  from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
20
- from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, WebSearchTool
21
+ from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
21
22
  from ..exceptions import UserError
22
23
  from ..messages import (
23
24
  AudioUrl,
@@ -109,6 +110,11 @@ Using this more broad type for the model name instead of the ChatModel definitio
109
110
  allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
110
111
  """
111
112
 
113
+ MCP_SERVER_TOOL_CONNECTOR_URI_SCHEME: Literal['x-openai-connector'] = 'x-openai-connector'
114
+ """
115
+ Prefix for OpenAI connector IDs. OpenAI supports either a URL or a connector ID when passing MCP configuration to a model,
116
+ by using that prefix like `x-openai-connector:<connector-id>` in a URL, you can pass a connector ID to a model.
117
+ """
112
118
 
113
119
  _CHAT_FINISH_REASON_MAP: dict[
114
120
  Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason
@@ -285,6 +291,8 @@ class OpenAIChatModel(Model):
285
291
  'vercel',
286
292
  'litellm',
287
293
  'nebius',
294
+ 'ovhcloud',
295
+ 'gateway',
288
296
  ]
289
297
  | Provider[AsyncOpenAI] = 'openai',
290
298
  profile: ModelProfileSpec | None = None,
@@ -314,6 +322,8 @@ class OpenAIChatModel(Model):
314
322
  'vercel',
315
323
  'litellm',
316
324
  'nebius',
325
+ 'ovhcloud',
326
+ 'gateway',
317
327
  ]
318
328
  | Provider[AsyncOpenAI] = 'openai',
319
329
  profile: ModelProfileSpec | None = None,
@@ -342,6 +352,8 @@ class OpenAIChatModel(Model):
342
352
  'vercel',
343
353
  'litellm',
344
354
  'nebius',
355
+ 'ovhcloud',
356
+ 'gateway',
345
357
  ]
346
358
  | Provider[AsyncOpenAI] = 'openai',
347
359
  profile: ModelProfileSpec | None = None,
@@ -363,7 +375,7 @@ class OpenAIChatModel(Model):
363
375
  self._model_name = model_name
364
376
 
365
377
  if isinstance(provider, str):
366
- provider = infer_provider(provider)
378
+ provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
367
379
  self._provider = provider
368
380
  self.client = provider.client
369
381
 
@@ -559,24 +571,7 @@ class OpenAIChatModel(Model):
559
571
  # - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
560
572
  # If you need this, please file an issue.
561
573
 
562
- vendor_details: dict[str, Any] = {}
563
-
564
- # Add logprobs to vendor_details if available
565
- if choice.logprobs is not None and choice.logprobs.content:
566
- # Convert logprobs to a serializable format
567
- vendor_details['logprobs'] = [
568
- {
569
- 'token': lp.token,
570
- 'bytes': lp.bytes,
571
- 'logprob': lp.logprob,
572
- 'top_logprobs': [
573
- {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
574
- ],
575
- }
576
- for lp in choice.logprobs.content
577
- ]
578
-
579
- if choice.message.content is not None:
574
+ if choice.message.content:
580
575
  items.extend(
581
576
  (replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
582
577
  for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
@@ -594,6 +589,23 @@ class OpenAIChatModel(Model):
594
589
  part.tool_call_id = _guard_tool_call_id(part)
595
590
  items.append(part)
596
591
 
592
+ vendor_details: dict[str, Any] = {}
593
+
594
+ # Add logprobs to vendor_details if available
595
+ if choice.logprobs is not None and choice.logprobs.content:
596
+ # Convert logprobs to a serializable format
597
+ vendor_details['logprobs'] = [
598
+ {
599
+ 'token': lp.token,
600
+ 'bytes': lp.bytes,
601
+ 'logprob': lp.logprob,
602
+ 'top_logprobs': [
603
+ {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
604
+ ],
605
+ }
606
+ for lp in choice.logprobs.content
607
+ ]
608
+
597
609
  raw_finish_reason = choice.finish_reason
598
610
  vendor_details['finish_reason'] = raw_finish_reason
599
611
  finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
@@ -903,7 +915,18 @@ class OpenAIResponsesModel(Model):
903
915
  self,
904
916
  model_name: OpenAIModelName,
905
917
  *,
906
- provider: Literal['openai', 'deepseek', 'azure', 'openrouter', 'grok', 'fireworks', 'together', 'nebius']
918
+ provider: Literal[
919
+ 'openai',
920
+ 'deepseek',
921
+ 'azure',
922
+ 'openrouter',
923
+ 'grok',
924
+ 'fireworks',
925
+ 'together',
926
+ 'nebius',
927
+ 'ovhcloud',
928
+ 'gateway',
929
+ ]
907
930
  | Provider[AsyncOpenAI] = 'openai',
908
931
  profile: ModelProfileSpec | None = None,
909
932
  settings: ModelSettings | None = None,
@@ -919,7 +942,7 @@ class OpenAIResponsesModel(Model):
919
942
  self._model_name = model_name
920
943
 
921
944
  if isinstance(provider, str):
922
- provider = infer_provider(provider)
945
+ provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
923
946
  self._provider = provider
924
947
  self.client = provider.client
925
948
 
@@ -1044,13 +1067,16 @@ class OpenAIResponsesModel(Model):
1044
1067
  elif isinstance(item, responses.ResponseFileSearchToolCall): # pragma: no cover
1045
1068
  # Pydantic AI doesn't yet support the FileSearch built-in tool
1046
1069
  pass
1047
- elif isinstance( # pragma: no cover
1048
- item,
1049
- responses.response_output_item.McpCall
1050
- | responses.response_output_item.McpListTools
1051
- | responses.response_output_item.McpApprovalRequest,
1052
- ):
1053
- # Pydantic AI supports MCP natively
1070
+ elif isinstance(item, responses.response_output_item.McpCall):
1071
+ call_part, return_part = _map_mcp_call(item, self.system)
1072
+ items.append(call_part)
1073
+ items.append(return_part)
1074
+ elif isinstance(item, responses.response_output_item.McpListTools):
1075
+ call_part, return_part = _map_mcp_list_tools(item, self.system)
1076
+ items.append(call_part)
1077
+ items.append(return_part)
1078
+ elif isinstance(item, responses.response_output_item.McpApprovalRequest): # pragma: no cover
1079
+ # Pydantic AI doesn't yet support McpApprovalRequest (explicit tool usage approval)
1054
1080
  pass
1055
1081
 
1056
1082
  finish_reason: FinishReason | None = None
@@ -1239,6 +1265,32 @@ class OpenAIResponsesModel(Model):
1239
1265
  elif isinstance(tool, CodeExecutionTool):
1240
1266
  has_image_generating_tool = True
1241
1267
  tools.append({'type': 'code_interpreter', 'container': {'type': 'auto'}})
1268
+ elif isinstance(tool, MCPServerTool):
1269
+ mcp_tool = responses.tool_param.Mcp(
1270
+ type='mcp',
1271
+ server_label=tool.id,
1272
+ require_approval='never',
1273
+ )
1274
+
1275
+ if tool.authorization_token: # pragma: no branch
1276
+ mcp_tool['authorization'] = tool.authorization_token
1277
+
1278
+ if tool.allowed_tools is not None: # pragma: no branch
1279
+ mcp_tool['allowed_tools'] = tool.allowed_tools
1280
+
1281
+ if tool.description: # pragma: no branch
1282
+ mcp_tool['server_description'] = tool.description
1283
+
1284
+ if tool.headers: # pragma: no branch
1285
+ mcp_tool['headers'] = tool.headers
1286
+
1287
+ if tool.url.startswith(MCP_SERVER_TOOL_CONNECTOR_URI_SCHEME + ':'):
1288
+ _, connector_id = tool.url.split(':', maxsplit=1)
1289
+ mcp_tool['connector_id'] = connector_id # pyright: ignore[reportGeneralTypeIssues]
1290
+ else:
1291
+ mcp_tool['server_url'] = tool.url
1292
+
1293
+ tools.append(mcp_tool)
1242
1294
  elif isinstance(tool, ImageGenerationTool): # pragma: no branch
1243
1295
  has_image_generating_tool = True
1244
1296
  tools.append(
@@ -1411,7 +1463,7 @@ class OpenAIResponsesModel(Model):
1411
1463
  type='web_search_call',
1412
1464
  )
1413
1465
  openai_messages.append(web_search_item)
1414
- elif item.tool_name == ImageGenerationTool.kind and item.tool_call_id: # pragma: no branch
1466
+ elif item.tool_name == ImageGenerationTool.kind and item.tool_call_id:
1415
1467
  # The cast is necessary because of https://github.com/openai/openai-python/issues/2648
1416
1468
  image_generation_item = cast(
1417
1469
  responses.response_input_item_param.ImageGenerationCall,
@@ -1421,6 +1473,37 @@ class OpenAIResponsesModel(Model):
1421
1473
  },
1422
1474
  )
1423
1475
  openai_messages.append(image_generation_item)
1476
+ elif ( # pragma: no branch
1477
+ item.tool_name.startswith(MCPServerTool.kind)
1478
+ and item.tool_call_id
1479
+ and (server_id := item.tool_name.split(':', 1)[1])
1480
+ and (args := item.args_as_dict())
1481
+ and (action := args.get('action'))
1482
+ ):
1483
+ if action == 'list_tools':
1484
+ mcp_list_tools_item = responses.response_input_item_param.McpListTools(
1485
+ id=item.tool_call_id,
1486
+ type='mcp_list_tools',
1487
+ server_label=server_id,
1488
+ tools=[], # These can be read server-side
1489
+ )
1490
+ openai_messages.append(mcp_list_tools_item)
1491
+ elif ( # pragma: no branch
1492
+ action == 'call_tool'
1493
+ and (tool_name := args.get('tool_name'))
1494
+ and (tool_args := args.get('tool_args'))
1495
+ ):
1496
+ mcp_call_item = responses.response_input_item_param.McpCall(
1497
+ id=item.tool_call_id,
1498
+ server_label=server_id,
1499
+ name=tool_name,
1500
+ arguments=to_json(tool_args).decode(),
1501
+ error=None, # These can be read server-side
1502
+ output=None, # These can be read server-side
1503
+ type='mcp_call',
1504
+ )
1505
+ openai_messages.append(mcp_call_item)
1506
+
1424
1507
  elif isinstance(item, BuiltinToolReturnPart):
1425
1508
  if item.provider_name == self.system and send_item_ids:
1426
1509
  if (
@@ -1439,9 +1522,12 @@ class OpenAIResponsesModel(Model):
1439
1522
  and (status := content.get('status'))
1440
1523
  ):
1441
1524
  web_search_item['status'] = status
1442
- elif item.tool_name == ImageGenerationTool.kind: # pragma: no branch
1525
+ elif item.tool_name == ImageGenerationTool.kind:
1443
1526
  # Image generation result does not need to be sent back, just the `id` off of `BuiltinToolCallPart`.
1444
1527
  pass
1528
+ elif item.tool_name.startswith(MCPServerTool.kind): # pragma: no branch
1529
+ # MCP call result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
1530
+ pass
1445
1531
  elif isinstance(item, FilePart):
1446
1532
  # This was generated by the `ImageGenerationTool` or `CodeExecutionTool`,
1447
1533
  # and does not need to be sent back separately from the corresponding `BuiltinToolReturnPart`.
@@ -1616,21 +1702,6 @@ class OpenAIStreamedResponse(StreamedResponse):
1616
1702
  self.provider_details = {'finish_reason': raw_finish_reason}
1617
1703
  self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
1618
1704
 
1619
- # Handle the text part of the response
1620
- content = choice.delta.content
1621
- if content is not None:
1622
- maybe_event = self._parts_manager.handle_text_delta(
1623
- vendor_part_id='content',
1624
- content=content,
1625
- thinking_tags=self._model_profile.thinking_tags,
1626
- ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1627
- )
1628
- if maybe_event is not None: # pragma: no branch
1629
- if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1630
- maybe_event.part.id = 'content'
1631
- maybe_event.part.provider_name = self.provider_name
1632
- yield maybe_event
1633
-
1634
1705
  # The `reasoning_content` field is only present in DeepSeek models.
1635
1706
  # https://api-docs.deepseek.com/guides/reasoning_model
1636
1707
  if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
@@ -1652,6 +1723,21 @@ class OpenAIStreamedResponse(StreamedResponse):
1652
1723
  provider_name=self.provider_name,
1653
1724
  )
1654
1725
 
1726
+ # Handle the text part of the response
1727
+ content = choice.delta.content
1728
+ if content:
1729
+ maybe_event = self._parts_manager.handle_text_delta(
1730
+ vendor_part_id='content',
1731
+ content=content,
1732
+ thinking_tags=self._model_profile.thinking_tags,
1733
+ ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
1734
+ )
1735
+ if maybe_event is not None: # pragma: no branch
1736
+ if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
1737
+ maybe_event.part.id = 'content'
1738
+ maybe_event.part.provider_name = self.provider_name
1739
+ yield maybe_event
1740
+
1655
1741
  for dtc in choice.delta.tool_calls or []:
1656
1742
  maybe_event = self._parts_manager.handle_tool_call_delta(
1657
1743
  vendor_part_id=dtc.index,
@@ -1755,7 +1841,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1755
1841
  args_json = call_part.args_as_json_str()
1756
1842
  # Drop the final `"}` so that we can add code deltas
1757
1843
  args_json_delta = args_json[:-2]
1758
- assert args_json_delta.endswith('code":"')
1844
+ assert args_json_delta.endswith('"code":"'), f'Expected {args_json_delta!r} to end in `"code":"`'
1759
1845
 
1760
1846
  yield self._parts_manager.handle_part(
1761
1847
  vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None)
@@ -1769,7 +1855,28 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1769
1855
  elif isinstance(chunk.item, responses.response_output_item.ImageGenerationCall):
1770
1856
  call_part, _, _ = _map_image_generation_tool_call(chunk.item, self.provider_name)
1771
1857
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
1858
+ elif isinstance(chunk.item, responses.response_output_item.McpCall):
1859
+ call_part, _ = _map_mcp_call(chunk.item, self.provider_name)
1772
1860
 
1861
+ args_json = call_part.args_as_json_str()
1862
+ # Drop the final `{}}` so that we can add tool args deltas
1863
+ args_json_delta = args_json[:-3]
1864
+ assert args_json_delta.endswith('"tool_args":'), (
1865
+ f'Expected {args_json_delta!r} to end in `"tool_args":"`'
1866
+ )
1867
+
1868
+ yield self._parts_manager.handle_part(
1869
+ vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None)
1870
+ )
1871
+ maybe_event = self._parts_manager.handle_tool_call_delta(
1872
+ vendor_part_id=f'{chunk.item.id}-call',
1873
+ args=args_json_delta,
1874
+ )
1875
+ if maybe_event is not None: # pragma: no branch
1876
+ yield maybe_event
1877
+ elif isinstance(chunk.item, responses.response_output_item.McpListTools):
1878
+ call_part, _ = _map_mcp_list_tools(chunk.item, self.provider_name)
1879
+ yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
1773
1880
  else:
1774
1881
  warnings.warn( # pragma: no cover
1775
1882
  f'Handling of this item type is not yet implemented. Please report on our GitHub: {chunk}',
@@ -1810,6 +1917,13 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1810
1917
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-file', part=file_part)
1811
1918
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1812
1919
 
1920
+ elif isinstance(chunk.item, responses.response_output_item.McpCall):
1921
+ _, return_part = _map_mcp_call(chunk.item, self.provider_name)
1922
+ yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1923
+ elif isinstance(chunk.item, responses.response_output_item.McpListTools):
1924
+ _, return_part = _map_mcp_list_tools(chunk.item, self.provider_name)
1925
+ yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
1926
+
1813
1927
  elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
1814
1928
  yield self._parts_manager.handle_thinking_delta(
1815
1929
  vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
@@ -1904,6 +2018,40 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
1904
2018
  )
1905
2019
  yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item_id}-file', part=file_part)
1906
2020
 
2021
+ elif isinstance(chunk, responses.ResponseMcpCallArgumentsDoneEvent):
2022
+ maybe_event = self._parts_manager.handle_tool_call_delta(
2023
+ vendor_part_id=f'{chunk.item_id}-call',
2024
+ args='}',
2025
+ )
2026
+ if maybe_event is not None: # pragma: no branch
2027
+ yield maybe_event
2028
+
2029
+ elif isinstance(chunk, responses.ResponseMcpCallArgumentsDeltaEvent):
2030
+ maybe_event = self._parts_manager.handle_tool_call_delta(
2031
+ vendor_part_id=f'{chunk.item_id}-call',
2032
+ args=chunk.delta,
2033
+ )
2034
+ if maybe_event is not None: # pragma: no branch
2035
+ yield maybe_event
2036
+
2037
+ elif isinstance(chunk, responses.ResponseMcpListToolsInProgressEvent):
2038
+ pass # there's nothing we need to do here
2039
+
2040
+ elif isinstance(chunk, responses.ResponseMcpListToolsCompletedEvent):
2041
+ pass # there's nothing we need to do here
2042
+
2043
+ elif isinstance(chunk, responses.ResponseMcpListToolsFailedEvent): # pragma: no cover
2044
+ pass # there's nothing we need to do here
2045
+
2046
+ elif isinstance(chunk, responses.ResponseMcpCallInProgressEvent):
2047
+ pass # there's nothing we need to do here
2048
+
2049
+ elif isinstance(chunk, responses.ResponseMcpCallFailedEvent): # pragma: no cover
2050
+ pass # there's nothing we need to do here
2051
+
2052
+ elif isinstance(chunk, responses.ResponseMcpCallCompletedEvent):
2053
+ pass # there's nothing we need to do here
2054
+
1907
2055
  else: # pragma: no cover
1908
2056
  warnings.warn(
1909
2057
  f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
@@ -1973,7 +2121,6 @@ def _map_usage(
1973
2121
  def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
1974
2122
  # When reasoning, the Responses API requires the `ResponseFunctionToolCall` to be returned with both the `call_id` and `id` fields.
1975
2123
  # Before our `ToolCallPart` gained the `id` field alongside `tool_call_id` field, we combined the two fields into a single string stored on `tool_call_id`.
1976
-
1977
2124
  if '|' in combined_id:
1978
2125
  call_id, id = combined_id.split('|', 1)
1979
2126
  return call_id, id
@@ -2013,7 +2160,7 @@ def _map_code_interpreter_tool_call(
2013
2160
  tool_call_id=item.id,
2014
2161
  args={
2015
2162
  'container_id': item.container_id,
2016
- 'code': item.code,
2163
+ 'code': item.code or '',
2017
2164
  },
2018
2165
  provider_name=provider_name,
2019
2166
  ),
@@ -2105,3 +2252,50 @@ def _map_image_generation_tool_call(
2105
2252
  ),
2106
2253
  file_part,
2107
2254
  )
2255
+
2256
+
2257
+ def _map_mcp_list_tools(
2258
+ item: responses.response_output_item.McpListTools, provider_name: str
2259
+ ) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2260
+ tool_name = ':'.join([MCPServerTool.kind, item.server_label])
2261
+ return (
2262
+ BuiltinToolCallPart(
2263
+ tool_name=tool_name,
2264
+ tool_call_id=item.id,
2265
+ provider_name=provider_name,
2266
+ args={'action': 'list_tools'},
2267
+ ),
2268
+ BuiltinToolReturnPart(
2269
+ tool_name=tool_name,
2270
+ tool_call_id=item.id,
2271
+ content=item.model_dump(mode='json', include={'tools', 'error'}),
2272
+ provider_name=provider_name,
2273
+ ),
2274
+ )
2275
+
2276
+
2277
+ def _map_mcp_call(
2278
+ item: responses.response_output_item.McpCall, provider_name: str
2279
+ ) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
2280
+ tool_name = ':'.join([MCPServerTool.kind, item.server_label])
2281
+ return (
2282
+ BuiltinToolCallPart(
2283
+ tool_name=tool_name,
2284
+ tool_call_id=item.id,
2285
+ args={
2286
+ 'action': 'call_tool',
2287
+ 'tool_name': item.name,
2288
+ 'tool_args': json.loads(item.arguments) if item.arguments else {},
2289
+ },
2290
+ provider_name=provider_name,
2291
+ ),
2292
+ BuiltinToolReturnPart(
2293
+ tool_name=tool_name,
2294
+ tool_call_id=item.id,
2295
+ content={
2296
+ 'output': item.output,
2297
+ 'error': item.error,
2298
+ },
2299
+ provider_name=provider_name,
2300
+ ),
2301
+ )
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
8
8
  from abc import ABC, abstractmethod
9
9
  from typing import Any, Generic, TypeVar
10
10
 
11
- from pydantic_ai import ModelProfile
11
+ from ..profiles import ModelProfile
12
12
 
13
13
  InterfaceClient = TypeVar('InterfaceClient')
14
14
 
@@ -53,7 +53,7 @@ class Provider(ABC, Generic[InterfaceClient]):
53
53
 
54
54
  def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
55
55
  """Infers the provider class from the provider name."""
56
- if provider == 'openai':
56
+ if provider in ('openai', 'openai-chat', 'openai-responses'):
57
57
  from .openai import OpenAIProvider
58
58
 
59
59
  return OpenAIProvider
@@ -73,15 +73,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
73
73
  from .azure import AzureProvider
74
74
 
75
75
  return AzureProvider
76
- elif provider == 'google-vertex':
77
- from .google_vertex import GoogleVertexProvider # type: ignore[reportDeprecated]
76
+ elif provider in ('google-vertex', 'google-gla'):
77
+ from .google import GoogleProvider
78
78
 
79
- return GoogleVertexProvider # type: ignore[reportDeprecated]
80
- elif provider == 'google-gla':
81
- from .google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
82
-
83
- return GoogleGLAProvider # type: ignore[reportDeprecated]
84
- # NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
79
+ return GoogleProvider
85
80
  elif provider == 'bedrock':
86
81
  from .bedrock import BedrockProvider
87
82
 
@@ -146,11 +141,25 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
146
141
  from .nebius import NebiusProvider
147
142
 
148
143
  return NebiusProvider
144
+ elif provider == 'ovhcloud':
145
+ from .ovhcloud import OVHcloudProvider
146
+
147
+ return OVHcloudProvider
149
148
  else: # pragma: no cover
150
149
  raise ValueError(f'Unknown provider: {provider}')
151
150
 
152
151
 
153
152
  def infer_provider(provider: str) -> Provider[Any]:
154
153
  """Infer the provider from the provider name."""
155
- provider_class = infer_provider_class(provider)
156
- return provider_class()
154
+ if provider.startswith('gateway/'):
155
+ from .gateway import gateway_provider
156
+
157
+ provider = provider.removeprefix('gateway/')
158
+ return gateway_provider(provider)
159
+ elif provider in ('google-vertex', 'google-gla'):
160
+ from .google import GoogleProvider
161
+
162
+ return GoogleProvider(vertexai=provider == 'google-vertex')
163
+ else:
164
+ provider_class = infer_provider_class(provider)
165
+ return provider_class()
@@ -4,7 +4,7 @@ import os
4
4
  import re
5
5
  from collections.abc import Callable
6
6
  from dataclasses import dataclass
7
- from typing import Literal, overload
7
+ from typing import Any, Literal, overload
8
8
 
9
9
  from pydantic_ai import ModelProfile
10
10
  from pydantic_ai.exceptions import UserError
@@ -21,6 +21,8 @@ try:
21
21
  from botocore.client import BaseClient
22
22
  from botocore.config import Config
23
23
  from botocore.exceptions import NoRegionError
24
+ from botocore.session import Session
25
+ from botocore.tokens import FrozenAuthToken
24
26
  except ImportError as _import_error:
25
27
  raise ImportError(
26
28
  'Please install the `boto3` package to use the Bedrock provider, '
@@ -117,10 +119,23 @@ class BedrockProvider(Provider[BaseClient]):
117
119
  def __init__(
118
120
  self,
119
121
  *,
122
+ api_key: str,
123
+ base_url: str | None = None,
120
124
  region_name: str | None = None,
125
+ profile_name: str | None = None,
126
+ aws_read_timeout: float | None = None,
127
+ aws_connect_timeout: float | None = None,
128
+ ) -> None: ...
129
+
130
+ @overload
131
+ def __init__(
132
+ self,
133
+ *,
121
134
  aws_access_key_id: str | None = None,
122
135
  aws_secret_access_key: str | None = None,
123
136
  aws_session_token: str | None = None,
137
+ base_url: str | None = None,
138
+ region_name: str | None = None,
124
139
  profile_name: str | None = None,
125
140
  aws_read_timeout: float | None = None,
126
141
  aws_connect_timeout: float | None = None,
@@ -130,11 +145,13 @@ class BedrockProvider(Provider[BaseClient]):
130
145
  self,
131
146
  *,
132
147
  bedrock_client: BaseClient | None = None,
133
- region_name: str | None = None,
134
148
  aws_access_key_id: str | None = None,
135
149
  aws_secret_access_key: str | None = None,
136
150
  aws_session_token: str | None = None,
151
+ base_url: str | None = None,
152
+ region_name: str | None = None,
137
153
  profile_name: str | None = None,
154
+ api_key: str | None = None,
138
155
  aws_read_timeout: float | None = None,
139
156
  aws_connect_timeout: float | None = None,
140
157
  ) -> None:
@@ -142,10 +159,12 @@ class BedrockProvider(Provider[BaseClient]):
142
159
 
143
160
  Args:
144
161
  bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
145
- region_name: The AWS region name.
146
- aws_access_key_id: The AWS access key ID.
147
- aws_secret_access_key: The AWS secret access key.
148
- aws_session_token: The AWS session token.
162
+ aws_access_key_id: The AWS access key ID. If not set, the `AWS_ACCESS_KEY_ID` environment variable will be used if available.
163
+ aws_secret_access_key: The AWS secret access key. If not set, the `AWS_SECRET_ACCESS_KEY` environment variable will be used if available.
164
+ aws_session_token: The AWS session token. If not set, the `AWS_SESSION_TOKEN` environment variable will be used if available.
165
+ api_key: The API key for Bedrock client. Can be used instead of `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`. If not set, the `AWS_BEARER_TOKEN_BEDROCK` environment variable will be used if available.
166
+ base_url: The base URL for the Bedrock client.
167
+ region_name: The AWS region name. If not set, the `AWS_DEFAULT_REGION` environment variable will be used if available.
149
168
  profile_name: The AWS profile name.
150
169
  aws_read_timeout: The read timeout for Bedrock client.
151
170
  aws_connect_timeout: The connect timeout for Bedrock client.
@@ -153,19 +172,44 @@ class BedrockProvider(Provider[BaseClient]):
153
172
  if bedrock_client is not None:
154
173
  self._client = bedrock_client
155
174
  else:
175
+ read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
176
+ connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
177
+ config: dict[str, Any] = {
178
+ 'read_timeout': read_timeout,
179
+ 'connect_timeout': connect_timeout,
180
+ }
156
181
  try:
157
- read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
158
- connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
159
- session = boto3.Session(
160
- aws_access_key_id=aws_access_key_id,
161
- aws_secret_access_key=aws_secret_access_key,
162
- aws_session_token=aws_session_token,
163
- region_name=region_name,
164
- profile_name=profile_name,
165
- )
182
+ if api_key is not None:
183
+ session = boto3.Session(
184
+ botocore_session=_BearerTokenSession(api_key),
185
+ region_name=region_name,
186
+ profile_name=profile_name,
187
+ )
188
+ config['signature_version'] = 'bearer'
189
+ else:
190
+ session = boto3.Session(
191
+ aws_access_key_id=aws_access_key_id,
192
+ aws_secret_access_key=aws_secret_access_key,
193
+ aws_session_token=aws_session_token,
194
+ region_name=region_name,
195
+ profile_name=profile_name,
196
+ )
166
197
  self._client = session.client( # type: ignore[reportUnknownMemberType]
167
198
  'bedrock-runtime',
168
- config=Config(read_timeout=read_timeout, connect_timeout=connect_timeout),
199
+ config=Config(**config),
200
+ endpoint_url=base_url,
169
201
  )
170
202
  except NoRegionError as exc: # pragma: no cover
171
203
  raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
204
+
205
+
206
+ class _BearerTokenSession(Session):
207
+ def __init__(self, token: str):
208
+ super().__init__()
209
+ self.token = token
210
+
211
+ def get_auth_token(self, **_kwargs: Any) -> FrozenAuthToken:
212
+ return FrozenAuthToken(self.token)
213
+
214
+ def get_credentials(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
215
+ return None