pydantic-ai-slim 0.0.13__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

pydantic_ai/__init__.py CHANGED
@@ -1,8 +1,18 @@
1
1
  from importlib.metadata import version
2
2
 
3
3
  from .agent import Agent
4
- from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
4
+ from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
5
5
  from .tools import RunContext, Tool
6
6
 
7
- __all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
7
+ __all__ = (
8
+ 'Agent',
9
+ 'RunContext',
10
+ 'Tool',
11
+ 'AgentRunError',
12
+ 'ModelRetry',
13
+ 'UnexpectedModelBehavior',
14
+ 'UsageLimitExceeded',
15
+ 'UserError',
16
+ '__version__',
17
+ )
8
18
  __version__ = version('pydantic_ai_slim')
pydantic_ai/_result.py CHANGED
@@ -29,25 +29,22 @@ class ResultValidator(Generic[AgentDeps, ResultData]):
29
29
  async def validate(
30
30
  self,
31
31
  result: ResultData,
32
- deps: AgentDeps,
33
- retry: int,
34
32
  tool_call: _messages.ToolCallPart | None,
35
- messages: list[_messages.ModelMessage],
33
+ run_context: RunContext[AgentDeps],
36
34
  ) -> ResultData:
37
35
  """Validate a result but calling the function.
38
36
 
39
37
  Args:
40
38
  result: The result data after Pydantic validation the message content.
41
- deps: The agent dependencies.
42
- retry: The current retry number.
43
39
  tool_call: The original tool call message, `None` if there was no tool call.
44
- messages: The messages exchanged so far in the conversation.
40
+ run_context: The current run context.
45
41
 
46
42
  Returns:
47
43
  Result of either the validated result data (ok) or a retry message (Err).
48
44
  """
49
45
  if self._takes_ctx:
50
- args = RunContext(deps, retry, messages, tool_call.tool_name if tool_call else None), result
46
+ ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None)
47
+ args = ctx, result
51
48
  else:
52
49
  args = (result,)
53
50
 
@@ -19,9 +19,9 @@ class SystemPromptRunner(Generic[AgentDeps]):
19
19
  self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
20
20
  self._is_async = inspect.iscoroutinefunction(self.function)
21
21
 
22
- async def run(self, deps: AgentDeps) -> str:
22
+ async def run(self, run_context: RunContext[AgentDeps]) -> str:
23
23
  if self._takes_ctx:
24
- args = (RunContext(deps, 0, [], None),)
24
+ args = (run_context,)
25
25
  else:
26
26
  args = ()
27
27
 
pydantic_ai/agent.py CHANGED
@@ -22,7 +22,7 @@ from . import (
22
22
  result,
23
23
  )
24
24
  from .result import ResultData
