pydantic-ai-slim 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -4,13 +4,13 @@ import asyncio
4
4
  import dataclasses
5
5
  import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
- from contextlib import asynccontextmanager, contextmanager
7
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
9
  from types import FrameType
10
10
  from typing import Any, Callable, Generic, Literal, cast, final, overload
11
11
 
12
12
  import logfire_api
13
- from typing_extensions import assert_never, deprecated
13
+ from typing_extensions import TypeVar, assert_never, deprecated
14
14
 
15
15
  from . import (
16
16
  _result,
@@ -20,9 +20,10 @@ from . import (
20
20
  messages as _messages,
21
21
  models,
22
22
  result,
23
+ usage as _usage,
23
24
  )
24
25
  from .result import ResultData
25
- from .settings import ModelSettings, UsageLimits, merge_model_settings
26
+ from .settings import ModelSettings, merge_model_settings
26
27
  from .tools import (
27
28
  AgentDeps,
28
29
  RunContext,
@@ -56,6 +57,8 @@ EndStrategy = Literal['early', 'exhaustive']
56
57
  - `'early'`: Stop processing other tool calls once a final result is found
57
58
  - `'exhaustive'`: Process all tool calls even after finding a final result
58
59
  """
60
+ RunResultData = TypeVar('RunResultData')
61
+ """Type variable for the result data of a run where `result_type` was customized on the run call."""
59
62
 
60
63
 
61
64
  @final
@@ -98,14 +101,17 @@ class Agent(Generic[AgentDeps, ResultData]):
98
101
  Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
99
102
  be merged with this value, with the runtime argument taking priority.
100
103
  """
101
-
104
+ _result_tool_name: str = dataclasses.field(repr=False)
105
+ _result_tool_description: str | None = dataclasses.field(repr=False)
102
106
  _result_schema: _result.ResultSchema[ResultData] | None = dataclasses.field(repr=False)
103
107
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = dataclasses.field(repr=False)
104
- _allow_text_result: bool = dataclasses.field(repr=False)
105
108
  _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
106
109
  _function_tools: dict[str, Tool[AgentDeps]] = dataclasses.field(repr=False)
107
110
  _default_retries: int = dataclasses.field(repr=False)
108
111
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(repr=False)
112
+ _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDeps]] = dataclasses.field(
113
+ repr=False
114
+ )
109
115
  _deps_type: type[AgentDeps] = dataclasses.field(repr=False)
110
116
  _max_result_retries: int = dataclasses.field(repr=False)
111
117
  _override_deps: _utils.Option[AgentDeps] = dataclasses.field(default=None, repr=False)
@@ -165,11 +171,11 @@ class Agent(Generic[AgentDeps, ResultData]):
165
171
  self.end_strategy = end_strategy
166
172
  self.name = name
167
173
  self.model_settings = model_settings
174
+ self._result_tool_name = result_tool_name
175
+ self._result_tool_description = result_tool_description
168
176
  self._result_schema = _result.ResultSchema[result_type].build(
169
177
  result_type, result_tool_name, result_tool_description
170
178
  )
171
- # if the result tool is None, or its schema allows `str`, we allow plain text results
172
- self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
173
179
 
174
180
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
175
181
  self._function_tools = {}
@@ -181,21 +187,53 @@ class Agent(Generic[AgentDeps, ResultData]):
181
187
  self._register_tool(Tool(tool))
182
188
  self._deps_type = deps_type
183
189
  self._system_prompt_functions = []
190
+ self._system_prompt_dynamic_functions = {}
184
191
  self._max_result_retries = result_retries if result_retries is not None else retries
185
192
  self._result_validators = []
186
193
 
