pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from dataclasses import dataclass, field
9
9
  from types import FrameType
10
- from typing import Any, Callable, Generic, cast, final, overload
10
+ from typing import Any, Callable, Generic, Literal, cast, final, overload
11
11
 
12
12
  import logfire_api
13
13
  from typing_extensions import assert_never
@@ -22,6 +22,7 @@ from . import (
22
22
  result,
23
23
  )
24
24
  from .result import ResultData
25
+ from .settings import ModelSettings, UsageLimits, merge_model_settings
25
26
  from .tools import (
26
27
  AgentDeps,
27
28
  RunContext,
@@ -39,6 +40,12 @@ __all__ = ('Agent',)
39
40
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
40
41
 
41
42
  NoneType = type(None)
43
+ EndStrategy = Literal['early', 'exhaustive']
44
+ """The strategy for handling multiple tool calls when a final result is found.
45
+
46
+ - `'early'`: Stop processing other tool calls once a final result is found
47
+ - `'exhaustive'`: Process all tool calls even after finding a final result
48
+ """
42
49
 
43
50
 
44
51
  @final
@@ -53,7 +60,7 @@ class Agent(Generic[AgentDeps, ResultData]):
53
60
 
54
61
  Minimal usage example:
55
62
 
56
- ```py
63
+ ```python
57
64
  from pydantic_ai import Agent
58
65
 
59
66
  agent = Agent('openai:gpt-4o')
@@ -63,14 +70,31 @@ class Agent(Generic[AgentDeps, ResultData]):
63
70
  ```
64
71
  """
65
72
 
66
- # dataclass fields mostly for my sanity knowing what attributes are available
73
+ # we use dataclass fields in order to conveniently know what attributes are available
67
74
  model: models.Model | models.KnownModelName | None
68
75
  """The default model configured for this agent."""
76
+
69
77
  name: str | None
70
78
  """The name of the agent, used for logging.
71
79
 
72
80
  If `None`, we try to infer the agent name from the call frame when the agent is first run.
73
81
  """
82
+ end_strategy: EndStrategy
83
+ """Strategy for handling tool calls when a final result is found."""
84
+
85
+ model_settings: ModelSettings | None
86
+ """Optional model request settings to use for this agents's runs, by default.
87
+
88
+ Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
89
+ be merged with this value, with the runtime argument taking priority.
90
+ """
91
+
92
+ last_run_messages: list[_messages.ModelMessage] | None
93
+ """The messages from the last run, useful when a run raised an exception.
94
+
95
+ Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
96
+ """
97
+
74
98
  _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
75
99
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
76
100
  _allow_text_result: bool = field(repr=False)
@@ -80,14 +104,8 @@ class Agent(Generic[AgentDeps, ResultData]):
80
104
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
81
105
  _deps_type: type[AgentDeps] = field(repr=False)
82
106
  _max_result_retries: int = field(repr=False)
83
- _current_result_retry: int = field(repr=False)
84
107
  _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
85
108
  _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
86
- last_run_messages: list[_messages.Message] | None = None
87
- """The messages from the last run, useful when a run raised an exception.
88
-
89
- Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
90
- """
91
109
 
92
110
  def __init__(
93
111
  self,
@@ -97,18 +115,20 @@ class Agent(Generic[AgentDeps, ResultData]):
97
115
  system_prompt: str | Sequence[str] = (),
98
116
  deps_type: type[AgentDeps] = NoneType,
99
117
  name: str | None = None,
118
+ model_settings: ModelSettings | None = None,
100
119
  retries: int = 1,
101
120
  result_tool_name: str = 'final_result',
102
121
  result_tool_description: str | None = None,
103
122
  result_retries: int | None = None,
104
123
  tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
105
124
  defer_model_check: bool = False,
125
+ end_strategy: EndStrategy = 'early',
106
126
  ):
107
127
  """Create an agent.
108
128
 
109
129
  Args:
110
130
  model: The default model to use for this agent, if not provide,
111
- you must provide the model when calling the agent.
131
+ you must provide the model when calling it.
112
132
  result_type: The type of the result data, used to validate the result data, defaults to `str`.
113
133
  system_prompt: Static system prompts to use for this agent, you can also register system
114
134
  prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
@@ -118,6 +138,7 @@ class Agent(Generic[AgentDeps, ResultData]):
118
138
  or add a type hint `: Agent[None, <return type>]`.
119
139
  name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
120
140
  when the agent is first run.
141
+ model_settings: Optional model request settings to use for this agent's runs, by default.
121
142
  retries: The default number of retries to allow before raising an error.
122
143
  result_tool_name: The name of the tool to use for the final result.
123
144
  result_tool_description: The description of the final result tool.
@@ -129,13 +150,18 @@ class Agent(Generic[AgentDeps, ResultData]):
129
150
  which checks for the necessary environment variables. Set this to `false`
130
151
  to defer the evaluation until the first run. Useful if you want to
131
152
  [override the model][pydantic_ai.Agent.override] for testing.
153
+ end_strategy: Strategy for handling tool calls that are requested alongside a final result.
154
+ See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
132
155
  """
133
156
  if model is None or defer_model_check:
134
157
  self.model = model
135
158
  else:
136
159
  self.model = models.infer_model(model)
137
160
 
161
+ self.end_strategy = end_strategy
138
162
  self.name = name
163
+ self.model_settings = model_settings
164
+ self.last_run_messages = None
139
165
  self._result_schema = _result.ResultSchema[result_type].build(
140
166
  result_type, result_tool_name, result_tool_description
141
167
  )
@@ -153,25 +179,39 @@ class Agent(Generic[AgentDeps, ResultData]):
153
179
  self._deps_type = deps_type
154
180
  self._system_prompt_functions = []
155
181
  self._max_result_retries = result_retries if result_retries is not None else retries
156
- self._current_result_retry = 0
157
182
  self._result_validators = []
158
183
 
159
184
  async def run(
160
185
  self,
161
186
  user_prompt: str,
162
187
  *,
163
- message_history: list[_messages.Message] | None = None,
188
+ message_history: list[_messages.ModelMessage] | None = None,
164
189
  model: models.Model | models.KnownModelName | None = None,
165
190
  deps: AgentDeps = None,
191
+ model_settings: ModelSettings | None = None,
192
+ usage_limits: UsageLimits | None = None,
166
193
  infer_name: bool = True,
167
194
  ) -> result.RunResult[ResultData]:
168
195
  """Run the agent with a user prompt in async mode.
169
196
 
197
+ Example:
198
+ ```python
199
+ from pydantic_ai import Agent
200
+
201
+ agent = Agent('openai:gpt-4o')
202
+
203
+ result_sync = agent.run_sync('What is the capital of Italy?')
204
+ print(result_sync.data)
205
+ #> Rome
206
+ ```
207
+
170
208
  Args:
171
209
  user_prompt: User input to start/continue the conversation.
172
210
  message_history: History of the conversation so far.
173
211
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
174
212
  deps: Optional dependencies to use for this run.
213
+ model_settings: Optional settings to use for this model's request.
214
+ usage_limits: Optional limits on model request count or token usage.
175
215
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
176
216
 
177
217
  Returns:
@@ -182,6 +222,7 @@ class Agent(Generic[AgentDeps, ResultData]):
182
222
  model_used, mode_selection = await self._get_model(model)
183
223
 
184
224
  deps = self._get_deps(deps)
225
+ new_message_index = len(message_history) if message_history else 0
185
226
 
186
227
  with _logfire.span(
187
228
  '{agent_name} run {prompt=}',
@@ -191,67 +232,91 @@ class Agent(Generic[AgentDeps, ResultData]):
191
232
  model_name=model_used.name(),
192
233
  agent_name=self.name or 'agent',
193
234
  ) as run_span:
194
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
195
- self.last_run_messages = messages
235
+ run_context = RunContext(deps, 0, [], None, model_used)
236
+ messages = await self._prepare_messages(user_prompt, message_history, run_context)
237
+ self.last_run_messages = run_context.messages = messages
196
238
 
197
239
  for tool in self._function_tools.values():
198
240
  tool.current_retry = 0
199
241
 
200
- cost = result.Cost()
242
+ usage = result.Usage(requests=0)
243
+ model_settings = merge_model_settings(self.model_settings, model_settings)
244
+ usage_limits = usage_limits or UsageLimits()
201
245
 
202
246
  run_step = 0
203
247
  while True:
248
+ usage_limits.check_before_request(usage)
249
+
204
250
  run_step += 1
205
251
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
206
- agent_model = await self._prepare_model(model_used, deps)
252
+ agent_model = await self._prepare_model(run_context)
207
253
 
208
254
  with _logfire.span('model request', run_step=run_step) as model_req_span:
209
- model_response, request_cost = await agent_model.request(messages)
255
+ model_response, request_usage = await agent_model.request(messages, model_settings)
210
256
  model_req_span.set_attribute('response', model_response)
211
- model_req_span.set_attribute('cost', request_cost)
212
- model_req_span.message = f'model request -> {model_response.role}'
257
+ model_req_span.set_attribute('usage', request_usage)
213
258
 
214
259
  messages.append(model_response)
215
- cost += request_cost
260
+ usage += request_usage
261
+ usage.requests += 1
262
+ usage_limits.check_tokens(request_usage)
216
263
 
217
264
  with _logfire.span('handle model response', run_step=run_step) as handle_span:
218
- final_result, response_messages = await self._handle_model_response(model_response, deps)
265
+ final_result, tool_responses = await self._handle_model_response(model_response, run_context)
219
266
 
220
- # Add all messages to the conversation
221
- messages.extend(response_messages)
267
+ if tool_responses:
268
+ # Add parts to the conversation as a new message
269
+ messages.append(_messages.ModelRequest(tool_responses))
222
270
 
223
271
  # Check if we got a final result
224
272
  if final_result is not None:
225
273
  result_data = final_result.data
226
274
  run_span.set_attribute('all_messages', messages)
227
- run_span.set_attribute('cost', cost)
275
+ run_span.set_attribute('usage', usage)
228
276
  handle_span.set_attribute('result', result_data)
229
277
  handle_span.message = 'handle model response -> final result'
230
- return result.RunResult(messages, new_message_index, result_data, cost)
278
+ return result.RunResult(messages, new_message_index, result_data, usage)
231
279
  else:
232
280
  # continue the conversation
233
- handle_span.set_attribute('tool_responses', response_messages)
234
- response_msgs = ' '.join(r.role for r in response_messages)
235
- handle_span.message = f'handle model response -> {response_msgs}'
281
+ handle_span.set_attribute('tool_responses', tool_responses)
282
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
283
+ handle_span.message = f'handle model response -> {tool_responses_str}'
236
284
 
237
285
  def run_sync(
238
286
  self,
239
287
  user_prompt: str,
240
288
  *,
241
- message_history: list[_messages.Message] | None = None,
289
+ message_history: list[_messages.ModelMessage] | None = None,
242
290
  model: models.Model | models.KnownModelName | None = None,
243
291
  deps: AgentDeps = None,
292
+ model_settings: ModelSettings | None = None,
293
+ usage_limits: UsageLimits | None = None,
244
294
  infer_name: bool = True,
245
295
  ) -> result.RunResult[ResultData]:
246
296
  """Run the agent with a user prompt synchronously.
247
297
 
248
- This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
298
+ This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
299
+ You therefore can't use this method inside async code or if there's an active event loop.
300
+
301
+ Example:
302
+ ```python
303
+ from pydantic_ai import Agent
304
+
305
+ agent = Agent('openai:gpt-4o')
306
+
307
+ async def main():
308
+ result = await agent.run('What is the capital of France?')
309
+ print(result.data)
310
+ #> Paris
311
+ ```
249
312
 
250
313
  Args:
251
314
  user_prompt: User input to start/continue the conversation.
252
315
  message_history: History of the conversation so far.
253
316
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
254
317
  deps: Optional dependencies to use for this run.
318
+ model_settings: Optional settings to use for this model's request.
319
+ usage_limits: Optional limits on model request count or token usage.
255
320
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
256
321
 
257
322
  Returns:
@@ -259,9 +324,16 @@ class Agent(Generic[AgentDeps, ResultData]):
259
324
  """
260
325
  if infer_name and self.name is None:
261
326
  self._infer_name(inspect.currentframe())
262
- loop = asyncio.get_event_loop()
263
- return loop.run_until_complete(
264
- self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False)
327
+ return asyncio.get_event_loop().run_until_complete(
328
+ self.run(
329
+ user_prompt,
330
+ message_history=message_history,
331
+ model=model,
332
+ deps=deps,
333
+ model_settings=model_settings,
334
+ usage_limits=usage_limits,
335
+ infer_name=False,
336
+ )
265
337
  )
266
338
 
267
339
  @asynccontextmanager
@@ -269,18 +341,34 @@ class Agent(Generic[AgentDeps, ResultData]):
269
341
  self,
270
342
  user_prompt: str,
271
343
  *,
272
- message_history: list[_messages.Message] | None = None,
344
+ message_history: list[_messages.ModelMessage] | None = None,
273
345
  model: models.Model | models.KnownModelName | None = None,
274
346
  deps: AgentDeps = None,
347
+ model_settings: ModelSettings | None = None,
348
+ usage_limits: UsageLimits | None = None,
275
349
  infer_name: bool = True,
276
350
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
277
351
  """Run the agent with a user prompt in async mode, returning a streamed response.
278
352
 
353
+ Example:
354
+ ```python
355
+ from pydantic_ai import Agent
356
+
357
+ agent = Agent('openai:gpt-4o')
358
+
359
+ async def main():
360
+ async with agent.run_stream('What is the capital of the UK?') as response:
361
+ print(await response.get_data())
362
+ #> London
363
+ ```
364
+
279
365
  Args:
280
366
  user_prompt: User input to start/continue the conversation.
281
367
  message_history: History of the conversation so far.
282
368
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
283
369
  deps: Optional dependencies to use for this run.
370
+ model_settings: Optional settings to use for this model's request.
371
+ usage_limits: Optional limits on model request count or token usage.
284
372
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
285
373
 
286
374
  Returns:
@@ -293,6 +381,7 @@ class Agent(Generic[AgentDeps, ResultData]):
293
381
  model_used, mode_selection = await self._get_model(model)
294
382
 
295
383
  deps = self._get_deps(deps)
384
+ new_message_index = len(message_history) if message_history else 0
296
385
 
297
386
  with _logfire.span(
298
387
  '{agent_name} run stream {prompt=}',
@@ -302,60 +391,89 @@ class Agent(Generic[AgentDeps, ResultData]):
302
391
  model_name=model_used.name(),
303
392
  agent_name=self.name or 'agent',
304
393
  ) as run_span:
305
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
306
- self.last_run_messages = messages
394
+ run_context = RunContext(deps, 0, [], None, model_used)
395
+ messages = await self._prepare_messages(user_prompt, message_history, run_context)
396
+ self.last_run_messages = run_context.messages = messages
307
397
 
308
398
  for tool in self._function_tools.values():
309
399
  tool.current_retry = 0
310
400
 
311
- cost = result.Cost()
401
+ usage = result.Usage()
402
+ model_settings = merge_model_settings(self.model_settings, model_settings)
403
+ usage_limits = usage_limits or UsageLimits()
312
404
 
313
405
  run_step = 0
314
406
  while True:
315
407
  run_step += 1
408
+ usage_limits.check_before_request(usage)
316
409
 
317
410
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
318
- agent_model = await self._prepare_model(model_used, deps)
411
+ agent_model = await self._prepare_model(run_context)
319
412
 
320
413
  with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
321
- async with agent_model.request_stream(messages) as model_response:
414
+ async with agent_model.request_stream(messages, model_settings) as model_response:
415
+ usage.requests += 1
322
416
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
323
417
  # We want to end the "model request" span here, but we can't exit the context manager
324
418
  # in the traditional way
325
419
  model_req_span.__exit__(None, None, None)
326
420
 
327
421
  with _logfire.span('handle model response') as handle_span:
328
- final_result, response_messages = await self._handle_streamed_model_response(
329
- model_response, deps
330
- )
331
-
332
- # Add all messages to the conversation
333
- messages.extend(response_messages)
422
+ maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
334
423
 
335
424
  # Check if we got a final result
336
- if final_result is not None:
337
- result_stream = final_result.data
338
- run_span.set_attribute('all_messages', messages)
339
- handle_span.set_attribute('result_type', result_stream.__class__.__name__)
425
+ if isinstance(maybe_final_result, _MarkFinalResult):
426
+ result_stream = maybe_final_result.data
427
+ result_tool_name = maybe_final_result.tool_name
340
428
  handle_span.message = 'handle model response -> final result'
429
+
430
+ async def on_complete():
431
+ """Called when the stream has completed.
432
+
433
+ The model response will have been added to messages by now
434
+ by `StreamedRunResult._marked_completed`.
435
+ """
436
+ last_message = messages[-1]
437
+ assert isinstance(last_message, _messages.ModelResponse)
438
+ tool_calls = [
439
+ part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
440
+ ]
441
+ parts = await self._process_function_tools(
442
+ tool_calls, result_tool_name, run_context
443
+ )
444
+ if parts:
445
+ messages.append(_messages.ModelRequest(parts))
446
+ run_span.set_attribute('all_messages', messages)
447
+
341
448
  yield result.StreamedRunResult(
342
449
  messages,
343
450
  new_message_index,
344
- cost,
451
+ usage,
452
+ usage_limits,
345
453
  result_stream,
346
454
  self._result_schema,
347
- deps,
455
+ run_context,
348
456
  self._result_validators,
349
- lambda m: run_span.set_attribute('all_messages', messages),
457
+ result_tool_name,
458
+ on_complete,
350
459
  )
351
460
  return
352
461
  else:
353
462
  # continue the conversation
354
- handle_span.set_attribute('tool_responses', response_messages)
355
- response_msgs = ' '.join(r.role for r in response_messages)
356
- handle_span.message = f'handle model response -> {response_msgs}'
357
- # the model_response should have been fully streamed by now, we can add it's cost
358
- cost += model_response.cost()
463
+ model_response_msg, tool_responses = maybe_final_result
464
+ # if we got a model response add that to messages
465
+ messages.append(model_response_msg)
466
+ if tool_responses:
467
+ # if we got one or more tool response parts, add a model request message
468
+ messages.append(_messages.ModelRequest(tool_responses))
469
+
470
+ handle_span.set_attribute('tool_responses', tool_responses)
471
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
472
+ handle_span.message = f'handle model response -> {tool_responses_str}'
473
+ # the model_response should have been fully streamed by now, we can add its usage
474
+ model_response_usage = model_response.usage()
475
+ usage += model_response_usage
476
+ usage_limits.check_tokens(usage)
359
477
 
360
478
  @contextmanager
361
479
  def override(
@@ -367,6 +485,7 @@ class Agent(Generic[AgentDeps, ResultData]):
367
485
  """Context manager to temporarily override agent dependencies and model.
368
486
 
369
487
  This is particularly useful when testing.
488
+ You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
370
489
 
371
490
  Args:
372
491
  deps: The dependencies to use instead of the dependencies passed to the agent run.
@@ -415,14 +534,14 @@ class Agent(Generic[AgentDeps, ResultData]):
415
534
  ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
416
535
  """Decorator to register a system prompt function.
417
536
 
418
- Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's only argument.
537
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
419
538
  Can decorate a sync or async functions.
420
539
 
421
540
  Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
422
541
  the type of the function, see `tests/typed_agent.py` for tests.
423
542
 
424
543
  Example:
425
- ```py
544
+ ```python
426
545
  from pydantic_ai import Agent, RunContext
427
546
 
428
547
  agent = Agent('test', deps_type=str)
@@ -466,14 +585,14 @@ class Agent(Generic[AgentDeps, ResultData]):
466
585
  ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
467
586
  """Decorator to register a result validator function.
468
587
 
469
- Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's first argument.
588
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
470
589
  Can decorate a sync or async functions.
471
590
 
472
591
  Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
473
592
  the type of the function, see `tests/typed_agent.py` for tests.
474
593
 
475
594
  Example:
476
- ```py
595
+ ```python
477
596
  from pydantic_ai import Agent, ModelRetry, RunContext
478
597
 
479
598
  agent = Agent('test', deps_type=str)
@@ -523,13 +642,13 @@ class Agent(Generic[AgentDeps, ResultData]):
523
642
  Can decorate a sync or async functions.
524
643
 
525
644
  The docstring is inspected to extract both the tool description and description of each parameter,
526
- [learn more](../agents.md#function-tools-and-schema).
645
+ [learn more](../tools.md#function-tools-and-schema).
527
646
 
528
647
  We can't add overloads for every possible signature of tool, since the return type is a recursive union
529
648
  so the signature of functions decorated with `@agent.tool` is obscured.
530
649
 
531
650
  Example:
532
- ```py
651
+ ```python
533
652
  from pydantic_ai import Agent, RunContext
534
653
 
535
654
  agent = Agent('test', deps_type=int)
@@ -595,13 +714,13 @@ class Agent(Generic[AgentDeps, ResultData]):
595
714
  Can decorate a sync or async functions.
596
715
 
597
716
  The docstring is inspected to extract both the tool description and description of each parameter,
598
- [learn more](../agents.md#function-tools-and-schema).
717
+ [learn more](../tools.md#function-tools-and-schema).
599
718
 
600
719
  We can't add overloads for every possible signature of tool, since the return type is a recursive union
601
720
  so the signature of functions decorated with `@agent.tool` is obscured.
602
721
 
603
722
  Example:
604
- ```py
723
+ ```python
605
724
  from pydantic_ai import Agent, RunContext
606
725
 
607
726
  agent = Agent('test')
@@ -696,193 +815,266 @@ class Agent(Generic[AgentDeps, ResultData]):
696
815
 
697
816
  return model_, mode_selection
698
817
 
699
- async def _prepare_model(self, model: models.Model, deps: AgentDeps) -> models.AgentModel:
700
- """Create building tools and create an agent model."""
818
+ async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
819
+ """Build tools and create an agent model."""
701
820
  function_tools: list[ToolDefinition] = []
702
821
 
703
822
  async def add_tool(tool: Tool[AgentDeps]) -> None:
704
- ctx = RunContext(deps, tool.current_retry, tool.name)
823
+ ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
705
824
  if tool_def := await tool.prepare_tool_def(ctx):
706
825
  function_tools.append(tool_def)
707
826
 
708
827
  await asyncio.gather(*map(add_tool, self._function_tools.values()))
709
828
 
710
- return await model.agent_model(
829
+ return await run_context.model.agent_model(
711
830
  function_tools=function_tools,
712
831
  allow_text_result=self._allow_text_result,
713
832
  result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
714
833
  )
715
834
 
716
835
  async def _prepare_messages(
717
- self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
718
- ) -> tuple[int, list[_messages.Message]]:
719
- # if message history includes system prompts, we don't want to regenerate them
720
- if message_history and any(m.role == 'system' for m in message_history):
836
+ self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
837
+ ) -> list[_messages.ModelMessage]:
838
+ if message_history:
721
839
  # shallow copy messages
722
840
  messages = message_history.copy()
841
+ messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
723
842
  else:
724
- messages = await self._init_messages(deps)
725
- if message_history:
726
- messages += message_history
843
+ parts = await self._sys_parts(run_context)
844
+ parts.append(_messages.UserPromptPart(user_prompt))
845
+ messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
727
846
 
728
- new_message_index = len(messages)
729
- messages.append(_messages.UserPrompt(user_prompt))
730
- return new_message_index, messages
847
+ return messages
731
848
 
732
849
  async def _handle_model_response(
733
- self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
734
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
850
+ self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
851
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
735
852
  """Process a non-streamed response from the model.
736
853
 
737
854
  Returns:
738
- A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
855
+ A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
739
856
  """
740
- if model_response.role == 'model-text-response':
741
- # plain string response
742
- if self._allow_text_result:
743
- result_data_input = cast(ResultData, model_response.content)
857
+ texts: list[str] = []
858
+ tool_calls: list[_messages.ToolCallPart] = []
859
+ for part in model_response.parts:
860
+ if isinstance(part, _messages.TextPart):
861
+ # ignore empty content for text parts, see #437
862
+ if part.content:
863
+ texts.append(part.content)
864
+ else:
865
+ tool_calls.append(part)
866
+
867
+ if texts:
868
+ text = '\n\n'.join(texts)
869
+ return await self._handle_text_response(text, run_context)
870
+ elif tool_calls:
871
+ return await self._handle_structured_response(tool_calls, run_context)
872
+ else:
873
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
874
+
875
+ async def _handle_text_response(
876
+ self, text: str, run_context: RunContext[AgentDeps]
877
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
878
+ """Handle a plain text response from the model for non-streaming responses."""
879
+ if self._allow_text_result:
880
+ result_data_input = cast(ResultData, text)
881
+ try:
882
+ result_data = await self._validate_result(result_data_input, run_context, None)
883
+ except _result.ToolRetryError as e:
884
+ self._incr_result_retry(run_context)
885
+ return None, [e.tool_retry]
886
+ else:
887
+ return _MarkFinalResult(result_data, None), []
888
+ else:
889
+ self._incr_result_retry(run_context)
890
+ response = _messages.RetryPromptPart(
891
+ content='Plain text responses are not permitted, please call one of the functions instead.',
892
+ )
893
+ return None, [response]
894
+
895
+ async def _handle_structured_response(
896
+ self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
897
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
898
+ """Handle a structured response containing tool calls from the model for non-streaming responses."""
899
+ assert tool_calls, 'Expected at least one tool call'
900
+
901
+ # first look for the result tool call
902
+ final_result: _MarkFinalResult[ResultData] | None = None
903
+
904
+ parts: list[_messages.ModelRequestPart] = []
905
+ if result_schema := self._result_schema:
906
+ if match := result_schema.find_tool(tool_calls):
907
+ call, result_tool = match
744
908
  try:
745
- result_data = await self._validate_result(result_data_input, deps, None)
909
+ result_data = result_tool.validate(call)
910
+ result_data = await self._validate_result(result_data, run_context, call)
746
911
  except _result.ToolRetryError as e:
747
- self._incr_result_retry()
748
- return None, [e.tool_retry]
912
+ self._incr_result_retry(run_context)
913
+ parts.append(e.tool_retry)
749
914
  else:
750
- return _MarkFinalResult(result_data), []
751
- else:
752
- self._incr_result_retry()
753
- response = _messages.RetryPrompt(
754
- content='Plain text responses are not permitted, please call one of the functions instead.',
915
+ final_result = _MarkFinalResult(result_data, call.tool_name)
916
+
917
+ # Then build the other request parts based on end strategy
918
+ parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
919
+
920
+ return final_result, parts
921
+
922
+ async def _process_function_tools(
923
+ self,
924
+ tool_calls: list[_messages.ToolCallPart],
925
+ result_tool_name: str | None,
926
+ run_context: RunContext[AgentDeps],
927
+ ) -> list[_messages.ModelRequestPart]:
928
+ """Process function (non-result) tool calls in parallel.
929
+
930
+ Also add stub return parts for any other tools that need it.
931
+ """
932
+ parts: list[_messages.ModelRequestPart] = []
933
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
934
+
935
+ stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
936
+
937
+ # we rely on the fact that if we found a result, it's the first result tool in the last
938
+ found_used_result_tool = False
939
+ for call in tool_calls:
940
+ if call.tool_name == result_tool_name and not found_used_result_tool:
941
+ found_used_result_tool = True
942
+ parts.append(
943
+ _messages.ToolReturnPart(
944
+ tool_name=call.tool_name,
945
+ content='Final result processed.',
946
+ tool_call_id=call.tool_call_id,
947
+ )
755
948
  )
756
- return None, [response]
757
- elif model_response.role == 'model-structured-response':
758
- if self._result_schema is not None:
759
- # if there's a result schema, and any of the calls match one of its tools, return the result
760
- # NOTE: this means we ignore any other tools called here
761
- if match := self._result_schema.find_tool(model_response):
762
- call, result_tool = match
763
- try:
764
- result_data = result_tool.validate(call)
765
- result_data = await self._validate_result(result_data, deps, call)
766
- except _result.ToolRetryError as e:
767
- self._incr_result_retry()
768
- return None, [e.tool_retry]
769
- else:
770
- # Add a ToolReturn message for the schema tool call
771
- tool_return = _messages.ToolReturn(
949
+ elif tool := self._function_tools.get(call.tool_name):
950
+ if stub_function_tools:
951
+ parts.append(
952
+ _messages.ToolReturnPart(
772
953
  tool_name=call.tool_name,
773
- content='Final result processed.',
774
- tool_id=call.tool_id,
954
+ content='Tool not executed - a final result was already processed.',
955
+ tool_call_id=call.tool_call_id,
775
956
  )
776
- return _MarkFinalResult(result_data), [tool_return]
777
-
778
- if not model_response.calls:
779
- raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
780
-
781
- # otherwise we run all tool functions in parallel
782
- messages: list[_messages.Message] = []
783
- tasks: list[asyncio.Task[_messages.Message]] = []
784
- for call in model_response.calls:
785
- if tool := self._function_tools.get(call.tool_name):
786
- tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
957
+ )
787
958
  else:
788
- messages.append(self._unknown_tool(call.tool_name))
959
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
960
+ elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
961
+ # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
962
+ # validation, we don't add another part here
963
+ if result_tool_name is not None:
964
+ parts.append(
965
+ _messages.ToolReturnPart(
966
+ tool_name=call.tool_name,
967
+ content='Result tool not used - a final result was already processed.',
968
+ tool_call_id=call.tool_call_id,
969
+ )
970
+ )
971
+ else:
972
+ parts.append(self._unknown_tool(call.tool_name, run_context))
789
973
 
974
+ # Run all tool tasks in parallel
975
+ if tasks:
790
976
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
791
- task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
792
- messages.extend(task_results)
793
- return None, messages
794
- else:
795
- assert_never(model_response)
977
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
978
+ parts.extend(task_results)
979
+ return parts
796
980
 
797
981
  async def _handle_streamed_model_response(
798
- self, model_response: models.EitherStreamedResponse, deps: AgentDeps
799
- ) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
982
+ self,
983
+ model_response: models.EitherStreamedResponse,
984
+ run_context: RunContext[AgentDeps],
985
+ ) -> (
986
+ _MarkFinalResult[models.EitherStreamedResponse]
987
+ | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
988
+ ):
800
989
  """Process a streamed response from the model.
801
990
 
802
991
  Returns:
803
- A tuple of (final_result, messages). If final_result is not None, the conversation should end.
992
+ Either a final result or a tuple of the model response and the tool responses for the next request.
993
+ If a final result is returned, the conversation should end.
804
994
  """
805
995
  if isinstance(model_response, models.StreamTextResponse):
806
996
  # plain string response
807
997
  if self._allow_text_result:
808
- return _MarkFinalResult(model_response), []
998
+ return _MarkFinalResult(model_response, None)
809
999
  else:
810
- self._incr_result_retry()
811
- response = _messages.RetryPrompt(
1000
+ self._incr_result_retry(run_context)
1001
+ response = _messages.RetryPromptPart(
812
1002
  content='Plain text responses are not permitted, please call one of the functions instead.',
813
1003
  )
814
- # stream the response, so cost is correct
1004
+ # stream the response, so usage is correct
815
1005
  async for _ in model_response:
816
1006
  pass
817
1007
 
818
- return None, [response]
819
- else:
820
- assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
1008
+ text = ''.join(model_response.get(final=True))
1009
+ return _messages.ModelResponse([_messages.TextPart(text)]), [response]
1010
+ elif isinstance(model_response, models.StreamStructuredResponse):
821
1011
  if self._result_schema is not None:
822
1012
  # if there's a result schema, iterate over the stream until we find at least one tool
823
1013
  # NOTE: this means we ignore any other tools called here
824
1014
  structured_msg = model_response.get()
825
- while not structured_msg.calls:
1015
+ while not structured_msg.parts:
826
1016
  try:
827
1017
  await model_response.__anext__()
828
1018
  except StopAsyncIteration:
829
1019
  break
830
1020
  structured_msg = model_response.get()
831
1021
 
832
- if match := self._result_schema.find_tool(structured_msg):
1022
+ if match := self._result_schema.find_tool(structured_msg.parts):
833
1023
  call, _ = match
834
- tool_return = _messages.ToolReturn(
835
- tool_name=call.tool_name,
836
- content='Final result processed.',
837
- tool_id=call.tool_id,
838
- )
839
- return _MarkFinalResult(model_response), [tool_return]
1024
+ return _MarkFinalResult(model_response, call.tool_name)
840
1025
 
841
1026
  # the model is calling a tool function, consume the response to get the next message
842
1027
  async for _ in model_response:
843
1028
  pass
844
- structured_msg = model_response.get()
845
- if not structured_msg.calls:
1029
+ model_response_msg = model_response.get()
1030
+ if not model_response_msg.parts:
846
1031
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
847
- messages: list[_messages.Message] = [structured_msg]
848
1032
 
849
1033
  # we now run all tool functions in parallel
850
- tasks: list[asyncio.Task[_messages.Message]] = []
851
- for call in structured_msg.calls:
852
- if tool := self._function_tools.get(call.tool_name):
853
- tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
854
- else:
855
- messages.append(self._unknown_tool(call.tool_name))
1034
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1035
+ parts: list[_messages.ModelRequestPart] = []
1036
+ for item in model_response_msg.parts:
1037
+ if isinstance(item, _messages.ToolCallPart):
1038
+ call = item
1039
+ if tool := self._function_tools.get(call.tool_name):
1040
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1041
+ else:
1042
+ parts.append(self._unknown_tool(call.tool_name, run_context))
856
1043
 
857
1044
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
858
- task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
859
- messages.extend(task_results)
860
- return None, messages
1045
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1046
+ parts.extend(task_results)
1047
+ return model_response_msg, parts
1048
+ else:
1049
+ assert_never(model_response)
861
1050
 
862
1051
  async def _validate_result(
863
- self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
1052
+ self,
1053
+ result_data: ResultData,
1054
+ run_context: RunContext[AgentDeps],
1055
+ tool_call: _messages.ToolCallPart | None,
864
1056
  ) -> ResultData:
865
1057
  for validator in self._result_validators:
866
- result_data = await validator.validate(result_data, deps, self._current_result_retry, tool_call)
1058
+ result_data = await validator.validate(result_data, tool_call, run_context)
867
1059
  return result_data
868
1060
 
869
- def _incr_result_retry(self) -> None:
870
- self._current_result_retry += 1
871
- if self._current_result_retry > self._max_result_retries:
1061
+ def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1062
+ run_context.retry += 1
1063
+ if run_context.retry > self._max_result_retries:
872
1064
  raise exceptions.UnexpectedModelBehavior(
873
1065
  f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
874
1066
  )
875
1067
 
876
- async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
1068
+ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
877
1069
  """Build the initial messages for the conversation."""
878
- messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts]
1070
+ messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
879
1071
  for sys_prompt_runner in self._system_prompt_functions:
880
- prompt = await sys_prompt_runner.run(deps)
881
- messages.append(_messages.SystemPrompt(prompt))
1072
+ prompt = await sys_prompt_runner.run(run_context)
1073
+ messages.append(_messages.SystemPromptPart(prompt))
882
1074
  return messages
883
1075
 
884
- def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
885
- self._incr_result_retry()
1076
+ def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
1077
+ self._incr_result_retry(run_context)
886
1078
  names = list(self._function_tools.keys())
887
1079
  if self._result_schema:
888
1080
  names.extend(self._result_schema.tool_names())
@@ -890,7 +1082,7 @@ class Agent(Generic[AgentDeps, ResultData]):
890
1082
  msg = f'Available tools: {", ".join(names)}'
891
1083
  else:
892
1084
  msg = 'No tools available.'
893
- return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')
1085
+ return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
894
1086
 
895
1087
  def _get_deps(self, deps: AgentDeps) -> AgentDeps:
896
1088
  """Get deps for a run.
@@ -934,3 +1126,6 @@ class _MarkFinalResult(Generic[ResultData]):
934
1126
  """
935
1127
 
936
1128
  data: ResultData
1129
+ """The final result data."""
1130
+ tool_name: str | None
1131
+ """Name of the final result tool, None if the result is a string."""