25
- from .settings import ModelSettings, merge_model_settings
25
+ from .settings import ModelSettings, UsageLimits, merge_model_settings
26
26
  from .tools import (
27
27
  AgentDeps,
28
28
  RunContext,
@@ -104,7 +104,6 @@ class Agent(Generic[AgentDeps, ResultData]):
104
104
  _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
105
105
  _deps_type: type[AgentDeps] = field(repr=False)
106
106
  _max_result_retries: int = field(repr=False)
107
- _current_result_retry: int = field(repr=False)
108
107
  _override_deps: _utils.Option[AgentDeps] = field(default=None, repr=False)
109
108
  _override_model: _utils.Option[models.Model] = field(default=None, repr=False)
110
109
 
@@ -180,7 +179,6 @@ class Agent(Generic[AgentDeps, ResultData]):
180
179
  self._deps_type = deps_type
181
180
  self._system_prompt_functions = []
182
181
  self._max_result_retries = result_retries if result_retries is not None else retries
183
- self._current_result_retry = 0
184
182
  self._result_validators = []
185
183
 
186
184
  async def run(
@@ -191,6 +189,7 @@ class Agent(Generic[AgentDeps, ResultData]):
191
189
  model: models.Model | models.KnownModelName | None = None,
192
190
  deps: AgentDeps = None,
193
191
  model_settings: ModelSettings | None = None,
192
+ usage_limits: UsageLimits | None = None,
194
193
  infer_name: bool = True,
195
194
  ) -> result.RunResult[ResultData]:
196
195
  """Run the agent with a user prompt in async mode.
@@ -211,8 +210,9 @@ class Agent(Generic[AgentDeps, ResultData]):
211
210
  message_history: History of the conversation so far.
212
211
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
213
212
  deps: Optional dependencies to use for this run.
214
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
215
213
  model_settings: Optional settings to use for this model's request.
214
+ usage_limits: Optional limits on model request count or token usage.
215
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
216
216
 
217
217
  Returns:
218
218
  The result of the run.
@@ -232,31 +232,37 @@ class Agent(Generic[AgentDeps, ResultData]):
232
232
  model_name=model_used.name(),
233
233
  agent_name=self.name or 'agent',
234
234
  ) as run_span:
235
- self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
235
+ run_context = RunContext(deps, 0, [], None, model_used)
236
+ messages = await self._prepare_messages(user_prompt, message_history, run_context)
237
+ self.last_run_messages = run_context.messages = messages
236
238
 
237
239
  for tool in self._function_tools.values():
238
240
  tool.current_retry = 0
239
241
 
240
- cost = result.Cost()
241
-
242
+ usage = result.Usage(requests=0)
242
243
  model_settings = merge_model_settings(self.model_settings, model_settings)
244
+ usage_limits = usage_limits or UsageLimits()
243
245
 
244
246
  run_step = 0
245
247
  while True:
248
+ usage_limits.check_before_request(usage)
249
+
246
250
  run_step += 1
247
251
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
248
- agent_model = await self._prepare_model(model_used, deps, messages)
252
+ agent_model = await self._prepare_model(run_context)
249
253
 
250
254
  with _logfire.span('model request', run_step=run_step) as model_req_span:
251
- model_response, request_cost = await agent_model.request(messages, model_settings)
255
+ model_response, request_usage = await agent_model.request(messages, model_settings)
252
256
  model_req_span.set_attribute('response', model_response)
253
- model_req_span.set_attribute('cost', request_cost)
257
+ model_req_span.set_attribute('usage', request_usage)
254
258
 
255
259
  messages.append(model_response)
256
- cost += request_cost
260
+ usage += request_usage
261
+ usage.requests += 1
262
+ usage_limits.check_tokens(request_usage)
257
263
 
258
264
  with _logfire.span('handle model response', run_step=run_step) as handle_span:
259
- final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
265
+ final_result, tool_responses = await self._handle_model_response(model_response, run_context)
260
266
 
261
267
  if tool_responses:
262
268
  # Add parts to the conversation as a new message
@@ -266,10 +272,10 @@ class Agent(Generic[AgentDeps, ResultData]):
266
272
  if final_result is not None:
267
273
  result_data = final_result.data
268
274
  run_span.set_attribute('all_messages', messages)
269
- run_span.set_attribute('cost', cost)
275
+ run_span.set_attribute('usage', usage)
270
276
  handle_span.set_attribute('result', result_data)
271
277
  handle_span.message = 'handle model response -> final result'
272
- return result.RunResult(messages, new_message_index, result_data, cost)
278
+ return result.RunResult(messages, new_message_index, result_data, usage)
273
279
  else:
274
280
  # continue the conversation
275
281
  handle_span.set_attribute('tool_responses', tool_responses)
@@ -284,6 +290,7 @@ class Agent(Generic[AgentDeps, ResultData]):
284
290
  model: models.Model | models.KnownModelName | None = None,
285
291
  deps: AgentDeps = None,
286
292
  model_settings: ModelSettings | None = None,
293
+ usage_limits: UsageLimits | None = None,
287
294
  infer_name: bool = True,
288
295
  ) -> result.RunResult[ResultData]:
289
296
  """Run the agent with a user prompt synchronously.
@@ -308,8 +315,9 @@ class Agent(Generic[AgentDeps, ResultData]):
308
315
  message_history: History of the conversation so far.
309
316
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
310
317
  deps: Optional dependencies to use for this run.
311
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
312
318
  model_settings: Optional settings to use for this model's request.
319
+ usage_limits: Optional limits on model request count or token usage.
320
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
313
321
 
314
322
  Returns:
315
323
  The result of the run.
@@ -322,8 +330,9 @@ class Agent(Generic[AgentDeps, ResultData]):
322
330
  message_history=message_history,
323
331
  model=model,
324
332
  deps=deps,
325
- infer_name=False,
326
333
  model_settings=model_settings,
334
+ usage_limits=usage_limits,
335
+ infer_name=False,
327
336
  )
