pydantic-ai-slim 0.0.52__tar.gz → 0.0.54__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (51) hide show
  1. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/.gitignore +1 -0
  2. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/PKG-INFO +3 -3
  3. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_agent_graph.py +9 -6
  4. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_cli.py +3 -5
  5. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_utils.py +5 -1
  6. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/agent.py +49 -9
  7. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/__init__.py +9 -0
  8. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/anthropic.py +1 -0
  9. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/bedrock.py +16 -14
  10. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/cohere.py +2 -1
  11. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/gemini.py +13 -2
  12. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/groq.py +1 -0
  13. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/mistral.py +2 -0
  14. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/openai.py +153 -7
  15. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/settings.py +13 -5
  16. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/tools.py +28 -5
  17. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/README.md +0 -0
  18. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/__init__.py +0 -0
  19. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/__main__.py +0 -0
  20. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_griffe.py +0 -0
  21. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_parts_manager.py +0 -0
  22. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_pydantic.py +0 -0
  23. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_result.py +0 -0
  24. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/_system_prompt.py +0 -0
  25. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/common_tools/__init__.py +0 -0
  26. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  27. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/common_tools/tavily.py +0 -0
  28. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/exceptions.py +0 -0
  29. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/format_as_xml.py +0 -0
  30. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/mcp.py +0 -0
  31. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/messages.py +0 -0
  32. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/fallback.py +0 -0
  33. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/function.py +0 -0
  34. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/instrumented.py +0 -0
  35. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/test.py +0 -0
  36. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/models/wrapper.py +0 -0
  37. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/__init__.py +0 -0
  38. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/anthropic.py +0 -0
  39. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/azure.py +0 -0
  40. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/bedrock.py +0 -0
  41. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/cohere.py +0 -0
  42. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/deepseek.py +0 -0
  43. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/google_gla.py +0 -0
  44. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/google_vertex.py +0 -0
  45. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/groq.py +0 -0
  46. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/mistral.py +0 -0
  47. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/providers/openai.py +0 -0
  48. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/py.typed +0 -0
  49. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/result.py +0 -0
  50. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pydantic_ai/usage.py +0 -0
  51. {pydantic_ai_slim-0.0.52 → pydantic_ai_slim-0.0.54}/pyproject.toml +0 -0
@@ -16,3 +16,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
16
16
  /question_graph_history.json
17
17
  /docs-site/.wrangler/
18
18
  /CLAUDE.md
19
+ node_modules/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.52
3
+ Version: 0.0.54
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.52
32
+ Requires-Dist: pydantic-graph==0.0.54
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -45,7 +45,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
47
  Provides-Extra: evals
48
- Requires-Dist: pydantic-evals==0.0.52; extra == 'evals'
48
+ Requires-Dist: pydantic-evals==0.0.54; extra == 'evals'
49
49
  Provides-Extra: groq
50
50
  Requires-Dist: groq>=0.15.0; extra == 'groq'
51
51
  Provides-Extra: logfire
@@ -79,7 +79,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
79
79
 
80
80
  user_deps: DepsT
81
81
 
82
- prompt: str | Sequence[_messages.UserContent]
82
+ prompt: str | Sequence[_messages.UserContent] | None
83
83
  new_message_index: int
84
84
 
85
85
  model: models.Model
@@ -124,7 +124,7 @@ def is_agent_node(
124
124
 
125
125
  @dataclasses.dataclass
126
126
  class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
127
- user_prompt: str | Sequence[_messages.UserContent]
127
+ user_prompt: str | Sequence[_messages.UserContent] | None
128
128
 
129
129
  system_prompts: tuple[str, ...]
130
130
  system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]]
@@ -151,7 +151,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
151
151
 
152
152
  async def _prepare_messages(
153
153
  self,
154
- user_prompt: str | Sequence[_messages.UserContent],
154
+ user_prompt: str | Sequence[_messages.UserContent] | None,
155
155
  message_history: list[_messages.ModelMessage] | None,
156
156
  run_context: RunContext[DepsT],
157
157
  ) -> tuple[list[_messages.ModelMessage], _messages.ModelRequest]:
@@ -166,16 +166,18 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
166
166
  messages = ctx_messages.messages
167
167
  ctx_messages.used = True
168
168
 
169
+ parts: list[_messages.ModelRequestPart] = []
169
170
  if message_history:
170
171
  # Shallow copy messages
171
172
  messages.extend(message_history)
172
173
  # Reevaluate any dynamic system prompt parts
173
174
  await self._reevaluate_dynamic_prompts(messages, run_context)
174
- return messages, _messages.ModelRequest([_messages.UserPromptPart(user_prompt)])
175
175
  else:
176
- parts = await self._sys_parts(run_context)
176
+ parts.extend(await self._sys_parts(run_context))
177
+
178
+ if user_prompt is not None:
177
179
  parts.append(_messages.UserPromptPart(user_prompt))
