pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/__init__.py CHANGED
@@ -1,8 +1,19 @@
1
1
  from importlib.metadata import version
2
2
 
3
- from .agent import Agent
4
- from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
3
+ from .agent import Agent, capture_run_messages
4
+ from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
5
5
  from .tools import RunContext, Tool
6
6
 
7
- __all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
7
+ __all__ = (
8
+ 'Agent',
9
+ 'capture_run_messages',
10
+ 'RunContext',
11
+ 'Tool',
12
+ 'AgentRunError',
13
+ 'ModelRetry',
14
+ 'UnexpectedModelBehavior',
15
+ 'UsageLimitExceeded',
16
+ 'UserError',
17
+ '__version__',
18
+ )
8
19
  __version__ = version('pydantic_ai_slim')
pydantic_ai/_result.py CHANGED
@@ -12,8 +12,8 @@ from typing_extensions import Self, TypeAliasType, TypedDict
12
12
 
13
13
  from . import _utils, messages as _messages
14
14
  from .exceptions import ModelRetry
15
- from .result import ResultData
16
- from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition
15
+ from .result import ResultData, ResultValidatorFunc
16
+ from .tools import AgentDeps, RunContext, ToolDefinition
17
17
 
18
18
 
19
19
  @dataclass
@@ -29,25 +29,22 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
29
29
  async def validate(
30
30
  self,
31
31
  result: ResultData,
32
- deps: AgentDeps,
33
- retry: int,
34
32
  tool_call: _messages.ToolCallPart | None,
35
- messages: list[_messages.ModelMessage],
33
+ run_context: RunContext[AgentDeps],
36
34
  ) -> ResultData:
37
35
  """Validate a result but calling the function.
38
36
 
39
37
  Args:
40
38
  result: The result data after Pydantic validation the message content.
41
- deps: The agent dependencies.
42
- retry: The current retry number.
43
39
  tool_call: The original tool call message, `None` if there was no tool call.
44
- messages: The messages exchanged so far in the conversation.
40
+ run_context: The current run context.
45
41
 
46
42
  Returns:
47
43
  Result of either the validated result data (ok) or a retry message (Err).
48
44
  """
49
45
  if self._takes_ctx:
50
- args = RunContext(deps, retry, messages, tool_call.tool_name if tool_call else None), result
46
+ ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
47
+ args = ctx, result
51
48
  else:
52
49
  args = (result,)
53
50
 
@@ -19,9 +19,9 @@ class SystemPromptRunner(Generic[AgentDeps]):
19
19
  self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
20
20
  self._is_async = inspect.iscoroutinefunction(self.function)
21
21
 
22
- async def run(self, deps: AgentDeps) -> str:
22
+ async def run(self, run_context: RunContext[AgentDeps]) -> str:
23
23
  if self._takes_ctx:
24
- args = (RunContext(deps, 0, [], None),)
24
+ args = (run_context,)
25
25
  else:
26
26
  args = ()
27
27
 
pydantic_ai/agent.py CHANGED
@@ -5,12 +5,13 @@ import dataclasses
5
5
  import inspect
6
6
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
+ from contextvars import ContextVar
8
9
  from dataclasses import dataclass, field
9
10
  from types import FrameType
10
11
  from typing import Any, Callable, Generic, Literal, cast, final, overload
11
12
 
12
13
  import logfire_api
13
- from typing_extensions import assert_never
14
+ from typing_extensions import assert_never, deprecated
14
15
 