328
337
  )
329
338
 
@@ -336,6 +345,7 @@ class Agent(Generic[AgentDeps, ResultData]):
336
345
  model: models.Model | models.KnownModelName | None = None,
337
346
  deps: AgentDeps = None,
338
347
  model_settings: ModelSettings | None = None,
348
+ usage_limits: UsageLimits | None = None,
339
349
  infer_name: bool = True,
340
350
  ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
341
351
  """Run the agent with a user prompt in async mode, returning a streamed response.
@@ -357,8 +367,9 @@ class Agent(Generic[AgentDeps, ResultData]):
357
367
  message_history: History of the conversation so far.
358
368
  model: Optional model to use for this run, required if `model` was not set when creating the agent.
359
369
  deps: Optional dependencies to use for this run.
360
- infer_name: Whether to try to infer the agent name from the call frame if it's not set.
361
370
  model_settings: Optional settings to use for this model's request.
371
+ usage_limits: Optional limits on model request count or token usage.
372
+ infer_name: Whether to try to infer the agent name from the call frame if it's not set.
362
373
 
363
374
  Returns:
364
375
  The result of the run.
@@ -380,32 +391,35 @@ class Agent(Generic[AgentDeps, ResultData]):
380
391
  model_name=model_used.name(),
381
392
  agent_name=self.name or 'agent',
382
393
  ) as run_span:
383
- self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
394
+ run_context = RunContext(deps, 0, [], None, model_used)
395
+ messages = await self._prepare_messages(user_prompt, message_history, run_context)
396
+ self.last_run_messages = run_context.messages = messages
384
397
 
385
398
  for tool in self._function_tools.values():
386
399
  tool.current_retry = 0
387
400
 
388
- cost = result.Cost()
401
+ usage = result.Usage()
389
402
  model_settings = merge_model_settings(self.model_settings, model_settings)
403
+ usage_limits = usage_limits or UsageLimits()
390
404
 
391
405
  run_step = 0
392
406
  while True:
393
407
  run_step += 1
408
+ usage_limits.check_before_request(usage)
394
409
 
395
410
  with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
396
- agent_model = await self._prepare_model(model_used, deps, messages)
411
+ agent_model = await self._prepare_model(run_context)
397
412
 
398
413
  with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
399
414
  async with agent_model.request_stream(messages, model_settings) as model_response:
415
+ usage.requests += 1
400
416
  model_req_span.set_attribute('response_type', model_response.__class__.__name__)
401
417
  # We want to end the "model request" span here, but we can't exit the context manager
402
418
  # in the traditional way
403
419
  model_req_span.__exit__(None, None, None)
404
420
 
405
421
  with _logfire.span('handle model response') as handle_span:
406
- maybe_final_result = await self._handle_streamed_model_response(
407
- model_response, deps, messages
408
- )
422
+ maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
409
423
 
410
424
  # Check if we got a final result
411
425
  if isinstance(maybe_final_result, _MarkFinalResult):
@@ -425,7 +439,7 @@ class Agent(Generic[AgentDeps, ResultData]):
425
439
  part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
426
440
  ]
427
441
  parts = await self._process_function_tools(
428
- tool_calls, result_tool_name, deps, messages
442
+ tool_calls, result_tool_name, run_context
429
443
  )
430
444
  if parts:
431
445
  messages.append(_messages.ModelRequest(parts))
@@ -434,10 +448,11 @@ class Agent(Generic[AgentDeps, ResultData]):
434
448
  yield result.StreamedRunResult(
435
449
  messages,
436
450
  new_message_index,
437
- cost,
451
+ usage,
452
+ usage_limits,
438
453
  result_stream,
439
454
  self._result_schema,
440
- deps,
455
+ run_context,
441
456
  self._result_validators,
442
457
  result_tool_name,
443
458
  on_complete,
@@ -455,8 +470,10 @@ class Agent(Generic[AgentDeps, ResultData]):
455
470
  handle_span.set_attribute('tool_responses', tool_responses)
456
471
  tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
457
472
  handle_span.message = f'handle model response -> {tool_responses_str}'
458
- # the model_response should have been fully streamed by now, we can add it's cost
459
- cost += model_response.cost()
473
+ # the model_response should have been fully streamed by now, we can add its usage
474
+ model_response_usage = model_response.usage()
475
+ usage += model_response_usage
476
+ usage_limits.check_tokens(usage)
460
477
 
461
478
  @contextmanager
462
479
  def override(
@@ -798,41 +815,39 @@ class Agent(Generic[AgentDeps, ResultData]):
798
815
 
799
816
  return model_, mode_selection
800
817
 
801
- async def _prepare_model(
802
- self, model: models.Model, deps: AgentDeps, messages: list[_messages.ModelMessage]
803
- ) -> models.AgentModel:
804
- """Create building tools and create an agent model."""
818
+ async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
819
+ """Build tools and create an agent model."""
805
820
  function_tools: list[ToolDefinition] = []
806
821
 
807
822
  async def add_tool(tool: Tool[AgentDeps]) -> None:
808
- ctx = RunContext(deps, tool.current_retry, messages, tool.name)
823
+ ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name)
809
824
  if tool_def := await tool.prepare_tool_def(ctx):
810
825
  function_tools.append(tool_def)
811
826
 
812
827
  await asyncio.gather(*map(add_tool, self._function_tools.values()))
813
828
 
814
- return await model.agent_model(
829
+ return await run_context.model.agent_model(
815
830
  function_tools=function_tools,
816
831
  allow_text_result=self._allow_text_result,
817
832
  result_tools=self._result_schema.tool_defs() if self._result_schema is not None else [],
818
833
  )
819
834
 
820
835
  async def _prepare_messages(
821
- self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.ModelMessage] | None
836
+ self, user_prompt: str, message_history: list[_messages.ModelMessage] | None, run_context: RunContext[AgentDeps]
822
837
  ) -> list[_messages.ModelMessage]:
823
838
  if message_history:
824
839
  # shallow copy messages
825
840
  messages = message_history.copy()
826
841
  messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
827
842
  else:
828
- parts = await self._sys_parts(deps)
843
+ parts = await self._sys_parts(run_context)
829
844
  parts.append(_messages.UserPromptPart(user_prompt))
830
845
  messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
831
846
 
832
847
  return messages
833
848
 
834
849
  async def _handle_model_response(
835
- self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
850
+ self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
836
851
  ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
837
852
  """Process a non-streamed response from the model.
