pydantic-ai-slim 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pydantic-ai-slim might be problematic. Click here for more details.
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/.gitignore +1 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/PKG-INFO +4 -4
- pydantic_ai_slim-0.3.2/pydantic_ai/_mcp.py +123 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/agent.py +10 -1
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/mcp.py +144 -115
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/messages.py +3 -1
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/__init__.py +6 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/gemini.py +3 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/google.py +3 -0
- pydantic_ai_slim-0.3.2/pydantic_ai/models/mcp_sampling.py +95 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/settings.py +1 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pyproject.toml +1 -1
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/LICENSE +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/README.md +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/__init__.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/__main__.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_a2a.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_agent_graph.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_cli.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_function_schema.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_griffe.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_output.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_parts_manager.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_system_prompt.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_thinking_part.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/_utils.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/common_tools/__init__.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/common_tools/duckduckgo.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/common_tools/tavily.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/direct.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/exceptions.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/ext/__init__.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/ext/langchain.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/format_as_xml.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/format_prompt.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/anthropic.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/bedrock.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/cohere.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/fallback.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/function.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/groq.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/instrumented.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/mistral.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/openai.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/test.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/models/wrapper.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/__init__.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/_json_schema.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/amazon.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/anthropic.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/cohere.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/deepseek.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/google.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/grok.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/meta.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/mistral.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/openai.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/profiles/qwen.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/__init__.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/anthropic.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/azure.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/bedrock.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/cohere.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/deepseek.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/fireworks.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/google.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/google_gla.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/google_vertex.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/grok.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/groq.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/heroku.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/mistral.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/openai.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/openrouter.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/providers/together.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/py.typed +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/result.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/tools.py +0 -0
- {pydantic_ai_slim-0.3.0 → pydantic_ai_slim-0.3.2}/pydantic_ai/usage.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pydantic-ai-slim
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Agent Framework / shim to use Pydantic with LLMs, slim package
|
|
5
5
|
Author-email: Samuel Colvin <samuel@pydantic.dev>, Marcelo Trylesinski <marcelotryle@gmail.com>, David Montague <david@pydantic.dev>, Alex Hall <alex@pydantic.dev>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -30,11 +30,11 @@ Requires-Dist: exceptiongroup; python_version < '3.11'
|
|
|
30
30
|
Requires-Dist: griffe>=1.3.2
|
|
31
31
|
Requires-Dist: httpx>=0.27
|
|
32
32
|
Requires-Dist: opentelemetry-api>=1.28.0
|
|
33
|
-
Requires-Dist: pydantic-graph==0.3.
|
|
33
|
+
Requires-Dist: pydantic-graph==0.3.2
|
|
34
34
|
Requires-Dist: pydantic>=2.10
|
|
35
35
|
Requires-Dist: typing-inspection>=0.4.0
|
|
36
36
|
Provides-Extra: a2a
|
|
37
|
-
Requires-Dist: fasta2a==0.3.
|
|
37
|
+
Requires-Dist: fasta2a==0.3.2; extra == 'a2a'
|
|
38
38
|
Provides-Extra: anthropic
|
|
39
39
|
Requires-Dist: anthropic>=0.52.0; extra == 'anthropic'
|
|
40
40
|
Provides-Extra: bedrock
|
|
@@ -48,7 +48,7 @@ Requires-Dist: cohere>=5.13.11; (platform_system != 'Emscripten') and extra == '
|
|
|
48
48
|
Provides-Extra: duckduckgo
|
|
49
49
|
Requires-Dist: duckduckgo-search>=7.0.0; extra == 'duckduckgo'
|
|
50
50
|
Provides-Extra: evals
|
|
51
|
-
Requires-Dist: pydantic-evals==0.3.
|
|
51
|
+
Requires-Dist: pydantic-evals==0.3.2; extra == 'evals'
|
|
52
52
|
Provides-Extra: google
|
|
53
53
|
Requires-Dist: google-genai>=1.15.0; extra == 'google'
|
|
54
54
|
Provides-Extra: groq
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from . import exceptions, messages
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from mcp import types as mcp_types
|
|
9
|
+
except ImportError as _import_error:
|
|
10
|
+
raise ImportError(
|
|
11
|
+
'Please install the `mcp` package to use the MCP server, '
|
|
12
|
+
'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
|
|
13
|
+
) from _import_error
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def map_from_mcp_params(params: mcp_types.CreateMessageRequestParams) -> list[messages.ModelMessage]:
|
|
17
|
+
"""Convert from MCP create message request parameters to pydantic-ai messages."""
|
|
18
|
+
pai_messages: list[messages.ModelMessage] = []
|
|
19
|
+
request_parts: list[messages.ModelRequestPart] = []
|
|
20
|
+
if params.systemPrompt:
|
|
21
|
+
request_parts.append(messages.SystemPromptPart(content=params.systemPrompt))
|
|
22
|
+
response_parts: list[messages.ModelResponsePart] = []
|
|
23
|
+
for msg in params.messages:
|
|
24
|
+
content = msg.content
|
|
25
|
+
if msg.role == 'user':
|
|
26
|
+
# if there are any response parts, add a response message wrapping them
|
|
27
|
+
if response_parts:
|
|
28
|
+
pai_messages.append(messages.ModelResponse(parts=response_parts))
|
|
29
|
+
response_parts = []
|
|
30
|
+
|
|
31
|
+
# TODO(Marcelo): We can reuse the `_map_tool_result_part` from the mcp module here.
|
|
32
|
+
if isinstance(content, mcp_types.TextContent):
|
|
33
|
+
user_part_content: str | Sequence[messages.UserContent] = content.text
|
|
34
|
+
else:
|
|
35
|
+
# image content
|
|
36
|
+
user_part_content = [
|
|
37
|
+
messages.BinaryContent(data=base64.b64decode(content.data), media_type=content.mimeType)
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
request_parts.append(messages.UserPromptPart(content=user_part_content))
|
|
41
|
+
else:
|
|
42
|
+
# role is assistant
|
|
43
|
+
# if there are any request parts, add a request message wrapping them
|
|
44
|
+
if request_parts:
|
|
45
|
+
pai_messages.append(messages.ModelRequest(parts=request_parts))
|
|
46
|
+
request_parts = []
|
|
47
|
+
|
|
48
|
+
response_parts.append(map_from_sampling_content(content))
|
|
49
|
+
|
|
50
|
+
if response_parts:
|
|
51
|
+
pai_messages.append(messages.ModelResponse(parts=response_parts))
|
|
52
|
+
if request_parts:
|
|
53
|
+
pai_messages.append(messages.ModelRequest(parts=request_parts))
|
|
54
|
+
return pai_messages
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def map_from_pai_messages(pai_messages: list[messages.ModelMessage]) -> tuple[str, list[mcp_types.SamplingMessage]]:
|
|
58
|
+
"""Convert from pydantic-ai messages to MCP sampling messages.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
A tuple containing the system prompt and a list of sampling messages.
|
|
62
|
+
"""
|
|
63
|
+
sampling_msgs: list[mcp_types.SamplingMessage] = []
|
|
64
|
+
|
|
65
|
+
def add_msg(
|
|
66
|
+
role: Literal['user', 'assistant'],
|
|
67
|
+
content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent,
|
|
68
|
+
):
|
|
69
|
+
sampling_msgs.append(mcp_types.SamplingMessage(role=role, content=content))
|
|
70
|
+
|
|
71
|
+
system_prompt: list[str] = []
|
|
72
|
+
for pai_message in pai_messages:
|
|
73
|
+
if isinstance(pai_message, messages.ModelRequest):
|
|
74
|
+
if pai_message.instructions is not None:
|
|
75
|
+
system_prompt.append(pai_message.instructions)
|
|
76
|
+
|
|
77
|
+
for part in pai_message.parts:
|
|
78
|
+
if isinstance(part, messages.SystemPromptPart):
|
|
79
|
+
system_prompt.append(part.content)
|
|
80
|
+
if isinstance(part, messages.UserPromptPart):
|
|
81
|
+
if isinstance(part.content, str):
|
|
82
|
+
add_msg('user', mcp_types.TextContent(type='text', text=part.content))
|
|
83
|
+
else:
|
|
84
|
+
for chunk in part.content:
|
|
85
|
+
if isinstance(chunk, str):
|
|
86
|
+
add_msg('user', mcp_types.TextContent(type='text', text=chunk))
|
|
87
|
+
elif isinstance(chunk, messages.BinaryContent) and chunk.is_image:
|
|
88
|
+
add_msg(
|
|
89
|
+
'user',
|
|
90
|
+
mcp_types.ImageContent(
|
|
91
|
+
type='image',
|
|
92
|
+
data=base64.b64decode(chunk.data).decode(),
|
|
93
|
+
mimeType=chunk.media_type,
|
|
94
|
+
),
|
|
95
|
+
)
|
|
96
|
+
# TODO(Marcelo): Add support for audio content.
|
|
97
|
+
else:
|
|
98
|
+
raise NotImplementedError(f'Unsupported content type: {type(chunk)}')
|
|
99
|
+
else:
|
|
100
|
+
add_msg('assistant', map_from_model_response(pai_message))
|
|
101
|
+
return ''.join(system_prompt), sampling_msgs
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def map_from_model_response(model_response: messages.ModelResponse) -> mcp_types.TextContent:
|
|
105
|
+
"""Convert from a model response to MCP text content."""
|
|
106
|
+
text_parts: list[str] = []
|
|
107
|
+
for part in model_response.parts:
|
|
108
|
+
if isinstance(part, messages.TextPart):
|
|
109
|
+
text_parts.append(part.content)
|
|
110
|
+
# TODO(Marcelo): We should ignore ThinkingPart here.
|
|
111
|
+
else:
|
|
112
|
+
raise exceptions.UnexpectedModelBehavior(f'Unexpected part type: {type(part).__name__}, expected TextPart')
|
|
113
|
+
return mcp_types.TextContent(type='text', text=''.join(text_parts))
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def map_from_sampling_content(
|
|
117
|
+
content: mcp_types.TextContent | mcp_types.ImageContent | mcp_types.AudioContent,
|
|
118
|
+
) -> messages.TextPart:
|
|
119
|
+
"""Convert from sampling content to a pydantic-ai text part."""
|
|
120
|
+
if isinstance(content, mcp_types.TextContent): # pragma: no branch
|
|
121
|
+
return messages.TextPart(content=content.text)
|
|
122
|
+
else:
|
|
123
|
+
raise NotImplementedError('Image and Audio responses in sampling are not yet supported')
|
|
@@ -1691,14 +1691,23 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
|
|
1691
1691
|
return isinstance(node, End)
|
|
1692
1692
|
|
|
1693
1693
|
@asynccontextmanager
|
|
1694
|
-
async def run_mcp_servers(
|
|
1694
|
+
async def run_mcp_servers(
|
|
1695
|
+
self, model: models.Model | models.KnownModelName | str | None = None
|
|
1696
|
+
) -> AsyncIterator[None]:
|
|
1695
1697
|
"""Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent.
|
|
1696
1698
|
|
|
1697
1699
|
Returns: a context manager to start and shutdown the servers.
|
|
1698
1700
|
"""
|
|
1701
|
+
try:
|
|
1702
|
+
sampling_model: models.Model | None = self._get_model(model)
|
|
1703
|
+
except exceptions.UserError: # pragma: no cover
|
|
1704
|
+
sampling_model = None
|
|
1705
|
+
|
|
1699
1706
|
exit_stack = AsyncExitStack()
|
|
1700
1707
|
try:
|
|
1701
1708
|
for mcp_server in self._mcp_servers:
|
|
1709
|
+
if sampling_model is not None: # pragma: no branch
|
|
1710
|
+
mcp_server.sampling_model = sampling_model
|
|
1702
1711
|
await exit_stack.enter_async_context(mcp_server)
|
|
1703
1712
|
yield
|
|
1704
1713
|
finally:
|
|
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import base64
|
|
4
4
|
import functools
|
|
5
|
-
import json
|
|
6
5
|
from abc import ABC, abstractmethod
|
|
7
6
|
from collections.abc import AsyncIterator, Awaitable, Sequence
|
|
8
7
|
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
|
|
@@ -13,41 +12,28 @@ from typing import Any, Callable
|
|
|
13
12
|
|
|
14
13
|
import anyio
|
|
15
14
|
import httpx
|
|
15
|
+
import pydantic_core
|
|
16
16
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
17
|
-
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
18
|
-
from mcp.shared.exceptions import McpError
|
|
19
|
-
from mcp.shared.message import SessionMessage
|
|
20
|
-
from mcp.types import (
|
|
21
|
-
AudioContent,
|
|
22
|
-
BlobResourceContents,
|
|
23
|
-
CallToolRequest,
|
|
24
|
-
CallToolRequestParams,
|
|
25
|
-
CallToolResult,
|
|
26
|
-
ClientRequest,
|
|
27
|
-
Content,
|
|
28
|
-
EmbeddedResource,
|
|
29
|
-
ImageContent,
|
|
30
|
-
LoggingLevel,
|
|
31
|
-
RequestParams,
|
|
32
|
-
TextContent,
|
|
33
|
-
TextResourceContents,
|
|
34
|
-
)
|
|
35
17
|
from typing_extensions import Self, assert_never, deprecated
|
|
36
18
|
|
|
37
|
-
from pydantic_ai.exceptions import ModelRetry
|
|
38
|
-
from pydantic_ai.messages import BinaryContent
|
|
39
|
-
from pydantic_ai.tools import RunContext, ToolDefinition
|
|
40
|
-
|
|
41
19
|
try:
|
|
42
|
-
from mcp
|
|
20
|
+
from mcp import types as mcp_types
|
|
21
|
+
from mcp.client.session import ClientSession, LoggingFnT
|
|
43
22
|
from mcp.client.sse import sse_client
|
|
44
23
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
24
|
+
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
|
|
25
|
+
from mcp.shared.context import RequestContext
|
|
26
|
+
from mcp.shared.exceptions import McpError
|
|
27
|
+
from mcp.shared.message import SessionMessage
|
|
45
28
|
except ImportError as _import_error:
|
|
46
29
|
raise ImportError(
|
|
47
30
|
'Please install the `mcp` package to use the MCP server, '
|
|
48
31
|
'you can use the `mcp` optional group — `pip install "pydantic-ai-slim[mcp]"`'
|
|
49
32
|
) from _import_error
|
|
50
33
|
|
|
34
|
+
# after mcp imports so any import error maps to this file, not _mcp.py
|
|
35
|
+
from . import _mcp, exceptions, messages, models, tools
|
|
36
|
+
|
|
51
37
|
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
|
|
52
38
|
|
|
53
39
|
|
|
@@ -57,22 +43,22 @@ class MCPServer(ABC):
|
|
|
57
43
|
See <https://modelcontextprotocol.io> for more information.
|
|
58
44
|
"""
|
|
59
45
|
|
|
60
|
-
|
|
46
|
+
# these fields should be re-defined by dataclass subclasses so they appear as fields {
|
|
61
47
|
tool_prefix: str | None = None
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
67
|
-
"""
|
|
68
|
-
|
|
48
|
+
log_level: mcp_types.LoggingLevel | None = None
|
|
49
|
+
log_handler: LoggingFnT | None = None
|
|
50
|
+
timeout: float = 5
|
|
69
51
|
process_tool_call: ProcessToolCallback | None = None
|
|
70
|
-
|
|
52
|
+
allow_sampling: bool = True
|
|
53
|
+
# } end of "abstract fields"
|
|
54
|
+
|
|
55
|
+
_running_count: int = 0
|
|
71
56
|
|
|
72
57
|
_client: ClientSession
|
|
73
58
|
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
|
74
59
|
_write_stream: MemoryObjectSendStream[SessionMessage]
|
|
75
60
|
_exit_stack: AsyncExitStack
|
|
61
|
+
sampling_model: models.Model | None = None
|
|
76
62
|
|
|
77
63
|
@abstractmethod
|
|
78
64
|
@asynccontextmanager
|
|
@@ -88,14 +74,6 @@ class MCPServer(ABC):
|
|
|
88
74
|
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
89
75
|
yield
|
|
90
76
|
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def _get_log_level(self) -> LoggingLevel | None:
|
|
93
|
-
"""Get the log level for the MCP server."""
|
|
94
|
-
raise NotImplementedError('MCP Server subclasses must implement this method.')
|
|
95
|
-
|
|
96
|
-
def _get_client_initialize_timeout(self) -> float:
|
|
97
|
-
return 5 # pragma: no cover
|
|
98
|
-
|
|
99
77
|
def get_prefixed_tool_name(self, tool_name: str) -> str:
|
|
100
78
|
"""Get the tool name with prefix if `tool_prefix` is set."""
|
|
101
79
|
return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name
|
|
@@ -104,21 +82,26 @@ class MCPServer(ABC):
|
|
|
104
82
|
"""Get original tool name without prefix for calling tools."""
|
|
105
83
|
return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name
|
|
106
84
|
|
|
107
|
-
|
|
85
|
+
@property
|
|
86
|
+
def is_running(self) -> bool:
|
|
87
|
+
"""Check if the MCP server is running."""
|
|
88
|
+
return bool(self._running_count)
|
|
89
|
+
|
|
90
|
+
async def list_tools(self) -> list[tools.ToolDefinition]:
|
|
108
91
|
"""Retrieve tools that are currently active on the server.
|
|
109
92
|
|
|
110
93
|
Note:
|
|
111
94
|
- We don't cache tools as they might change.
|
|
112
95
|
- We also don't subscribe to the server to avoid complexity.
|
|
113
96
|
"""
|
|
114
|
-
|
|
97
|
+
mcp_tools = await self._client.list_tools()
|
|
115
98
|
return [
|
|
116
|
-
ToolDefinition(
|
|
99
|
+
tools.ToolDefinition(
|
|
117
100
|
name=self.get_prefixed_tool_name(tool.name),
|
|
118
101
|
description=tool.description or '',
|
|
119
102
|
parameters_json_schema=tool.inputSchema,
|
|
120
103
|
)
|
|
121
|
-
for tool in
|
|
104
|
+
for tool in mcp_tools.tools
|
|
122
105
|
]
|
|
123
106
|
|
|
124
107
|
async def call_tool(
|
|
@@ -143,44 +126,48 @@ class MCPServer(ABC):
|
|
|
143
126
|
try:
|
|
144
127
|
# meta param is not provided by session yet, so build and can send_request directly.
|
|
145
128
|
result = await self._client.send_request(
|
|
146
|
-
ClientRequest(
|
|
147
|
-
CallToolRequest(
|
|
129
|
+
mcp_types.ClientRequest(
|
|
130
|
+
mcp_types.CallToolRequest(
|
|
148
131
|
method='tools/call',
|
|
149
|
-
params=CallToolRequestParams(
|
|
132
|
+
params=mcp_types.CallToolRequestParams(
|
|
150
133
|
name=self.get_unprefixed_tool_name(tool_name),
|
|
151
134
|
arguments=arguments,
|
|
152
|
-
_meta=RequestParams.Meta(**metadata) if metadata else None,
|
|
135
|
+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
|
|
153
136
|
),
|
|
154
137
|
)
|
|
155
138
|
),
|
|
156
|
-
CallToolResult,
|
|
139
|
+
mcp_types.CallToolResult,
|
|
157
140
|
)
|
|
158
141
|
except McpError as e:
|
|
159
|
-
raise ModelRetry(e.error.message)
|
|
142
|
+
raise exceptions.ModelRetry(e.error.message)
|
|
160
143
|
|
|
161
144
|
content = [self._map_tool_result_part(part) for part in result.content]
|
|
162
145
|
|
|
163
146
|
if result.isError:
|
|
164
147
|
text = '\n'.join(str(part) for part in content)
|
|
165
|
-
raise ModelRetry(text)
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
return content[0]
|
|
169
|
-
return content
|
|
148
|
+
raise exceptions.ModelRetry(text)
|
|
149
|
+
else:
|
|
150
|
+
return content[0] if len(content) == 1 else content
|
|
170
151
|
|
|
171
152
|
async def __aenter__(self) -> Self:
|
|
172
|
-
self.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
153
|
+
if self._running_count == 0:
|
|
154
|
+
self._exit_stack = AsyncExitStack()
|
|
155
|
+
|
|
156
|
+
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
|
|
157
|
+
client = ClientSession(
|
|
158
|
+
read_stream=self._read_stream,
|
|
159
|
+
write_stream=self._write_stream,
|
|
160
|
+
sampling_callback=self._sampling_callback if self.allow_sampling else None,
|
|
161
|
+
logging_callback=self.log_handler,
|
|
162
|
+
)
|
|
163
|
+
self._client = await self._exit_stack.enter_async_context(client)
|
|
177
164
|
|
|
178
|
-
|
|
179
|
-
|
|
165
|
+
with anyio.fail_after(self.timeout):
|
|
166
|
+
await self._client.initialize()
|
|
180
167
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
self.
|
|
168
|
+
if log_level := self.log_level:
|
|
169
|
+
await self._client.set_logging_level(log_level)
|
|
170
|
+
self._running_count += 1
|
|
184
171
|
return self
|
|
185
172
|
|
|
186
173
|
async def __aexit__(
|
|
@@ -189,32 +176,64 @@ class MCPServer(ABC):
|
|
|
189
176
|
exc_value: BaseException | None,
|
|
190
177
|
traceback: TracebackType | None,
|
|
191
178
|
) -> bool | None:
|
|
192
|
-
|
|
193
|
-
self.
|
|
179
|
+
self._running_count -= 1
|
|
180
|
+
if self._running_count <= 0:
|
|
181
|
+
await self._exit_stack.aclose()
|
|
182
|
+
|
|
183
|
+
async def _sampling_callback(
|
|
184
|
+
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
|
|
185
|
+
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
|
|
186
|
+
"""MCP sampling callback."""
|
|
187
|
+
if self.sampling_model is None:
|
|
188
|
+
raise ValueError('Sampling model is not set') # pragma: no cover
|
|
189
|
+
|
|
190
|
+
pai_messages = _mcp.map_from_mcp_params(params)
|
|
191
|
+
model_settings = models.ModelSettings()
|
|
192
|
+
if max_tokens := params.maxTokens: # pragma: no branch
|
|
193
|
+
model_settings['max_tokens'] = max_tokens
|
|
194
|
+
if temperature := params.temperature: # pragma: no branch
|
|
195
|
+
model_settings['temperature'] = temperature
|
|
196
|
+
if stop_sequences := params.stopSequences: # pragma: no branch
|
|
197
|
+
model_settings['stop_sequences'] = stop_sequences
|
|
198
|
+
|
|
199
|
+
model_response = await self.sampling_model.request(
|
|
200
|
+
pai_messages,
|
|
201
|
+
model_settings,
|
|
202
|
+
models.ModelRequestParameters(),
|
|
203
|
+
)
|
|
204
|
+
return mcp_types.CreateMessageResult(
|
|
205
|
+
role='assistant',
|
|
206
|
+
content=_mcp.map_from_model_response(model_response),
|
|
207
|
+
model=self.sampling_model.model_name,
|
|
208
|
+
)
|
|
194
209
|
|
|
195
|
-
def _map_tool_result_part(
|
|
210
|
+
def _map_tool_result_part(
|
|
211
|
+
self, part: mcp_types.Content
|
|
212
|
+
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
|
|
196
213
|
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values
|
|
197
214
|
|
|
198
|
-
if isinstance(part, TextContent):
|
|
215
|
+
if isinstance(part, mcp_types.TextContent):
|
|
199
216
|
text = part.text
|
|
200
217
|
if text.startswith(('[', '{')):
|
|
201
218
|
try:
|
|
202
|
-
return
|
|
219
|
+
return pydantic_core.from_json(text)
|
|
203
220
|
except ValueError:
|
|
204
221
|
pass
|
|
205
222
|
return text
|
|
206
|
-
elif isinstance(part, ImageContent):
|
|
207
|
-
return BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
|
|
208
|
-
elif isinstance(part, AudioContent):
|
|
223
|
+
elif isinstance(part, mcp_types.ImageContent):
|
|
224
|
+
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType)
|
|
225
|
+
elif isinstance(part, mcp_types.AudioContent):
|
|
209
226
|
# NOTE: The FastMCP server doesn't support audio content.
|
|
210
227
|
# See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details.
|
|
211
|
-
return BinaryContent(
|
|
212
|
-
|
|
228
|
+
return messages.BinaryContent(
|
|
229
|
+
data=base64.b64decode(part.data), media_type=part.mimeType
|
|
230
|
+
) # pragma: no cover
|
|
231
|
+
elif isinstance(part, mcp_types.EmbeddedResource):
|
|
213
232
|
resource = part.resource
|
|
214
|
-
if isinstance(resource, TextResourceContents):
|
|
233
|
+
if isinstance(resource, mcp_types.TextResourceContents):
|
|
215
234
|
return resource.text
|
|
216
|
-
elif isinstance(resource, BlobResourceContents):
|
|
217
|
-
return BinaryContent(
|
|
235
|
+
elif isinstance(resource, mcp_types.BlobResourceContents):
|
|
236
|
+
return messages.BinaryContent(
|
|
218
237
|
data=base64.b64decode(resource.blob),
|
|
219
238
|
media_type=resource.mimeType or 'application/octet-stream',
|
|
220
239
|
)
|
|
@@ -275,17 +294,11 @@ class MCPServerStdio(MCPServer):
|
|
|
275
294
|
By default the subprocess will not inherit any environment variables from the parent process.
|
|
276
295
|
If you want to inherit the environment variables from the parent process, use `env=os.environ`.
|
|
277
296
|
"""
|
|
278
|
-
log_level: LoggingLevel | None = None
|
|
279
|
-
"""The log level to set when connecting to the server, if any.
|
|
280
|
-
|
|
281
|
-
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
282
|
-
|
|
283
|
-
If `None`, no log level will be set.
|
|
284
|
-
"""
|
|
285
297
|
|
|
286
298
|
cwd: str | Path | None = None
|
|
287
299
|
"""The working directory to use when spawning the process."""
|
|
288
300
|
|
|
301
|
+
# last fields are re-defined from the parent class so they appear as fields
|
|
289
302
|
tool_prefix: str | None = None
|
|
290
303
|
"""A prefix to add to all tools that are registered with the server.
|
|
291
304
|
|
|
@@ -294,11 +307,25 @@ class MCPServerStdio(MCPServer):
|
|
|
294
307
|
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
295
308
|
"""
|
|
296
309
|
|
|
310
|
+
log_level: mcp_types.LoggingLevel | None = None
|
|
311
|
+
"""The log level to set when connecting to the server, if any.
|
|
312
|
+
|
|
313
|
+
See <https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging#logging> for more details.
|
|
314
|
+
|
|
315
|
+
If `None`, no log level will be set.
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
log_handler: LoggingFnT | None = None
|
|
319
|
+
"""A handler for logging messages from the server."""
|
|
320
|
+
|
|
321
|
+
timeout: float = 5
|
|
322
|
+
"""The timeout in seconds to wait for the client to initialize."""
|
|
323
|
+
|
|
297
324
|
process_tool_call: ProcessToolCallback | None = None
|
|
298
325
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
299
326
|
|
|
300
|
-
|
|
301
|
-
"""
|
|
327
|
+
allow_sampling: bool = True
|
|
328
|
+
"""Whether to allow MCP sampling through this client."""
|
|
302
329
|
|
|
303
330
|
@asynccontextmanager
|
|
304
331
|
async def client_streams(
|
|
@@ -313,15 +340,9 @@ class MCPServerStdio(MCPServer):
|
|
|
313
340
|
async with stdio_client(server=server) as (read_stream, write_stream):
|
|
314
341
|
yield read_stream, write_stream
|
|
315
342
|
|
|
316
|
-
def _get_log_level(self) -> LoggingLevel | None:
|
|
317
|
-
return self.log_level
|
|
318
|
-
|
|
319
343
|
def __repr__(self) -> str:
|
|
320
344
|
return f'MCPServerStdio(command={self.command!r}, args={self.args!r}, tool_prefix={self.tool_prefix!r})'
|
|
321
345
|
|
|
322
|
-
def _get_client_initialize_timeout(self) -> float:
|
|
323
|
-
return self.timeout
|
|
324
|
-
|
|
325
346
|
|
|
326
347
|
@dataclass
|
|
327
348
|
class _MCPServerHTTP(MCPServer):
|
|
@@ -360,13 +381,6 @@ class _MCPServerHTTP(MCPServer):
|
|
|
360
381
|
```
|
|
361
382
|
"""
|
|
362
383
|
|
|
363
|
-
timeout: float = 5
|
|
364
|
-
"""Initial connection timeout in seconds for establishing the connection.
|
|
365
|
-
|
|
366
|
-
This timeout applies to the initial connection setup and handshake.
|
|
367
|
-
If the connection cannot be established within this time, the operation will fail.
|
|
368
|
-
"""
|
|
369
|
-
|
|
370
384
|
sse_read_timeout: float = 5 * 60
|
|
371
385
|
"""Maximum time in seconds to wait for new SSE messages before timing out.
|
|
372
386
|
|
|
@@ -375,7 +389,16 @@ class _MCPServerHTTP(MCPServer):
|
|
|
375
389
|
and may be closed. Defaults to 5 minutes (300 seconds).
|
|
376
390
|
"""
|
|
377
391
|
|
|
378
|
-
|
|
392
|
+
# last fields are re-defined from the parent class so they appear as fields
|
|
393
|
+
tool_prefix: str | None = None
|
|
394
|
+
"""A prefix to add to all tools that are registered with the server.
|
|
395
|
+
|
|
396
|
+
If not empty, will include a trailing underscore (`_`).
|
|
397
|
+
|
|
398
|
+
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
log_level: mcp_types.LoggingLevel | None = None
|
|
379
402
|
"""The log level to set when connecting to the server, if any.
|
|
380
403
|
|
|
381
404
|
See <https://modelcontextprotocol.io/introduction#logging> for more details.
|
|
@@ -383,17 +406,22 @@ class _MCPServerHTTP(MCPServer):
|
|
|
383
406
|
If `None`, no log level will be set.
|
|
384
407
|
"""
|
|
385
408
|
|
|
386
|
-
|
|
387
|
-
"""A
|
|
409
|
+
log_handler: LoggingFnT | None = None
|
|
410
|
+
"""A handler for logging messages from the server."""
|
|
388
411
|
|
|
389
|
-
|
|
412
|
+
timeout: float = 5
|
|
413
|
+
"""Initial connection timeout in seconds for establishing the connection.
|
|
390
414
|
|
|
391
|
-
|
|
415
|
+
This timeout applies to the initial connection setup and handshake.
|
|
416
|
+
If the connection cannot be established within this time, the operation will fail.
|
|
392
417
|
"""
|
|
393
418
|
|
|
394
419
|
process_tool_call: ProcessToolCallback | None = None
|
|
395
420
|
"""Hook to customize tool calling and optionally pass extra metadata."""
|
|
396
421
|
|
|
422
|
+
allow_sampling: bool = True
|
|
423
|
+
"""Whether to allow MCP sampling through this client."""
|
|
424
|
+
|
|
397
425
|
@property
|
|
398
426
|
@abstractmethod
|
|
399
427
|
def _transport_client(
|
|
@@ -419,7 +447,10 @@ class _MCPServerHTTP(MCPServer):
|
|
|
419
447
|
async def client_streams(
|
|
420
448
|
self,
|
|
421
449
|
) -> AsyncIterator[
|
|
422
|
-
tuple[
|
|
450
|
+
tuple[
|
|
451
|
+
MemoryObjectReceiveStream[SessionMessage | Exception],
|
|
452
|
+
MemoryObjectSendStream[SessionMessage],
|
|
453
|
+
]
|
|
423
454
|
]: # pragma: no cover
|
|
424
455
|
if self.http_client and self.headers:
|
|
425
456
|
raise ValueError('`http_client` is mutually exclusive with `headers`.')
|
|
@@ -451,15 +482,9 @@ class _MCPServerHTTP(MCPServer):
|
|
|
451
482
|
async with transport_client_partial(headers=self.headers) as (read_stream, write_stream, *_):
|
|
452
483
|
yield read_stream, write_stream
|
|
453
484
|
|
|
454
|
-
def _get_log_level(self) -> LoggingLevel | None:
|
|
455
|
-
return self.log_level
|
|
456
|
-
|
|
457
485
|
def __repr__(self) -> str: # pragma: no cover
|
|
458
486
|
return f'{self.__class__.__name__}(url={self.url!r}, tool_prefix={self.tool_prefix!r})'
|
|
459
487
|
|
|
460
|
-
def _get_client_initialize_timeout(self) -> float: # pragma: no cover
|
|
461
|
-
return self.timeout
|
|
462
|
-
|
|
463
488
|
|
|
464
489
|
@dataclass
|
|
465
490
|
class MCPServerSSE(_MCPServerHTTP):
|
|
@@ -555,7 +580,11 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
|
|
|
555
580
|
|
|
556
581
|
|
|
557
582
|
ToolResult = (
|
|
558
|
-
str
|
|
583
|
+
str
|
|
584
|
+
| messages.BinaryContent
|
|
585
|
+
| dict[str, Any]
|
|
586
|
+
| list[Any]
|
|
587
|
+
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
|
|
559
588
|
)
|
|
560
589
|
"""The result type of a tool call."""
|
|
561
590
|
|
|
@@ -564,7 +593,7 @@ CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[
|
|
|
564
593
|
|
|
565
594
|
ProcessToolCallback = Callable[
|
|
566
595
|
[
|
|
567
|
-
RunContext[Any],
|
|
596
|
+
tools.RunContext[Any],
|
|
568
597
|
CallToolFunc,
|
|
569
598
|
str,
|
|
570
599
|
dict[str, Any],
|
|
@@ -763,7 +763,9 @@ class ThinkingPartDelta:
|
|
|
763
763
|
ValueError: If `part` is not a `ThinkingPart`.
|
|
764
764
|
"""
|
|
765
765
|
if isinstance(part, ThinkingPart):
|
|
766
|
-
|
|
766
|
+
new_content = part.content + self.content_delta if self.content_delta else part.content
|
|
767
|
+
new_signature = self.signature_delta if self.signature_delta is not None else part.signature
|
|
768
|
+
return replace(part, content=new_content, signature=new_signature)
|
|
767
769
|
elif isinstance(part, ThinkingPartDelta):
|
|
768
770
|
if self.content_delta is None and self.signature_delta is None:
|
|
769
771
|
raise ValueError('Cannot apply ThinkingPartDelta with no content or signature')
|
|
@@ -139,8 +139,11 @@ KnownModelName = TypeAliasType(
|
|
|
139
139
|
'google-gla:gemini-2.0-flash-lite-preview-02-05',
|
|
140
140
|
'google-gla:gemini-2.0-pro-exp-02-05',
|
|
141
141
|
'google-gla:gemini-2.5-flash-preview-05-20',
|
|
142
|
+
'google-gla:gemini-2.5-flash',
|
|
143
|
+
'google-gla:gemini-2.5-flash-lite-preview-06-17',
|
|
142
144
|
'google-gla:gemini-2.5-pro-exp-03-25',
|
|
143
145
|
'google-gla:gemini-2.5-pro-preview-05-06',
|
|
146
|
+
'google-gla:gemini-2.5-pro',
|
|
144
147
|
'google-vertex:gemini-1.5-flash',
|
|
145
148
|
'google-vertex:gemini-1.5-flash-8b',
|
|
146
149
|
'google-vertex:gemini-1.5-pro',
|
|
@@ -149,8 +152,11 @@ KnownModelName = TypeAliasType(
|
|
|
149
152
|
'google-vertex:gemini-2.0-flash-lite-preview-02-05',
|
|
150
153
|
'google-vertex:gemini-2.0-pro-exp-02-05',
|
|
151
154
|
'google-vertex:gemini-2.5-flash-preview-05-20',
|
|
155
|
+
'google-vertex:gemini-2.5-flash',
|
|
156
|
+
'google-vertex:gemini-2.5-flash-lite-preview-06-17',
|
|
152
157
|
'google-vertex:gemini-2.5-pro-exp-03-25',
|
|
153
158
|
'google-vertex:gemini-2.5-pro-preview-05-06',
|
|
159
|
+
'google-vertex:gemini-2.5-pro',
|
|
154
160
|
'gpt-3.5-turbo',
|
|
155
161
|
'gpt-3.5-turbo-0125',
|
|
156
162
|
'gpt-3.5-turbo-0301',
|
|
@@ -54,8 +54,11 @@ LatestGeminiModelNames = Literal[
|
|
|
54
54
|
'gemini-2.0-flash-lite-preview-02-05',
|
|
55
55
|
'gemini-2.0-pro-exp-02-05',
|
|
56
56
|
'gemini-2.5-flash-preview-05-20',
|
|
57
|
+
'gemini-2.5-flash',
|
|
58
|
+
'gemini-2.5-flash-lite-preview-06-17',
|
|
57
59
|
'gemini-2.5-pro-exp-03-25',
|
|
58
60
|
'gemini-2.5-pro-preview-05-06',
|
|
61
|
+
'gemini-2.5-pro',
|
|
59
62
|
]
|
|
60
63
|
"""Latest Gemini models."""
|
|
61
64
|
|
|
@@ -79,8 +79,11 @@ LatestGoogleModelNames = Literal[
|
|
|
79
79
|
'gemini-2.0-flash-lite-preview-02-05',
|
|
80
80
|
'gemini-2.0-pro-exp-02-05',
|
|
81
81
|
'gemini-2.5-flash-preview-05-20',
|
|
82
|
+
'gemini-2.5-flash',
|
|
83
|
+
'gemini-2.5-flash-lite-preview-06-17',
|
|
82
84
|
'gemini-2.5-pro-exp-03-25',
|
|
83
85
|
'gemini-2.5-pro-preview-05-06',
|
|
86
|
+
'gemini-2.5-pro',
|
|
84
87
|
]
|
|
85
88
|
"""Latest Gemini models."""
|
|
86
89
|
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations as _annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, cast
|
|
7
|
+
|
|
8
|
+
from .. import _mcp, exceptions, usage
|
|
9
|
+
from ..messages import ModelMessage, ModelResponse
|
|
10
|
+
from ..settings import ModelSettings
|
|
11
|
+
from . import Model, ModelRequestParameters, StreamedResponse
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from mcp import ServerSession
|
|
15
|
+
from mcp.types import ModelPreferences
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MCPSamplingModelSettings(ModelSettings, total=False):
|
|
19
|
+
"""Settings used for an MCP Sampling model request.
|
|
20
|
+
|
|
21
|
+
ALL FIELDS MUST BE `mcp_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
mcp_model_preferences: ModelPreferences
|
|
25
|
+
"""Model preferences to use for MCP Sampling."""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class MCPSamplingModel(Model):
|
|
30
|
+
"""A model that uses MCP Sampling.
|
|
31
|
+
|
|
32
|
+
[MCP Sampling](https://modelcontextprotocol.io/docs/concepts/sampling)
|
|
33
|
+
allows an MCP server to make requests to a model by calling back to the MCP client that connected to it.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
session: ServerSession
|
|
37
|
+
"""The MCP server session to use for sampling."""
|
|
38
|
+
|
|
39
|
+
default_max_tokens: int = 16_384
|
|
40
|
+
"""Default max tokens to use if not set in [`ModelSettings`][pydantic_ai.settings.ModelSettings.max_tokens].
|
|
41
|
+
|
|
42
|
+
Max tokens is a required parameter for MCP Sampling, but optional on
|
|
43
|
+
[`ModelSettings`][pydantic_ai.settings.ModelSettings], so this value is used as fallback.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
async def request(
|
|
47
|
+
self,
|
|
48
|
+
messages: list[ModelMessage],
|
|
49
|
+
model_settings: ModelSettings | None,
|
|
50
|
+
model_request_parameters: ModelRequestParameters,
|
|
51
|
+
) -> ModelResponse:
|
|
52
|
+
system_prompt, sampling_messages = _mcp.map_from_pai_messages(messages)
|
|
53
|
+
model_settings = cast(MCPSamplingModelSettings, model_settings or {})
|
|
54
|
+
|
|
55
|
+
result = await self.session.create_message(
|
|
56
|
+
sampling_messages,
|
|
57
|
+
max_tokens=model_settings.get('max_tokens', self.default_max_tokens),
|
|
58
|
+
system_prompt=system_prompt,
|
|
59
|
+
temperature=model_settings.get('temperature'),
|
|
60
|
+
model_preferences=model_settings.get('mcp_model_preferences'),
|
|
61
|
+
stop_sequences=model_settings.get('stop_sequences'),
|
|
62
|
+
)
|
|
63
|
+
if result.role == 'assistant':
|
|
64
|
+
return ModelResponse(
|
|
65
|
+
parts=[_mcp.map_from_sampling_content(result.content)],
|
|
66
|
+
usage=usage.Usage(requests=1),
|
|
67
|
+
model_name=result.model,
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
raise exceptions.UnexpectedModelBehavior(
|
|
71
|
+
f'Unexpected result from MCP sampling, expected "assistant" role, got {result.role}.'
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@asynccontextmanager
|
|
75
|
+
async def request_stream(
|
|
76
|
+
self,
|
|
77
|
+
messages: list[ModelMessage],
|
|
78
|
+
model_settings: ModelSettings | None,
|
|
79
|
+
model_request_parameters: ModelRequestParameters,
|
|
80
|
+
) -> AsyncIterator[StreamedResponse]:
|
|
81
|
+
raise NotImplementedError('MCP Sampling does not support streaming')
|
|
82
|
+
yield
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def model_name(self) -> str:
|
|
86
|
+
"""The model name.
|
|
87
|
+
|
|
88
|
+
Since the model name isn't known until the request is made, this property always returns `'mcp-sampling'`.
|
|
89
|
+
"""
|
|
90
|
+
return 'mcp-sampling'
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def system(self) -> str:
|
|
94
|
+
"""The system / model provider, returns `'MCP'`."""
|
|
95
|
+
return 'MCP'
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|