pydantic-ai-slim 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -4,13 +4,13 @@ import asyncio
4
4
  import dataclasses
5
5
  import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
- from contextlib import asynccontextmanager, contextmanager
7
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
9
  from types import FrameType
10
10
  from typing import Any, Callable, Generic, Literal, cast, final, overload
11
11
 
12
12
  import logfire_api
13
- from typing_extensions import assert_never, deprecated
13
+ from typing_extensions import TypeVar, assert_never, deprecated
14
14
 
15
15
  from . import (
16
16
  _result,
@@ -26,6 +26,7 @@ from .result import ResultData
26
26
  from .settings import ModelSettings, merge_model_settings
27
27
  from .tools import (
28
28
  AgentDeps,
29
+ DocstringFormat,
29
30
  RunContext,
30
31
  Tool,
31
32
  ToolDefinition,
@@ -57,6 +58,8 @@ EndStrategy = Literal['early', 'exhaustive']
57
58
  - `'early'`: Stop processing other tool calls once a final result is found
58
59
  - `'exhaustive'`: Process all tool calls even after finding a final result
59
60
  """
61
+ RunResultData = TypeVar('RunResultData')
62
+ """Type variable for the result data of a run where `result_type` was customized on the run call."""
60
63
 
61
64
 
62
65
  @final
@@ -99,14 +102,17 @@ class Agent(Generic[AgentDeps, ResultData]):
99
102
  Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
100
103
  be merged with this value, with the runtime argument taking priority.
101
104
  """
102
-
105
+ _result_tool_name: str = dataclasses.field(repr=False)
106
+ _result_tool_description: str | None = dataclasses.field(repr=False)
103
107
  _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
104
108
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
105
- _allow_text_result: bool = dataclasses.field(repr=False)
106
109
  _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
107
110
  _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
108
111
  _default_retries: int = dataclasses.field(repr=False)
109
112
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
113
+ _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
114
+ repr=False
115
+ )
110
116
  _deps_type: type[AgentDeps] = dataclasses.field(repr=False)
111
117
  _max_result_retries: int = dataclasses.field(repr=False)
112
118
  _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
@@ -166,11 +172,11 @@ class Agent(Generic[AgentDeps, ResultData]):
166
172
  self.end_strategy = end_strategy
167
173
  self.name = name
168
174
  self.model_settings = model_settings
175
+ self._result_tool_name = result_tool_name
176
+ self._result_tool_description = result_tool_description
169
177
  self._result_schema = _result.ResultSchema[result_type].build(
170
178
  result_type, result_tool_name, result_tool_description
171
179
  )
172
- # if the result tool is None, or its schema allows `str`, we allow plain text results
173
- self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
174
180
 
175
181
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
176
182
  self._function_tools = {}
@@ -182,13 +188,31 @@ class Agent(Generic[AgentDeps, ResultData]):
182
188
  self._register_tool(Tool(tool))
183
189
  self._deps_type = deps_type
184
190
  self._system_prompt_functions = []
191
+ self._system_prompt_dynamic_functions = {}
185
192
  self._max_result_retries = result_retries if result_retries is not None else retries
186
193
  self._result_validators = []
187
194
 
195
+ @overload
196
+ async def run(
197
+ self,
198
+ user_prompt: str,
199
+ *,
200
+ result_type: None = None,
201
+ message_history: list[_messages.ModelMessage] | None = None,
202
+ model: models.Model | models.KnownModelName | None = None,
203
+ deps: AgentDeps = None,
204
+ model_settings: ModelSettings | None = None,
205
+ usage_limits: _usage.UsageLimits | None = None,
206
+ usage: _usage.Usage | None = None,
207
+ infer_name: bool = True,
208
+ ) -> result.RunResult[ResultData]: ...
209
+
210
+ @overload
188
211
  async def run(
189
212
  self,
190
213
  user_prompt: str,
191
214
  *,
215
+ result_type: type[RunResultData],
192
216
  message_history: list[_messages.ModelMessage] | None = None,
193
217
  model: models.Model | models.KnownModelName | None = None,
194
218
  deps: AgentDeps = None,
@@ -196,7 +220,21 @@ class Agent(Generic[AgentDeps, ResultData]):
196
220
  usage_limits: _usage.UsageLimits | None = None,
197
221
  usage: _usage.Usage | None = None,
198
222
  infer_name: bool = True,
199
- ) -> result.RunResult[ResultData]:
223
+ ) -> result.RunResult[RunResultData]: ...
224
+
225
+ async def run(
226
+ self,
227
+ user_prompt: str,
228
+ *,
229
+ message_history: list[_messages.ModelMessage] | None = None,
230
+ model: models.Model | models.KnownModelName | None = None,
231
+ deps: AgentDeps = None,
232
+ model_settings: ModelSettings | None = None,
233
+ usage_limits: _usage.UsageLimits | None = None,
234
+ usage: _usage.Usage | None = None,
235
+ result_type: type[RunResultData] | None = None,
236
+ infer_name: bool = True,
237
+ ) -> result.RunResult[Any]:
200
238
  """Run the agent with a user prompt in async mode.
201
239
 
202
240
  Example:
@@ -205,12 +243,15 @@ class Agent(Generic[AgentDeps, ResultData]):
205
243
 
206
244
  agent = Agent('openai:gpt-4o')
207
245
 
208
- result_sync = agent.run_sync('What is the capital of Italy?')
209
- print(result_sync.data)
210
- #> Rome
246
+ async def main():
247
+ result = await agent.run('What is the capital of France?')
248
+ print(result.data)
249
+ #> Paris
211
250
  ```
212
251
 
213
252
  Args:
253
+ result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
254
+ result validators since result validators would expect an argument that matches the agent's result type.
214
255
  user_prompt: User input to start/continue the conversation.
215
256
  message_history: History of the conversation so far.
216
257
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
@@ -229,6 +270,7 @@ class Agent(Generic[AgentDeps, ResultData]):
229
270
 
230
271
  deps = self._get_deps(deps)
231
272
  new_message_index = len(message_history) if message_history else 0
273
+ result_schema = self._prepare_result_schema(result_type)
232
274
 
233
275
  with _logfire.span(
234
276
  '{agent_name} run {prompt=}',
@@ -252,7 +294,7 @@ class Agent(Generic[AgentDeps, ResultData]):
252
294
 
253
295
  run_context.run_step += 1
254
296
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
255
- agent_model = await self._prepare_model(run_context)
297
+ agent_model = await self._prepare_model(run_context, result_schema)
256
298
 
257
299
  with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
258
300
  model_response, request_usage = await agent_model.request(messages, model_settings)
@@ -264,7 +306,9 @@ class Agent(Generic[AgentDeps, ResultData]):
264
306
  usage_limits.check_tokens(run_context.usage)
265
307
 
266
308
  with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
267
- final_result, tool_responses = await self._handle_model_response(model_response, run_context)
309
+ final_result, tool_responses = await self._handle_model_response(
310
+ model_response, run_context, result_schema
311
+ )
268
312
 
269
313
  if tool_responses:
270
314
  # Add parts to the conversation as a new message
@@ -287,10 +331,26 @@ class Agent(Generic[AgentDeps, ResultData]):
287
331
  tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
288
332
  handle_span.message = f'handle model response -> {tool_responses_str}'
289
333
 
334
+ @overload
335
+ def run_sync(
336
+ self,
337
+ user_prompt: str,
338
+ *,
339
+ message_history: list[_messages.ModelMessage] | None = None,
340
+ model: models.Model | models.KnownModelName | None = None,
341
+ deps: AgentDeps = None,
342
+ model_settings: ModelSettings | None = None,
343
+ usage_limits: _usage.UsageLimits | None = None,
344
+ usage: _usage.Usage | None = None,
345
+ infer_name: bool = True,
346
+ ) -> result.RunResult[ResultData]: ...
347
+
348
+ @overload
290
349
  def run_sync(
291
350
  self,
292
351
  user_prompt: str,
293
352
  *,
353
+ result_type: type[RunResultData] | None,
294
354
  message_history: list[_messages.ModelMessage] | None = None,
295
355
  model: models.Model | models.KnownModelName | None = None,
296
356
  deps: AgentDeps = None,
@@ -298,7 +358,21 @@ class Agent(Generic[AgentDeps, ResultData]):
298
358
  usage_limits: _usage.UsageLimits | None = None,
299
359
  usage: _usage.Usage | None = None,
300
360
  infer_name: bool = True,
301
- ) -> result.RunResult[ResultData]:
361
+ ) -> result.RunResult[RunResultData]: ...
362
+
363
+ def run_sync(
364
+ self,
365
+ user_prompt: str,
366
+ *,
367
+ result_type: type[RunResultData] | None = None,
368
+ message_history: list[_messages.ModelMessage] | None = None,
369
+ model: models.Model | models.KnownModelName | None = None,
370
+ deps: AgentDeps = None,
371
+ model_settings: ModelSettings | None = None,
372
+ usage_limits: _usage.UsageLimits | None = None,
373
+ usage: _usage.Usage | None = None,
374
+ infer_name: bool = True,
375
+ ) -> result.RunResult[Any]:
302
376
  """Run the agent with a user prompt synchronously.