15
16
  from . import (
16
17
  _result,
@@ -22,7 +23,7 @@ from . import (
22
23
  result,
23
24
  )
24
25
  from .result import ResultData
25
- from .settings import ModelSettings, merge_model_settings
26
+ from .settings import ModelSettings, UsageLimits, merge_model_settings
26
27
  from .tools import (
27
28
  AgentDeps,
28
29
  RunContext,
@@ -35,7 +36,7 @@ from .tools import (
35
36
  ToolPrepareFunc,
36
37
  )
37
38
 
38
- __all__ = ('Agent',)
39
+ __all__ = 'Agent', 'capture_run_messages', 'EndStrategy'
39
40
 
40
41
  _logfire = logfire_api.Logfire(otel_scope='pydantic-ai')
41
42
 
@@ -89,12 +90,6 @@ class Agent(Generic[AgentDeps, ResultData]):
89
90
  be merged with this value, with the runtime argument taking priority.
90
91
  """
91
92
 
92
- last_run_messages: list[_messages.ModelMessage] | None
93
- """The messages from the last run, useful when a run raised an exception.
94
-
95
- Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
96
- """
97
-
98
93
  _result_schema: _result.ResultSchema[ResultData] | None = field(repr=False)
99
94
  _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
100
95
  _allow_text_result: bool = field(repr=False)
@@ -104,7 +99,6 @@ class Agent(Generic[AgentDeps, ResultData]):
104
99
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
105
100
  _deps_type: type[AgentDeps] = field(repr=False)
106
101
  _max_result_retries: int = field(repr=False)
107
- _current_result_retry: int = field(repr=False)
108
102
  _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
109
103
  _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
110
104
 
@@ -162,7 +156,6 @@ class Agent(Generic[AgentDeps, ResultData]):
162
156
  self.end_strategy = end_strategy
163
157
  self.name = name
164
158
  self.model_settings = model_settings
165
- self.last_run_messages = None
166
159
  self._result_schema = _result.ResultSchema[result_type].build(
167
160
  result_type, result_tool_name, result_tool_description
168
161
  )
@@ -180,7 +173,6 @@ class Agent(Generic[AgentDeps, ResultData]):
180
173
  self._deps_type = deps_type
181
174
  self._system_prompt_functions = []
182
175
  self._max_result_retries = result_retries if result_retries is not None else retries
183
- self._current_result_retry = 0
184
176
  self._result_validators = []
185
177
 
186
178
  async def run(
@@ -191,6 +183,7 @@ class Agent(Generic[AgentDeps, ResultData]):
191
183
  model: models.Model | models.KnownModelName | None = None,
192
184
  deps: AgentDeps = None,
193
185
  model_settings: ModelSettings | None = None,
186
+ usage_limits: UsageLimits | None = None,
194
187
  infer_name: bool = True,
195
188
  ) -> result.RunResult[ResultData]:
196
189
  """Run the agent with a user prompt in async mode.
@@ -211,8 +204,9 @@ class Agent(Generic[AgentDeps, ResultData]):
211
204
  message_history: History of the conversation so far.
212
205
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
213
206
  deps: Optional dependencies to use for this run.
214
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
215
207
  model_settings: Optional settings to use for this model's request.
208
+ usage_limits: Optional limits on model request count or token usage.
209
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
216
210
 
217
211
  Returns:
218
212
  The result of the run.
@@ -232,31 +226,37 @@ class Agent(Generic[AgentDeps, ResultData]):
232
226
  model_name=model_used.name(),
233
227
  agent_name=self.name or 'agent',
234
228
  ) as run_span:
235
- self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
229
+ run_context = RunContext(deps, 0, [], None, model_used)
230
+ messages = await self._prepare_messages(user_prompt, message_history, run_context)
231
+ run_context.messages = messages
236
232
 
237
233
  for tool in self._function_tools.values():
238
234
  tool.current_retry = 0
239
235
 
240
- cost = result.Cost()
241
-
236
+ usage = result.Usage(requests=0)
242
237
  model_settings = merge_model_settings(self.model_settings, model_settings)
238
+ usage_limits = usage_limits or UsageLimits()
243
239
 
244
240
  run_step = 0
245
241
  while True:
242
+ usage_limits.check_before_request(usage)
243
+
246
244
  run_step += 1
247
245
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
248
- agent_model = await self._prepare_model(model_used, deps, messages)
246
+ agent_model = await self._prepare_model(run_context)
249
247
 
250
248
  with _logfire.span('model request', run_step=run_step) as model_req_span:
251
- model_response, request_cost = await agent_model.request(messages, model_settings)
249
+ model_response, request_usage = await agent_model.request(messages, model_settings)
252
250
  model_req_span.set_attribute('response', model_response)
253
- model_req_span.set_attribute('cost', request_cost)
251
+ model_req_span.set_attribute('usage', request_usage)
254
252
 
255
253
  messages.append(model_response)
256
- cost += request_cost
254
+ usage += request_usage
255
+ usage.requests += 1
256
+ usage_limits.check_tokens(request_usage)
257
257
 
258
258
  with _logfire.span('handle model response', run_step=run_step) as handle_span:
259
- final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
259
+ final_result, tool_responses = await self._handle_model_response(model_response, run_context)
260
260
 
261
261
  if tool_responses:
262
262
  # Add parts to the conversation as a new message
@@ -266,10 +266,10 @@ class Agent(Generic[AgentDeps, ResultData]):
266
266
  if final_result is not None:
267
267
  result_data = final_result.data
268
268
  run_span.set_attribute('all_messages', messages)
269
- run_span.set_attribute('cost', cost)
269
+ run_span.set_attribute('usage', usage)
270
270
  handle_span.set_attribute('result', result_data)
271
271
  handle_span.message = 'handle model response -> final result'
272
- return result.RunResult(messages, new_message_index, result_data, cost)
272
+ return result.RunResult(messages, new_message_index, result_data, usage)
273
273
  else:
274
274
  # continue the conversation
275
275
  handle_span.set_attribute('tool_responses', tool_responses)
@@ -284,6 +284,7 @@ class Agent(Generic[AgentDeps, ResultData]):
284
284
  model: models.Model | models.KnownModelName | None = None,
285
285
  deps: AgentDeps = None,
286
286
  model_settings: ModelSettings | None = None,
287
+ usage_limits: UsageLimits | None = None,
287
288
  infer_name: bool = True,
288
289
  ) -> result.RunResult[ResultData]:
