pydantic-ai-slim 0.0.34__tar.gz → 0.0.36__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/PKG-INFO +4 -2
  2. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_pydantic.py +2 -2
  3. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/common_tools/duckduckgo.py +1 -1
  4. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/common_tools/tavily.py +2 -1
  5. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/__init__.py +48 -1
  6. pydantic_ai_slim-0.0.36/pydantic_ai/models/bedrock.py +451 -0
  7. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/vertexai.py +15 -24
  8. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/providers/__init__.py +5 -0
  9. pydantic_ai_slim-0.0.36/pydantic_ai/providers/bedrock.py +76 -0
  10. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/providers/google_vertex.py +15 -24
  11. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/settings.py +3 -0
  12. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pyproject.toml +4 -2
  13. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/.gitignore +0 -0
  14. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/README.md +0 -0
  15. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/__init__.py +0 -0
  16. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_agent_graph.py +0 -0
  17. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_cli.py +0 -0
  18. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_griffe.py +0 -0
  19. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_parts_manager.py +0 -0
  20. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_result.py +0 -0
  21. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_system_prompt.py +0 -0
  22. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/_utils.py +0 -0
  23. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/agent.py +0 -0
  24. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/common_tools/__init__.py +0 -0
  25. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/exceptions.py +0 -0
  26. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/format_as_xml.py +0 -0
  27. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/messages.py +0 -0
  28. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/anthropic.py +0 -0
  29. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/cohere.py +0 -0
  30. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/fallback.py +0 -0
  31. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/function.py +0 -0
  32. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/gemini.py +0 -0
  33. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/groq.py +0 -0
  34. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/instrumented.py +0 -0
  35. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/mistral.py +0 -0
  36. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/openai.py +0 -0
  37. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/test.py +0 -0
  38. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/models/wrapper.py +0 -0
  39. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/providers/deepseek.py +0 -0
  40. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/providers/google_gla.py +0 -0
  41. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/providers/openai.py +0 -0
  42. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/py.typed +0 -0
  43. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/result.py +0 -0
  44. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/tools.py +0 -0
  45. {pydantic_ai_slim-0.0.34 → pydantic_ai_slim-0.0.36}/pydantic_ai/usage.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.34
3
+ Version: 0.0.36
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,11 +29,13 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.34
32
+ Requires-Dist: pydantic-graph==0.0.36
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
36
36
  Requires-Dist: anthropic>=0.49.0; extra == 'anthropic'
37
+ Provides-Extra: bedrock
38
+ Requires-Dist: boto3>=1.34.116; extra == 'bedrock'
37
39
  Provides-Extra: cli
38
40
  Requires-Dist: argcomplete>=3.5.0; extra == 'cli'
39
41
  Requires-Dist: prompt-toolkit>=3; extra == 'cli'
@@ -6,7 +6,7 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
6
6
  from __future__ import annotations as _annotations
7
7
 
8
8
  from inspect import Parameter, signature
9
- from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
9
+ from typing import TYPE_CHECKING, Any, Callable, cast
10
10
 
11
11
  from pydantic import ConfigDict
12
12
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -15,7 +15,7 @@ from pydantic.fields import FieldInfo
15
15
  from pydantic.json_schema import GenerateJsonSchema
16
16
  from pydantic.plugin._schema_validator import create_schema_validator
17
17
  from pydantic_core import SchemaValidator, core_schema
18
- from typing_extensions import get_origin
18
+ from typing_extensions import TypedDict, get_origin
19
19
 
20
20
  from ._griffe import doc_descriptions
21
21
  from ._utils import check_object_json_schema, is_model_like
@@ -1,10 +1,10 @@
1
1
  import functools
2
2
  from dataclasses import dataclass
3
- from typing import TypedDict
4
3
 
5
4
  import anyio
6
5
  import anyio.to_thread
7
6
  from pydantic import TypeAdapter
7
+ from typing_extensions import TypedDict
8
8
 
9
9
  from pydantic_ai.tools import Tool
10
10
 
@@ -1,7 +1,8 @@
1
1
  from dataclasses import dataclass
2
- from typing import Literal, TypedDict
2
+ from typing import Literal
3
3
 
4
4
  from pydantic import TypeAdapter
5
+ from typing_extensions import TypedDict
5
6
 
6
7
  from pydantic_ai.tools import Tool
7
8
 
