pydantic-ai-slim 0.0.21__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/agent.py CHANGED
@@ -5,14 +5,17 @@ import dataclasses
5
5
  import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
8
- from contextvars import ContextVar
9
8
  from types import FrameType
10
- from typing import Any, Callable, Generic, Literal, cast, final, overload
9
+ from typing import Any, Callable, Generic, cast, final, overload
11
10
 
12
11
  import logfire_api
13
- from typing_extensions import TypeVar, assert_never, deprecated
12
+ from typing_extensions import TypeVar, deprecated
13
+
14
+ from pydantic_graph import Graph, GraphRunContext, HistoryStep
15
+ from pydantic_graph.nodes import End
14
16
 
15
17
  from . import (
18
+ _agent_graph,
16
19
  _result,
17
20
  _system_prompt,
18
21
  _utils,
@@ -22,6 +25,7 @@ from . import (
22
25
  result,
23
26
  usage as _usage,
24
27
  )
28
+ from ._agent_graph import EndStrategy, capture_run_messages # imported for re-export
25
29
  from .result import ResultDataT
26
30
  from .settings import ModelSettings, merge_model_settings
27
31
  from .tools import (
@@ -29,7 +33,6 @@ from .tools import (
29
33
  DocstringFormat,
30
34
  RunContext,
31
35
  Tool,
32
- ToolDefinition,
33
36
  ToolFuncContext,
34
37
  ToolFuncEither,
35
38
  ToolFuncPlain,
@@ -52,14 +55,7 @@ else:
52
55
  logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
53
56
 
54
57
  T = TypeVar('T')
55
- """An invariant TypeVar."""
56
58
  NoneType = type(None)
57
- EndStrategy = Literal['early', 'exhaustive']
58
- """The strategy for handling multiple tool calls when a final result is found.
59
-
60
- - `'early'`: Stop processing other tool calls once a final result is found
61
- - `'exhaustive'`: Process all tool calls even after finding a final result
62
- """
63
59
  RunResultDataT = TypeVar('RunResultDataT')
64
60
  """Type variable for the result data of a run where `result_type` was customized on the run call."""
65
61
 
@@ -104,18 +100,24 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
104
100
  Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
105
101
  be merged with this value, with the runtime argument taking priority.
106
102
  """
103
+
104
+ result_type: type[ResultDataT] = dataclasses.field(repr=False)
105
+ """
106
+ The type of the result data, used to validate the result data, defaults to `str`.
107
+ """
108
+
109
+ _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
107
110
  _result_tool_name: str = dataclasses.field(repr=False)
108
111
  _result_tool_description: str | None = dataclasses.field(repr=False)
109
112
  _result_schema: _result.ResultSchema[ResultDataT] | None = dataclasses.field(repr=False)
110
113
  _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = dataclasses.field(repr=False)
111
114
  _system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
112
- _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
113
- _default_retries: int = dataclasses.field(repr=False)
114
115
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False)
115
116
  _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(
116
117
  repr=False
117
118
  )
118
- _deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
119
+ _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
120
+ _default_retries: int = dataclasses.field(repr=False)
119
121
  _max_result_retries: int = dataclasses.field(repr=False)
120
122
  _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
121
123
  _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
@@ -174,25 +176,30 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
174
176
  self.end_strategy = end_strategy
175
177
  self.name = name
176
178
  self.model_settings = model_settings
179
+ self.result_type = result_type
180
+
181
+ self._deps_type = deps_type
182
+
177
183
  self._result_tool_name = result_tool_name
178
184
  self._result_tool_description = result_tool_description
179
- self._result_schema = _result.ResultSchema[result_type].build(
185
+ self._result_schema: _result.ResultSchema[ResultDataT] | None = _result.ResultSchema[result_type].build(
180
186
  result_type, result_tool_name, result_tool_description
181
187
  )
188
+ self._result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] = []
182
189
 
183
190
  self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
184
- self._function_tools = {}
191
+ self._system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []
192
+ self._system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = {}
193
+
194
+ self._function_tools: dict[str, Tool[AgentDepsT]] = {}
195
+
185
196
  self._default_retries = retries
197
+ self._max_result_retries = result_retries if result_retries is not None else retries
186
198
  for tool in tools:
187
199
  if isinstance(tool, Tool):
188
200
  self._register_tool(tool)
189
201
  else:
190
202
  self._register_tool(Tool(tool))
191
- self._deps_type = deps_type
192
- self._system_prompt_functions = []
193
- self._system_prompt_dynamic_functions = {}
194
- self._max_result_retries = result_retries if result_retries is not None else retries
195
- self._result_validators = []
196
203
 
197
204
  @overload
198
205
  async def run(
@@ -272,66 +279,80 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
272
279
 
273
280
  deps = self._get_deps(deps)
274
281
  new_message_index = len(message_history) if message_history else 0
275
- result_schema = self._prepare_result_schema(result_type)
282
+ result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
283
+
284
+ # Build the graph
285
+ graph = self._build_graph(result_type)
286
+
287
+ # Build the initial state
288
+ state = _agent_graph.GraphAgentState(
289
+ message_history=message_history[:] if message_history else [],
290
+ usage=usage or _usage.Usage(),
291
+ retries=0,
292
+ run_step=0,
293
+ )
294
+
295
+ # We consider it a user error if a user tries to restrict the result type while having a result validator that
296
+ # may change the result type from the restricted type to something else. Therefore, we consider the following
297
+ # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
298
+ result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
299
+
300
+ # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
301
+ # runs. Requires some changes to `Tool` to make them copyable though.
302
+ for v in self._function_tools.values():
303
+ v.current_retry = 0
304
+
305
+ model_settings = merge_model_settings(self.model_settings, model_settings)
306
+ usage_limits = usage_limits or _usage.UsageLimits()
276
307
 
277
308
  with _logfire.span(
278
309
  '{agent_name} run {prompt=}',
279
310
  prompt=user_prompt,
280
311
  agent=self,
281
- model_name=model_used.name(),
312
+ model_name=model_used.name() if model_used else 'no-model',
282
313
  agent_name=self.name or 'agent',
283
314
  ) as run_span:
284
- run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
285
- messages = await self._prepare_messages(user_prompt, message_history, run_context)
286
- run_context.messages = messages
287
-
288
- for tool in self._function_tools.values():
289
- tool.current_retry = 0
290
-
291
- model_settings = merge_model_settings(self.model_settings, model_settings)
292
- usage_limits = usage_limits or _usage.UsageLimits()
293
-
294
- while True:
295
- usage_limits.check_before_request(run_context.usage)
296
-
297
- run_context.run_step += 1
298
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
299
- agent_model = await self._prepare_model(run_context, result_schema)
300
-
301
- with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
302
- model_response, request_usage = await agent_model.request(messages, model_settings)
303
- model_req_span.set_attribute('response', model_response)
304
- model_req_span.set_attribute('usage', request_usage)
315
+ # Build the deps object for the graph
316
+ graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
317
+ user_deps=deps,
318
+ prompt=user_prompt,
319
+ new_message_index=new_message_index,
320
+ model=model_used,
321
+ model_settings=model_settings,
322
+ usage_limits=usage_limits,
323
+ max_result_retries=self._max_result_retries,
324
+ end_strategy=self.end_strategy,
325
+ result_schema=result_schema,
326
+ result_tools=self._result_schema.tool_defs() if self._result_schema else [],
327
+ result_validators=result_validators,
328
+ function_tools=self._function_tools,
329
+ run_span=run_span,
330
+ )
305
331
 
306
- messages.append(model_response)
307
- run_context.usage.incr(request_usage, requests=1)
308
- usage_limits.check_tokens(run_context.usage)
332
+ start_node = _agent_graph.UserPromptNode[AgentDepsT](
333
+ user_prompt=user_prompt,
334
+ system_prompts=self._system_prompts,
335
+ system_prompt_functions=self._system_prompt_functions,
336
+ system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
337
+ )
309
338
 
310
- with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
311
- final_result, tool_responses = await self._handle_model_response(
312
- model_response, run_context, result_schema
313
- )
339
+ # Actually run
340
+ end_result, _ = await graph.run(
341
+ start_node,
342
+ state=state,
343
+ deps=graph_deps,
344
+ infer_name=False,
345
+ )
314
346
 
315
- if tool_responses:
316
- # Add parts to the conversation as a new message
317
- messages.append(_messages.ModelRequest(tool_responses))
318
-
319
- # Check if we got a final result
320
- if final_result is not None:
321
- result_data = final_result.data
322
- result_tool_name = final_result.tool_name
323
- run_span.set_attribute('all_messages', messages)
324
- run_span.set_attribute('usage', run_context.usage)
325
- handle_span.set_attribute('result', result_data)
326
- handle_span.message = 'handle model response -> final result'
327
- return result.RunResult(
328
- messages, new_message_index, result_data, result_tool_name, run_context.usage
329
- )
330
- else:
331
- # continue the conversation
332
- handle_span.set_attribute('tool_responses', tool_responses)
333
- tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
334
- handle_span.message = f'handle model response -> {tool_responses_str}'
347
+ # Build final run result
348
+ # We don't do any advanced checking if the data is actually from a final result or not
349
+ return result.RunResult(
350
+ state.message_history,
351
+ new_message_index,
352
+ end_result.data,
353
+ end_result.tool_name,
354
+ state.usage,
355
+ )
335
356
 
336
357
  @overload
337
358
  def run_sync(
@@ -503,7 +524,31 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
503
524
 
504
525
  deps = self._get_deps(deps)
505
526
  new_message_index = len(message_history) if message_history else 0
506
- result_schema = self._prepare_result_schema(result_type)
527
+ result_schema: _result.ResultSchema[RunResultDataT] | None = self._prepare_result_schema(result_type)
528
+
529
+ # Build the graph
530
+ graph = self._build_stream_graph(result_type)
531
+
532
+ # Build the initial state
533
+ graph_state = _agent_graph.GraphAgentState(
534
+ message_history=message_history[:] if message_history else [],
535
+ usage=usage or _usage.Usage(),
536
+ retries=0,
537
+ run_step=0,
538
+ )
539
+
540
+ # We consider it a user error if a user tries to restrict the result type while having a result validator that
541
+ # may change the result type from the restricted type to something else. Therefore, we consider the following
542
+ # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
543
+ result_validators = cast(list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators)
544
+
545
+ # TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
546
+ # runs. Requires some changes to `Tool` to make them copyable though.
547
+ for v in self._function_tools.values():
548
+ v.current_retry = 0
549
+
550
+ model_settings = merge_model_settings(self.model_settings, model_settings)
551
+ usage_limits = usage_limits or _usage.UsageLimits()
507
552
 
508
553
  with _logfire.span(
509
554
  '{agent_name} run stream {prompt=}',
@@ -512,97 +557,53 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
512
557
  model_name=model_used.name(),
513
558
  agent_name=self.name or 'agent',
514
559
  ) as run_span:
515
- run_context = RunContext(deps, model_used, usage or _usage.Usage(), user_prompt)
516
- messages = await self._prepare_messages(user_prompt, message_history, run_context)
517
- run_context.messages = messages
518
-
519
- for tool in self._function_tools.values():
520
- tool.current_retry = 0
560
+ # Build the deps object for the graph
561
+ graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunResultDataT](
562
+ user_deps=deps,
563
+ prompt=user_prompt,
564
+ new_message_index=new_message_index,
565
+ model=model_used,
566
+ model_settings=model_settings,
567
+ usage_limits=usage_limits,
568
+ max_result_retries=self._max_result_retries,
569
+ end_strategy=self.end_strategy,
570
+ result_schema=result_schema,
571
+ result_tools=self._result_schema.tool_defs() if self._result_schema else [],
572
+ result_validators=result_validators,
573
+ function_tools=self._function_tools,
574
+ run_span=run_span,
575
+ )
521
576
 
522
- model_settings = merge_model_settings(self.model_settings, model_settings)
523
- usage_limits = usage_limits or _usage.UsageLimits()
577
+ start_node = _agent_graph.StreamUserPromptNode[AgentDepsT](
578
+ user_prompt=user_prompt,
579
+ system_prompts=self._system_prompts,
580
+ system_prompt_functions=self._system_prompt_functions,
581
+ system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
582
+ )
524
583
 
584
+ # Actually run
585
+ node = start_node
586
+ history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = []
525
587
  while True:
526
- run_context.run_step += 1
527
- usage_limits.check_before_request(run_context.usage)
528
-
529
- with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
530
- agent_model = await self._prepare_model(run_context, result_schema)
531
-
532
- with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
533
- async with agent_model.request_stream(messages, model_settings) as model_response:
534
- run_context.usage.requests += 1
535
- model_req_span.set_attribute('response_type', model_response.__class__.__name__)
536
- # We want to end the "model request" span here, but we can't exit the context manager
537
- # in the traditional way
538
- model_req_span.__exit__(None, None, None)
539
-
540
- with _logfire.span('handle model response') as handle_span:
541
- maybe_final_result = await self._handle_streamed_response(
542
- model_response, run_context, result_schema
543
- )
544
-
545
- # Check if we got a final result
546
- if isinstance(maybe_final_result, _MarkFinalResult):
547
- result_stream = maybe_final_result.data
548
- result_tool_name = maybe_final_result.tool_name
549
- handle_span.message = 'handle model response -> final result'
550
-
551
- async def on_complete():
552
- """Called when the stream has completed.
553
-
554
- The model response will have been added to messages by now
555
- by `StreamedRunResult._marked_completed`.
556
- """
557
- last_message = messages[-1]
558
- assert isinstance(last_message, _messages.ModelResponse)
559
- tool_calls = [
560
- part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
561
- ]
562
- parts = await self._process_function_tools(
563
- tool_calls, result_tool_name, run_context, result_schema
564
- )
565
- if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
566
- self._incr_result_retry(run_context)
567
- if parts:
568
- messages.append(_messages.ModelRequest(parts))
569
- run_span.set_attribute('all_messages', messages)
570
-
571
- # The following is not guaranteed to be true, but we consider it a user error if
572
- # there are result validators that might convert the result data from an overridden
573
- # `result_type` to a type that is not valid as such.
574
- result_validators = cast(
575
- list[_result.ResultValidator[AgentDepsT, RunResultDataT]], self._result_validators
576
- )
577
-
578
- yield result.StreamedRunResult(
579
- messages,
580
- new_message_index,
581
- usage_limits,
582
- result_stream,
583
- result_schema,
584
- run_context,
585
- result_validators,
586
- result_tool_name,
587
- on_complete,
588
- )
589
- return
590
- else:
591
- # continue the conversation
592
- model_response_msg, tool_responses = maybe_final_result
593
- # if we got a model response add that to messages
594
- messages.append(model_response_msg)
595
- if tool_responses:
596
- # if we got one or more tool response parts, add a model request message
597
- messages.append(_messages.ModelRequest(tool_responses))
598
-
599
- handle_span.set_attribute('tool_responses', tool_responses)
600
- tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
601
- handle_span.message = f'handle model response -> {tool_responses_str}'
602
- # the model_response should have been fully streamed by now, we can add its usage
603
- model_response_usage = model_response.usage()
604
- run_context.usage.incr(model_response_usage)
605
- usage_limits.check_tokens(run_context.usage)
588
+ if isinstance(node, _agent_graph.StreamModelRequestNode):
589
+ node = cast(
590
+ _agent_graph.StreamModelRequestNode[
591
+ AgentDepsT, result.StreamedRunResult[AgentDepsT, RunResultDataT]
592
+ ],
593
+ node,
594
+ )
595
+ async with node.run_to_result(GraphRunContext(graph_state, graph_deps)) as r:
596
+ if isinstance(r, End):
597
+ yield r.data
598
+ break
599
+ assert not isinstance(node, End) # the previous line should be hit first
600
+ node = await graph.next(
601
+ node,
602
+ history,
603
+ state=graph_state,
604
+ deps=graph_deps,
605
+ infer_name=False,
606
+ )
606
607
 
607
608
  @contextmanager
608
609
  def override(
@@ -718,7 +719,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
718
719
  return decorator
719
720
  else:
720
721
  assert not dynamic, "dynamic can't be True in this case"
721
- self._system_prompt_functions.append(_system_prompt.SystemPromptRunner(func, dynamic=dynamic))
722
+ self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
722
723
  return func
723
724
 
724
725
  @overload
@@ -998,335 +999,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
998
999
 
999
1000
  return model_
1000
1001
 
1001
- async def _prepare_model(
1002
- self, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
1003
- ) -> models.AgentModel:
1004
- """Build tools and create an agent model."""
1005
- function_tools: list[ToolDefinition] = []
1006
-
1007
- async def add_tool(tool: Tool[AgentDepsT]) -> None:
1008
- ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
1009
- if tool_def := await tool.prepare_tool_def(ctx):
1010
- function_tools.append(tool_def)
1011
-
1012
- await asyncio.gather(*map(add_tool, self._function_tools.values()))
1013
-
1014
- return await run_context.model.agent_model(
1015
- function_tools=function_tools,
1016
- allow_text_result=self._allow_text_result(result_schema),
1017
- result_tools=result_schema.tool_defs() if result_schema is not None else [],
1018
- )
1019
-
1020
- async def _reevaluate_dynamic_prompts(
1021
- self, messages: list[_messages.ModelMessage], run_context: RunContext[AgentDepsT]
1022
- ) -> None:
1023
- """Reevaluate any `SystemPromptPart` with dynamic_ref in the provided messages by running the associated runner function."""
1024
- # Only proceed if there's at least one dynamic runner.
1025
- if self._system_prompt_dynamic_functions:
1026
- for msg in messages:
1027
- if isinstance(msg, _messages.ModelRequest):
1028
- for i, part in enumerate(msg.parts):
1029
- if isinstance(part, _messages.SystemPromptPart) and part.dynamic_ref:
1030
- # Look up the runner by its ref
1031
- if runner := self._system_prompt_dynamic_functions.get(part.dynamic_ref):
1032
- updated_part_content = await runner.run(run_context)
1033
- msg.parts[i] = _messages.SystemPromptPart(
1034
- updated_part_content, dynamic_ref=part.dynamic_ref
1035
- )
1036
-
1037
- def _prepare_result_schema(
1038
- self, result_type: type[RunResultDataT] | None
1039
- ) -> _result.ResultSchema[RunResultDataT] | None:
1040
- if result_type is not None:
1041
- if self._result_validators:
1042
- raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
1043
- return _result.ResultSchema[result_type].build(
1044
- result_type, self._result_tool_name, self._result_tool_description
1045
- )
1046
- else:
1047
- return self._result_schema # pyright: ignore[reportReturnType]
1048
-
1049
- async def _prepare_messages(
1050
- self,
1051
- user_prompt: str,
1052
- message_history: list[_messages.ModelMessage] | None,
1053
- run_context: RunContext[AgentDepsT],
1054
- ) -> list[_messages.ModelMessage]:
1055
- try:
1056
- ctx_messages = get_captured_run_messages()
1057
- except LookupError:
1058
- messages: list[_messages.ModelMessage] = []
1059
- else:
1060
- if ctx_messages.used:
1061
- messages = []
1062
- else:
1063
- messages = ctx_messages.messages
1064
- ctx_messages.used = True
1065
-
1066
- if message_history:
1067
- # Shallow copy messages
1068
- messages.extend(message_history)
1069
- # Reevaluate any dynamic system prompt parts
1070
- await self._reevaluate_dynamic_prompts(messages, run_context)
1071
- messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
1072
- else:
1073
- parts = await self._sys_parts(run_context)
1074
- parts.append(_messages.UserPromptPart(user_prompt))
1075
- messages.append(_messages.ModelRequest(parts))
1076
-
1077
- return messages
1078
-
1079
- async def _handle_model_response(
1080
- self,
1081
- model_response: _messages.ModelResponse,
1082
- run_context: RunContext[AgentDepsT],
1083
- result_schema: _result.ResultSchema[RunResultDataT] | None,
1084
- ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1085
- """Process a non-streamed response from the model.
1086
-
1087
- Returns:
1088
- A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
1089
- """
1090
- texts: list[str] = []
1091
- tool_calls: list[_messages.ToolCallPart] = []
1092
- for part in model_response.parts:
1093
- if isinstance(part, _messages.TextPart):
1094
- # ignore empty content for text parts, see #437
1095
- if part.content:
1096
- texts.append(part.content)
1097
- else:
1098
- tool_calls.append(part)
1099
-
1100
- # At the moment, we prioritize at least executing tool calls if they are present.
1101
- # In the future, we'd consider making this configurable at the agent or run level.
1102
- # This accounts for cases like anthropic returns that might contain a text response
1103
- # and a tool call response, where the text response just indicates the tool call will happen.
1104
- if tool_calls:
1105
- return await self._handle_structured_response(tool_calls, run_context, result_schema)
1106
- elif texts:
1107
- text = '\n\n'.join(texts)
1108
- return await self._handle_text_response(text, run_context, result_schema)
1109
- else:
1110
- raise exceptions.UnexpectedModelBehavior('Received empty model response')
1111
-
1112
- async def _handle_text_response(
1113
- self, text: str, run_context: RunContext[AgentDepsT], result_schema: _result.ResultSchema[RunResultDataT] | None
1114
- ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1115
- """Handle a plain text response from the model for non-streaming responses."""
1116
- if self._allow_text_result(result_schema):
1117
- result_data_input = cast(RunResultDataT, text)
1118
- try:
1119
- result_data = await self._validate_result(result_data_input, run_context, None)
1120
- except _result.ToolRetryError as e:
1121
- self._incr_result_retry(run_context)
1122
- return None, [e.tool_retry]
1123
- else:
1124
- return _MarkFinalResult(result_data, None), []
1125
- else:
1126
- self._incr_result_retry(run_context)
1127
- response = _messages.RetryPromptPart(
1128
- content='Plain text responses are not permitted, please call one of the functions instead.',
1129
- )
1130
- return None, [response]
1131
-
1132
- async def _handle_structured_response(
1133
- self,
1134
- tool_calls: list[_messages.ToolCallPart],
1135
- run_context: RunContext[AgentDepsT],
1136
- result_schema: _result.ResultSchema[RunResultDataT] | None,
1137
- ) -> tuple[_MarkFinalResult[RunResultDataT] | None, list[_messages.ModelRequestPart]]:
1138
- """Handle a structured response containing tool calls from the model for non-streaming responses."""
1139
- assert tool_calls, 'Expected at least one tool call'
1140
-
1141
- # first look for the result tool call
1142
- final_result: _MarkFinalResult[RunResultDataT] | None = None
1143
-
1144
- parts: list[_messages.ModelRequestPart] = []
1145
- if result_schema is not None:
1146
- if match := result_schema.find_tool(tool_calls):
1147
- call, result_tool = match
1148
- try:
1149
- result_data = result_tool.validate(call)
1150
- result_data = await self._validate_result(result_data, run_context, call)
1151
- except _result.ToolRetryError as e:
1152
- parts.append(e.tool_retry)
1153
- else:
1154
- final_result = _MarkFinalResult(result_data, call.tool_name)
1155
-
1156
- # Then build the other request parts based on end strategy
1157
- parts += await self._process_function_tools(
1158
- tool_calls, final_result and final_result.tool_name, run_context, result_schema
1159
- )
1160
-
1161
- if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1162
- self._incr_result_retry(run_context)
1163
-
1164
- return final_result, parts
1165
-
1166
- async def _process_function_tools(
1167
- self,
1168
- tool_calls: list[_messages.ToolCallPart],
1169
- result_tool_name: str | None,
1170
- run_context: RunContext[AgentDepsT],
1171
- result_schema: _result.ResultSchema[RunResultDataT] | None,
1172
- ) -> list[_messages.ModelRequestPart]:
1173
- """Process function (non-result) tool calls in parallel.
1174
-
1175
- Also add stub return parts for any other tools that need it.
1176
- """
1177
- parts: list[_messages.ModelRequestPart] = []
1178
- tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1179
-
1180
- stub_function_tools = bool(result_tool_name) and self.end_strategy == 'early'
1181
-
1182
- # we rely on the fact that if we found a result, it's the first result tool in the last
1183
- found_used_result_tool = False
1184
- for call in tool_calls:
1185
- if call.tool_name == result_tool_name and not found_used_result_tool:
1186
- found_used_result_tool = True
1187
- parts.append(
1188
- _messages.ToolReturnPart(
1189
- tool_name=call.tool_name,
1190
- content='Final result processed.',
1191
- tool_call_id=call.tool_call_id,
1192
- )
1193
- )
1194
- elif tool := self._function_tools.get(call.tool_name):
1195
- if stub_function_tools:
1196
- parts.append(
1197
- _messages.ToolReturnPart(
1198
- tool_name=call.tool_name,
1199
- content='Tool not executed - a final result was already processed.',
1200
- tool_call_id=call.tool_call_id,
1201
- )
1202
- )
1203
- else:
1204
- tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1205
- elif result_schema is not None and call.tool_name in result_schema.tools:
1206
- # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
1207
- # validation, we don't add another part here
1208
- if result_tool_name is not None:
1209
- parts.append(
1210
- _messages.ToolReturnPart(
1211
- tool_name=call.tool_name,
1212
- content='Result tool not used - a final result was already processed.',
1213
- tool_call_id=call.tool_call_id,
1214
- )
1215
- )
1216
- else:
1217
- parts.append(self._unknown_tool(call.tool_name, result_schema))
1218
-
1219
- # Run all tool tasks in parallel
1220
- if tasks:
1221
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1222
- task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1223
- parts.extend(task_results)
1224
- return parts
1225
-
1226
- async def _handle_streamed_response(
1227
- self,
1228
- streamed_response: models.StreamedResponse,
1229
- run_context: RunContext[AgentDepsT],
1230
- result_schema: _result.ResultSchema[RunResultDataT] | None,
1231
- ) -> _MarkFinalResult[models.StreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]:
1232
- """Process a streamed response from the model.
1233
-
1234
- Returns:
1235
- Either a final result or a tuple of the model response and the tool responses for the next request.
1236
- If a final result is returned, the conversation should end.
1237
- """
1238
- received_text = False
1239
-
1240
- async for maybe_part_event in streamed_response:
1241
- if isinstance(maybe_part_event, _messages.PartStartEvent):
1242
- new_part = maybe_part_event.part
1243
- if isinstance(new_part, _messages.TextPart):
1244
- received_text = True
1245
- if self._allow_text_result(result_schema):
1246
- return _MarkFinalResult(streamed_response, None)
1247
- elif isinstance(new_part, _messages.ToolCallPart):
1248
- if result_schema is not None and (match := result_schema.find_tool([new_part])):
1249
- call, _ = match
1250
- return _MarkFinalResult(streamed_response, call.tool_name)
1251
- else:
1252
- assert_never(new_part)
1253
-
1254
- tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
1255
- parts: list[_messages.ModelRequestPart] = []
1256
- model_response = streamed_response.get()
1257
- if not model_response.parts:
1258
- raise exceptions.UnexpectedModelBehavior('Received empty model response')
1259
- for p in model_response.parts:
1260
- if isinstance(p, _messages.ToolCallPart):
1261
- if tool := self._function_tools.get(p.tool_name):
1262
- tasks.append(asyncio.create_task(tool.run(p, run_context), name=p.tool_name))
1263
- else:
1264
- parts.append(self._unknown_tool(p.tool_name, result_schema))
1265
-
1266
- if received_text and not tasks and not parts:
1267
- # Can only get here if self._allow_text_result returns `False` for the provided result_schema
1268
- self._incr_result_retry(run_context)
1269
- model_response = _messages.RetryPromptPart(
1270
- content='Plain text responses are not permitted, please call one of the functions instead.',
1271
- )
1272
- return streamed_response.get(), [model_response]
1273
-
1274
- with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1275
- task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
1276
- parts.extend(task_results)
1277
-
1278
- if any(isinstance(part, _messages.RetryPromptPart) for part in parts):
1279
- self._incr_result_retry(run_context)
1280
-
1281
- return model_response, parts
1282
-
1283
- async def _validate_result(
1284
- self,
1285
- result_data: RunResultDataT,
1286
- run_context: RunContext[AgentDepsT],
1287
- tool_call: _messages.ToolCallPart | None,
1288
- ) -> RunResultDataT:
1289
- if self._result_validators:
1290
- agent_result_data = cast(ResultDataT, result_data)
1291
- for validator in self._result_validators:
1292
- agent_result_data = await validator.validate(agent_result_data, tool_call, run_context)
1293
- return cast(RunResultDataT, agent_result_data)
1294
- else:
1295
- return result_data
1296
-
1297
- def _incr_result_retry(self, run_context: RunContext[AgentDepsT]) -> None:
1298
- run_context.retry += 1
1299
- if run_context.retry > self._max_result_retries:
1300
- raise exceptions.UnexpectedModelBehavior(
1301
- f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
1302
- )
1303
-
1304
- async def _sys_parts(self, run_context: RunContext[AgentDepsT]) -> list[_messages.ModelRequestPart]:
1305
- """Build the initial messages for the conversation."""
1306
- messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1307
- for sys_prompt_runner in self._system_prompt_functions:
1308
- prompt = await sys_prompt_runner.run(run_context)
1309
- if sys_prompt_runner.dynamic:
1310
- messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
1311
- else:
1312
- messages.append(_messages.SystemPromptPart(prompt))
1313
- return messages
1314
-
1315
- def _unknown_tool(
1316
- self,
1317
- tool_name: str,
1318
- result_schema: _result.ResultSchema[RunResultDataT] | None,
1319
- ) -> _messages.RetryPromptPart:
1320
- names = list(self._function_tools.keys())
1321
- if result_schema:
1322
- names.extend(result_schema.tool_names())
1323
- if names:
1324
- msg = f'Available tools: {", ".join(names)}'
1325
- else:
1326
- msg = 'No tools available.'
1327
- return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
1328
-
1329
- def _get_deps(self: Agent[T, Any], deps: T) -> T:
1002
+ def _get_deps(self: Agent[T, ResultDataT], deps: T) -> T:
1330
1003
  """Get deps for a run.
1331
1004
 
1332
1005
  If we've overridden deps via `_override_deps`, use that, otherwise use the deps passed to the call.
@@ -1357,10 +1030,6 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1357
1030
  self.name = name
1358
1031
  return
1359
1032
 
1360
- @staticmethod
1361
- def _allow_text_result(result_schema: _result.ResultSchema[RunResultDataT] | None) -> bool:
1362
- return result_schema is None or result_schema.allow_text_result
1363
-
1364
1033
  @property
1365
1034
  @deprecated(
1366
1035
  'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
@@ -1368,65 +1037,24 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1368
1037
  def last_run_messages(self) -> list[_messages.ModelMessage]:
1369
1038
  raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1370
1039
 
1040
+ def _build_graph(
1041
+ self, result_type: type[RunResultDataT] | None
1042
+ ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
1043
+ return _agent_graph.build_agent_graph(self.name, self._deps_type, result_type or self.result_type)
1371
1044
 
1372
- @dataclasses.dataclass
1373
- class _RunMessages:
1374
- messages: list[_messages.ModelMessage]
1375
- used: bool = False
1376
-
1377
-
1378
- _messages_ctx_var: ContextVar[_RunMessages] = ContextVar('var')
1379
-
1380
-
1381
- @contextmanager
1382
- def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1383
- """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
1384
-
1385
- Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
1386
-
1387
- Examples:
1388
- ```python
1389
- from pydantic_ai import Agent, capture_run_messages
1390
-
1391
- agent = Agent('test')
1392
-
1393
- with capture_run_messages() as messages:
1394
- try:
1395
- result = agent.run_sync('foobar')
1396
- except Exception:
1397
- print(messages)
1398
- raise
1399
- ```
1400
-
1401
- !!! note
1402
- If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context,
1403
- `messages` will represent the messages exchanged during the first call only.
1404
- """
1405
- try:
1406
- yield _messages_ctx_var.get().messages
1407
- except LookupError:
1408
- messages: list[_messages.ModelMessage] = []
1409
- token = _messages_ctx_var.set(_RunMessages(messages))
1410
- try:
1411
- yield messages
1412
- finally:
1413
- _messages_ctx_var.reset(token)
1414
-
1415
-
1416
- def get_captured_run_messages() -> _RunMessages:
1417
- return _messages_ctx_var.get()
1418
-
1419
-
1420
- @dataclasses.dataclass
1421
- class _MarkFinalResult(Generic[ResultDataT]):
1422
- """Marker class to indicate that the result is the final result.
1423
-
1424
- This allows us to use `isinstance`, which wouldn't be possible if we were returning `ResultDataT` directly.
1425
-
1426
- It also avoids problems in the case where the result type is itself `None`, but is set.
1427
- """
1045
+ def _build_stream_graph(
1046
+ self, result_type: type[RunResultDataT] | None
1047
+ ) -> Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], Any]:
1048
+ return _agent_graph.build_agent_stream_graph(self.name, self._deps_type, result_type or self.result_type)
1428
1049
 
1429
- data: ResultDataT
1430
- """The final result data."""
1431
- tool_name: str | None
1432
- """Name of the final result tool, None if the result is a string."""
1050
+ def _prepare_result_schema(
1051
+ self, result_type: type[RunResultDataT] | None
1052
+ ) -> _result.ResultSchema[RunResultDataT] | None:
1053
+ if result_type is not None:
1054
+ if self._result_validators:
1055
+ raise exceptions.UserError('Cannot set a custom run `result_type` when the agent has result validators')
1056
+ return _result.ResultSchema[result_type].build(
1057
+ result_type, self._result_tool_name, self._result_tool_description
1058
+ )
1059
+ else:
1060
+ return self._result_schema # pyright: ignore[reportReturnType]