194
+ @overload
195
+ async def run(
196
+ self,
197
+ user_prompt: str,
198
+ *,
199
+ result_type: None = None,
200
+ message_history: list[_messages.ModelMessage] | None = None,
201
+ model: models.Model | models.KnownModelName | None = None,
202
+ deps: AgentDeps = None,
203
+ model_settings: ModelSettings | None = None,
204
+ usage_limits: _usage.UsageLimits | None = None,
205
+ usage: _usage.Usage | None = None,
206
+ infer_name: bool = True,
207
+ ) -> result.RunResult[ResultData]: ...
208
+
209
+ @overload
187
210
  async def run(
188
211
  self,
189
212
  user_prompt: str,
190
213
  *,
214
+ result_type: type[RunResultData],
191
215
  message_history: list[_messages.ModelMessage] | None = None,
192
216
  model: models.Model | models.KnownModelName | None = None,
193
217
  deps: AgentDeps = None,
194
218
  model_settings: ModelSettings | None = None,
195
- usage_limits: UsageLimits | None = None,
196
- usage: result.Usage | None = None,
219
+ usage_limits: _usage.UsageLimits | None = None,
220
+ usage: _usage.Usage | None = None,
197
221
  infer_name: bool = True,
198
- ) -> result.RunResult[ResultData]:
222
+ ) -> result.RunResult[RunResultData]: ...
223
+
224
+ async def run(
225
+ self,
226
+ user_prompt: str,
227
+ *,
228
+ message_history: list[_messages.ModelMessage] | None = None,
229
+ model: models.Model | models.KnownModelName | None = None,
230
+ deps: AgentDeps = None,
231
+ model_settings: ModelSettings | None = None,
232
+ usage_limits: _usage.UsageLimits | None = None,
233
+ usage: _usage.Usage | None = None,
234
+ result_type: type[RunResultData] | None = None,
235
+ infer_name: bool = True,
236
+ ) -> result.RunResult[Any]:
199
237
  """Run the agent with a user prompt in async mode.
200
238
 
201
239
  Example:
@@ -210,6 +248,8 @@ class Agent(Generic[AgentDeps, ResultData]):
210
248
  ```
211
249
 
212
250
  Args:
251
+ result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
252
+ result validators since result validators would expect an argument that matches the agent's result type.
213
253
  user_prompt: User input to start/continue the conversation.
214
254
  message_history: History of the conversation so far.
215
255
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
@@ -228,6 +268,7 @@ class Agent(Generic[AgentDeps, ResultData]):
228
268
 
229
269
  deps = self._get_deps(deps)
230
270
  new_message_index = len(message_history) if message_history else 0
271
+ result_schema = self._prepare_result_schema(result_type)
231
272
 
232
273
  with _logfire.span(
233
274
  '{agent_name} run {prompt=}',
@@ -236,7 +277,7 @@ class Agent(Generic[AgentDeps, ResultData]):
236
277
  model_name=model_used.name(),
237
278
  agent_name=self.name or 'agent',
238
279
  ) as run_span:
239
- run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
280
+ run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
240
281
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
241
282
  run_context.messages = messages
242
283
 
@@ -244,14 +285,14 @@ class Agent(Generic[AgentDeps, ResultData]):
244
285
  tool.current_retry = 0
245
286
 
246
287
  model_settings = merge_model_settings(self.model_settings, model_settings)
247
- usage_limits = usage_limits or UsageLimits()
288
+ usage_limits = usage_limits or _usage.UsageLimits()
248
289
 
249
290
  while True:
250
291
  usage_limits.check_before_request(run_context.usage)
251
292
 
252
293
  run_context.run_step += 1
253
294
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
254
- agent_model = await self._prepare_model(run_context)
295
+ agent_model = await self._prepare_model(run_context, result_schema)
255
296
 
256
297
  with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
257
298
  model_response, request_usage = await agent_model.request(messages, model_settings)
@@ -263,7 +304,9 @@ class Agent(Generic[AgentDeps, ResultData]):
263
304
  usage_limits.check_tokens(run_context.usage)
264
305
 
265
306
  with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
266
- final_result, tool_responses = await self._handle_model_response(model_response, run_context)
307
+ final_result, tool_responses = await self._handle_model_response(
308
+ model_response, run_context, result_schema
309
+ )
267
310
 
268
311
  if tool_responses:
269
312
  # Add parts to the conversation as a new message
@@ -272,29 +315,62 @@ class Agent(Generic[AgentDeps, ResultData]):
272
315
  # Check if we got a final result
273
316
  if final_result is not None:
274
317
  result_data = final_result.data
318
+ result_tool_name = final_result.tool_name
275
319
  run_span.set_attribute('all_messages', messages)
276
320
  run_span.set_attribute('usage', run_context.usage)
277
321
  handle_span.set_attribute('result', result_data)
278
322
  handle_span.message = 'handle model response -> final result'
279
- return result.RunResult(messages, new_message_index, result_data, run_context.usage)
323
+ return result.RunResult(
324
+ messages, new_message_index, result_data, result_tool_name, run_context.usage
325
+ )
280
326
  else:
281
327
  # continue the conversation
282
328
  handle_span.set_attribute('tool_responses', tool_responses)
283
329
  tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
284
330
  handle_span.message = f'handle model response -> {tool_responses_str}'
285
331
 
332
+ @overload
333
+ def run_sync(
334
+ self,
335
+ user_prompt: str,
336
+ *,
337
+ message_history: list[_messages.ModelMessage] | None = None,
338
+ model: models.Model | models.KnownModelName | None = None,
339
+ deps: AgentDeps = None,
340
+ model_settings: ModelSettings | None = None,
341
+ usage_limits: _usage.UsageLimits | None = None,
342
+ usage: _usage.Usage | None = None,
343
+ infer_name: bool = True,
344
+ ) -> result.RunResult[ResultData]: ...
345
+
346
+ @overload
286
347
  def run_sync(
287
348
  self,
288
349
  user_prompt: str,
289
350
  *,
351
+ result_type: type[RunResultData] | None,
290
352
  message_history: list[_messages.ModelMessage] | None = None,
291
353
  model: models.Model | models.KnownModelName | None = None,
292
354
  deps: AgentDeps = None,
293
355
  model_settings: ModelSettings | None = None,
294
- usage_limits: UsageLimits | None = None,
295
- usage: result.Usage | None = None,
356
+ usage_limits: _usage.UsageLimits | None = None,
357
+ usage: _usage.Usage | None = None,
296
358
  infer_name: bool = True,
297
- ) -> result.RunResult[ResultData]:
359
+ ) -> result.RunResult[RunResultData]: ...
360
+
361
+ def run_sync(
362
+ self,
363
+ user_prompt: str,
364
+ *,
365
+ result_type: type[RunResultData] | None = None,
366
+ message_history: list[_messages.ModelMessage] | None = None,
367
+ model: models.Model | models.KnownModelName | None = None,
368
+ deps: AgentDeps = None,
369
+ model_settings: ModelSettings | None = None,
370
+ usage_limits: _usage.UsageLimits | None = None,
371
+ usage: _usage.Usage | None = None,
372
+ infer_name: bool = True,
373
+ ) -> result.RunResult[Any]:
298
374
  """Run the agent with a user prompt synchronously.