178
- return messages, _messages.ModelRequest(parts)
180
+ return messages, _messages.ModelRequest(parts)
179
181
 
180
182
  async def _reevaluate_dynamic_prompts(
181
183
  self, messages: list[_messages.ModelMessage], run_context: RunContext[DepsT]
@@ -311,6 +313,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
311
313
  return self._result
312
314
 
313
315
  model_settings, model_request_parameters = await self._prepare_request(ctx)
316
+ model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
314
317
  model_response, request_usage = await ctx.deps.model.request(
315
318
  ctx.state.message_history, model_settings, model_request_parameters
316
319
  )
@@ -15,7 +15,7 @@ from typing_inspection.introspection import get_literal_values
15
15
 
16
16
  from pydantic_ai.agent import Agent
17
17
  from pydantic_ai.exceptions import UserError
18
- from pydantic_ai.messages import ModelMessage, PartDeltaEvent, TextPartDelta
18
+ from pydantic_ai.messages import ModelMessage
19
19
  from pydantic_ai.models import KnownModelName, infer_model
20
20
 
21
21
  try:
@@ -222,10 +222,8 @@ async def ask_agent(
222
222
  status.stop() # stopping multiple times is idempotent
223
223
  stack.enter_context(live) # entering multiple times is idempotent
224
224
 
225
- async for event in handle_stream:
226
- if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
227
- content += event.delta.content_delta
228
- live.update(Markdown(content, code_theme=code_theme))
225
+ async for content in handle_stream.stream_output():
226
+ live.update(Markdown(content, code_theme=code_theme))
229
227
 
230
228
  assert agent_run.result is not None
231
229
  return agent_run.result.all_messages()
@@ -50,7 +50,11 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
50
50
  if schema.get('type') == 'object':
51
51
  return schema
52
52
  elif schema.get('$ref') is not None:
53
- return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
53
+ maybe_result = schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
54
+
55
+ if "'$ref': '#/$defs/" in str(maybe_result):
56
+ return schema # We can't remove the $defs because the schema contains other references
57
+ return maybe_result
54
58
  else:
55
59
  raise UserError('Schema must be an object')
56
60
 
@@ -242,7 +242,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
242
242
  @overload
243
243
  async def run(
244
244
  self,
245
- user_prompt: str | Sequence[_messages.UserContent],
245
+ user_prompt: str | Sequence[_messages.UserContent] | None = None,
246
246
  *,
247
247
  result_type: None = None,
248
248
  message_history: list[_messages.ModelMessage] | None = None,
@@ -257,7 +257,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
257
257
  @overload
258
258
  async def run(
259
259
  self,
260
- user_prompt: str | Sequence[_messages.UserContent],
260
+ user_prompt: str | Sequence[_messages.UserContent] | None = None,
261
261
  *,
262
262
  result_type: type[RunResultDataT],
263
263
  message_history: list[_messages.ModelMessage] | None = None,
@@ -271,7 +271,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
271
271
 
272
272
  async def run(
273
273
  self,
274
- user_prompt: str | Sequence[_messages.UserContent],
274
+ user_prompt: str | Sequence[_messages.UserContent] | None = None,
275
275
  *,
276
276
  result_type: type[RunResultDataT] | None = None,
277
277
  message_history: list[_messages.ModelMessage] | None = None,
@@ -335,7 +335,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
335
335
  @asynccontextmanager
336
336
  async def iter(
337
337
  self,
338
- user_prompt: str | Sequence[_messages.UserContent],
338
+ user_prompt: str | Sequence[_messages.UserContent] | None = None,
339
339
  *,
340
340
  result_type: type[RunResultDataT] | None = None,
341
341
  message_history: list[_messages.ModelMessage] | None = None,
@@ -372,6 +372,12 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
372
372
  print(nodes)
373
373
  '''
374
374
  [
375
+ UserPromptNode(
376
+ user_prompt='What is the capital of France?',
377
+ system_prompts=(),
378
+ system_prompt_functions=[],
379
+ system_prompt_dynamic_functions={},
380
+ ),
375
381
  ModelRequestNode(
376
382
  request=ModelRequest(
377
383
  parts=[
@@ -497,7 +503,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
497
503
  @overload
498
504
  def run_sync(
499
505
  self,
500
- user_prompt: str | Sequence[_messages.UserContent],
506
+ user_prompt: str | Sequence[_messages.UserContent] | None = None,
501
507
  *,
502
508
  message_history: list[_messages.ModelMessage] | None = None,
503
509
  model: models.Model | models.KnownModelName | str | None = None,
@@ -511,7 +517,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
511
517
  @overload
512
518
  def run_sync(
513
519
  self,
514
- user_prompt: str | Sequence[_messages.UserContent],
520
+ user_prompt: str | Sequence[_messages.UserContent] | None = None,
515
521
  *,
516
522
  result_type: type[RunResultDataT] | None,
517
523
  message_history: list[_messages.ModelMessage] | None = None,
@@ -525,7 +531,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
525
531
 
526
532
  def run_sync(
527
533
  self,
528
- user_prompt: str | Sequence[_messages.UserContent],
534
+ user_prompt: str | Sequence[_messages.UserContent] | None = None,
529
535
  *,
530
536
  result_type: type[RunResultDataT] | None = None,
531
537
  message_history: list[_messages.ModelMessage] | None = None,
@@ -940,6 +946,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
940
946
  docstring_format: DocstringFormat = 'auto',
941
947
  require_parameter_descriptions: bool = False,
942
948
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
949
+ strict: bool | None = None,
943
950
  ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
944
951
 
945
952
  def tool(
@@ -953,6 +960,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
953
960
  docstring_format: DocstringFormat = 'auto',
954
961
  require_parameter_descriptions: bool = False,
955
962
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
963
+ strict: bool | None = None,
956
964
  ) -> Any:
957
965
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
958
966
 
@@ -995,6 +1003,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
995
1003
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
996
1004
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
997
1005
  schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1006
+ strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1007
+ See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
998
1008
  """
999
1009
  if func is None:
1000
1010
 
@@ -1011,6 +1021,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1011
1021
  docstring_format,
1012
1022
  require_parameter_descriptions,
1013
1023
  schema_generator,
1024
+ strict,
1014
1025
  )
1015
1026
  return func_
1016
1027
 
@@ -1018,7 +1029,15 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1018
1029
  else:
1019
1030
  # noinspection PyTypeChecker
1020
1031
  self._register_function(
1021
- func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1032
+ func,
1033
+ True,
1034
+ name,
1035
+ retries,
1036
+ prepare,
1037
+ docstring_format,
1038
+ require_parameter_descriptions,
1039
+ schema_generator,
1040
+ strict,
1022
1041
  )
1023
1042
  return func
1024
1043
 
@@ -1036,6 +1055,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1036
1055
  docstring_format: DocstringFormat = 'auto',
1037
1056
  require_parameter_descriptions: bool = False,
1038
1057
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1058
+ strict: bool | None = None,
1039
1059
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
1040
1060
 
1041
1061
  def tool_plain(
@@ -1049,6 +1069,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1049
1069
  docstring_format: DocstringFormat = 'auto',
1050
1070
  require_parameter_descriptions: bool = False,
1051
1071
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1072
+ strict: bool | None = None,
1052
1073
  ) -> Any:
1053
1074
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
1054
1075
 
@@ -1091,6 +1112,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1091
1112
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
1092
1113
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
1093
1114
  schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1115
+ strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1116
+ See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1094
1117
  """
1095
1118
  if func is None:
1096
1119
 
@@ -1105,13 +1128,22 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1105
1128
  docstring_format,
1106
1129
  require_parameter_descriptions,
1107
1130
  schema_generator,
1131
+ strict,
1108
1132
  )
1109
1133
  return func_
1110
1134
 
1111
1135
  return tool_decorator
1112
1136
  else:
1113
1137
  self._register_function(
1114
- func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1138
+ func,
1139
+ False,
1140
+ name,
1141
+ retries,
1142
+ prepare,
1143
+ docstring_format,
1144
+ require_parameter_descriptions,
1145
+ schema_generator,
1146
+ strict,
1115
1147
  )
1116
1148
  return func
1117
1149
 
@@ -1125,6 +1157,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1125
1157
  docstring_format: DocstringFormat,
1126
1158
  require_parameter_descriptions: bool,
1127
1159
  schema_generator: type[GenerateJsonSchema],
1160
+ strict: bool | None,
1128
1161
  ) -> None:
1129
1162
  """Private utility to register a function as a tool."""
1130
1163
  retries_ = retries if retries is not None else self._default_retries
@@ -1137,6 +1170,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1137
1170
  docstring_format=docstring_format,
1138
1171
  require_parameter_descriptions=require_parameter_descriptions,
1139
1172
  schema_generator=schema_generator,
1173
+ strict=strict,
1140
1174
  )
1141
1175
  self._register_tool(tool)
1142
1176
 
@@ -1327,6 +1361,12 @@ class AgentRun(Generic[AgentDepsT, ResultDataT]):
1327
1361
  print(nodes)
1328
1362
  '''
1329
1363
  [
1364
+ UserPromptNode(
1365
+ user_prompt='What is the capital of France?',
1366
+ system_prompts=(),
1367
+ system_prompt_functions=[],
1368
+ system_prompt_dynamic_functions={},
1369
+ ),
1330
1370
  ModelRequestNode(
1331
1371
  request=ModelRequest(
1332
1372
  parts=[
@@ -274,6 +274,15 @@ class Model(ABC):
274
274
  # noinspection PyUnreachableCode
275
275
  yield # pragma: no cover
276
276
 
277
+ def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
278
+ """Customize the request parameters for the model.
279
+
280
+ This method can be overridden by subclasses to modify the request parameters before sending them to the model.
281
+ In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary
282
+ for vendor/model-specific reasons.
283
+ """
284
+ return model_request_parameters
285
+
277
286
  @property
278
287
  @abstractmethod
279
288
  def model_name(self) -> str:
@@ -226,6 +226,7 @@ class AnthropicModel(Model):
226
226
  tools=tools or NOT_GIVEN,
227
227
  tool_choice=tool_choice or NOT_GIVEN,
228
228
  stream=stream,
229
+ stop_sequences=model_settings.get('stop_sequences', NOT_GIVEN),
229
230
  temperature=model_settings.get('temperature', NOT_GIVEN),
230
231
  top_p=model_settings.get('top_p', NOT_GIVEN),
231
232
  timeout=model_settings.get('timeout', NOT_GIVEN),
@@ -42,12 +42,14 @@ if TYPE_CHECKING:
42
42
  from mypy_boto3_bedrock_runtime.type_defs import (
43
43
  ContentBlockOutputTypeDef,
44
44
  ContentBlockUnionTypeDef,
45
+ ConverseRequestTypeDef,
45
46
  ConverseResponseTypeDef,
46
47
  ConverseStreamMetadataEventTypeDef,
47
48
  ConverseStreamOutputTypeDef,
48
49
  ImageBlockTypeDef,
49
50
  InferenceConfigurationTypeDef,
50
51
  MessageUnionTypeDef,
52
+ SystemContentBlockTypeDef,
51
53
  ToolChoiceTypeDef,
52
54
  ToolTypeDef,
53
55
  )
@@ -258,20 +260,19 @@ class BedrockConverseModel(Model):
258
260
  else:
259
261
  tool_choice = {'auto': {}}
260
262
 
261
- system_prompt, bedrock_messages = await self._map_message(messages)
263
+ system_prompt, bedrock_messages = await self._map_messages(messages)
262
264
  inference_config = self._map_inference_config(model_settings)
263
265
 
264
- params = {
266
+ params: ConverseRequestTypeDef = {
265
267
  'modelId': self.model_name,
266
268
  'messages': bedrock_messages,
267
- 'system': [{'text': system_prompt}],
269
+ 'system': system_prompt,
268
270
  'inferenceConfig': inference_config,
269
- **(
270
- {'toolConfig': {'tools': tools, **({'toolChoice': tool_choice} if tool_choice else {})}}
271
- if tools
272
- else {}
273
- ),
274
271
  }
272
+ if tools:
273
+ params['toolConfig'] = {'tools': tools}
274
+ if tool_choice:
275
+ params['toolConfig']['toolChoice'] = tool_choice
275
276
 
276
277
  if stream:
277
278
  model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
@@ -293,21 +294,22 @@ class BedrockConverseModel(Model):
293
294
  inference_config['temperature'] = temperature
294
295
  if top_p := model_settings.get('top_p'):
295
296
  inference_config['topP'] = top_p
296
- # TODO(Marcelo): This is not included in model_settings yet.
297
- # if stop_sequences := model_settings.get('stop_sequences'):
298
- # inference_config['stopSequences'] = stop_sequences
297
+ if stop_sequences := model_settings.get('stop_sequences'):
298
+ inference_config['stopSequences'] = stop_sequences
299
299
 
300
300
  return inference_config
301
301
 
302
- async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[MessageUnionTypeDef]]:
302
+ async def _map_messages(
303
+ self, messages: list[ModelMessage]
304
+ ) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
303
305
  """Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
304
- system_prompt: str = ''
306
+ system_prompt: list[SystemContentBlockTypeDef] = []
305
307
  bedrock_messages: list[MessageUnionTypeDef] = []
306
308
  for m in messages:
307
309
  if isinstance(m, ModelRequest):
308
310
  for part in m.parts:
309
311
  if isinstance(part, SystemPromptPart):
310
- system_prompt += part.content
312
+ system_prompt.append({'text': part.content})
311
313
  elif isinstance(part, UserPromptPart):
312
314
  bedrock_messages.extend(await self._map_user_prompt(part))
313
315
  elif isinstance(part, ToolReturnPart):
@@ -118,7 +118,7 @@ class CohereModel(Model):
118
118
  'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
119
119
  created using the other parameters.
120
120
  """
121
- self._model_name: CohereModelName = model_name
121
+ self._model_name = model_name
122
122
 
123
123
  if isinstance(provider, str):
124
124
  provider = infer_provider(provider)
@@ -163,6 +163,7 @@ class CohereModel(Model):
163
163
  messages=cohere_messages,
164
164
  tools=tools or OMIT,
165
165
  max_tokens=model_settings.get('max_tokens', OMIT),
166
+ stop_sequences=model_settings.get('stop_sequences', OMIT),
166
167
  temperature=model_settings.get('temperature', OMIT),
167
168
  p=model_settings.get('top_p', OMIT),
168
169
  seed=model_settings.get('seed', OMIT),
@@ -5,7 +5,7 @@ import re
5
5
  from collections.abc import AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
7
  from copy import deepcopy
8
- from dataclasses import dataclass, field
8
+ from dataclasses import dataclass, field, replace
9
9
  from datetime import datetime
10
10
  from typing import Annotated, Any, Literal, Protocol, Union, cast
11
11
  from uuid import uuid4
@@ -152,6 +152,16 @@ class GeminiModel(Model):
152
152
  ) as http_response:
153
153
  yield await self._process_streamed_response(http_response)
154
154
 
155
+ def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
156
+ def _customize_tool_def(t: ToolDefinition):
157
+ return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).simplify())
158
+
159
+ return ModelRequestParameters(
160
+ function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
161
+ allow_text_result=model_request_parameters.allow_text_result,
162
+ result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
163
+ )
164
+
155
165
  @property
156
166
  def model_name(self) -> GeminiModelName:
157
167
  """The model name."""
@@ -496,6 +506,7 @@ class _GeminiGenerationConfig(TypedDict, total=False):
496
506
  top_p: float
497
507
  presence_penalty: float
498
508
  frequency_penalty: float
509
+ stop_sequences: list[str]
499
510
 
500
511
 
501
512
  class _GeminiContent(TypedDict):
@@ -640,7 +651,7 @@ class _GeminiFunction(TypedDict):
640
651
 
641
652
 
642
653
  def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
643
- json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
654
+ json_schema = tool.parameters_json_schema
644
655
  f = _GeminiFunction(name=tool.name, description=tool.description)
645
656
  if json_schema.get('properties'):
646
657
  f['parameters'] = json_schema
@@ -208,6 +208,7 @@ class GroqModel(Model):
208
208
  parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
209
209
  tools=tools or NOT_GIVEN,
210
210
  tool_choice=tool_choice or NOT_GIVEN,
211
+ stop=model_settings.get('stop_sequences', NOT_GIVEN),
211
212
  stream=stream,
212
213
  max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
213
214
  temperature=model_settings.get('temperature', NOT_GIVEN),
@@ -199,6 +199,7 @@ class MistralModel(Model):
199
199
  top_p=model_settings.get('top_p', 1),
200
200
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
201
201
  random_seed=model_settings.get('seed', UNSET),
202
+ stop=model_settings.get('stop_sequences', None),
202
203
  )
203
204
  except SDKError as e:
204
205
  if (status_code := e.status_code) >= 400:
@@ -236,6 +237,7 @@ class MistralModel(Model):
236
237
  timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
237
238
  presence_penalty=model_settings.get('presence_penalty'),
238
239
  frequency_penalty=model_settings.get('frequency_penalty'),
240
+ stop=model_settings.get('stop_sequences', None),
239
241
  )
240
242
 
241
243
  elif model_request_parameters.result_tools:
@@ -4,9 +4,9 @@ import base64
4
4
  import warnings
5
5
  from collections.abc import AsyncIterable, AsyncIterator, Sequence
6
6
  from contextlib import asynccontextmanager
7
- from dataclasses import dataclass, field
7
+ from dataclasses import dataclass, field, replace
8
8
  from datetime import datetime, timezone
9
- from typing import Literal, Union, cast, overload
9
+ from typing import Any, Literal, Union, cast, overload
10
10
 
11
11
  from typing_extensions import assert_never
12
12
 
@@ -150,7 +150,7 @@ class OpenAIModel(Model):
150
150
  """
151
151
 
152
152
  client: AsyncOpenAI = field(repr=False)
153
- system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
153
+ system_prompt_role: OpenAISystemPromptRole | None = field(default=None, repr=False)
154
154
 
155
155
  _model_name: OpenAIModelName = field(repr=False)
156
156
  _system: str = field(default='openai', repr=False)
@@ -208,6 +208,9 @@ class OpenAIModel(Model):
208
208
  async with response:
209
209
  yield await self._process_streamed_response(response)
210
210
 
211
+ def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
212
+ return _customize_request_parameters(model_request_parameters)
213
+
211
214
  @property
212
215
  def model_name(self) -> OpenAIModelName:
213
216
  """The model name."""
@@ -268,6 +271,7 @@ class OpenAIModel(Model):
268
271
  tool_choice=tool_choice or NOT_GIVEN,
269
272
  stream=stream,
270
273
  stream_options={'include_usage': True} if stream else NOT_GIVEN,
274
+ stop=model_settings.get('stop_sequences', NOT_GIVEN),
271
275
  max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
272
276
  temperature=model_settings.get('temperature', NOT_GIVEN),
273
277
  top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -351,7 +355,7 @@ class OpenAIModel(Model):
351
355
 
352
356
  @staticmethod
353
357
  def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam:
354
- return {
358
+ tool_param: chat.ChatCompletionToolParam = {
355
359
  'type': 'function',
356
360
  'function': {
357
361
  'name': f.name,
@@ -359,6 +363,9 @@ class OpenAIModel(Model):
359
363
  'parameters': f.parameters_json_schema,
360
364
  },
361
365
  }
366
+ if f.strict:
367
+ tool_param['function']['strict'] = f.strict
368
+ return tool_param
362
369
 
363
370
  async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.ChatCompletionMessageParam]:
364
371
  for part in message.parts:
@@ -522,6 +529,9 @@ class OpenAIResponsesModel(Model):
522
529
  async with response:
523
530
  yield await self._process_streamed_response(response)
524
531
 
532
+ def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
533
+ return _customize_request_parameters(model_request_parameters)
534
+
525
535
  def _process_response(self, response: responses.Response) -> ModelResponse:
526
536
  """Process a non-streamed response, and prepare a message to return."""
527
537
  timestamp = datetime.fromtimestamp(response.created_at, tz=timezone.utc)
@@ -602,7 +612,7 @@ class OpenAIResponsesModel(Model):
602
612
  truncation=model_settings.get('openai_truncation', NOT_GIVEN),
603
613
  timeout=model_settings.get('timeout', NOT_GIVEN),
604
614
  reasoning=reasoning,
605
- user=model_settings.get('user', NOT_GIVEN),
615
+ user=model_settings.get('openai_user', NOT_GIVEN),
606
616
  )
607
617
  except APIStatusError as e:
608
618
  if (status_code := e.status_code) >= 400:
@@ -630,8 +640,8 @@ class OpenAIResponsesModel(Model):
630
640
  'parameters': f.parameters_json_schema,
631
641
  'type': 'function',
632
642
  'description': f.description,
633
- # TODO(Marcelo): We should make this configurable, and if True, set `additionalProperties` to False.
634
- 'strict': False,
643
+ # NOTE: f.strict should already be a boolean thanks to customize_request_parameters
644
+ 'strict': f.strict or False,
635
645
  }
636
646
 
637
647
  async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[responses.ResponseInputItemParam]]:
@@ -907,3 +917,139 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
907
917
  total_tokens=response_usage.total_tokens,
908
918
  details=details,
909
919
  )
920
+
921
+
922
+ class _StrictSchemaHelper:
923
+ def make_schema_strict(self, schema: dict[str, Any]) -> dict[str, Any]:
924
+ """Recursively handle the schema to make it compatible with OpenAI strict mode.
925
+
926
+ See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details,
927
+ but this basically just requires:
928
+ * `additionalProperties` must be set to false for each object in the parameters
929
+ * all fields in properties must be marked as required
930
+ """
931
+ assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
932
+
933
+ # Create a copy to avoid modifying the original schema
934
+ schema = schema.copy()
935
+
936
+ # Handle $defs
937
+ if defs := schema.get('$defs'):
938
+ schema['$defs'] = {k: self.make_schema_strict(v) for k, v in defs.items()}
939
+
940
+ # Process schema based on its type
941
+ schema_type = schema.get('type')
942
+ if schema_type == 'object':
943
+ # Handle object type by setting additionalProperties to false
944
+ # and adding all properties to required list
945
+ self._make_object_schema_strict(schema)
946
+ elif schema_type == 'array':
947
+ # Handle array types by processing their items
948
+ if 'items' in schema:
949
+ items: Any = schema['items']
950
+ schema['items'] = self.make_schema_strict(items)
951
+ if 'prefixItems' in schema:
952
+ prefix_items: list[Any] = schema['prefixItems']
953
+ schema['prefixItems'] = [self.make_schema_strict(item) for item in prefix_items]
954
+
955
+ elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
956
+ pass # Primitive types need no special handling
957
+ elif 'oneOf' in schema:
958
+ schema['oneOf'] = [self.make_schema_strict(item) for item in schema['oneOf']]
959
+ elif 'anyOf' in schema:
960
+ schema['anyOf'] = [self.make_schema_strict(item) for item in schema['anyOf']]
961
+
962
+ return schema
963
+
964
+ def _make_object_schema_strict(self, schema: dict[str, Any]) -> None:
965
+ schema['additionalProperties'] = False
966
+
967
+ # Handle patternProperties; note this may not be compatible with strict mode but is included for completeness
968
+ if 'patternProperties' in schema and isinstance(schema['patternProperties'], dict):
969
+ pattern_props: dict[str, Any] = schema['patternProperties']
970
+ schema['patternProperties'] = {str(k): self.make_schema_strict(v) for k, v in pattern_props.items()}
971
+
972
+ # Handle properties — update their schemas recursively, and make all properties required
973
+ if 'properties' in schema and isinstance(schema['properties'], dict):
974
+ properties: dict[str, Any] = schema['properties']
975
+ schema['properties'] = {k: self.make_schema_strict(v) for k, v in properties.items()}
976
+ schema['required'] = list(properties.keys())
977
+
978
+ def is_schema_strict(self, schema: dict[str, Any]) -> bool:
979
+ """Check if the schema is strict-mode-compatible.
980
+
981
+ A schema is compatible if:
982
+ * `additionalProperties` is set to false for each object in the parameters
983
+ * all fields in properties are marked as required
984
+
985
+ See https://platform.openai.com/docs/guides/function-calling?api-mode=responses#strict-mode for more details.
986
+ """
987
+ assert isinstance(schema, dict), 'Schema must be a dictionary, this is probably a bug'
988
+
989
+ # Note that checking the defs first is usually the fastest way to proceed, but
990
+ # it makes it hard/impossible to hit coverage below, hence all the pragma no covers.
991
+ # I still included the handling below because I'm not _confident_ those code paths can't be hit.
992
+ if defs := schema.get('$defs'):
993
+ if not all(self.is_schema_strict(v) for v in defs.values()): # pragma: no branch
994
+ return False
995
+
996
+ schema_type = schema.get('type')
997
+ if schema_type == 'object':
998
+ if not self._is_object_schema_strict(schema):
999
+ return False
1000
+ elif schema_type == 'array':
1001
+ if 'items' in schema:
1002
+ items: Any = schema['items']
1003
+ if not self.is_schema_strict(items): # pragma: no cover
1004
+ return False
1005
+ if 'prefixItems' in schema:
1006
+ prefix_items: list[Any] = schema['prefixItems']
1007
+ if not all(self.is_schema_strict(item) for item in prefix_items): # pragma: no cover
1008
+ return False
1009
+ elif schema_type in {'string', 'number', 'integer', 'boolean', 'null'}:
1010
+ pass
1011
+ elif 'oneOf' in schema: # pragma: no cover
1012
+ if not all(self.is_schema_strict(item) for item in schema['oneOf']):
1013
+ return False
1014
+
1015
+ elif 'anyOf' in schema: # pragma: no cover
1016
+ if not all(self.is_schema_strict(item) for item in schema['anyOf']):
1017
+ return False
1018
+
1019
+ return True
1020
+
1021
+ def _is_object_schema_strict(self, schema: dict[str, Any]) -> bool:
1022
+ """Check if the schema is an object and has additionalProperties set to false."""
1023
+ if schema.get('additionalProperties') is not False:
1024
+ return False
1025
+ if 'properties' not in schema: # pragma: no cover
1026
+ return False
1027
+ if 'required' not in schema: # pragma: no cover
1028
+ return False
1029
+
1030
+ for k, v in schema['properties'].items():
1031
+ if k not in schema['required']:
1032
+ return False
1033
+ if not self.is_schema_strict(v): # pragma: no cover
1034
+ return False
1035
+
1036
+ return True
1037
+
1038
+
1039
+ def _customize_request_parameters(model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
1040
+ """Customize the request parameters for OpenAI models."""
1041
+
1042
+ def _customize_tool_def(t: ToolDefinition):
1043
+ if t.strict is True:
1044
+ parameters_json_schema = _StrictSchemaHelper().make_schema_strict(t.parameters_json_schema)
1045
+ return replace(t, parameters_json_schema=parameters_json_schema)
1046
+ elif t.strict is None:
1047
+ strict = _StrictSchemaHelper().is_schema_strict(t.parameters_json_schema)
1048
+ return replace(t, strict=strict)
1049
+ return t
1050
+
1051
+ return ModelRequestParameters(
1052
+ function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
1053
+ allow_text_result=model_request_parameters.allow_text_result,
1054
+ result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
1055
+ )
@@ -1,13 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
4
-
5
3
  from httpx import Timeout
6
4
  from typing_extensions import TypedDict
7
5
 
8
- if TYPE_CHECKING:
9
- pass
10
-
11
6
 
12
7
  class ModelSettings(TypedDict, total=False):
13
8
  """Settings to configure an LLM.
@@ -133,6 +128,19 @@ class ModelSettings(TypedDict, total=False):
133
128
  * Groq
134
129
  """
135
130
 
131
+ stop_sequences: list[str]
132
+ """Sequences that will cause the model to stop generating.
133
+
134
+ Supported by:
135
+
136
+ * OpenAI
137
+ * Anthropic
138
+ * Bedrock
139
+ * Mistral
140
+ * Groq
141
+ * Cohere
142
+ """
143
+
136
144
 
137
145
  def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
138
146
  """Merge two sets of model settings, preferring the overrides.
@@ -48,7 +48,7 @@ class RunContext(Generic[AgentDepsT]):
48
48
  """The model used in this run."""
49
49
  usage: Usage
50
50
  """LLM usage associated with the run."""
51
- prompt: str | Sequence[_messages.UserContent]
51
+ prompt: str | Sequence[_messages.UserContent] | None
52
52
  """The original user prompt passed to the run."""
53
53
  messages: list[_messages.ModelMessage] = field(default_factory=list)
54
54
  """Messages exchanged in the conversation so far."""
@@ -173,12 +173,18 @@ class Tool(Generic[AgentDepsT]):
173
173
  prepare: ToolPrepareFunc[AgentDepsT] | None
174
174
  docstring_format: DocstringFormat
175
175
  require_parameter_descriptions: bool
176
+ strict: bool | None
176
177
  _is_async: bool = field(init=False)
177
178
  _single_arg_name: str | None = field(init=False)
178
179
  _positional_fields: list[str] = field(init=False)
179
180
  _var_positional_field: str | None = field(init=False)
180
181
  _validator: SchemaValidator = field(init=False, repr=False)
181
- _parameters_json_schema: ObjectJsonSchema = field(init=False)
182
+ _base_parameters_json_schema: ObjectJsonSchema = field(init=False)
183
+ """
184
+ The base JSON schema for the tool's parameters.
185
+
186
+ This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
187
+ """
182
188
 
183
189
  # TODO: Move this state off the Tool class, which is otherwise stateless.
184
190
  # This should be tracked inside a specific agent run, not the tool.
@@ -196,6 +202,7 @@ class Tool(Generic[AgentDepsT]):
196
202
  docstring_format: DocstringFormat = 'auto',
197
203
  require_parameter_descriptions: bool = False,
198
204
  schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
205
+ strict: bool | None = None,
199
206
  ):
200
207
  """Create a new tool instance.
201
208
 
@@ -246,6 +253,8 @@ class Tool(Generic[AgentDepsT]):
246
253
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
247
254
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
248
255
  schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
256
+ strict: Whether to enforce JSON schema compliance (only affects OpenAI).
257
+ See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
249
258
  """
250
259
  if takes_ctx is None:
251
260
  takes_ctx = _pydantic.takes_ctx(function)
@@ -261,12 +270,13 @@ class Tool(Generic[AgentDepsT]):
261
270
  self.prepare = prepare
262
271
  self.docstring_format = docstring_format
263
272
  self.require_parameter_descriptions = require_parameter_descriptions
273
+ self.strict = strict
264
274
  self._is_async = inspect.iscoroutinefunction(self.function)
265
275
  self._single_arg_name = f['single_arg_name']
266
276
  self._positional_fields = f['positional_fields']
267
277
  self._var_positional_field = f['var_positional_field']
268
278
  self._validator = f['validator']
269
- self._parameters_json_schema = f['json_schema']
279
+ self._base_parameters_json_schema = f['json_schema']
270
280
 
271
281
  async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None:
272
282
  """Get the tool definition.
@@ -280,7 +290,8 @@ class Tool(Generic[AgentDepsT]):
280
290
  tool_def = ToolDefinition(
281
291
  name=self.name,
282
292
  description=self.description,
283
- parameters_json_schema=self._parameters_json_schema,
293
+ parameters_json_schema=self._base_parameters_json_schema,
294
+ strict=self.strict,
284
295
  )
285
296
  if self.prepare is not None:
286
297
  return await self.prepare(ctx, tool_def)
@@ -400,7 +411,7 @@ With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `ext
400
411
  class ToolDefinition:
401
412
  """Definition of a tool passed to a model.
402
413
 
403
- This is used for both function tools result tools.
414
+ This is used for both function tools and result tools.
404
415
  """
405
416
 
406
417
  name: str
@@ -417,3 +428,15 @@ class ToolDefinition:
417
428
 
418
429
  This will only be set for result tools which don't have an `object` JSON schema.
419
430
  """
431
+
432
+ strict: bool | None = None
433
+ """Whether to enforce (vendor-specific) strict JSON schema validation for tool calls.
434
+
435
+ Setting this to `True` while using a supported model generally imposes some restrictions on the tool's JSON schema
436
+ in exchange for guaranteeing the API responses strictly match that schema.
437
+
438
+ When `False`, the model may be free to generate other properties or types (depending on the vendor).
439
+ When `None` (the default), the value will be inferred based on the compatibility of the parameters_json_schema.
440
+
441
+ Note: this is currently only supported by OpenAI models.
442
+ """