pydantic-ai-slim 0.2.18__tar.gz → 0.2.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (76) hide show
  1. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/PKG-INFO +4 -4
  2. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_agent_graph.py +68 -14
  3. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_function_schema.py +14 -6
  4. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_output.py +1 -1
  5. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_system_prompt.py +1 -1
  6. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_utils.py +28 -3
  7. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/agent.py +13 -3
  8. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/mcp.py +66 -5
  9. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/messages.py +4 -5
  10. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/anthropic.py +1 -1
  11. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/gemini.py +1 -3
  12. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/google.py +18 -6
  13. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/__init__.py +23 -17
  14. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/tools.py +1 -2
  15. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/.gitignore +0 -0
  16. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/LICENSE +0 -0
  17. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/README.md +0 -0
  18. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/__init__.py +0 -0
  19. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/__main__.py +0 -0
  20. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_a2a.py +0 -0
  21. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_cli.py +0 -0
  22. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_griffe.py +0 -0
  23. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/_parts_manager.py +0 -0
  24. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/common_tools/__init__.py +0 -0
  25. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  26. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/common_tools/tavily.py +0 -0
  27. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/direct.py +0 -0
  28. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/exceptions.py +0 -0
  29. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/ext/__init__.py +0 -0
  30. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/ext/langchain.py +0 -0
  31. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/format_as_xml.py +0 -0
  32. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/format_prompt.py +0 -0
  33. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/__init__.py +0 -0
  34. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/bedrock.py +0 -0
  35. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/cohere.py +0 -0
  36. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/fallback.py +0 -0
  37. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/function.py +0 -0
  38. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/groq.py +0 -0
  39. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/instrumented.py +0 -0
  40. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/mistral.py +0 -0
  41. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/openai.py +0 -0
  42. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/test.py +0 -0
  43. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/models/wrapper.py +0 -0
  44. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/__init__.py +0 -0
  45. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/_json_schema.py +0 -0
  46. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/amazon.py +0 -0
  47. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/anthropic.py +0 -0
  48. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/cohere.py +0 -0
  49. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/deepseek.py +0 -0
  50. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/google.py +0 -0
  51. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/grok.py +0 -0
  52. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/meta.py +0 -0
  53. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/mistral.py +0 -0
  54. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/openai.py +0 -0
  55. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/profiles/qwen.py +0 -0
  56. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/anthropic.py +0 -0
  57. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/azure.py +0 -0
  58. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/bedrock.py +0 -0
  59. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/cohere.py +0 -0
  60. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/deepseek.py +0 -0
  61. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/fireworks.py +0 -0
  62. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/google.py +0 -0
  63. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/google_gla.py +0 -0
  64. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/google_vertex.py +0 -0
  65. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/grok.py +0 -0
  66. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/groq.py +0 -0
  67. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/heroku.py +0 -0
  68. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/mistral.py +0 -0
  69. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/openai.py +0 -0
  70. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/openrouter.py +0 -0
  71. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/providers/together.py +0 -0
  72. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/py.typed +0 -0
  73. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/result.py +0 -0
  74. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/settings.py +0 -0
  75. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pydantic_ai/usage.py +0 -0
  76. {pydantic_ai_slim-0.2.18 → pydantic_ai_slim-0.2.20}/pyproject.toml +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.2.18
3
+ Version: 0.2.20
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
6
6
  License-Expression: MIT
@@ -30,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
30
30
  Requires-Dist: griffe>=1.3.2
31
31
  Requires-Dist: httpx>=0.27
32
32
  Requires-Dist: opentelemetry-api>=1.28.0
33
- Requires-Dist: pydantic-graph==0.2.18
33
+ Requires-Dist: pydantic-graph==0.2.20
34
34
  Requires-Dist: pydantic>=2.10
35
35
  Requires-Dist: typing-inspection>=0.4.0
36
36
  Provides-Extra: a2a
37
- Requires-Dist: fasta2a==0.2.18; extra == 'a2a'
37
+ Requires-Dist: fasta2a==0.2.20; extra == 'a2a'
38
38
  Provides-Extra: anthropic
