pydantic-ai-slim 0.0.52__tar.gz → 0.0.54__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/.gitignore +1 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/PKG-INFO +3 -3
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_agent_graph.py +9 -6
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_cli.py +3 -5
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_utils.py +5 -1
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/agent.py +49 -9
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/__init__.py +9 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/anthropic.py +1 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/bedrock.py +16 -14
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/cohere.py +2 -1
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/gemini.py +13 -2
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/groq.py +1 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/mistral.py +2 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/openai.py +153 -7
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/settings.py +13 -5
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/tools.py +28 -5
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/README.md +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_pydantic.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_result.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/mcp.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/messages.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/usage.py +0 -0
- {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pyproject.toml +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.54
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
29
29
|
Requires-Dist: griffe>=1.3.2
|
|
30
30
|
Requires-Dist: httpx>=0.27
|
|
31
31
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
32
|
-
Requires-Dist: pydantic-graph==0.0.
|
|
32
|
+
Requires-Dist: pydantic-graph==0.0.54
|
|
33
33
|
Requires-Dist: pydantic>=2.10
|
|
34
34
|
Requires-Dist: typing-inspection>=0.4.0
|
|
35
35
|
Provides-Extra: anthropic
|
|
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
45
45
|
Provides-Extra: duckduckgo
|
|
46
46
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
47
47
|
Provides-Extra: evals
|
|
48
|
-
Requires-Dist: pydantic-evals==0.0.
|
|
48
|
+
Requires-Dist: pydantic-evals==0.0.54; extra == 'evals'
|
|
49
49
|
Provides-Extra: groq
|
|
50
50
|
Requires-Dist: groq>=0.15.0; extra == 'groq'
|
|
51
51
|
Provides-Extra: logfire
|
|
@@ -79,7 +79,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
|
|
|
79
79
|
|
|
80
80
|
user_deps: DepsT
|
|
81
81
|
|
|
82
|
-
prompt: str | Sequence[_messages.UserContent]
|
|
82
|
+
prompt: str | Sequence[_messages.UserContent] | None
|
|
83
83
|
new_message_index: int
|
|
84
84
|
|
|
85
85
|
model: models.Model
|
|
@@ -124,7 +124,7 @@ def is_agent_node(
|
|
|
124
124
|
|
|
125
125
|
@dataclasses.dataclass
|
|
126
126
|
class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
127
|
-
user_prompt: str | Sequence[_messages.UserContent]
|
|
127
|
+
user_prompt: str | Sequence[_messages.UserContent] | None
|
|
128
128
|
|
|
129
129
|
system_prompts: tuple[str, ...]
|
|
130
130
|
system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
|
|
@@ -151,7 +151,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
151
151
|
|
|
152
152
|
async def _prepare_messages(
|
|
153
153
|
self,
|
|
154
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
154
|
+
user_prompt: str | Sequence[_messages.UserContent] | None,
|
|
155
155
|
message_history: list[_messages.ModelMessage] | None,
|
|
156
156
|
run_context: RunContext[DepsT],
|
|
157
157
|
) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
|
|
@@ -166,16 +166,18 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
166
166
|
messages = ctx_messages.messages
|
|
167
167
|
ctx_messages.used = True
|
|
168
168
|
|
|
169
|
+
parts: list[_messages.ModelRequestPart] = []
|
|
169
170
|
if message_history:
|
|
170
171
|
# Shallow copy messages
|
|
171
172
|
messages.extend(message_history)
|
|
172
173
|
# Reevaluate any dynamic system prompt parts
|
|
173
174
|
await self._reevaluate_dynamic_prompts(messages, run_context)
|
|
174
|
-
return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
|
|
175
175
|
else:
|
|
176
|
-
parts
|
|
176
|
+
parts.extend(await self._sys_parts(run_context))
|
|
177
|
+
|
|
178
|
+
if user_prompt is not None:
|
|
177
179
|
parts.append(_messages.UserPromptPart(user_prompt))
|
|
178
|
-
|
|
180
|
+
return messages, _messages.ModelRequest(parts)
|
|
179
181
|
|
|
180
182
|
async def _reevaluate_dynamic_prompts(
|
|
181
183
|
self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
|
|
@@ -311,6 +313,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
|
|
|
311
313
|
return self._result
|
|
312
314
|
|
|
313
315
|
model_settings, model_request_parameters = await self._prepare_request(ctx)
|
|
316
|
+
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
|
|
314
317
|
model_response, request_usage = await ctx.deps.model.request(
|
|
315
318
|
ctx.state.message_history, model_settings, model_request_parameters
|
|
316
319
|
)
|
|
@@ -15,7 +15,7 @@ from typing_inspection.introspection import get_literal_values
|
|
|
15
15
|
|
|
16
16
|
from pydantic_ai.agent import Agent
|
|
17
17
|
from pydantic_ai.exceptions import UserError
|
|
18
|
-
from pydantic_ai.messages import ModelMessage
|
|
18
|
+
from pydantic_ai.messages import ModelMessage
|
|
19
19
|
from pydantic_ai.models import KnownModelName, infer_model
|
|
20
20
|
|
|
21
21
|
try:
|
|
@@ -222,10 +222,8 @@ async def ask_agent(
|
|
|
222
222
|
status.stop() # stopping multiple times is idempotent
|
|
223
223
|
stack.enter_context(live) # entering multiple times is idempotent
|
|
224
224
|
|
|
225
|
-
async for
|
|
226
|
-
|
|
227
|
-
content += event.delta.content_delta
|
|
228
|
-
live.update(Markdown(content, code_theme=code_theme))
|
|
225
|
+
async for content in handle_stream.stream_output():
|
|
226
|
+
live.update(Markdown(content, code_theme=code_theme))
|
|
229
227
|
|
|
230
228
|
assert agent_run.result is not None
|
|
231
229
|
return agent_run.result.all_messages()
|
|
@@ -50,7 +50,11 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
|
|
|
50
50
|
if schema.get('type') == 'object':
|
|
51
51
|
return schema
|
|
52
52
|
elif schema.get('$ref') is not None:
|
|
53
|
-
|
|
53
|
+
maybe_result = schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
|
|
54
|
+
|
|
55
|
+
if "'$ref': '#/$defs/" in str(maybe_result):
|
|
56
|
+
return schema # We can't remove the $defs because the schema contains other references
|
|
57
|
+
return maybe_result
|
|
54
58
|
else:
|
|
55
59
|
raise UserError('Schema must be an object')
|
|
56
60
|
|
|
@@ -242,7 +242,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
242
242
|
@overload
|
|
243
243
|
async def run(
|
|
244
244
|
self,
|
|
245
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
245
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
246
246
|
*,
|
|
247
247
|
result_type: None = None,
|
|
248
248
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -257,7 +257,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
257
257
|
@overload
|
|
258
258
|
async def run(
|
|
259
259
|
self,
|
|
260
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
260
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
261
261
|
*,
|
|
262
262
|
result_type: type[RunResultDataT],
|
|
263
263
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -271,7 +271,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
271
271
|
|
|
272
272
|
async def run(
|
|
273
273
|
self,
|
|
274
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
274
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
275
275
|
*,
|
|
276
276
|
result_type: type[RunResultDataT] | None = None,
|
|
277
277
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -335,7 +335,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
335
335
|
@asynccontextmanager
|
|
336
336
|
async def iter(
|
|
337
337
|
self,
|
|
338
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
338
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
339
339
|
*,
|
|
340
340
|
result_type: type[RunResultDataT] | None = None,
|
|
341
341
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -372,6 +372,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
372
372
|
print(nodes)
|
|
373
373
|
'''
|
|
374
374
|
[
|
|
375
|
+
UserPromptNode(
|
|
376
|
+
user_prompt='What is the capital of France?',
|
|
377
|
+
system_prompts=(),
|
|
378
|
+
system_prompt_functions=[],
|
|
379
|
+
system_prompt_dynamic_functions={},
|
|
380
|
+
),
|
|
375
381
|
ModelRequestNode(
|
|
376
382
|
request=ModelRequest(
|
|
377
383
|
parts=[
|
|
@@ -497,7 +503,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
497
503
|
@overload
|
|
498
504
|
def run_sync(
|
|
499
505
|
self,
|
|
500
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
506
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
501
507
|
*,
|
|
502
508
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
503
509
|
model: models.Model | models.KnownModelName | str | None = None,
|
|
@@ -511,7 +517,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
511
517
|
@overload
|
|
512
518
|
def run_sync(
|
|
513
519
|
self,
|
|
514
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
520
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
515
521
|
*,
|
|
516
522
|
result_type: type[RunResultDataT] | None,
|
|
517
523
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -525,7 +531,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
525
531
|
|
|
526
532
|
def run_sync(
|
|
527
533
|
self,
|
|
528
|
-
user_prompt: str | Sequence[_messages.UserContent],
|
|
534
|
+
user_prompt: str | Sequence[_messages.UserContent] | None = None,
|
|
529
535
|
*,
|
|
530
536
|
result_type: type[RunResultDataT] | None = None,
|
|
531
537
|
message_history: list[_messages.ModelMessage] | None = None,
|
|
@@ -940,6 +946,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
940
946
|
docstring_format: DocstringFormat = 'auto',
|
|
941
947
|
require_parameter_descriptions: bool = False,
|
|
942
948
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
949
|
+
strict: bool | None = None,
|
|
943
950
|
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
|
|
944
951
|
|
|
945
952
|
def tool(
|
|
@@ -953,6 +960,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
953
960
|
docstring_format: DocstringFormat = 'auto',
|
|
954
961
|
require_parameter_descriptions: bool = False,
|
|
955
962
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
963
|
+
strict: bool | None = None,
|
|
956
964
|
) -> Any:
|
|
957
965
|
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
|
|
958
966
|
|
|
@@ -995,6 +1003,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
995
1003
|
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
996
1004
|
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
997
1005
|
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
1006
|
+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1007
|
+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
998
1008
|
"""
|
|
999
1009
|
if func is None:
|
|
1000
1010
|
|
|
@@ -1011,6 +1021,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1011
1021
|
docstring_format,
|
|
1012
1022
|
require_parameter_descriptions,
|
|
1013
1023
|
schema_generator,
|
|
1024
|
+
strict,
|
|
1014
1025
|
)
|
|
1015
1026
|
return func_
|
|
1016
1027
|
|
|
@@ -1018,7 +1029,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1018
1029
|
else:
|
|
1019
1030
|
# noinspection PyTypeChecker
|
|
1020
1031
|
self._register_function(
|
|
1021
|
-
func,
|
|
1032
|
+
func,
|
|
1033
|
+
True,
|
|
1034
|
+
name,
|
|
1035
|
+
retries,
|
|
1036
|
+
prepare,
|
|
1037
|
+
docstring_format,
|
|
1038
|
+
require_parameter_descriptions,
|
|
1039
|
+
schema_generator,
|
|
1040
|
+
strict,
|
|
1022
1041
|
)
|
|
1023
1042
|
return func
|
|
1024
1043
|
|
|
@@ -1036,6 +1055,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1036
1055
|
docstring_format: DocstringFormat = 'auto',
|
|
1037
1056
|
require_parameter_descriptions: bool = False,
|
|
1038
1057
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1058
|
+
strict: bool | None = None,
|
|
1039
1059
|
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
|
|
1040
1060
|
|
|
1041
1061
|
def tool_plain(
|
|
@@ -1049,6 +1069,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1049
1069
|
docstring_format: DocstringFormat = 'auto',
|
|
1050
1070
|
require_parameter_descriptions: bool = False,
|
|
1051
1071
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
1072
|
+
strict: bool | None = None,
|
|
1052
1073
|
) -> Any:
|
|
1053
1074
|
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
|
|
1054
1075
|
|
|
@@ -1091,6 +1112,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1091
1112
|
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
1092
1113
|
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
1093
1114
|
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
|
|
1115
|
+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
1116
|
+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
1094
1117
|
"""
|
|
1095
1118
|
if func is None:
|
|
1096
1119
|
|
|
@@ -1105,13 +1128,22 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1105
1128
|
docstring_format,
|
|
1106
1129
|
require_parameter_descriptions,
|
|
1107
1130
|
schema_generator,
|
|
1131
|
+
strict,
|
|
1108
1132
|
)
|
|
1109
1133
|
return func_
|
|
1110
1134
|
|
|
1111
1135
|
return tool_decorator
|
|
1112
1136
|
else:
|
|
1113
1137
|
self._register_function(
|
|
1114
|
-
func,
|
|
1138
|
+
func,
|
|
1139
|
+
False,
|
|
1140
|
+
name,
|
|
1141
|
+
retries,
|
|
1142
|
+
prepare,
|
|
1143
|
+
docstring_format,
|
|
1144
|
+
require_parameter_descriptions,
|
|
1145
|
+
schema_generator,
|
|
1146
|
+
strict,
|
|
1115
1147
|
)
|
|
1116
1148
|
return func
|
|
1117
1149
|
|
|
@@ -1125,6 +1157,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1125
1157
|
docstring_format: DocstringFormat,
|
|
1126
1158
|
require_parameter_descriptions: bool,
|
|
1127
1159
|
schema_generator: type[GenerateJsonSchema],
|
|
1160
|
+
strict: bool | None,
|
|
1128
1161
|
) -> None:
|
|
1129
1162
|
"""Private utility to register a function as a tool."""
|
|
1130
1163
|
retries_ = retries if retries is not None else self._default_retries
|
|
@@ -1137,6 +1170,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
|
|
|
1137
1170
|
docstring_format=docstring_format,
|
|
1138
1171
|
require_parameter_descriptions=require_parameter_descriptions,
|
|
1139
1172
|
schema_generator=schema_generator,
|
|
1173
|
+
strict=strict,
|
|
1140
1174
|
)
|
|
1141
1175
|
self._register_tool(tool)
|
|
1142
1176
|
|
|
@@ -1327,6 +1361,12 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
|
|
|
1327
1361
|
print(nodes)
|
|
1328
1362
|
'''
|
|
1329
1363
|
[
|
|
1364
|
+
UserPromptNode(
|
|
1365
|
+
user_prompt='What is the capital of France?',
|
|
1366
|
+
system_prompts=(),
|
|
1367
|
+
system_prompt_functions=[],
|
|
1368
|
+
system_prompt_dynamic_functions={},
|
|
1369
|
+
),
|
|
1330
1370
|
ModelRequestNode(
|
|
1331
1371
|
request=ModelRequest(
|
|
1332
1372
|
parts=[
|
|
@@ -274,6 +274,15 @@ class Model(ABC):
|
|
|
274
274
|
# noinspection PyUnreachableCode
|
|
275
275
|
yield # pragma: no cover
|
|
276
276
|
|
|
277
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
278
|
+
"""Customize the request parameters for the model.
|
|
279
|
+
|
|
280
|
+
This method can be overridden by subclasses to modify the request parameters before sending them to the model.
|
|
281
|
+
In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary
|
|
282
|
+
for vendor/model-specific reasons.
|
|
283
|
+
"""
|
|
284
|
+
return model_request_parameters
|
|
285
|
+
|
|
277
286
|
@property
|
|
278
287
|
@abstractmethod
|
|
279
288
|
def model_name(self) -> str:
|
|
@@ -226,6 +226,7 @@ class AnthropicModel(Model):
|
|
|
226
226
|
tools=tools or NOT_GIVEN,
|
|
227
227
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
228
228
|
stream=stream,
|
|
229
|
+
stop_sequences=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
229
230
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
230
231
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
231
232
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
@@ -42,12 +42,14 @@ if TYPE_CHECKING:
|
|
|
42
42
|
from mypy_boto3_bedrock_runtime.type_defs import (
|
|
43
43
|
ContentBlockOutputTypeDef,
|
|
44
44
|
ContentBlockUnionTypeDef,
|
|
45
|
+
ConverseRequestTypeDef,
|
|
45
46
|
ConverseResponseTypeDef,
|
|
46
47
|
ConverseStreamMetadataEventTypeDef,
|
|
47
48
|
ConverseStreamOutputTypeDef,
|
|
48
49
|
ImageBlockTypeDef,
|
|
49
50
|
InferenceConfigurationTypeDef,
|
|
50
51
|
MessageUnionTypeDef,
|
|
52
|
+
SystemContentBlockTypeDef,
|
|
51
53
|
ToolChoiceTypeDef,
|
|
52
54
|
ToolTypeDef,
|
|
53
55
|
)
|
|
@@ -258,20 +260,19 @@ class BedrockConverseModel(Model):
|
|
|
258
260
|
else:
|
|
259
261
|
tool_choice = {'auto': {}}
|
|
260
262
|
|
|
261
|
-
system_prompt, bedrock_messages = await self.
|
|
263
|
+
system_prompt, bedrock_messages = await self._map_messages(messages)
|
|
262
264
|
inference_config = self._map_inference_config(model_settings)
|
|
263
265
|
|
|
264
|
-
params = {
|
|
266
|
+
params: ConverseRequestTypeDef = {
|
|
265
267
|
'modelId': self.model_name,
|
|
266
268
|
'messages': bedrock_messages,
|
|
267
|
-
'system':
|
|
269
|
+
'system': system_prompt,
|
|
268
270
|
'inferenceConfig': inference_config,
|
|
269
|
-
**(
|
|
270
|
-
{'toolConfig': {'tools': tools, **({'toolChoice': tool_choice} if tool_choice else {})}}
|
|
271
|
-
if tools
|
|
272
|
-
else {}
|
|
273
|
-
),
|
|
274
271
|
}
|
|
272
|
+
if tools:
|
|
273
|
+
params['toolConfig'] = {'tools': tools}
|
|
274
|
+
if tool_choice:
|
|
275
|
+
params['toolConfig']['toolChoice'] = tool_choice
|
|
275
276
|
|
|
276
277
|
if stream:
|
|
277
278
|
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
|
|
@@ -293,21 +294,22 @@ class BedrockConverseModel(Model):
|
|
|
293
294
|
inference_config['temperature'] = temperature
|
|
294
295
|
if top_p := model_settings.get('top_p'):
|
|
295
296
|
inference_config['topP'] = top_p
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
# inference_config['stopSequences'] = stop_sequences
|
|
297
|
+
if stop_sequences := model_settings.get('stop_sequences'):
|
|
298
|
+
inference_config['stopSequences'] = stop_sequences
|
|
299
299
|
|
|
300
300
|
return inference_config
|
|
301
301
|
|
|
302
|
-
async def
|
|
302
|
+
async def _map_messages(
|
|
303
|
+
self, messages: list[ModelMessage]
|
|
304
|
+
) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
|
|
303
305
|
"""Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
|
|
304
|
-
system_prompt:
|
|
306
|
+
system_prompt: list[SystemContentBlockTypeDef] = []
|
|
305
307
|
bedrock_messages: list[MessageUnionTypeDef] = []
|
|
306
308
|
for m in messages:
|
|
307
309
|
if isinstance(m, ModelRequest):
|
|
308
310
|
for part in m.parts:
|
|
309
311
|
if isinstance(part, SystemPromptPart):
|
|
310
|
-
system_prompt
|
|
312
|
+
system_prompt.append({'text': part.content})
|
|
311
313
|
elif isinstance(part, UserPromptPart):
|
|
312
314
|
bedrock_messages.extend(await self._map_user_prompt(part))
|
|
313
315
|
elif isinstance(part, ToolReturnPart):
|
|
@@ -118,7 +118,7 @@ class CohereModel(Model):
|
|
|
118
118
|
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
|
|
119
119
|
created using the other parameters.
|
|
120
120
|
"""
|
|
121
|
-
self._model_name
|
|
121
|
+
self._model_name = model_name
|
|
122
122
|
|
|
123
123
|
if isinstance(provider, str):
|
|
124
124
|
provider = infer_provider(provider)
|
|
@@ -163,6 +163,7 @@ class CohereModel(Model):
|
|
|
163
163
|
messages=cohere_messages,
|
|
164
164
|
tools=tools or OMIT,
|
|
165
165
|
max_tokens=model_settings.get('max_tokens', OMIT),
|
|
166
|
+
stop_sequences=model_settings.get('stop_sequences', OMIT),
|
|
166
167
|
temperature=model_settings.get('temperature', OMIT),
|
|
167
168
|
p=model_settings.get('top_p', OMIT),
|
|
168
169
|
seed=model_settings.get('seed', OMIT),
|
|
@@ -5,7 +5,7 @@ import re
|
|
|
5
5
|
from collections.abc import AsyncIterator, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
7
|
from copy import deepcopy
|
|
8
|
-
from dataclasses import dataclass, field
|
|
8
|
+
from dataclasses import dataclass, field, replace
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Annotated, Any, Literal, Protocol, Union, cast
|
|
11
11
|
from uuid import uuid4
|
|
@@ -152,6 +152,16 @@ class GeminiModel(Model):
|
|
|
152
152
|
) as http_response:
|
|
153
153
|
yield await self._process_streamed_response(http_response)
|
|
154
154
|
|
|
155
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
156
|
+
def _customize_tool_def(t: ToolDefinition):
|
|
157
|
+
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).simplify())
|
|
158
|
+
|
|
159
|
+
return ModelRequestParameters(
|
|
160
|
+
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
161
|
+
allow_text_result=model_request_parameters.allow_text_result,
|
|
162
|
+
result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
|
|
163
|
+
)
|
|
164
|
+
|
|
155
165
|
@property
|
|
156
166
|
def model_name(self) -> GeminiModelName:
|
|
157
167
|
"""The model name."""
|
|
@@ -496,6 +506,7 @@ class _GeminiGenerationConfig(TypedDict, total=False):
|
|
|
496
506
|
top_p: float
|
|
497
507
|
presence_penalty: float
|
|
498
508
|
frequency_penalty: float
|
|
509
|
+
stop_sequences: list[str]
|
|
499
510
|
|
|
500
511
|
|
|
501
512
|
class _GeminiContent(TypedDict):
|
|
@@ -640,7 +651,7 @@ class _GeminiFunction(TypedDict):
|
|
|
640
651
|
|
|
641
652
|
|
|
642
653
|
def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
|
|
643
|
-
json_schema =
|
|
654
|
+
json_schema = tool.parameters_json_schema
|
|
644
655
|
f = _GeminiFunction(name=tool.name, description=tool.description)
|
|
645
656
|
if json_schema.get('properties'):
|
|
646
657
|
f['parameters'] = json_schema
|
|
@@ -208,6 +208,7 @@ class GroqModel(Model):
|
|
|
208
208
|
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
|
|
209
209
|
tools=tools or NOT_GIVEN,
|
|
210
210
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
211
|
+
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
211
212
|
stream=stream,
|
|
212
213
|
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
213
214
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
@@ -199,6 +199,7 @@ class MistralModel(Model):
|
|
|
199
199
|
top_p=model_settings.get('top_p', 1),
|
|
200
200
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
201
201
|
random_seed=model_settings.get('seed', UNSET),
|
|
202
|
+
stop=model_settings.get('stop_sequences', None),
|
|
202
203
|
)
|
|
203
204
|
except SDKError as e:
|
|
204
205
|
if (status_code := e.status_code) >= 400:
|
|
@@ -236,6 +237,7 @@ class MistralModel(Model):
|
|
|
236
237
|
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
|
|
237
238
|
presence_penalty=model_settings.get('presence_penalty'),
|
|
238
239
|
frequency_penalty=model_settings.get('frequency_penalty'),
|
|
240
|
+
stop=model_settings.get('stop_sequences', None),
|
|
239
241
|
)
|
|
240
242
|
|
|
241
243
|
elif model_request_parameters.result_tools:
|
|
@@ -4,9 +4,9 @@ import base64
|
|
|
4
4
|
import warnings
|
|
5
5
|
from collections.abc import AsyncIterable, AsyncIterator, Sequence
|
|
6
6
|
from contextlib import asynccontextmanager
|
|
7
|
-
from dataclasses import dataclass, field
|
|
7
|
+
from dataclasses import dataclass, field, replace
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
|
-
from typing import Literal, Union, cast, overload
|
|
9
|
+
from typing import Any, Literal, Union, cast, overload
|
|
10
10
|
|
|
11
11
|
from typing_extensions import assert_never
|
|
12
12
|
|
|
@@ -150,7 +150,7 @@ class OpenAIModel(Model):
|
|
|
150
150
|
"""
|
|
151
151
|
|
|
152
152
|
client: AsyncOpenAI = field(repr=False)
|
|
153
|
-
system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
|
|
153
|
+
system_prompt_role: OpenAISystemPromptRole | None = field(default=None, repr=False)
|
|
154
154
|
|
|
155
155
|
_model_name: OpenAIModelName = field(repr=False)
|
|
156
156
|
_system: str = field(default='openai', repr=False)
|
|
@@ -208,6 +208,9 @@ class OpenAIModel(Model):
|
|
|
208
208
|
async with response:
|
|
209
209
|
yield await self._process_streamed_response(response)
|
|
210
210
|
|
|
211
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
212
|
+
return _customize_request_parameters(model_request_parameters)
|
|
213
|
+
|
|
211
214
|
@property
|
|
212
215
|
def model_name(self) -> OpenAIModelName:
|
|
213
216
|
"""The model name."""
|
|
@@ -268,6 +271,7 @@ class OpenAIModel(Model):
|
|
|
268
271
|
tool_choice=tool_choice or NOT_GIVEN,
|
|
269
272
|
stream=stream,
|
|
270
273
|
stream_options={'include_usage': True} if stream else NOT_GIVEN,
|
|
274
|
+
stop=model_settings.get('stop_sequences', NOT_GIVEN),
|
|
271
275
|
max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
|
|
272
276
|
temperature=model_settings.get('temperature', NOT_GIVEN),
|
|
273
277
|
top_p=model_settings.get('top_p', NOT_GIVEN),
|
|
@@ -351,7 +355,7 @@ class OpenAIModel(Model):
|
|
|
351
355
|
|
|
352
356
|
@staticmethod
|
|
353
357
|
def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
|
|
354
|
-
|
|
358
|
+
tool_param: chat.ChatCompletionToolParam = {
|
|
355
359
|
'type': 'function',
|
|
356
360
|
'function': {
|
|
357
361
|
'name': f.name,
|
|
@@ -359,6 +363,9 @@ class OpenAIModel(Model):
|
|
|
359
363
|
'parameters': f.parameters_json_schema,
|
|
360
364
|
},
|
|
361
365
|
}
|
|
366
|
+
if f.strict:
|
|
367
|
+
tool_param['function']['strict'] = f.strict
|
|
368
|
+
return tool_param
|
|
362
369
|
|
|
363
370
|
async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
|
|
364
371
|
for part in message.parts:
|
|
@@ -522,6 +529,9 @@ class OpenAIResponsesModel(Model):
|
|
|
522
529
|
async with response:
|
|
523
530
|
yield await self._process_streamed_response(response)
|
|
524
531
|
|
|
532
|
+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
533
|
+
return _customize_request_parameters(model_request_parameters)
|
|
534
|
+
|
|
525
535
|
def _process_response(self, response: responses.Response) -> ModelResponse:
|
|
526
536
|
"""Process a non-streamed response, and prepare a message to return."""
|
|
527
537
|
timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
|
|
@@ -602,7 +612,7 @@ class OpenAIResponsesModel(Model):
|
|
|
602
612
|
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
|
|
603
613
|
timeout=model_settings.get('timeout', NOT_GIVEN),
|
|
604
614
|
reasoning=reasoning,
|
|
605
|
-
user=model_settings.get('
|
|
615
|
+
user=model_settings.get('openai_user', NOT_GIVEN),
|
|
606
616
|
)
|
|
607
617
|
except APIStatusError as e:
|
|
608
618
|
if (status_code := e.status_code) >= 400:
|
|
@@ -630,8 +640,8 @@ class OpenAIResponsesModel(Model):
|
|
|
630
640
|
'parameters': f.parameters_json_schema,
|
|
631
641
|
'type': 'function',
|
|
632
642
|
'description': f.description,
|
|
633
|
-
#
|
|
634
|
-
'strict': False,
|
|
643
|
+
# NOTE: f.strict should already be a boolean thanks to customize_request_parameters
|
|
644
|
+
'strict': f.strict or False,
|
|
635
645
|
}
|
|
636
646
|
|
|
637
647
|
async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[responses.ResponseInputItemParam]]:
|
|
@@ -907,3 +917,139 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
|
|
|
907
917
|
total_tokens=response_usage.total_tokens,
|
|
908
918
|
details=details,
|
|
909
919
|
)
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
class _StrictSchemaHelper:
|
|
923
|
+
def make_schema_strict(self, schema: dict[str, Any]) -> dict[str, Any]:
|
|
924
|
+
"""Recursively handle the schema to make it compatible with OpenAI strict mode.
|
|
925
|
+
|
|
926
|
+
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
|
|
927
|
+
but this basically just requires:
|
|
928
|
+
* `additionalProperties` must be set to false for each object in the parameters
|
|
929
|
+
* all fields in properties must be marked as required
|
|
930
|
+
"""
|
|
931
|
+
assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
|
|
932
|
+
|
|
933
|
+
# Create a copy to avoid modifying the original schema
|
|
934
|
+
schema = schema.copy()
|
|
935
|
+
|
|
936
|
+
# Handle $defs
|
|
937
|
+
if defs := schema.get('$defs'):
|
|
938
|
+
schema['$defs'] = {k: self.make_schema_strict(v) for k, v in defs.items()}
|
|
939
|
+
|
|
940
|
+
# Process schema based on its type
|
|
941
|
+
schema_type = schema.get('type')
|
|
942
|
+
if schema_type == 'object':
|
|
943
|
+
# Handle object type by setting additionalProperties to false
|
|
944
|
+
# and adding all properties to required list
|
|
945
|
+
self._make_object_schema_strict(schema)
|
|
946
|
+
elif schema_type == 'array':
|
|
947
|
+
# Handle array types by processing their items
|
|
948
|
+
if 'items' in schema:
|
|
949
|
+
items: Any = schema['items']
|
|
950
|
+
schema['items'] = self.make_schema_strict(items)
|
|
951
|
+
if 'prefixItems' in schema:
|
|
952
|
+
prefix_items: list[Any] = schema['prefixItems']
|
|
953
|
+
schema['prefixItems'] = [self.make_schema_strict(item) for item in prefix_items]
|
|
954
|
+
|
|
955
|
+
elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
|
|
956
|
+
pass # Primitive types need no special handling
|
|
957
|
+
elif 'oneOf' in schema:
|
|
958
|
+
schema['oneOf'] = [self.make_schema_strict(item) for item in schema['oneOf']]
|
|
959
|
+
elif 'anyOf' in schema:
|
|
960
|
+
schema['anyOf'] = [self.make_schema_strict(item) for item in schema['anyOf']]
|
|
961
|
+
|
|
962
|
+
return schema
|
|
963
|
+
|
|
964
|
+
def _make_object_schema_strict(self, schema: dict[str, Any]) -> None:
|
|
965
|
+
schema['additionalProperties'] = False
|
|
966
|
+
|
|
967
|
+
# Handle patternProperties; note this may not be compatible with strict mode but is included for completeness
|
|
968
|
+
if 'patternProperties' in schema and isinstance(schema['patternProperties'], dict):
|
|
969
|
+
pattern_props: dict[str, Any] = schema['patternProperties']
|
|
970
|
+
schema['patternProperties'] = {str(k): self.make_schema_strict(v) for k, v in pattern_props.items()}
|
|
971
|
+
|
|
972
|
+
# Handle properties — update their schemas recursively, and make all properties required
|
|
973
|
+
if 'properties' in schema and isinstance(schema['properties'], dict):
|
|
974
|
+
properties: dict[str, Any] = schema['properties']
|
|
975
|
+
schema['properties'] = {k: self.make_schema_strict(v) for k, v in properties.items()}
|
|
976
|
+
schema['required'] = list(properties.keys())
|
|
977
|
+
|
|
978
|
+
def is_schema_strict(self, schema: dict[str, Any]) -> bool:
|
|
979
|
+
"""Check if the schema is strict-mode-compatible.
|
|
980
|
+
|
|
981
|
+
A schema is compatible if:
|
|
982
|
+
* `additionalProperties` is set to false for each object in the parameters
|
|
983
|
+
* all fields in properties are marked as required
|
|
984
|
+
|
|
985
|
+
See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details.
|
|
986
|
+
"""
|
|
987
|
+
assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
|
|
988
|
+
|
|
989
|
+
# Note that checking the defs first is usually the fastest way to proceed, but
|
|
990
|
+
# it makes it hard/impossible to hit coverage below, hence all the pragma no covers.
|
|
991
|
+
# I still included the handling below because I'm not _confident_ those code paths can't be hit.
|
|
992
|
+
if defs := schema.get('$defs'):
|
|
993
|
+
if not all(self.is_schema_strict(v) for v in defs.values()): # pragma: no branch
|
|
994
|
+
return False
|
|
995
|
+
|
|
996
|
+
schema_type = schema.get('type')
|
|
997
|
+
if schema_type == 'object':
|
|
998
|
+
if not self._is_object_schema_strict(schema):
|
|
999
|
+
return False
|
|
1000
|
+
elif schema_type == 'array':
|
|
1001
|
+
if 'items' in schema:
|
|
1002
|
+
items: Any = schema['items']
|
|
1003
|
+
if not self.is_schema_strict(items): # pragma: no cover
|
|
1004
|
+
return False
|
|
1005
|
+
if 'prefixItems' in schema:
|
|
1006
|
+
prefix_items: list[Any] = schema['prefixItems']
|
|
1007
|
+
if not all(self.is_schema_strict(item) for item in prefix_items): # pragma: no cover
|
|
1008
|
+
return False
|
|
1009
|
+
elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
|
|
1010
|
+
pass
|
|
1011
|
+
elif 'oneOf' in schema: # pragma: no cover
|
|
1012
|
+
if not all(self.is_schema_strict(item) for item in schema['oneOf']):
|
|
1013
|
+
return False
|
|
1014
|
+
|
|
1015
|
+
elif 'anyOf' in schema: # pragma: no cover
|
|
1016
|
+
if not all(self.is_schema_strict(item) for item in schema['anyOf']):
|
|
1017
|
+
return False
|
|
1018
|
+
|
|
1019
|
+
return True
|
|
1020
|
+
|
|
1021
|
+
def _is_object_schema_strict(self, schema: dict[str, Any]) -> bool:
|
|
1022
|
+
"""Check if the schema is an object and has additionalProperties set to false."""
|
|
1023
|
+
if schema.get('additionalProperties') is not False:
|
|
1024
|
+
return False
|
|
1025
|
+
if 'properties' not in schema: # pragma: no cover
|
|
1026
|
+
return False
|
|
1027
|
+
if 'required' not in schema: # pragma: no cover
|
|
1028
|
+
return False
|
|
1029
|
+
|
|
1030
|
+
for k, v in schema['properties'].items():
|
|
1031
|
+
if k not in schema['required']:
|
|
1032
|
+
return False
|
|
1033
|
+
if not self.is_schema_strict(v): # pragma: no cover
|
|
1034
|
+
return False
|
|
1035
|
+
|
|
1036
|
+
return True
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
|
|
1040
|
+
"""Customize the request parameters for OpenAI models."""
|
|
1041
|
+
|
|
1042
|
+
def _customize_tool_def(t: ToolDefinition):
|
|
1043
|
+
if t.strict is True:
|
|
1044
|
+
parameters_json_schema = _StrictSchemaHelper().make_schema_strict(t.parameters_json_schema)
|
|
1045
|
+
return replace(t, parameters_json_schema=parameters_json_schema)
|
|
1046
|
+
elif t.strict is None:
|
|
1047
|
+
strict = _StrictSchemaHelper().is_schema_strict(t.parameters_json_schema)
|
|
1048
|
+
return replace(t, strict=strict)
|
|
1049
|
+
return t
|
|
1050
|
+
|
|
1051
|
+
return ModelRequestParameters(
|
|
1052
|
+
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
|
|
1053
|
+
allow_text_result=model_request_parameters.allow_text_result,
|
|
1054
|
+
result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
|
|
1055
|
+
)
|
|
@@ -1,13 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from typing import TYPE_CHECKING
|
|
4
|
-
|
|
5
3
|
from httpx import Timeout
|
|
6
4
|
from typing_extensions import TypedDict
|
|
7
5
|
|
|
8
|
-
if TYPE_CHECKING:
|
|
9
|
-
pass
|
|
10
|
-
|
|
11
6
|
|
|
12
7
|
class ModelSettings(TypedDict, total=False):
|
|
13
8
|
"""Settings to configure an LLM.
|
|
@@ -133,6 +128,19 @@ class ModelSettings(TypedDict, total=False):
|
|
|
133
128
|
* Groq
|
|
134
129
|
"""
|
|
135
130
|
|
|
131
|
+
stop_sequences: list[str]
|
|
132
|
+
"""Sequences that will cause the model to stop generating.
|
|
133
|
+
|
|
134
|
+
Supported by:
|
|
135
|
+
|
|
136
|
+
* OpenAI
|
|
137
|
+
* Anthropic
|
|
138
|
+
* Bedrock
|
|
139
|
+
* Mistral
|
|
140
|
+
* Groq
|
|
141
|
+
* Cohere
|
|
142
|
+
"""
|
|
143
|
+
|
|
136
144
|
|
|
137
145
|
def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
|
|
138
146
|
"""Merge two sets of model settings, preferring the overrides.
|
|
@@ -48,7 +48,7 @@ class RunContext(Generic[AgentDepsT]):
|
|
|
48
48
|
"""The model used in this run."""
|
|
49
49
|
usage: Usage
|
|
50
50
|
"""LLM usage associated with the run."""
|
|
51
|
-
prompt: str | Sequence[_messages.UserContent]
|
|
51
|
+
prompt: str | Sequence[_messages.UserContent] | None
|
|
52
52
|
"""The original user prompt passed to the run."""
|
|
53
53
|
messages: list[_messages.ModelMessage] = field(default_factory=list)
|
|
54
54
|
"""Messages exchanged in the conversation so far."""
|
|
@@ -173,12 +173,18 @@ class Tool(Generic[AgentDepsT]):
|
|
|
173
173
|
prepare: ToolPrepareFunc[AgentDepsT] | None
|
|
174
174
|
docstring_format: DocstringFormat
|
|
175
175
|
require_parameter_descriptions: bool
|
|
176
|
+
strict: bool | None
|
|
176
177
|
_is_async: bool = field(init=False)
|
|
177
178
|
_single_arg_name: str | None = field(init=False)
|
|
178
179
|
_positional_fields: list[str] = field(init=False)
|
|
179
180
|
_var_positional_field: str | None = field(init=False)
|
|
180
181
|
_validator: SchemaValidator = field(init=False, repr=False)
|
|
181
|
-
|
|
182
|
+
_base_parameters_json_schema: ObjectJsonSchema = field(init=False)
|
|
183
|
+
"""
|
|
184
|
+
The base JSON schema for the tool's parameters.
|
|
185
|
+
|
|
186
|
+
This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
|
|
187
|
+
"""
|
|
182
188
|
|
|
183
189
|
# TODO: Move this state off the Tool class, which is otherwise stateless.
|
|
184
190
|
# This should be tracked inside a specific agent run, not the tool.
|
|
@@ -196,6 +202,7 @@ class Tool(Generic[AgentDepsT]):
|
|
|
196
202
|
docstring_format: DocstringFormat = 'auto',
|
|
197
203
|
require_parameter_descriptions: bool = False,
|
|
198
204
|
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
|
|
205
|
+
strict: bool | None = None,
|
|
199
206
|
):
|
|
200
207
|
"""Create a new tool instance.
|
|
201
208
|
|
|
@@ -246,6 +253,8 @@ class Tool(Generic[AgentDepsT]):
|
|
|
246
253
|
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
|
|
247
254
|
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
|
|
248
255
|
schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
|
|
256
|
+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
|
|
257
|
+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
|
|
249
258
|
"""
|
|
250
259
|
if takes_ctx is None:
|
|
251
260
|
takes_ctx = _pydantic.takes_ctx(function)
|
|
@@ -261,12 +270,13 @@ class Tool(Generic[AgentDepsT]):
|
|
|
261
270
|
self.prepare = prepare
|
|
262
271
|
self.docstring_format = docstring_format
|
|
263
272
|
self.require_parameter_descriptions = require_parameter_descriptions
|
|
273
|
+
self.strict = strict
|
|
264
274
|
self._is_async = inspect.iscoroutinefunction(self.function)
|
|
265
275
|
self._single_arg_name = f['single_arg_name']
|
|
266
276
|
self._positional_fields = f['positional_fields']
|
|
267
277
|
self._var_positional_field = f['var_positional_field']
|
|
268
278
|
self._validator = f['validator']
|
|
269
|
-
self.
|
|
279
|
+
self._base_parameters_json_schema = f['json_schema']
|
|
270
280
|
|
|
271
281
|
async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
|
|
272
282
|
"""Get the tool definition.
|
|
@@ -280,7 +290,8 @@ class Tool(Generic[AgentDepsT]):
|
|
|
280
290
|
tool_def = ToolDefinition(
|
|
281
291
|
name=self.name,
|
|
282
292
|
description=self.description,
|
|
283
|
-
parameters_json_schema=self.
|
|
293
|
+
parameters_json_schema=self._base_parameters_json_schema,
|
|
294
|
+
strict=self.strict,
|
|
284
295
|
)
|
|
285
296
|
if self.prepare is not None:
|
|
286
297
|
return await self.prepare(ctx, tool_def)
|
|
@@ -400,7 +411,7 @@ With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `ext
|
|
|
400
411
|
class ToolDefinition:
|
|
401
412
|
"""Definition of a tool passed to a model.
|
|
402
413
|
|
|
403
|
-
This is used for both function tools result tools.
|
|
414
|
+
This is used for both function tools and result tools.
|
|
404
415
|
"""
|
|
405
416
|
|
|
406
417
|
name: str
|
|
@@ -417,3 +428,15 @@ class ToolDefinition:
|
|
|
417
428
|
|
|
418
429
|
This will only be set for result tools which don't have an `object` JSON schema.
|
|
419
430
|
"""
|
|
431
|
+
|
|
432
|
+
strict: bool | None = None
|
|
433
|
+
"""Whether to enforce (vendor-specific) strict JSON schema validation for tool calls.
|
|
434
|
+
|
|
435
|
+
Setting this to `True` while using a supported model generally imposes some restrictions on the tool's JSON schema
|
|
436
|
+
in exchange for guaranteeing the API responses strictly match that schema.
|
|
437
|
+
|
|
438
|
+
When `False`, the model may be free to generate other properties or types (depending on the vendor).
|
|
439
|
+
When `None` (the default), the value will be inferred based on the compatibility of the parameters_json_schema.
|
|
440
|
+
|
|
441
|
+
Note: this is currently only supported by OpenAI models.
|
|
442
|
+
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|