303
377
 
304
378
  This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
@@ -310,13 +384,14 @@ class Agent(Generic[AgentDeps, ResultData]):
310
384
 
311
385
  agent = Agent('openai:gpt-4o')
312
386
 
313
- async def main():
314
- result = await agent.run('What is the capital of France?')
315
- print(result.data)
316
- #> Paris
387
+ result_sync = agent.run_sync('What is the capital of Italy?')
388
+ print(result_sync.data)
389
+ #> Rome
317
390
  ```
318
391
 
319
392
  Args:
393
+ result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
394
+ result validators since result validators would expect an argument that matches the agent's result type.
320
395
  user_prompt: User input to start/continue the conversation.
321
396
  message_history: History of the conversation so far.
322
397
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
@@ -334,6 +409,7 @@ class Agent(Generic[AgentDeps, ResultData]):
334
409
  return asyncio.get_event_loop().run_until_complete(
335
410
  self.run(
336
411
  user_prompt,
412
+ result_type=result_type,
337
413
  message_history=message_history,
338
414
  model=model,
339
415
  deps=deps,
@@ -344,11 +420,42 @@ class Agent(Generic[AgentDeps, ResultData]):
344
420
  )
345
421
  )
346
422
 
423
+ @overload
424
+ def run_stream(
425
+ self,
426
+ user_prompt: str,
427
+ *,
428
+ result_type: None = None,
429
+ message_history: list[_messages.ModelMessage] | None = None,
430
+ model: models.Model | models.KnownModelName | None = None,
431
+ deps: AgentDeps = None,
432
+ model_settings: ModelSettings | None = None,
433
+ usage_limits: _usage.UsageLimits | None = None,
434
+ usage: _usage.Usage | None = None,
435
+ infer_name: bool = True,
436
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, ResultData]]: ...
437
+
438
+ @overload
439
+ def run_stream(
440
+ self,
441
+ user_prompt: str,
442
+ *,
443
+ result_type: type[RunResultData],
444
+ message_history: list[_messages.ModelMessage] | None = None,
445
+ model: models.Model | models.KnownModelName | None = None,
446
+ deps: AgentDeps = None,
447
+ model_settings: ModelSettings | None = None,
448
+ usage_limits: _usage.UsageLimits | None = None,
449
+ usage: _usage.Usage | None = None,
450
+ infer_name: bool = True,
451
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, RunResultData]]: ...
452
+
347
453
  @asynccontextmanager
348
454
  async def run_stream(
349
455
  self,
350
456
  user_prompt: str,
351
457
  *,
458
+ result_type: type[RunResultData] | None = None,
352
459
  message_history: list[_messages.ModelMessage] | None = None,
353
460
  model: models.Model | models.KnownModelName | None = None,
354
461
  deps: AgentDeps = None,
@@ -356,7 +463,7 @@ class Agent(Generic[AgentDeps, ResultData]):
356
463
  usage_limits: _usage.UsageLimits | None = None,
357
464
  usage: _usage.Usage | None = None,
358
465
  infer_name: bool = True,
359
- ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
466
+ ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, Any]]:
360
467
  """Run the agent with a user prompt in async mode, returning a streamed response.
361
468
 
362
469
  Example:
@@ -372,6 +479,8 @@ class Agent(Generic[AgentDeps, ResultData]):
372
479
  ```