299
375
 
300
376
  This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
@@ -313,6 +389,8 @@ class Agent(Generic[AgentDeps, ResultData]):
313
389
  ```
314
390
 
315
391
  Args:
392
+ result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
393
+ result validators since result validators would expect an argument that matches the agent's result type.
316
394
  user_prompt: User input to start/continue the conversation.
317
395
  message_history: History of the conversation so far.
318
396
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
@@ -330,6 +408,7 @@ class Agent(Generic[AgentDeps, ResultData]):
330
408
  return asyncio.get_event_loop().run_until_complete(
331
409
  self.run(
332
410
  user_prompt,
411
+ result_type=result_type,
333
412
  message_history=message_history,
334
413
  model=model,
335
414
  deps=deps,
@@ -340,19 +419,50 @@ class Agent(Generic[AgentDeps, ResultData]):
340
419
  )
341
420
  )
342
421
 
422
+ @overload
423
+ def run_stream(
424
+ self,
425
+ user_prompt: str,
426
+ *,
427
+ result_type: None = None,
428
+ message_history: list[_messages.ModelMessage] | None = None,
429
+ model: models.Model | models.KnownModelName | None = None,
430
+ deps: AgentDeps = None,
431
+ model_settings: ModelSettings | None = None,
432
+ usage_limits: _usage.UsageLimits | None = None,
433
+ usage: _usage.Usage | None = None,
434
+ infer_name: bool = True,
435
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, ResultData]]: ...
436
+
437
+ @overload
438
+ def run_stream(
439
+ self,
440
+ user_prompt: str,
441
+ *,
442
+ result_type: type[RunResultData],
443
+ message_history: list[_messages.ModelMessage] | None = None,
444
+ model: models.Model | models.KnownModelName | None = None,
445
+ deps: AgentDeps = None,
446
+ model_settings: ModelSettings | None = None,
447
+ usage_limits: _usage.UsageLimits | None = None,
448
+ usage: _usage.Usage | None = None,
449
+ infer_name: bool = True,
450
+ ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDeps, RunResultData]]: ...
451
+
343
452
  @asynccontextmanager
344
453
  async def run_stream(
345
454
  self,
346
455
  user_prompt: str,
347
456
  *,
457
+ result_type: type[RunResultData] | None = None,
348
458
  message_history: list[_messages.ModelMessage] | None = None,
349
459
  model: models.Model | models.KnownModelName | None = None,
350
460
  deps: AgentDeps = None,
351
461
  model_settings: ModelSettings | None = None,
352
- usage_limits: UsageLimits | None = None,
353
- usage: result.Usage | None = None,
462
+ usage_limits: _usage.UsageLimits | None = None,
463
+ usage: _usage.Usage | None = None,
354
464
  infer_name: bool = True,
355
- ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
465
+ ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, Any]]:
356
466
  """Run the agent with a user prompt in async mode, returning a streamed response.
357
467
 
358
468
  Example:
@@ -368,6 +478,8 @@ class Agent(Generic[AgentDeps, ResultData]):
368
478
  ```