289
290
  """Run the agent with a user prompt synchronously.
@@ -308,8 +309,9 @@ class Agent(Generic[AgentDeps, ResultData]):
308
309
  message_history: History of the conversation so far.
309
310
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
310
311
  deps: Optional dependencies to use for this run.
311
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
312
312
  model_settings: Optional settings to use for this model's request.
313
+ usage_limits: Optional limits on model request count or token usage.
314
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
313
315
 
314
316
  Returns:
315
317
  The result of the run.
@@ -322,8 +324,9 @@ class Agent(Generic[AgentDeps, ResultData]):
322
324
  message_history=message_history,
323
325
  model=model,
324
326
  deps=deps,
325
- infer_name=False,
326
327
  model_settings=model_settings,
328
+ usage_limits=usage_limits,
329
+ infer_name=False,
327
330
  )
328
331
  )
329
332
 
@@ -336,6 +339,7 @@ class Agent(Generic[AgentDeps, ResultData]):
336
339
  model: models.Model | models.KnownModelName | None = None,
337
340
  deps: AgentDeps = None,
338
341
  model_settings: ModelSettings | None = None,
342
+ usage_limits: UsageLimits | None = None,
339
343
  infer_name: bool = True,
340
344
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
341
345
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -357,8 +361,9 @@ class Agent(Generic[AgentDeps, ResultData]):
357
361
  message_history: History of the conversation so far.
358
362
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
359
363
  deps: Optional dependencies to use for this run.
360
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
361
364
  model_settings: Optional settings to use for this model's request.
365
+ usage_limits: Optional limits on model request count or token usage.
366
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
362
367
 
363
368
  Returns:
364
369
  The result of the run.
@@ -380,32 +385,35 @@ class Agent(Generic[AgentDeps, ResultData]):
380
385
  model_name=model_used.name(),
381
386
  agent_name=self.name or 'agent',
382
387
  ) as run_span:
383
- self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
388
+ run_context = RunContext(deps, 0, [], None, model_used)
389
+ messages = await self._prepare_messages(user_prompt, message_history, run_context)
390
+ run_context.messages = messages
384
391
 
385
392
  for tool in self._function_tools.values():
386
393
  tool.current_retry = 0
387
394
 
388
- cost = result.Cost()
395
+ usage = result.Usage()
389
396
  model_settings = merge_model_settings(self.model_settings, model_settings)
397
+ usage_limits = usage_limits or UsageLimits()
390
398
 
391
399
  run_step = 0
392
400
  while True:
393
401
  run_step += 1
402
+ usage_limits.check_before_request(usage)
394
403
 
395
404
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
396
- agent_model = await self._prepare_model(model_used, deps, messages)
405
+ agent_model = await self._prepare_model(run_context)
397
406
 
398
407
  with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
399
408
  async with agent_model.request_stream(messages, model_settings) as model_response:
409
+ usage.requests += 1
400
410
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
401
411
  # We want to end the "model request" span here, but we can't exit the context manager
402
412
  # in the traditional way
403
413
  model_req_span.__exit__(None, None, None)
404
414
 
405
415
  with _logfire.span('handle model response') as handle_span:
406
- maybe_final_result = await self._handle_streamed_model_response(
407
- model_response, deps, messages
408
- )
416
+ maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
409
417
 
410
418
  # Check if we got a final result
411
419
  if isinstance(maybe_final_result, _MarkFinalResult):
@@ -425,7 +433,7 @@ class Agent(Generic[AgentDeps, ResultData]):
425
433
  part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
426
434
  ]
427
435
  parts = await self._process_function_tools(
428
- tool_calls, result_tool_name, deps, messages
436
+ tool_calls, result_tool_name, run_context
429
437
  )
430
438
  if parts:
431
439
  messages.append(_messages.ModelRequest(parts))
@@ -434,10 +442,11 @@ class Agent(Generic[AgentDeps, ResultData]):
434
442
  yield result.StreamedRunResult(
435
443
  messages,
436
444
  new_message_index,
437
- cost,
445
+ usage,
446
+ usage_limits,
438
447
  result_stream,
439
448
  self._result_schema,
440
- deps,
449
+ run_context,
441
450
  self._result_validators,
442
451
  result_tool_name,
443
452
  on_complete,
@@ -455,8 +464,10 @@ class Agent(Generic[AgentDeps, ResultData]):
455
464
  handle_span.set_attribute('tool_responses', tool_responses)
456
465
  tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
457
466
  handle_span.message = f'handle model response -> {tool_responses_str}'
458
- # the model_response should have been fully streamed by now, we can add it's cost
459
- cost += model_response.cost()
467
+ # the model_response should have been fully streamed by now, we can add its usage
468
+ model_response_usage = model_response.usage()
469
+ usage += model_response_usage
470
+ usage_limits.check_tokens(usage)
460
471
 
461
472
  @contextmanager
462
473
  def override(
@@ -597,7 +608,7 @@ class Agent(Generic[AgentDeps, ResultData]):
597
608
  #> success (no tool calls)
598
609
  ```
599
610
  """