838
853
 
@@ -841,42 +856,44 @@ class Agent(Generic[AgentDeps, ResultData]):
841
856
  """
842
857
  texts: list[str] = []
843
858
  tool_calls: list[_messages.ToolCallPart] = []
844
- for item in model_response.parts:
845
- if isinstance(item, _messages.TextPart):
846
- texts.append(item.content)
859
+ for part in model_response.parts:
860
+ if isinstance(part, _messages.TextPart):
861
+ # ignore empty content for text parts, see #437
862
+ if part.content:
863
+ texts.append(part.content)
847
864
  else:
848
- tool_calls.append(item)
865
+ tool_calls.append(part)
849
866
 
850
867
  if texts:
851
868
  text = '\n\n'.join(texts)
852
- return await self._handle_text_response(text, deps, conv_messages)
869
+ return await self._handle_text_response(text, run_context)
853
870
  elif tool_calls:
854
- return await self._handle_structured_response(tool_calls, deps, conv_messages)
871
+ return await self._handle_structured_response(tool_calls, run_context)
855
872
  else:
856
873
  raise exceptions.UnexpectedModelBehavior('Received empty model response')
857
874
 
858
875
  async def _handle_text_response(
859
- self, text: str, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
876
+ self, text: str, run_context: RunContext[AgentDeps]
860
877
  ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
861
878
  """Handle a plain text response from the model for non-streaming responses."""
862
879
  if self._allow_text_result:
863
880
  result_data_input = cast(ResultData, text)
864
881
  try:
865
- result_data = await self._validate_result(result_data_input, deps, None, conv_messages)
882
+ result_data = await self._validate_result(result_data_input, run_context, None)
866
883
  except _result.ToolRetryError as e:
867
- self._incr_result_retry()
884
+ self._incr_result_retry(run_context)
868
885
  return None, [e.tool_retry]
869
886
  else:
870
887
  return _MarkFinalResult(result_data, None), []
871
888
  else:
872
- self._incr_result_retry()
889
+ self._incr_result_retry(run_context)
873
890
  response = _messages.RetryPromptPart(
874
891
  content='Plain text responses are not permitted, please call one of the functions instead.',
875
892
  )
876
893
  return None, [response]
877
894
 
878
895
  async def _handle_structured_response(
879
- self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
896
+ self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
880
897
  ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
881
898
  """Handle a structured response containing tool calls from the model for non-streaming responses."""
882
899
  assert tool_calls, 'Expected at least one tool call'
@@ -890,17 +907,15 @@ class Agent(Generic[AgentDeps, ResultData]):
890
907
  call, result_tool = match
891
908
  try:
892
909
  result_data = result_tool.validate(call)
893
- result_data = await self._validate_result(result_data, deps, call, conv_messages)
910
+ result_data = await self._validate_result(result_data, run_context, call)
894
911
  except _result.ToolRetryError as e:
895
- self._incr_result_retry()
912
+ self._incr_result_retry(run_context)
896
913
  parts.append(e.tool_retry)
897
914
  else:
898
915
  final_result = _MarkFinalResult(result_data, call.tool_name)
899
916
 
900
917
  # Then build the other request parts based on end strategy
901
- parts += await self._process_function_tools(
902
- tool_calls, final_result and final_result.tool_name, deps, conv_messages
903
- )
918
+ parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
904
919
 
905
920
  return final_result, parts
906
921
 
@@ -908,8 +923,7 @@ class Agent(Generic[AgentDeps, ResultData]):
908
923
  self,
909
924
  tool_calls: list[_messages.ToolCallPart],
910
925
  result_tool_name: str | None,
911
- deps: AgentDeps,
912
- conv_messages: list[_messages.ModelMessage],
926
+ run_context: RunContext[AgentDeps],
913
927
  ) -> list[_messages.ModelRequestPart]:
914
928
  """Process function (non-result) tool calls in parallel.