369
479
 
370
480
  Args:
481
+ result_type: Custom result type to use for this run, `result_type` may only be used if the agent has no
482
+ result validators since result validators would expect an argument that matches the agent's result type.
371
483
  user_prompt: User input to start/continue the conversation.
372
484
  message_history: History of the conversation so far.
373
485
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
@@ -388,6 +500,7 @@ class Agent(Generic[AgentDeps, ResultData]):
388
500
 
389
501
  deps = self._get_deps(deps)
390
502
  new_message_index = len(message_history) if message_history else 0
503
+ result_schema = self._prepare_result_schema(result_type)
391
504
 
392
505
  with _logfire.span(
393
506
  '{agent_name} run stream {prompt=}',
@@ -396,7 +509,7 @@ class Agent(Generic[AgentDeps, ResultData]):
396
509
  model_name=model_used.name(),
397
510
  agent_name=self.name or 'agent',
398
511
  ) as run_span:
399
- run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
512
+ run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
400
513
  messages = await self._prepare_messages(user_prompt, message_history, run_context)
401
514
  run_context.messages = messages
402
515
 
@@ -404,14 +517,14 @@ class Agent(Generic[AgentDeps, ResultData]):
404
517
  tool.current_retry = 0
405
518
 
406
519
  model_settings = merge_model_settings(self.model_settings, model_settings)
407
- usage_limits = usage_limits or UsageLimits()
520
+ usage_limits = usage_limits or _usage.UsageLimits()
408
521
 
409
522
  while True:
410
523
  run_context.run_step += 1
411
524
  usage_limits.check_before_request(run_context.usage)
412
525
 
413
526
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
414
- agent_model = await self._prepare_model(run_context)
527
+ agent_model = await self._prepare_model(run_context, result_schema)
415
528
 
416
529
  with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
417
530
  async with agent_model.request_stream(messages, model_settings) as model_response:
@@ -422,7 +535,9 @@ class Agent(Generic[AgentDeps, ResultData]):
422
535
  model_req_span.__exit__(None, None, None)
423
536
 
424
537
  with _logfire.span('handle model response') as handle_span:
425
- maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
538
+ maybe_final_result = await self._handle_streamed_model_response(
539
+ model_response, run_context, result_schema
540
+ )
426
541
 
