pydantic-rpc 0.6.1__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ """MCP (Model Context Protocol) support for pydantic-rpc."""
2
+
3
+ from .exporter import MCPExporter
4
+
5
+ __all__ = ["MCPExporter"]
@@ -0,0 +1,115 @@
1
+ """Type conversion utilities for MCP integration."""
2
+
3
+ import datetime
4
+ import enum
5
+ import inspect
6
+ from collections.abc import AsyncIterator
7
+ from typing import Any, Callable, Union, get_args, get_origin
8
+
9
+ from pydantic import BaseModel
10
+
11
+ _SIMPLE_TYPE_MAP = {
12
+ int: {"type": "integer"},
13
+ float: {"type": "number"},
14
+ str: {"type": "string"},
15
+ bool: {"type": "boolean"},
16
+ bytes: {"type": "string", "format": "byte"},
17
+ datetime.datetime: {"type": "string", "format": "date-time"},
18
+ datetime.timedelta: {"type": "string", "format": "duration"},
19
+ }
20
+
21
+
22
+ def is_streaming_return(return_type: Any) -> bool:
23
+ """Check if the return type is a streaming response (AsyncIterator)."""
24
+ origin = get_origin(return_type)
25
+ return origin is AsyncIterator
26
+
27
+
28
+ def python_type_to_json_type(python_type: Any) -> dict[str, Any]:
29
+ """Convert Python type to JSON Schema type."""
30
+ if python_type in _SIMPLE_TYPE_MAP:
31
+ return _SIMPLE_TYPE_MAP[python_type].copy()
32
+
33
+ match python_type:
34
+ case t if get_origin(t) is list:
35
+ item_type = get_args(t)[0]
36
+ return {"type": "array", "items": python_type_to_json_type(item_type)}
37
+ case t if get_origin(t) is dict:
38
+ _, value_type = get_args(t)
39
+ return {
40
+ "type": "object",
41
+ "additionalProperties": python_type_to_json_type(value_type),
42
+ }
43
+ case t if get_origin(t) is Union:
44
+ union_args = get_args(t)
45
+ non_none_types = [arg for arg in union_args if arg is not type(None)]
46
+ if len(non_none_types) == 1:
47
+ schema = python_type_to_json_type(non_none_types[0])
48
+ schema["nullable"] = True
49
+ return schema
50
+ else:
51
+ return {
52
+ "oneOf": [python_type_to_json_type(arg) for arg in non_none_types]
53
+ }
54
+ case t if inspect.isclass(t) and issubclass(t, BaseModel):
55
+ return t.model_json_schema()
56
+ case t if inspect.isclass(t) and issubclass(t, enum.Enum):
57
+ return {"type": "string", "enum": [e.value for e in t]}
58
+ case _:
59
+ # Default to object type for unknown types
60
+ return {"type": "object"}
61
+
62
+
63
+ def extract_method_info(method: Callable[..., Any]) -> dict[str, Any]:
64
+ """Extract method information for MCP tool definition."""
65
+ sig = inspect.signature(method)
66
+ doc = inspect.getdoc(method) or ""
67
+
68
+ # Get parameter types (skip 'self' for instance methods)
69
+ params = list(sig.parameters.values())
70
+ if params and params[0].name in ("self", "cls"):
71
+ params = params[1:]
72
+
73
+ # Extract input type
74
+ input_type = None
75
+ if params:
76
+ input_type = params[0].annotation
77
+
78
+ # Extract return type
79
+ return_type = sig.return_annotation
80
+
81
+ # Build parameter schema
82
+ parameters_schema = {}
83
+ if input_type and input_type != inspect._empty:
84
+ if inspect.isclass(input_type) and issubclass(input_type, BaseModel):
85
+ # Use Pydantic's built-in schema generation
86
+ parameters_schema = input_type.model_json_schema()
87
+ else:
88
+ parameters_schema = python_type_to_json_type(input_type)
89
+
90
+ # Build response schema
91
+ response_schema = {}
92
+ if return_type and return_type != inspect._empty:
93
+ if is_streaming_return(return_type):
94
+ # For streaming responses, extract the inner type
95
+ inner_type = get_args(return_type)[0]
96
+ response_schema = {
97
+ "type": "object",
98
+ "properties": {
99
+ "stream": {
100
+ "type": "array",
101
+ "items": python_type_to_json_type(inner_type),
102
+ }
103
+ },
104
+ }
105
+ elif inspect.isclass(return_type) and issubclass(return_type, BaseModel):
106
+ response_schema = return_type.model_json_schema()
107
+ else:
108
+ response_schema = python_type_to_json_type(return_type)
109
+
110
+ return {
111
+ "description": doc,
112
+ "parameters": parameters_schema,
113
+ "response": response_schema,
114
+ "is_streaming": is_streaming_return(return_type),
115
+ }
@@ -0,0 +1,283 @@
1
+ """MCP exporter for pydantic-rpc services using the official MCP SDK."""
2
+
3
+ import asyncio
4
+ import inspect
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ from pydantic import BaseModel
9
+
10
+ try:
11
+ from mcp.server import InitializationOptions, Server
12
+ from mcp.server.sse import SseServerTransport
13
+ from mcp.server.stdio import stdio_server
14
+ from mcp.types import (
15
+ Content,
16
+ ServerCapabilities,
17
+ TextContent,
18
+ Tool,
19
+ )
20
+ except ImportError:
21
+ raise ImportError("mcp is required for MCP support. Install with: pip install mcp")
22
+
23
+ from .converter import extract_method_info
24
+
25
+
26
+ class MCPExporter:
27
+ """Export pydantic-rpc services as MCP tools using the official MCP SDK."""
28
+
29
+ def __init__(
30
+ self,
31
+ service_obj: object,
32
+ name: str | None = None,
33
+ description: str | None = None,
34
+ ):
35
+ """Initialize MCPExporter with a service object.
36
+
37
+ Args:
38
+ service_obj: The service object containing RPC methods to export as MCP tools.
39
+ name: Name for the MCP server (defaults to service class name)
40
+ description: Description for the MCP server
41
+ """
42
+ self.service: object = service_obj
43
+ self.name: str = name or service_obj.__class__.__name__
44
+ self.description: str = (
45
+ description or f"MCP tools from {service_obj.__class__.__name__}"
46
+ )
47
+
48
+ # Create MCP Server instance
49
+ self.server: Server[Any] = Server(
50
+ self.name, version="1.0.0", instructions=self.description
51
+ )
52
+
53
+ # Store tools for later reference
54
+ self.tools: dict[str, tuple[Tool, Any]] = {}
55
+
56
+ # SSE transport instance (created lazily)
57
+ self._sse_transport: SseServerTransport | None = None
58
+
59
+ # Register handlers
60
+ self._register_handlers()
61
+
62
+ # Extract and store tools
63
+ self._extract_tools()
64
+
65
+ def _register_handlers(self):
66
+ """Register MCP protocol handlers."""
67
+
68
+ @self.server.list_tools()
69
+ async def handle_list_tools() -> list[Tool]: # pyright:ignore[reportUnusedFunction]
70
+ """List all available tools."""
71
+ return [tool for tool, _ in self.tools.values()]
72
+
73
+ @self.server.call_tool()
74
+ async def handle_call_tool( # pyright:ignore[reportUnusedFunction]
75
+ name: str, arguments: dict[str, Any] | None
76
+ ) -> list[Content]:
77
+ """Execute a tool."""
78
+ if name not in self.tools:
79
+ raise ValueError(f"Unknown tool: {name}")
80
+
81
+ _, method = self.tools[name]
82
+
83
+ # Extract arguments from the request
84
+ args = arguments or {}
85
+
86
+ # Call the method
87
+ if inspect.iscoroutinefunction(method):
88
+ result = await method(**args)
89
+ else:
90
+ result = method(**args)
91
+
92
+ # Convert result to TextContent
93
+ if isinstance(result, BaseModel):
94
+ content = result.model_dump_json(indent=2)
95
+ else:
96
+ content = str(result)
97
+
98
+ return [TextContent(type="text", text=content)]
99
+
100
+ def _extract_tools(self):
101
+ """Extract tools from the service object."""
102
+ for method_name, method in inspect.getmembers(self.service, inspect.ismethod):
103
+ # Skip private methods
104
+ if method_name.startswith("_"):
105
+ continue
106
+
107
+ # Exclude methods from external modules (like pytest fixtures)
108
+ method_module = inspect.getmodule(method)
109
+ if method_module and not method_module.__name__.startswith(
110
+ self.service.__class__.__module__
111
+ ):
112
+ continue
113
+
114
+ tool_name = method_name.lower()
115
+
116
+ # Get method info
117
+ try:
118
+ method_info = extract_method_info(method)
119
+ except Exception:
120
+ # Skip methods that can't be processed
121
+ continue
122
+
123
+ # Create Tool definition
124
+ tool = Tool(
125
+ name=tool_name,
126
+ description=method_info["description"] or f"Execute {method_name}",
127
+ inputSchema=method_info["parameters"],
128
+ )
129
+
130
+ # Create a wrapper that handles Pydantic model parameters
131
+ sig = inspect.signature(method)
132
+ params = list(sig.parameters.values())
133
+ if params and params[0].name in ("self", "cls"):
134
+ params = params[1:]
135
+
136
+ if params and params[0].annotation != inspect._empty:
137
+ param_type = params[0].annotation
138
+ if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
139
+ # For Pydantic models, create a wrapper that constructs the model
140
+ if inspect.iscoroutinefunction(method):
141
+
142
+ def make_async_wrapper(
143
+ m: Callable[..., Any], pt: type[BaseModel]
144
+ ) -> Callable[..., Any]:
145
+ async def wrapped_method(**kwargs: Any) -> Any:
146
+ request = pt(**kwargs)
147
+ return await m(request)
148
+
149
+ return wrapped_method
150
+
151
+ self.tools[tool_name] = (
152
+ tool,
153
+ make_async_wrapper(method, param_type),
154
+ )
155
+ else:
156
+
157
+ def make_sync_wrapper(
158
+ m: Callable[..., Any], pt: type[BaseModel]
159
+ ) -> Callable[..., Any]:
160
+ def wrapped_method(**kwargs: Any) -> Any:
161
+ request = pt(**kwargs)
162
+ return m(request)
163
+
164
+ return wrapped_method
165
+
166
+ self.tools[tool_name] = (
167
+ tool,
168
+ make_sync_wrapper(method, param_type),
169
+ )
170
+ else:
171
+ # For non-Pydantic types, use the method directly
172
+ self.tools[tool_name] = (tool, method)
173
+ else:
174
+ # No parameters
175
+ self.tools[tool_name] = (tool, method)
176
+
177
+ def run_stdio(self):
178
+ """Run the MCP server in stdio mode."""
179
+ asyncio.run(self._run_stdio())
180
+
181
+ async def _run_stdio(self):
182
+ """Async implementation of stdio server."""
183
+ async with stdio_server() as (read_stream, write_stream):
184
+ init_options = InitializationOptions(
185
+ server_name=self.name,
186
+ server_version="1.0.0",
187
+ capabilities=ServerCapabilities(),
188
+ )
189
+ await self.server.run(read_stream, write_stream, init_options)
190
+
191
+ def get_asgi_app(self, path: str = "/mcp"):
192
+ """Get the ASGI app for HTTP/SSE transport.
193
+
194
+ Args:
195
+ path: The base path for MCP endpoints (default: "/mcp")
196
+
197
+ Returns:
198
+ An ASGI application that can be mounted or run directly.
199
+ """
200
+ _ = path
201
+ try:
202
+ from starlette.applications import Starlette
203
+ from starlette.routing import Mount, Route
204
+ except ImportError:
205
+ raise ImportError(
206
+ "starlette is required for HTTP/SSE transport. "
207
+ "Install with: pip install starlette"
208
+ )
209
+
210
+ # Create SSE transport if not already created
211
+ if self._sse_transport is None:
212
+ self._sse_transport = SseServerTransport("/messages/")
213
+
214
+ # Get transport (guaranteed to be non-None after above check)
215
+ sse_transport = self._sse_transport
216
+
217
+ # Create SSE endpoint handler
218
+ async def handle_sse(request: Any) -> None:
219
+ # Use ASGI interface directly
220
+ scope = request.scope
221
+ receive = request.receive
222
+ send = request._send
223
+
224
+ async with sse_transport.connect_sse(scope, receive, send) as (
225
+ read_stream,
226
+ write_stream,
227
+ ):
228
+ init_options = InitializationOptions(
229
+ server_name=self.name,
230
+ server_version="1.0.0",
231
+ capabilities=ServerCapabilities(),
232
+ )
233
+ await self.server.run(read_stream, write_stream, init_options)
234
+
235
+ # Create Starlette app with routes
236
+ app = Starlette(
237
+ routes=[
238
+ Route("/sse", endpoint=handle_sse, methods=["GET"]),
239
+ Mount("/messages/", app=sse_transport.handle_post_message),
240
+ ]
241
+ )
242
+
243
+ return app
244
+
245
+ def mount_to_asgi(self, asgi_app: Any, path: str = "/mcp"):
246
+ """Mount MCP endpoints to an existing ASGI application.
247
+
248
+ Works with Starlette/FastAPI applications and pydantic-rpc ASGI apps.
249
+
250
+ Args:
251
+ asgi_app: The ASGI application to mount to
252
+ path: The path prefix for MCP endpoints (default: "/mcp")
253
+ """
254
+ mcp_asgi = self.get_asgi_app(path)
255
+
256
+ # Check if the app has a mount method (Starlette/FastAPI)
257
+ if hasattr(asgi_app, "mount"):
258
+ asgi_app.mount(path, mcp_asgi)
259
+ # Check if it's a pydantic-rpc ASGI app (ASGIApp/ConnecpyASGIApp)
260
+ elif hasattr(asgi_app, "_app"):
261
+ original_app = asgi_app._app
262
+
263
+ async def wrapped_app(
264
+ scope: dict[str, Any],
265
+ receive: Callable[[], Any],
266
+ send: Callable[[Any], Any],
267
+ ) -> None:
268
+ if scope["type"] == "http" and scope["path"].startswith(path):
269
+ # Create a new scope with adjusted path
270
+ scope = dict(scope)
271
+ scope["path"] = scope["path"][len(path) :]
272
+ if not scope["path"]:
273
+ scope["path"] = "/"
274
+ await mcp_asgi(scope, receive, send)
275
+ else:
276
+ await original_app(scope, receive, send)
277
+
278
+ asgi_app._app = wrapped_app
279
+ else:
280
+ raise ValueError(
281
+ "Unable to mount MCP to the provided ASGI app. "
282
+ "The app must have either a 'mount' method or '_app' attribute."
283
+ )
pydantic_rpc/py.typed CHANGED
File without changes