stirrup 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stirrup/tools/mcp.py ADDED
@@ -0,0 +1,500 @@
1
+ """MCP (Model Context Protocol) tool provider for connecting to MCP servers.
2
+
3
+ This module provides MCPToolProvider, a ToolProvider that manages connections to
4
+ multiple MCP servers and exposes each MCP tool as a separate Tool object.
5
+
6
+ Example usage:
7
+ ```python
8
+ from stirrup.clients.chat_completions_client import ChatCompletionsClient
9
+
10
+ # With Agent (preferred)
11
+ client = ChatCompletionsClient(model="gpt-5")
12
+ agent = Agent(
13
+ client=client,
14
+ name="assistant",
15
+ tools=[*DEFAULT_TOOLS, MCPToolProvider.from_config("mcp.json")],
16
+ )
17
+ async with agent.session() as session:
18
+ await session.run("Use MCP tools")
19
+
20
+ # Standalone usage
21
+ provider = MCPToolProvider.from_config(Path("mcp.json"))
22
+ async with provider as tools:
23
+ # tools is a list of Tool objects
24
+ pass
25
+ ```
26
+
27
+ Requires the optional `mcp` dependency:
28
+ pip install stirrup[mcp]
29
+ """
30
+
31
+ from collections.abc import AsyncIterator
32
+ from contextlib import AsyncExitStack, asynccontextmanager
33
+ from pathlib import Path
34
+ from types import TracebackType
35
+ from typing import Any, Self
36
+
37
+ from json_schema_to_pydantic import create_model
38
+ from pydantic import BaseModel, Field, model_validator
39
+
40
+ from stirrup.core.models import Tool, ToolProvider, ToolResult, ToolUseCountMetadata
41
+
42
+ # MCP imports (optional dependency)
43
+ try:
44
+ from mcp import ClientSession, StdioServerParameters
45
+ from mcp.client.sse import sse_client
46
+ from mcp.client.stdio import stdio_client
47
+ from mcp.client.streamable_http import streamablehttp_client
48
+ except ImportError as e:
49
+ raise ImportError(
50
+ "Requires installation of the mcp extra. Install with (for example): `uv pip install stirrup[mcp]` or `uv add stirrup[mcp]`",
51
+ ) from e
52
+
53
+ # WebSocket client requires additional 'websockets' package
54
+ try:
55
+ from mcp.client.websocket import websocket_client
56
+ except ImportError:
57
+ websocket_client = None # type: ignore[assignment, misc]
58
+
59
+
60
+ __all__ = [
61
+ "MCPConfig",
62
+ "MCPServerConfig",
63
+ "MCPToolProvider",
64
+ "SseServerConfig",
65
+ "StdioServerConfig",
66
+ "StreamableHttpServerConfig",
67
+ "WebSocketServerConfig",
68
+ ]
69
+
70
+
71
+ # === Models ===
72
+
73
+
74
+ class StdioServerConfig(BaseModel):
75
+ """Configuration for stdio-based MCP servers (local process)."""
76
+
77
+ command: str
78
+ """Command to run the MCP server (e.g., "npx", "python")."""
79
+
80
+ args: list[str] = Field(default_factory=list)
81
+ """Arguments to pass to the command."""
82
+
83
+ env: dict[str, str] | None = None
84
+ """Environment variables to set for the server process."""
85
+
86
+ cwd: str | None = None
87
+ """Working directory for the server process."""
88
+
89
+ encoding: str = "utf-8"
90
+ """Text encoding for messages."""
91
+
92
+
93
+ class SseServerConfig(BaseModel):
94
+ """Configuration for SSE-based MCP servers (HTTP GET with Server-Sent Events)."""
95
+
96
+ url: str
97
+ """The SSE endpoint URL (must end with /sse)."""
98
+
99
+ headers: dict[str, str] | None = None
100
+ """Optional HTTP headers."""
101
+
102
+ timeout: float = 5.0
103
+ """HTTP timeout for regular operations (seconds)."""
104
+
105
+ sse_read_timeout: float = 300.0
106
+ """Timeout for SSE read operations (seconds)."""
107
+
108
+
109
+ class StreamableHttpServerConfig(BaseModel):
110
+ """Configuration for Streamable HTTP MCP servers (HTTP POST with optional SSE responses)."""
111
+
112
+ url: str
113
+ """The endpoint URL."""
114
+
115
+ headers: dict[str, str] | None = None
116
+ """Optional HTTP headers."""
117
+
118
+ timeout: float = 30.0
119
+ """HTTP timeout (seconds)."""
120
+
121
+ sse_read_timeout: float = 300.0
122
+ """SSE read timeout (seconds)."""
123
+
124
+ terminate_on_close: bool = True
125
+ """Close session when transport closes."""
126
+
127
+
128
+ class WebSocketServerConfig(BaseModel):
129
+ """Configuration for WebSocket-based MCP servers."""
130
+
131
+ url: str
132
+ """The WebSocket URL (must start with ws:// or wss://)."""
133
+
134
+
135
+ # Type alias for the union of all server config types
136
+ MCPServerConfig = StdioServerConfig | SseServerConfig | StreamableHttpServerConfig | WebSocketServerConfig
137
+
138
+
139
+ def _infer_server_config(data: dict[str, Any]) -> MCPServerConfig:
140
+ """Infer and instantiate the correct config class from raw data.
141
+
142
+ Inference rules:
143
+ - 'command' field present -> StdioServerConfig
144
+ - 'url' starts with ws:// or wss:// -> WebSocketServerConfig
145
+ - 'url' ends with /sse -> SseServerConfig
146
+ - 'url' present (default) -> StreamableHttpServerConfig
147
+
148
+ Args:
149
+ data: Raw configuration dictionary.
150
+
151
+ Returns:
152
+ Appropriate server config instance.
153
+
154
+ Raises:
155
+ ValueError: If neither 'command' nor 'url' is provided.
156
+ """
157
+ if "command" in data:
158
+ return StdioServerConfig(**data)
159
+ if "url" in data:
160
+ url = data["url"]
161
+ if url.startswith(("ws://", "wss://")):
162
+ return WebSocketServerConfig(**data)
163
+ if url.endswith("/sse"):
164
+ return SseServerConfig(**data)
165
+ return StreamableHttpServerConfig(**data)
166
+ raise ValueError("Config must have 'command' (stdio) or 'url' (SSE/HTTP/WebSocket)")
167
+
168
+
169
+ class MCPConfig(BaseModel):
170
+ """Root configuration matching mcp.json format."""
171
+
172
+ mcp_servers: dict[str, MCPServerConfig] = Field(alias="mcpServers")
173
+ """Map of server names to their configurations."""
174
+
175
+ @model_validator(mode="before")
176
+ @classmethod
177
+ def _infer_transport_types(cls, data: dict[str, Any]) -> dict[str, Any]:
178
+ """Convert raw server configs to appropriate typed instances."""
179
+ if "mcpServers" in data:
180
+ data["mcpServers"] = {
181
+ name: _infer_server_config(config) if isinstance(config, dict) else config
182
+ for name, config in data["mcpServers"].items()
183
+ }
184
+ return data
185
+
186
+
187
+ # === Manager ===
188
+
189
+
190
+ class MCPToolProvider(ToolProvider):
191
+ """MCP tool provider that manages connections to multiple MCP servers.
192
+
193
+ MCPToolProvider connects to MCP servers and exposes each server's tools
194
+ as individual Tool objects.
195
+
196
+ Usage with Agent (preferred):
197
+ from stirrup.clients.chat_completions_client import ChatCompletionsClient
198
+
199
+ client = ChatCompletionsClient(model="gpt-5")
200
+ agent = Agent(
201
+ client=client,
202
+ name="assistant",
203
+ tools=[*DEFAULT_TOOLS, MCPToolProvider.from_config("mcp.json")],
204
+ )
205
+
206
+ async with agent.session(output_dir="./output") as session:
207
+ await session.run("Use MCP tools")
208
+
209
+ Standalone usage with connect() context manager:
210
+ provider = MCPToolProvider.from_config(Path("mcp.json"))
211
+ async with provider.connect() as provider:
212
+ tools = provider.get_all_tools()
213
+ # Use tools...
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ config: MCPConfig,
219
+ server_names: list[str] | None = None,
220
+ ) -> None:
221
+ """Initialize the MCP manager.
222
+
223
+ Args:
224
+ config: MCPConfig instance.
225
+ server_names: Which servers to connect to. If None, connects to all servers in config.
226
+ """
227
+ self._config = config
228
+ self._server_names = server_names
229
+ self._servers: dict[str, ClientSession] = {}
230
+ self._tools: dict[str, list[dict[str, Any]]] = {}
231
+ self._exit_stack: AsyncExitStack | None = None
232
+
233
+ @classmethod
234
+ def from_config(cls, config_path: Path | str, server_names: list[str] | None = None) -> Self:
235
+ """Create an MCPToolProvider from a config file.
236
+
237
+ Args:
238
+ config_path: Path to the MCP config file.
239
+ server_names: Which servers to connect to. If None, connects to all servers in config.
240
+
241
+ Returns:
242
+ MCPToolProvider instance.
243
+ """
244
+ config = MCPConfig.model_validate_json(Path(config_path).read_text())
245
+
246
+ return cls(config=config, server_names=server_names)
247
+
248
+ @asynccontextmanager
249
+ async def connect(self) -> AsyncIterator[Self]:
250
+ """Connect to MCP servers from config file.
251
+
252
+ Yields:
253
+ Self with active connections to specified servers.
254
+
255
+ Raises:
256
+ FileNotFoundError: If config file doesn't exist.
257
+ KeyError: If a specified server name doesn't exist in config.
258
+ """
259
+ config = self._config
260
+ servers_to_connect = self._server_names or list(config.mcp_servers.keys())
261
+
262
+ async with AsyncExitStack() as stack:
263
+ for name in servers_to_connect:
264
+ if name not in config.mcp_servers:
265
+ raise KeyError(f"Server '{name}' not found in config. Available: {list(config.mcp_servers.keys())}")
266
+
267
+ server_config = config.mcp_servers[name]
268
+
269
+ # Connect to server based on transport type
270
+ match server_config:
271
+ case StdioServerConfig():
272
+ server_params = StdioServerParameters(
273
+ command=server_config.command,
274
+ args=server_config.args,
275
+ env=server_config.env,
276
+ cwd=server_config.cwd,
277
+ encoding=server_config.encoding,
278
+ )
279
+ read, write = await stack.enter_async_context(stdio_client(server_params))
280
+ case SseServerConfig():
281
+ read, write = await stack.enter_async_context(
282
+ sse_client(
283
+ url=server_config.url,
284
+ headers=server_config.headers,
285
+ timeout=server_config.timeout,
286
+ sse_read_timeout=server_config.sse_read_timeout,
287
+ )
288
+ )
289
+ case StreamableHttpServerConfig():
290
+ read, write, _ = await stack.enter_async_context(
291
+ streamablehttp_client(
292
+ url=server_config.url,
293
+ headers=server_config.headers,
294
+ timeout=server_config.timeout,
295
+ sse_read_timeout=server_config.sse_read_timeout,
296
+ terminate_on_close=server_config.terminate_on_close,
297
+ )
298
+ )
299
+ case WebSocketServerConfig():
300
+ if websocket_client is None:
301
+ raise ImportError(
302
+ f"WebSocket transport for server '{name}' requires the 'websockets' package. "
303
+ "Install with: pip install websockets"
304
+ )
305
+ read, write = await stack.enter_async_context(websocket_client(url=server_config.url))
306
+
307
+ session = await stack.enter_async_context(ClientSession(read, write))
308
+ await session.initialize()
309
+
310
+ # Cache session and available tools
311
+ self._servers[name] = session
312
+ response = await session.list_tools()
313
+ self._tools[name] = [
314
+ {"name": t.name, "description": t.description, "schema": t.inputSchema} for t in response.tools
315
+ ]
316
+
317
+ try:
318
+ yield self
319
+ finally:
320
+ self._servers.clear()
321
+ self._tools.clear()
322
+
323
+ @property
324
+ def servers(self) -> list[str]:
325
+ """List of connected server names."""
326
+ return list(self._servers.keys())
327
+
328
+ def get_tools(self, server: str) -> list[dict[str, Any]]:
329
+ """Get available tools for a specific server.
330
+
331
+ Args:
332
+ server: Server name.
333
+
334
+ Returns:
335
+ List of tool info dicts with name, description, and schema.
336
+ """
337
+ return self._tools.get(server, [])
338
+
339
+ @property
340
+ def all_tools(self) -> dict[str, list[str]]:
341
+ """Get all available tools grouped by server.
342
+
343
+ Returns:
344
+ Dict mapping server names to lists of tool names.
345
+ """
346
+ return {server: [t["name"] for t in tools] for server, tools in self._tools.items()}
347
+
348
+ async def call_tool(self, server: str, tool_name: str, arguments: dict[str, Any]) -> str:
349
+ """Call a tool on a specific MCP server.
350
+
351
+ Args:
352
+ server: Name of the MCP server.
353
+ tool_name: Name of the tool to call.
354
+ arguments: Arguments to pass to the tool.
355
+
356
+ Returns:
357
+ Tool result as a string (text content extracted from response).
358
+
359
+ Raises:
360
+ ValueError: If server is not connected.
361
+ """
362
+ session = self._servers.get(server)
363
+ if session is None:
364
+ raise ValueError(f"Server '{server}' not connected. Available: {self.servers}")
365
+
366
+ result = await session.call_tool(tool_name, arguments)
367
+
368
+ # Extract text content from result
369
+ text_parts = [str(content.text) for content in result.content if hasattr(content, "text")]
370
+ return "\n".join(text_parts)
371
+
372
+ def get_all_tools(self) -> list[Tool[Any, ToolUseCountMetadata]]:
373
+ """Get individual Tool objects for each tool from all connected MCP servers.
374
+
375
+ Each MCP tool is exposed as a separate Tool with its own parameter schema,
376
+ allowing the LLM to see and call each tool directly without routing through
377
+ a unified proxy.
378
+
379
+ Tool names are formatted as '{server}__{tool_name}' to ensure uniqueness
380
+ across servers (e.g., 'supabase__query_table').
381
+
382
+ Returns:
383
+ List of Tool objects, one for each tool available across all connected servers.
384
+ """
385
+ tools: list[Tool[Any, ToolUseCountMetadata]] = []
386
+
387
+ for server_name, server_tools in self._tools.items():
388
+ for tool_info in server_tools:
389
+ mcp_tool_name = tool_info["name"]
390
+ # Create unique tool name with server prefix
391
+ unique_name = f"{server_name}__{mcp_tool_name}"
392
+
393
+ # Convert JSON schema to Pydantic model
394
+ params_model = create_model(
395
+ tool_info.get("schema", {}),
396
+ )
397
+
398
+ # Create executor closure - capture server_name and mcp_tool_name
399
+ # using default arguments to avoid late binding issues in the loop
400
+ async def executor(
401
+ params: BaseModel,
402
+ _server: str = server_name,
403
+ _tool: str = mcp_tool_name,
404
+ ) -> ToolResult[ToolUseCountMetadata]:
405
+ content = await self.call_tool(_server, _tool, params.model_dump())
406
+ xml_content = f"<mcp_result>\n{content}\n</mcp_result>"
407
+ return ToolResult(content=xml_content, metadata=ToolUseCountMetadata())
408
+
409
+ tools.append(
410
+ Tool(
411
+ name=unique_name,
412
+ description=tool_info.get("description") or f"Tool '{mcp_tool_name}' from {server_name}",
413
+ parameters=params_model,
414
+ executor=executor, # ty: ignore[invalid-argument-type]
415
+ )
416
+ )
417
+
418
+ return tools
419
+
420
+ # Tool lifecycle protocol implementation
421
+ async def __aenter__(self) -> list[Tool[Any, ToolUseCountMetadata]]:
422
+ """Enter async context: connect to MCP servers and return all tools.
423
+
424
+ Returns:
425
+ List of Tool objects, one for each tool available across all connected servers.
426
+ """
427
+ self._exit_stack = AsyncExitStack()
428
+ await self._exit_stack.__aenter__()
429
+
430
+ config = self._config
431
+ servers_to_connect = self._server_names or list(config.mcp_servers.keys())
432
+
433
+ for name in servers_to_connect:
434
+ if name not in config.mcp_servers:
435
+ raise KeyError(f"Server '{name}' not found in config. Available: {list(config.mcp_servers.keys())}")
436
+
437
+ server_config = config.mcp_servers[name]
438
+
439
+ # Connect to server based on transport type
440
+ match server_config:
441
+ case StdioServerConfig():
442
+ server_params = StdioServerParameters(
443
+ command=server_config.command,
444
+ args=server_config.args,
445
+ env=server_config.env,
446
+ cwd=server_config.cwd,
447
+ encoding=server_config.encoding,
448
+ )
449
+ read, write = await self._exit_stack.enter_async_context(stdio_client(server_params))
450
+ case SseServerConfig():
451
+ read, write = await self._exit_stack.enter_async_context(
452
+ sse_client(
453
+ url=server_config.url,
454
+ headers=server_config.headers,
455
+ timeout=server_config.timeout,
456
+ sse_read_timeout=server_config.sse_read_timeout,
457
+ )
458
+ )
459
+ case StreamableHttpServerConfig():
460
+ read, write, _ = await self._exit_stack.enter_async_context(
461
+ streamablehttp_client(
462
+ url=server_config.url,
463
+ headers=server_config.headers,
464
+ timeout=server_config.timeout,
465
+ sse_read_timeout=server_config.sse_read_timeout,
466
+ terminate_on_close=server_config.terminate_on_close,
467
+ )
468
+ )
469
+ case WebSocketServerConfig():
470
+ if websocket_client is None:
471
+ raise ImportError(
472
+ f"WebSocket transport for server '{name}' requires the 'websockets' package. "
473
+ "Install with: pip install websockets"
474
+ )
475
+ read, write = await self._exit_stack.enter_async_context(websocket_client(url=server_config.url))
476
+
477
+ session = await self._exit_stack.enter_async_context(ClientSession(read, write))
478
+ await session.initialize()
479
+
480
+ # Cache session and available tools
481
+ self._servers[name] = session
482
+ response = await session.list_tools()
483
+ self._tools[name] = [
484
+ {"name": t.name, "description": t.description, "schema": t.inputSchema} for t in response.tools
485
+ ]
486
+
487
+ return self.get_all_tools()
488
+
489
+ async def __aexit__(
490
+ self,
491
+ exc_type: type[BaseException] | None,
492
+ exc_val: BaseException | None,
493
+ exc_tb: TracebackType | None,
494
+ ) -> None:
495
+ """Exit async context: disconnect from MCP servers."""
496
+ self._servers.clear()
497
+ self._tools.clear()
498
+ if self._exit_stack:
499
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
500
+ self._exit_stack = None
@@ -0,0 +1,83 @@
1
+ """View image tool provider for execution environments."""
2
+
3
+ from stirrup.core.models import Tool, ToolProvider, ToolUseCountMetadata
4
+ from stirrup.tools.code_backends.base import CodeExecToolProvider, ViewImageParams
5
+
6
+
7
+ class ViewImageToolProvider(ToolProvider):
8
+ """Tool provider for viewing images from an execution environment.
9
+
10
+ Can be used with an explicit exec_env or will auto-detect from the
11
+ Agent's session state. Works regardless of tool ordering in the Agent.
12
+
13
+ Examples:
14
+ from stirrup.clients.chat_completions_client import ChatCompletionsClient
15
+
16
+ client = ChatCompletionsClient(model="gpt-5")
17
+
18
+ # Explicit exec_env
19
+ exec_env = LocalCodeExecToolProvider()
20
+ agent = Agent(
21
+ client=client, name="assistant",
22
+ tools=[exec_env, ViewImageToolProvider(exec_env)],
23
+ )
24
+
25
+ # Auto-detect (any order works)
26
+ agent = Agent(
27
+ client=client, name="assistant",
28
+ tools=[ViewImageToolProvider(), LocalCodeExecToolProvider()],
29
+ )
30
+
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ exec_env: CodeExecToolProvider | None = None,
36
+ *,
37
+ name: str = "view_image",
38
+ description: str | None = None,
39
+ ) -> None:
40
+ """Initialize ViewImageToolProvider.
41
+
42
+ Args:
43
+ exec_env: Optional execution environment. If None, will auto-detect
44
+ from the Agent's session state.
45
+ name: Tool name (default: "view_image").
46
+ description: Tool description (default: standard description).
47
+
48
+ """
49
+ self._exec_env = exec_env
50
+ self._name = name
51
+ self._description = description
52
+
53
+ async def __aenter__(self) -> Tool[ViewImageParams, ToolUseCountMetadata]:
54
+ """Enter async context: resolve exec_env and return view_image tool."""
55
+ # Import here to avoid circular dependency
56
+ from stirrup.core.agent import _SESSION_STATE
57
+
58
+ state = _SESSION_STATE.get(None)
59
+ agent_exec_env = state.exec_env if state else None
60
+
61
+ if self._exec_env is not None:
62
+ # Explicit exec_env provided - validate it matches agent's exec_env
63
+ if agent_exec_env is not None and self._exec_env is not agent_exec_env:
64
+ raise ValueError(
65
+ f"ViewImageToolProvider exec_env ({type(self._exec_env).__name__}) "
66
+ f"does not match Agent's exec_env ({type(agent_exec_env).__name__}). "
67
+ "Use the same exec_env instance or omit exec_env to auto-detect."
68
+ )
69
+ exec_env = self._exec_env
70
+ else:
71
+ # Auto-detect from session state
72
+ if agent_exec_env is None:
73
+ raise RuntimeError(
74
+ "ViewImageToolProvider requires a CodeExecToolProvider. "
75
+ "Either pass exec_env explicitly or include a CodeExecToolProvider "
76
+ "in the Agent's tools list."
77
+ )
78
+ exec_env = agent_exec_env
79
+
80
+ return exec_env.get_view_image_tool(
81
+ name=self._name,
82
+ description=self._description,
83
+ )