@@ -34,6 +34,49 @@ KnownModelName = Literal[
34
34
  'anthropic:claude-3-opus-latest',
35
35
  'claude-3-7-sonnet-latest',
36
36
  'claude-3-5-haiku-latest',
37
+ 'bedrock:amazon.titan-tg1-large',
38
+ 'bedrock:amazon.titan-text-lite-v1',
39
+ 'bedrock:amazon.titan-text-express-v1',
40
+ 'bedrock:us.amazon.nova-pro-v1:0',
41
+ 'bedrock:us.amazon.nova-lite-v1:0',
42
+ 'bedrock:us.amazon.nova-micro-v1:0',
43
+ 'bedrock:anthropic.claude-3-5-sonnet-20241022-v2:0',
44
+ 'bedrock:us.anthropic.claude-3-5-sonnet-20241022-v2:0',
45
+ 'bedrock:anthropic.claude-3-5-haiku-20241022-v1:0',
46
+ 'bedrock:us.anthropic.claude-3-5-haiku-20241022-v1:0',
47
+ 'bedrock:anthropic.claude-instant-v1',
48
+ 'bedrock:anthropic.claude-v2:1',
49
+ 'bedrock:anthropic.claude-v2',
50
+ 'bedrock:anthropic.claude-3-sonnet-20240229-v1:0',
51
+ 'bedrock:us.anthropic.claude-3-sonnet-20240229-v1:0',
52
+ 'bedrock:anthropic.claude-3-haiku-20240307-v1:0',
53
+ 'bedrock:us.anthropic.claude-3-haiku-20240307-v1:0',
54
+ 'bedrock:anthropic.claude-3-opus-20240229-v1:0',
55
+ 'bedrock:us.anthropic.claude-3-opus-20240229-v1:0',
56
+ 'bedrock:anthropic.claude-3-5-sonnet-20240620-v1:0',
57
+ 'bedrock:us.anthropic.claude-3-5-sonnet-20240620-v1:0',
58
+ 'bedrock:anthropic.claude-3-7-sonnet-20250219-v1:0',
59
+ 'bedrock:us.anthropic.claude-3-7-sonnet-20250219-v1:0',
60
+ 'bedrock:cohere.command-text-v14',
61
+ 'bedrock:cohere.command-r-v1:0',
62
+ 'bedrock:cohere.command-r-plus-v1:0',
63
+ 'bedrock:cohere.command-light-text-v14',
64
+ 'bedrock:meta.llama3-8b-instruct-v1:0',
65
+ 'bedrock:meta.llama3-70b-instruct-v1:0',
66
+ 'bedrock:meta.llama3-1-8b-instruct-v1:0',
67
+ 'bedrock:us.meta.llama3-1-8b-instruct-v1:0',
68
+ 'bedrock:meta.llama3-1-70b-instruct-v1:0',
69
+ 'bedrock:us.meta.llama3-1-70b-instruct-v1:0',
70
+ 'bedrock:meta.llama3-1-405b-instruct-v1:0',
71
+ 'bedrock:us.meta.llama3-2-11b-instruct-v1:0',
72
+ 'bedrock:us.meta.llama3-2-90b-instruct-v1:0',
73
+ 'bedrock:us.meta.llama3-2-1b-instruct-v1:0',
74
+ 'bedrock:us.meta.llama3-2-3b-instruct-v1:0',
75
+ 'bedrock:us.meta.llama3-3-70b-instruct-v1:0',
76
+ 'bedrock:mistral.mistral-7b-instruct-v0:2',
77
+ 'bedrock:mistral.mixtral-8x7b-instruct-v0:1',
78
+ 'bedrock:mistral.mistral-large-2402-v1:0',
79
+ 'bedrock:mistral.mistral-large-2407-v1:0',
37
80
  'claude-3-5-sonnet-latest',
38
81
  'claude-3-opus-latest',
39
82
  'cohere:c4ai-aya-expanse-32b',
@@ -324,7 +367,7 @@ def infer_model(model: Model | KnownModelName) -> Model:
324
367
  return TestModel()
325
368
 
326
369
  try:
327
- provider, model_name = model.split(':')
370
+ provider, model_name = model.split(':', maxsplit=1)
328
371
  except ValueError:
329
372
  model_name = model
330
373
  # TODO(Marcelo): We should deprecate this way.
@@ -368,6 +411,10 @@ def infer_model(model: Model | KnownModelName) -> Model:
368
411
 
369
412
  # TODO(Marcelo): Missing provider API.
370
413
  return AnthropicModel(model_name)
414
+ elif provider == 'bedrock':
415
+ from .bedrock import BedrockConverseModel
416
+
417
+ return BedrockConverseModel(model_name)
371
418
  else:
372
419
  raise UserError(f'Unknown model: {model}')
373
420
 
@@ -0,0 +1,451 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import typing
5
+ from collections.abc import AsyncIterator, Iterable
6
+ from contextlib import asynccontextmanager
7
+ from dataclasses import dataclass, field
8
+ from datetime import datetime
9
+ from typing import TYPE_CHECKING, Generic, Literal, Union, cast, overload
10
+
11
+ import anyio
12
+ import anyio.to_thread
13
+ from typing_extensions import ParamSpec, assert_never
14
+
15
+ from pydantic_ai import _utils, result
16
+ from pydantic_ai.messages import (
17
+ ModelMessage,
18
+ ModelRequest,
19
+ ModelResponse,
20
+ ModelResponsePart,
21
+ ModelResponseStreamEvent,
22
+ RetryPromptPart,
23
+ SystemPromptPart,
24
+ TextPart,
25
+ ToolCallPart,
26
+ ToolReturnPart,
27
+ UserPromptPart,
28
+ )
29
+ from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
30
+ from pydantic_ai.providers import Provider, infer_provider
31
+ from pydantic_ai.settings import ModelSettings
32
+ from pydantic_ai.tools import ToolDefinition
33
+
34
+ if TYPE_CHECKING:
35
+ from botocore.client import BaseClient
36
+ from botocore.eventstream import EventStream
37
+ from mypy_boto3_bedrock_runtime import BedrockRuntimeClient
38
+ from mypy_boto3_bedrock_runtime.type_defs import (
39
+ ContentBlockOutputTypeDef,
40
+ ConverseResponseTypeDef,
41
+ ConverseStreamMetadataEventTypeDef,
42
+ ConverseStreamOutputTypeDef,
43
+ InferenceConfigurationTypeDef,
44
+ MessageUnionTypeDef,
45
+ ToolChoiceTypeDef,
46
+ ToolTypeDef,
47
+ )
48
+
49
+
50
+ LatestBedrockModelNames = Literal[
51
+ 'amazon.titan-tg1-large',
52
+ 'amazon.titan-text-lite-v1',
53
+ 'amazon.titan-text-express-v1',
54
+ 'us.amazon.nova-pro-v1:0',
55
+ 'us.amazon.nova-lite-v1:0',
56
+ 'us.amazon.nova-micro-v1:0',
57
+ 'anthropic.claude-3-5-sonnet-20241022-v2:0',
58
+ 'us.anthropic.claude-3-5-sonnet-20241022-v2:0',
59
+ 'anthropic.claude-3-5-haiku-20241022-v1:0',
60
+ 'us.anthropic.claude-3-5-haiku-20241022-v1:0',
61
+ 'anthropic.claude-instant-v1',
62
+ 'anthropic.claude-v2:1',
63
+ 'anthropic.claude-v2',
64
+ 'anthropic.claude-3-sonnet-20240229-v1:0',
65
+ 'us.anthropic.claude-3-sonnet-20240229-v1:0',
66
+ 'anthropic.claude-3-haiku-20240307-v1:0',
67
+ 'us.anthropic.claude-3-haiku-20240307-v1:0',
68
+ 'anthropic.claude-3-opus-20240229-v1:0',
69
+ 'us.anthropic.claude-3-opus-20240229-v1:0',
70
+ 'anthropic.claude-3-5-sonnet-20240620-v1:0',
71
+ 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
72
+ 'anthropic.claude-3-7-sonnet-20250219-v1:0',
73
+ 'us.anthropic.claude-3-7-sonnet-20250219-v1:0',
74
+ 'cohere.command-text-v14',
75
+ 'cohere.command-r-v1:0',
76
+ 'cohere.command-r-plus-v1:0',
77
+ 'cohere.command-light-text-v14',
78
+ 'meta.llama3-8b-instruct-v1:0',
79
+ 'meta.llama3-70b-instruct-v1:0',
80
+ 'meta.llama3-1-8b-instruct-v1:0',
81
+ 'us.meta.llama3-1-8b-instruct-v1:0',
82
+ 'meta.llama3-1-70b-instruct-v1:0',
83
+ 'us.meta.llama3-1-70b-instruct-v1:0',
84
+ 'meta.llama3-1-405b-instruct-v1:0',
85
+ 'us.meta.llama3-2-11b-instruct-v1:0',
86
+ 'us.meta.llama3-2-90b-instruct-v1:0',
87
+ 'us.meta.llama3-2-1b-instruct-v1:0',
88
+ 'us.meta.llama3-2-3b-instruct-v1:0',
89
+ 'us.meta.llama3-3-70b-instruct-v1:0',
90
+ 'mistral.mistral-7b-instruct-v0:2',
91
+ 'mistral.mixtral-8x7b-instruct-v0:1',
92
+ 'mistral.mistral-large-2402-v1:0',
93
+ 'mistral.mistral-large-2407-v1:0',
94
+ ]
95
+ """Latest Bedrock models."""
96
+
97
+ BedrockModelName = Union[str, LatestBedrockModelNames]
98
+ """Possible Bedrock model names.
99
+
100
+ Since Bedrock supports a variety of date-stamped models, we explicitly list the latest models but allow any name in the type hints.
101
+ See [the Bedrock docs](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html) for a full list.
102
+ """
103
+
104
+
105
+ P = ParamSpec('P')
106
+ T = typing.TypeVar('T')
107
+
108
+
109
+ @dataclass(init=False)
110
+ class BedrockConverseModel(Model):
111
+ """A model that uses the Bedrock Converse API."""
112
+
113
+ client: BedrockRuntimeClient
114
+
115
+ _model_name: BedrockModelName = field(repr=False)
116
+ _system: str | None = field(default='bedrock', repr=False)
117
+
118
+ @property
119
+ def model_name(self) -> str:
120
+ """The model name."""
121
+ return self._model_name
122
+
123
+ @property
124
+ def system(self) -> str | None:
125
+ """The system / model provider, ex: openai."""
126
+ return self._system
127
+
128
+ def __init__(
129
+ self,
130
+ model_name: BedrockModelName,
131
+ *,
132
+ provider: Literal['bedrock'] | Provider[BaseClient] = 'bedrock',
133
+ ):
134
+ """Initialize a Bedrock model.
135
+
136
+ Args:
137
+ model_name: The name of the model to use.
138
+ model_name: The name of the Bedrock model to use. List of model names available
139
+ [here](https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html).
140
+ provider: The provider to use. Defaults to `'bedrock'`.
141
+ """
142
+ self._model_name = model_name
143
+
144
+ if isinstance(provider, str):
145
+ self.client = infer_provider(provider).client
146
+ else:
147
+ self.client = cast('BedrockRuntimeClient', provider.client)
148
+
149
+ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ToolTypeDef]:
150
+ tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools]
151
+ if model_request_parameters.result_tools:
152
+ tools += [self._map_tool_definition(r) for r in model_request_parameters.result_tools]
153
+ return tools
154
+
155
+ @staticmethod
156
+ def _map_tool_definition(f: ToolDefinition) -> ToolTypeDef:
157
+ return {
158
+ 'toolSpec': {
159
+ 'name': f.name,
160
+ 'description': f.description,
161
+ 'inputSchema': {'json': f.parameters_json_schema},
162
+ }
163
+ }
164
+
165
+ async def request(
166
+ self,
167
+ messages: list[ModelMessage],
168
+ model_settings: ModelSettings | None,
169
+ model_request_parameters: ModelRequestParameters,
170
+ ) -> tuple[ModelResponse, result.Usage]:
171
+ response = await self._messages_create(messages, False, model_settings, model_request_parameters)
172
+ return await self._process_response(response)
173
+
174
+ @asynccontextmanager
175
+ async def request_stream(
176
+ self,
177
+ messages: list[ModelMessage],
178
+ model_settings: ModelSettings | None,
179
+ model_request_parameters: ModelRequestParameters,
180
+ ) -> AsyncIterator[StreamedResponse]:
181
+ response = await self._messages_create(messages, True, model_settings, model_request_parameters)
182
+ yield BedrockStreamedResponse(_model_name=self.model_name, _event_stream=response)
183
+
184
+ async def _process_response(self, response: ConverseResponseTypeDef) -> tuple[ModelResponse, result.Usage]:
185
+ items: list[ModelResponsePart] = []
186
+ if message := response['output'].get('message'):
187
+ for item in message['content']:
188
+ if text := item.get('text'):
189
+ items.append(TextPart(content=text))
190
+ else:
191
+ tool_use = item.get('toolUse')
192
+ assert tool_use is not None, f'Found a content that is not a text or tool use: {item}'
193
+ items.append(
194
+ ToolCallPart(
195
+ tool_name=tool_use['name'],
196
+ args=tool_use['input'],
197
+ tool_call_id=tool_use['toolUseId'],
198
+ ),
199
+ )
200
+ usage = result.Usage(
201
+ request_tokens=response['usage']['inputTokens'],
202
+ response_tokens=response['usage']['outputTokens'],
203
+ total_tokens=response['usage']['totalTokens'],
204
+ )
205
+ return ModelResponse(items, model_name=self.model_name), usage
206
+
207
+ @overload
208
+ async def _messages_create(
209
+ self,
210
+ messages: list[ModelMessage],
211
+ stream: Literal[True],
212
+ model_settings: ModelSettings | None,
213
+ model_request_parameters: ModelRequestParameters,
214
+ ) -> EventStream[ConverseStreamOutputTypeDef]:
215
+ pass
216
+
217
+ @overload
218
+ async def _messages_create(
219
+ self,
220
+ messages: list[ModelMessage],
221
+ stream: Literal[False],
222
+ model_settings: ModelSettings | None,
223
+ model_request_parameters: ModelRequestParameters,
224
+ ) -> ConverseResponseTypeDef:
225
+ pass
226
+
227
+ async def _messages_create(
228
+ self,
229
+ messages: list[ModelMessage],
230
+ stream: bool,
231
+ model_settings: ModelSettings | None,
232
+ model_request_parameters: ModelRequestParameters,
233
+ ) -> ConverseResponseTypeDef | EventStream[ConverseStreamOutputTypeDef]:
234
+ tools = self._get_tools(model_request_parameters)
235
+ support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
236
+ if not tools or not support_tools_choice:
237
+ tool_choice: ToolChoiceTypeDef = {}
238
+ elif not model_request_parameters.allow_text_result:
239
+ tool_choice = {'any': {}}
240
+ else:
241
+ tool_choice = {'auto': {}}
242
+
243
+ system_prompt, bedrock_messages = self._map_message(messages)
244
+ inference_config = self._map_inference_config(model_settings)
245
+
246
+ params = {
247
+ 'modelId': self.model_name,
248
+ 'messages': bedrock_messages,
249
+ 'system': [{'text': system_prompt}],
250
+ 'inferenceConfig': inference_config,
251
+ **(
252
+ {'toolConfig': {'tools': tools, **({'toolChoice': tool_choice} if tool_choice else {})}}
253
+ if tools
254
+ else {}
255
+ ),
256
+ }
257
+
258
+ if stream:
259
+ model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
260
+ model_response = model_response['stream']
261
+ else:
262
+ model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
263
+ return model_response
264
+
265
+ @staticmethod
266
+ def _map_inference_config(
267
+ model_settings: ModelSettings | None,
268
+ ) -> InferenceConfigurationTypeDef:
269
+ model_settings = model_settings or {}
270
+ inference_config: InferenceConfigurationTypeDef = {}
271
+
272
+ if max_tokens := model_settings.get('max_tokens'):
273
+ inference_config['maxTokens'] = max_tokens
274
+ if temperature := model_settings.get('temperature'):
275
+ inference_config['temperature'] = temperature
276
+ if top_p := model_settings.get('top_p'):
277
+ inference_config['topP'] = top_p
278
+ # TODO(Marcelo): This is not included in model_settings yet.
279
+ # if stop_sequences := model_settings.get('stop_sequences'):
280
+ # inference_config['stopSequences'] = stop_sequences
281
+
282
+ return inference_config
283
+
284
+ def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageUnionTypeDef]]:
285
+ """Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
286
+ system_prompt: str = ''
287
+ bedrock_messages: list[MessageUnionTypeDef] = []
288
+ for m in messages:
289
+ if isinstance(m, ModelRequest):
290
+ for part in m.parts:
291
+ if isinstance(part, SystemPromptPart):
292
+ system_prompt += part.content
293
+ elif isinstance(part, UserPromptPart):
294
+ if isinstance(part.content, str):
295
+ bedrock_messages.append({'role': 'user', 'content': [{'text': part.content}]})
296
+ else:
297
+ raise NotImplementedError('User prompt can only be a string for now.')
298
+ elif isinstance(part, ToolReturnPart):
299
+ assert part.tool_call_id is not None
300
+ bedrock_messages.append(
301
+ {
302
+ 'role': 'user',
303
+ 'content': [
304
+ {
305
+ 'toolResult': {
306
+ 'toolUseId': part.tool_call_id,
307
+ 'content': [{'text': part.model_response_str()}],
308
+ 'status': 'success',
309
+ }
310
+ }
311
+ ],
312
+ }
313
+ )
314
+ elif isinstance(part, RetryPromptPart):
315
+ # TODO(Marcelo): We need to add a test here.
316
+ if part.tool_name is None: # pragma: no cover
317
+ bedrock_messages.append({'role': 'user', 'content': [{'text': part.model_response()}]})
318
+ else:
319
+ assert part.tool_call_id is not None
320
+ bedrock_messages.append(
321
+ {
322
+ 'role': 'user',
323
+ 'content': [
324
+ {
325
+ 'toolResult': {
326
+ 'toolUseId': part.tool_call_id,
327
+ 'content': [{'text': part.model_response()}],
328
+ 'status': 'error',
329
+ }
330
+ }
331
+ ],
332
+ }
333
+ )
334
+ elif isinstance(m, ModelResponse):
335
+ content: list[ContentBlockOutputTypeDef] = []
336
+ for item in m.parts:
337
+ if isinstance(item, TextPart):
338
+ content.append({'text': item.content})
339
+ else:
340
+ assert isinstance(item, ToolCallPart)
341
+ content.append(self._map_tool_call(item)) # FIXME: MISSING key
342
+ bedrock_messages.append({'role': 'assistant', 'content': content})
343
+ else:
344
+ assert_never(m)
345
+ return system_prompt, bedrock_messages
346
+
347
+ @staticmethod
348
+ def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
349
+ assert t.tool_call_id is not None
350
+ return {
351
+ 'toolUse': {
352
+ 'toolUseId': t.tool_call_id,
353
+ 'name': t.tool_name,
354
+ 'input': t.args_as_dict(),
355
+ }
356
+ }
357
+
358
+
359
+ @dataclass
360
+ class BedrockStreamedResponse(StreamedResponse):
361
+ """Implementation of `StreamedResponse` for Bedrock models."""
362
+
363
+ _model_name: BedrockModelName
364
+ _event_stream: EventStream[ConverseStreamOutputTypeDef]
365
+ _timestamp: datetime = field(default_factory=_utils.now_utc)
366
+
367
+ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
368
+ """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
369
+
370
+ This method should be implemented by subclasses to translate the vendor-specific stream of events into
371
+ pydantic_ai-format events.
372
+ """
373
+ chunk: ConverseStreamOutputTypeDef
374
+ tool_id: str | None = None
375
+ async for chunk in _AsyncIteratorWrapper(self._event_stream):
376
+ # TODO(Marcelo): Switch this to `match` when we drop Python 3.9 support.
377
+ if 'messageStart' in chunk:
378
+ continue
379
+ if 'messageStop' in chunk:
380
+ continue
381
+ if 'metadata' in chunk:
382
+ if 'usage' in chunk['metadata']:
383
+ self._usage += self._map_usage(chunk['metadata'])
384
+ continue
385
+ if 'contentBlockStart' in chunk:
386
+ index = chunk['contentBlockStart']['contentBlockIndex']
387
+ start = chunk['contentBlockStart']['start']
388
+ if 'toolUse' in start:
389
+ tool_use_start = start['toolUse']
390
+ tool_id = tool_use_start['toolUseId']
391
+ tool_name = tool_use_start['name']
392
+ maybe_event = self._parts_manager.handle_tool_call_delta(
393
+ vendor_part_id=index,
394
+ tool_name=tool_name,
395
+ args=None,
396
+ tool_call_id=tool_id,
397
+ )
398
+ if maybe_event:
399
+ yield maybe_event
400
+ if 'contentBlockDelta' in chunk:
401
+ index = chunk['contentBlockDelta']['contentBlockIndex']
402
+ delta = chunk['contentBlockDelta']['delta']
403
+ if 'text' in delta:
404
+ yield self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text'])
405
+ if 'toolUse' in delta:
406
+ tool_use = delta['toolUse']
407
+ maybe_event = self._parts_manager.handle_tool_call_delta(
408
+ vendor_part_id=index,
409
+ tool_name=tool_use.get('name'),
410
+ args=tool_use.get('input'),
411
+ tool_call_id=tool_id,
412
+ )
413
+ if maybe_event:
414
+ yield maybe_event
415
+
416
+ @property
417
+ def timestamp(self) -> datetime:
418
+ return self._timestamp
419
+
420
+ @property
421
+ def model_name(self) -> str:
422
+ """Get the model name of the response."""
423
+ return self._model_name
424
+
425
+ def _map_usage(self, metadata: ConverseStreamMetadataEventTypeDef) -> result.Usage:
426
+ return result.Usage(
427
+ request_tokens=metadata['usage']['inputTokens'],
428
+ response_tokens=metadata['usage']['outputTokens'],
429
+ total_tokens=metadata['usage']['totalTokens'],
430
+ )
431
+
432
+
433
+ class _AsyncIteratorWrapper(Generic[T]):
434
+ """Wrap a synchronous iterator in an async iterator."""
435
+
436
+ def __init__(self, sync_iterator: Iterable[T]):
437
+ self.sync_iterator = iter(sync_iterator)
438
+
439
+ def __aiter__(self):
440
+ return self
441
+
442
+ async def __anext__(self) -> T:
443
+ try:
444
+ # Run the synchronous next() call in a thread pool
445
+ item = await anyio.to_thread.run_sync(next, self.sync_iterator)
446
+ return item
447
+ except RuntimeError as e:
448
+ if type(e.__cause__) is StopIteration:
449
+ raise StopAsyncIteration
450
+ else:
451
+ raise e
@@ -224,15 +224,13 @@ class BearerTokenAuth:
224
224
 
225
225
 
226
226
  VertexAiRegion = Literal[
227
- 'us-central1',
228
- 'us-east1',
229
- 'us-east4',
230
- 'us-south1',
231
- 'us-west1',
232
- 'us-west2',
233
- 'us-west3',
234
- 'us-west4',
235
- 'us-east5',
227
+ 'asia-east1',
228
+ 'asia-east2',
229
+ 'asia-northeast1',
230
+ 'asia-northeast3',
231
+ 'asia-south1',
232
+ 'asia-southeast1',
233
+ 'australia-southeast1',
236
234
  'europe-central2',
237
235
  'europe-north1',
238
236
  'europe-southwest1',
@@ -243,27 +241,20 @@ VertexAiRegion = Literal[
243
241
  'europe-west6',
244
242
  'europe-west8',
245
243
  'europe-west9',
246
- 'europe-west12',
247
- 'africa-south1',
248
- 'asia-east1',
249
- 'asia-east2',
250
- 'asia-northeast1',
251
- 'asia-northeast2',
252
- 'asia-northeast3',
253
- 'asia-south1',
254
- 'asia-southeast1',
255
- 'asia-southeast2',
256
- 'australia-southeast1',
257
- 'australia-southeast2',
258
244
  'me-central1',
259
245
  'me-central2',
260
246
  'me-west1',
261
247
  'northamerica-northeast1',
262
- 'northamerica-northeast2',
263
248
  'southamerica-east1',
264
- 'southamerica-west1',
249
+ 'us-central1',
250
+ 'us-east1',
251
+ 'us-east4',
252
+ 'us-east5',
253
+ 'us-south1',
254
+ 'us-west1',
255
+ 'us-west4',
265
256
  ]
266
257
  """Regions available for Vertex AI.
