pydantic-ai-slim 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- pydantic_ai/__init__.py +4 -0
- pydantic_ai/_agent_graph.py +41 -8
- pydantic_ai/agent/__init__.py +11 -19
- pydantic_ai/builtin_tools.py +106 -4
- pydantic_ai/exceptions.py +5 -0
- pydantic_ai/mcp.py +1 -22
- pydantic_ai/models/__init__.py +45 -37
- pydantic_ai/models/anthropic.py +132 -11
- pydantic_ai/models/bedrock.py +4 -4
- pydantic_ai/models/cohere.py +0 -7
- pydantic_ai/models/gemini.py +9 -2
- pydantic_ai/models/google.py +31 -21
- pydantic_ai/models/groq.py +4 -4
- pydantic_ai/models/huggingface.py +2 -2
- pydantic_ai/models/openai.py +243 -49
- pydantic_ai/providers/__init__.py +21 -12
- pydantic_ai/providers/bedrock.py +60 -16
- pydantic_ai/providers/gateway.py +60 -72
- pydantic_ai/providers/google.py +61 -23
- pydantic_ai/providers/ovhcloud.py +95 -0
- pydantic_ai/usage.py +13 -2
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.4.0.dist-info}/METADATA +5 -5
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.4.0.dist-info}/RECORD +26 -25
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.4.0.dist-info}/WHEEL +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.4.0.dist-info}/entry_points.txt +0 -0
- {pydantic_ai_slim-1.2.1.dist-info → pydantic_ai_slim-1.4.0.dist-info}/licenses/LICENSE +0 -0
pydantic_ai/models/openai.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations as _annotations
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
|
+
import json
|
|
4
5
|
import warnings
|
|
5
6
|
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
6
7
|
from contextlib import asynccontextmanager
|
|
@@ -17,7 +18,7 @@ from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
|
|
|
17
18
|
from .._run_context import RunContext
|
|
18
19
|
from .._thinking_part import split_content_into_text_and_thinking
|
|
19
20
|
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
|
|
20
|
-
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, WebSearchTool
|
|
21
|
+
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
|
|
21
22
|
from ..exceptions import UserError
|
|
22
23
|
from ..messages import (
|
|
23
24
|
AudioUrl,
|
|
@@ -109,6 +110,11 @@ Using this more broad type for the model name instead of the ChatModel definitio
|
|
|
109
110
|
allows this model to be used more easily with other model types (ie, Ollama, Deepseek).
|
|
110
111
|
"""
|
|
111
112
|
|
|
113
|
+
MCP_SERVER_TOOL_CONNECTOR_URI_SCHEME: Literal['x-openai-connector'] = 'x-openai-connector'
|
|
114
|
+
"""
|
|
115
|
+
Prefix for OpenAI connector IDs. OpenAI supports either a URL or a connector ID when passing MCP configuration to a model,
|
|
116
|
+
by using that prefix like `x-openai-connector:<connector-id>` in a URL, you can pass a connector ID to a model.
|
|
117
|
+
"""
|
|
112
118
|
|
|
113
119
|
_CHAT_FINISH_REASON_MAP: dict[
|
|
114
120
|
Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'], FinishReason
|
|
@@ -285,6 +291,8 @@ class OpenAIChatModel(Model):
|
|
|
285
291
|
'vercel',
|
|
286
292
|
'litellm',
|
|
287
293
|
'nebius',
|
|
294
|
+
'ovhcloud',
|
|
295
|
+
'gateway',
|
|
288
296
|
]
|
|
289
297
|
| Provider[AsyncOpenAI] = 'openai',
|
|
290
298
|
profile: ModelProfileSpec | None = None,
|
|
@@ -314,6 +322,8 @@ class OpenAIChatModel(Model):
|
|
|
314
322
|
'vercel',
|
|
315
323
|
'litellm',
|
|
316
324
|
'nebius',
|
|
325
|
+
'ovhcloud',
|
|
326
|
+
'gateway',
|
|
317
327
|
]
|
|
318
328
|
| Provider[AsyncOpenAI] = 'openai',
|
|
319
329
|
profile: ModelProfileSpec | None = None,
|
|
@@ -342,6 +352,8 @@ class OpenAIChatModel(Model):
|
|
|
342
352
|
'vercel',
|
|
343
353
|
'litellm',
|
|
344
354
|
'nebius',
|
|
355
|
+
'ovhcloud',
|
|
356
|
+
'gateway',
|
|
345
357
|
]
|
|
346
358
|
| Provider[AsyncOpenAI] = 'openai',
|
|
347
359
|
profile: ModelProfileSpec | None = None,
|
|
@@ -363,7 +375,7 @@ class OpenAIChatModel(Model):
|
|
|
363
375
|
self._model_name = model_name
|
|
364
376
|
|
|
365
377
|
if isinstance(provider, str):
|
|
366
|
-
provider = infer_provider(provider)
|
|
378
|
+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
|
|
367
379
|
self._provider = provider
|
|
368
380
|
self.client = provider.client
|
|
369
381
|
|
|
@@ -559,24 +571,7 @@ class OpenAIChatModel(Model):
|
|
|
559
571
|
# - https://openrouter.ai/docs/use-cases/reasoning-tokens#preserving-reasoning-blocks
|
|
560
572
|
# If you need this, please file an issue.
|
|
561
573
|
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
# Add logprobs to vendor_details if available
|
|
565
|
-
if choice.logprobs is not None and choice.logprobs.content:
|
|
566
|
-
# Convert logprobs to a serializable format
|
|
567
|
-
vendor_details['logprobs'] = [
|
|
568
|
-
{
|
|
569
|
-
'token': lp.token,
|
|
570
|
-
'bytes': lp.bytes,
|
|
571
|
-
'logprob': lp.logprob,
|
|
572
|
-
'top_logprobs': [
|
|
573
|
-
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
|
|
574
|
-
],
|
|
575
|
-
}
|
|
576
|
-
for lp in choice.logprobs.content
|
|
577
|
-
]
|
|
578
|
-
|
|
579
|
-
if choice.message.content is not None:
|
|
574
|
+
if choice.message.content:
|
|
580
575
|
items.extend(
|
|
581
576
|
(replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part)
|
|
582
577
|
for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags)
|
|
@@ -594,6 +589,23 @@ class OpenAIChatModel(Model):
|
|
|
594
589
|
part.tool_call_id = _guard_tool_call_id(part)
|
|
595
590
|
items.append(part)
|
|
596
591
|
|
|
592
|
+
vendor_details: dict[str, Any] = {}
|
|
593
|
+
|
|
594
|
+
# Add logprobs to vendor_details if available
|
|
595
|
+
if choice.logprobs is not None and choice.logprobs.content:
|
|
596
|
+
# Convert logprobs to a serializable format
|
|
597
|
+
vendor_details['logprobs'] = [
|
|
598
|
+
{
|
|
599
|
+
'token': lp.token,
|
|
600
|
+
'bytes': lp.bytes,
|
|
601
|
+
'logprob': lp.logprob,
|
|
602
|
+
'top_logprobs': [
|
|
603
|
+
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
|
|
604
|
+
],
|
|
605
|
+
}
|
|
606
|
+
for lp in choice.logprobs.content
|
|
607
|
+
]
|
|
608
|
+
|
|
597
609
|
raw_finish_reason = choice.finish_reason
|
|
598
610
|
vendor_details['finish_reason'] = raw_finish_reason
|
|
599
611
|
finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
@@ -903,7 +915,18 @@ class OpenAIResponsesModel(Model):
|
|
|
903
915
|
self,
|
|
904
916
|
model_name: OpenAIModelName,
|
|
905
917
|
*,
|
|
906
|
-
provider: Literal[
|
|
918
|
+
provider: Literal[
|
|
919
|
+
'openai',
|
|
920
|
+
'deepseek',
|
|
921
|
+
'azure',
|
|
922
|
+
'openrouter',
|
|
923
|
+
'grok',
|
|
924
|
+
'fireworks',
|
|
925
|
+
'together',
|
|
926
|
+
'nebius',
|
|
927
|
+
'ovhcloud',
|
|
928
|
+
'gateway',
|
|
929
|
+
]
|
|
907
930
|
| Provider[AsyncOpenAI] = 'openai',
|
|
908
931
|
profile: ModelProfileSpec | None = None,
|
|
909
932
|
settings: ModelSettings | None = None,
|
|
@@ -919,7 +942,7 @@ class OpenAIResponsesModel(Model):
|
|
|
919
942
|
self._model_name = model_name
|
|
920
943
|
|
|
921
944
|
if isinstance(provider, str):
|
|
922
|
-
provider = infer_provider(provider)
|
|
945
|
+
provider = infer_provider('gateway/openai' if provider == 'gateway' else provider)
|
|
923
946
|
self._provider = provider
|
|
924
947
|
self.client = provider.client
|
|
925
948
|
|
|
@@ -1044,13 +1067,16 @@ class OpenAIResponsesModel(Model):
|
|
|
1044
1067
|
elif isinstance(item, responses.ResponseFileSearchToolCall): # pragma: no cover
|
|
1045
1068
|
# Pydantic AI doesn't yet support the FileSearch built-in tool
|
|
1046
1069
|
pass
|
|
1047
|
-
elif isinstance(
|
|
1048
|
-
item,
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1070
|
+
elif isinstance(item, responses.response_output_item.McpCall):
|
|
1071
|
+
call_part, return_part = _map_mcp_call(item, self.system)
|
|
1072
|
+
items.append(call_part)
|
|
1073
|
+
items.append(return_part)
|
|
1074
|
+
elif isinstance(item, responses.response_output_item.McpListTools):
|
|
1075
|
+
call_part, return_part = _map_mcp_list_tools(item, self.system)
|
|
1076
|
+
items.append(call_part)
|
|
1077
|
+
items.append(return_part)
|
|
1078
|
+
elif isinstance(item, responses.response_output_item.McpApprovalRequest): # pragma: no cover
|
|
1079
|
+
# Pydantic AI doesn't yet support McpApprovalRequest (explicit tool usage approval)
|
|
1054
1080
|
pass
|
|
1055
1081
|
|
|
1056
1082
|
finish_reason: FinishReason | None = None
|
|
@@ -1239,6 +1265,32 @@ class OpenAIResponsesModel(Model):
|
|
|
1239
1265
|
elif isinstance(tool, CodeExecutionTool):
|
|
1240
1266
|
has_image_generating_tool = True
|
|
1241
1267
|
tools.append({'type': 'code_interpreter', 'container': {'type': 'auto'}})
|
|
1268
|
+
elif isinstance(tool, MCPServerTool):
|
|
1269
|
+
mcp_tool = responses.tool_param.Mcp(
|
|
1270
|
+
type='mcp',
|
|
1271
|
+
server_label=tool.id,
|
|
1272
|
+
require_approval='never',
|
|
1273
|
+
)
|
|
1274
|
+
|
|
1275
|
+
if tool.authorization_token: # pragma: no branch
|
|
1276
|
+
mcp_tool['authorization'] = tool.authorization_token
|
|
1277
|
+
|
|
1278
|
+
if tool.allowed_tools is not None: # pragma: no branch
|
|
1279
|
+
mcp_tool['allowed_tools'] = tool.allowed_tools
|
|
1280
|
+
|
|
1281
|
+
if tool.description: # pragma: no branch
|
|
1282
|
+
mcp_tool['server_description'] = tool.description
|
|
1283
|
+
|
|
1284
|
+
if tool.headers: # pragma: no branch
|
|
1285
|
+
mcp_tool['headers'] = tool.headers
|
|
1286
|
+
|
|
1287
|
+
if tool.url.startswith(MCP_SERVER_TOOL_CONNECTOR_URI_SCHEME + ':'):
|
|
1288
|
+
_, connector_id = tool.url.split(':', maxsplit=1)
|
|
1289
|
+
mcp_tool['connector_id'] = connector_id # pyright: ignore[reportGeneralTypeIssues]
|
|
1290
|
+
else:
|
|
1291
|
+
mcp_tool['server_url'] = tool.url
|
|
1292
|
+
|
|
1293
|
+
tools.append(mcp_tool)
|
|
1242
1294
|
elif isinstance(tool, ImageGenerationTool): # pragma: no branch
|
|
1243
1295
|
has_image_generating_tool = True
|
|
1244
1296
|
tools.append(
|
|
@@ -1411,7 +1463,7 @@ class OpenAIResponsesModel(Model):
|
|
|
1411
1463
|
type='web_search_call',
|
|
1412
1464
|
)
|
|
1413
1465
|
openai_messages.append(web_search_item)
|
|
1414
|
-
elif item.tool_name == ImageGenerationTool.kind and item.tool_call_id:
|
|
1466
|
+
elif item.tool_name == ImageGenerationTool.kind and item.tool_call_id:
|
|
1415
1467
|
# The cast is necessary because of https://github.com/openai/openai-python/issues/2648
|
|
1416
1468
|
image_generation_item = cast(
|
|
1417
1469
|
responses.response_input_item_param.ImageGenerationCall,
|
|
@@ -1421,6 +1473,37 @@ class OpenAIResponsesModel(Model):
|
|
|
1421
1473
|
},
|
|
1422
1474
|
)
|
|
1423
1475
|
openai_messages.append(image_generation_item)
|
|
1476
|
+
elif ( # pragma: no branch
|
|
1477
|
+
item.tool_name.startswith(MCPServerTool.kind)
|
|
1478
|
+
and item.tool_call_id
|
|
1479
|
+
and (server_id := item.tool_name.split(':', 1)[1])
|
|
1480
|
+
and (args := item.args_as_dict())
|
|
1481
|
+
and (action := args.get('action'))
|
|
1482
|
+
):
|
|
1483
|
+
if action == 'list_tools':
|
|
1484
|
+
mcp_list_tools_item = responses.response_input_item_param.McpListTools(
|
|
1485
|
+
id=item.tool_call_id,
|
|
1486
|
+
type='mcp_list_tools',
|
|
1487
|
+
server_label=server_id,
|
|
1488
|
+
tools=[], # These can be read server-side
|
|
1489
|
+
)
|
|
1490
|
+
openai_messages.append(mcp_list_tools_item)
|
|
1491
|
+
elif ( # pragma: no branch
|
|
1492
|
+
action == 'call_tool'
|
|
1493
|
+
and (tool_name := args.get('tool_name'))
|
|
1494
|
+
and (tool_args := args.get('tool_args'))
|
|
1495
|
+
):
|
|
1496
|
+
mcp_call_item = responses.response_input_item_param.McpCall(
|
|
1497
|
+
id=item.tool_call_id,
|
|
1498
|
+
server_label=server_id,
|
|
1499
|
+
name=tool_name,
|
|
1500
|
+
arguments=to_json(tool_args).decode(),
|
|
1501
|
+
error=None, # These can be read server-side
|
|
1502
|
+
output=None, # These can be read server-side
|
|
1503
|
+
type='mcp_call',
|
|
1504
|
+
)
|
|
1505
|
+
openai_messages.append(mcp_call_item)
|
|
1506
|
+
|
|
1424
1507
|
elif isinstance(item, BuiltinToolReturnPart):
|
|
1425
1508
|
if item.provider_name == self.system and send_item_ids:
|
|
1426
1509
|
if (
|
|
@@ -1439,9 +1522,12 @@ class OpenAIResponsesModel(Model):
|
|
|
1439
1522
|
and (status := content.get('status'))
|
|
1440
1523
|
):
|
|
1441
1524
|
web_search_item['status'] = status
|
|
1442
|
-
elif item.tool_name == ImageGenerationTool.kind:
|
|
1525
|
+
elif item.tool_name == ImageGenerationTool.kind:
|
|
1443
1526
|
# Image generation result does not need to be sent back, just the `id` off of `BuiltinToolCallPart`.
|
|
1444
1527
|
pass
|
|
1528
|
+
elif item.tool_name.startswith(MCPServerTool.kind): # pragma: no branch
|
|
1529
|
+
# MCP call result does not need to be sent back, just the fields off of `BuiltinToolCallPart`.
|
|
1530
|
+
pass
|
|
1445
1531
|
elif isinstance(item, FilePart):
|
|
1446
1532
|
# This was generated by the `ImageGenerationTool` or `CodeExecutionTool`,
|
|
1447
1533
|
# and does not need to be sent back separately from the corresponding `BuiltinToolReturnPart`.
|
|
@@ -1616,21 +1702,6 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1616
1702
|
self.provider_details = {'finish_reason': raw_finish_reason}
|
|
1617
1703
|
self.finish_reason = _CHAT_FINISH_REASON_MAP.get(raw_finish_reason)
|
|
1618
1704
|
|
|
1619
|
-
# Handle the text part of the response
|
|
1620
|
-
content = choice.delta.content
|
|
1621
|
-
if content is not None:
|
|
1622
|
-
maybe_event = self._parts_manager.handle_text_delta(
|
|
1623
|
-
vendor_part_id='content',
|
|
1624
|
-
content=content,
|
|
1625
|
-
thinking_tags=self._model_profile.thinking_tags,
|
|
1626
|
-
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
1627
|
-
)
|
|
1628
|
-
if maybe_event is not None: # pragma: no branch
|
|
1629
|
-
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
|
|
1630
|
-
maybe_event.part.id = 'content'
|
|
1631
|
-
maybe_event.part.provider_name = self.provider_name
|
|
1632
|
-
yield maybe_event
|
|
1633
|
-
|
|
1634
1705
|
# The `reasoning_content` field is only present in DeepSeek models.
|
|
1635
1706
|
# https://api-docs.deepseek.com/guides/reasoning_model
|
|
1636
1707
|
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
|
|
@@ -1652,6 +1723,21 @@ class OpenAIStreamedResponse(StreamedResponse):
|
|
|
1652
1723
|
provider_name=self.provider_name,
|
|
1653
1724
|
)
|
|
1654
1725
|
|
|
1726
|
+
# Handle the text part of the response
|
|
1727
|
+
content = choice.delta.content
|
|
1728
|
+
if content:
|
|
1729
|
+
maybe_event = self._parts_manager.handle_text_delta(
|
|
1730
|
+
vendor_part_id='content',
|
|
1731
|
+
content=content,
|
|
1732
|
+
thinking_tags=self._model_profile.thinking_tags,
|
|
1733
|
+
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
|
|
1734
|
+
)
|
|
1735
|
+
if maybe_event is not None: # pragma: no branch
|
|
1736
|
+
if isinstance(maybe_event, PartStartEvent) and isinstance(maybe_event.part, ThinkingPart):
|
|
1737
|
+
maybe_event.part.id = 'content'
|
|
1738
|
+
maybe_event.part.provider_name = self.provider_name
|
|
1739
|
+
yield maybe_event
|
|
1740
|
+
|
|
1655
1741
|
for dtc in choice.delta.tool_calls or []:
|
|
1656
1742
|
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
1657
1743
|
vendor_part_id=dtc.index,
|
|
@@ -1755,7 +1841,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1755
1841
|
args_json = call_part.args_as_json_str()
|
|
1756
1842
|
# Drop the final `"}` so that we can add code deltas
|
|
1757
1843
|
args_json_delta = args_json[:-2]
|
|
1758
|
-
assert args_json_delta.endswith('code":"')
|
|
1844
|
+
assert args_json_delta.endswith('"code":"'), f'Expected {args_json_delta!r} to end in `"code":"`'
|
|
1759
1845
|
|
|
1760
1846
|
yield self._parts_manager.handle_part(
|
|
1761
1847
|
vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None)
|
|
@@ -1769,7 +1855,28 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1769
1855
|
elif isinstance(chunk.item, responses.response_output_item.ImageGenerationCall):
|
|
1770
1856
|
call_part, _, _ = _map_image_generation_tool_call(chunk.item, self.provider_name)
|
|
1771
1857
|
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
|
|
1858
|
+
elif isinstance(chunk.item, responses.response_output_item.McpCall):
|
|
1859
|
+
call_part, _ = _map_mcp_call(chunk.item, self.provider_name)
|
|
1772
1860
|
|
|
1861
|
+
args_json = call_part.args_as_json_str()
|
|
1862
|
+
# Drop the final `{}}` so that we can add tool args deltas
|
|
1863
|
+
args_json_delta = args_json[:-3]
|
|
1864
|
+
assert args_json_delta.endswith('"tool_args":'), (
|
|
1865
|
+
f'Expected {args_json_delta!r} to end in `"tool_args":"`'
|
|
1866
|
+
)
|
|
1867
|
+
|
|
1868
|
+
yield self._parts_manager.handle_part(
|
|
1869
|
+
vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None)
|
|
1870
|
+
)
|
|
1871
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
1872
|
+
vendor_part_id=f'{chunk.item.id}-call',
|
|
1873
|
+
args=args_json_delta,
|
|
1874
|
+
)
|
|
1875
|
+
if maybe_event is not None: # pragma: no branch
|
|
1876
|
+
yield maybe_event
|
|
1877
|
+
elif isinstance(chunk.item, responses.response_output_item.McpListTools):
|
|
1878
|
+
call_part, _ = _map_mcp_list_tools(chunk.item, self.provider_name)
|
|
1879
|
+
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-call', part=call_part)
|
|
1773
1880
|
else:
|
|
1774
1881
|
warnings.warn( # pragma: no cover
|
|
1775
1882
|
f'Handling of this item type is not yet implemented. Please report on our GitHub: {chunk}',
|
|
@@ -1810,6 +1917,13 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1810
1917
|
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-file', part=file_part)
|
|
1811
1918
|
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
|
|
1812
1919
|
|
|
1920
|
+
elif isinstance(chunk.item, responses.response_output_item.McpCall):
|
|
1921
|
+
_, return_part = _map_mcp_call(chunk.item, self.provider_name)
|
|
1922
|
+
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
|
|
1923
|
+
elif isinstance(chunk.item, responses.response_output_item.McpListTools):
|
|
1924
|
+
_, return_part = _map_mcp_list_tools(chunk.item, self.provider_name)
|
|
1925
|
+
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part)
|
|
1926
|
+
|
|
1813
1927
|
elif isinstance(chunk, responses.ResponseReasoningSummaryPartAddedEvent):
|
|
1814
1928
|
yield self._parts_manager.handle_thinking_delta(
|
|
1815
1929
|
vendor_part_id=f'{chunk.item_id}-{chunk.summary_index}',
|
|
@@ -1904,6 +2018,40 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
|
|
|
1904
2018
|
)
|
|
1905
2019
|
yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item_id}-file', part=file_part)
|
|
1906
2020
|
|
|
2021
|
+
elif isinstance(chunk, responses.ResponseMcpCallArgumentsDoneEvent):
|
|
2022
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
2023
|
+
vendor_part_id=f'{chunk.item_id}-call',
|
|
2024
|
+
args='}',
|
|
2025
|
+
)
|
|
2026
|
+
if maybe_event is not None: # pragma: no branch
|
|
2027
|
+
yield maybe_event
|
|
2028
|
+
|
|
2029
|
+
elif isinstance(chunk, responses.ResponseMcpCallArgumentsDeltaEvent):
|
|
2030
|
+
maybe_event = self._parts_manager.handle_tool_call_delta(
|
|
2031
|
+
vendor_part_id=f'{chunk.item_id}-call',
|
|
2032
|
+
args=chunk.delta,
|
|
2033
|
+
)
|
|
2034
|
+
if maybe_event is not None: # pragma: no branch
|
|
2035
|
+
yield maybe_event
|
|
2036
|
+
|
|
2037
|
+
elif isinstance(chunk, responses.ResponseMcpListToolsInProgressEvent):
|
|
2038
|
+
pass # there's nothing we need to do here
|
|
2039
|
+
|
|
2040
|
+
elif isinstance(chunk, responses.ResponseMcpListToolsCompletedEvent):
|
|
2041
|
+
pass # there's nothing we need to do here
|
|
2042
|
+
|
|
2043
|
+
elif isinstance(chunk, responses.ResponseMcpListToolsFailedEvent): # pragma: no cover
|
|
2044
|
+
pass # there's nothing we need to do here
|
|
2045
|
+
|
|
2046
|
+
elif isinstance(chunk, responses.ResponseMcpCallInProgressEvent):
|
|
2047
|
+
pass # there's nothing we need to do here
|
|
2048
|
+
|
|
2049
|
+
elif isinstance(chunk, responses.ResponseMcpCallFailedEvent): # pragma: no cover
|
|
2050
|
+
pass # there's nothing we need to do here
|
|
2051
|
+
|
|
2052
|
+
elif isinstance(chunk, responses.ResponseMcpCallCompletedEvent):
|
|
2053
|
+
pass # there's nothing we need to do here
|
|
2054
|
+
|
|
1907
2055
|
else: # pragma: no cover
|
|
1908
2056
|
warnings.warn(
|
|
1909
2057
|
f'Handling of this event type is not yet implemented. Please report on our GitHub: {chunk}',
|
|
@@ -1973,7 +2121,6 @@ def _map_usage(
|
|
|
1973
2121
|
def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
|
|
1974
2122
|
# When reasoning, the Responses API requires the `ResponseFunctionToolCall` to be returned with both the `call_id` and `id` fields.
|
|
1975
2123
|
# Before our `ToolCallPart` gained the `id` field alongside `tool_call_id` field, we combined the two fields into a single string stored on `tool_call_id`.
|
|
1976
|
-
|
|
1977
2124
|
if '|' in combined_id:
|
|
1978
2125
|
call_id, id = combined_id.split('|', 1)
|
|
1979
2126
|
return call_id, id
|
|
@@ -2013,7 +2160,7 @@ def _map_code_interpreter_tool_call(
|
|
|
2013
2160
|
tool_call_id=item.id,
|
|
2014
2161
|
args={
|
|
2015
2162
|
'container_id': item.container_id,
|
|
2016
|
-
'code': item.code,
|
|
2163
|
+
'code': item.code or '',
|
|
2017
2164
|
},
|
|
2018
2165
|
provider_name=provider_name,
|
|
2019
2166
|
),
|
|
@@ -2105,3 +2252,50 @@ def _map_image_generation_tool_call(
|
|
|
2105
2252
|
),
|
|
2106
2253
|
file_part,
|
|
2107
2254
|
)
|
|
2255
|
+
|
|
2256
|
+
|
|
2257
|
+
def _map_mcp_list_tools(
|
|
2258
|
+
item: responses.response_output_item.McpListTools, provider_name: str
|
|
2259
|
+
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
|
|
2260
|
+
tool_name = ':'.join([MCPServerTool.kind, item.server_label])
|
|
2261
|
+
return (
|
|
2262
|
+
BuiltinToolCallPart(
|
|
2263
|
+
tool_name=tool_name,
|
|
2264
|
+
tool_call_id=item.id,
|
|
2265
|
+
provider_name=provider_name,
|
|
2266
|
+
args={'action': 'list_tools'},
|
|
2267
|
+
),
|
|
2268
|
+
BuiltinToolReturnPart(
|
|
2269
|
+
tool_name=tool_name,
|
|
2270
|
+
tool_call_id=item.id,
|
|
2271
|
+
content=item.model_dump(mode='json', include={'tools', 'error'}),
|
|
2272
|
+
provider_name=provider_name,
|
|
2273
|
+
),
|
|
2274
|
+
)
|
|
2275
|
+
|
|
2276
|
+
|
|
2277
|
+
def _map_mcp_call(
|
|
2278
|
+
item: responses.response_output_item.McpCall, provider_name: str
|
|
2279
|
+
) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]:
|
|
2280
|
+
tool_name = ':'.join([MCPServerTool.kind, item.server_label])
|
|
2281
|
+
return (
|
|
2282
|
+
BuiltinToolCallPart(
|
|
2283
|
+
tool_name=tool_name,
|
|
2284
|
+
tool_call_id=item.id,
|
|
2285
|
+
args={
|
|
2286
|
+
'action': 'call_tool',
|
|
2287
|
+
'tool_name': item.name,
|
|
2288
|
+
'tool_args': json.loads(item.arguments) if item.arguments else {},
|
|
2289
|
+
},
|
|
2290
|
+
provider_name=provider_name,
|
|
2291
|
+
),
|
|
2292
|
+
BuiltinToolReturnPart(
|
|
2293
|
+
tool_name=tool_name,
|
|
2294
|
+
tool_call_id=item.id,
|
|
2295
|
+
content={
|
|
2296
|
+
'output': item.output,
|
|
2297
|
+
'error': item.error,
|
|
2298
|
+
},
|
|
2299
|
+
provider_name=provider_name,
|
|
2300
|
+
),
|
|
2301
|
+
)
|
|
@@ -8,7 +8,7 @@ from __future__ import annotations as _annotations
|
|
|
8
8
|
from abc import ABC, abstractmethod
|
|
9
9
|
from typing import Any, Generic, TypeVar
|
|
10
10
|
|
|
11
|
-
from
|
|
11
|
+
from ..profiles import ModelProfile
|
|
12
12
|
|
|
13
13
|
InterfaceClient = TypeVar('InterfaceClient')
|
|
14
14
|
|
|
@@ -53,7 +53,7 @@ class Provider(ABC, Generic[InterfaceClient]):
|
|
|
53
53
|
|
|
54
54
|
def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
55
55
|
"""Infers the provider class from the provider name."""
|
|
56
|
-
if provider
|
|
56
|
+
if provider in ('openai', 'openai-chat', 'openai-responses'):
|
|
57
57
|
from .openai import OpenAIProvider
|
|
58
58
|
|
|
59
59
|
return OpenAIProvider
|
|
@@ -73,15 +73,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
73
73
|
from .azure import AzureProvider
|
|
74
74
|
|
|
75
75
|
return AzureProvider
|
|
76
|
-
elif provider
|
|
77
|
-
from .
|
|
76
|
+
elif provider in ('google-vertex', 'google-gla'):
|
|
77
|
+
from .google import GoogleProvider
|
|
78
78
|
|
|
79
|
-
return
|
|
80
|
-
elif provider == 'google-gla':
|
|
81
|
-
from .google_gla import GoogleGLAProvider # type: ignore[reportDeprecated]
|
|
82
|
-
|
|
83
|
-
return GoogleGLAProvider # type: ignore[reportDeprecated]
|
|
84
|
-
# NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
|
|
79
|
+
return GoogleProvider
|
|
85
80
|
elif provider == 'bedrock':
|
|
86
81
|
from .bedrock import BedrockProvider
|
|
87
82
|
|
|
@@ -146,11 +141,25 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
|
|
|
146
141
|
from .nebius import NebiusProvider
|
|
147
142
|
|
|
148
143
|
return NebiusProvider
|
|
144
|
+
elif provider == 'ovhcloud':
|
|
145
|
+
from .ovhcloud import OVHcloudProvider
|
|
146
|
+
|
|
147
|
+
return OVHcloudProvider
|
|
149
148
|
else: # pragma: no cover
|
|
150
149
|
raise ValueError(f'Unknown provider: {provider}')
|
|
151
150
|
|
|
152
151
|
|
|
153
152
|
def infer_provider(provider: str) -> Provider[Any]:
|
|
154
153
|
"""Infer the provider from the provider name."""
|
|
155
|
-
|
|
156
|
-
|
|
154
|
+
if provider.startswith('gateway/'):
|
|
155
|
+
from .gateway import gateway_provider
|
|
156
|
+
|
|
157
|
+
provider = provider.removeprefix('gateway/')
|
|
158
|
+
return gateway_provider(provider)
|
|
159
|
+
elif provider in ('google-vertex', 'google-gla'):
|
|
160
|
+
from .google import GoogleProvider
|
|
161
|
+
|
|
162
|
+
return GoogleProvider(vertexai=provider == 'google-vertex')
|
|
163
|
+
else:
|
|
164
|
+
provider_class = infer_provider_class(provider)
|
|
165
|
+
return provider_class()
|
pydantic_ai/providers/bedrock.py
CHANGED
|
@@ -4,7 +4,7 @@ import os
|
|
|
4
4
|
import re
|
|
5
5
|
from collections.abc import Callable
|
|
6
6
|
from dataclasses import dataclass
|
|
7
|
-
from typing import Literal, overload
|
|
7
|
+
from typing import Any, Literal, overload
|
|
8
8
|
|
|
9
9
|
from pydantic_ai import ModelProfile
|
|
10
10
|
from pydantic_ai.exceptions import UserError
|
|
@@ -21,6 +21,8 @@ try:
|
|
|
21
21
|
from botocore.client import BaseClient
|
|
22
22
|
from botocore.config import Config
|
|
23
23
|
from botocore.exceptions import NoRegionError
|
|
24
|
+
from botocore.session import Session
|
|
25
|
+
from botocore.tokens import FrozenAuthToken
|
|
24
26
|
except ImportError as _import_error:
|
|
25
27
|
raise ImportError(
|
|
26
28
|
'Please install the `boto3` package to use the Bedrock provider, '
|
|
@@ -117,10 +119,23 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
117
119
|
def __init__(
|
|
118
120
|
self,
|
|
119
121
|
*,
|
|
122
|
+
api_key: str,
|
|
123
|
+
base_url: str | None = None,
|
|
120
124
|
region_name: str | None = None,
|
|
125
|
+
profile_name: str | None = None,
|
|
126
|
+
aws_read_timeout: float | None = None,
|
|
127
|
+
aws_connect_timeout: float | None = None,
|
|
128
|
+
) -> None: ...
|
|
129
|
+
|
|
130
|
+
@overload
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
*,
|
|
121
134
|
aws_access_key_id: str | None = None,
|
|
122
135
|
aws_secret_access_key: str | None = None,
|
|
123
136
|
aws_session_token: str | None = None,
|
|
137
|
+
base_url: str | None = None,
|
|
138
|
+
region_name: str | None = None,
|
|
124
139
|
profile_name: str | None = None,
|
|
125
140
|
aws_read_timeout: float | None = None,
|
|
126
141
|
aws_connect_timeout: float | None = None,
|
|
@@ -130,11 +145,13 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
130
145
|
self,
|
|
131
146
|
*,
|
|
132
147
|
bedrock_client: BaseClient | None = None,
|
|
133
|
-
region_name: str | None = None,
|
|
134
148
|
aws_access_key_id: str | None = None,
|
|
135
149
|
aws_secret_access_key: str | None = None,
|
|
136
150
|
aws_session_token: str | None = None,
|
|
151
|
+
base_url: str | None = None,
|
|
152
|
+
region_name: str | None = None,
|
|
137
153
|
profile_name: str | None = None,
|
|
154
|
+
api_key: str | None = None,
|
|
138
155
|
aws_read_timeout: float | None = None,
|
|
139
156
|
aws_connect_timeout: float | None = None,
|
|
140
157
|
) -> None:
|
|
@@ -142,10 +159,12 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
142
159
|
|
|
143
160
|
Args:
|
|
144
161
|
bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
162
|
+
aws_access_key_id: The AWS access key ID. If not set, the `AWS_ACCESS_KEY_ID` environment variable will be used if available.
|
|
163
|
+
aws_secret_access_key: The AWS secret access key. If not set, the `AWS_SECRET_ACCESS_KEY` environment variable will be used if available.
|
|
164
|
+
aws_session_token: The AWS session token. If not set, the `AWS_SESSION_TOKEN` environment variable will be used if available.
|
|
165
|
+
api_key: The API key for Bedrock client. Can be used instead of `aws_access_key_id`, `aws_secret_access_key`, and `aws_session_token`. If not set, the `AWS_BEARER_TOKEN_BEDROCK` environment variable will be used if available.
|
|
166
|
+
base_url: The base URL for the Bedrock client.
|
|
167
|
+
region_name: The AWS region name. If not set, the `AWS_DEFAULT_REGION` environment variable will be used if available.
|
|
149
168
|
profile_name: The AWS profile name.
|
|
150
169
|
aws_read_timeout: The read timeout for Bedrock client.
|
|
151
170
|
aws_connect_timeout: The connect timeout for Bedrock client.
|
|
@@ -153,19 +172,44 @@ class BedrockProvider(Provider[BaseClient]):
|
|
|
153
172
|
if bedrock_client is not None:
|
|
154
173
|
self._client = bedrock_client
|
|
155
174
|
else:
|
|
175
|
+
read_timeout = aws_read_timeout or float(os.getenv('AWS_READ_TIMEOUT', 300))
|
|
176
|
+
connect_timeout = aws_connect_timeout or float(os.getenv('AWS_CONNECT_TIMEOUT', 60))
|
|
177
|
+
config: dict[str, Any] = {
|
|
178
|
+
'read_timeout': read_timeout,
|
|
179
|
+
'connect_timeout': connect_timeout,
|
|
180
|
+
}
|
|
156
181
|
try:
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
182
|
+
if api_key is not None:
|
|
183
|
+
session = boto3.Session(
|
|
184
|
+
botocore_session=_BearerTokenSession(api_key),
|
|
185
|
+
region_name=region_name,
|
|
186
|
+
profile_name=profile_name,
|
|
187
|
+
)
|
|
188
|
+
config['signature_version'] = 'bearer'
|
|
189
|
+
else:
|
|
190
|
+
session = boto3.Session(
|
|
191
|
+
aws_access_key_id=aws_access_key_id,
|
|
192
|
+
aws_secret_access_key=aws_secret_access_key,
|
|
193
|
+
aws_session_token=aws_session_token,
|
|
194
|
+
region_name=region_name,
|
|
195
|
+
profile_name=profile_name,
|
|
196
|
+
)
|
|
166
197
|
self._client = session.client( # type: ignore[reportUnknownMemberType]
|
|
167
198
|
'bedrock-runtime',
|
|
168
|
-
config=Config(
|
|
199
|
+
config=Config(**config),
|
|
200
|
+
endpoint_url=base_url,
|
|
169
201
|
)
|
|
170
202
|
except NoRegionError as exc: # pragma: no cover
|
|
171
203
|
raise UserError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class _BearerTokenSession(Session):
|
|
207
|
+
def __init__(self, token: str):
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.token = token
|
|
210
|
+
|
|
211
|
+
def get_auth_token(self, **_kwargs: Any) -> FrozenAuthToken:
|
|
212
|
+
return FrozenAuthToken(self.token)
|
|
213
|
+
|
|
214
|
+
def get_credentials(self) -> None: # type: ignore[reportIncompatibleMethodOverride]
|
|
215
|
+
return None
|