39
39
  Requires-Dist: anthropic>=0.52.0; extra == 'anthropic'
40
40
  Provides-Extra: bedrock
@@ -48,7 +48,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
48
48
  Provides-Extra: duckduckgo
49
49
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
50
50
  Provides-Extra: evals
51
- Requires-Dist: pydantic-evals==0.2.18; extra == 'evals'
51
+ Requires-Dist: pydantic-evals==0.2.20; extra == 'evals'
52
52
  Provides-Extra: google
53
53
  Requires-Dist: google-genai>=1.15.0; extra == 'google'
54
54
  Provides-Extra: groq
@@ -12,18 +12,12 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
12
12
  from opentelemetry.trace import Tracer
13
13
  from typing_extensions import TypeGuard, TypeVar, assert_never
14
14
 
15
+ from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
16
+ from pydantic_ai._utils import is_async_callable, run_in_executor
15
17
  from pydantic_graph import BaseNode, Graph, GraphRunContext
16
18
  from pydantic_graph.nodes import End, NodeRunEndT
17
19
 
18
- from . import (
19
- _output,
20
- _system_prompt,
21
- exceptions,
22
- messages as _messages,
23
- models,
24
- result,
25
- usage as _usage,
26
- )
20
+ from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
27
21
  from .result import OutputDataT
28
22
  from .settings import ModelSettings, merge_model_settings
29
23
  from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
@@ -39,6 +33,7 @@ __all__ = (
39
33
  'CallToolsNode',
40
34
  'build_run_context',
41
35
  'capture_run_messages',
36
+ 'HistoryProcessor',
42
37
  )
43
38
 
44
39
 
@@ -54,6 +49,23 @@ EndStrategy = Literal['early', 'exhaustive']
54
49
  DepsT = TypeVar('DepsT')
55
50
  OutputT = TypeVar('OutputT')
56
51
 
52
+ _HistoryProcessorSync = Callable[[list[_messages.ModelMessage]], list[_messages.ModelMessage]]
53
+ _HistoryProcessorAsync = Callable[[list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]]
54
+ _HistoryProcessorSyncWithCtx = Callable[[RunContext[DepsT], list[_messages.ModelMessage]], list[_messages.ModelMessage]]
55
+ _HistoryProcessorAsyncWithCtx = Callable[
56
+ [RunContext[DepsT], list[_messages.ModelMessage]], Awaitable[list[_messages.ModelMessage]]
57
+ ]
58
+ HistoryProcessor = Union[
59
+ _HistoryProcessorSync,
60
+ _HistoryProcessorAsync,
61
+ _HistoryProcessorSyncWithCtx[DepsT],
62
+ _HistoryProcessorAsyncWithCtx[DepsT],
63
+ ]
64
+ """A function that processes a list of model messages and returns a list of model messages.
65
+
66
+ Can optionally accept a `RunContext` as a parameter.
67
+ """
68
+
57
69
 
58
70
  @dataclasses.dataclass
59
71
  class GraphAgentState:
@@ -93,6 +105,8 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
93
105
  output_schema: _output.OutputSchema[OutputDataT] | None
94
106
  output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
95
107
 
108
+ history_processors: Sequence[HistoryProcessor[DepsT]]
109
+
96
110
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
97
111
  mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
98
112
  default_retries: int
@@ -327,8 +341,11 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
327
341
 
328
342
  model_settings, model_request_parameters = await self._prepare_request(ctx)
329
343
  model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
344
+ message_history = await _process_message_history(
345
+ ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
346
+ )
330
347
  async with ctx.deps.model.request_stream(
331
- ctx.state.message_history, model_settings, model_request_parameters
348
+ message_history, model_settings, model_request_parameters
332
349
  ) as streamed_response:
333
350
  self._did_stream = True
334
351
  ctx.state.usage.requests += 1
@@ -350,9 +367,10 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
350
367
 
351
368
  model_settings, model_request_parameters = await self._prepare_request(ctx)
352
369
  model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