427
542
  # Check if we got a final result
428
543
  if isinstance(maybe_final_result, _MarkFinalResult):
@@ -442,7 +557,7 @@ class Agent(Generic[AgentDeps, ResultData]):
442
557
  part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
443
558
  ]
444
559
  parts = await self._process_function_tools(
445
- tool_calls, result_tool_name, run_context
560
+ tool_calls, result_tool_name, run_context, result_schema
446
561
  )
447
562
  if parts:
448
563
  messages.append(_messages.ModelRequest(parts))
@@ -453,7 +568,7 @@ class Agent(Generic[AgentDeps, ResultData]):
453
568
  new_message_index,
454
569
  usage_limits,
455
570
  result_stream,
456
- self._result_schema,
571
+ result_schema,
457
572
  run_context,
458
573
  self._result_validators,
459
574
  result_tool_name,
@@ -531,17 +646,37 @@ class Agent(Generic[AgentDeps, ResultData]):
531
646
  @overload
532
647
  def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
533
648
 
649
+ @overload
650
+ def system_prompt(
651
+ self, /, *, dynamic: bool = False
652
+ ) -> Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]: ...
653
+
534
654
  def system_prompt(
535
- self, func: _system_prompt.SystemPromptFunc[AgentDeps], /
536
- ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
655
+ self,
656
+ func: _system_prompt.SystemPromptFunc[AgentDeps] | None = None,
657
+ /,
658
+ *,
659
+ dynamic: bool = False,
660
+ ) -> (
661
+ Callable[[_system_prompt.SystemPromptFunc[AgentDeps]], _system_prompt.SystemPromptFunc[AgentDeps]]
662
+ | _system_prompt.SystemPromptFunc[AgentDeps]
663
+ ):
537
664
  """Decorator to register a system prompt function.
538
665
 
539
666
  Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
540
667
  Can decorate a sync or async functions.
541
668
 
669
+ The decorator can be used either bare (`agent.system_prompt`) or as a function call
670
+ (`agent.system_prompt(...)`), see the examples below.
671
+
542
672
  Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
543
673
  the type of the function, see `tests/typed_agent.py` for tests.
544
674
 
675
+ Args:
676
+ func: The function to decorate
677
+ dynamic: If True, the system prompt will be reevaluated even when `messages_history` is provided,
678
+ see [`SystemPromptPart.dynamic_ref`][pydantic_ai.messages.SystemPromptPart.dynamic_ref]
679
+
545
680
  Example:
546
681
  ```python
547
682
  from pydantic_ai import Agent, RunContext
@@ -552,17 +687,27 @@ class Agent(Generic[AgentDeps, ResultData]):
552
687
  def simple_system_prompt() -> str:
553
688
  return 'foobar'
554
689
 
555
- @agent.system_prompt
690
+ @agent.system_prompt(dynamic=True)
556
691
  async def async_system_prompt(ctx: RunContext[str]) -> str:
557
692
  return f'{ctx.deps} is the best'
558
-
559
- result = agent.run_sync('foobar', deps='spam')
560
- print(result.data)
561
- #> success (no tool calls)
562
693
  ```
563
694
  """