267
258
 
268
- More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
259
+ More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
269
260
  """
@@ -60,5 +60,10 @@ def infer_provider(provider: str) -> Provider[Any]:
60
60
  from .google_gla import GoogleGLAProvider
61
61
 
62
62
  return GoogleGLAProvider()
63
+ # NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
64
+ elif provider == 'bedrock': # pragma: no cover
65
+ from .bedrock import BedrockProvider
66
+
67
+ return BedrockProvider()
63
68
  else: # pragma: no cover
64
69
  raise ValueError(f'Unknown provider: {provider}')
@@ -0,0 +1,76 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from typing import overload
4
+
5
+ from pydantic_ai.providers import Provider
6
+
7
+ try:
8
+ import boto3
9
+ from botocore.client import BaseClient
10
+ from botocore.exceptions import NoRegionError
11
+ except ImportError as _import_error:
12
+ raise ImportError(
13
+ 'Please install `boto3` to use the Bedrock provider, '
14
+ "you can use the `bedrock` optional group — `pip install 'pydantic-ai-slim[bedrock]'`"
15
+ ) from _import_error
16
+
17
+
18
+ class BedrockProvider(Provider[BaseClient]):
19
+ """Provider for AWS Bedrock."""
20
+
21
+ @property
22
+ def name(self) -> str:
23
+ return 'bedrock'
24
+
25
+ @property
26
+ def base_url(self) -> str:
27
+ return self._client.meta.endpoint_url
28
+
29
+ @property
30
+ def client(self) -> BaseClient:
31
+ return self._client
32
+
33
+ @overload
34
+ def __init__(self, *, bedrock_client: BaseClient) -> None: ...
35
+
36
+ @overload
37
+ def __init__(
38
+ self,
39
+ *,
40
+ region_name: str | None = None,
41
+ aws_access_key_id: str | None = None,
42
+ aws_secret_access_key: str | None = None,
43
+ aws_session_token: str | None = None,
44
+ ) -> None: ...
45
+
46
+ def __init__(
47
+ self,
48
+ *,
49
+ bedrock_client: BaseClient | None = None,
50
+ region_name: str | None = None,
51
+ aws_access_key_id: str | None = None,
52
+ aws_secret_access_key: str | None = None,
53
+ aws_session_token: str | None = None,
54
+ ) -> None:
55
+ """Initialize the Bedrock provider.
56
+
57
+ Args:
58
+ bedrock_client: A boto3 client for Bedrock Runtime. If provided, other arguments are ignored.
59
+ region_name: The AWS region name.
60
+ aws_access_key_id: The AWS access key ID.
61
+ aws_secret_access_key: The AWS secret access key.
62
+ aws_session_token: The AWS session token.
63
+ """
64
+ if bedrock_client is not None:
65
+ self._client = bedrock_client
66
+ else:
67
+ try:
68
+ self._client = boto3.client( # type: ignore[reportUnknownMemberType]
69
+ 'bedrock-runtime',
70
+ aws_access_key_id=aws_access_key_id,
71
+ aws_secret_access_key=aws_secret_access_key,
72
+ aws_session_token=aws_session_token,
73
+ region_name=region_name,
74
+ )
75
+ except NoRegionError as exc: # pragma: no cover
76
+ raise ValueError('You must provide a `region_name` or a boto3 client for Bedrock Runtime.') from exc
@@ -155,15 +155,13 @@ async def _creds_from_file(service_account_file: str | Path) -> ServiceAccountCr
155
155
 
156
156
 
157
157
  VertexAiRegion = Literal[
158
- 'us-central1',
159
- 'us-east1',
160
- 'us-east4',
161
- 'us-south1',
162
- 'us-west1',
163
- 'us-west2',
164
- 'us-west3',
165
- 'us-west4',
166
- 'us-east5',
158
+ 'asia-east1',
159
+ 'asia-east2',
160
+ 'asia-northeast1',
161
+ 'asia-northeast3',
162
+ 'asia-south1',
163
+ 'asia-southeast1',
164
+ 'australia-southeast1',
167
165
  'europe-central2',
168
166
  'europe-north1',
169
167
  'europe-southwest1',
@@ -174,27 +172,20 @@ VertexAiRegion = Literal[
174
172
  'europe-west6',
175
173
  'europe-west8',
176
174
  'europe-west9',
177
- 'europe-west12',
178
- 'africa-south1',
179
- 'asia-east1',
180
- 'asia-east2',
181
- 'asia-northeast1',
182
- 'asia-northeast2',
183
- 'asia-northeast3',
184
- 'asia-south1',
185
- 'asia-southeast1',
186
- 'asia-southeast2',
187
- 'australia-southeast1',
188
- 'australia-southeast2',
189
175
  'me-central1',
190
176
  'me-central2',
191
177
  'me-west1',
192
178
  'northamerica-northeast1',
193
- 'northamerica-northeast2',
194
179
  'southamerica-east1',
195
- 'southamerica-west1',
180
+ 'us-central1',
181
+ 'us-east1',
182
+ 'us-east4',
183
+ 'us-east5',
184
+ 'us-south1',
185
+ 'us-west1',
186
+ 'us-west4',
196
187
  ]
197
188
  """Regions available for Vertex AI.