353
- model_response = await ctx.deps.model.request(
354
- ctx.state.message_history, model_settings, model_request_parameters
370
+ message_history = await _process_message_history(
371
+ ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
355
372
  )
373
+ model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
356
374
  ctx.state.usage.incr(_usage.Usage())
357
375
 
358
376
  return self._finish_handling(ctx, model_response)
@@ -647,6 +665,7 @@ async def process_function_tools( # noqa C901
647
665
  # if tool_name is in output_schema, it means we found a output tool but an error occurred in
648
666
  # validation, we don't add another part here
649
667
  if output_tool_name is not None:
668
+ yield _messages.FunctionToolCallEvent(call)
650
669
  if found_used_output_tool:
651
670
  content = 'Output tool not used - a final result was already processed.'
652
671
  else:
@@ -657,9 +676,14 @@ async def process_function_tools( # noqa C901
657
676
  content=content,
658
677
  tool_call_id=call.tool_call_id,
659
678
  )
679
+ yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
660
680
  output_parts.append(part)
661
681
  else:
662
- output_parts.append(_unknown_tool(call.tool_name, call.tool_call_id, ctx))
682
+ yield _messages.FunctionToolCallEvent(call)
683
+
684
+ part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
685
+ yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
686
+ output_parts.append(part)
663
687
 
664
688
  if not calls_to_run:
665
689
  return
@@ -755,7 +779,12 @@ async def _tool_from_mcp_server(
755
779
  # some weird edge case occurs.
756
780
  if not server.is_running: # pragma: no cover
757
781
  raise exceptions.UserError(f'MCP server is not running: {server}')
758
- result = await server.call_tool(tool_name, args)
782
+
783
+ if server.process_tool_call is not None:
784
+ result = await server.process_tool_call(ctx, server.call_tool, tool_name, args)
785
+ else:
786
+ result = await server.call_tool(tool_name, args)
787
+
759
788
  return result
760
789
 
761
790
  for server in ctx.deps.mcp_servers:
@@ -865,3 +894,28 @@ def build_agent_graph(
865
894
  auto_instrument=False,
866
895
  )
867
896
  return graph
897
+
898
+
899
+ async def _process_message_history(
900
+ messages: list[_messages.ModelMessage],
901
+ processors: Sequence[HistoryProcessor[DepsT]],
902
+ run_context: RunContext[DepsT],
903
+ ) -> list[_messages.ModelMessage]:
904
+ """Process message history through a sequence of processors."""
905
+ for processor in processors:
906
+ takes_ctx = is_takes_ctx(processor)
907
+
908
+ if is_async_callable(processor):
909
+ if takes_ctx:
910
+ messages = await processor(run_context, messages)
911
+ else:
912
+ async_processor = cast(_HistoryProcessorAsync, processor)
913
+ messages = await async_processor(messages)
914
+ else:
915
+ if takes_ctx:
916
+ sync_processor_with_ctx = cast(_HistoryProcessorSyncWithCtx[DepsT], processor)
917
+ messages = await run_in_executor(sync_processor_with_ctx, run_context, messages)
918
+ else:
919
+ sync_processor = cast(_HistoryProcessorSync, processor)
920
+ messages = await run_in_executor(sync_processor, messages)
921
+ return messages
@@ -5,11 +5,10 @@ This module has to use numerous internal Pydantic APIs and is therefore brittle
5
5
 
6
6
  from __future__ import annotations as _annotations
7
7
 
8
- import inspect
9
8
  from collections.abc import Awaitable
10
9
  from dataclasses import dataclass, field
11
10
  from inspect import Parameter, signature
12
- from typing import TYPE_CHECKING, Any, Callable, cast
11
+ from typing import TYPE_CHECKING, Any, Callable, Union, cast
13
12
 
14
13
  from pydantic import ConfigDict
15
14
  from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -18,12 +17,12 @@ from pydantic.fields import FieldInfo
18
17
  from pydantic.json_schema import GenerateJsonSchema
19
18
  from pydantic.plugin._schema_validator import create_schema_validator
20
19
  from pydantic_core import SchemaValidator, core_schema
21
- from typing_extensions import get_origin
20
+ from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar, get_origin
22
21
 
23
22
  from pydantic_ai.tools import RunContext
24
23
 
25
24
  from ._griffe import doc_descriptions
26
- from ._utils import check_object_json_schema, is_model_like, run_in_executor
25
+ from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
27
26
 
28
27
  if TYPE_CHECKING:
29
28
  from .tools import DocstringFormat, ObjectJsonSchema
@@ -214,12 +213,21 @@ def function_schema( # noqa: C901
214
213
  positional_fields=positional_fields,
215
214
  var_positional_field=var_positional_field,
216
215
  takes_ctx=takes_ctx,
217
- is_async=inspect.iscoroutinefunction(function),
216
+ is_async=is_async_callable(function),
218
217
  function=function,
219
218
  )
220
219
 
221
220
 
222
- def _takes_ctx(function: Callable[..., Any]) -> bool:
221
+ P = ParamSpec('P')
222
+ R = TypeVar('R')
223
+
224
+
225
+ WithCtx = Callable[Concatenate[RunContext[Any], P], R]
226
+ WithoutCtx = Callable[P, R]
227
+ TargetFunc = Union[WithCtx[P, R], WithoutCtx[P, R]]
228
+
229
+
230
+ def _takes_ctx(function: TargetFunc[P, R]) -> TypeIs[WithCtx[P, R]]:
223
231
  """Check if a function takes a `RunContext` first argument.
224
232
 
225
233
  Args:
@@ -60,7 +60,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
60
60
 
61
61
  def __post_init__(self):
62
62
  self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
63
- self._is_async = inspect.iscoroutinefunction(self.function)
63
+ self._is_async = _utils.is_async_callable(self.function)
64
64
 
65
65
  async def validate(
66
66
  self,
@@ -18,7 +18,7 @@ class SystemPromptRunner(Generic[AgentDepsT]):
18
18
 
19
19
  def __post_init__(self):
20
20
  self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
21
- self._is_async = inspect.iscoroutinefunction(self.function)
21
+ self._is_async = _utils.is_async_callable(self.function)
22
22
 
23
23
  async def run(self, run_context: RunContext[AgentDepsT]) -> str:
24
24
  if self._takes_ctx:
@@ -1,20 +1,22 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import asyncio
4
+ import functools
5
+ import inspect
4
6
  import time
5
7
  import uuid
6
- from collections.abc import AsyncIterable, AsyncIterator, Iterator
8
+ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
7
9
  from contextlib import asynccontextmanager, suppress
8
10
  from dataclasses import dataclass, fields, is_dataclass
9
11
  from datetime import datetime, timezone
10
12
  from functools import partial
11
13
  from types import GenericAlias
12
- from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
14
+ from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload
13
15
 
14
16
  from anyio.to_thread import run_sync
15
17
  from pydantic import BaseModel, TypeAdapter
16
18
  from pydantic.json_schema import JsonSchemaValue
17
- from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
19
+ from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
18
20
 
19
21
  from pydantic_graph._utils import AbstractSpan
20
22
 
@@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str:
302
304
 
303
305
  def number_to_datetime(x: int | float) -> datetime:
304
306
  return TypeAdapter(datetime).validate_python(x)
307
+
308
+
309
+ AwaitableCallable = Callable[..., Awaitable[T]]
310
+
311
+
312
+ @overload
313
+ def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
314
+
315
+
316
+ @overload
317
+ def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
318
+
319
+
320
+ def is_async_callable(obj: Any) -> Any:
321
+ """Correctly check if a callable is async.
322
+
323
+ This function was copied from Starlette:
324
+ https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40
325
+ """
326
+ while isinstance(obj, functools.partial):
327
+ obj = obj.func
328
+
329
+ return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
@@ -28,6 +28,7 @@ from . import (
28
28
  result,
29
29
  usage as _usage,
30
30
  )
31
+ from ._agent_graph import HistoryProcessor
31
32
  from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
32
33
  from .result import FinalResult, OutputDataT, StreamedRunResult
33
34
  from .settings import ModelSettings, merge_model_settings
@@ -179,6 +180,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
179
180
  defer_model_check: bool = False,
180
181
  end_strategy: EndStrategy = 'early',
181
182
  instrument: InstrumentationSettings | bool | None = None,
183
+ history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
182
184
  ) -> None: ...
183
185
 
184
186
  @overload
@@ -208,6 +210,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
208
210
  defer_model_check: bool = False,
209
211
  end_strategy: EndStrategy = 'early',
210
212
  instrument: InstrumentationSettings | bool | None = None,
213
+ history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
211
214
  ) -> None: ...
212
215
 
213
216
  def __init__(
@@ -232,6 +235,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
232
235
  defer_model_check: bool = False,
233
236
  end_strategy: EndStrategy = 'early',
234
237
  instrument: InstrumentationSettings | bool | None = None,
238
+ history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
235
239
  **_deprecated_kwargs: Any,
236
240
  ):
237
241
  """Create an agent.
@@ -275,6 +279,9 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
275
279
  [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
276
280
  will be used, which defaults to False.
277
281
  See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
282
+ history_processors: Optional list of callables to process the message history before sending it to the model.
283
+ Each processor takes a list of messages and returns a modified list of messages.
284
+ Processors can be sync or async and are applied in sequence.
278
285
  """
279
286
  if model is None or defer_model_check:
280
287
  self.model = model
@@ -343,6 +350,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
343
350
  self._max_result_retries = output_retries if output_retries is not None else retries
344
351
  self._mcp_servers = mcp_servers
345
352
  self._prepare_tools = prepare_tools
353
+ self.history_processors = history_processors or []
346
354
  for tool in tools:
347
355
  if isinstance(tool, Tool):
348
356
  self._register_tool(tool)
@@ -669,10 +677,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
669
677
  if self._instructions is None and not self._instructions_functions:
670
678
  return None
671
679
 
672
- instructions = self._instructions or ''
680
+ instructions = [self._instructions] if self._instructions else []
673
681
  for instructions_runner in self._instructions_functions:
674
- instructions += '\n' + await instructions_runner.run(run_context)
675
- return instructions.strip()
682
+ instructions.append(await instructions_runner.run(run_context))
683
+ concatenated_instructions = '\n'.join(instruction for instruction in instructions if instruction)
684
+ return concatenated_instructions.strip() if concatenated_instructions else None
676
685
 
677
686
  # Copy the function tools so that retry state is agent-run-specific
678
687
  # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
@@ -689,6 +698,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
689
698
  end_strategy=self.end_strategy,
690
699
  output_schema=output_schema,
691
700
  output_validators=output_validators,
701
+ history_processors=self.history_processors,
692
702
  function_tools=run_function_tools,
693
703
  mcp_servers=self._mcp_servers,
694
704
  default_retries=self._default_retries,
@@ -4,7 +4,7 @@ import base64
4
4
  import functools
5
5
  import json
6
6
  from abc import ABC, abstractmethod
7
- from collections.abc import AsyncIterator, Sequence
7
+ from collections.abc import AsyncIterator, Awaitable, Sequence
8
8
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
9
9
  from dataclasses import dataclass
10
10
  from pathlib import Path
@@ -15,14 +15,20 @@ import anyio
15
15
  import httpx
16
16
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
17
17
  from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
18
+ from mcp.shared.exceptions import McpError
18
19
  from mcp.shared.message import SessionMessage
19
20
  from mcp.types import (
20
21
  AudioContent,
21
22
  BlobResourceContents,
23
+ CallToolRequest,
24
+ CallToolRequestParams,
25
+ CallToolResult,
26
+ ClientRequest,
22
27
  Content,
23
28
  EmbeddedResource,
24
29
  ImageContent,
25
30
  LoggingLevel,
31
+ RequestParams,
26
32
  TextContent,
27
33
  TextResourceContents,
28
34
  )
@@ -30,7 +36,7 @@ from typing_extensions import Self, assert_never, deprecated
30
36
 
31
37
  from pydantic_ai.exceptions import ModelRetry
32
38
  from pydantic_ai.messages import BinaryContent
33
- from pydantic_ai.tools import ToolDefinition
39
+ from pydantic_ai.tools import RunContext, ToolDefinition
34
40
 
35
41
  try:
36
42
  from mcp.client.session import ClientSession
@@ -60,6 +66,9 @@ class MCPServer(ABC):
60
66
  e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
61
67
  """
62
68
 
69
+ process_tool_call: ProcessToolCallback | None = None
70
+ """Hook to customize tool calling and optionally pass extra metadata."""
71
+
63
72
  _client: ClientSession
64
73
  _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
65
74
  _write_stream: MemoryObjectSendStream[SessionMessage]
@@ -113,13 +122,17 @@ class MCPServer(ABC):
113
122
  ]
114
123
 
115
124
  async def call_tool(
116
- self, tool_name: str, arguments: dict[str, Any]
117
- ) -> str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]:
125
+ self,
126
+ tool_name: str,
127
+ arguments: dict[str, Any],
128
+ metadata: dict[str, Any] | None = None,
129
+ ) -> ToolResult:
118
130
  """Call a tool on the server.
119
131
 
120
132
  Args:
121
133
  tool_name: The name of the tool to call.
122
134
  arguments: The arguments to pass to the tool.
135
+ metadata: Request-level metadata (optional)
123
136
 
124
137
  Returns:
125
138
  The result of the tool call.
@@ -127,7 +140,23 @@ class MCPServer(ABC):
127
140
  Raises:
128
141
  ModelRetry: If the tool call fails.
129
142
  """
130
- result = await self._client.call_tool(self.get_unprefixed_tool_name(tool_name), arguments)
143
+ try:
144
+ # meta param is not provided by session yet, so build and can send_request directly.
145
+ result = await self._client.send_request(
146
+ ClientRequest(
147
+ CallToolRequest(
148
+ method='tools/call',
149
+ params=CallToolRequestParams(
150
+ name=self.get_unprefixed_tool_name(tool_name),
151
+ arguments=arguments,
152
+ _meta=RequestParams.Meta(**metadata) if metadata else None,
153
+ ),
154
+ )
155
+ ),
156
+ CallToolResult,
157
+ )
158
+ except McpError as e:
159
+ raise ModelRetry(e.error.message)
131
160
 
132
161
  content = [self._map_tool_result_part(part) for part in result.content]
133
162
 
@@ -265,6 +294,9 @@ class MCPServerStdio(MCPServer):
265
294
  e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
266
295
  """
267
296
 
297
+ process_tool_call: ProcessToolCallback | None = None
298
+ """Hook to customize tool calling and optionally pass extra metadata."""
299
+
268
300
  timeout: float = 5
269
301
  """ The timeout in seconds to wait for the client to initialize."""
270
302
 
@@ -359,6 +391,9 @@ class _MCPServerHTTP(MCPServer):
359
391
  For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
360
392
  """
361
393
 
394
+ process_tool_call: ProcessToolCallback | None = None
395
+ """Hook to customize tool calling and optionally pass extra metadata."""
396
+
362
397
  @property
363
398
  @abstractmethod
364
399
  def _transport_client(
@@ -517,3 +552,29 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
517
552
  @property
518
553
  def _transport_client(self):
519
554
  return streamablehttp_client # pragma: no cover
555
+
556
+
557
+ ToolResult = (
558
+ str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]
559
+ )
560
+ """The result type of a tool call."""
561
+
562
+ CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]]
563
+ """A function type that represents a tool call."""
564
+
565
+ ProcessToolCallback = Callable[
566
+ [
567
+ RunContext[Any],
568
+ CallToolFunc,
569
+ str,
570
+ dict[str, Any],
571
+ ],
572
+ Awaitable[ToolResult],
573
+ ]
574
+ """A process tool callback.
575
+
576
+ It accepts a run context, the original tool call function, a tool name, and arguments.
577
+
578
+ Allows wrapping an MCP server tool call to customize it, including adding extra request
579
+ metadata.
580
+ """
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
3
  import base64