564
- self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func))
565
- return func
695
+ if func is None:
696
+
697
+ def decorator(
698
+ func_: _system_prompt.SystemPromptFunc[AgentDeps],
699
+ ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
700
+ runner = _system_prompt.SystemPromptRunner(func_, dynamic=dynamic)
701
+ self._system_prompt_functions.append(runner)
702
+ if dynamic:
703
+ self._system_prompt_dynamic_functions[func_.__qualname__] = runner
704
+ return func_
705
+
706
+ return decorator
707
+ else:
708
+ assert not dynamic, "dynamic can't be True in this case"
709
+ self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
710
+ return func
566
711
 
567
712
  @overload
568
713
  def result_validator(
@@ -814,7 +959,9 @@ class Agent(Generic[AgentDeps, ResultData]):
814
959
 
815
960
  return model_
816
961
 
817
- async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
962
+ async def _prepare_model(
963
+ self, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
964
+ ) -> models.AgentModel:
818
965
  """Build tools and create an agent model."""
819
966
  function_tools: list[ToolDefinition] = []
820
967
 
@@ -827,10 +974,39 @@ class Agent(Generic[AgentDeps, ResultData]):
827
974
 
828
975
  return await run_context.model.agent_model(
829
976
  function_tools=function_tools,
830
- allow_text_result=self._allow_text_result,
831
- result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
977
+ allow_text_result=self._allow_text_result(result_schema),
978
+ result_tools=result_schema.tool_defs() if result_schema is not None else [],
832
979
  )
833
980
 
981
+ async def _reevaluate_dynamic_prompts(
982
+ self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDeps]
983
+ ) -> None:
984
+ """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
985
+ # Only proceed if there's at least one dynamic runner.
986
+ if self._system_prompt_dynamic_functions:
987
+ for msg in messages:
988
+ if isinstance(msg, _messages.ModelRequest):
989
+ for i, part in enumerate(msg.parts):
990
+ if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
991
+ # Look up the runner by its ref
992
+ if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
993
+ updated_part_content = await runner.run(run_context)
994
+ msg.parts[i] = _messages.SystemPromptPart(
995
+ updated_part_content, dynamic_ref=part.dynamic_ref
996
+ )
997
+
998
+ def _prepare_result_schema(
999
+ self, result_type: type[RunResultData] | None
1000
+ ) -> _result.ResultSchema[RunResultData] | None:
1001
+ if result_type is not None:
1002
+ if self._result_validators:
1003
+ raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
1004
+ return _result.ResultSchema[result_type].build(
1005
+ result_type, self._result_tool_name, self._result_tool_description
1006
+ )
1007
+ else:
1008
+ return self._result_schema # pyright: ignore[reportReturnType]
1009
+
834
1010
  async def _prepare_messages(
835
1011
  self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
836
1012
  ) -> list[_messages.ModelMessage]:
@@ -846,8 +1022,10 @@ class Agent(Generic[AgentDeps, ResultData]):
846
1022
  ctx_messages.used = True
847
1023
 
848
1024
  if message_history:
849
- # shallow copy messages
1025
+ # Shallow copy messages
850
1026
  messages.extend(message_history)
1027
+ # Reevaluate any dynamic system prompt parts
1028
+ await self._reevaluate_dynamic_prompts(messages, run_context)
851
1029
  messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
852
1030
  else:
853
1031
  parts = await self._sys_parts(run_context)
@@ -857,8 +1035,11 @@ class Agent(Generic[AgentDeps, ResultData]):
857
1035
  return messages
858
1036
 
859
1037
  async def _handle_model_response(
860
- self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
861
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
1038
+ self,
1039
+ model_response: _messages.ModelResponse,
1040
+ run_context: RunContext[AgentDeps],
1041
+ result_schema: _result.ResultSchema[RunResultData] | None,
1042
+ ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
862
1043
  """Process a non-streamed response from the model.
863
1044
 
864
1045
  Returns:
@@ -879,19 +1060,19 @@ class Agent(Generic[AgentDeps, ResultData]):
879
1060
  # This accounts for cases like anthropic returns that might contain a text response
880
1061
  # and a tool call response, where the text response just indicates the tool call will happen.
881
1062
  if tool_calls:
882
- return await self._handle_structured_response(tool_calls, run_context)
1063
+ return await self._handle_structured_response(tool_calls, run_context, result_schema)
883
1064
  elif texts:
884
1065
  text = '\n\n'.join(texts)
885
- return await self._handle_text_response(text, run_context)
1066
+ return await self._handle_text_response(text, run_context, result_schema)
886
1067
  else:
887
1068
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
888
1069
 
889
1070
  async def _handle_text_response(
890
- self, text: str, run_context: RunContext[AgentDeps]
891
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
1071
+ self, text: str, run_context: RunContext[AgentDeps], result_schema: _result.ResultSchema[RunResultData] | None
1072
+ ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
892
1073
  """Handle a plain text response from the model for non-streaming responses."""
893
- if self._allow_text_result:
894
- result_data_input = cast(ResultData, text)
1074
+ if self._allow_text_result(result_schema):
1075
+ result_data_input = cast(RunResultData, text)
895
1076
  try:
896
1077
  result_data = await self._validate_result(result_data_input, run_context, None)
897
1078
  except _result.ToolRetryError as e:
@@ -907,16 +1088,19 @@ class Agent(Generic[AgentDeps, ResultData]):
907
1088
  return None, [response]
908
1089
 
909
1090
  async def _handle_structured_response(
910
- self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
911
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
1091
+ self,
1092
+ tool_calls: list[_messages.ToolCallPart],
1093
+ run_context: RunContext[AgentDeps],
1094
+ result_schema: _result.ResultSchema[RunResultData] | None,
1095
+ ) -> tuple[_MarkFinalResult[RunResultData] | None, list[_messages.ModelRequestPart]]:
912
1096
  """Handle a structured response containing tool calls from the model for non-streaming responses."""
913
1097
  assert tool_calls, 'Expected at least one tool call'
914
1098
 
915
1099
  # first look for the result tool call
916
- final_result: _MarkFinalResult[ResultData] | None = None
1100
+ final_result: _MarkFinalResult[RunResultData] | None = None
917
1101
 
918
1102
  parts: list[_messages.ModelRequestPart] = []
919
- if result_schema := self._result_schema:
1103
+ if result_schema := result_schema:
920
1104
  if match := result_schema.find_tool(tool_calls):
921
1105
  call, result_tool = match
922
1106
  try:
@@ -929,7 +1113,9 @@ class Agent(Generic[AgentDeps, ResultData]):
929
1113
  final_result = _MarkFinalResult(result_data, call.tool_name)
930
1114
 
931
1115
  # Then build the other request parts based on end strategy
932
- parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
1116
+ parts += await self._process_function_tools(
1117
+ tool_calls, final_result and final_result.tool_name, run_context, result_schema
1118
+ )
933
1119
 
934
1120
  return final_result, parts
935
1121
 
@@ -938,6 +1124,7 @@ class Agent(Generic[AgentDeps, ResultData]):
938
1124
  tool_calls: list[_messages.ToolCallPart],
939
1125
  result_tool_name: str | None,
940
1126
  run_context: RunContext[AgentDeps],
1127
+ result_schema: _result.ResultSchema[RunResultData] | None,
941
1128
  ) -> list[_messages.ModelRequestPart]:
942
1129
  """Process function (non-result) tool calls in parallel.
943
1130
 
@@ -971,7 +1158,7 @@ class Agent(Generic[AgentDeps, ResultData]):
971
1158
  )
972
1159
  else:
973
1160
  tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
974
- elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
1161
+ elif result_schema is not None and call.tool_name in result_schema.tools:
975
1162
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
976
1163
  # validation, we don't add another part here
977
1164
  if result_tool_name is not None:
@@ -983,7 +1170,7 @@ class Agent(Generic[AgentDeps, ResultData]):
983
1170
  )
984
1171
  )
985
1172
  else:
986
- parts.append(self._unknown_tool(call.tool_name, run_context))
1173
+ parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
987
1174
 
988
1175
  # Run all tool tasks in parallel
989
1176
  if tasks:
@@ -996,6 +1183,7 @@ class Agent(Generic[AgentDeps, ResultData]):
996
1183
  self,
997
1184
  model_response: models.EitherStreamedResponse,
998
1185
  run_context: RunContext[AgentDeps],
1186
+ result_schema: _result.ResultSchema[RunResultData] | None,
999
1187
  ) -> (
1000
1188
  _MarkFinalResult[models.EitherStreamedResponse]
1001
1189
  | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
@@ -1008,7 +1196,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1008
1196
  """
1009
1197
  if isinstance(model_response, models.StreamTextResponse):
1010
1198
  # plain string response
1011
- if self._allow_text_result:
1199
+ if self._allow_text_result(result_schema):
1012
1200
  return _MarkFinalResult(model_response, None)
1013
1201
  else:
1014
1202
  self._incr_result_retry(run_context)
@@ -1022,7 +1210,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1022
1210
  text = ''.join(model_response.get(final=True))
1023
1211
  return _messages.ModelResponse([_messages.TextPart(text)]), [response]
1024
1212
  elif isinstance(model_response, models.StreamStructuredResponse):
1025
- if self._result_schema is not None:
1213
+ if result_schema is not None:
1026
1214
  # if there's a result schema, iterate over the stream until we find at least one tool
1027
1215
  # NOTE: this means we ignore any other tools called here
1028
1216
  structured_msg = model_response.get()
@@ -1033,7 +1221,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1033
1221
  break
1034
1222
  structured_msg = model_response.get()
1035
1223
 
1036
- if match := self._result_schema.find_tool(structured_msg.parts):
1224
+ if match := result_schema.find_tool(structured_msg.parts):
1037
1225
  call, _ = match
1038
1226
  return _MarkFinalResult(model_response, call.tool_name)
1039
1227
 
@@ -1053,7 +1241,7 @@ class Agent(Generic[AgentDeps, ResultData]):
1053
1241
  if tool := self._function_tools.get(call.tool_name):
1054
1242
  tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1055
1243
  else:
1056
- parts.append(self._unknown_tool(call.tool_name, run_context))
1244
+ parts.append(self._unknown_tool(call.tool_name, run_context, result_schema))
1057
1245
 
1058
1246
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1059
1247
  task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
@@ -1064,13 +1252,17 @@ class Agent(Generic[AgentDeps, ResultData]):
1064
1252
 
1065
1253
  async def _validate_result(
1066
1254
  self,
1067
- result_data: ResultData,
1255
+ result_data: RunResultData,
1068
1256
  run_context: RunContext[AgentDeps],
1069
1257
  tool_call: _messages.ToolCallPart | None,
1070
- ) -> ResultData:
1071
- for validator in self._result_validators:
1072
- result_data = await validator.validate(result_data, tool_call, run_context)
1073
- return result_data
1258
+ ) -> RunResultData:
1259
+ if self._result_validators:
1260
+ agent_result_data = cast(ResultData, result_data)
1261
+ for validator in self._result_validators:
1262
+ agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1263
+ return cast(RunResultData, agent_result_data)
1264
+ else:
1265
+ return result_data
1074
1266
 
