pydantic-ai-slim 0.0.40__tar.gz → 0.0.42__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (50) hide show
  1. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/.gitignore +1 -1
  2. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/PKG-INFO +5 -3
  3. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_agent_graph.py +61 -7
  4. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_pydantic.py +6 -5
  5. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/agent.py +53 -6
  6. pydantic_ai_slim-0.0.42/pydantic_ai/mcp.py +198 -0
  7. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/anthropic.py +31 -2
  8. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/fallback.py +13 -8
  9. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/gemini.py +3 -5
  10. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/groq.py +2 -3
  11. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/mistral.py +37 -5
  12. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/openai.py +2 -3
  13. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/__init__.py +8 -0
  14. pydantic_ai_slim-0.0.42/pydantic_ai/providers/anthropic.py +74 -0
  15. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/groq.py +15 -11
  16. pydantic_ai_slim-0.0.42/pydantic_ai/providers/mistral.py +73 -0
  17. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/tools.py +23 -2
  18. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pyproject.toml +8 -3
  19. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/README.md +0 -0
  20. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/__init__.py +0 -0
  21. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_cli.py +0 -0
  22. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_griffe.py +0 -0
  23. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_parts_manager.py +0 -0
  24. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_result.py +0 -0
  25. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_system_prompt.py +0 -0
  26. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/_utils.py +0 -0
  27. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/common_tools/__init__.py +0 -0
  28. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  29. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/common_tools/tavily.py +0 -0
  30. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/exceptions.py +0 -0
  31. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/format_as_xml.py +0 -0
  32. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/messages.py +0 -0
  33. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/__init__.py +0 -0
  34. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/bedrock.py +0 -0
  35. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/cohere.py +0 -0
  36. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/function.py +0 -0
  37. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/instrumented.py +0 -0
  38. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/test.py +0 -0
  39. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/vertexai.py +0 -0
  40. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/models/wrapper.py +0 -0
  41. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/azure.py +0 -0
  42. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/bedrock.py +0 -0
  43. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/deepseek.py +0 -0
  44. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/google_gla.py +0 -0
  45. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/google_vertex.py +0 -0
  46. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/providers/openai.py +0 -0
  47. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/py.typed +0 -0
  48. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/result.py +0 -0
  49. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/settings.py +0 -0
  50. {pydantic_ai_slim-0.0.40 → pydantic_ai_slim-0.0.42}/pydantic_ai/usage.py +0 -0
@@ -1,5 +1,4 @@
1
1
  site
2
- .python-version
3
2
  .venv
4
3
  dist
5
4
  __pycache__
@@ -16,3 +15,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
16
15
  .vscode/
17
16
  /question_graph_history.json
18
17
  /docs-site/.wrangler/
18
+ /CLAUDE.md
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.0.40
3
+ Version: 0.0.42
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>
6
6
  License-Expression: MIT
@@ -29,7 +29,7 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
29
29
  Requires-Dist: griffe>=1.3.2
30
30
  Requires-Dist: httpx>=0.27
31
31
  Requires-Dist: opentelemetry-api>=1.28.0
32
- Requires-Dist: pydantic-graph==0.0.40
32
+ Requires-Dist: pydantic-graph==0.0.42
33
33
  Requires-Dist: pydantic>=2.10
34
34
  Requires-Dist: typing-inspection>=0.4.0
35
35
  Provides-Extra: anthropic
@@ -45,9 +45,11 @@ Requires-Dist: cohere>=5.13.11; extra == 'cohere'
45
45
  Provides-Extra: duckduckgo
46
46
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
47
47
  Provides-Extra: groq
48
- Requires-Dist: groq>=0.12.0; extra == 'groq'
48
+ Requires-Dist: groq>=0.15.0; extra == 'groq'
49
49
  Provides-Extra: logfire
50
50
  Requires-Dist: logfire>=2.3; extra == 'logfire'
51
+ Provides-Extra: mcp
52
+ Requires-Dist: mcp>=1.4.1; (python_version >= '3.10') and extra == 'mcp'
51
53
  Provides-Extra: mistral
52
54
  Requires-Dist: mistralai>=1.2.5; extra == 'mistral'
53
55
  Provides-Extra: openai
@@ -7,7 +7,7 @@ from collections.abc import AsyncIterator, Iterator, Sequence
7
7
  from contextlib import asynccontextmanager, contextmanager
8
8
  from contextvars import ContextVar
9
9
  from dataclasses import field
10
- from typing import Any, Generic, Literal, Union, cast
10
+ from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast
11
11
 
12
12
  from opentelemetry.trace import Span, Tracer
13
13
  from typing_extensions import TypeGuard, TypeVar, assert_never
@@ -27,11 +27,10 @@ from . import (
27
27
  from .models.instrumented import InstrumentedModel
28
28
  from .result import ResultDataT
29
29
  from .settings import ModelSettings, merge_model_settings
30
- from .tools import (
31
- RunContext,
32
- Tool,
33
- ToolDefinition,
34
- )
30
+ from .tools import RunContext, Tool, ToolDefinition
31
+
32
+ if TYPE_CHECKING:
33
+ from .mcp import MCPServer
35
34
 
36
35
  __all__ = (
37
36
  'GraphAgentState',
@@ -94,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]):
94
93
  result_validators: list[_result.ResultValidator[DepsT, ResultDataT]]
95
94
 
96
95
  function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
96
+ mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
97
97
 
98
98
  run_span: Span
99
99
  tracer: Tracer
