mcp-proxy 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sergey Parfenyuk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.1
2
+ Name: mcp-proxy
3
+ Version: 0.2.1
4
+ Author-email: Sergey Parfenyuk <sergey.parfenyuk@gmail.com>
5
+ Requires-Python: >=3.11
6
+ License-File: LICENSE
7
+ Requires-Dist: mcp
@@ -0,0 +1,76 @@
1
+ # mcp-proxy
2
+
3
+ - [mcp-proxy](#mcp-proxy)
4
+ - [About](#about)
5
+ - [Installation](#installation)
6
+ - [Configuration](#configuration)
7
+ - [Claude Desktop Configuration](#claude-desktop-configuration)
8
+ - [Advanced Configuration](#advanced-configuration)
9
+ - [Environment Variables](#environment-variables)
10
+
11
+ ## About
12
+
13
+ Connect to MCP servers that run on SSE transport using the MCP Proxy server.
14
+
15
+ ```mermaid
16
+ graph LR
17
+ A["Claude Desktop"] <--> B["mcp-proxy"]
18
+ B <--> C["External MCP Server"]
19
+
20
+ style A fill:#ffe6f9,stroke:#333,color:black,stroke-width:2px
21
+ style B fill:#e6e6ff,stroke:#333,color:black,stroke-width:2px
22
+ style C fill:#e6ffe6,stroke:#333,color:black,stroke-width:2px
23
+ ```
24
+
25
+ > [!TIP]
26
+ > As of now, Claude Desktop does not support MCP servers that run on SSE transport. This server is a workaround to enable the support.
27
+
28
+ ## Installation
29
+
30
+ ```bash
31
+ uv tool install git+https://github.com/sparfenyuk/mcp-proxy
32
+ ```
33
+
34
+ > [!NOTE]
35
+ > If you have already installed the server, you can update it using `uv tool upgrade --reinstall` command.
36
+
37
+ > [!NOTE]
38
+ > If you want to delete the server, use the `uv tool uninstall mcp-proxy` command.
39
+
40
+ ## Configuration
41
+
42
+ ### Claude Desktop Configuration
43
+
44
+ Configure Claude Desktop to recognize the MCP server.
45
+
46
+ 1. Open the Claude Desktop configuration file:
47
+ - in MacOS, the configuration file is located at `~/Library/Application Support/Claude/claude_desktop_config.json`
48
+ - in Windows, the configuration file is located at `%APPDATA%\Claude\claude_desktop_config.json`
49
+
50
+ > __Note:__
51
+ > You can also find claude_desktop_config.json inside the settings of Claude Desktop app
52
+
53
+ 2. Add the server configuration
54
+
55
+ ```json
56
+ {
57
+ "mcpServers": {
58
+ "mcp-proxy": {
59
+ "command": "mcp-proxy",
60
+ "env": {
61
+ "SSE_URL": "http://example.io/sse"
62
+ }
63
+ }
64
+ }
65
+ }
66
+
67
+ ```
68
+
69
+ ## Advanced Configuration
70
+
71
+ ### Environment Variables
72
+
73
+ | Name | Description |
74
+ | ---------------- | ---------------------------------------------------------------------------------- |
75
+ | SSE_URL | The MCP server SSE endpoint to connect to e.g. http://example.io/sse |
76
+ | API_ACCESS_TOKEN | Added in the `Authorization` header of the HTTP request as a `Bearer` access token |
@@ -0,0 +1,63 @@
1
+ [project]
2
+ name = "mcp-proxy"
3
+ authors = [{ name = "Sergey Parfenyuk", email = "sergey.parfenyuk@gmail.com" }]
4
+ version = "0.2.1"
5
+ requires-python = ">=3.11"
6
+ dependencies = ["mcp"]
7
+
8
+ [build-system]
9
+ requires = ["setuptools"]
10
+ build-backend = "setuptools.build_meta"
11
+
12
+ [project.scripts]
13
+ mcp-proxy = "mcp_proxy.__main__:main"
14
+
15
+ [tool.setuptools.package-data]
16
+ "*" = ["py.typed"]
17
+
18
+ [tool.uv]
19
+ dev-dependencies = [
20
+ "pytest>=8.3.3",
21
+ "pytest-asyncio>=0.25.0",
22
+ "coverage>=7.6.0",
23
+ ]
24
+
25
+ [tool.coverage.run]
26
+ branch = true
27
+
28
+ [tool.coverage.report]
29
+ skip_covered = true
30
+ show_missing = true
31
+ precision = 2
32
+ exclude_lines = [
33
+ 'pragma: no cover',
34
+ 'raise NotImplementedError',
35
+ 'if TYPE_CHECKING:',
36
+ 'if typing.TYPE_CHECKING:',
37
+ '@overload',
38
+ '@typing.overload',
39
+ '\(Protocol\):$',
40
+ 'typing.assert_never',
41
+ '$\s*assert_never\(',
42
+ 'if __name__ == .__main__.:',
43
+ ]
44
+
45
+ [tool.ruff.lint]
46
+ select = ["ALL"]
47
+ ignore = [
48
+ "EM101", # Exception must not use a string literal, assign to variable first
49
+ "TRY003", # Avoid specifying long messages outside the exception class
50
+ "ERA001", # Found commented-out code
51
+ ]
52
+
53
+ [tool.ruff.lint.per-file-ignores]
54
+ "tests/*" = ["S101", "INP001"]
55
+
56
+ [tool.ruff]
57
+ line-length = 100
58
+
59
+ [tool.pytest.ini_options]
60
+ pythonpath = "src"
61
+ addopts = ["--import-mode=importlib"]
62
+ asyncio_mode = "auto"
63
+ asyncio_default_fixture_loop_scope = "function"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,143 @@
1
+ """Create a local server that proxies requests to a remote server over SSE."""
2
+
3
+ import logging
4
+ import typing as t
5
+
6
+ from mcp import server, types
7
+ from mcp.client.session import ClientSession
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ async def create_proxy_server(remote_app: ClientSession) -> server.Server: # noqa: C901
13
+ """Create a server instance from a remote app."""
14
+ response = await remote_app.initialize()
15
+ capabilities = response.capabilities
16
+
17
+ app = server.Server(response.serverInfo.name)
18
+
19
+ if capabilities.prompts:
20
+
21
+ async def _list_prompts(_: t.Any) -> types.ServerResult: # noqa: ANN401
22
+ result = await remote_app.list_prompts()
23
+ return types.ServerResult(result)
24
+
25
+ app.request_handlers[types.ListPromptsRequest] = _list_prompts
26
+
27
+ async def _get_prompt(req: types.GetPromptRequest) -> types.ServerResult:
28
+ result = await remote_app.get_prompt(req.params.name, req.params.arguments)
29
+ return types.ServerResult(result)
30
+
31
+ app.request_handlers[types.GetPromptRequest] = _get_prompt
32
+
33
+ if capabilities.resources:
34
+
35
+ async def _list_resources(_: t.Any) -> types.ServerResult: # noqa: ANN401
36
+ result = await remote_app.list_resources()
37
+ return types.ServerResult(result)
38
+
39
+ app.request_handlers[types.ListResourcesRequest] = _list_resources
40
+
41
+ # list_resource_templates() is not implemented in the client
42
+ # async def _list_resource_templates(_: t.Any) -> types.ServerResult:
43
+ # result = await remote_app.list_resource_templates()
44
+ # return types.ServerResult(result)
45
+
46
+ # app.request_handlers[types.ListResourceTemplatesRequest] = _list_resource_templates
47
+
48
+ async def _read_resource(req: types.ReadResourceRequest) -> types.ServerResult:
49
+ result = await remote_app.read_resource(req.params.uri)
50
+ return types.ServerResult(result)
51
+
52
+ app.request_handlers[types.ReadResourceRequest] = _read_resource
53
+
54
+ if capabilities.logging:
55
+
56
+ async def _set_logging_level(req: types.SetLevelRequest) -> types.ServerResult:
57
+ await remote_app.set_logging_level(req.params.level)
58
+ return types.ServerResult(types.EmptyResult())
59
+
60
+ app.request_handlers[types.SetLevelRequest] = _set_logging_level
61
+
62
+ if capabilities.resources:
63
+
64
+ async def _subscribe_resource(req: types.SubscribeRequest) -> types.ServerResult:
65
+ await remote_app.subscribe_resource(req.params.uri)
66
+ return types.ServerResult(types.EmptyResult())
67
+
68
+ app.request_handlers[types.SubscribeRequest] = _subscribe_resource
69
+
70
+ async def _unsubscribe_resource(req: types.UnsubscribeRequest) -> types.ServerResult:
71
+ await remote_app.unsubscribe_resource(req.params.uri)
72
+ return types.ServerResult(types.EmptyResult())
73
+
74
+ app.request_handlers[types.UnsubscribeRequest] = _unsubscribe_resource
75
+
76
+ if capabilities.tools:
77
+
78
+ async def _list_tools(_: t.Any) -> types.ServerResult: # noqa: ANN401
79
+ tools = await remote_app.list_tools()
80
+ return types.ServerResult(tools)
81
+
82
+ app.request_handlers[types.ListToolsRequest] = _list_tools
83
+
84
+ async def _call_tool(req: types.CallToolRequest) -> types.ServerResult:
85
+ try:
86
+ result = await remote_app.call_tool(
87
+ req.params.name,
88
+ (req.params.arguments or {}),
89
+ )
90
+ return types.ServerResult(result)
91
+ except Exception as e: # noqa: BLE001
92
+ return types.ServerResult(
93
+ types.CallToolResult(
94
+ content=[types.TextContent(type="text", text=str(e))],
95
+ isError=True,
96
+ ),
97
+ )
98
+
99
+ app.request_handlers[types.CallToolRequest] = _call_tool
100
+
101
+ async def _send_progress_notification(req: types.ProgressNotification) -> None:
102
+ await remote_app.send_progress_notification(
103
+ req.params.progressToken,
104
+ req.params.progress,
105
+ req.params.total,
106
+ )
107
+
108
+ app.notification_handlers[types.ProgressNotification] = _send_progress_notification
109
+
110
+ async def _complete(req: types.CompleteRequest) -> types.ServerResult:
111
+ result = await remote_app.complete(
112
+ req.params.ref,
113
+ req.params.argument.model_dump(),
114
+ )
115
+ return types.ServerResult(result)
116
+
117
+ app.request_handlers[types.CompleteRequest] = _complete
118
+
119
+ return app
120
+
121
+
122
+ async def run_sse_client(url: str, api_access_token: str | None = None) -> None:
123
+ """Run the SSE client.
124
+
125
+ Args:
126
+ url: The URL to connect to.
127
+ api_access_token: The API access token to use for authentication.
128
+
129
+ """
130
+ from mcp.client.sse import sse_client
131
+
132
+ headers = {}
133
+ if api_access_token is not None:
134
+ headers["Authorization"] = f"Bearer {api_access_token}"
135
+
136
+ async with sse_client(url=url, headers=headers) as streams, ClientSession(*streams) as session:
137
+ app = await create_proxy_server(session)
138
+ async with server.stdio_server() as (read_stream, write_stream):
139
+ await app.run(
140
+ read_stream,
141
+ write_stream,
142
+ app.create_initialization_options(),
143
+ )
@@ -0,0 +1,30 @@
1
+ """The entry point for the mcp-proxy application. It sets up the logging and runs the main function.
2
+
3
+ Two ways to run the application:
4
+ 1. Run the application as a module `uv run -m mcp_proxy`
5
+ 2. Run the application as a package `uv run mcp-proxy`
6
+
7
+ """
8
+
9
+ import asyncio
10
+ import logging
11
+ import os
12
+ import typing as t
13
+
14
+ from . import run_sse_client
15
+
16
+ logging.basicConfig(level=logging.DEBUG)
17
+ SSE_URL: t.Final[str] = os.getenv("SSE_URL", "")
18
+ API_ACCESS_TOKEN: t.Final[str | None] = os.getenv("API_ACCESS_TOKEN", None)
19
+
20
+ if not SSE_URL:
21
+ raise ValueError("SSE_URL environment variable is not set")
22
+
23
+
24
+ def main() -> None:
25
+ """Start the client using asyncio."""
26
+ asyncio.run(run_sse_client(SSE_URL, api_access_token=API_ACCESS_TOKEN))
27
+
28
+
29
+ if __name__ == "__main__":
30
+ main()
File without changes
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.1
2
+ Name: mcp-proxy
3
+ Version: 0.2.1
4
+ Author-email: Sergey Parfenyuk <sergey.parfenyuk@gmail.com>
5
+ Requires-Python: >=3.11
6
+ License-File: LICENSE
7
+ Requires-Dist: mcp
@@ -0,0 +1,13 @@
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/mcp_proxy/__init__.py
5
+ src/mcp_proxy/__main__.py
6
+ src/mcp_proxy/py.typed
7
+ src/mcp_proxy.egg-info/PKG-INFO
8
+ src/mcp_proxy.egg-info/SOURCES.txt
9
+ src/mcp_proxy.egg-info/dependency_links.txt
10
+ src/mcp_proxy.egg-info/entry_points.txt
11
+ src/mcp_proxy.egg-info/requires.txt
12
+ src/mcp_proxy.egg-info/top_level.txt
13
+ tests/test_init.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ mcp-proxy = mcp_proxy.__main__:main
@@ -0,0 +1 @@
1
+ mcp_proxy
@@ -0,0 +1,464 @@
1
+ """Tests for the mcp-proxy module.
2
+
3
+ Tests are running in two modes:
4
+ - One where the server is exercised directly though an in memory client, just to
5
+ set a baseline for the expected behavior.
6
+ - Another where the server is exercised through a proxy server, which forwards
7
+ the requests to the original server.
8
+
9
+ The same test code is run on both to ensure parity.
10
+ """
11
+
12
+ from collections.abc import AsyncGenerator, Awaitable, Callable
13
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager
14
+ from unittest.mock import AsyncMock
15
+
16
+ import pytest
17
+ from mcp import types
18
+ from mcp.client.session import ClientSession
19
+ from mcp.server import Server
20
+ from mcp.shared.exceptions import McpError
21
+ from mcp.shared.memory import create_connected_server_and_client_session
22
+ from pydantic import AnyUrl
23
+
24
+ from mcp_proxy import create_proxy_server
25
+
26
+ TOOL_INPUT_SCHEMA = {"type": "object", "properties": {"input1": {"type": "string"}}}
27
+
28
+ SessionContextManager = Callable[[Server], AbstractAsyncContextManager[ClientSession]]
29
+
30
+ # Direct server connection
31
+ in_memory: SessionContextManager = create_connected_server_and_client_session
32
+
33
+
34
+ @asynccontextmanager
35
+ async def proxy(server: Server) -> AsyncGenerator[ClientSession, None]:
36
+ """Create a connection to the server through the proxy server."""
37
+ async with in_memory(server) as session:
38
+ wrapped_server = await create_proxy_server(session)
39
+ async with in_memory(wrapped_server) as wrapped_session:
40
+ yield wrapped_session
41
+
42
+
43
+ @pytest.fixture(params=["server", "proxy"])
44
+ def session_generator(request: pytest.FixtureRequest) -> SessionContextManager:
45
+ """Fixture that returns a client creation strategy either direct or using the proxy."""
46
+ if request.param == "server":
47
+ return in_memory
48
+ return proxy
49
+
50
+
51
+ @pytest.fixture
52
+ def server() -> Server:
53
+ """Return a server instance."""
54
+ return Server("test-server")
55
+
56
+
57
+ @pytest.fixture
58
+ def server_can_list_prompts(server: Server, prompt: types.Prompt) -> Server:
59
+ """Return a server instance with prompts."""
60
+
61
+ @server.list_prompts()
62
+ async def _() -> list[types.Prompt]:
63
+ return [prompt]
64
+
65
+ return server
66
+
67
+
68
+ @pytest.fixture
69
+ def server_can_get_prompt(
70
+ server_can_list_prompts: Server,
71
+ prompt_callback: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]],
72
+ ) -> Server:
73
+ """Return a server instance with prompts."""
74
+ server_can_list_prompts.get_prompt()(prompt_callback)
75
+
76
+ return server_can_list_prompts
77
+
78
+
79
+ @pytest.fixture
80
+ def server_can_list_tools(server: Server, tool: types.Tool) -> Server:
81
+ """Return a server instance with tools."""
82
+
83
+ @server.list_tools()
84
+ async def _() -> list[types.Tool]:
85
+ return [tool]
86
+
87
+ return server
88
+
89
+
90
+ @pytest.fixture
91
+ def server_can_call_tool(server_can_list_tools: Server, tool: Callable[..., ...]) -> Server:
92
+ """Return a server instance with tools."""
93
+ server_can_list_tools.call_tool()(tool)
94
+
95
+ return server_can_list_tools
96
+
97
+
98
+ @pytest.fixture
99
+ def server_can_list_resources(server: Server, resource: types.Resource) -> Server:
100
+ """Return a server instance with resources."""
101
+
102
+ @server.list_resources()
103
+ async def _() -> list[types.Resource]:
104
+ return [resource]
105
+
106
+ return server
107
+
108
+
109
+ @pytest.fixture
110
+ def server_can_subscribe_resource(
111
+ server_can_list_resources: Server,
112
+ subscribe_callback: Callable[[AnyUrl], Awaitable[None]],
113
+ ) -> Server:
114
+ """Return a server instance with resource templates."""
115
+ server_can_list_resources.subscribe_resource()(subscribe_callback)
116
+
117
+ return server_can_list_resources
118
+
119
+
120
+ @pytest.fixture
121
+ def server_can_unsubscribe_resource(
122
+ server_can_list_resources: Server,
123
+ unsubscribe_callback: Callable[[AnyUrl], Awaitable[None]],
124
+ ) -> Server:
125
+ """Return a server instance with resource templates."""
126
+ server_can_list_resources.unsubscribe_resource()(unsubscribe_callback)
127
+
128
+ return server_can_list_resources
129
+
130
+
131
+ @pytest.fixture
132
+ def server_can_read_resource(
133
+ server_can_list_resources: Server,
134
+ resource_callback: Callable[[AnyUrl], Awaitable[str | bytes]],
135
+ ) -> Server:
136
+ """Return a server instance with resources."""
137
+ server_can_list_resources.read_resource()(resource_callback)
138
+
139
+ return server_can_list_resources
140
+
141
+
142
+ @pytest.fixture
143
+ def server_can_set_logging_level(
144
+ server: Server,
145
+ logging_level_callback: Callable[[types.LoggingLevel], Awaitable[None]],
146
+ ) -> Server:
147
+ """Return a server instance with logging capabilities."""
148
+ server.set_logging_level()(logging_level_callback)
149
+
150
+ return server
151
+
152
+
153
+ @pytest.fixture
154
+ def server_can_send_progress_notification(
155
+ server: Server,
156
+ ) -> Server:
157
+ """Return a server instance with logging capabilities."""
158
+ return server
159
+
160
+
161
+ @pytest.fixture
162
+ def server_can_complete(
163
+ server: Server,
164
+ complete_callback: Callable[
165
+ [types.PromptReference | types.ResourceReference, types.CompletionArgument],
166
+ Awaitable[types.Completion | None],
167
+ ],
168
+ ) -> Server:
169
+ """Return a server instance with logging capabilities."""
170
+ server.completion()(complete_callback)
171
+ return server
172
+
173
+
174
+ @pytest.mark.parametrize("prompt", [types.Prompt(name="prompt1")])
175
+ async def test_list_prompts(
176
+ session_generator: SessionContextManager,
177
+ server_can_list_prompts: Server,
178
+ prompt: types.Prompt,
179
+ ) -> None:
180
+ """Test list_prompts."""
181
+ async with session_generator(server_can_list_prompts) as session:
182
+ result = await session.initialize()
183
+ assert result.capabilities
184
+ assert result.capabilities.prompts
185
+ assert not result.capabilities.tools
186
+ assert not result.capabilities.resources
187
+ assert not result.capabilities.logging
188
+
189
+ list_prompts_result = await session.list_prompts()
190
+ assert list_prompts_result.prompts == [prompt]
191
+
192
+ with pytest.raises(McpError, match="Method not found"):
193
+ await session.list_tools()
194
+
195
+
196
+ @pytest.mark.parametrize(
197
+ "tool",
198
+ [
199
+ types.Tool(
200
+ name="tool-name",
201
+ description="tool-description",
202
+ inputSchema=TOOL_INPUT_SCHEMA,
203
+ ),
204
+ ],
205
+ )
206
+ async def test_list_tools(
207
+ session_generator: SessionContextManager,
208
+ server_can_list_tools: Server,
209
+ tool: types.Tool,
210
+ ) -> None:
211
+ """Test list_tools."""
212
+ async with session_generator(server_can_list_tools) as session:
213
+ result = await session.initialize()
214
+ assert result.capabilities
215
+ assert result.capabilities.tools
216
+ assert not result.capabilities.prompts
217
+ assert not result.capabilities.resources
218
+ assert not result.capabilities.logging
219
+
220
+ list_tools_result = await session.list_tools()
221
+ assert list_tools_result.tools == [tool]
222
+
223
+ with pytest.raises(McpError, match="Method not found"):
224
+ await session.list_prompts()
225
+
226
+
227
+ @pytest.mark.parametrize("logging_level_callback", [AsyncMock()])
228
+ @pytest.mark.parametrize(
229
+ "log_level",
230
+ ["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"],
231
+ )
232
+ async def test_set_logging_error(
233
+ session_generator: SessionContextManager,
234
+ server_can_set_logging_level: Server,
235
+ logging_level_callback: AsyncMock,
236
+ log_level: types.LoggingLevel,
237
+ ) -> None:
238
+ """Test set_logging_level."""
239
+ async with session_generator(server_can_set_logging_level) as session:
240
+ result = await session.initialize()
241
+ assert result.capabilities
242
+ assert result.capabilities.logging
243
+ assert not result.capabilities.prompts
244
+ assert not result.capabilities.resources
245
+ assert not result.capabilities.tools
246
+
247
+ logging_level_callback.return_value = None
248
+ await session.set_logging_level(log_level)
249
+ logging_level_callback.assert_called_once_with(log_level)
250
+ logging_level_callback.reset_mock() # Reset the mock for the next test case
251
+
252
+
253
+ @pytest.mark.parametrize("tool", [AsyncMock()])
254
+ async def test_call_tool(
255
+ session_generator: SessionContextManager,
256
+ server_can_call_tool: Server,
257
+ tool: AsyncMock,
258
+ ) -> None:
259
+ """Test call_tool."""
260
+ async with session_generator(server_can_call_tool) as session:
261
+ result = await session.initialize()
262
+ assert result.capabilities
263
+ assert result.capabilities
264
+ assert result.capabilities.tools
265
+ assert not result.capabilities.prompts
266
+ assert not result.capabilities.resources
267
+ assert not result.capabilities.logging
268
+
269
+ tool.return_value = []
270
+ call_tool_result = await session.call_tool("tool", {})
271
+ assert call_tool_result.content == []
272
+ assert not call_tool_result.isError
273
+
274
+ tool.assert_called_once_with("tool", {})
275
+ tool.reset_mock()
276
+
277
+
278
+ @pytest.mark.parametrize(
279
+ "resource",
280
+ [
281
+ types.Resource(
282
+ uri=AnyUrl("scheme://resource-uri"),
283
+ name="resource-name",
284
+ description="resource-description",
285
+ ),
286
+ ],
287
+ )
288
+ async def test_list_resources(
289
+ session_generator: SessionContextManager,
290
+ server_can_list_resources: Server,
291
+ resource: types.Resource,
292
+ ) -> None:
293
+ """Test get_resource."""
294
+ async with session_generator(server_can_list_resources) as session:
295
+ result = await session.initialize()
296
+ assert result.capabilities
297
+ assert result.capabilities.resources
298
+ assert not result.capabilities.prompts
299
+ assert not result.capabilities.tools
300
+ assert not result.capabilities.logging
301
+
302
+ list_resources_result = await session.list_resources()
303
+ assert list_resources_result.resources == [resource]
304
+
305
+
306
+ @pytest.mark.parametrize("prompt_callback", [AsyncMock()])
307
+ @pytest.mark.parametrize("prompt", [types.Prompt(name="prompt1")])
308
+ async def test_get_prompt(
309
+ session_generator: SessionContextManager,
310
+ server_can_get_prompt: Server,
311
+ prompt_callback: AsyncMock,
312
+ ) -> None:
313
+ """Test get_prompt."""
314
+ async with session_generator(server_can_get_prompt) as session:
315
+ await session.initialize()
316
+
317
+ prompt_callback.return_value = types.GetPromptResult(messages=[])
318
+
319
+ await session.get_prompt("prompt", {})
320
+ prompt_callback.assert_called_once_with("prompt", {})
321
+ prompt_callback.reset_mock()
322
+
323
+
324
+ @pytest.mark.parametrize("resource_callback", [AsyncMock()])
325
+ @pytest.mark.parametrize(
326
+ "resource",
327
+ [
328
+ types.Resource(
329
+ uri=AnyUrl("scheme://resource-uri"),
330
+ name="resource-name",
331
+ description="resource-description",
332
+ ),
333
+ ],
334
+ )
335
+ async def test_read_resource(
336
+ session_generator: SessionContextManager,
337
+ server_can_read_resource: Server,
338
+ resource_callback: AsyncMock,
339
+ resource: types.Resource,
340
+ ) -> None:
341
+ """Test read_resource."""
342
+ async with session_generator(server_can_read_resource) as session:
343
+ await session.initialize()
344
+
345
+ resource_callback.return_value = "resource-content"
346
+ await session.read_resource(resource.uri)
347
+ resource_callback.assert_called_once_with(resource.uri)
348
+ resource_callback.reset_mock()
349
+
350
+
351
+ @pytest.mark.parametrize("subscribe_callback", [AsyncMock()])
352
+ @pytest.mark.parametrize(
353
+ "resource",
354
+ [
355
+ types.Resource(
356
+ uri=AnyUrl("scheme://resource-uri"),
357
+ name="resource-name",
358
+ description="resource-description",
359
+ ),
360
+ ],
361
+ )
362
+ async def test_subscribe_resource(
363
+ session_generator: SessionContextManager,
364
+ server_can_subscribe_resource: Server,
365
+ subscribe_callback: AsyncMock,
366
+ resource: types.Resource,
367
+ ) -> None:
368
+ """Test subscribe_resource."""
369
+ async with session_generator(server_can_subscribe_resource) as session:
370
+ await session.initialize()
371
+
372
+ subscribe_callback.return_value = None
373
+ await session.subscribe_resource(resource.uri)
374
+ subscribe_callback.assert_called_once_with(resource.uri)
375
+ subscribe_callback.reset_mock()
376
+
377
+
378
+ @pytest.mark.parametrize("unsubscribe_callback", [AsyncMock()])
379
+ @pytest.mark.parametrize(
380
+ "resource",
381
+ [
382
+ types.Resource(
383
+ uri=AnyUrl("scheme://resource-uri"),
384
+ name="resource-name",
385
+ description="resource-description",
386
+ ),
387
+ ],
388
+ )
389
+ async def test_unsubscribe_resource(
390
+ session_generator: SessionContextManager,
391
+ server_can_unsubscribe_resource: Server,
392
+ unsubscribe_callback: AsyncMock,
393
+ resource: types.Resource,
394
+ ) -> None:
395
+ """Test subscribe_resource."""
396
+ async with session_generator(server_can_unsubscribe_resource) as session:
397
+ await session.initialize()
398
+
399
+ unsubscribe_callback.return_value = None
400
+ await session.unsubscribe_resource(resource.uri)
401
+ unsubscribe_callback.assert_called_once_with(resource.uri)
402
+ unsubscribe_callback.reset_mock()
403
+
404
+
405
+ async def test_send_progress_notification(
406
+ session_generator: SessionContextManager,
407
+ server_can_send_progress_notification: Server,
408
+ ) -> None:
409
+ """Test send_progress_notification."""
410
+ async with session_generator(server_can_send_progress_notification) as session:
411
+ await session.initialize()
412
+
413
+ await session.send_progress_notification(
414
+ progress_token=1,
415
+ progress=0.5,
416
+ total=1,
417
+ )
418
+
419
+
420
+ @pytest.mark.parametrize("complete_callback", [AsyncMock()])
421
+ async def test_complete(
422
+ session_generator: SessionContextManager,
423
+ server_can_complete: Server,
424
+ complete_callback: AsyncMock,
425
+ ) -> None:
426
+ """Test complete."""
427
+ async with session_generator(server_can_complete) as session:
428
+ await session.initialize()
429
+
430
+ complete_callback.return_value = None
431
+ result = await session.complete(
432
+ types.PromptReference(type="ref/prompt", name="name"),
433
+ argument={"name": "name", "value": "value"},
434
+ )
435
+
436
+ assert result.completion.values == []
437
+
438
+ complete_callback.assert_called_with(
439
+ types.PromptReference(type="ref/prompt", name="name"),
440
+ types.CompletionArgument(name="name", value="value"),
441
+ )
442
+ complete_callback.reset_mock()
443
+
444
+
445
+ @pytest.mark.parametrize("tool", [AsyncMock()])
446
+ async def test_call_tool_with_error(
447
+ session_generator: SessionContextManager,
448
+ server_can_call_tool: Server,
449
+ tool: AsyncMock,
450
+ ) -> None:
451
+ """Test call_tool."""
452
+ async with session_generator(server_can_call_tool) as session:
453
+ result = await session.initialize()
454
+ assert result.capabilities
455
+ assert result.capabilities
456
+ assert result.capabilities.tools
457
+ assert not result.capabilities.prompts
458
+ assert not result.capabilities.resources
459
+ assert not result.capabilities.logging
460
+
461
+ tool.side_effect = Exception("Error")
462
+
463
+ call_tool_result = await session.call_tool("tool", {})
464
+ assert call_tool_result.isError