pydantic-ai-slim 0.0.41__tar.gz → 0.0.43__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (50) hide show
  1. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/.gitignore +1 -1
  2. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/PKG-INFO +5 -3
  3. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_agent_graph.py +61 -7
  4. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_cli.py +1 -1
  5. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_pydantic.py +6 -5
  6. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/agent.py +55 -7
  7. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/common_tools/duckduckgo.py +1 -1
  8. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/common_tools/tavily.py +1 -1
  9. pydantic_ai_slim-0.0.43/pydantic_ai/mcp.py +198 -0
  10. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/messages.py +3 -0
  11. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/anthropic.py +1 -1
  12. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/cohere.py +1 -1
  13. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/groq.py +1 -1
  14. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/instrumented.py +13 -7
  15. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/mistral.py +1 -1
  16. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/openai.py +1 -1
  17. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/vertexai.py +1 -1
  18. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/wrapper.py +5 -1
  19. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/anthropic.py +1 -1
  20. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/azure.py +1 -1
  21. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/bedrock.py +1 -1
  22. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/deepseek.py +1 -1
  23. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/google_vertex.py +11 -19
  24. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/groq.py +16 -12
  25. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/mistral.py +12 -12
  26. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/openai.py +1 -1
  27. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/tools.py +23 -2
  28. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pyproject.toml +8 -3
  29. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/README.md +0 -0
  30. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/__init__.py +0 -0
  31. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_griffe.py +0 -0
  32. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_parts_manager.py +0 -0
  33. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_result.py +0 -0
  34. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_system_prompt.py +0 -0
  35. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/_utils.py +0 -0
  36. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/common_tools/__init__.py +0 -0
  37. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/exceptions.py +0 -0
  38. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/format_as_xml.py +0 -0
  39. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/__init__.py +0 -0
  40. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/bedrock.py +0 -0
  41. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/fallback.py +0 -0
  42. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/function.py +0 -0
  43. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/gemini.py +0 -0
  44. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/models/test.py +0 -0
  45. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/__init__.py +0 -0
  46. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/providers/google_gla.py +0 -0
  47. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/py.typed +0 -0
  48. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/result.py +0 -0
  49. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/settings.py +0 -0
  50. {pydantic_ai_slim-0.0.41 → pydantic_ai_slim-0.0.43}/pydantic_ai/usage.py +0 -0
@@ -1,5 +1,4 @@
1
1
  site
2
- .python-version
3
2
  .venv
4
3
  dist
5
4
  __pycache__
@@ -16,3 +15,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
16
15
  .vscode/
17
16
  /question_graph_history.json
18
17
  /docs-site/.wrangler/
18
+ /CLAUDE.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.41
3
+ Version: 0.0.43
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.41
32
+ Requires-Dist: pydantic-graph==0.0.43
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -45,9 +45,11 @@ Requires-Dist: cohere>=5.13.11; extra == 'cohere'
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
47
  Provides-Extra: groq
48
- Requires-Dist: groq>=0.12.0; extra == 'groq'
48
+ Requires-Dist: groq>=0.15.0; extra == 'groq'
49
49
  Provides-Extra: logfire
50
50
  Requires-Dist: logfire>=2.3; extra == 'logfire'
51
+ Provides-Extra: mcp
52
+ Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
51
53
  Provides-Extra: mistral
52
54
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
53
55
  Provides-Extra: openai
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
9
  from dataclasses import field
10
- from typing import Any, Generic, Literal, Union, cast
10
+ from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
11
11
 
12
12
  from opentelemetry.trace import Span, Tracer
13
13
  from typing_extensions import TypeGuard, TypeVar, assert_never