915
929
 
@@ -942,7 +956,7 @@ class Agent(Generic[AgentDeps, ResultData]):
942
956
  )
943
957
  )
944
958
  else:
945
- tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
959
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
946
960
  elif self._result_schema is not None and call.tool_name in self._result_schema.tools:
947
961
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
948
962
  # validation, we don't add another part here
@@ -955,7 +969,7 @@ class Agent(Generic[AgentDeps, ResultData]):
955
969
  )
956
970
  )
957
971
  else:
958
- parts.append(self._unknown_tool(call.tool_name))
972
+ parts.append(self._unknown_tool(call.tool_name, run_context))
959
973
 
960
974
  # Run all tool tasks in parallel
961
975
  if tasks:
@@ -967,8 +981,7 @@ class Agent(Generic[AgentDeps, ResultData]):
967
981
  async def _handle_streamed_model_response(
968
982
  self,
969
983
  model_response: models.EitherStreamedResponse,
970
- deps: AgentDeps,
971
- conv_messages: list[_messages.ModelMessage],
984
+ run_context: RunContext[AgentDeps],
972
985
  ) -> (
973
986
  _MarkFinalResult[models.EitherStreamedResponse]
974
987
  | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
@@ -984,11 +997,11 @@ class Agent(Generic[AgentDeps, ResultData]):
984
997
  if self._allow_text_result:
985
998
  return _MarkFinalResult(model_response, None)
986
999
  else:
987
- self._incr_result_retry()
1000
+ self._incr_result_retry(run_context)
988
1001
  response = _messages.RetryPromptPart(
989
1002
  content='Plain text responses are not permitted, please call one of the functions instead.',
990
1003
  )
991
- # stream the response, so cost is correct
1004
+ # stream the response, so usage is correct
992
1005
  async for _ in model_response:
993
1006
  pass
994
1007
 
@@ -1024,9 +1037,9 @@ class Agent(Generic[AgentDeps, ResultData]):
1024
1037
  if isinstance(item, _messages.ToolCallPart):
1025
1038
  call = item
1026
1039
  if tool := self._function_tools.get(call.tool_name):
1027
- tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
1040
+ tasks.append(asyncio.create_task(tool.run(call, run_context), name=call.tool_name))
1028
1041
  else:
1029
- parts.append(self._unknown_tool(call.tool_name))
1042
+ parts.append(self._unknown_tool(call.tool_name, run_context))
1030
1043
 
1031
1044
  with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
1032
1045
  task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
@@ -1038,33 +1051,30 @@ class Agent(Generic[AgentDeps, ResultData]):
1038
1051
  async def _validate_result(
1039
1052
  self,
1040
1053
  result_data: ResultData,
1041
- deps: AgentDeps,
1054
+ run_context: RunContext[AgentDeps],
1042
1055
  tool_call: _messages.ToolCallPart | None,
1043
- conv_messages: list[_messages.ModelMessage],
1044
1056
  ) -> ResultData:
1045
1057
  for validator in self._result_validators:
1046
- result_data = await validator.validate(
1047
- result_data, deps, self._current_result_retry, tool_call, conv_messages
1048
- )
1058
+ result_data = await validator.validate(result_data, tool_call, run_context)
1049
1059
  return result_data
1050
1060
 
1051
- def _incr_result_retry(self) -> None:
1052
- self._current_result_retry += 1
1053
- if self._current_result_retry > self._max_result_retries:
1061
+ def _incr_result_retry(self, run_context: RunContext[AgentDeps]) -> None:
1062
+ run_context.retry += 1
1063
+ if run_context.retry > self._max_result_retries:
1054
1064
  raise exceptions.UnexpectedModelBehavior(
1055
1065
  f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
1056
1066
  )
1057
1067
 
1058
- async def _sys_parts(self, deps: AgentDeps) -> list[_messages.ModelRequestPart]:
1068
+ async def _sys_parts(self, run_context: RunContext[AgentDeps]) -> list[_messages.ModelRequestPart]:
1059
1069
  """Build the initial messages for the conversation."""
1060
1070
  messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
1061
1071
  for sys_prompt_runner in self._system_prompt_functions:
1062
- prompt = await sys_prompt_runner.run(deps)
1072
+ prompt = await sys_prompt_runner.run(run_context)
1063
1073
  messages.append(_messages.SystemPromptPart(prompt))
1064
1074
  return messages
1065
1075
 
1066
- def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
1067
- self._incr_result_retry()
1076
+ def _unknown_tool(self, tool_name: str, run_context: RunContext[AgentDeps]) -> _messages.RetryPromptPart:
1077
+ self._incr_result_retry(run_context)
1068
1078
  names = list(self._function_tools.keys())
1069
1079
  if self._result_schema:
1070
1080
  names.extend(self._result_schema.tool_names())
pydantic_ai/exceptions.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import json
4
4
 
5
- __all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior'
5
+ __all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
6
6
 
7
7
 
8
8
  class ModelRetry(Exception):
@@ -30,7 +30,25 @@ class UserError(RuntimeError):
30
30
  super().__init__(message)
31
31
 
32
32
 
33
- class UnexpectedModelBehavior(RuntimeError):
33
+ class AgentRunError(RuntimeError):
34
+ """Base class for errors occurring during an agent run."""
35
+
36
+ message: str
37
+ """The error message."""
38
+
39
+ def __init__(self, message: str):
40
+ self.message = message
41
+ super().__init__(message)
42
+
43
+ def __str__(self) -> str:
44
+ return self.message
45
+
46
+
47
+ class UsageLimitExceeded(AgentRunError):
48
+ """Error raised when a Model's usage exceeds the specified limits."""
49
+
50
+
51
+ class UnexpectedModelBehavior(AgentRunError):
34
52
  """Error caused by unexpected Model behavior, e.g. an unexpected response code."""
35
53
 
36
54
  message: str
pydantic_ai/messages.py CHANGED
@@ -2,11 +2,11 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  from dataclasses import dataclass, field
4
4
  from datetime import datetime
5
- from typing import Annotated, Any, Literal, Union
5
+ from typing import Annotated, Any, Literal, Union, cast
6
6
 
7
7
  import pydantic
8
8
  import pydantic_core
9
- from typing_extensions import Self
9
+ from typing_extensions import Self, assert_never
10
10
 
11
11
  from ._utils import now_utc as _now_utc
12
12
 
@@ -190,12 +190,34 @@ class ToolCallPart:
190
190
  """Part type identifier, this is available on all parts as a discriminator."""
191
191
 
192
192
  @classmethod
193
- def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> Self:
194
- return cls(tool_name, ArgsJson(args_json), tool_call_id)
193
+ def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
194
+ """Create a `ToolCallPart` from raw arguments."""
195
+ if isinstance(args, str):
196
+ return cls(tool_name, ArgsJson(args), tool_call_id)
197
+ elif isinstance(args, dict):
198
+ return cls(tool_name, ArgsDict(args), tool_call_id)
199
+ else:
200
+ assert_never(args)
195
201
 
196
- @classmethod
197
- def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> Self:
198
- return cls(tool_name, ArgsDict(args_dict), tool_call_id)
202
+ def args_as_dict(self) -> dict[str, Any]:
203
+ """Return the arguments as a Python dictionary.
204
+
205
+ This is just for convenience with models that require dicts as input.
206
+ """
207
+ if isinstance(self.args, ArgsDict):
208
+ return self.args.args_dict
209
+ args = pydantic_core.from_json(self.args.args_json)
210
+ assert isinstance(args, dict), 'args should be a dict'
211
+ return cast(dict[str, Any], args)
212
+
213
+ def args_as_json_str(self) -> str:
214
+ """Return the arguments as a JSON string.
215
+
216
+ This is just for convenience with models that require JSON strings as input.
217
+ """
218
+ if isinstance(self.args, ArgsJson):
219
+ return self.args.args_json
220
+ return pydantic_core.to_json(self.args.args_dict).decode()
199
221
 
200
222
  def has_content(self) -> bool:
201
223
  if isinstance(self.args, ArgsDict):