@@ -219,7 +219,17 @@ async def _prepare_request_parameters(
219
219
  if tool_def := await tool.prepare_tool_def(ctx):
220
220
  function_tool_defs.append(tool_def)
221
221
 
222
- await asyncio.gather(*map(add_tool, ctx.deps.function_tools.values()))
222
+ async def add_mcp_server_tools(server: MCPServer) -> None:
223
+ if not server.is_running:
224
+ raise exceptions.UserError(f'MCP server is not running: {server}')
225
+ tool_defs = await server.list_tools()
226
+ # TODO(Marcelo): We should check if the tool names are unique. If not, we should raise an error.
227
+ function_tool_defs.extend(tool_defs)
228
+
229
+ await asyncio.gather(
230
+ *map(add_tool, ctx.deps.function_tools.values()),
231
+ *map(add_mcp_server_tools, ctx.deps.mcp_servers),
232
+ )
223
233
 
224
234
  result_schema = ctx.deps.result_schema
225
235
  return models.ModelRequestParameters(
@@ -594,6 +604,21 @@ async def process_function_tools(
594
604
  yield event
595
605
  call_index_to_event_id[len(calls_to_run)] = event.call_id
596
606
  calls_to_run.append((tool, call))
607
+ elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx):
608
+ if stub_function_tools:
609
+ # TODO(Marcelo): We should add coverage for this part of the code.
610
+ output_parts.append( # pragma: no cover
611
+ _messages.ToolReturnPart(
612
+ tool_name=call.tool_name,
613
+ content='Tool not executed - a final result was already processed.',
614
+ tool_call_id=call.tool_call_id,
615
+ )
616
+ )
617
+ else:
618
+ event = _messages.FunctionToolCallEvent(call)
619
+ yield event
620
+ call_index_to_event_id[len(calls_to_run)] = event.call_id
621
+ calls_to_run.append((mcp_tool, call))
597
622
  elif result_schema is not None and call.tool_name in result_schema.tools:
598
623
  # if tool_name is in _result_schema, it means we found a result tool but an error occurred in
599
624
  # validation, we don't add another part here
@@ -641,6 +666,35 @@ async def process_function_tools(
641
666
  output_parts.append(results_by_index[k])
642
667
 
643
668
 
669
+ async def _tool_from_mcp_server(
670
+ tool_name: str,
671
+ ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
672
+ ) -> Tool[DepsT] | None:
673
+ """Call each MCP server to find the tool with the given name.
674
+
675
+ Args:
676
+ tool_name: The name of the tool to find.
677
+ ctx: The current run context.
678
+
679
+ Returns:
680
+ The tool with the given name, or `None` if no tool with the given name is found.
681
+ """
682
+
683
+ async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
684
+ # There's no normal situation where the server will not be running at this point, we check just in case
685
+ # some weird edge case occurs.
686
+ if not server.is_running: # pragma: no cover
687
+ raise exceptions.UserError(f'MCP server is not running: {server}')
688
+ result = await server.call_tool(tool_name, args)
689
+ return result
690
+
691
+ for server in ctx.deps.mcp_servers:
692
+ tools = await server.list_tools()
693
+ if tool_name in {tool.name for tool in tools}:
694
+ return Tool(name=tool_name, function=run_tool, takes_ctx=True)
695
+ return None
696
+
697
+
644
698
  def _unknown_tool(
645
699
  tool_name: str,
646
700
  ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
@@ -44,6 +44,7 @@ def function_schema( # noqa: C901
44
44
  takes_ctx: bool,
45
45
  docstring_format: DocstringFormat,
46
46
  require_parameter_descriptions: bool,
47
+ schema_generator: type[GenerateJsonSchema],
47
48
  ) -> FunctionSchema:
48
49
  """Build a Pydantic validator and JSON schema from a tool function.
49
50
 
@@ -52,6 +53,7 @@ def function_schema( # noqa: C901
52
53
  takes_ctx: Whether the function takes a `RunContext` first argument.
53
54
  docstring_format: The docstring format to use.
54
55
  require_parameter_descriptions: Whether to require descriptions for all tool function parameters.
56
+ schema_generator: The JSON schema generator class to use.
55
57
 
56
58
  Returns:
57
59
  A `FunctionSchema` instance.
@@ -150,14 +152,12 @@ def function_schema( # noqa: C901
150
152
  )
151
153
  # PluggableSchemaValidator is api compatible with SchemaValidator
152
154
  schema_validator = cast(SchemaValidator, schema_validator)
153
- json_schema = GenerateJsonSchema().generate(schema)
155
+ json_schema = schema_generator().generate(schema)
154
156
 
155
157
  # workaround for https://github.com/pydantic/pydantic/issues/10785
156
- # if we build a custom TypeDict schema (matches when `single_arg_name is None`), we manually set
158
+ # if we build a custom TypedDict schema (matches when `single_arg_name is None`), we manually set
157
159
  # `additionalProperties` in the JSON Schema
158
- if single_arg_name is None:
159
- json_schema['additionalProperties'] = bool(var_kwargs_schema)
160
- elif not description:
160
+ if single_arg_name is not None and not description:
161
161
  # if the tool description is not set, and we have a single parameter, take the description from that
162
162
  # and set it on the tool
163
163
  description = json_schema.pop('description', None)
@@ -218,6 +218,7 @@ def _build_schema(
218
218
  td_schema = core_schema.typed_dict_schema(
219
219
  fields,
220
220
  config=core_config,
221
+ total=var_kwargs_schema is None,
221
222
  extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
222
223
  )
223
224
  return td_schema, None
@@ -3,12 +3,13 @@ from __future__ import annotations as _annotations
3
3
  import dataclasses
4
4
  import inspect
5
5
  from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
6
- from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
6
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
7
7
  from copy import deepcopy
8
8
  from types import FrameType
9
- from typing import Any, Callable, ClassVar, Generic, cast, final, overload
9
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
10
10
 
11
11
  from opentelemetry.trace import NoOpTracer, use_span
12
+ from pydantic.json_schema import GenerateJsonSchema
12
13
  from typing_extensions import TypeGuard, TypeVar, deprecated
13
14
 
14
15
  from pydantic_graph import End, Graph, GraphRun, GraphRunContext
@@ -31,6 +32,7 @@ from .settings import ModelSettings, merge_model_settings
31
32
  from .tools import (
32
33
  AgentDepsT,
33
34
  DocstringFormat,
35
+ GenerateToolJsonSchema,
34
36
  RunContext,
35
37
  Tool,
36
38
  ToolFuncContext,
@@ -47,6 +49,9 @@ CallToolsNode = _agent_graph.CallToolsNode
47
49
  ModelRequestNode = _agent_graph.ModelRequestNode
48
50
  UserPromptNode = _agent_graph.UserPromptNode
49
51
 
52
+ if TYPE_CHECKING:
53
+ from pydantic_ai.mcp import MCPServer
54
+
50
55
  __all__ = (
51
56
  'Agent',
52
57
  'AgentRun',
@@ -129,6 +134,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
129
134
  repr=False
130
135
  )
131
136
  _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False)
137
+ _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
132
138
  _default_retries: int = dataclasses.field(repr=False)
133
139
  _max_result_retries: int = dataclasses.field(repr=False)
134
140
  _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
@@ -148,6 +154,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
148
154
  result_tool_description: str | None = None,
149
155
  result_retries: int | None = None,
150
156
  tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
157
+ mcp_servers: Sequence[MCPServer] = (),
151
158
  defer_model_check: bool = False,
152
159
  end_strategy: EndStrategy = 'early',
153
160
  instrument: InstrumentationSettings | bool | None = None,
@@ -173,6 +180,8 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
173
180
  result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
174
181
  tools: Tools to register with the agent, you can also register tools via the decorators
175
182
  [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
183
+ mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
184
+ for each server you want the agent to connect to.
176
185
  defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
177
186
  it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
178
187
  which checks for the necessary environment variables. Set this to `false`
@@ -215,6 +224,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
215
224
 
216
225
  self._default_retries = retries
217
226
  self._max_result_retries = result_retries if result_retries is not None else retries
227
+ self._mcp_servers = mcp_servers
218
228
  for tool in tools:
219
229
  if isinstance(tool, Tool):
220
230
  self._register_tool(tool)
@@ -461,6 +471,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
461
471
  result_tools=self._result_schema.tool_defs() if self._result_schema else [],
462
472
  result_validators=result_validators,
463
473
  function_tools=self._function_tools,
474
+ mcp_servers=self._mcp_servers,
464
475
  run_span=run_span,
465
476
  tracer=tracer,
466
477
  )
@@ -927,6 +938,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
927
938
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
928
939
  docstring_format: DocstringFormat = 'auto',
929
940
  require_parameter_descriptions: bool = False,
941
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
930
942
  ) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
931
943
 
932
944
  def tool(
@@ -939,6 +951,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
939
951
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
940
952
  docstring_format: DocstringFormat = 'auto',
941
953
  require_parameter_descriptions: bool = False,
954
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
942
955
  ) -> Any:
943
956
  """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
944
957
 
@@ -980,6 +993,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
980
993
  docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
981
994
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
982
995
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
996
+ schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
983
997
  """
984
998
  if func is None:
985
999
 
@@ -988,7 +1002,14 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
988
1002
  ) -> ToolFuncContext[AgentDepsT, ToolParams]:
989
1003
  # noinspection PyTypeChecker
990
1004
  self._register_function(
991
- func_, True, name, retries, prepare, docstring_format, require_parameter_descriptions
1005
+ func_,
1006
+ True,
1007
+ name,
1008
+ retries,
1009
+ prepare,
1010
+ docstring_format,
1011
+ require_parameter_descriptions,
1012
+ schema_generator,
992
1013
  )
993
1014
  return func_
994
1015
 
@@ -996,7 +1017,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
996
1017
  else:
997
1018
  # noinspection PyTypeChecker
998
1019
  self._register_function(
999
- func, True, name, retries, prepare, docstring_format, require_parameter_descriptions
1020
+ func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1000
1021
  )
1001
1022
  return func
1002
1023
 
@@ -1013,6 +1034,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1013
1034
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
1014
1035
  docstring_format: DocstringFormat = 'auto',
1015
1036
  require_parameter_descriptions: bool = False,
1037
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1016
1038
  ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
1017
1039
 
1018
1040
  def tool_plain(
@@ -1025,6 +1047,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1025
1047
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
1026
1048
  docstring_format: DocstringFormat = 'auto',
1027
1049
  require_parameter_descriptions: bool = False,
1050
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1028
1051
  ) -> Any:
1029
1052
  """Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
1030
1053
 
@@ -1066,20 +1089,28 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1066
1089
  docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
1067
1090
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
1068
1091
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
1092
+ schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1069
1093
  """
1070
1094
  if func is None:
1071
1095
 
1072
1096
  def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
1073
1097
  # noinspection PyTypeChecker
1074
1098
  self._register_function(
1075
- func_, False, name, retries, prepare, docstring_format, require_parameter_descriptions
1099
+ func_,
1100
+ False,
1101
+ name,
1102
+ retries,
1103
+ prepare,
1104
+ docstring_format,
1105
+ require_parameter_descriptions,
1106
+ schema_generator,
1076
1107
  )
1077
1108
  return func_
1078
1109
 
1079
1110
  return tool_decorator
1080
1111
  else:
1081
1112
  self._register_function(
1082
- func, False, name, retries, prepare, docstring_format, require_parameter_descriptions
1113
+ func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1083
1114
  )
1084
1115
  return func
1085
1116
 
@@ -1092,6 +1123,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1092
1123
  prepare: ToolPrepareFunc[AgentDepsT] | None,
1093
1124
  docstring_format: DocstringFormat,
1094
1125
  require_parameter_descriptions: bool,
1126
+ schema_generator: type[GenerateJsonSchema],
1095
1127
  ) -> None:
1096
1128
  """Private utility to register a function as a tool."""
1097
1129
  retries_ = retries if retries is not None else self._default_retries
@@ -1103,6 +1135,7 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1103
1135
  prepare=prepare,
1104
1136
  docstring_format=docstring_format,
1105
1137
  require_parameter_descriptions=require_parameter_descriptions,
1138
+ schema_generator=schema_generator,
1106
1139
  )
1107
1140
  self._register_tool(tool)
1108
1141
 
@@ -1253,6 +1286,20 @@ class Agent(Generic[AgentDepsT, ResultDataT]):
1253
1286
  """
1254
1287
  return isinstance(node, End)
1255
1288
 
1289
+ @asynccontextmanager
1290
+ async def run_mcp_servers(self) -> AsyncIterator[None]:
1291
+ """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1292
+
1293
+ Returns: a context manager to start and shutdown the servers.
1294
+ """
1295
+ exit_stack = AsyncExitStack()
1296
+ try:
1297
+ for mcp_server in self._mcp_servers:
1298
+ await exit_stack.enter_async_context(mcp_server)
1299
+ yield
1300
+ finally:
1301
+ await exit_stack.aclose()
1302
+
1256
1303
 
1257
1304
  @dataclasses.dataclass(repr=False)
1258
1305
  class AgentRun(Generic[AgentDepsT, ResultDataT]):
@@ -0,0 +1,198 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import AsyncIterator, Sequence
5
+ from contextlib import AsyncExitStack, asynccontextmanager
6
+ from dataclasses import dataclass
7
+ from types import TracebackType
8
+ from typing import Any
9
+
10
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11
+ from mcp.types import JSONRPCMessage
12
+ from typing_extensions import Self
13
+
14
+ from pydantic_ai.tools import ToolDefinition
15
+
16
+ try:
17
+ from mcp.client.session import ClientSession
18
+ from mcp.client.sse import sse_client
19
+ from mcp.client.stdio import StdioServerParameters, stdio_client
20
+ from mcp.types import CallToolResult
21
+ except ImportError as _import_error:
22
+ raise ImportError(
23
+ 'Please install the `mcp` package to use the MCP server, '
24
+ "you can use the `mcp` optional group — `pip install 'pydantic-ai-slim[mcp]'`"
25
+ ) from _import_error
26
+
27
+ __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP'
28
+
29
+
30
+ class MCPServer(ABC):
31
+ """Base class for attaching agents to MCP servers.
32
+
33
+ See <https://modelcontextprotocol.io> for more information.
34
+ """
35
+
36
+ is_running: bool = False
37
+
38
+ _client: ClientSession
39
+ _read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception]
40
+ _write_stream: MemoryObjectSendStream[JSONRPCMessage]
41
+ _exit_stack: AsyncExitStack
42
+
43
+ @abstractmethod
44
+ @asynccontextmanager
45
+ async def client_streams(
46
+ self,
47
+ ) -> AsyncIterator[
48
+ tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
49
+ ]:
50
+ """Create the streams for the MCP server."""
51
+ raise NotImplementedError('MCP Server subclasses must implement this method.')
52
+ yield
53
+
54
+ async def list_tools(self) -> list[ToolDefinition]:
55
+ """Retrieve tools that are currently active on the server.
56
+
57
+ Note:
58
+ - We don't cache tools as they might change.
59
+ - We also don't subscribe to the server to avoid complexity.
60
+ """
61
+ tools = await self._client.list_tools()
62
+ return [
63
+ ToolDefinition(
64
+ name=tool.name,
65
+ description=tool.description or '',
66
+ parameters_json_schema=tool.inputSchema,
67
+ )
68
+ for tool in tools.tools
69
+ ]
70
+
71
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
72
+ """Call a tool on the server.
73
+
74
+ Args:
75
+ tool_name: The name of the tool to call.
76
+ arguments: The arguments to pass to the tool.
77
+
78
+ Returns:
79
+ The result of the tool call.
80
+ """
81
+ return await self._client.call_tool(tool_name, arguments)
82
+
83
+ async def __aenter__(self) -> Self:
84
+ self._exit_stack = AsyncExitStack()
85
+
86
+ self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
87
+ client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream)
88
+ self._client = await self._exit_stack.enter_async_context(client)
89
+
90
+ await self._client.initialize()
91
+ self.is_running = True
92
+ return self
93
+
94
+ async def __aexit__(
95
+ self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
96
+ ) -> bool | None:
97
+ await self._exit_stack.aclose()
98
+ self.is_running = False
99
+
100
+
101
+ @dataclass
102
+ class MCPServerStdio(MCPServer):
103
+ """Runs an MCP server in a subprocess and communicates with it over stdin/stdout.
104
+
105
+ This class implements the stdio transport from the MCP specification.
106
+ See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio> for more information.
107
+
108
+ !!! note
109
+ Using this class as an async context manager will start the server as a subprocess when entering the context,
110
+ and stop it when exiting the context.
111
+
112
+ Example:
113
+ ```python {py="3.10"}
114
+ from pydantic_ai import Agent
115
+ from pydantic_ai.mcp import MCPServerStdio
116
+
117
+ server = MCPServerStdio('npx', ['-y', '@pydantic/mcp-run-python', 'stdio']) # (1)!
118
+ agent = Agent('openai:gpt-4o', mcp_servers=[server])
119
+
120
+ async def main():
121
+ async with agent.run_mcp_servers(): # (2)!
122
+ ...
123
+ ```
124
+
125
+ 1. See [MCP Run Python](../mcp/run-python.md) for more information.
126
+ 2. This will start the server as a subprocess and connect to it.
127
+ """
128
+
129
+ command: str
130
+ """The command to run."""
131
+
132
+ args: Sequence[str]
133
+ """The arguments to pass to the command."""
134
+
135
+ env: dict[str, str] | None = None
136
+ """The environment variables the CLI server will have access to.
137
+
138
+ By default the subprocess will not inherit any environment variables from the parent process.
139
+ If you want to inherit the environment variables from the parent process, use `env=os.environ`.
140
+ """
141
+
142
+ @asynccontextmanager
143
+ async def client_streams(
144
+ self,
145
+ ) -> AsyncIterator[
146
+ tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
147
+ ]:
148
+ server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env)
149
+ async with stdio_client(server=server) as (read_stream, write_stream):
150
+ yield read_stream, write_stream
151
+
152
+
153
+ @dataclass
154
+ class MCPServerHTTP(MCPServer):
155
+ """An MCP server that connects over streamable HTTP connections.
156
+
157
+ This class implements the SSE transport from the MCP specification.
158
+ See <https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse> for more information.
159
+
160
+ The name "HTTP" is used since this implemented will be adapted in future to use the new
161
+ [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.
162
+
163
+ !!! note
164
+ Using this class as an async context manager will create a new pool of HTTP connections to connect
165
+ to a server which should already be running.
166
+
167
+ Example:
168
+ ```python {py="3.10"}
169
+ from pydantic_ai import Agent
170
+ from pydantic_ai.mcp import MCPServerHTTP
171
+
172
+ server = MCPServerHTTP('http://localhost:3001/sse') # (1)!
173
+ agent = Agent('openai:gpt-4o', mcp_servers=[server])
174
+
175
+ async def main():
176
+ async with agent.run_mcp_servers(): # (2)!
177
+ ...
178
+ ```
179
+
180
+ 1. E.g. you might be connecting to a server run with `npx @pydantic/mcp-run-python sse`,
181
+ see [MCP Run Python](../mcp/run-python.md) for more information.
182
+ 2. This will connect to a server running on `localhost:3001`.
183
+ """
184
+
185
+ url: str
186
+ """The URL of the SSE endpoint on the MCP server.
187
+
188
+ For example for a server running locally, this might be `http://localhost:3001/sse`.
189
+ """
190
+
191
+ @asynccontextmanager
192
+ async def client_streams(
193
+ self,
194
+ ) -> AsyncIterator[
195
+ tuple[MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage]]
196
+ ]: # pragma: no cover
197
+ async with sse_client(url=self.url) as (read_stream, write_stream):
198
+ yield read_stream, write_stream
@@ -11,7 +11,7 @@ from typing import Any, Literal, Union, cast, overload
11
11
 