@@ -27,11 +27,10 @@ from . import (
27
27
  from .models.instrumented import InstrumentedModel
28
28
  from .result import ResultDataT
29
29
  from .settings import ModelSettings, merge_model_settings
30
- from .tools import (
31
- RunContext,
32
- Tool,
33
- ToolDefinition,
34
- )
30
+ from .tools import RunContext, Tool, ToolDefinition
31
+
32
+ if TYPE_CHECKING:
33
+ from .mcp import MCPServer
35
34
 
36
35
  __all__ = (
37
36
  'GraphAgentState',
@@ -94,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
94
93
  result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
95
94
 
96
95
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
96
+ mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
97
97
 
98
98
  run_span: Span
99
99
  tracer: Tracer
@@ -219,7 +219,17 @@ async def _prepare_request_parameters(
219
219
  if tool_def := await tool.prepare_tool_def(ctx):
220
220
  function_tool_defs.append(tool_def)
221
221
 
222
- await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
222
+ async def add_mcp_server_tools(server: MCPServer) -> None:
223
+ if not server.is_running:
224
+ raise exceptions.UserError(f'MCP server is not running: {server}')
225
+ tool_defs = await server.list_tools()
226
+ # TODO(Marcelo): We should check if the tool names are unique. If not, we should raise an error.
227
+ function_tool_defs.extend(tool_defs)
228
+
229
+ await asyncio.gather(
230
+ *map(add_tool, ctx.deps.function_tools.values()),
231
+ *map(add_mcp_server_tools, ctx.deps.mcp_servers),
232
+ )
223
233
 
224
234
  result_schema = ctx.deps.result_schema
225
235
  return models.ModelRequestParameters(
@@ -594,6 +604,21 @@ async def process_function_tools(
594
604
  yield event
595
605
  call_index_to_event_id[len(calls_to_run)] = event.call_id
596
606
  calls_to_run.append((tool, call))
607
+ elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
608
+ if stub_function_tools:
609
+ # TODO(Marcelo): We should add coverage for this part of the code.
610
+ output_parts.append( # pragma: no cover
611
+ _messages.ToolReturnPart(
612
+ tool_name=call.tool_name,
613
+ content='Tool not executed - a final result was already processed.',
614
+ tool_call_id=call.tool_call_id,
615
+ )
616
+ )
617
+ else:
618
+ event = _messages.FunctionToolCallEvent(call)
619
+ yield event
620
+ call_index_to_event_id[len(calls_to_run)] = event.call_id
621
+ calls_to_run.append((mcp_tool, call))
597
622
  elif result_schema is not None and call.tool_name in result_schema.tools:
598
623
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
599
624
  # validation, we don't add another part here
@@ -641,6 +666,35 @@ async def process_function_tools(
641
666
  output_parts.append(results_by_index[k])
642
667
 
643
668
 
669
+ async def _tool_from_mcp_server(
670
+ tool_name: str,
671
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
672
+ ) -> Tool[DepsT] | None:
673
+ """Call each MCP server to find the tool with the given name.
674
+
675
+ Args:
676
+ tool_name: The name of the tool to find.
677
+ ctx: The current run context.
678
+
679
+ Returns:
680
+ The tool with the given name, or `None` if no tool with the given name is found.
681
+ """
682
+
683
+ async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
684
+ # There's no normal situation where the server will not be running at this point, we check just in case
685
+ # some weird edge case occurs.
686
+ if not server.is_running: # pragma: no cover
687
+ raise exceptions.UserError(f'MCP server is not running: {server}')
688
+ result = await server.call_tool(tool_name, args)
689
+ return result
690
+
691
+ for server in ctx.deps.mcp_servers:
692
+ tools = await server.list_tools()
693
+ if tool_name in {tool.name for tool in tools}:
694
+ return Tool(name=tool_name, function=run_tool, takes_ctx=True)
695
+ return None
696
+
697
+
644
698
  def _unknown_tool(
645
699
  tool_name: str,
646
700
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
@@ -31,7 +31,7 @@ try:
31
31
  except ImportError as _import_error:
32
32
  raise ImportError(
33
33
  'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the PydanticAI CLI, '
34
- "you can use the `cli` optional group — `pip install 'pydantic-ai-slim[cli]'`"
34
+ 'you can use the `cli` optional group — `pip install "pydantic-ai-slim[cli]"`'
35
35
  ) from _import_error
36
36
 
37
37
  from pydantic_ai.agent import Agent
@@ -44,6 +44,7 @@ def function_schema( # noqa: C901
44
44
  takes_ctx: bool,
45
45
  docstring_format: DocstringFormat,
46
46
  require_parameter_descriptions: bool,
47
+ schema_generator: type[GenerateJsonSchema],
47
48
  ) -> FunctionSchema:
48
49
  """Build a Pydantic validator and JSON schema from a tool function.
49
50
 
@@ -52,6 +53,7 @@ def function_schema( # noqa: C901
52
53
  takes_ctx: Whether the function takes a `RunContext` first argument.
53
54
  docstring_format: The docstring format to use.
54
55
  require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
56
+ schema_generator: The JSON schema generator class to use.
55
57
 
56
58
  Returns:
57
59
  A `FunctionSchema` instance.
@@ -150,14 +152,12 @@ def function_schema( # noqa: C901
150
152
  )
151
153
  # PluggableSchemaValidator is api compatible with SchemaValidator
152
154
  schema_validator = cast(SchemaValidator, schema_validator)
153
- json_schema = GenerateJsonSchema().generate(schema)
155
+ json_schema = schema_generator().generate(schema)
154
156
 
155
157
  # workaround for https://github.com/pydantic/pydantic/issues/10785
156
- # if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
158
+ # if we build a custom TypedDict schema (matches when `single_arg_name is None`), we manually set
157
159
  # `additionalProperties` in the JSON Schema
158
- if single_arg_name is None:
159
- json_schema['additionalProperties'] = bool(var_kwargs_schema)
160
- elif not description:
160
+ if single_arg_name is not None and not description:
161
161
  # if the tool description is not set, and we have a single parameter, take the description from that
162
162
  # and set it on the tool
163
163
  description = json_schema.pop('description', None)
@@ -218,6 +218,7 @@ def _build_schema(
218
218
  td_schema = core_schema.typed_dict_schema(
219
219
  fields,
220
220
  config=core_config,
221
+ total=var_kwargs_schema is None,
221
222
  extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
222
223
  )
223
224
  return td_schema, None
@@ -3,12 +3,13 @@ from __future__ import annotations as _annotations
3
3
  import dataclasses
4
4
  import inspect
5
5
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
6
- from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
6
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
7
7
  from copy import deepcopy
8
8
  from types import FrameType
9
- from typing import Any, Callable, ClassVar, Generic, cast, final, overload
9
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
10
10
 
11
11
  from opentelemetry.trace import NoOpTracer, use_span
12
+ from pydantic.json_schema import GenerateJsonSchema
12
13
  from typing_extensions import TypeGuard, TypeVar, deprecated
13
14
 
14
15
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
@@ -31,6 +32,7 @@ from .settings import ModelSettings, merge_model_settings
31
32
  from .tools import (
32
33
  AgentDepsT,
33
34
  DocstringFormat,
35
+ GenerateToolJsonSchema,
34
36
  RunContext,
35
37
  Tool,
36
38
  ToolFuncContext,
@@ -47,6 +49,9 @@ CallToolsNode = _agent_graph.CallToolsNode
47
49
  ModelRequestNode = _agent_graph.ModelRequestNode
48
50
  UserPromptNode = _agent_graph.UserPromptNode
49
51
 
52
+ if TYPE_CHECKING:
53
+ from pydantic_ai.mcp import MCPServer
54
+
50
55
  __all__ = (
51
56
  'Agent',
52
57
  'AgentRun',
@@ -129,6 +134,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
129
134
  repr=False
130
135
  )
131
136
  _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
137
+ _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
132
138
  _default_retries: int = dataclasses.field(repr=False)
133
139
  _max_result_retries: int = dataclasses.field(repr=False)
134
140
  _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
@@ -148,6 +154,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
148
154
  result_tool_description: str | None = None,
149
155
  result_retries: int | None = None,
150
156
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
157
+ mcp_servers: Sequence[MCPServer] = (),
151
158
  defer_model_check: bool = False,
152
159
  end_strategy: EndStrategy = 'early',
153
160
  instrument: InstrumentationSettings | bool | None = None,
@@ -173,6 +180,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
173
180
  result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
174
181
  tools: Tools to register with the agent, you can also register tools via the decorators
175
182
  [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
183
+ mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
184
+ for each server you want the agent to connect to.
176
185
  defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
177
186
  it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
178
187
  which checks for the necessary environment variables. Set this to `false`
@@ -186,6 +195,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
186
195
  If this isn't set, then the last value set by
187
196
  [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all]
188
197
  will be used, which defaults to False.
198
+ See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
189
199
  """
190
200
  if model is None or defer_model_check:
191
201
  self.model = model
@@ -215,6 +225,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
215
225
 
216
226
  self._default_retries = retries
217
227
  self._max_result_retries = result_retries if result_retries is not None else retries
228
+ self._mcp_servers = mcp_servers
218
229
  for tool in tools:
219
230
  if isinstance(tool, Tool):
220
231
  self._register_tool(tool)
@@ -435,7 +446,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
435
446
  usage_limits = usage_limits or _usage.UsageLimits()
436
447
 
437
448
  if isinstance(model_used, InstrumentedModel):
438
- tracer = model_used.options.tracer
449
+ tracer = model_used.settings.tracer
439
450
  else:
440
451
  tracer = NoOpTracer()
441
452
  agent_name = self.name or 'agent'
@@ -461,6 +472,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
461
472
  result_tools=self._result_schema.tool_defs() if self._result_schema else [],
462
473
  result_validators=result_validators,
463
474
  function_tools=self._function_tools,
475
+ mcp_servers=self._mcp_servers,
464
476
  run_span=run_span,
465
477
  tracer=tracer,
466
478
  )
@@ -927,6 +939,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
927
939
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
928
940
  docstring_format: DocstringFormat = 'auto',
929
941
  require_parameter_descriptions: bool = False,
942
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
930
943
  ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
931
944
 
932
945
  def tool(
@@ -939,6 +952,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
939
952
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
940
953
  docstring_format: DocstringFormat = 'auto',
941
954
  require_parameter_descriptions: bool = False,
955
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
942
956
  ) -> Any:
943
957
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
944
958
 
@@ -980,6 +994,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
980
994
  docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
981
995
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
982
996
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
997
+ schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
983
998
  """
984
999
  if func is None:
985
1000
 
@@ -988,7 +1003,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
988
1003
  ) -> ToolFuncContext[AgentDepsT, ToolParams]:
989
1004
  # noinspection PyTypeChecker
990
1005
  self._register_function(
991
- func_, True, name, retries, prepare, docstring_format, require_parameter_descriptions
1006
+ func_,
1007
+ True,
1008
+ name,
1009
+ retries,
1010
+ prepare,
1011
+ docstring_format,
1012
+ require_parameter_descriptions,
1013
+ schema_generator,
992
1014
  )
993
1015
  return func_
994
1016
 
@@ -996,7 +1018,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
996
1018
  else:
997
1019
  # noinspection PyTypeChecker
998
1020
  self._register_function(
999
- func, True, name, retries, prepare, docstring_format, require_parameter_descriptions
1021
+ func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1000
1022
  )
1001
1023
  return func
1002
1024
 
@@ -1013,6 +1035,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1013
1035
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
1014
1036
  docstring_format: DocstringFormat = 'auto',
1015
1037
  require_parameter_descriptions: bool = False,
1038
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1016
1039
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
1017
1040
 
1018
1041
  def tool_plain(
@@ -1025,6 +1048,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1025
1048
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
1026
1049
  docstring_format: DocstringFormat = 'auto',
1027
1050
  require_parameter_descriptions: bool = False,
1051
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1028
1052
  ) -> Any:
1029
1053
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
1030
1054
 
@@ -1066,20 +1090,28 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1066
1090
  docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
1067
1091
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
1068
1092
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
1093
+ schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1069
1094
  """
1070
1095
  if func is None:
1071
1096
 
1072
1097
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
1073
1098
  # noinspection PyTypeChecker
1074
1099
  self._register_function(
1075
- func_, False, name, retries, prepare, docstring_format, require_parameter_descriptions
1100
+ func_,
1101
+ False,
1102
+ name,
1103
+ retries,
1104
+ prepare,
1105
+ docstring_format,
1106
+ require_parameter_descriptions,
1107
+ schema_generator,
1076
1108
  )
1077
1109
  return func_
1078
1110
 
1079
1111
  return tool_decorator
1080
1112
  else:
1081
1113
  self._register_function(
1082
- func, False, name, retries, prepare, docstring_format, require_parameter_descriptions
1114
+ func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1083
1115
  )
1084
1116
  return func
1085
1117
 
@@ -1092,6 +1124,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1092
1124
  prepare: ToolPrepareFunc[AgentDepsT] | None,
1093
1125
  docstring_format: DocstringFormat,
1094
1126
  require_parameter_descriptions: bool,
1127
+ schema_generator: type[GenerateJsonSchema],
1095
1128
  ) -> None:
1096
1129
  """Private utility to register a function as a tool."""
1097
1130
  retries_ = retries if retries is not None else self._default_retries
@@ -1103,6 +1136,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1103
1136
  prepare=prepare,
1104
1137
  docstring_format=docstring_format,
1105
1138
  require_parameter_descriptions=require_parameter_descriptions,
1139
+ schema_generator=schema_generator,
1106
1140
  )
1107
1141
  self._register_tool(tool)
1108
1142
 
@@ -1253,6 +1287,20 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1253
1287
  """
1254
1288
  return isinstance(node, End)
1255
1289
 
1290
+ @asynccontextmanager
1291
+ async def run_mcp_servers(self) -> AsyncIterator[None]:
1292
+ """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1293
+
1294
+ Returns: a context manager to start and shutdown the servers.
1295
+ """
1296
+ exit_stack = AsyncExitStack()
1297
+ try:
1298
+ for mcp_server in self._mcp_servers:
1299
+ await exit_stack.enter_async_context(mcp_server)
1300
+ yield
1301
+ finally:
1302
+ await exit_stack.aclose()
1303
+
1256
1304
 
1257
1305
  @dataclasses.dataclass(repr=False)
1258
1306
  class AgentRun(Generic[AgentDepsT, ResultDataT]):
@@ -13,7 +13,7 @@ try:
13
13
  except ImportError as _import_error:
14
14
  raise ImportError(
15
15
  'Please install `duckduckgo-search` to use the DuckDuckGo search tool, '
16
- "you can use the `duckduckgo` optional group — `pip install 'pydantic-ai-slim[duckduckgo]'`"
16
+ 'you can use the `duckduckgo` optional group — `pip install "pydantic-ai-slim[duckduckgo]"`'
17
17
  ) from _import_error
18
18
 
19
19
  __all__ = ('duckduckgo_search_tool',)
@@ -11,7 +11,7 @@ try:
11
11
  except ImportError as _import_error:
12
12
  raise ImportError(
13
13
  'Please install `tavily-python` to use the Tavily search tool, '
14
- "you can use the `tavily` optional group — `pip install 'pydantic-ai-slim[tavily]'`"
14
+ 'you can use the `tavily` optional group — `pip install "pydantic-ai-slim[tavily]"`'
15
15
  ) from _import_error
16
16
 
17
17
  __all__ = ('tavily_search_tool',)
@@ -0,0 +1,198 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import AsyncIterator, Sequence
5
+ from contextlib import AsyncExitStack, asynccontextmanager
6
+ from dataclasses import dataclass
7
+ from types import TracebackType
8
+ from typing import Any
9
+
10
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11
+ from mcp.types import JSONRPCMessage
12
+ from typing_extensions import Self
13
+
14
+ from pydantic_ai.tools import ToolDefinition
15
+
16
+ try:
17
+ from mcp.client.session import ClientSession
18
+ from mcp.client.sse import sse_client
19
+ from mcp.client.stdio import StdioServerParameters, stdio_client
20
+ from mcp.types import CallToolResult
21
+ except ImportError as _import_error:
22
+ raise ImportError(
23
+ 'Please install the `mcp` package to use the MCP server, '
24
+ 'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
25
+ ) from _import_error
26
+
27
+ __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP'
28
+
29
+
30
+ class MCPServer(ABC):
31
+ """Base class for attaching agents to MCP servers.
32
+
33
+ See <https://modelcontextprotocol.io> for more information.
34
+ """
35
+
36
+ is_running: bool = False
37
+
38
+ _client: ClientSession
39
+ _read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
40
+ _write_stream: MemoryObjectSendStream[JSONRPCMessage]
41
+ _exit_stack: AsyncExitStack
42
+
43
+ @abstractmethod
44
+ @asynccontextmanager
45
+ async def client_streams(
46
+ self,
47
+ ) -> AsyncIterator[
48
+ tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
49
+ ]:
50
+ """Create the streams for the MCP server."""
51
+ raise NotImplementedError('MCP Server subclasses must implement this method.')
52
+ yield
53
+
54
+ async def list_tools(self) -> list[ToolDefinition]:
55
+ """Retrieve tools that are currently active on the server.
56
+
57
+ Note:
58
+ - We don't cache tools as they might change.
59
+ - We also don't subscribe to the server to avoid complexity.
60
+ """
61
+ tools = await self._client.list_tools()
62
+ return [
63
+ ToolDefinition(
64
+ name=tool.name,
65
+ description=tool.description or '',
66
+ parameters_json_schema=tool.inputSchema,
67
+ )
68
+ for tool in tools.tools
69
+ ]
70
+
71
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
72
+ """Call a tool on the server.
73
+
74
+ Args:
75
+ tool_name: The name of the tool to call.
76
+ arguments: The arguments to pass to the tool.
77
+
78
+ Returns:
79
+ The result of the tool call.
80
+ """
81
+ return await self._client.call_tool(tool_name, arguments)
82
+
83
+ async def __aenter__(self) -> Self:
84
+ self._exit_stack = AsyncExitStack()
85
+
86
+ self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
87
+ client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream)
88
+ self._client = await self._exit_stack.enter_async_context(client)
89
+
90
+ await self._client.initialize()
91
+ self.is_running = True
92
+ return self
93
+
94
+ async def __aexit__(
95
+ self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
96
+ ) -> bool | None:
97
+ await self._exit_stack.aclose()
98
+ self.is_running = False
99
+
100
+
101
+ @dataclass
102
+ class MCPServerStdio(MCPServer):
103
+ """Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
104
+
105
+ This class implements the stdio transport from the MCP specification.
106
+ See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio> for more information.
107
+
108
+ !!! note
109
+ Using this class as an async context manager will start the server as a subprocess when entering the context,
110
+ and stop it when exiting the context.
111
+
112
+ Example:
113
+ ```python {py="3.10"}
114
+ from pydantic_ai import Agent
115
+ from pydantic_ai.mcp import MCPServerStdio
116
+
117
+ server = MCPServerStdio('npx', ['-y', '@pydantic/mcp-run-python', 'stdio']) # (1)!
118
+ agent = Agent('openai:gpt-4o', mcp_servers=[server])
119
+
120
+ async def main():
121
+ async with agent.run_mcp_servers(): # (2)!
122
+ ...
123
+ ```
124
+
125
+ 1. See [MCP Run Python](../mcp/run-python.md) for more information.
126
+ 2. This will start the server as a subprocess and connect to it.
127
+ """
128
+
129
+ command: str
130
+ """The command to run."""
131
+
132
+ args: Sequence[str]
133
+ """The arguments to pass to the command."""
134
+
135
+ env: dict[str, str] | None = None
136
+ """The environment variables the CLI server will have access to.
137
+
138
+ By default the subprocess will not inherit any environment variables from the parent process.
139
+ If you want to inherit the environment variables from the parent process, use `env=os.environ`.
140
+ """
141
+
142
+ @asynccontextmanager
143
+ async def client_streams(
144
+ self,
145
+ ) -> AsyncIterator[
146
+ tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
147
+ ]:
148
+ server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env)
149
+ async with stdio_client(server=server) as (read_stream, write_stream):
150
+ yield read_stream, write_stream
151
+
152
+
153
+ @dataclass
154
+ class MCPServerHTTP(MCPServer):
155
+ """An MCP server that connects over streamable HTTP connections.
156
+
157
+ This class implements the SSE transport from the MCP specification.
158
+ See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
159
+
160
+ The name "HTTP" is used since this implemented will be adapted in future to use the new
161
+ [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
162
+
163
+ !!! note
164
+ Using this class as an async context manager will create a new pool of HTTP connections to connect
165
+ to a server which should already be running.
166
+
167
+ Example:
168
+ ```python {py="3.10"}
169
+ from pydantic_ai import Agent
170
+ from pydantic_ai.mcp import MCPServerHTTP
171
+
172
+ server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
173
+ agent = Agent('openai:gpt-4o', mcp_servers=[server])
174
+
175
+ async def main():
176
+ async with agent.run_mcp_servers(): # (2)!
177
+ ...
178
+ ```
179
+
180
+ 1. E.g. you might be connecting to a server run with `npx @pydantic/mcp-run-python sse`,
181
+ see [MCP Run Python](../mcp/run-python.md) for more information.
182
+ 2. This will connect to a server running on `localhost:3001`.
183
+ """
184
+
185
+ url: str
186
+ """The URL of the SSE endpoint on the MCP server.
187
+
188
+ For example for a server running locally, this might be `http://localhost:3001/sse`.
189
+ """
190
+
191
+ @asynccontextmanager
192
+ async def client_streams(
193
+ self,
194
+ ) -> AsyncIterator[
195
+ tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
196
+ ]: # pragma: no cover
197
+ async with sse_client(url=self.url) as (read_stream, write_stream):
198
+ yield read_stream, write_stream
@@ -26,6 +26,9 @@ class SystemPromptPart:
26
26
  content: str
27
27
  """The content of the prompt."""
28
28
 
29
+ timestamp: datetime = field(default_factory=_now_utc)
30
+ """The timestamp of the prompt."""
31
+
29
32
  dynamic_ref: str | None = None
30
33
  """The ref of the dynamic system prompt function that generated this part.
31
34
 
@@ -65,7 +65,7 @@ try:
65
65
  except ImportError as _import_error:
66
66
  raise ImportError(
67
67
  'Please install `anthropic` to use the Anthropic model, '
68
- "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
68
+ 'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
69
69
  ) from _import_error
70
70
 
71
71
  LatestAnthropicModelNames = Literal[
@@ -50,7 +50,7 @@ try:
50
50
  except ImportError as _import_error:
51
51
  raise ImportError(
52
52
  'Please install `cohere` to use the Cohere model, '
53
- "you can use the `cohere` optional group — `pip install 'pydantic-ai-slim[cohere]'`"
53
+ 'you can use the `cohere` optional group — `pip install "pydantic-ai-slim[cohere]"`'
54
54
  ) from _import_error
55
55
 
56
56
  LatestCohereModelNames = Literal[
@@ -41,7 +41,7 @@ try:
41
41
  except ImportError as _import_error:
42
42
  raise ImportError(
43
43
  'Please install `groq` to use the Groq model, '
44
- "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
44
+ 'you can use the `groq` optional group — `pip install "pydantic-ai-slim[groq]"`'
45
45
  ) from _import_error
46
46
 
47
47
 
@@ -52,7 +52,9 @@ class InstrumentationSettings:
52
52
 
53
53
  - `Agent(instrument=...)`
54
54
  - [`Agent.instrument_all()`][pydantic_ai.agent.Agent.instrument_all]
55
- - `InstrumentedModel`
55
+ - [`InstrumentedModel`][pydantic_ai.models.instrumented.InstrumentedModel]
56
+
57
+ See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
56
58
  """
57
59
 
58
60
  tracer: Tracer = field(repr=False)
@@ -94,9 +96,13 @@ GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model'
94
96
 
95
97
  @dataclass
96
98
  class InstrumentedModel(WrapperModel):
97
- """Model which is instrumented with OpenTelemetry."""
99
+ """Model which wraps another model so that requests are instrumented with OpenTelemetry.
100
+
101
+ See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
102
+ """
98
103
 
99
- options: InstrumentationSettings
104
+ settings: InstrumentationSettings
105
+ """Configuration for instrumenting requests."""
100
106
 
101
107
  def __init__(
102
108
  self,
@@ -104,7 +110,7 @@ class InstrumentedModel(WrapperModel):
104
110
  options: InstrumentationSettings | None = None,
105
111
  ) -> None:
106
112
  super().__init__(wrapped)
107
- self.options = options or InstrumentationSettings()
113
+ self.settings = options or InstrumentationSettings()
108
114
 
109
115
  async def request(
110
116
  self,
@@ -156,7 +162,7 @@ class InstrumentedModel(WrapperModel):
156
162
  if isinstance(value := model_settings.get(key), (float, int)):
157
163
  attributes[f'gen_ai.request.{key}'] = value
158
164
 
159
- with self.options.tracer.start_as_current_span(span_name, attributes=attributes) as span:
165
+ with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
160
166
 
161
167
  def finish(response: ModelResponse, usage: Usage):
162
168
  if not span.is_recording():
@@ -190,9 +196,9 @@ class InstrumentedModel(WrapperModel):
190
196
  yield finish
191
197
 
192
198
  def _emit_events(self, span: Span, events: list[Event]) -> None:
193
- if self.options.event_mode == 'logs':
199
+ if self.settings.event_mode == 'logs':
194
200
  for event in events:
195
- self.options.event_logger.emit(event)
201
+ self.settings.event_logger.emit(event)
196
202
  else:
197
203
  attr_name = 'events'
198
204
  span.set_attributes(
@@ -75,7 +75,7 @@ try:
75
75
  except ImportError as e:
76
76
  raise ImportError(
77
77
  'Please install `mistral` to use the Mistral model, '
78
- "you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
78
+ 'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
79
79
  ) from e
80
80
 
81
81
  LatestMistralModelNames = Literal[
@@ -57,7 +57,7 @@ try:
57
57
  except ImportError as _import_error:
58
58
  raise ImportError(
59
59
  'Please install `openai` to use the OpenAI model, '
60
- "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
60
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
61
61
  ) from _import_error
62
62
 
63
63
  OpenAIModelName = Union[str, ChatModel]
@@ -27,7 +27,7 @@ try:
27
27
  except ImportError as _import_error:
28
28
  raise ImportError(
29
29
  'Please install `google-auth` to use the VertexAI model, '
30
- "you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
30
+ 'you can use the `vertexai` optional group — `pip install "pydantic-ai-slim[vertexai]"`'
31
31
  ) from _import_error
32
32
 
33
33
  VERTEX_AI_URL_TEMPLATE = (
@@ -13,9 +13,13 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, i
13
13
 
14
14
  @dataclass(init=False)
15
15
  class WrapperModel(Model):
16
- """Model which wraps another model."""
16
+ """Model which wraps another model.
17
+
18
+ Does nothing on its own, used as a base class.
19
+ """
17
20
 
18
21
  wrapped: Model
22
+ """The underlying model being wrapped."""
19
23
 
20
24
  def __init__(self, wrapped: Model | KnownModelName):
21
25
  self.wrapped = infer_model(wrapped)
@@ -12,7 +12,7 @@ try:
12
12
  except ImportError as _import_error: # pragma: no cover
13
13
  raise ImportError(
14
14
  'Please install the `anthropic` package to use the Anthropic provider, '
15
- "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
15
+ 'you can use the `anthropic` optional group — `pip install "pydantic-ai-slim[anthropic]"`'
16
16
  ) from _import_error
17
17
 
18
18
 
@@ -13,7 +13,7 @@ try:
13
13
  except ImportError as _import_error: # pragma: no cover
14
14
  raise ImportError(
15
15
  'Please install the `openai` package to use the Azure provider, '
16
- "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
16
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
17
17
  ) from _import_error
18
18
 
19
19
 
@@ -11,7 +11,7 @@ try:
11
11
  except ImportError as _import_error:
12
12
  raise ImportError(
13
13
  'Please install the `boto3` package to use the Bedrock provider, '
14
- "you can use the `bedrock` optional group — `pip install 'pydantic-ai-slim[bedrock]'`"
14
+ 'you can use the `bedrock` optional group — `pip install "pydantic-ai-slim[bedrock]"`'
15
15
  ) from _import_error
16
16
 
17
17
 
@@ -13,7 +13,7 @@ try:
13
13
  except ImportError as _import_error: # pragma: no cover
14
14
  raise ImportError(
15
15
  'Please install the `openai` package to use the DeepSeek provider, '
16
- "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
16
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
17
17
  ) from _import_error
18
18
 
19
19
  from . import Provider
@@ -2,7 +2,6 @@ from __future__ import annotations as _annotations
2
2
 
3
3
  import functools
4
4
  from collections.abc import AsyncGenerator, Mapping
5
- from datetime import datetime, timedelta
6
5
  from pathlib import Path
7
6
  from typing import Literal, overload
8
7
 
@@ -22,15 +21,12 @@ try:
22
21
  except ImportError as _import_error:
23
22
  raise ImportError(
24
23
  'Please install the `google-auth` package to use the Google Vertex AI provider, '
25
- "you can use the `vertexai` optional group — `pip install 'pydantic-ai-slim[vertexai]'`"
24
+ 'you can use the `vertexai` optional group — `pip install "pydantic-ai-slim[vertexai]"`'
26
25
  ) from _import_error
27
26
 
28
27
 
29
28
  __all__ = ('GoogleVertexProvider',)
30
29
 
31
- # default expiry is 3600 seconds
32
- MAX_TOKEN_AGE = timedelta(seconds=3000)
33
-
34
30
 
35
31
  class GoogleVertexProvider(Provider[httpx.AsyncClient]):
36
32
  """Provider for Vertex AI API."""
@@ -131,19 +127,21 @@ class _VertexAIAuth(httpx.Auth):
131
127
  self.region = region
132
128
 
133
129
  self.credentials = None
134
- self.token_created: datetime | None = None
135
130
 
136
131
  async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
137
132
  if self.credentials is None:
138
133
  self.credentials = await self._get_credentials()
139
- if self.credentials.token is None or self._token_expired(): # type: ignore[reportUnknownMemberType]
140
- await anyio.to_thread.run_sync(self._refresh_token)
141
- self.token_created = datetime.now()
134
+ if self.credentials.token is None: # type: ignore[reportUnknownMemberType]
135
+ await self._refresh_token()
142
136
  request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
143
-
144
137
  # NOTE: This workaround is in place because we might get the project_id from the credentials.
145
138
  request.url = httpx.URL(str(request.url).replace('projects/None', f'projects/{self.project_id}'))
146
- yield request
139
+ response = yield request
140
+
141
+ if response.status_code == 401:
142
+ await self._refresh_token()
143
+ request.headers['Authorization'] = f'Bearer {self.credentials.token}' # type: ignore[reportUnknownMemberType]
144
+ yield request
147
145
 
148
146
  async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials:
149
147
  if self.service_account_file is not None:
@@ -166,15 +164,9 @@ class _VertexAIAuth(httpx.Auth):
166
164
  self.project_id = creds_project_id
167
165
  return creds
168
166
 
169
- def _token_expired(self) -> bool:
170
- if self.token_created is None:
171
- return True
172
- else:
173
- return (datetime.now() - self.token_created) > MAX_TOKEN_AGE
174
-
175
- def _refresh_token(self) -> str: # pragma: no cover
167
+ async def _refresh_token(self) -> str: # pragma: no cover
176
168
  assert self.credentials is not None
177
- self.credentials.refresh(Request()) # type: ignore[reportUnknownMemberType]
169
+ await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType]
178
170
  assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType]
179
171
  return self.credentials.token
180
172
 
@@ -12,7 +12,7 @@ try:
12
12
  except ImportError as _import_error: # pragma: no cover
13
13
  raise ImportError(
14
14
  'Please install the `groq` package to use the Groq provider, '
15
- "you can use the `groq` optional group — `pip install 'pydantic-ai-slim[groq]'`"
15
+ 'you can use the `groq` optional group — `pip install "pydantic-ai-slim[groq]"`'
16
16
  ) from _import_error
17
17
 
18
18
 
@@ -57,17 +57,21 @@ class GroqProvider(Provider[AsyncGroq]):
57
57
  client to use. If provided, `api_key` and `http_client` must be `None`.
58
58
  http_client: An existing `AsyncHTTPClient` to use for making HTTP requests.
59
59
  """
60
- api_key = api_key or os.environ.get('GROQ_API_KEY')
61
-
62
- if api_key is None and groq_client is None:
63
- raise ValueError(
64
- 'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
65
- 'to use the Groq provider.'
66
- )
67
-
68
60
  if groq_client is not None:
61
+ assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
62
+ assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
69
63
  self._client = groq_client
70
- elif http_client is not None:
71
- self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
72
64
  else:
73
- self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
65
+ api_key = api_key or os.environ.get('GROQ_API_KEY')
66
+
67
+ if api_key is None:
68
+ raise ValueError(
69
+ 'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
70
+ 'to use the Groq provider.'
71
+ )
72
+ elif http_client is not None:
73
+ self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
74
+ else:
75
+ self._client = AsyncGroq(
76
+ base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client()
77
+ )
@@ -12,7 +12,7 @@ try:
12
12
  except ImportError as e: # pragma: no cover
13
13
  raise ImportError(
14
14
  'Please install the `mistral` package to use the Mistral provider, '
15
- "you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
15
+ 'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`'
16
16
  ) from e
17
17
 
18
18
 
@@ -55,19 +55,19 @@ class MistralProvider(Provider[Mistral]):
55
55
  mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
56
56
  http_client: An existing async client to use for making HTTP requests.
57
57
  """
58
- api_key = api_key or os.environ.get('MISTRAL_API_KEY')
59
-
60
- if api_key is None and mistral_client is None:
61
- raise ValueError(
62
- 'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
63
- 'to use the Mistral provider.'
64
- )
65
-
66
58
  if mistral_client is not None:
67
59
  assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
68
60
  assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
69
61
  self._client = mistral_client
70
- elif http_client is not None:
71
- self._client = Mistral(api_key=api_key, async_client=http_client)
72
62
  else:
73
- self._client = Mistral(api_key=api_key, async_client=cached_async_http_client())
63
+ api_key = api_key or os.environ.get('MISTRAL_API_KEY')
64
+
65
+ if api_key is None:
66
+ raise ValueError(
67
+ 'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
68
+ 'to use the Mistral provider.'
69
+ )
70
+ elif http_client is not None:
71
+ self._client = Mistral(api_key=api_key, async_client=http_client)
72
+ else:
73
+ self._client = Mistral(api_key=api_key, async_client=cached_async_http_client())
@@ -11,7 +11,7 @@ try:
11
11
  except ImportError as _import_error: # pragma: no cover
12
12
  raise ImportError(
13
13
  'Please install the `openai` package to use the OpenAI provider, '
14
- "you can use the `openai` optional group — `pip install 'pydantic-ai-slim[openai]'`"
14
+ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`'
15
15
  ) from _import_error
16
16
 
17
17
 
@@ -7,7 +7,8 @@ from dataclasses import dataclass, field
7
7
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
- from pydantic_core import SchemaValidator
10
+ from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
11
+ from pydantic_core import SchemaValidator, core_schema
11
12
  from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
12
13
 
13
14
  from . import _pydantic, _utils, messages as _messages, models
@@ -142,6 +143,22 @@ DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
142
143
  A = TypeVar('A')
143
144
 
144
145
 
146
+ class GenerateToolJsonSchema(GenerateJsonSchema):
147
+ def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
148
+ s = super().typed_dict_schema(schema)
149
+ total = schema.get('total')
150
+ if total is not None:
151
+ s['additionalProperties'] = not total
152
+ return s
153
+
154
+ def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
155
+ # Remove largely-useless property titles
156
+ s = super()._named_required_fields_schema(named_required_fields)
157
+ for p in s.get('properties', {}):
158
+ s['properties'][p].pop('title', None)
159
+ return s
160
+
161
+
145
162
  @dataclass(init=False)
146
163
  class Tool(Generic[AgentDepsT]):
147
164
  """A tool function for an agent."""
@@ -176,6 +193,7 @@ class Tool(Generic[AgentDepsT]):
176
193
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
177
194
  docstring_format: DocstringFormat = 'auto',
178
195
  require_parameter_descriptions: bool = False,
196
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
179
197
  ):
180
198
  """Create a new tool instance.
181
199
 
@@ -225,11 +243,14 @@ class Tool(Generic[AgentDepsT]):
225
243
  docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
226
244
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
227
245
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
246
+ schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
228
247
  """
229
248
  if takes_ctx is None:
230
249
  takes_ctx = _pydantic.takes_ctx(function)
231
250
 
232
- f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
251
+ f = _pydantic.function_schema(
252
+ function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator
253
+ )
233
254
  self.function = function
234
255
  self.takes_ctx = takes_ctx
235
256
  self.max_retries = max_retries
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai-slim"
7
- version = "0.0.41"
7
+ version = "0.0.43"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
9
9
  authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
10
10
  license = "MIT"
@@ -36,7 +36,7 @@ dependencies = [
36
36
  "griffe>=1.3.2",
37
37
  "httpx>=0.27",
38
38
  "pydantic>=2.10",
39
- "pydantic-graph==0.0.41",
39
+ "pydantic-graph==0.0.43",
40
40
  "exceptiongroup; python_version < '3.11'",
41
41
  "opentelemetry-api>=1.28.0",
42
42
  "typing-inspection>=0.4.0",
@@ -50,7 +50,7 @@ openai = ["openai>=1.65.1"]
50
50
  cohere = ["cohere>=5.13.11"]
51
51
  vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
52
52
  anthropic = ["anthropic>=0.49.0"]
53
- groq = ["groq>=0.12.0"]
53
+ groq = ["groq>=0.15.0"]
54
54
  mistral = ["mistralai>=1.2.5"]
55
55
  bedrock = ["boto3>=1.34.116"]
56
56
  # Tools
@@ -58,6 +58,8 @@ duckduckgo = ["duckduckgo-search>=7.0.0"]
58
58
  tavily = ["tavily-python>=0.5.0"]
59
59
  # CLI
60
60
  cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
61
+ # MCP
62
+ mcp = ["mcp>=1.4.1; python_version >= '3.10'"]
61
63
 
62
64
  [dependency-groups]
63
65
  dev = [
@@ -75,6 +77,9 @@ dev = [
75
77
  "boto3-stubs[bedrock-runtime]",
76
78
  ]
77
79
 
80
+ [tool.hatch.metadata]
81
+ allow-direct-references = true
82
+
78
83
  [project.scripts]
79
84
  pai = "pydantic_ai._cli:app"
80
85