198
189
 
199
- More details [here](https://cloud.google.com/vertex-ai/docs/reference/rest#rest_endpoints).
190
+ More details [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#genai-locations).
200
191
  """
@@ -27,6 +27,7 @@ class ModelSettings(TypedDict, total=False):
27
27
  * Groq
28
28
  * Cohere
29
29
  * Mistral
30
+ * Bedrock
30
31
  """
31
32
 
32
33
  temperature: float
@@ -45,6 +46,7 @@ class ModelSettings(TypedDict, total=False):
45
46
  * Groq
46
47
  * Cohere
47
48
  * Mistral
49
+ * Bedrock
48
50
  """
49
51
 
50
52
  top_p: float
@@ -62,6 +64,7 @@ class ModelSettings(TypedDict, total=False):
62
64
  * Groq
63
65
  * Cohere
64
66
  * Mistral
67
+ * Bedrock
65
68
  """
66
69
 
67
70
  timeout: float | Timeout
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai-slim"
7
- version = "0.0.34"
7
+ version = "0.0.36"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
9
9
  authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
10
10
  license = "MIT"
@@ -36,7 +36,7 @@ dependencies = [
36
36
  "griffe>=1.3.2",
37
37
  "httpx>=0.27",
38
38
  "pydantic>=2.10",
39
- "pydantic-graph==0.0.34",
39
+ "pydantic-graph==0.0.36",
40
40
  "exceptiongroup; python_version < '3.11'",
41
41
  "opentelemetry-api>=1.28.0",
42
42
  "typing-inspection>=0.4.0",
@@ -52,6 +52,7 @@ vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
52
52
  anthropic = ["anthropic>=0.49.0"]
53
53
  groq = ["groq>=0.12.0"]
54
54
  mistral = ["mistralai>=1.2.5"]
55
+ bedrock = ["boto3>=1.34.116"]
55
56
  # Tools
56
57
  duckduckgo = ["duckduckgo-search>=7.0.0"]
57
58
  tavily = ["tavily-python>=0.5.0"]
@@ -71,6 +72,7 @@ dev = [
71
72
  "pytest-pretty>=1.2.0",
72
73
  "pytest-recording>=0.13.2",
73
74
  "diff-cover>=9.2.0",
75
+ "boto3-stubs[bedrock-runtime]",
74
76
  ]
75
77
 
76
78
  [project.scripts]