373
480
 
374
481
  Args:
482
+ result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
483
+ result validators since result validators would expect an argument that matches the agent's result type.
375
484
  user_prompt: User input to start/continue the conversation.
376
485
  message_history: History of the conversation so far.
377
486
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
@@ -392,6 +501,7 @@ class Agent(Generic[AgentDeps, ResultData]):
392
501
 
393
502
  deps = self._get_deps(deps)
394
503
  new_message_index = len(message_history) if message_history else 0
504
+ result_schema = self._prepare_result_schema(result_type)
395
505
 
396
506
  with _logfire.span(
397
507
  '{agent_name} run stream {prompt=}',
@@ -415,7 +525,7 @@ class Agent(Generic[AgentDeps, ResultData]):
415
525
  usage_limits.check_before_request(run_context.usage)
416
526
 
417
527
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
418
- agent_model = await self._prepare_model(run_context)
528
+ agent_model = await self._prepare_model(run_context, result_schema)
419
529
 
420
530
  with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
421
531
  async with agent_model.request_stream(messages, model_settings) as model_response:
@@ -426,7 +536,9 @@ class Agent(Generic[AgentDeps, ResultData]):
426
536
  model_req_span.__exit__(None, None, None)
427
537
 
428
538
  with _logfire.span('handle model response') as handle_span:
429
- maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
539
+ maybe_final_result = await self._handle_streamed_response(
540
+ model_response, run_context, result_schema
541
+ )
430
542
 
431
543
  # Check if we got a final result
432
544
  if isinstance(maybe_final_result, _MarkFinalResult):
@@ -446,7 +558,7 @@ class Agent(Generic[AgentDeps, ResultData]):
446
558
  part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
447
559
  ]