1075
1267
  def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1076
1268
  run_context.retry += 1
@@ -1084,14 +1276,22 @@ class Agent(Generic[AgentDeps, ResultData]):
1084
1276
  messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1085
1277
  for sys_prompt_runner in self._system_prompt_functions:
1086
1278
  prompt = await sys_prompt_runner.run(run_context)
1087
- messages.append(_messages.SystemPromptPart(prompt))
1279
+ if sys_prompt_runner.dynamic:
1280
+ messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
1281
+ else:
1282
+ messages.append(_messages.SystemPromptPart(prompt))
1088
1283
  return messages
1089
1284
 
1090
- def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
1285
+ def _unknown_tool(
1286
+ self,
1287
+ tool_name: str,
1288
+ run_context: RunContext[AgentDeps],
1289
+ result_schema: _result.ResultSchema[RunResultData] | None,
1290
+ ) -> _messages.RetryPromptPart:
1091
1291
  self._incr_result_retry(run_context)
1092
1292
  names = list(self._function_tools.keys())
1093
- if self._result_schema:
1094
- names.extend(self._result_schema.tool_names())
1293
+ if result_schema:
1294
+ names.extend(result_schema.tool_names())
1095
1295
  if names:
1096
1296
  msg = f'Available tools: {", ".join(names)}'
1097
1297
  else:
@@ -1129,6 +1329,10 @@ class Agent(Generic[AgentDeps, ResultData]):
1129
1329
  self.name = name
1130
1330
  return
1131
1331
 
1332
+ @staticmethod
1333
+ def _allow_text_result(result_schema: _result.ResultSchema[RunResultData] | None) -> bool:
1334
+ return result_schema is None or result_schema.allow_text_result
1335
+
1132
1336
  @property
1133
1337
  @deprecated(
1134
1338
  'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None