pydantic-ai-slim 0.0.12__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from dataclasses import dataclass, field
9
9
  from types import FrameType
10
- from typing import Any, Callable, Generic, cast, final, overload
10
+ from typing import Any, Callable, Generic, Literal, cast, final, overload
11
11
 
12
12
  import logfire_api
13
13
  from typing_extensions import assert_never
@@ -22,6 +22,7 @@ from . import (
22
22
  result,
23
23
  )
24
24
  from .result import ResultData
25
+ from .settings import ModelSettings, merge_model_settings
25
26
  from .tools import (
26
27
  AgentDeps,
27
28
  RunContext,
@@ -39,6 +40,12 @@ __all__ = ('Agent',)
39
40
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
40
41
 
41
42
  NoneType = type(None)
43
+ EndStrategy = Literal['early', 'exhaustive']
44
+ """The strategy for handling multiple tool calls when a final result is found.
45
+
46
+ - `'early'`: Stop processing other tool calls once a final result is found
47
+ - `'exhaustive'`: Process all tool calls even after finding a final result
48
+ """
42
49
 
43
50
 
44
51
  @final
@@ -53,7 +60,7 @@ class Agent(Generic[AgentDeps, ResultData]):
53
60
 
54
61
  Minimal usage example:
55
62
 
56
- ```py
63
+ ```python
57
64
  from pydantic_ai import Agent
58
65
 
59
66
  agent = Agent('openai:gpt-4o')
@@ -63,14 +70,31 @@ class Agent(Generic[AgentDeps, ResultData]):
63
70
  ```
64
71
  """
65
72
 
66
- # dataclass fields mostly for my sanity knowing what attributes are available
73
+ # we use dataclass fields in order to conveniently know what attributes are available
67
74
  model: models.Model | models.KnownModelName | None
68
75
  """The default model configured for this agent."""
76
+
69
77
  name: str | None
70
78
  """The name of the agent, used for logging.
71
79
 
72
80
  If `None`, we try to infer the agent name from the call frame when the agent is first run.
73
81
  """
82
+ end_strategy: EndStrategy
83
+ """Strategy for handling tool calls when a final result is found."""
84
+
85
+ model_settings: ModelSettings | None
86
+ """Optional model request settings to use for this agents's runs, by default.
87
+
88
+ Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
89
+ be merged with this value, with the runtime argument taking priority.
90
+ """
91
+
92
+ last_run_messages: list[_messages.ModelMessage] | None
93
+ """The messages from the last run, useful when a run raised an exception.
94
+
95
+ Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
96
+ """
97
+
74
98
  _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
75
99
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
76
100
  _allow_text_result: bool = field(repr=False)
@@ -83,11 +107,6 @@ class Agent(Generic[AgentDeps, ResultData]):
83
107
  _current_result_retry: int = field(repr=False)
84
108
  _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
85
109
  _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
86
- last_run_messages: list[_messages.Message] | None = None
87
- """The messages from the last run, useful when a run raised an exception.
88
-
89
- Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
90
- """
91
110
 
92
111
  def __init__(
93
112
  self,
@@ -97,18 +116,20 @@ class Agent(Generic[AgentDeps, ResultData]):
97
116
  system_prompt: str | Sequence[str] = (),
98
117
  deps_type: type[AgentDeps] = NoneType,
99
118
  name: str | None = None,
119
+ model_settings: ModelSettings | None = None,
100
120
  retries: int = 1,
101
121
  result_tool_name: str = 'final_result',
102
122
  result_tool_description: str | None = None,
103
123
  result_retries: int | None = None,
104
124
  tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
105
125
  defer_model_check: bool = False,
126
+ end_strategy: EndStrategy = 'early',
106
127
  ):
107
128
  """Create an agent.
108
129
 
109
130
  Args:
110
131
  model: The default model to use for this agent, if not provide,
111
- you must provide the model when calling the agent.
132
+ you must provide the model when calling it.
112
133
  result_type: The type of the result data, used to validate the result data, defaults to `str`.
113
134
  system_prompt: Static system prompts to use for this agent, you can also register system
114
135
  prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
@@ -118,6 +139,7 @@ class Agent(Generic[AgentDeps, ResultData]):
118
139
  or add a type hint `: Agent[None, <return type>]`.
119
140
  name: The name of the agent, used for logging. If `None`, we try to infer the agent name from the call frame
120
141
  when the agent is first run.
142
+ model_settings: Optional model request settings to use for this agent's runs, by default.
121
143
  retries: The default number of retries to allow before raising an error.
122
144
  result_tool_name: The name of the tool to use for the final result.
123
145
  result_tool_description: The description of the final result tool.
@@ -129,13 +151,18 @@ class Agent(Generic[AgentDeps, ResultData]):
129
151
  which checks for the necessary environment variables. Set this to `false`
130
152
  to defer the evaluation until the first run. Useful if you want to
131
153
  [override the model][pydantic_ai.Agent.override] for testing.
154
+ end_strategy: Strategy for handling tool calls that are requested alongside a final result.
155
+ See [`EndStrategy`][pydantic_ai.agent.EndStrategy] for more information.
132
156
  """
133
157
  if model is None or defer_model_check:
134
158
  self.model = model
135
159
  else:
136
160
  self.model = models.infer_model(model)
137
161
 
162
+ self.end_strategy = end_strategy
138
163
  self.name = name
164
+ self.model_settings = model_settings
165
+ self.last_run_messages = None
139
166
  self._result_schema = _result.ResultSchema[result_type].build(
140
167
  result_type, result_tool_name, result_tool_description
141
168
  )
@@ -160,19 +187,32 @@ class Agent(Generic[AgentDeps, ResultData]):
160
187
  self,
161
188
  user_prompt: str,
162
189
  *,
163
- message_history: list[_messages.Message] | None = None,
190
+ message_history: list[_messages.ModelMessage] | None = None,
164
191
  model: models.Model | models.KnownModelName | None = None,
165
192
  deps: AgentDeps = None,
193
+ model_settings: ModelSettings | None = None,
166
194
  infer_name: bool = True,
167
195
  ) -> result.RunResult[ResultData]:
168
196
  """Run the agent with a user prompt in async mode.
169
197
 
198
+ Example:
199
+ ```python
200
+ from pydantic_ai import Agent
201
+
202
+ agent = Agent('openai:gpt-4o')
203
+
204
+ result_sync = agent.run_sync('What is the capital of Italy?')
205
+ print(result_sync.data)
206
+ #> Rome
207
+ ```
208
+
170
209
  Args:
171
210
  user_prompt: User input to start/continue the conversation.
172
211
  message_history: History of the conversation so far.
173
212
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
174
213
  deps: Optional dependencies to use for this run.
175
214
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
215
+ model_settings: Optional settings to use for this model's request.
176
216
 
177
217
  Returns:
178
218
  The result of the run.
@@ -182,6 +222,7 @@ class Agent(Generic[AgentDeps, ResultData]):
182
222
  model_used, mode_selection = await self._get_model(model)
183
223
 
184
224
  deps = self._get_deps(deps)
225
+ new_message_index = len(message_history) if message_history else 0
185
226
 
186
227
  with _logfire.span(
187
228
  '{agent_name} run {prompt=}',
@@ -191,34 +232,35 @@ class Agent(Generic[AgentDeps, ResultData]):
191
232
  model_name=model_used.name(),
192
233
  agent_name=self.name or 'agent',
193
234
  ) as run_span:
194
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
195
- self.last_run_messages = messages
235
+ self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
196
236
 
197
237
  for tool in self._function_tools.values():
198
238
  tool.current_retry = 0
199
239
 
200
240
  cost = result.Cost()
201
241
 
242
+ model_settings = merge_model_settings(self.model_settings, model_settings)
243
+
202
244
  run_step = 0
203
245
  while True:
204
246
  run_step += 1
205
247
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
206
- agent_model = await self._prepare_model(model_used, deps)
248
+ agent_model = await self._prepare_model(model_used, deps, messages)
207
249
 
208
250
  with _logfire.span('model request', run_step=run_step) as model_req_span:
209
- model_response, request_cost = await agent_model.request(messages)
251
+ model_response, request_cost = await agent_model.request(messages, model_settings)
210
252
  model_req_span.set_attribute('response', model_response)
211
253
  model_req_span.set_attribute('cost', request_cost)
212
- model_req_span.message = f'model request -> {model_response.role}'
213
254
 
214
255
  messages.append(model_response)
215
256
  cost += request_cost
216
257
 
217
258
  with _logfire.span('handle model response', run_step=run_step) as handle_span:
218
- final_result, response_messages = await self._handle_model_response(model_response, deps)
259
+ final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
219
260
 
220
- # Add all messages to the conversation
221
- messages.extend(response_messages)
261
+ if tool_responses:
262
+ # Add parts to the conversation as a new message
263
+ messages.append(_messages.ModelRequest(tool_responses))
222
264
 
223
265
  # Check if we got a final result
224
266
  if final_result is not None:
@@ -230,22 +272,36 @@ class Agent(Generic[AgentDeps, ResultData]):
230
272
  return result.RunResult(messages, new_message_index, result_data, cost)
231
273
  else:
232
274
  # continue the conversation
233
- handle_span.set_attribute('tool_responses', response_messages)
234
- response_msgs = ' '.join(r.role for r in response_messages)
235
- handle_span.message = f'handle model response -> {response_msgs}'
275
+ handle_span.set_attribute('tool_responses', tool_responses)
276
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
277
+ handle_span.message = f'handle model response -> {tool_responses_str}'
236
278
 
237
279
  def run_sync(
238
280
  self,
239
281
  user_prompt: str,
240
282
  *,
241
- message_history: list[_messages.Message] | None = None,
283
+ message_history: list[_messages.ModelMessage] | None = None,
242
284
  model: models.Model | models.KnownModelName | None = None,
243
285
  deps: AgentDeps = None,
286
+ model_settings: ModelSettings | None = None,
244
287
  infer_name: bool = True,
245
288
  ) -> result.RunResult[ResultData]:
246
289
  """Run the agent with a user prompt synchronously.
247
290
 
248
- This is a convenience method that wraps `self.run` with `loop.run_until_complete()`.
291
+ This is a convenience method that wraps [`self.run`][pydantic_ai.Agent.run] with `loop.run_until_complete(...)`.
292
+ You therefore can't use this method inside async code or if there's an active event loop.
293
+
294
+ Example:
295
+ ```python
296
+ from pydantic_ai import Agent
297
+
298
+ agent = Agent('openai:gpt-4o')
299
+
300
+ async def main():
301
+ result = await agent.run('What is the capital of France?')
302
+ print(result.data)
303
+ #> Paris
304
+ ```
249
305
 
250
306
  Args:
251
307
  user_prompt: User input to start/continue the conversation.
@@ -253,15 +309,22 @@ class Agent(Generic[AgentDeps, ResultData]):
253
309
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
254
310
  deps: Optional dependencies to use for this run.
255
311
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
312
+ model_settings: Optional settings to use for this model's request.
256
313
 
257
314
  Returns:
258
315
  The result of the run.
259
316
  """
260
317
  if infer_name and self.name is None:
261
318
  self._infer_name(inspect.currentframe())
262
- loop = asyncio.get_event_loop()
263
- return loop.run_until_complete(
264
- self.run(user_prompt, message_history=message_history, model=model, deps=deps, infer_name=False)
319
+ return asyncio.get_event_loop().run_until_complete(
320
+ self.run(
321
+ user_prompt,
322
+ message_history=message_history,
323
+ model=model,
324
+ deps=deps,
325
+ infer_name=False,
326
+ model_settings=model_settings,
327
+ )
265
328
  )
266
329
 
267
330
  @asynccontextmanager
@@ -269,19 +332,33 @@ class Agent(Generic[AgentDeps, ResultData]):
269
332
  self,
270
333
  user_prompt: str,
271
334
  *,
272
- message_history: list[_messages.Message] | None = None,
335
+ message_history: list[_messages.ModelMessage] | None = None,
273
336
  model: models.Model | models.KnownModelName | None = None,
274
337
  deps: AgentDeps = None,
338
+ model_settings: ModelSettings | None = None,
275
339
  infer_name: bool = True,
276
340
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
277
341
  """Run the agent with a user prompt in async mode, returning a streamed response.
278
342
 
343
+ Example:
344
+ ```python
345
+ from pydantic_ai import Agent
346
+
347
+ agent = Agent('openai:gpt-4o')
348
+
349
+ async def main():
350
+ async with agent.run_stream('What is the capital of the UK?') as response:
351
+ print(await response.get_data())
352
+ #> London
353
+ ```
354
+
279
355
  Args:
280
356
  user_prompt: User input to start/continue the conversation.
281
357
  message_history: History of the conversation so far.
282
358
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
283
359
  deps: Optional dependencies to use for this run.
284
360
  infer_name: Whether to try to infer the agent name from the call frame if it's not set.
361
+ model_settings: Optional settings to use for this model's request.
285
362
 
286
363
  Returns:
287
364
  The result of the run.
@@ -293,6 +370,7 @@ class Agent(Generic[AgentDeps, ResultData]):
293
370
  model_used, mode_selection = await self._get_model(model)
294
371
 
295
372
  deps = self._get_deps(deps)
373
+ new_message_index = len(message_history) if message_history else 0
296
374
 
297
375
  with _logfire.span(
298
376
  '{agent_name} run stream {prompt=}',
@@ -302,42 +380,57 @@ class Agent(Generic[AgentDeps, ResultData]):
302
380
  model_name=model_used.name(),
303
381
  agent_name=self.name or 'agent',
304
382
  ) as run_span:
305
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
306
- self.last_run_messages = messages
383
+ self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
307
384
 
308
385
  for tool in self._function_tools.values():
309
386
  tool.current_retry = 0
310
387
 
311
388
  cost = result.Cost()
389
+ model_settings = merge_model_settings(self.model_settings, model_settings)
312
390
 
313
391
  run_step = 0
314
392
  while True:
315
393
  run_step += 1
316
394
 
317
395
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
318
- agent_model = await self._prepare_model(model_used, deps)
396
+ agent_model = await self._prepare_model(model_used, deps, messages)
319
397
 
320
398
  with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
321
- async with agent_model.request_stream(messages) as model_response:
399
+ async with agent_model.request_stream(messages, model_settings) as model_response:
322
400
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
323
401
  # We want to end the "model request" span here, but we can't exit the context manager
324
402
  # in the traditional way
325
403
  model_req_span.__exit__(None, None, None)
326
404
 
327
405
  with _logfire.span('handle model response') as handle_span:
328
- final_result, response_messages = await self._handle_streamed_model_response(
329
- model_response, deps
406
+ maybe_final_result = await self._handle_streamed_model_response(
407
+ model_response, deps, messages
330
408
  )
331
409
 
332
- # Add all messages to the conversation
333
- messages.extend(response_messages)
334
-
335
410
  # Check if we got a final result
336
- if final_result is not None:
337
- result_stream = final_result.data
338
- run_span.set_attribute('all_messages', messages)
339
- handle_span.set_attribute('result_type', result_stream.__class__.__name__)
411
+ if isinstance(maybe_final_result, _MarkFinalResult):
412
+ result_stream = maybe_final_result.data
413
+ result_tool_name = maybe_final_result.tool_name
340
414
  handle_span.message = 'handle model response -> final result'
415
+
416
+ async def on_complete():
417
+ """Called when the stream has completed.
418
+
419
+ The model response will have been added to messages by now
420
+ by `StreamedRunResult._marked_completed`.
421
+ """
422
+ last_message = messages[-1]
423
+ assert isinstance(last_message, _messages.ModelResponse)
424
+ tool_calls = [
425
+ part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
426
+ ]
427
+ parts = await self._process_function_tools(
428
+ tool_calls, result_tool_name, deps, messages
429
+ )
430
+ if parts:
431
+ messages.append(_messages.ModelRequest(parts))
432
+ run_span.set_attribute('all_messages', messages)
433
+
341
434
  yield result.StreamedRunResult(
342
435
  messages,
343
436
  new_message_index,
@@ -346,14 +439,22 @@ class Agent(Generic[AgentDeps, ResultData]):
346
439
  self._result_schema,
347
440
  deps,
348
441
  self._result_validators,
349
- lambda m: run_span.set_attribute('all_messages', messages),
442
+ result_tool_name,
443
+ on_complete,
350
444
  )
351
445
  return
352
446
  else:
353
447
  # continue the conversation
354
- handle_span.set_attribute('tool_responses', response_messages)
355
- response_msgs = ' '.join(r.role for r in response_messages)
356
- handle_span.message = f'handle model response -> {response_msgs}'
448
+ model_response_msg, tool_responses = maybe_final_result
449
+ # if we got a model response add that to messages
450
+ messages.append(model_response_msg)
451
+ if tool_responses:
452
+ # if we got one or more tool response parts, add a model request message
453
+ messages.append(_messages.ModelRequest(tool_responses))
454
+
455
+ handle_span.set_attribute('tool_responses', tool_responses)
456
+ tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
457
+ handle_span.message = f'handle model response -> {tool_responses_str}'
357
458
  # the model_response should have been fully streamed by now, we can add it's cost
358
459
  cost += model_response.cost()
359
460
 
@@ -367,6 +468,7 @@ class Agent(Generic[AgentDeps, ResultData]):
367
468
  """Context manager to temporarily override agent dependencies and model.
368
469
 
369
470
  This is particularly useful when testing.
471
+ You can find an example of this [here](../testing-evals.md#overriding-model-via-pytest-fixtures).
370
472
 
371
473
  Args:
372
474
  deps: The dependencies to use instead of the dependencies passed to the agent run.
@@ -415,14 +517,14 @@ class Agent(Generic[AgentDeps, ResultData]):
415
517
  ) -> _system_prompt.SystemPromptFunc[AgentDeps]:
416
518
  """Decorator to register a system prompt function.
417
519
 
418
- Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's only argument.
520
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its only argument.
419
521
  Can decorate a sync or async functions.
420
522
 
421
523
  Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
422
524
  the type of the function, see `tests/typed_agent.py` for tests.
423
525
 
424
526
  Example:
425
- ```py
527
+ ```python
426
528
  from pydantic_ai import Agent, RunContext
427
529
 
428
530
  agent = Agent('test', deps_type=str)
@@ -466,14 +568,14 @@ class Agent(Generic[AgentDeps, ResultData]):
466
568
  ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
467
569
  """Decorator to register a result validator function.
468
570
 
469
- Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's first argument.
571
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
470
572
  Can decorate a sync or async functions.
471
573
 
472
574
  Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
473
575
  the type of the function, see `tests/typed_agent.py` for tests.
474
576
 
475
577
  Example:
476
- ```py
578
+ ```python
477
579
  from pydantic_ai import Agent, ModelRetry, RunContext
478
580
 
479
581
  agent = Agent('test', deps_type=str)
@@ -523,13 +625,13 @@ class Agent(Generic[AgentDeps, ResultData]):
523
625
  Can decorate a sync or async functions.
524
626
 
525
627
  The docstring is inspected to extract both the tool description and description of each parameter,
526
- [learn more](../agents.md#function-tools-and-schema).
628
+ [learn more](../tools.md#function-tools-and-schema).
527
629
 
528
630
  We can't add overloads for every possible signature of tool, since the return type is a recursive union
529
631
  so the signature of functions decorated with `@agent.tool` is obscured.
530
632
 
531
633
  Example:
532
- ```py
634
+ ```python
533
635
  from pydantic_ai import Agent, RunContext
534
636
 
535
637
  agent = Agent('test', deps_type=int)
@@ -595,13 +697,13 @@ class Agent(Generic[AgentDeps, ResultData]):
595
697
  Can decorate a sync or async functions.
596
698
 
597
699
  The docstring is inspected to extract both the tool description and description of each parameter,
598
- [learn more](../agents.md#function-tools-and-schema).
700
+ [learn more](../tools.md#function-tools-and-schema).
599
701
 
600
702
  We can't add overloads for every possible signature of tool, since the return type is a recursive union
601
703
  so the signature of functions decorated with `@agent.tool` is obscured.
602
704
 
603
705
  Example:
604
- ```py
706
+ ```python
605
707
  from pydantic_ai import Agent, RunContext
606
708
 
607
709
  agent = Agent('test')
@@ -696,12 +798,14 @@ class Agent(Generic[AgentDeps, ResultData]):
696
798
 
697
799
  return model_, mode_selection
698
800
 
699
- async def _prepare_model(self, model: models.Model, deps: AgentDeps) -> models.AgentModel:
801
+ async def _prepare_model(
802
+ self, model: models.Model, deps: AgentDeps, messages: list[_messages.ModelMessage]
803
+ ) -> models.AgentModel:
700
804
  """Create building tools and create an agent model."""
701
805
  function_tools: list[ToolDefinition] = []
702
806
 
703
807
  async def add_tool(tool: Tool[AgentDeps]) -> None:
704
- ctx = RunContext(deps, tool.current_retry, tool.name)
808
+ ctx = RunContext(deps, tool.current_retry, messages, tool.name)
705
809
  if tool_def := await tool.prepare_tool_def(ctx):
706
810
  function_tools.append(tool_def)
707
811
 
@@ -714,156 +818,234 @@ class Agent(Generic[AgentDeps, ResultData]):
714
818
  )
715
819
 
716
820
  async def _prepare_messages(
717
- self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
718
- ) -> tuple[int, list[_messages.Message]]:
719
- # if message history includes system prompts, we don't want to regenerate them
720
- if message_history and any(m.role == 'system' for m in message_history):
821
+ self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.ModelMessage] | None
822
+ ) -> list[_messages.ModelMessage]:
823
+ if message_history:
721
824
  # shallow copy messages
722
825
  messages = message_history.copy()
826
+ messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
723
827
  else:
724
- messages = await self._init_messages(deps)
725
- if message_history:
726
- messages += message_history
828
+ parts = await self._sys_parts(deps)
829
+ parts.append(_messages.UserPromptPart(user_prompt))
830
+ messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
727
831
 
728
- new_message_index = len(messages)
729
- messages.append(_messages.UserPrompt(user_prompt))
730
- return new_message_index, messages
832
+ return messages
731
833
 
732
834
  async def _handle_model_response(
733
- self, model_response: _messages.ModelAnyResponse, deps: AgentDeps
734
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
835
+ self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
836
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
735
837
  """Process a non-streamed response from the model.
736
838
 
737
839
  Returns:
738
- A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
840
+ A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
739
841
  """
740
- if model_response.role == 'model-text-response':
741
- # plain string response
742
- if self._allow_text_result:
743
- result_data_input = cast(ResultData, model_response.content)
842
+ texts: list[str] = []
843
+ tool_calls: list[_messages.ToolCallPart] = []
844
+ for item in model_response.parts:
845
+ if isinstance(item, _messages.TextPart):
846
+ texts.append(item.content)
847
+ else:
848
+ tool_calls.append(item)
849
+
850
+ if texts:
851
+ text = '\n\n'.join(texts)
852
+ return await self._handle_text_response(text, deps, conv_messages)
853
+ elif tool_calls:
854
+ return await self._handle_structured_response(tool_calls, deps, conv_messages)
855
+ else:
856
+ raise exceptions.UnexpectedModelBehavior('Received empty model response')
857
+
858
+ async def _handle_text_response(
859
+ self, text: str, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
860
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
861
+ """Handle a plain text response from the model for non-streaming responses."""
862
+ if self._allow_text_result:
863
+ result_data_input = cast(ResultData, text)
864
+ try:
865
+ result_data = await self._validate_result(result_data_input, deps, None, conv_messages)
866
+ except _result.ToolRetryError as e:
867
+ self._incr_result_retry()
868
+ return None, [e.tool_retry]
869
+ else:
870
+ return _MarkFinalResult(result_data, None), []
871
+ else:
872
+ self._incr_result_retry()
873
+ response = _messages.RetryPromptPart(
874
+ content='Plain text responses are not permitted, please call one of the functions instead.',
875
+ )
876
+ return None, [response]
877
+
878
+ async def _handle_structured_response(
879
+ self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
880
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
881
+ """Handle a structured response containing tool calls from the model for non-streaming responses."""
882
+ assert tool_calls, 'Expected at least one tool call'
883
+
884
+ # first look for the result tool call
885
+ final_result: _MarkFinalResult[ResultData] | None = None
886
+
887
+ parts: list[_messages.ModelRequestPart] = []
888
+ if result_schema := self._result_schema:
889
+ if match := result_schema.find_tool(tool_calls):
890
+ call, result_tool = match
744
891
  try:
745
- result_data = await self._validate_result(result_data_input, deps, None)
892
+ result_data = result_tool.validate(call)
893
+ result_data = await self._validate_result(result_data, deps, call, conv_messages)
746
894
  except _result.ToolRetryError as e:
747
895
  self._incr_result_retry()
748
- return None, [e.tool_retry]
896
+ parts.append(e.tool_retry)
749
897
  else:
750
- return _MarkFinalResult(result_data), []
751
- else:
752
- self._incr_result_retry()
753
- response = _messages.RetryPrompt(
754
- content='Plain text responses are not permitted, please call one of the functions instead.',
898
+ final_result = _MarkFinalResult(result_data, call.tool_name)
899
+
900
+ # Then build the other request parts based on end strategy
901
+ parts += await self._process_function_tools(
902
+ tool_calls, final_result and final_result.tool_name, deps, conv_messages
903
+ )
904
+
905
+ return final_result, parts
906
+
907
+ async def _process_function_tools(
908
+ self,
909
+ tool_calls: list[_messages.ToolCallPart],
910
+ result_tool_name: str | None,
911
+ deps: AgentDeps,
912
+ conv_messages: list[_messages.ModelMessage],
913
+ ) -> list[_messages.ModelRequestPart]:
914
+ """Process function (non-result) tool calls in parallel.
915
+
916
+ Also add stub return parts for any other tools that need it.
917
+ """
918
+ parts: list[_messages.ModelRequestPart] = []
919
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
920
+
921
+ stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
922
+
923
+ # we rely on the fact that if we found a result, it's the first result tool in the last
924
+ found_used_result_tool = False
925
+ for call in tool_calls:
926
+ if call.tool_name == result_tool_name and not found_used_result_tool:
927
+ found_used_result_tool = True
928
+ parts.append(
929
+ _messages.ToolReturnPart(
930
+ tool_name=call.tool_name,
931
+ content='Final result processed.',
932
+ tool_call_id=call.tool_call_id,
933
+ )
755
934
  )
756
- return None, [response]
757
- elif model_response.role == 'model-structured-response':
758
- if self._result_schema is not None:
759
- # if there's a result schema, and any of the calls match one of its tools, return the result
760
- # NOTE: this means we ignore any other tools called here
761
- if match := self._result_schema.find_tool(model_response):
762
- call, result_tool = match
763
- try:
764
- result_data = result_tool.validate(call)
765
- result_data = await self._validate_result(result_data, deps, call)
766
- except _result.ToolRetryError as e:
767
- self._incr_result_retry()
768
- return None, [e.tool_retry]
769
- else:
770
- # Add a ToolReturn message for the schema tool call
771
- tool_return = _messages.ToolReturn(
935
+ elif tool := self._function_tools.get(call.tool_name):
936
+ if stub_function_tools:
937
+ parts.append(
938
+ _messages.ToolReturnPart(
772
939
  tool_name=call.tool_name,
773
- content='Final result processed.',
774
- tool_id=call.tool_id,
940
+ content='Tool not executed - a final result was already processed.',
941
+ tool_call_id=call.tool_call_id,
775
942
  )
776
- return _MarkFinalResult(result_data), [tool_return]
777
-
778
- if not model_response.calls:
779
- raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
780
-
781
- # otherwise we run all tool functions in parallel
782
- messages: list[_messages.Message] = []
783
- tasks: list[asyncio.Task[_messages.Message]] = []
784
- for call in model_response.calls:
785
- if tool := self._function_tools.get(call.tool_name):
786
- tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
943
+ )
787
944
  else:
788
- messages.append(self._unknown_tool(call.tool_name))
945
+ tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
946
+ elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
947
+ # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
948
+ # validation, we don't add another part here
949
+ if result_tool_name is not None:
950
+ parts.append(
951
+ _messages.ToolReturnPart(
952
+ tool_name=call.tool_name,
953
+ content='Result tool not used - a final result was already processed.',
954
+ tool_call_id=call.tool_call_id,
955
+ )
956
+ )
957
+ else:
958
+ parts.append(self._unknown_tool(call.tool_name))
789
959
 
960
+ # Run all tool tasks in parallel
961
+ if tasks:
790
962
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
791
- task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
792
- messages.extend(task_results)
793
- return None, messages
794
- else:
795
- assert_never(model_response)
963
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
964
+ parts.extend(task_results)
965
+ return parts
796
966
 
797
967
  async def _handle_streamed_model_response(
798
- self, model_response: models.EitherStreamedResponse, deps: AgentDeps
799
- ) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
968
+ self,
969
+ model_response: models.EitherStreamedResponse,
970
+ deps: AgentDeps,
971
+ conv_messages: list[_messages.ModelMessage],
972
+ ) -> (
973
+ _MarkFinalResult[models.EitherStreamedResponse]
974
+ | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
975
+ ):
800
976
  """Process a streamed response from the model.
801
977
 
802
978
  Returns:
803
- A tuple of (final_result, messages). If final_result is not None, the conversation should end.
979
+ Either a final result or a tuple of the model response and the tool responses for the next request.
980
+ If a final result is returned, the conversation should end.
804
981
  """
805
982
  if isinstance(model_response, models.StreamTextResponse):
806
983
  # plain string response
807
984
  if self._allow_text_result:
808
- return _MarkFinalResult(model_response), []
985
+ return _MarkFinalResult(model_response, None)
809
986
  else:
810
987
  self._incr_result_retry()
811
- response = _messages.RetryPrompt(
988
+ response = _messages.RetryPromptPart(
812
989
  content='Plain text responses are not permitted, please call one of the functions instead.',
813
990
  )
814
991
  # stream the response, so cost is correct
815
992
  async for _ in model_response:
816
993
  pass
817
994
 
818
- return None, [response]
819
- else:
820
- assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
995
+ text = ''.join(model_response.get(final=True))
996
+ return _messages.ModelResponse([_messages.TextPart(text)]), [response]
997
+ elif isinstance(model_response, models.StreamStructuredResponse):
821
998
  if self._result_schema is not None:
822
999
  # if there's a result schema, iterate over the stream until we find at least one tool
823
1000
  # NOTE: this means we ignore any other tools called here
824
1001
  structured_msg = model_response.get()
825
- while not structured_msg.calls:
1002
+ while not structured_msg.parts:
826
1003
  try:
827
1004
  await model_response.__anext__()
828
1005
  except StopAsyncIteration:
829
1006
  break
830
1007
  structured_msg = model_response.get()
831
1008
 
832
- if match := self._result_schema.find_tool(structured_msg):
1009
+ if match := self._result_schema.find_tool(structured_msg.parts):
833
1010
  call, _ = match
834
- tool_return = _messages.ToolReturn(
835
- tool_name=call.tool_name,
836
- content='Final result processed.',
837
- tool_id=call.tool_id,
838
- )
839
- return _MarkFinalResult(model_response), [tool_return]
1011
+ return _MarkFinalResult(model_response, call.tool_name)
840
1012
 
841
1013
  # the model is calling a tool function, consume the response to get the next message
842
1014
  async for _ in model_response:
843
1015
  pass
844
- structured_msg = model_response.get()
845
- if not structured_msg.calls:
1016
+ model_response_msg = model_response.get()
1017
+ if not model_response_msg.parts:
846
1018
  raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
847
- messages: list[_messages.Message] = [structured_msg]
848
1019
 
849
1020
  # we now run all tool functions in parallel
850
- tasks: list[asyncio.Task[_messages.Message]] = []
851
- for call in structured_msg.calls:
852
- if tool := self._function_tools.get(call.tool_name):
853
- tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name))
854
- else:
855
- messages.append(self._unknown_tool(call.tool_name))
1021
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1022
+ parts: list[_messages.ModelRequestPart] = []
1023
+ for item in model_response_msg.parts:
1024
+ if isinstance(item, _messages.ToolCallPart):
1025
+ call = item
1026
+ if tool := self._function_tools.get(call.tool_name):
1027
+ tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
1028
+ else:
1029
+ parts.append(self._unknown_tool(call.tool_name))
856
1030
 
857
1031
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
858
- task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
859
- messages.extend(task_results)
860
- return None, messages
1032
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1033
+ parts.extend(task_results)
1034
+ return model_response_msg, parts
1035
+ else:
1036
+ assert_never(model_response)
861
1037
 
862
1038
  async def _validate_result(
863
- self, result_data: ResultData, deps: AgentDeps, tool_call: _messages.ToolCall | None
1039
+ self,
1040
+ result_data: ResultData,
1041
+ deps: AgentDeps,
1042
+ tool_call: _messages.ToolCallPart | None,
1043
+ conv_messages: list[_messages.ModelMessage],
864
1044
  ) -> ResultData:
865
1045
  for validator in self._result_validators:
866
- result_data = await validator.validate(result_data, deps, self._current_result_retry, tool_call)
1046
+ result_data = await validator.validate(
1047
+ result_data, deps, self._current_result_retry, tool_call, conv_messages
1048
+ )
867
1049
  return result_data
868
1050
 
869
1051
  def _incr_result_retry(self) -> None:
@@ -873,15 +1055,15 @@ class Agent(Generic[AgentDeps, ResultData]):
873
1055
  f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
874
1056
  )
875
1057
 
876
- async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
1058
+ async def _sys_parts(self, deps: AgentDeps) -> list[_messages.ModelRequestPart]:
877
1059
  """Build the initial messages for the conversation."""
878
- messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts]
1060
+ messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
879
1061
  for sys_prompt_runner in self._system_prompt_functions:
880
1062
  prompt = await sys_prompt_runner.run(deps)
881
- messages.append(_messages.SystemPrompt(prompt))
1063
+ messages.append(_messages.SystemPromptPart(prompt))
882
1064
  return messages
883
1065
 
884
- def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
1066
+ def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
885
1067
  self._incr_result_retry()
886
1068
  names = list(self._function_tools.keys())
887
1069
  if self._result_schema:
@@ -890,7 +1072,7 @@ class Agent(Generic[AgentDeps, ResultData]):
890
1072
  msg = f'Available tools: {", ".join(names)}'
891
1073
  else:
892
1074
  msg = 'No tools available.'
893
- return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')
1075
+ return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
894
1076
 
895
1077
  def _get_deps(self, deps: AgentDeps) -> AgentDeps:
896
1078
  """Get deps for a run.
@@ -934,3 +1116,6 @@ class _MarkFinalResult(Generic[ResultData]):
934
1116
  """
935
1117
 
936
1118
  data: ResultData
1119
+ """The final result data."""
1120
+ tool_name: str | None
1121
+ """Name of the final result tool, None if the result is a string."""