448
560
  parts = await self._process_function_tools(
449
- tool_calls, result_tool_name, run_context
561
+ tool_calls, result_tool_name, run_context, result_schema
450
562
  )
451
563
  if parts:
452
564
  messages.append(_messages.ModelRequest(parts))
@@ -457,7 +569,7 @@ class Agent(Generic[AgentDeps, ResultData]):
457
569
  new_message_index,
458
570
  usage_limits,
459
571
  result_stream,
460
- self._result_schema,
572
+ result_schema,
461
573
  run_context,
462
574
  self._result_validators,
463
575
  result_tool_name,
@@ -535,17 +647,37 @@ class Agent(Generic[AgentDeps, ResultData]):
535
647
  @overload
536
648
  def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
537
649
 
650
+ @overload
651
+ def system_prompt(
652
+ self, /, *, dynamic: bool = False
653
+ ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
654
+
538
655
  def system_prompt(
539
- self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
540
- ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
656
+ self,
657
+ func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
658
+ /,
659
+ *,
660
+ dynamic: bool = False,
661
+ ) -> (
662
+ Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
663
+ | _system_prompt.SystemPromptFunc[AgentDeps]
664
+ ):
541
665
  """Decorator to register a system prompt function.
542
666
 
543
667
  Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
544
668
  Can decorate a sync or async functions.
545
669
 
670
+ The decorator can be used either bare (`agent.system_prompt`) or as a function call
671
+ (`agent.system_prompt(...)`), see the examples below.
672
+
546
673
  Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
547
674
  the type of the function, see `tests/typed_agent.py` for tests.
548
675
 
676
+ Args:
677
+ func: The function to decorate
678
+ dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
679
+ see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
680
+
549
681
  Example:
550
682
  ```python
551
683
  from pydantic_ai import Agent, RunContext
@@ -556,17 +688,27 @@ class Agent(Generic[AgentDeps, ResultData]):
556
688
  def simple_system_prompt() -> str:
557
689
  return 'foobar'
558
690
 
559
- @agent.system_prompt
691
+ @agent.system_prompt(dynamic=True)
560
692
  async def async_system_prompt(ctx: RunContext[str]) -> str:
561
693
  return f'{ctx.deps} is the best'
562
-
563
- result = agent.run_sync('foobar', deps='spam')
564
- print(result.data)
565
- #> success (no tool calls)
566
694
  ```
567
695
  """
568
- self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
569
- return func
696
+ if func is None:
697
+
698
+ def decorator(
699
+ func_: _system_prompt.SystemPromptFunc[AgentDeps],
700
+ ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
701
+ runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
702
+ self._system_prompt_functions.append(runner)
703
+ if dynamic:
704
+ self._system_prompt_dynamic_functions[func_.__qualname__] = runner
705
+ return func_
706
+
707
+ return decorator
708
+ else:
709
+ assert not dynamic, "dynamic can't be True in this case"
710
+ self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
711
+ return func
570
712
 
571
713
  @overload
572
714
  def result_validator(
@@ -633,6 +775,8 @@ class Agent(Generic[AgentDeps, ResultData]):
633
775
  *,
634
776
  retries: int | None = None,
635
777
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
778
+ docstring_format: DocstringFormat = 'auto',
779
+ require_parameter_descriptions: bool = False,
636
780
  ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
637
781
 
638
782
  def tool(
@@ -642,6 +786,8 @@ class Agent(Generic[AgentDeps, ResultData]):
642
786
  *,
643
787
  retries: int | None = None,
644
788
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
789
+ docstring_format: DocstringFormat = 'auto',
790
+ require_parameter_descriptions: bool = False,
645
791
  ) -> Any:
646
792
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
647
793
 
@@ -679,6 +825,9 @@ class Agent(Generic[AgentDeps, ResultData]):
679
825
  prepare: custom method to prepare the tool definition for each step, return `None` to omit this
680
826
  tool from a given step. This is useful if you want to customise a tool at call time,
681
827
  or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
828
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
829
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
830
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
682
831
  """
683
832
  if func is None:
684
833
 
@@ -686,13 +835,13 @@ class Agent(Generic[AgentDeps, ResultData]):
686
835
  func_: ToolFuncContext[AgentDeps, ToolParams],
687
836
  ) -> ToolFuncContext[AgentDeps, ToolParams]:
688
837
  # noinspection PyTypeChecker
689
- self._register_function(func_, True, retries, prepare)
838
+ self._register_function(func_, True, retries, prepare, docstring_format, require_parameter_descriptions)
690
839
  return func_
691
840
 
692
841
  return tool_decorator
693
842
  else:
694
843
  # noinspection PyTypeChecker
695
- self._register_function(func, True, retries, prepare)
844
+ self._register_function(func, True, retries, prepare, docstring_format, require_parameter_descriptions)
696
845
  return func
697
846
 
698
847
  @overload
@@ -705,6 +854,8 @@ class Agent(Generic[AgentDeps, ResultData]):
705
854
  *,
706
855
  retries: int | None = None,
707
856
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
857
+ docstring_format: DocstringFormat = 'auto',
858
+ require_parameter_descriptions: bool = False,
708
859
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
709
860
 
710
861
  def tool_plain(
@@ -714,6 +865,8 @@ class Agent(Generic[AgentDeps, ResultData]):
714
865
  *,
715
866
  retries: int | None = None,
716
867
  prepare: ToolPrepareFunc[AgentDeps] | None = None,
868
+ docstring_format: DocstringFormat = 'auto',
869
+ require_parameter_descriptions: bool = False,
717
870
  ) -> Any:
718
871
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
719
872
 
@@ -751,17 +904,22 @@ class Agent(Generic[AgentDeps, ResultData]):
751
904
  prepare: custom method to prepare the tool definition for each step, return `None` to omit this
752
905
  tool from a given step. This is useful if you want to customise a tool at call time,
753
906
  or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc].
907
+ docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
908
+ Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
909
+ require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
754
910
  """
755
911
  if func is None:
756
912
 
757
913
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
758
914
  # noinspection PyTypeChecker
759
- self._register_function(func_, False, retries, prepare)
915
+ self._register_function(
916
+ func_, False, retries, prepare, docstring_format, require_parameter_descriptions
917
+ )
760
918
  return func_
761
919
 
762
920
  return tool_decorator
763
921
  else:
764
- self._register_function(func, False, retries, prepare)
922
+ self._register_function(func, False, retries, prepare, docstring_format, require_parameter_descriptions)
765
923
  return func
766
924
 
767
925
  def _register_function(
@@ -770,10 +928,19 @@ class Agent(Generic[AgentDeps, ResultData]):
770
928
  takes_ctx: bool,
771
929
  retries: int | None,
772
930
  prepare: ToolPrepareFunc[AgentDeps] | None,
931
+ docstring_format: DocstringFormat,
932
+ require_parameter_descriptions: bool,
773
933
  ) -> None:
774
934
  """Private utility to register a function as a tool."""
775
935
  retries_ = retries if retries is not None else self._default_retries
776
- tool = Tool(func, takes_ctx=takes_ctx, max_retries=retries_, prepare=prepare)
936
+ tool = Tool(
937
+ func,
938
+ takes_ctx=takes_ctx,
939
+ max_retries=retries_,
940
+ prepare=prepare,
941
+ docstring_format=docstring_format,
942
+ require_parameter_descriptions=require_parameter_descriptions,
943
+ )
777
944
  self._register_tool(tool)
778
945
 
779
946
  def _register_tool(self, tool: Tool[AgentDeps]) -> None:
@@ -818,7 +985,9 @@ class Agent(Generic[AgentDeps, ResultData]):
818
985
 
819
986
  return model_
820
987
 
821
- async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
988
+ async def _prepare_model(
989
+ self, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
990
+ ) -> models.AgentModel:
822
991
  """Build tools and create an agent model."""
823
992
  function_tools: list[ToolDefinition] = []
824
993
 
@@ -831,10 +1000,39 @@ class Agent(Generic[AgentDeps, ResultData]):
831
1000
 
832
1001
  return await run_context.model.agent_model(
833
1002
  function_tools=function_tools,
834
- allow_text_result=self._allow_text_result,
835
- result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
1003
+ allow_text_result=self._allow_text_result(result_schema),
1004
+ result_tools=result_schema.tool_defs() if result_schema is not None else [],
836
1005
  )
837
1006
 
1007
+ async def _reevaluate_dynamic_prompts(
1008
+ self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
1009
+ ) -> None:
1010
+ """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
1011
+ # Only proceed if there's at least one dynamic runner.
1012
+ if self._system_prompt_dynamic_functions:
1013
+ for msg in messages:
1014
+ if isinstance(msg, _messages.ModelRequest):
1015
+ for i, part in enumerate(msg.parts):
1016
+ if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
1017
+ # Look up the runner by its ref
1018
+ if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
1019
+ updated_part_content = await runner.run(run_context)
1020
+ msg.parts[i] = _messages.SystemPromptPart(
1021
+ updated_part_content, dynamic_ref=part.dynamic_ref
1022
+ )
1023
+
1024
+ def _prepare_result_schema(
1025
+ self, result_type: type[RunResultData] | None
1026
+ ) -> _result.ResultSchema[RunResultData] | None:
1027
+ if result_type is not None:
1028
+ if self._result_validators:
1029
+ raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
1030
+ return _result.ResultSchema[result_type].build(
1031
+ result_type, self._result_tool_name, self._result_tool_description
1032
+ )
1033
+ else:
1034
+ return self._result_schema # pyright: ignore[reportReturnType]
1035
+
838
1036
  async def _prepare_messages(
839
1037
  self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
840
1038
  ) -> list[_messages.ModelMessage]:
@@ -850,8 +1048,10 @@ class Agent(Generic[AgentDeps, ResultData]):
850
1048
  ctx_messages.used = True
851
1049
 
852
1050
  if message_history:
853
- # shallow copy messages
1051
+ # Shallow copy messages
854
1052
  messages.extend(message_history)
1053
+ # Reevaluate any dynamic system prompt parts
1054
+ await self._reevaluate_dynamic_prompts(messages, run_context)
855
1055
  messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
856
1056
  else:
857
1057
  parts = await self._sys_parts(run_context)
@@ -861,8 +1061,11 @@ class Agent(Generic[AgentDeps, ResultData]):
861
1061
  return messages
862
1062
 
863
1063
  async def _handle_model_response(
864
- self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
865
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
1064
+ self,
1065
+ model_response: _messages.ModelResponse,
1066
+ run_context: RunContext[AgentDeps],
1067
+ result_schema: _result.ResultSchema[RunResultData] | None,
1068
+ ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
866
1069
  """Process a non-streamed response from the model.
867
1070
 
868
1071
  Returns:
@@ -883,19 +1086,19 @@ class Agent(Generic[AgentDeps, ResultData]):
883
1086
  # This accounts for cases like anthropic returns that might contain a text response
884
1087
  # and a tool call response, where the text response just indicates the tool call will happen.
885
1088
  if tool_calls:
886
- return await self._handle_structured_response(tool_calls, run_context)
1089
+ return await self._handle_structured_response(tool_calls, run_context, result_schema)
887
1090
  elif texts:
888
1091
  text = '\n\n'.join(texts)
889
- return await self._handle_text_response(text, run_context)
1092
+ return await self._handle_text_response(text, run_context, result_schema)
890
1093
  else:
891
1094
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
892
1095
 
893
1096
  async def _handle_text_response(
894
- self, text: str, run_context: RunContext[AgentDeps]
895
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
1097
+ self, text: str, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
1098
+ ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
896
1099
  """Handle a plain text response from the model for non-streaming responses."""
897
- if self._allow_text_result:
898
- result_data_input = cast(ResultData, text)
1100
+ if self._allow_text_result(result_schema):
1101
+ result_data_input = cast(RunResultData, text)
899
1102
  try:
900
1103
  result_data = await self._validate_result(result_data_input, run_context, None)
901
1104
  except _result.ToolRetryError as e:
@@ -911,16 +1114,19 @@ class Agent(Generic[AgentDeps, ResultData]):
911
1114
  return None, [response]
912
1115
 
913
1116
  async def _handle_structured_response(
914
- self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
915
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
1117
+ self,
1118
+ tool_calls: list[_messages.ToolCallPart],
1119
+ run_context: RunContext[AgentDeps],
1120
+ result_schema: _result.ResultSchema[RunResultData] | None,
1121
+ ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
916
1122
  """Handle a structured response containing tool calls from the model for non-streaming responses."""
917
1123
  assert tool_calls, 'Expected at least one tool call'
918
1124
 
919
1125
  # first look for the result tool call
920
- final_result: _MarkFinalResult[ResultData] | None = None
1126
+ final_result: _MarkFinalResult[RunResultData] | None = None
921
1127
 
922
1128
  parts: list[_messages.ModelRequestPart] = []
923
- if result_schema := self._result_schema:
1129
+ if result_schema is not None:
924
1130
  if match := result_schema.find_tool(tool_calls):
925
1131
  call, result_tool = match
926
1132
  try:
@@ -933,7 +1139,9 @@ class Agent(Generic[AgentDeps, ResultData]):
933
1139
  final_result = _MarkFinalResult(result_data, call.tool_name)
934
1140
 
935
1141
  # Then build the other request parts based on end strategy
936
- parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
1142
+ parts += await self._process_function_tools(
1143
+ tool_calls, final_result and final_result.tool_name, run_context, result_schema
1144
+ )
937
1145
 
938
1146
  return final_result, parts
939
1147
 
@@ -942,6 +1150,7 @@ class Agent(Generic[AgentDeps, ResultData]):
942
1150
  tool_calls: list[_messages.ToolCallPart],
943
1151
  result_tool_name: str | None,
944
1152
  run_context: RunContext[AgentDeps],
1153
+ result_schema: _result.ResultSchema[RunResultData] | None,
945
1154
  ) -> list[_messages.ModelRequestPart]:
946
1155
  """Process function (non-result) tool calls in parallel.
947
1156
 
@@ -975,7 +1184,7 @@ class Agent(Generic[AgentDeps, ResultData]):
975
1184
  )
976
1185
  else:
977
1186
  tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
978
- elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
1187
+ elif result_schema is not None and call.tool_name in result_schema.tools:
979
1188
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
980
1189
  # validation, we don't add another part here
981
1190
  if result_tool_name is not None:
@@ -987,7 +1196,7 @@ class Agent(Generic[AgentDeps, ResultData]):
987
1196
  )
988
1197
  )
989
1198
  else:
990
- parts.append(self._unknown_tool(call.tool_name, run_context))
1199
+ parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
991
1200
 
992
1201
  # Run all tool tasks in parallel
993
1202
  if tasks:
@@ -996,85 +1205,72 @@ class Agent(Generic[AgentDeps, ResultData]):
996
1205
  parts.extend(task_results)
997
1206
  return parts
998
1207
 
999
- async def _handle_streamed_model_response(
1208
+ async def _handle_streamed_response(
1000
1209
  self,
1001
- model_response: models.EitherStreamedResponse,
1210
+ streamed_response: models.StreamedResponse,
1002
1211
  run_context: RunContext[AgentDeps],
1003
- ) -> (
1004
- _MarkFinalResult[models.EitherStreamedResponse]
1005
- | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
1006
- ):
1212
+ result_schema: _result.ResultSchema[RunResultData] | None,
1213
+ ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
1007
1214
  """Process a streamed response from the model.
1008
1215
 
1009
1216
  Returns:
1010
1217
  Either a final result or a tuple of the model response and the tool responses for the next request.
1011
1218
  If a final result is returned, the conversation should end.
1012
1219
  """
1013
- if isinstance(model_response, models.StreamTextResponse):
1014
- # plain string response
1015
- if self._allow_text_result:
1016
- return _MarkFinalResult(model_response, None)
1017
- else:
1018
- self._incr_result_retry(run_context)
1019
- response = _messages.RetryPromptPart(
1020
- content='Plain text responses are not permitted, please call one of the functions instead.',
1021
- )
1022
- # stream the response, so usage is correct
1023
- async for _ in model_response:
1024
- pass
1025
-
1026
- text = ''.join(model_response.get(final=True))
1027
- return _messages.ModelResponse([_messages.TextPart(text)]), [response]
1028
- elif isinstance(model_response, models.StreamStructuredResponse):
1029
- if self._result_schema is not None:
1030
- # if there's a result schema, iterate over the stream until we find at least one tool
1031
- # NOTE: this means we ignore any other tools called here
1032
- structured_msg = model_response.get()
1033
- while not structured_msg.parts:
1034
- try:
1035
- await model_response.__anext__()
1036
- except StopAsyncIteration:
1037
- break
1038
- structured_msg = model_response.get()
1039
-
1040
- if match := self._result_schema.find_tool(structured_msg.parts):
1041
- call, _ = match
1042
- return _MarkFinalResult(model_response, call.tool_name)
1043
-
1044
- # the model is calling a tool function, consume the response to get the next message
1045
- async for _ in model_response:
1046
- pass
1047
- model_response_msg = model_response.get()
1048
- if not model_response_msg.parts:
1049
- raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
1050
-
1051
- # we now run all tool functions in parallel
1052
- tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1053
- parts: list[_messages.ModelRequestPart] = []
1054
- for item in model_response_msg.parts:
1055
- if isinstance(item, _messages.ToolCallPart):
1056
- call = item
1057
- if tool := self._function_tools.get(call.tool_name):
1058
- tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1059
- else:
1060
- parts.append(self._unknown_tool(call.tool_name, run_context))
1220
+ received_text = False
1221
+
1222
+ async for maybe_part_event in streamed_response:
1223
+ if isinstance(maybe_part_event, _messages.PartStartEvent):
1224
+ new_part = maybe_part_event.part
1225
+ if isinstance(new_part, _messages.TextPart):
1226
+ received_text = True
1227
+ if self._allow_text_result(result_schema):
1228
+ return _MarkFinalResult(streamed_response, None)
1229
+ elif isinstance(new_part, _messages.ToolCallPart):
1230
+ if result_schema is not None and (match := result_schema.find_tool([new_part])):
1231
+ call, _ = match
1232
+ return _MarkFinalResult(streamed_response, call.tool_name)
1233
+ else:
1234
+ assert_never(new_part)
1061
1235
 
1062
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1063
- task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1064
- parts.extend(task_results)
1065
- return model_response_msg, parts
1066
- else:
1067
- assert_never(model_response)
1236
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1237
+ parts: list[_messages.ModelRequestPart] = []
1238
+ model_response = streamed_response.get()
1239
+ if not model_response.parts:
1240
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
1241
+ for p in model_response.parts:
1242
+ if isinstance(p, _messages.ToolCallPart):
1243
+ if tool := self._function_tools.get(p.tool_name):
1244
+ tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
1245
+ else:
1246
+ parts.append(self._unknown_tool(p.tool_name, run_context, result_schema))
1247
+
1248
+ if received_text and not tasks and not parts:
1249
+ # Can only get here if self._allow_text_result returns `False` for the provided result_schema
1250
+ self._incr_result_retry(run_context)
1251
+ model_response = _messages.RetryPromptPart(
1252
+ content='Plain text responses are not permitted, please call one of the functions instead.',
1253
+ )
1254
+ return streamed_response.get(), [model_response]
1255
+
1256
+ with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1257
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1258
+ parts.extend(task_results)
1259
+ return model_response, parts
1068
1260
 
1069
1261
  async def _validate_result(
1070
1262
  self,
1071
- result_data: ResultData,
1263
+ result_data: RunResultData,
1072
1264
  run_context: RunContext[AgentDeps],
1073
1265
  tool_call: _messages.ToolCallPart | None,
1074
- ) -> ResultData:
1075
- for validator in self._result_validators:
1076
- result_data = await validator.validate(result_data, tool_call, run_context)
1077
- return result_data
1266
+ ) -> RunResultData:
1267
+ if self._result_validators:
1268
+ agent_result_data = cast(ResultData, result_data)
1269
+ for validator in self._result_validators:
1270
+ agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1271
+ return cast(RunResultData, agent_result_data)
1272
+ else:
1273
+ return result_data
1078
1274
 
1079
1275
  def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1080
1276
  run_context.retry += 1
@@ -1088,14 +1284,22 @@ class Agent(Generic[AgentDeps, ResultData]):
1088
1284
  messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1089
1285
  for sys_prompt_runner in self._system_prompt_functions:
1090
1286
  prompt = await sys_prompt_runner.run(run_context)
1091
- messages.append(_messages.SystemPromptPart(prompt))
1287
+ if sys_prompt_runner.dynamic:
1288
+ messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
1289
+ else:
1290
+ messages.append(_messages.SystemPromptPart(prompt))
1092
1291
  return messages
1093
1292
 
1094
- def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
1293
+ def _unknown_tool(
1294
+ self,
1295
+ tool_name: str,
1296
+ run_context: RunContext[AgentDeps],
1297
+ result_schema: _result.ResultSchema[RunResultData] | None,
1298
+ ) -> _messages.RetryPromptPart:
1095
1299
  self._incr_result_retry(run_context)
1096
1300
  names = list(self._function_tools.keys())
1097
- if self._result_schema:
1098
- names.extend(self._result_schema.tool_names())
1301
+ if result_schema:
1302
+ names.extend(result_schema.tool_names())
1099
1303
  if names:
1100
1304
  msg = f'Available tools: {", ".join(names)}'
1101
1305
  else:
@@ -1133,6 +1337,10 @@ class Agent(Generic[AgentDeps, ResultData]):
1133
1337
  self.name = name
1134
1338
  return
1135
1339
 
1340
+ @staticmethod
1341
+ def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
1342
+ return result_schema is None or result_schema.allow_text_result
1343
+
1136
1344
  @property
1137
1345
  @deprecated(
1138
1346
  'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None