4
- import uuid
5
4
  from abc import ABC, abstractmethod
6
5
  from collections.abc import Sequence
7
6
  from dataclasses import dataclass, field, replace
@@ -888,13 +887,13 @@ class FunctionToolCallEvent:
888
887
 
889
888
  part: ToolCallPart
890
889
  """The (function) tool call to make."""
891
- call_id: str = field(init=False)
892
- """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
893
890
  event_kind: Literal['function_tool_call'] = 'function_tool_call'
894
891
  """Event type identifier, used as a discriminator."""
895
892
 
896
- def __post_init__(self):
897
- self.call_id = self.part.tool_call_id or str(uuid.uuid4())
893
+ @property
894
+ def call_id(self) -> str:
895
+ """An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
896
+ return self.part.tool_call_id
898
897
 
899
898
  __repr__ = _utils.dataclasses_no_defaults_repr
900
899
 
@@ -220,7 +220,7 @@ class AnthropicModel(Model):
220
220
  extra_headers = model_settings.get('extra_headers', {})
221
221
  extra_headers.setdefault('User-Agent', get_user_agent())
222
222
  return await self.client.beta.messages.create(
223
- max_tokens=model_settings.get('max_tokens', 1024),
223
+ max_tokens=model_settings.get('max_tokens', 4096),
224
224
  system=system_prompt or NOT_GIVEN,
225
225
  messages=anthropic_messages,
226
226
  model=self._model_name,
@@ -723,9 +723,7 @@ class _GeminiFunction(TypedDict):
723
723
 
724
724
  def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
725
725
  json_schema = tool.parameters_json_schema
726
- f = _GeminiFunction(name=tool.name, description=tool.description)
727
- if json_schema.get('properties'):
728
- f['parameters'] = json_schema
726
+ f = _GeminiFunction(name=tool.name, description=tool.description, parameters=json_schema)
729
727
  return f
730
728
 
731
729
 
@@ -10,9 +10,8 @@ from uuid import uuid4
10
10
 
11
11
  from typing_extensions import assert_never
12
12
 
13
- from pydantic_ai.providers import Provider
14
-
15
13
  from .. import UnexpectedModelBehavior, _utils, usage
14
+ from ..exceptions import UserError
16
15
  from ..messages import (
17
16
  BinaryContent,
18
17
  FileUrl,
@@ -30,6 +29,7 @@ from ..messages import (
30
29
  VideoUrl,
31
30
  )
32
31
  from ..profiles import ModelProfileSpec
32
+ from ..providers import Provider
33
33
  from ..settings import ModelSettings
34
34
  from ..tools import ToolDefinition
35
35
  from . import (
@@ -52,6 +52,7 @@ try:
52
52
  FunctionDeclarationDict,
53
53
  GenerateContentConfigDict,
54
54
  GenerateContentResponse,
55
+ HttpOptionsDict,
55
56
  Part,
56
57
  PartDict,
57
58
  SafetySettingDict,
@@ -252,8 +253,17 @@ class GoogleModel(Model):
252
253
  tool_config = self._get_tool_config(model_request_parameters, tools)
253
254
  system_instruction, contents = await self._map_messages(messages)
254
255
 
256
+ http_options: HttpOptionsDict = {
257
+ 'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
258
+ }
259
+ if timeout := model_settings.get('timeout'):
260
+ if isinstance(timeout, (int, float)):
261
+ http_options['timeout'] = int(1000 * timeout)
262
+ else:
263
+ raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')
264
+
255
265
  config = GenerateContentConfigDict(
256
- http_options={'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}},
266
+ http_options=http_options,
257
267
  system_instruction=system_instruction,
258
268
  temperature=model_settings.get('temperature'),
259
269
  top_p=model_settings.get('top_p'),
@@ -469,9 +479,11 @@ def _process_response_from_parts(
469
479
 
470
480
  def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
471
481
  json_schema = tool.parameters_json_schema
472
- f = FunctionDeclarationDict(name=tool.name, description=tool.description)
473
- if json_schema.get('properties'): # pragma: no branch
474
- f['parameters'] = json_schema # type: ignore
482
+ f = FunctionDeclarationDict(
483
+ name=tool.name,
484
+ description=tool.description,
485
+ parameters=json_schema, # type: ignore
486
+ )
475
487
  return f
476
488
 
477
489
 
@@ -48,68 +48,74 @@ class Provider(ABC, Generic[InterfaceClient]):
48
48
  return None # pragma: no cover
49
49
 
50
50
 
51
- def infer_provider(provider: str) -> Provider[Any]: # noqa: C901
52
- """Infer the provider from the provider name."""
51
+ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901
52
+ """Infers the provider class from the provider name."""
53
53
  if provider == 'openai':
54
54
  from .openai import OpenAIProvider
55
55
 
56
- return OpenAIProvider()
56
+ return OpenAIProvider
57
57
  elif provider == 'deepseek':
58
58
  from .deepseek import DeepSeekProvider
59
59
 
60
- return DeepSeekProvider()
60
+ return DeepSeekProvider
61
61
  elif provider == 'openrouter':
62
62
  from .openrouter import OpenRouterProvider
63
63
 
64
- return OpenRouterProvider()
64
+ return OpenRouterProvider
65
65
  elif provider == 'azure':
66
66
  from .azure import AzureProvider
67
67
 
68
- return AzureProvider()
68
+ return AzureProvider
69
69
  elif provider == 'google-vertex':
70
70
  from .google_vertex import GoogleVertexProvider
71
71
 
72
- return GoogleVertexProvider()
72
+ return GoogleVertexProvider
73
73
  elif provider == 'google-gla':
74
74
  from .google_gla import GoogleGLAProvider
75
75
 
76
- return GoogleGLAProvider()
76
+ return GoogleGLAProvider
77
77
  # NOTE: We don't test because there are many ways the `boto3.client` can retrieve the credentials.
78
78
  elif provider == 'bedrock':
79
79
  from .bedrock import BedrockProvider
80
80
 
81
- return BedrockProvider()
81
+ return BedrockProvider
82
82
  elif provider == 'groq':
83
83
  from .groq import GroqProvider
84
84
 
85
- return GroqProvider()
85
+ return GroqProvider
86
86
  elif provider == 'anthropic':
87
87
  from .anthropic import AnthropicProvider
88
88
 
89
- return AnthropicProvider()
89
+ return AnthropicProvider
90
90
  elif provider == 'mistral':
91
91
  from .mistral import MistralProvider
92
92
 
93
- return MistralProvider()
93
+ return MistralProvider
94
94
  elif provider == 'cohere':
95
95
  from .cohere import CohereProvider
96
96
 
97
- return CohereProvider()
97
+ return CohereProvider
98
98
  elif provider == 'grok':
99
99
  from .grok import GrokProvider
100
100
 
101
- return GrokProvider()
101
+ return GrokProvider
102
102
  elif provider == 'fireworks':
103
103
  from .fireworks import FireworksProvider
104
104
 
105
- return FireworksProvider()
105
+ return FireworksProvider
106
106
  elif provider == 'together':
107
107
  from .together import TogetherProvider
108
108
 
109
- return TogetherProvider()
109
+ return TogetherProvider
110
110
  elif provider == 'heroku':
111
111
  from .heroku import HerokuProvider
112
112
 
113
- return HerokuProvider()
113
+ return HerokuProvider
114
114
  else: # pragma: no cover
115
115
  raise ValueError(f'Unknown provider: {provider}')
116
+
117
+
118
+ def infer_provider(provider: str) -> Provider[Any]:
119
+ """Infer the provider from the provider name."""
120
+ provider_class = infer_provider_class(provider)
121
+ return provider_class()
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations as _annotations
2
2
 
3
- import asyncio
4
3
  import dataclasses
5
4
  import json
6
5
  from collections.abc import Awaitable, Sequence
@@ -337,7 +336,7 @@ class Tool(Generic[AgentDepsT]):
337
336
  validator=SchemaValidator(schema=core_schema.any_schema()),
338
337
  json_schema=json_schema,
339
338
  takes_ctx=False,
340
- is_async=asyncio.iscoroutinefunction(function),
339
+ is_async=_utils.is_async_callable(function),
341
340
  )
342
341
 
343
342
  return cls(