12
12
  from anthropic.types import DocumentBlockParam
13
13
  from httpx import AsyncClient as AsyncHTTPClient
14
- from typing_extensions import assert_never
14
+ from typing_extensions import assert_never, deprecated
15
15
 
16
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
17
17
  from .._utils import guard_tool_call_id as _guard_tool_call_id
@@ -31,6 +31,7 @@ from ..messages import (
31
31
  ToolReturnPart,
32
32
  UserPromptPart,
33
33
  )
34
+ from ..providers import Provider, infer_provider
34
35
  from ..settings import ModelSettings
35
36
  from ..tools import ToolDefinition
36
37
  from . import Model, ModelRequestParameters, StreamedResponse, cached_async_http_client, check_allow_model_requests
@@ -111,10 +112,31 @@ class AnthropicModel(Model):
111
112
  _model_name: AnthropicModelName = field(repr=False)
112
113
  _system: str = field(default='anthropic', repr=False)
113
114
 
115
+ @overload
116
+ def __init__(
117
+ self,
118
+ model_name: AnthropicModelName,
119
+ *,
120
+ provider: Literal['anthropic'] | Provider[AsyncAnthropic] = 'anthropic',
121
+ ) -> None: ...
122
+
123
+ @deprecated('Use the `provider` parameter instead of `api_key`, `anthropic_client`, and `http_client`.')
124
+ @overload
125
+ def __init__(
126
+ self,
127
+ model_name: AnthropicModelName,
128
+ *,
129
+ provider: None = None,
130
+ api_key: str | None = None,
131
+ anthropic_client: AsyncAnthropic | None = None,
132
+ http_client: AsyncHTTPClient | None = None,
133
+ ) -> None: ...
134
+
114
135
  def __init__(
115
136
  self,
116
137
  model_name: AnthropicModelName,
117
138
  *,
139
+ provider: Literal['anthropic'] | Provider[AsyncAnthropic] | None = None,
118
140
  api_key: str | None = None,
119
141
  anthropic_client: AsyncAnthropic | None = None,
120
142
  http_client: AsyncHTTPClient | None = None,
@@ -124,6 +146,8 @@ class AnthropicModel(Model):
124
146
  Args:
125
147
  model_name: The name of the Anthropic model to use. List of model names available
126
148
  [here](https://docs.anthropic.com/en/docs/about-claude/models).
149
+ provider: The provider to use for the Anthropic API. Can be either the string 'anthropic' or an
150
+ instance of `Provider[AsyncAnthropic]`. If not provided, the other parameters will be used.
127
151
  api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
128
152
  will be used if available.
129
153
  anthropic_client: An existing
@@ -132,7 +156,12 @@ class AnthropicModel(Model):
132
156
  http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
133
157
  """
134
158
  self._model_name = model_name
135
- if anthropic_client is not None:
159
+
160
+ if provider is not None:
161
+ if isinstance(provider, str):
162
+ provider = infer_provider(provider)
163
+ self.client = provider.client
164
+ elif anthropic_client is not None:
136
165
  assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
137
166
  assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
138
167
  self.client = anthropic_client
@@ -70,14 +70,9 @@ class FallbackModel(Model):
70
70
  exceptions.append(exc)
71
71
  continue
72
72
  raise exc
73
- else:
74
- with suppress(Exception):
75
- span = get_current_span()
76
- if span.is_recording():
77
- attributes = getattr(span, 'attributes', {})
78
- if attributes.get('gen_ai.request.model') == self.model_name:
79
- span.set_attributes(InstrumentedModel.model_attributes(model))
80
- return response, usage
73
+
74
+ self._set_span_attributes(model)
75
+ return response, usage
81
76
 
82
77
  raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
83
78
 
@@ -102,11 +97,21 @@ class FallbackModel(Model):
102
97
  exceptions.append(exc)
103
98
  continue
104
99
  raise exc
100
+
101
+ self._set_span_attributes(model)
105
102
  yield response
106
103
  return
107
104
 
108
105
  raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
109
106
 
107
+ def _set_span_attributes(self, model: Model):
108
+ with suppress(Exception):
109
+ span = get_current_span()
110
+ if span.is_recording():
111
+ attributes = getattr(span, 'attributes', {})
112
+ if attributes.get('gen_ai.request.model') == self.model_name:
113
+ span.set_attributes(InstrumentedModel.model_attributes(model))
114
+
110
115
  @property
111
116
  def model_name(self) -> str:
112
117
  """The model name."""
@@ -139,11 +139,9 @@ class GeminiModel(Model):
139
139
 
140
140
  if provider is not None:
141
141
  if isinstance(provider, str):
142
- self._system = provider
143
- self.client = infer_provider(provider).client
144
- else:
145
- self._system = provider.name
146
- self.client = provider.client
142
+ provider = infer_provider(provider)
143
+ self._system = provider.name
144
+ self.client = provider.client
147
145
  self._url = str(self.client.base_url)
148
146
  else:
149
147
  if api_key is None:
@@ -138,9 +138,8 @@ class GroqModel(Model):
138
138
 
139
139
  if provider is not None:
140
140
  if isinstance(provider, str):
141
- self.client = infer_provider(provider).client
142
- else:
143
- self.client = provider.client
141
+ provider = infer_provider(provider)
142
+ self.client = provider.client
144
143
  elif groq_client is not None:
145
144
  assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
146
145
  assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
@@ -7,11 +7,11 @@ from contextlib import asynccontextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from datetime import datetime, timezone
9
9
  from itertools import chain
10
- from typing import Any, Callable, Literal, Union, cast
10
+ from typing import Any, Callable, Literal, Union, cast, overload
11
11
 
12
12
  import pydantic_core
13
13
  from httpx import AsyncClient as AsyncHTTPClient, Timeout
14
- from typing_extensions import assert_never
14
+ from typing_extensions import assert_never, deprecated
15
15
 
16
16
  from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
17
17
  from .._utils import now_utc as _now_utc
@@ -31,6 +31,7 @@ from ..messages import (
31
31
  ToolReturnPart,
32
32
  UserPromptPart,
33
33
  )
34
+ from ..providers import Provider, infer_provider
34
35
  from ..result import Usage
35
36
  from ..settings import ModelSettings
36
37
  from ..tools import ToolDefinition
@@ -112,10 +113,33 @@ class MistralModel(Model):
112
113
  _model_name: MistralModelName = field(repr=False)
113
114
  _system: str = field(default='mistral_ai', repr=False)
114
115
 
116
+ @overload
115
117
  def __init__(
116
118
  self,
117
119
  model_name: MistralModelName,
118
120
  *,
121
+ provider: Literal['mistral'] | Provider[Mistral] = 'mistral',
122
+ json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
123
+ ) -> None: ...
124
+
125
+ @overload
126
+ @deprecated('Use the `provider` parameter instead of `api_key`, `client` and `http_client`.')
127
+ def __init__(
128
+ self,
129
+ model_name: MistralModelName,
130
+ *,
131
+ provider: None = None,
132
+ api_key: str | Callable[[], str | None] | None = None,
133
+ client: Mistral | None = None,
134
+ http_client: AsyncHTTPClient | None = None,
135
+ json_mode_schema_prompt: str = """Answer in JSON Object, respect the format:\n```\n{schema}\n```\n""",
136
+ ) -> None: ...
137
+
138
+ def __init__(
139
+ self,
140
+ model_name: MistralModelName,
141
+ *,
142
+ provider: Literal['mistral'] | Provider[Mistral] | None = None,
119
143
  api_key: str | Callable[[], str | None] | None = None,
120
144
  client: Mistral | None = None,
121
145
  http_client: AsyncHTTPClient | None = None,
@@ -124,6 +148,9 @@ class MistralModel(Model):
124
148
  """Initialize a Mistral model.
125
149
 
126
150
  Args:
151
+ provider: The provider to use for authentication and API access. Can be either the string
152
+ 'mistral' or an instance of `Provider[Mistral]`. If not provided, a new provider will be
153
+ created using the other parameters.
127
154
  model_name: The name of the model to use.
128
155
  api_key: The API key to use for authentication, if unset uses `MISTRAL_API_KEY` environment variable.
129
156
  client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
@@ -133,17 +160,22 @@ class MistralModel(Model):
133
160
  self._model_name = model_name
134
161
  self.json_mode_schema_prompt = json_mode_schema_prompt
135
162
 
136
- if client is not None:
163
+ if provider is not None:
164
+ if isinstance(provider, str):
165
+ # TODO(Marcelo): We should add an integration test with VCR when I get the API key.
166
+ provider = infer_provider(provider) # pragma: no cover
167
+ self.client = provider.client
168
+ elif client is not None:
137
169
  assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
138
170
  assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
139
171
  self.client = client
140
172
  else:
141
- api_key = os.getenv('MISTRAL_API_KEY') if api_key is None else api_key
173
+ api_key = api_key or os.getenv('MISTRAL_API_KEY')
142
174
  self.client = Mistral(api_key=api_key, async_client=http_client or cached_async_http_client())
143
175
 
144
176
  @property
145
177
  def base_url(self) -> str:
146
- return str(self.client.sdk_configuration.get_server_details()[0])
178
+ return self.client.sdk_configuration.get_server_details()[0]
147
179
 
148
180
  async def request(
149
181
  self,
@@ -162,9 +162,8 @@ class OpenAIModel(Model):
162
162
 
163
163
  if provider is not None:
164
164
  if isinstance(provider, str):
165
- self.client = infer_provider(provider).client
166
- else:
167
- self.client = provider.client
165
+ provider = infer_provider(provider)
166
+ self.client = provider.client
168
167
  else: # pragma: no cover
169
168
  # This is a workaround for the OpenAI client requiring an API key, whilst locally served,
170
169
  # openai compatible models do not always need an API key, but a placeholder (non-empty) key is required.
@@ -69,5 +69,13 @@ def infer_provider(provider: str) -> Provider[Any]:
69
69
  from .groq import GroqProvider
70
70
 
71
71
  return GroqProvider()
72
+ elif provider == 'anthropic':
73
+ from .anthropic import AnthropicProvider
74
+
75
+ return AnthropicProvider()
76
+ elif provider == 'mistral':
77
+ from .mistral import MistralProvider
78
+
79
+ return MistralProvider()
72
80
  else: # pragma: no cover
73
81
  raise ValueError(f'Unknown provider: {provider}')
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ import httpx
7
+
8
+ from pydantic_ai.models import cached_async_http_client
9
+
10
+ try:
11
+ from anthropic import AsyncAnthropic
12
+ except ImportError as _import_error: # pragma: no cover
13
+ raise ImportError(
14
+ 'Please install the `anthropic` package to use the Anthropic provider, '
15
+ "you can use the `anthropic` optional group — `pip install 'pydantic-ai-slim[anthropic]'`"
16
+ ) from _import_error
17
+
18
+
19
+ from . import Provider
20
+
21
+
22
+ class AnthropicProvider(Provider[AsyncAnthropic]):
23
+ """Provider for Anthropic API."""
24
+
25
+ @property
26
+ def name(self) -> str:
27
+ return 'anthropic'
28
+
29
+ @property
30
+ def base_url(self) -> str:
31
+ return str(self._client.base_url)
32
+
33
+ @property
34
+ def client(self) -> AsyncAnthropic:
35
+ return self._client
36
+
37
+ @overload
38
+ def __init__(self, *, anthropic_client: AsyncAnthropic | None = None) -> None: ...
39
+
40
+ @overload
41
+ def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
42
+
43
+ def __init__(
44
+ self,
45
+ *,
46
+ api_key: str | None = None,
47
+ anthropic_client: AsyncAnthropic | None = None,
48
+ http_client: httpx.AsyncClient | None = None,
49
+ ) -> None:
50
+ """Create a new Anthropic provider.
51
+
52
+ Args:
53
+ api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
54
+ will be used if available.
55
+ anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python)
56
+ client to use. If provided, the `api_key` and `http_client` arguments will be ignored.
57
+ http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
58
+ """
59
+ if anthropic_client is not None:
60
+ assert http_client is None, 'Cannot provide both `anthropic_client` and `http_client`'
61
+ assert api_key is None, 'Cannot provide both `anthropic_client` and `api_key`'
62
+ self._client = anthropic_client
63
+ else:
64
+ api_key = api_key or os.environ.get('ANTHROPIC_API_KEY')
65
+ if api_key is None:
66
+ raise ValueError(
67
+ 'Set the `ANTHROPIC_API_KEY` environment variable or pass it via `AnthropicProvider(api_key=...)`'
68
+ 'to use the Anthropic provider.'
69
+ )
70
+
71
+ if http_client is not None:
72
+ self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
73
+ else:
74
+ self._client = AsyncAnthropic(api_key=api_key, http_client=cached_async_http_client())
@@ -57,17 +57,21 @@ class GroqProvider(Provider[AsyncGroq]):
57
57
  client to use. If provided, `api_key` and `http_client` must be `None`.
58
58
  http_client: An existing `AsyncHTTPClient` to use for making HTTP requests.
59
59
  """
60
- api_key = api_key or os.environ.get('GROQ_API_KEY')
61
-
62
- if api_key is None and groq_client is None:
63
- raise ValueError(
64
- 'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
65
- 'to use the Groq provider.'
66
- )
67
-
68
60
  if groq_client is not None:
61
+ assert http_client is None, 'Cannot provide both `groq_client` and `http_client`'
62
+ assert api_key is None, 'Cannot provide both `groq_client` and `api_key`'
69
63
  self._client = groq_client
70
- elif http_client is not None:
71
- self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
72
64
  else:
73
- self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client())
65
+ api_key = api_key or os.environ.get('GROQ_API_KEY')
66
+
67
+ if api_key is None:
68
+ raise ValueError(
69
+ 'Set the `GROQ_API_KEY` environment variable or pass it via `GroqProvider(api_key=...)`'
70
+ 'to use the Groq provider.'
71
+ )
72
+ elif http_client is not None:
73
+ self._client = AsyncGroq(base_url=self.base_url, api_key=api_key, http_client=http_client)
74
+ else:
75
+ self._client = AsyncGroq(
76
+ base_url=self.base_url, api_key=api_key, http_client=cached_async_http_client()
77
+ )
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ import os
4
+ from typing import overload
5
+
6
+ from httpx import AsyncClient as AsyncHTTPClient
7
+
8
+ from pydantic_ai.models import cached_async_http_client
9
+
10
+ try:
11
+ from mistralai import Mistral
12
+ except ImportError as e: # pragma: no cover
13
+ raise ImportError(
14
+ 'Please install the `mistral` package to use the Mistral provider, '
15
+ "you can use the `mistral` optional group — `pip install 'pydantic-ai-slim[mistral]'`"
16
+ ) from e
17
+
18
+
19
+ from . import Provider
20
+
21
+
22
+ class MistralProvider(Provider[Mistral]):
23
+ """Provider for Mistral API."""
24
+
25
+ @property
26
+ def name(self) -> str:
27
+ return 'mistral'
28
+
29
+ @property
30
+ def base_url(self) -> str:
31
+ return self.client.sdk_configuration.get_server_details()[0]
32
+
33
+ @property
34
+ def client(self) -> Mistral:
35
+ return self._client
36
+
37
+ @overload
38
+ def __init__(self, *, mistral_client: Mistral | None = None) -> None: ...
39
+
40
+ @overload
41
+ def __init__(self, *, api_key: str | None = None, http_client: AsyncHTTPClient | None = None) -> None: ...
42
+
43
+ def __init__(
44
+ self,
45
+ *,
46
+ api_key: str | None = None,
47
+ mistral_client: Mistral | None = None,
48
+ http_client: AsyncHTTPClient | None = None,
49
+ ) -> None:
50
+ """Create a new Mistral provider.
51
+
52
+ Args:
53
+ api_key: The API key to use for authentication, if not provided, the `MISTRAL_API_KEY` environment variable
54
+ will be used if available.
55
+ mistral_client: An existing `Mistral` client to use, if provided, `api_key` and `http_client` must be `None`.
56
+ http_client: An existing async client to use for making HTTP requests.
57
+ """
58
+ if mistral_client is not None:
59
+ assert http_client is None, 'Cannot provide both `mistral_client` and `http_client`'
60
+ assert api_key is None, 'Cannot provide both `mistral_client` and `api_key`'
61
+ self._client = mistral_client
62
+ else:
63
+ api_key = api_key or os.environ.get('MISTRAL_API_KEY')
64
+
65
+ if api_key is None:
66
+ raise ValueError(
67
+ 'Set the `MISTRAL_API_KEY` environment variable or pass it via `MistralProvider(api_key=...)`'
68
+ 'to use the Mistral provider.'
69
+ )
70
+ elif http_client is not None:
71
+ self._client = Mistral(api_key=api_key, async_client=http_client)
72
+ else:
73
+ self._client = Mistral(api_key=api_key, async_client=cached_async_http_client())
@@ -7,7 +7,8 @@ from dataclasses import dataclass, field
7
7
  from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
8
8
 
9
9
  from pydantic import ValidationError
10
- from pydantic_core import SchemaValidator
10
+ from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
11
+ from pydantic_core import SchemaValidator, core_schema
11
12
  from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar
12
13
 
13
14
  from . import _pydantic, _utils, messages as _messages, models
@@ -142,6 +143,22 @@ DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
142
143
  A = TypeVar('A')
143
144
 
144
145
 
146
+ class GenerateToolJsonSchema(GenerateJsonSchema):
147
+ def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
148
+ s = super().typed_dict_schema(schema)
149
+ total = schema.get('total')
150
+ if total is not None:
151
+ s['additionalProperties'] = not total
152
+ return s
153
+
154
+ def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
155
+ # Remove largely-useless property titles
156
+ s = super()._named_required_fields_schema(named_required_fields)
157
+ for p in s.get('properties', {}):
158
+ s['properties'][p].pop('title', None)
159
+ return s
160
+
161
+
145
162
  @dataclass(init=False)
146
163
  class Tool(Generic[AgentDepsT]):
147
164
  """A tool function for an agent."""
@@ -176,6 +193,7 @@ class Tool(Generic[AgentDepsT]):
176
193
  prepare: ToolPrepareFunc[AgentDepsT] | None = None,
177
194
  docstring_format: DocstringFormat = 'auto',
178
195
  require_parameter_descriptions: bool = False,
196
+ schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
179
197
  ):
180
198
  """Create a new tool instance.
181
199
 
@@ -225,11 +243,14 @@ class Tool(Generic[AgentDepsT]):
225
243
  docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
226
244
  Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
227
245
  require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
246
+ schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`.
228
247
  """
229
248
  if takes_ctx is None:
230
249
  takes_ctx = _pydantic.takes_ctx(function)
231
250
 
232
- f = _pydantic.function_schema(function, takes_ctx, docstring_format, require_parameter_descriptions)
251
+ f = _pydantic.function_schema(
252
+ function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator
253
+ )
233
254
  self.function = function
234
255
  self.takes_ctx = takes_ctx
235
256
  self.max_retries = max_retries
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "pydantic-ai-slim"
7
- version = "0.0.40"
7
+ version = "0.0.42"
8
8
  description = "Agent Framework / shim to use Pydantic with LLMs, slim package"
9
9
  authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }]
10
10
  license = "MIT"
@@ -36,7 +36,7 @@ dependencies = [
36
36
  "griffe>=1.3.2",
37
37
  "httpx>=0.27",
38
38
  "pydantic>=2.10",
39
- "pydantic-graph==0.0.40",
39
+ "pydantic-graph==0.0.42",
40
40
  "exceptiongroup; python_version < '3.11'",
41
41
  "opentelemetry-api>=1.28.0",
42
42
  "typing-inspection>=0.4.0",
@@ -50,7 +50,7 @@ openai = ["openai>=1.65.1"]
50
50
  cohere = ["cohere>=5.13.11"]
51
51
  vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
52
52
  anthropic = ["anthropic>=0.49.0"]
53
- groq = ["groq>=0.12.0"]
53
+ groq = ["groq>=0.15.0"]
54
54
  mistral = ["mistralai>=1.2.5"]
55
55
  bedrock = ["boto3>=1.34.116"]
56
56
  # Tools
@@ -58,6 +58,8 @@ duckduckgo = ["duckduckgo-search>=7.0.0"]
58
58
  tavily = ["tavily-python>=0.5.0"]
59
59
  # CLI
60
60
  cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
61
+ # MCP
62
+ mcp = ["mcp>=1.4.1; python_version >= '3.10'"]
61
63
 
62
64
  [dependency-groups]
63
65
  dev = [
@@ -75,6 +77,9 @@ dev = [
75
77
  "boto3-stubs[bedrock-runtime]",
76
78
  ]
77
79
 
80
+ [tool.hatch.metadata]
81
+ allow-direct-references = true
82
+
78
83
  [project.scripts]
79
84
  pai = "pydantic_ai._cli:app"
80
85