pydantic-ai-slim 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pydantic-ai-slim might be problematic. Click here for more details.

Files changed (79) hide show
  1. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/.gitignore +1 -0
  2. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/PKG-INFO +4 -4
  3. pydantic_ai_slim-0.3.2/pydantic_ai/_mcp.py +123 -0
  4. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/agent.py +10 -1
  5. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/mcp.py +144 -115
  6. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/messages.py +3 -1
  7. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/__init__.py +6 -0
  8. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/gemini.py +3 -0
  9. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/google.py +3 -0
  10. pydantic_ai_slim-0.3.2/pydantic_ai/models/mcp_sampling.py +95 -0
  11. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/settings.py +1 -0
  12. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pyproject.toml +1 -1
  13. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/LICENSE +0 -0
  14. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/README.md +0 -0
  15. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/__init__.py +0 -0
  16. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/__main__.py +0 -0
  17. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_a2a.py +0 -0
  18. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_agent_graph.py +0 -0
  19. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_cli.py +0 -0
  20. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_function_schema.py +0 -0
  21. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_griffe.py +0 -0
  22. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_output.py +0 -0
  23. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_parts_manager.py +0 -0
  24. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_system_prompt.py +0 -0
  25. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_thinking_part.py +0 -0
  26. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_utils.py +0 -0
  27. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/common_tools/__init__.py +0 -0
  28. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/common_tools/duckduckgo.py +0 -0
  29. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/common_tools/tavily.py +0 -0
  30. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/direct.py +0 -0
  31. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/exceptions.py +0 -0
  32. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/ext/__init__.py +0 -0
  33. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/ext/langchain.py +0 -0
  34. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/format_as_xml.py +0 -0
  35. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/format_prompt.py +0 -0
  36. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/anthropic.py +0 -0
  37. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/bedrock.py +0 -0
  38. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/cohere.py +0 -0
  39. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/fallback.py +0 -0
  40. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/function.py +0 -0
  41. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/groq.py +0 -0
  42. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/instrumented.py +0 -0
  43. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/mistral.py +0 -0
  44. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/openai.py +0 -0
  45. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/test.py +0 -0
  46. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/wrapper.py +0 -0
  47. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/__init__.py +0 -0
  48. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/_json_schema.py +0 -0
  49. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/amazon.py +0 -0
  50. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/anthropic.py +0 -0
  51. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/cohere.py +0 -0
  52. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/deepseek.py +0 -0
  53. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/google.py +0 -0
  54. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/grok.py +0 -0
  55. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/meta.py +0 -0
  56. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/mistral.py +0 -0
  57. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/openai.py +0 -0
  58. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/qwen.py +0 -0
  59. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/__init__.py +0 -0
  60. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/anthropic.py +0 -0
  61. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/azure.py +0 -0
  62. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/bedrock.py +0 -0
  63. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/cohere.py +0 -0
  64. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/deepseek.py +0 -0
  65. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/fireworks.py +0 -0
  66. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/google.py +0 -0
  67. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/google_gla.py +0 -0
  68. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/google_vertex.py +0 -0
  69. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/grok.py +0 -0
  70. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/groq.py +0 -0
  71. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/heroku.py +0 -0
  72. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/mistral.py +0 -0
  73. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/openai.py +0 -0
  74. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/openrouter.py +0 -0
  75. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/together.py +0 -0
  76. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/py.typed +0 -0
  77. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/result.py +0 -0
  78. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/tools.py +0 -0
  79. {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/usage.py +0 -0
@@ -19,3 +19,4 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
19
19
  node_modules/
20
20
  **.idea/
21
21
  .coverage*
22
+ /test_tmp/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydantic-ai-slim
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
5
5
  Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
6
6
  License-Expression: MIT
@@ -30,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
30
30
  Requires-Dist: griffe>=1.3.2
31
31
  Requires-Dist: httpx>=0.27
32
32
  Requires-Dist: opentelemetry-api>=1.28.0
33
- Requires-Dist: pydantic-graph==0.3.0
33
+ Requires-Dist: pydantic-graph==0.3.2
34
34
  Requires-Dist: pydantic>=2.10
35
35
  Requires-Dist: typing-inspection>=0.4.0
36
36
  Provides-Extra: a2a
37
- Requires-Dist: fasta2a==0.3.0; extra == 'a2a'
37
+ Requires-Dist: fasta2a==0.3.2; extra == 'a2a'
38
38
  Provides-Extra: anthropic
39
39
  Requires-Dist: anthropic>=0.52.0; extra == 'anthropic'
40
40
  Provides-Extra: bedrock
@@ -48,7 +48,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
48
48
  Provides-Extra: duckduckgo
49
49
  Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
50
50
  Provides-Extra: evals
51
- Requires-Dist: pydantic-evals==0.3.0; extra == 'evals'
51
+ Requires-Dist: pydantic-evals==0.3.2; extra == 'evals'
52
52
  Provides-Extra: google
53
53
  Requires-Dist: google-genai>=1.15.0; extra == 'google'
54
54
  Provides-Extra: groq
@@ -0,0 +1,123 @@
1
+ import base64
2
+ from collections.abc import Sequence
3
+ from typing import Literal
4
+
5
+ from . import exceptions, messages
6
+
7
+ try:
8
+ from mcp import types as mcp_types
9
+ except ImportError as _import_error:
10
+ raise ImportError(
11
+ 'Please install the `mcp` package to use the MCP server, '
12
+ 'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
13
+ ) from _import_error
14
+
15
+
16
+ def map_from_mcp_params(params: mcp_types.CreateMessageRequestParams) -> list[messages.ModelMessage]:
17
+ """Convert from MCP create message request parameters to pydantic-ai messages."""
18
+ pai_messages: list[messages.ModelMessage] = []
19
+ request_parts: list[messages.ModelRequestPart] = []
20
+ if params.systemPrompt:
21
+ request_parts.append(messages.SystemPromptPart(content=params.systemPrompt))
22
+ response_parts: list[messages.ModelResponsePart] = []
23
+ for msg in params.messages:
24
+ content = msg.content
25
+ if msg.role == 'user':
26
+ # if there are any response parts, add a response message wrapping them
27
+ if response_parts:
28
+ pai_messages.append(messages.ModelResponse(parts=response_parts))
29
+ response_parts = []
30
+
31
+ # TODO(Marcelo): We can reuse the `_map_tool_result_part` from the mcp module here.
32
+ if isinstance(content, mcp_types.TextContent):
33
+ user_part_content: str | Sequence[messages.UserContent] = content.text
34
+ else:
35
+ # image content
36
+ user_part_content = [
37
+ messages.BinaryContent(data=base64.b64decode(content.data), media_type=content.mimeType)
38
+ ]
39
+
40
+ request_parts.append(messages.UserPromptPart(content=user_part_content))
41
+ else:
42
+ # role is assistant
43
+ # if there are any request parts, add a request message wrapping them
44
+ if request_parts:
45
+ pai_messages.append(messages.ModelRequest(parts=request_parts))
46
+ request_parts = []
47
+
48
+ response_parts.append(map_from_sampling_content(content))
49
+
50
+ if response_parts:
51
+ pai_messages.append(messages.ModelResponse(parts=response_parts))
52
+ if request_parts:
53
+ pai_messages.append(messages.ModelRequest(parts=request_parts))
54
+ return pai_messages
55
+
56
+
57
+ def map_from_pai_messages(pai_messages: list[messages.ModelMessage]) -> tuple[str, list[mcp_types.SamplingMessage]]:
58
+ """Convert from pydantic-ai messages to MCP sampling messages.
59
+
60
+ Returns:
61
+ A tuple containing the system prompt and a list of sampling messages.
62
+ """
63
+ sampling_msgs: list[mcp_types.SamplingMessage] = []
64
+
65
+ def add_msg(
66
+ role: Literal['user', 'assistant'],
67
+ content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent,
68
+ ):
69
+ sampling_msgs.append(mcp_types.SamplingMessage(role=role, content=content))
70
+
71
+ system_prompt: list[str] = []
72
+ for pai_message in pai_messages:
73
+ if isinstance(pai_message, messages.ModelRequest):
74
+ if pai_message.instructions is not None:
75
+ system_prompt.append(pai_message.instructions)
76
+
77
+ for part in pai_message.parts:
78
+ if isinstance(part, messages.SystemPromptPart):
79
+ system_prompt.append(part.content)
80
+ if isinstance(part, messages.UserPromptPart):
81
+ if isinstance(part.content, str):
82
+ add_msg('user', mcp_types.TextContent(type='text', text=part.content))
83
+ else:
84
+ for chunk in part.content:
85
+ if isinstance(chunk, str):
86
+ add_msg('user', mcp_types.TextContent(type='text', text=chunk))
87
+ elif isinstance(chunk, messages.BinaryContent) and chunk.is_image:
88
+ add_msg(
89
+ 'user',
90
+ mcp_types.ImageContent(
91
+ type='image',
92
+ data=base64.b64decode(chunk.data).decode(),
93
+ mimeType=chunk.media_type,
94
+ ),
95
+ )
96
+ # TODO(Marcelo): Add support for audio content.
97
+ else:
98
+ raise NotImplementedError(f'Unsupported content type: {type(chunk)}')
99
+ else:
100
+ add_msg('assistant', map_from_model_response(pai_message))
101
+ return ''.join(system_prompt), sampling_msgs
102
+
103
+
104
+ def map_from_model_response(model_response: messages.ModelResponse) -> mcp_types.TextContent:
105
+ """Convert from a model response to MCP text content."""
106
+ text_parts: list[str] = []
107
+ for part in model_response.parts:
108
+ if isinstance(part, messages.TextPart):
109
+ text_parts.append(part.content)
110
+ # TODO(Marcelo): We should ignore ThinkingPart here.
111
+ else:
112
+ raise exceptions.UnexpectedModelBehavior(f'Unexpected part type: {type(part).__name__}, expected TextPart')
113
+ return mcp_types.TextContent(type='text', text=''.join(text_parts))
114
+
115
+
116
+ def map_from_sampling_content(
117
+ content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent,
118
+ ) -> messages.TextPart:
119
+ """Convert from sampling content to a pydantic-ai text part."""
120
+ if isinstance(content, mcp_types.TextContent): # pragma: no branch
121
+ return messages.TextPart(content=content.text)
122
+ else:
123
+ raise NotImplementedError('Image and Audio responses in sampling are not yet supported')
@@ -1691,14 +1691,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
1691
1691
  return isinstance(node, End)
1692
1692
 
1693
1693
  @asynccontextmanager
1694
- async def run_mcp_servers(self) -> AsyncIterator[None]:
1694
+ async def run_mcp_servers(
1695
+ self, model: models.Model | models.KnownModelName | str | None = None
1696
+ ) -> AsyncIterator[None]:
1695
1697
  """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
1696
1698
 
1697
1699
  Returns: a context manager to start and shutdown the servers.
1698
1700
  """
1701
+ try:
1702
+ sampling_model: models.Model | None = self._get_model(model)
1703
+ except exceptions.UserError: # pragma: no cover
1704
+ sampling_model = None
1705
+
1699
1706
  exit_stack = AsyncExitStack()
1700
1707
  try:
1701
1708
  for mcp_server in self._mcp_servers:
1709
+ if sampling_model is not None: # pragma: no branch
1710
+ mcp_server.sampling_model = sampling_model
1702
1711
  await exit_stack.enter_async_context(mcp_server)
1703
1712
  yield
1704
1713
  finally:
@@ -2,7 +2,6 @@ from __future__ import annotations
2
2
 
3
3
  import base64
4
4
  import functools
5
- import json
6
5
  from abc import ABC, abstractmethod
7
6
  from collections.abc import AsyncIterator, Awaitable, Sequence
8
7
  from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
@@ -13,41 +12,28 @@ from typing import Any, Callable
13
12
 
14
13
  import anyio
15
14
  import httpx
15
+ import pydantic_core
16
16
  from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
17
- from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
18
- from mcp.shared.exceptions import McpError
19
- from mcp.shared.message import SessionMessage
20
- from mcp.types import (
21
- AudioContent,
22
- BlobResourceContents,
23
- CallToolRequest,
24
- CallToolRequestParams,
25
- CallToolResult,
26
- ClientRequest,
27
- Content,
28
- EmbeddedResource,
29
- ImageContent,
30
- LoggingLevel,
31
- RequestParams,
32
- TextContent,
33
- TextResourceContents,
34
- )
35
17
  from typing_extensions import Self, assert_never, deprecated
36
18
 
37
- from pydantic_ai.exceptions import ModelRetry
38
- from pydantic_ai.messages import BinaryContent
39
- from pydantic_ai.tools import RunContext, ToolDefinition
40
-
41
19
  try:
42
- from mcp.client.session import ClientSession
20
+ from mcp import types as mcp_types
21
+ from mcp.client.session import ClientSession, LoggingFnT
43
22
  from mcp.client.sse import sse_client
44
23
  from mcp.client.stdio import StdioServerParameters, stdio_client
24
+ from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
25
+ from mcp.shared.context import RequestContext
26
+ from mcp.shared.exceptions import McpError
27
+ from mcp.shared.message import SessionMessage
45
28
  except ImportError as _import_error:
46
29
  raise ImportError(
47
30
  'Please install the `mcp` package to use the MCP server, '
48
31
  'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
49
32
  ) from _import_error
50
33
 
34
+ # after mcp imports so any import error maps to this file, not _mcp.py
35
+ from . import _mcp, exceptions, messages, models, tools
36
+
51
37
  __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
52
38
 
53
39
 
@@ -57,22 +43,22 @@ class MCPServer(ABC):
57
43
  See <https://modelcontextprotocol.io> for more information.
58
44
  """
59
45
 
60
- is_running: bool = False
46
+ # these fields should be re-defined by dataclass subclasses so they appear as fields {
61
47
  tool_prefix: str | None = None
62
- """A prefix to add to all tools that are registered with the server.
63
-
64
- If not empty, will include a trailing underscore(`_`).
65
-
66
- e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
67
- """
68
-
48
+ log_level: mcp_types.LoggingLevel | None = None
49
+ log_handler: LoggingFnT | None = None
50
+ timeout: float = 5
69
51
  process_tool_call: ProcessToolCallback | None = None
70
- """Hook to customize tool calling and optionally pass extra metadata."""
52
+ allow_sampling: bool = True
53
+ # } end of "abstract fields"
54
+
55
+ _running_count: int = 0
71
56
 
72
57
  _client: ClientSession
73
58
  _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
74
59
  _write_stream: MemoryObjectSendStream[SessionMessage]
75
60
  _exit_stack: AsyncExitStack
61
+ sampling_model: models.Model | None = None
76
62
 
77
63
  @abstractmethod
78
64
  @asynccontextmanager
@@ -88,14 +74,6 @@ class MCPServer(ABC):
88
74
  raise NotImplementedError('MCP Server subclasses must implement this method.')
89
75
  yield
90
76
 
91
- @abstractmethod
92
- def _get_log_level(self) -> LoggingLevel | None:
93
- """Get the log level for the MCP server."""
94
- raise NotImplementedError('MCP Server subclasses must implement this method.')
95
-
96
- def _get_client_initialize_timeout(self) -> float:
97
- return 5 # pragma: no cover
98
-
99
77
  def get_prefixed_tool_name(self, tool_name: str) -> str:
100
78
  """Get the tool name with prefix if `tool_prefix` is set."""
101
79
  return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name
@@ -104,21 +82,26 @@ class MCPServer(ABC):
104
82
  """Get original tool name without prefix for calling tools."""
105
83
  return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name
106
84
 
107
- async def list_tools(self) -> list[ToolDefinition]:
85
+ @property
86
+ def is_running(self) -> bool:
87
+ """Check if the MCP server is running."""
88
+ return bool(self._running_count)
89
+
90
+ async def list_tools(self) -> list[tools.ToolDefinition]:
108
91
  """Retrieve tools that are currently active on the server.
109
92
 
110
93
  Note:
111
94
  - We don't cache tools as they might change.
112
95
  - We also don't subscribe to the server to avoid complexity.
113
96
  """
114
- tools = await self._client.list_tools()
97
+ mcp_tools = await self._client.list_tools()
115
98
  return [
116
- ToolDefinition(
99
+ tools.ToolDefinition(
117
100
  name=self.get_prefixed_tool_name(tool.name),
118
101
  description=tool.description or '',
119
102
  parameters_json_schema=tool.inputSchema,
120
103
  )
121
- for tool in tools.tools
104
+ for tool in mcp_tools.tools
122
105
  ]
123
106
 
124
107
  async def call_tool(
@@ -143,44 +126,48 @@ class MCPServer(ABC):
143
126
  try:
144
127
  # meta param is not provided by session yet, so build and can send_request directly.
145
128
  result = await self._client.send_request(
146
- ClientRequest(
147
- CallToolRequest(
129
+ mcp_types.ClientRequest(
130
+ mcp_types.CallToolRequest(
148
131
  method='tools/call',
149
- params=CallToolRequestParams(
132
+ params=mcp_types.CallToolRequestParams(
150
133
  name=self.get_unprefixed_tool_name(tool_name),
151
134
  arguments=arguments,
152
- _meta=RequestParams.Meta(**metadata) if metadata else None,
135
+ _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
153
136
  ),
154
137
  )
155
138
  ),
156
- CallToolResult,
139
+ mcp_types.CallToolResult,
157
140
  )
158
141
  except McpError as e:
159
- raise ModelRetry(e.error.message)
142
+ raise exceptions.ModelRetry(e.error.message)
160
143
 
161
144
  content = [self._map_tool_result_part(part) for part in result.content]
162
145
 
163
146
  if result.isError:
164
147
  text = '\n'.join(str(part) for part in content)
165
- raise ModelRetry(text)
166
-
167
- if len(content) == 1:
168
- return content[0]
169
- return content
148
+ raise exceptions.ModelRetry(text)
149
+ else:
150
+ return content[0] if len(content) == 1 else content
170
151
 
171
152
  async def __aenter__(self) -> Self:
172
- self._exit_stack = AsyncExitStack()
173
-
174
- self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
175
- client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream)
176
- self._client = await self._exit_stack.enter_async_context(client)
153
+ if self._running_count == 0:
154
+ self._exit_stack = AsyncExitStack()
155
+
156
+ self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
157
+ client = ClientSession(
158
+ read_stream=self._read_stream,
159
+ write_stream=self._write_stream,
160
+ sampling_callback=self._sampling_callback if self.allow_sampling else None,
161
+ logging_callback=self.log_handler,
162
+ )
163
+ self._client = await self._exit_stack.enter_async_context(client)
177
164
 
178
- with anyio.fail_after(self._get_client_initialize_timeout()):
179
- await self._client.initialize()
165
+ with anyio.fail_after(self.timeout):
166
+ await self._client.initialize()
180
167
 
181
- if log_level := self._get_log_level():
182
- await self._client.set_logging_level(log_level)
183
- self.is_running = True
168
+ if log_level := self.log_level:
169
+ await self._client.set_logging_level(log_level)
170
+ self._running_count += 1
184
171
  return self
185
172
 
186
173
  async def __aexit__(
@@ -189,32 +176,64 @@ class MCPServer(ABC):
189
176
  exc_value: BaseException | None,
190
177
  traceback: TracebackType | None,
191
178
  ) -> bool | None:
192
- await self._exit_stack.aclose()
193
- self.is_running = False
179
+ self._running_count -= 1
180
+ if self._running_count <= 0:
181
+ await self._exit_stack.aclose()
182
+
183
+ async def _sampling_callback(
184
+ self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
185
+ ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
186
+ """MCP sampling callback."""
187
+ if self.sampling_model is None:
188
+ raise ValueError('Sampling model is not set') # pragma: no cover
189
+
190
+ pai_messages = _mcp.map_from_mcp_params(params)
191
+ model_settings = models.ModelSettings()
192
+ if max_tokens := params.maxTokens: # pragma: no branch
193
+ model_settings['max_tokens'] = max_tokens
194
+ if temperature := params.temperature: # pragma: no branch
195
+ model_settings['temperature'] = temperature
196
+ if stop_sequences := params.stopSequences: # pragma: no branch
197
+ model_settings['stop_sequences'] = stop_sequences
198
+
199
+ model_response = await self.sampling_model.request(
200
+ pai_messages,
201
+ model_settings,
202
+ models.ModelRequestParameters(),
203
+ )
204
+ return mcp_types.CreateMessageResult(
205
+ role='assistant',
206
+ content=_mcp.map_from_model_response(model_response),
207
+ model=self.sampling_model.model_name,
208
+ )
194
209
 
195
- def _map_tool_result_part(self, part: Content) -> str | BinaryContent | dict[str, Any] | list[Any]:
210
+ def _map_tool_result_part(
211
+ self, part: mcp_types.Content
212
+ ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
196
213
  # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
197
214
 
198
- if isinstance(part, TextContent):
215
+ if isinstance(part, mcp_types.TextContent):
199
216
  text = part.text
200
217
  if text.startswith(('[', '{')):
201
218
  try:
202
- return json.loads(text)
219
+ return pydantic_core.from_json(text)
203
220
  except ValueError:
204
221
  pass
205
222
  return text
206
- elif isinstance(part, ImageContent):
207
- return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
208
- elif isinstance(part, AudioContent):
223
+ elif isinstance(part, mcp_types.ImageContent):
224
+ return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
225
+ elif isinstance(part, mcp_types.AudioContent):
209
226
  # NOTE: The FastMCP server doesn't support audio content.
210
227
  # See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
211
- return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) # pragma: no cover
212
- elif isinstance(part, EmbeddedResource):
228
+ return messages.BinaryContent(
229
+ data=base64.b64decode(part.data), media_type=part.mimeType
230
+ ) # pragma: no cover
231
+ elif isinstance(part, mcp_types.EmbeddedResource):
213
232
  resource = part.resource
214
- if isinstance(resource, TextResourceContents):
233
+ if isinstance(resource, mcp_types.TextResourceContents):
215
234
  return resource.text
216
- elif isinstance(resource, BlobResourceContents):
217
- return BinaryContent(
235
+ elif isinstance(resource, mcp_types.BlobResourceContents):
236
+ return messages.BinaryContent(
218
237
  data=base64.b64decode(resource.blob),
219
238
  media_type=resource.mimeType or 'application/octet-stream',
220
239
  )
@@ -275,17 +294,11 @@ class MCPServerStdio(MCPServer):
275
294
  By default the subprocess will not inherit any environment variables from the parent process.
276
295
  If you want to inherit the environment variables from the parent process, use `env=os.environ`.
277
296
  """
278
- log_level: LoggingLevel | None = None
279
- """The log level to set when connecting to the server, if any.
280
-
281
- See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
282
-
283
- If `None`, no log level will be set.
284
- """
285
297
 
286
298
  cwd: str | Path | None = None
287
299
  """The working directory to use when spawning the process."""
288
300
 
301
+ # last fields are re-defined from the parent class so they appear as fields
289
302
  tool_prefix: str | None = None
290
303
  """A prefix to add to all tools that are registered with the server.
291
304
 
@@ -294,11 +307,25 @@ class MCPServerStdio(MCPServer):
294
307
  e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
295
308
  """
296
309
 
310
+ log_level: mcp_types.LoggingLevel | None = None
311
+ """The log level to set when connecting to the server, if any.
312
+
313
+ See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
314
+
315
+ If `None`, no log level will be set.
316
+ """
317
+
318
+ log_handler: LoggingFnT | None = None
319
+ """A handler for logging messages from the server."""
320
+
321
+ timeout: float = 5
322
+ """The timeout in seconds to wait for the client to initialize."""
323
+
297
324
  process_tool_call: ProcessToolCallback | None = None
298
325
  """Hook to customize tool calling and optionally pass extra metadata."""
299
326
 
300
- timeout: float = 5
301
- """ The timeout in seconds to wait for the client to initialize."""
327
+ allow_sampling: bool = True
328
+ """Whether to allow MCP sampling through this client."""
302
329
 
303
330
  @asynccontextmanager
304
331
  async def client_streams(
@@ -313,15 +340,9 @@ class MCPServerStdio(MCPServer):
313
340
  async with stdio_client(server=server) as (read_stream, write_stream):
314
341
  yield read_stream, write_stream
315
342
 
316
- def _get_log_level(self) -> LoggingLevel | None:
317
- return self.log_level
318
-
319
343
  def __repr__(self) -> str:
320
344
  return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
321
345
 
322
- def _get_client_initialize_timeout(self) -> float:
323
- return self.timeout
324
-
325
346
 
326
347
  @dataclass
327
348
  class _MCPServerHTTP(MCPServer):
@@ -360,13 +381,6 @@ class _MCPServerHTTP(MCPServer):
360
381
  ```
361
382
  """
362
383
 
363
- timeout: float = 5
364
- """Initial connection timeout in seconds for establishing the connection.
365
-
366
- This timeout applies to the initial connection setup and handshake.
367
- If the connection cannot be established within this time, the operation will fail.
368
- """
369
-
370
384
  sse_read_timeout: float = 5 * 60
371
385
  """Maximum time in seconds to wait for new SSE messages before timing out.
372
386
 
@@ -375,7 +389,16 @@ class _MCPServerHTTP(MCPServer):
375
389
  and may be closed. Defaults to 5 minutes (300 seconds).
376
390
  """
377
391
 
378
- log_level: LoggingLevel | None = None
392
+ # last fields are re-defined from the parent class so they appear as fields
393
+ tool_prefix: str | None = None
394
+ """A prefix to add to all tools that are registered with the server.
395
+
396
+ If not empty, will include a trailing underscore (`_`).
397
+
398
+ For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
399
+ """
400
+
401
+ log_level: mcp_types.LoggingLevel | None = None
379
402
  """The log level to set when connecting to the server, if any.
380
403
 
381
404
  See <https://modelcontextprotocol.io/introduction#logging> for more details.
@@ -383,17 +406,22 @@ class _MCPServerHTTP(MCPServer):
383
406
  If `None`, no log level will be set.
384
407
  """
385
408
 
386
- tool_prefix: str | None = None
387
- """A prefix to add to all tools that are registered with the server.
409
+ log_handler: LoggingFnT | None = None
410
+ """A handler for logging messages from the server."""
388
411
 
389
- If not empty, will include a trailing underscore (`_`).
412
+ timeout: float = 5
413
+ """Initial connection timeout in seconds for establishing the connection.
390
414
 
391
- For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
415
+ This timeout applies to the initial connection setup and handshake.
416
+ If the connection cannot be established within this time, the operation will fail.
392
417
  """
393
418
 
394
419
  process_tool_call: ProcessToolCallback | None = None
395
420
  """Hook to customize tool calling and optionally pass extra metadata."""
396
421
 
422
+ allow_sampling: bool = True
423
+ """Whether to allow MCP sampling through this client."""
424
+
397
425
  @property
398
426
  @abstractmethod
399
427
  def _transport_client(
@@ -419,7 +447,10 @@ class _MCPServerHTTP(MCPServer):
419
447
  async def client_streams(
420
448
  self,
421
449
  ) -> AsyncIterator[
422
- tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
450
+ tuple[
451
+ MemoryObjectReceiveStream[SessionMessage | Exception],
452
+ MemoryObjectSendStream[SessionMessage],
453
+ ]
423
454
  ]: # pragma: no cover
424
455
  if self.http_client and self.headers:
425
456
  raise ValueError('`http_client` is mutually exclusive with `headers`.')
@@ -451,15 +482,9 @@ class _MCPServerHTTP(MCPServer):
451
482
  async with transport_client_partial(headers=self.headers) as (read_stream, write_stream, *_):
452
483
  yield read_stream, write_stream
453
484
 
454
- def _get_log_level(self) -> LoggingLevel | None:
455
- return self.log_level
456
-
457
485
  def __repr__(self) -> str: # pragma: no cover
458
486
  return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
459
487
 
460
- def _get_client_initialize_timeout(self) -> float: # pragma: no cover
461
- return self.timeout
462
-
463
488
 
464
489
  @dataclass
465
490
  class MCPServerSSE(_MCPServerHTTP):
@@ -555,7 +580,11 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
555
580
 
556
581
 
557
582
  ToolResult = (
558
- str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]
583
+ str
584
+ | messages.BinaryContent
585
+ | dict[str, Any]
586
+ | list[Any]
587
+ | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
559
588
  )
560
589
  """The result type of a tool call."""
561
590
 
@@ -564,7 +593,7 @@ CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[
564
593
 
565
594
  ProcessToolCallback = Callable[
566
595
  [
567
- RunContext[Any],
596
+ tools.RunContext[Any],
568
597
  CallToolFunc,
569
598
  str,
570
599
  dict[str, Any],
@@ -763,7 +763,9 @@ class ThinkingPartDelta:
763
763
  ValueError: If `part` is not a `ThinkingPart`.
764
764
  """
765
765
  if isinstance(part, ThinkingPart):
766
- return replace(part, content=part.content + self.content_delta if self.content_delta else None)
766
+ new_content = part.content + self.content_delta if self.content_delta else part.content
767
+ new_signature = self.signature_delta if self.signature_delta is not None else part.signature
768
+ return replace(part, content=new_content, signature=new_signature)
767
769
  elif isinstance(part, ThinkingPartDelta):
768
770
  if self.content_delta is None and self.signature_delta is None:
769
771
  raise ValueError('Cannot apply ThinkingPartDelta with no content or signature')
@@ -139,8 +139,11 @@ KnownModelName = TypeAliasType(
139
139
  'google-gla:gemini-2.0-flash-lite-preview-02-05',
140
140
  'google-gla:gemini-2.0-pro-exp-02-05',
141
141
  'google-gla:gemini-2.5-flash-preview-05-20',
142
+ 'google-gla:gemini-2.5-flash',
143
+ 'google-gla:gemini-2.5-flash-lite-preview-06-17',
142
144
  'google-gla:gemini-2.5-pro-exp-03-25',
143
145
  'google-gla:gemini-2.5-pro-preview-05-06',
146
+ 'google-gla:gemini-2.5-pro',
144
147
  'google-vertex:gemini-1.5-flash',
145
148
  'google-vertex:gemini-1.5-flash-8b',
146
149
  'google-vertex:gemini-1.5-pro',
@@ -149,8 +152,11 @@ KnownModelName = TypeAliasType(
149
152
  'google-vertex:gemini-2.0-flash-lite-preview-02-05',
150
153
  'google-vertex:gemini-2.0-pro-exp-02-05',
151
154
  'google-vertex:gemini-2.5-flash-preview-05-20',
155
+ 'google-vertex:gemini-2.5-flash',
156
+ 'google-vertex:gemini-2.5-flash-lite-preview-06-17',
152
157
  'google-vertex:gemini-2.5-pro-exp-03-25',
153
158
  'google-vertex:gemini-2.5-pro-preview-05-06',
159
+ 'google-vertex:gemini-2.5-pro',
154
160
  'gpt-3.5-turbo',
155
161
  'gpt-3.5-turbo-0125',
156
162
  'gpt-3.5-turbo-0301',
@@ -54,8 +54,11 @@ LatestGeminiModelNames = Literal[
54
54
  'gemini-2.0-flash-lite-preview-02-05',
55
55
  'gemini-2.0-pro-exp-02-05',
56
56
  'gemini-2.5-flash-preview-05-20',
57
+ 'gemini-2.5-flash',
58
+ 'gemini-2.5-flash-lite-preview-06-17',
57
59
  'gemini-2.5-pro-exp-03-25',
58
60
  'gemini-2.5-pro-preview-05-06',
61
+ 'gemini-2.5-pro',
59
62
  ]
60
63
  """Latest Gemini models."""
61
64
 
@@ -79,8 +79,11 @@ LatestGoogleModelNames = Literal[
79
79
  'gemini-2.0-flash-lite-preview-02-05',
80
80
  'gemini-2.0-pro-exp-02-05',
81
81
  'gemini-2.5-flash-preview-05-20',
82
+ 'gemini-2.5-flash',
83
+ 'gemini-2.5-flash-lite-preview-06-17',
82
84
  'gemini-2.5-pro-exp-03-25',
83
85
  'gemini-2.5-pro-preview-05-06',
86
+ 'gemini-2.5-pro',
84
87
  ]
85
88
  """Latest Gemini models."""
86
89
 
@@ -0,0 +1,95 @@
1
+ from __future__ import annotations as _annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from contextlib import asynccontextmanager
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, cast
7
+
8
+ from .. import _mcp, exceptions, usage
9
+ from ..messages import ModelMessage, ModelResponse
10
+ from ..settings import ModelSettings
11
+ from . import Model, ModelRequestParameters, StreamedResponse
12
+
13
+ if TYPE_CHECKING:
14
+ from mcp import ServerSession
15
+ from mcp.types import ModelPreferences
16
+
17
+
18
+ class MCPSamplingModelSettings(ModelSettings, total=False):
19
+ """Settings used for an MCP Sampling model request.
20
+
21
+ ALL FIELDS MUST BE `mcp_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
22
+ """
23
+
24
+ mcp_model_preferences: ModelPreferences
25
+ """Model preferences to use for MCP Sampling."""
26
+
27
+
28
+ @dataclass
29
+ class MCPSamplingModel(Model):
30
+ """A model that uses MCP Sampling.
31
+
32
+ [MCP Sampling](https://modelcontextprotocol.io/docs/concepts/sampling)
33
+ allows an MCP server to make requests to a model by calling back to the MCP client that connected to it.
34
+ """
35
+
36
+ session: ServerSession
37
+ """The MCP server session to use for sampling."""
38
+
39
+ default_max_tokens: int = 16_384
40
+ """Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
41
+
42
+ Max tokens is a required parameter for MCP Sampling, but optional on
43
+ [`ModelSettings`][pydantic_ai.settings.ModelSettings], so this value is used as fallback.
44
+ """
45
+
46
+ async def request(
47
+ self,
48
+ messages: list[ModelMessage],
49
+ model_settings: ModelSettings | None,
50
+ model_request_parameters: ModelRequestParameters,
51
+ ) -> ModelResponse:
52
+ system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages)
53
+ model_settings = cast(MCPSamplingModelSettings, model_settings or {})
54
+
55
+ result = await self.session.create_message(
56
+ sampling_messages,
57
+ max_tokens=model_settings.get('max_tokens', self.default_max_tokens),
58
+ system_prompt=system_prompt,
59
+ temperature=model_settings.get('temperature'),
60
+ model_preferences=model_settings.get('mcp_model_preferences'),
61
+ stop_sequences=model_settings.get('stop_sequences'),
62
+ )
63
+ if result.role == 'assistant':
64
+ return ModelResponse(
65
+ parts=[_mcp.map_from_sampling_content(result.content)],
66
+ usage=usage.Usage(requests=1),
67
+ model_name=result.model,
68
+ )
69
+ else:
70
+ raise exceptions.UnexpectedModelBehavior(
71
+ f'Unexpected result from MCP sampling, expected "assistant" role, got {result.role}.'
72
+ )
73
+
74
+ @asynccontextmanager
75
+ async def request_stream(
76
+ self,
77
+ messages: list[ModelMessage],
78
+ model_settings: ModelSettings | None,
79
+ model_request_parameters: ModelRequestParameters,
80
+ ) -> AsyncIterator[StreamedResponse]:
81
+ raise NotImplementedError('MCP Sampling does not support streaming')
82
+ yield
83
+
84
+ @property
85
+ def model_name(self) -> str:
86
+ """The model name.
87
+
88
+ Since the model name isn't known until the request is made, this property always returns `'mcp-sampling'`.
89
+ """
90
+ return 'mcp-sampling'
91
+
92
+ @property
93
+ def system(self) -> str:
94
+ """The system / model provider, returns `'MCP'`."""
95
+ return 'MCP'
@@ -23,6 +23,7 @@ class ModelSettings(TypedDict, total=False):
23
23
  * Cohere
24
24
  * Mistral
25
25
  * Bedrock
26
+ * MCP Sampling
26
27
  """
27
28
 
28
29
  temperature: float
@@ -92,7 +92,7 @@ dev = [
92
92
  "pytest>=8.3.3",
93
93
  "pytest-examples>=0.0.14",
94
94
  "pytest-mock>=3.14.0",
95
- "pytest-pretty>=1.2.0",
95
+ "pytest-pretty>=1.3.0",
96
96
  "pytest-recording>=0.13.2",
97
97
  "diff-cover>=9.2.0",
98
98
  "boto3-stubs[bedrock-runtime]",