600
- self._result_validators.append(_result.ResultValidator(func))
611
+ self._result_validators.append(_result.ResultValidator[AgentDeps, Any](func))
601
612
  return func
602
613
 
603
614
  @overload
@@ -798,41 +809,50 @@ class Agent(Generic[AgentDeps, ResultData]):
798
809
 
799
810
  return model_, mode_selection
800
811
 
801
- async def _prepare_model(
802
- self, model: models.Model, deps: AgentDeps, messages: list[_messages.ModelMessage]
803
- ) -> models.AgentModel:
804
- """Create building tools and create an agent model."""
812
+ async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
813
+ """Build tools and create an agent model."""
805
814
  function_tools: list[ToolDefinition] = []
806
815
 
807
816
  async def add_tool(tool: Tool[AgentDeps]) -> None:
808
- ctx = RunContext(deps, tool.current_retry, messages, tool.name)
817
+ ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
809
818
  if tool_def := await tool.prepare_tool_def(ctx):
810
819
  function_tools.append(tool_def)
811
820
 
812
821
  await asyncio.gather(*map(add_tool, self._function_tools.values()))
813
822
 
814
- return await model.agent_model(
823
+ return await run_context.model.agent_model(
815
824
  function_tools=function_tools,
816
825
  allow_text_result=self._allow_text_result,
817
826
  result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
818
827
  )
819
828
 
820
829
  async def _prepare_messages(
821
- self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.ModelMessage] | None
830
+ self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
822
831
  ) -> list[_messages.ModelMessage]:
832
+ try:
833
+ messages = _messages_ctx_var.get()
834
+ except LookupError:
835
+ messages = []
836
+ else:
837
+ if messages:
838
+ raise exceptions.UserError(
839
+ 'The capture_run_messages() context manager may only be used to wrap '
840
+ 'one call to run(), run_sync(), or run_stream().'
841
+ )
842
+
823
843
  if message_history:
824
844
  # shallow copy messages
825
- messages = message_history.copy()
845
+ messages.extend(message_history)
826
846
  messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
827
847
  else:
828
- parts = await self._sys_parts(deps)
848
+ parts = await self._sys_parts(run_context)
829
849
  parts.append(_messages.UserPromptPart(user_prompt))
830
- messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
850
+ messages.append(_messages.ModelRequest(parts))
831
851
 
832
852
  return messages
833
853
 
834
854
  async def _handle_model_response(
835
- self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
855
+ self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
836
856
  ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
837
857
  """Process a non-streamed response from the model.
838
858
 
@@ -841,42 +861,48 @@ class Agent(Generic[AgentDeps, ResultData]):
841
861
  """
842
862
  texts: list[str] = []
843
863
  tool_calls: list[_messages.ToolCallPart] = []
844
- for item in model_response.parts:
845
- if isinstance(item, _messages.TextPart):
846
- texts.append(item.content)
864
+ for part in model_response.parts:
865
+ if isinstance(part, _messages.TextPart):
866
+ # ignore empty content for text parts, see #437
867
+ if part.content:
868
+ texts.append(part.content)
847
869
  else:
848
- tool_calls.append(item)
849
-
850
- if texts:
870
+ tool_calls.append(part)
871
+
872
+ # At the moment, we prioritize at least executing tool calls if they are present.
873
+ # In the future, we'd consider making this configurable at the agent or run level.
874
+ # This accounts for cases like anthropic returns that might contain a text response
875
+ # and a tool call response, where the text response just indicates the tool call will happen.
876
+ if tool_calls:
877
+ return await self._handle_structured_response(tool_calls, run_context)
878
+ elif texts:
851
879
  text = '\n\n'.join(texts)
852
- return await self._handle_text_response(text, deps, conv_messages)
853
- elif tool_calls:
854
- return await self._handle_structured_response(tool_calls, deps, conv_messages)
880
+ return await self._handle_text_response(text, run_context)
855
881
  else:
856
882
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
857
883
 
858
884
  async def _handle_text_response(
859
- self, text: str, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
885
+ self, text: str, run_context: RunContext[AgentDeps]
860
886
  ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
861
887
  """Handle a plain text response from the model for non-streaming responses."""
862
888
  if self._allow_text_result:
863
889
  result_data_input = cast(ResultData, text)
864
890
  try:
865
- result_data = await self._validate_result(result_data_input, deps, None, conv_messages)
891
+ result_data = await self._validate_result(result_data_input, run_context, None)
866
892
  except _result.ToolRetryError as e:
867
- self._incr_result_retry()
893
+ self._incr_result_retry(run_context)
868
894
  return None, [e.tool_retry]
869
895
  else:
870
896
  return _MarkFinalResult(result_data, None), []
871
897
  else:
872
- self._incr_result_retry()
898
+ self._incr_result_retry(run_context)
873
899
  response = _messages.RetryPromptPart(
874
900
  content='Plain text responses are not permitted, please call one of the functions instead.',
875
901
  )
876
902
  return None, [response]
877
903
 
878
904
  async def _handle_structured_response(
879
- self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
905
+ self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
880
906
  ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
881
907
  """Handle a structured response containing tool calls from the model for non-streaming responses."""
882
908
  assert tool_calls, 'Expected at least one tool call'
@@ -890,17 +916,15 @@ class Agent(Generic[AgentDeps, ResultData]):
890
916
  call, result_tool = match
891
917
  try:
892
918
  result_data = result_tool.validate(call)
893
- result_data = await self._validate_result(result_data, deps, call, conv_messages)
919
+ result_data = await self._validate_result(result_data, run_context, call)
894
920
  except _result.ToolRetryError as e:
895
- self._incr_result_retry()
921
+ self._incr_result_retry(run_context)
896
922
  parts.append(e.tool_retry)
897
923
  else:
898
924
  final_result = _MarkFinalResult(result_data, call.tool_name)
899
925
 
900
926
  # Then build the other request parts based on end strategy
901
- parts += await self._process_function_tools(
902
- tool_calls, final_result and final_result.tool_name, deps, conv_messages
903
- )
927
+ parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
904
928
 
905
929
  return final_result, parts
906
930
 
@@ -908,8 +932,7 @@ class Agent(Generic[AgentDeps, ResultData]):
908
932
  self,
909
933
  tool_calls: list[_messages.ToolCallPart],
910
934
  result_tool_name: str | None,
911
- deps: AgentDeps,
912
- conv_messages: list[_messages.ModelMessage],
935
+ run_context: RunContext[AgentDeps],
913
936
  ) -> list[_messages.ModelRequestPart]:
914
937
  """Process function (non-result) tool calls in parallel.
915
938
 
@@ -942,7 +965,7 @@ class Agent(Generic[AgentDeps, ResultData]):
942
965
  )
943
966
  )
944
967
  else:
945
- tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
968
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
946
969
  elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
947
970
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
948
971
  # validation, we don't add another part here
@@ -955,7 +978,7 @@ class Agent(Generic[AgentDeps, ResultData]):
955
978
  )
956
979
  )
957
980
  else:
958
- parts.append(self._unknown_tool(call.tool_name))
981
+ parts.append(self._unknown_tool(call.tool_name, run_context))
959
982
 
960
983
  # Run all tool tasks in parallel
961
984
  if tasks:
@@ -967,8 +990,7 @@ class Agent(Generic[AgentDeps, ResultData]):
967
990
  async def _handle_streamed_model_response(
968
991
  self,
969
992
  model_response: models.EitherStreamedResponse,
970
- deps: AgentDeps,
971
- conv_messages: list[_messages.ModelMessage],
993
+ run_context: RunContext[AgentDeps],
972
994
  ) -> (
973
995
  _MarkFinalResult[models.EitherStreamedResponse]
974
996
  | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
@@ -984,11 +1006,11 @@ class Agent(Generic[AgentDeps, ResultData]):
984
1006
  if self._allow_text_result:
985
1007
  return _MarkFinalResult(model_response, None)
986
1008
  else:
987
- self._incr_result_retry()
1009
+ self._incr_result_retry(run_context)
988
1010
  response = _messages.RetryPromptPart(
989
1011
  content='Plain text responses are not permitted, please call one of the functions instead.',
990
1012
  )
991
- # stream the response, so cost is correct
1013
+ # stream the response, so usage is correct
992
1014
  async for _ in model_response:
993
1015
  pass
994
1016
 
@@ -1024,9 +1046,9 @@ class Agent(Generic[AgentDeps, ResultData]):
1024
1046
  if isinstance(item, _messages.ToolCallPart):
1025
1047
  call = item
1026
1048
  if tool := self._function_tools.get(call.tool_name):
1027
- tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
1049
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1028
1050
  else:
1029
- parts.append(self._unknown_tool(call.tool_name))
1051
+ parts.append(self._unknown_tool(call.tool_name, run_context))
1030
1052
 
1031
1053
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1032
1054
  task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
@@ -1038,33 +1060,30 @@ class Agent(Generic[AgentDeps, ResultData]):
1038
1060
  async def _validate_result(
1039
1061
  self,
1040
1062
  result_data: ResultData,
1041
- deps: AgentDeps,
1063
+ run_context: RunContext[AgentDeps],
1042
1064
  tool_call: _messages.ToolCallPart | None,
1043
- conv_messages: list[_messages.ModelMessage],
1044
1065
  ) -> ResultData:
1045
1066
  for validator in self._result_validators:
1046
- result_data = await validator.validate(
1047
- result_data, deps, self._current_result_retry, tool_call, conv_messages
1048
- )
1067
+ result_data = await validator.validate(result_data, tool_call, run_context)
1049
1068
  return result_data
1050
1069
 
1051
- def _incr_result_retry(self) -> None:
1052
- self._current_result_retry += 1
1053
- if self._current_result_retry > self._max_result_retries:
1070
+ def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1071
+ run_context.retry += 1
1072
+ if run_context.retry > self._max_result_retries:
1054
1073
  raise exceptions.UnexpectedModelBehavior(
1055
1074
  f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
1056
1075
  )
1057
1076
 
1058
- async def _sys_parts(self, deps: AgentDeps) -> list[_messages.ModelRequestPart]:
1077
+ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
1059
1078
  """Build the initial messages for the conversation."""
1060
1079
  messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1061
1080
  for sys_prompt_runner in self._system_prompt_functions:
1062
- prompt = await sys_prompt_runner.run(deps)
1081
+ prompt = await sys_prompt_runner.run(run_context)
1063
1082
  messages.append(_messages.SystemPromptPart(prompt))
1064
1083
  return messages
1065
1084
 
1066
- def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
1067
- self._incr_result_retry()
1085
+ def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
1086
+ self._incr_result_retry(run_context)
1068
1087
  names = list(self._function_tools.keys())
1069
1088
  if self._result_schema:
1070
1089
  names.extend(self._result_schema.tool_names())
@@ -1105,6 +1124,51 @@ class Agent(Generic[AgentDeps, ResultData]):
1105
1124
  self.name = name
1106
1125
  return
1107
1126
 
1127
+ @property
1128
+ @deprecated(
1129
+ 'The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.', category=None
1130
+ )
1131
+ def last_run_messages(self) -> list[_messages.ModelMessage]:
1132
+ raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.')
1133
+
1134
+
1135
+ _messages_ctx_var: ContextVar[list[_messages.ModelMessage]] = ContextVar('var')
1136
+
1137
+
1138
+ @contextmanager
1139
+ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]:
1140
+ """Context manager to access the messages used in a [`run`][pydantic_ai.Agent.run], [`run_sync`][pydantic_ai.Agent.run_sync], or [`run_stream`][pydantic_ai.Agent.run_stream] call.
1141
+
1142
+ Useful when a run may raise an exception, see [model errors](../agents.md#model-errors) for more information.
1143
+
1144
+ Examples:
1145
+ ```python
1146
+ from pydantic_ai import Agent, capture_run_messages
1147
+
1148
+ agent = Agent('test')
1149
+
1150
+ with capture_run_messages() as messages:
1151
+ try:
1152
+ result = agent.run_sync('foobar')
1153
+ except Exception:
1154
+ print(messages)
1155
+ raise
1156
+ ```
1157
+
1158
+ !!! note
1159
+ You may not call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context.
1160
+ If you try to do so, a [`UserError`][pydantic_ai.exceptions.UserError] will be raised.
1161
+ """
1162
+ try:
1163
+ yield _messages_ctx_var.get()
1164
+ except LookupError:
1165
+ messages: list[_messages.ModelMessage] = []
1166
+ token = _messages_ctx_var.set(messages)
1167
+ try:
1168
+ yield messages
1169
+ finally:
1170
+ _messages_ctx_var.reset(token)
1171
+
1108
1172
 
1109
1173
  @dataclass
1110
1174
  class _MarkFinalResult(Generic[ResultData]):