stirrup 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stirrup/__init__.py +76 -0
- stirrup/clients/__init__.py +14 -0
- stirrup/clients/chat_completions_client.py +219 -0
- stirrup/clients/litellm_client.py +141 -0
- stirrup/clients/utils.py +161 -0
- stirrup/constants.py +14 -0
- stirrup/core/__init__.py +1 -0
- stirrup/core/agent.py +1097 -0
- stirrup/core/exceptions.py +7 -0
- stirrup/core/models.py +599 -0
- stirrup/prompts/__init__.py +22 -0
- stirrup/prompts/base_system_prompt.txt +1 -0
- stirrup/prompts/message_summarizer.txt +27 -0
- stirrup/prompts/message_summarizer_bridge.txt +11 -0
- stirrup/py.typed +0 -0
- stirrup/tools/__init__.py +77 -0
- stirrup/tools/calculator.py +32 -0
- stirrup/tools/code_backends/__init__.py +38 -0
- stirrup/tools/code_backends/base.py +454 -0
- stirrup/tools/code_backends/docker.py +752 -0
- stirrup/tools/code_backends/e2b.py +359 -0
- stirrup/tools/code_backends/local.py +481 -0
- stirrup/tools/finish.py +23 -0
- stirrup/tools/mcp.py +500 -0
- stirrup/tools/view_image.py +83 -0
- stirrup/tools/web.py +336 -0
- stirrup/utils/__init__.py +10 -0
- stirrup/utils/logging.py +944 -0
- stirrup/utils/text.py +11 -0
- stirrup-0.1.0.dist-info/METADATA +318 -0
- stirrup-0.1.0.dist-info/RECORD +32 -0
- stirrup-0.1.0.dist-info/WHEEL +4 -0
stirrup/tools/mcp.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
1
|
+
"""MCP (Model Context Protocol) tool provider for connecting to MCP servers.
|
|
2
|
+
|
|
3
|
+
This module provides MCPToolProvider, a ToolProvider that manages connections to
|
|
4
|
+
multiple MCP servers and exposes each MCP tool as a separate Tool object.
|
|
5
|
+
|
|
6
|
+
Example usage:
|
|
7
|
+
```python
|
|
8
|
+
from stirrup.clients.chat_completions_client import ChatCompletionsClient
|
|
9
|
+
|
|
10
|
+
# With Agent (preferred)
|
|
11
|
+
client = ChatCompletionsClient(model="gpt-5")
|
|
12
|
+
agent = Agent(
|
|
13
|
+
client=client,
|
|
14
|
+
name="assistant",
|
|
15
|
+
tools=[*DEFAULT_TOOLS, MCPToolProvider.from_config("mcp.json")],
|
|
16
|
+
)
|
|
17
|
+
async with agent.session() as session:
|
|
18
|
+
await session.run("Use MCP tools")
|
|
19
|
+
|
|
20
|
+
# Standalone usage
|
|
21
|
+
provider = MCPToolProvider.from_config(Path("mcp.json"))
|
|
22
|
+
async with provider as tools:
|
|
23
|
+
# tools is a list of Tool objects
|
|
24
|
+
pass
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Requires the optional `mcp` dependency:
|
|
28
|
+
pip install stirrup[mcp]
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from collections.abc import AsyncIterator
|
|
32
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
|
33
|
+
from pathlib import Path
|
|
34
|
+
from types import TracebackType
|
|
35
|
+
from typing import Any, Self
|
|
36
|
+
|
|
37
|
+
from json_schema_to_pydantic import create_model
|
|
38
|
+
from pydantic import BaseModel, Field, model_validator
|
|
39
|
+
|
|
40
|
+
from stirrup.core.models import Tool, ToolProvider, ToolResult, ToolUseCountMetadata
|
|
41
|
+
|
|
42
|
+
# MCP imports (optional dependency)
|
|
43
|
+
try:
|
|
44
|
+
from mcp import ClientSession, StdioServerParameters
|
|
45
|
+
from mcp.client.sse import sse_client
|
|
46
|
+
from mcp.client.stdio import stdio_client
|
|
47
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
48
|
+
except ImportError as e:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"Requires installation of the mcp extra. Install with (for example): `uv pip install stirrup[mcp]` or `uv add stirrup[mcp]`",
|
|
51
|
+
) from e
|
|
52
|
+
|
|
53
|
+
# WebSocket client requires additional 'websockets' package
|
|
54
|
+
try:
|
|
55
|
+
from mcp.client.websocket import websocket_client
|
|
56
|
+
except ImportError:
|
|
57
|
+
websocket_client = None # type: ignore[assignment, misc]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
__all__ = [
|
|
61
|
+
"MCPConfig",
|
|
62
|
+
"MCPServerConfig",
|
|
63
|
+
"MCPToolProvider",
|
|
64
|
+
"SseServerConfig",
|
|
65
|
+
"StdioServerConfig",
|
|
66
|
+
"StreamableHttpServerConfig",
|
|
67
|
+
"WebSocketServerConfig",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# === Models ===
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class StdioServerConfig(BaseModel):
|
|
75
|
+
"""Configuration for stdio-based MCP servers (local process)."""
|
|
76
|
+
|
|
77
|
+
command: str
|
|
78
|
+
"""Command to run the MCP server (e.g., "npx", "python")."""
|
|
79
|
+
|
|
80
|
+
args: list[str] = Field(default_factory=list)
|
|
81
|
+
"""Arguments to pass to the command."""
|
|
82
|
+
|
|
83
|
+
env: dict[str, str] | None = None
|
|
84
|
+
"""Environment variables to set for the server process."""
|
|
85
|
+
|
|
86
|
+
cwd: str | None = None
|
|
87
|
+
"""Working directory for the server process."""
|
|
88
|
+
|
|
89
|
+
encoding: str = "utf-8"
|
|
90
|
+
"""Text encoding for messages."""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class SseServerConfig(BaseModel):
|
|
94
|
+
"""Configuration for SSE-based MCP servers (HTTP GET with Server-Sent Events)."""
|
|
95
|
+
|
|
96
|
+
url: str
|
|
97
|
+
"""The SSE endpoint URL (must end with /sse)."""
|
|
98
|
+
|
|
99
|
+
headers: dict[str, str] | None = None
|
|
100
|
+
"""Optional HTTP headers."""
|
|
101
|
+
|
|
102
|
+
timeout: float = 5.0
|
|
103
|
+
"""HTTP timeout for regular operations (seconds)."""
|
|
104
|
+
|
|
105
|
+
sse_read_timeout: float = 300.0
|
|
106
|
+
"""Timeout for SSE read operations (seconds)."""
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class StreamableHttpServerConfig(BaseModel):
|
|
110
|
+
"""Configuration for Streamable HTTP MCP servers (HTTP POST with optional SSE responses)."""
|
|
111
|
+
|
|
112
|
+
url: str
|
|
113
|
+
"""The endpoint URL."""
|
|
114
|
+
|
|
115
|
+
headers: dict[str, str] | None = None
|
|
116
|
+
"""Optional HTTP headers."""
|
|
117
|
+
|
|
118
|
+
timeout: float = 30.0
|
|
119
|
+
"""HTTP timeout (seconds)."""
|
|
120
|
+
|
|
121
|
+
sse_read_timeout: float = 300.0
|
|
122
|
+
"""SSE read timeout (seconds)."""
|
|
123
|
+
|
|
124
|
+
terminate_on_close: bool = True
|
|
125
|
+
"""Close session when transport closes."""
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class WebSocketServerConfig(BaseModel):
|
|
129
|
+
"""Configuration for WebSocket-based MCP servers."""
|
|
130
|
+
|
|
131
|
+
url: str
|
|
132
|
+
"""The WebSocket URL (must start with ws:// or wss://)."""
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# Type alias for the union of all server config types
|
|
136
|
+
MCPServerConfig = StdioServerConfig | SseServerConfig | StreamableHttpServerConfig | WebSocketServerConfig
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _infer_server_config(data: dict[str, Any]) -> MCPServerConfig:
|
|
140
|
+
"""Infer and instantiate the correct config class from raw data.
|
|
141
|
+
|
|
142
|
+
Inference rules:
|
|
143
|
+
- 'command' field present -> StdioServerConfig
|
|
144
|
+
- 'url' starts with ws:// or wss:// -> WebSocketServerConfig
|
|
145
|
+
- 'url' ends with /sse -> SseServerConfig
|
|
146
|
+
- 'url' present (default) -> StreamableHttpServerConfig
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
data: Raw configuration dictionary.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Appropriate server config instance.
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
ValueError: If neither 'command' nor 'url' is provided.
|
|
156
|
+
"""
|
|
157
|
+
if "command" in data:
|
|
158
|
+
return StdioServerConfig(**data)
|
|
159
|
+
if "url" in data:
|
|
160
|
+
url = data["url"]
|
|
161
|
+
if url.startswith(("ws://", "wss://")):
|
|
162
|
+
return WebSocketServerConfig(**data)
|
|
163
|
+
if url.endswith("/sse"):
|
|
164
|
+
return SseServerConfig(**data)
|
|
165
|
+
return StreamableHttpServerConfig(**data)
|
|
166
|
+
raise ValueError("Config must have 'command' (stdio) or 'url' (SSE/HTTP/WebSocket)")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class MCPConfig(BaseModel):
|
|
170
|
+
"""Root configuration matching mcp.json format."""
|
|
171
|
+
|
|
172
|
+
mcp_servers: dict[str, MCPServerConfig] = Field(alias="mcpServers")
|
|
173
|
+
"""Map of server names to their configurations."""
|
|
174
|
+
|
|
175
|
+
@model_validator(mode="before")
|
|
176
|
+
@classmethod
|
|
177
|
+
def _infer_transport_types(cls, data: dict[str, Any]) -> dict[str, Any]:
|
|
178
|
+
"""Convert raw server configs to appropriate typed instances."""
|
|
179
|
+
if "mcpServers" in data:
|
|
180
|
+
data["mcpServers"] = {
|
|
181
|
+
name: _infer_server_config(config) if isinstance(config, dict) else config
|
|
182
|
+
for name, config in data["mcpServers"].items()
|
|
183
|
+
}
|
|
184
|
+
return data
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# === Manager ===
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class MCPToolProvider(ToolProvider):
|
|
191
|
+
"""MCP tool provider that manages connections to multiple MCP servers.
|
|
192
|
+
|
|
193
|
+
MCPToolProvider connects to MCP servers and exposes each server's tools
|
|
194
|
+
as individual Tool objects.
|
|
195
|
+
|
|
196
|
+
Usage with Agent (preferred):
|
|
197
|
+
from stirrup.clients.chat_completions_client import ChatCompletionsClient
|
|
198
|
+
|
|
199
|
+
client = ChatCompletionsClient(model="gpt-5")
|
|
200
|
+
agent = Agent(
|
|
201
|
+
client=client,
|
|
202
|
+
name="assistant",
|
|
203
|
+
tools=[*DEFAULT_TOOLS, MCPToolProvider.from_config("mcp.json")],
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
async with agent.session(output_dir="./output") as session:
|
|
207
|
+
await session.run("Use MCP tools")
|
|
208
|
+
|
|
209
|
+
Standalone usage with connect() context manager:
|
|
210
|
+
provider = MCPToolProvider.from_config(Path("mcp.json"))
|
|
211
|
+
async with provider.connect() as provider:
|
|
212
|
+
tools = provider.get_all_tools()
|
|
213
|
+
# Use tools...
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
config: MCPConfig,
|
|
219
|
+
server_names: list[str] | None = None,
|
|
220
|
+
) -> None:
|
|
221
|
+
"""Initialize the MCP manager.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
config: MCPConfig instance.
|
|
225
|
+
server_names: Which servers to connect to. If None, connects to all servers in config.
|
|
226
|
+
"""
|
|
227
|
+
self._config = config
|
|
228
|
+
self._server_names = server_names
|
|
229
|
+
self._servers: dict[str, ClientSession] = {}
|
|
230
|
+
self._tools: dict[str, list[dict[str, Any]]] = {}
|
|
231
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
232
|
+
|
|
233
|
+
@classmethod
|
|
234
|
+
def from_config(cls, config_path: Path | str, server_names: list[str] | None = None) -> Self:
|
|
235
|
+
"""Create an MCPToolProvider from a config file.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
config_path: Path to the MCP config file.
|
|
239
|
+
server_names: Which servers to connect to. If None, connects to all servers in config.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
MCPToolProvider instance.
|
|
243
|
+
"""
|
|
244
|
+
config = MCPConfig.model_validate_json(Path(config_path).read_text())
|
|
245
|
+
|
|
246
|
+
return cls(config=config, server_names=server_names)
|
|
247
|
+
|
|
248
|
+
@asynccontextmanager
|
|
249
|
+
async def connect(self) -> AsyncIterator[Self]:
|
|
250
|
+
"""Connect to MCP servers from config file.
|
|
251
|
+
|
|
252
|
+
Yields:
|
|
253
|
+
Self with active connections to specified servers.
|
|
254
|
+
|
|
255
|
+
Raises:
|
|
256
|
+
FileNotFoundError: If config file doesn't exist.
|
|
257
|
+
KeyError: If a specified server name doesn't exist in config.
|
|
258
|
+
"""
|
|
259
|
+
config = self._config
|
|
260
|
+
servers_to_connect = self._server_names or list(config.mcp_servers.keys())
|
|
261
|
+
|
|
262
|
+
async with AsyncExitStack() as stack:
|
|
263
|
+
for name in servers_to_connect:
|
|
264
|
+
if name not in config.mcp_servers:
|
|
265
|
+
raise KeyError(f"Server '{name}' not found in config. Available: {list(config.mcp_servers.keys())}")
|
|
266
|
+
|
|
267
|
+
server_config = config.mcp_servers[name]
|
|
268
|
+
|
|
269
|
+
# Connect to server based on transport type
|
|
270
|
+
match server_config:
|
|
271
|
+
case StdioServerConfig():
|
|
272
|
+
server_params = StdioServerParameters(
|
|
273
|
+
command=server_config.command,
|
|
274
|
+
args=server_config.args,
|
|
275
|
+
env=server_config.env,
|
|
276
|
+
cwd=server_config.cwd,
|
|
277
|
+
encoding=server_config.encoding,
|
|
278
|
+
)
|
|
279
|
+
read, write = await stack.enter_async_context(stdio_client(server_params))
|
|
280
|
+
case SseServerConfig():
|
|
281
|
+
read, write = await stack.enter_async_context(
|
|
282
|
+
sse_client(
|
|
283
|
+
url=server_config.url,
|
|
284
|
+
headers=server_config.headers,
|
|
285
|
+
timeout=server_config.timeout,
|
|
286
|
+
sse_read_timeout=server_config.sse_read_timeout,
|
|
287
|
+
)
|
|
288
|
+
)
|
|
289
|
+
case StreamableHttpServerConfig():
|
|
290
|
+
read, write, _ = await stack.enter_async_context(
|
|
291
|
+
streamablehttp_client(
|
|
292
|
+
url=server_config.url,
|
|
293
|
+
headers=server_config.headers,
|
|
294
|
+
timeout=server_config.timeout,
|
|
295
|
+
sse_read_timeout=server_config.sse_read_timeout,
|
|
296
|
+
terminate_on_close=server_config.terminate_on_close,
|
|
297
|
+
)
|
|
298
|
+
)
|
|
299
|
+
case WebSocketServerConfig():
|
|
300
|
+
if websocket_client is None:
|
|
301
|
+
raise ImportError(
|
|
302
|
+
f"WebSocket transport for server '{name}' requires the 'websockets' package. "
|
|
303
|
+
"Install with: pip install websockets"
|
|
304
|
+
)
|
|
305
|
+
read, write = await stack.enter_async_context(websocket_client(url=server_config.url))
|
|
306
|
+
|
|
307
|
+
session = await stack.enter_async_context(ClientSession(read, write))
|
|
308
|
+
await session.initialize()
|
|
309
|
+
|
|
310
|
+
# Cache session and available tools
|
|
311
|
+
self._servers[name] = session
|
|
312
|
+
response = await session.list_tools()
|
|
313
|
+
self._tools[name] = [
|
|
314
|
+
{"name": t.name, "description": t.description, "schema": t.inputSchema} for t in response.tools
|
|
315
|
+
]
|
|
316
|
+
|
|
317
|
+
try:
|
|
318
|
+
yield self
|
|
319
|
+
finally:
|
|
320
|
+
self._servers.clear()
|
|
321
|
+
self._tools.clear()
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def servers(self) -> list[str]:
|
|
325
|
+
"""List of connected server names."""
|
|
326
|
+
return list(self._servers.keys())
|
|
327
|
+
|
|
328
|
+
def get_tools(self, server: str) -> list[dict[str, Any]]:
|
|
329
|
+
"""Get available tools for a specific server.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
server: Server name.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
List of tool info dicts with name, description, and schema.
|
|
336
|
+
"""
|
|
337
|
+
return self._tools.get(server, [])
|
|
338
|
+
|
|
339
|
+
@property
|
|
340
|
+
def all_tools(self) -> dict[str, list[str]]:
|
|
341
|
+
"""Get all available tools grouped by server.
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Dict mapping server names to lists of tool names.
|
|
345
|
+
"""
|
|
346
|
+
return {server: [t["name"] for t in tools] for server, tools in self._tools.items()}
|
|
347
|
+
|
|
348
|
+
async def call_tool(self, server: str, tool_name: str, arguments: dict[str, Any]) -> str:
|
|
349
|
+
"""Call a tool on a specific MCP server.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
server: Name of the MCP server.
|
|
353
|
+
tool_name: Name of the tool to call.
|
|
354
|
+
arguments: Arguments to pass to the tool.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Tool result as a string (text content extracted from response).
|
|
358
|
+
|
|
359
|
+
Raises:
|
|
360
|
+
ValueError: If server is not connected.
|
|
361
|
+
"""
|
|
362
|
+
session = self._servers.get(server)
|
|
363
|
+
if session is None:
|
|
364
|
+
raise ValueError(f"Server '{server}' not connected. Available: {self.servers}")
|
|
365
|
+
|
|
366
|
+
result = await session.call_tool(tool_name, arguments)
|
|
367
|
+
|
|
368
|
+
# Extract text content from result
|
|
369
|
+
text_parts = [str(content.text) for content in result.content if hasattr(content, "text")]
|
|
370
|
+
return "\n".join(text_parts)
|
|
371
|
+
|
|
372
|
+
def get_all_tools(self) -> list[Tool[Any, ToolUseCountMetadata]]:
|
|
373
|
+
"""Get individual Tool objects for each tool from all connected MCP servers.
|
|
374
|
+
|
|
375
|
+
Each MCP tool is exposed as a separate Tool with its own parameter schema,
|
|
376
|
+
allowing the LLM to see and call each tool directly without routing through
|
|
377
|
+
a unified proxy.
|
|
378
|
+
|
|
379
|
+
Tool names are formatted as '{server}__{tool_name}' to ensure uniqueness
|
|
380
|
+
across servers (e.g., 'supabase__query_table').
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
List of Tool objects, one for each tool available across all connected servers.
|
|
384
|
+
"""
|
|
385
|
+
tools: list[Tool[Any, ToolUseCountMetadata]] = []
|
|
386
|
+
|
|
387
|
+
for server_name, server_tools in self._tools.items():
|
|
388
|
+
for tool_info in server_tools:
|
|
389
|
+
mcp_tool_name = tool_info["name"]
|
|
390
|
+
# Create unique tool name with server prefix
|
|
391
|
+
unique_name = f"{server_name}__{mcp_tool_name}"
|
|
392
|
+
|
|
393
|
+
# Convert JSON schema to Pydantic model
|
|
394
|
+
params_model = create_model(
|
|
395
|
+
tool_info.get("schema", {}),
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Create executor closure - capture server_name and mcp_tool_name
|
|
399
|
+
# using default arguments to avoid late binding issues in the loop
|
|
400
|
+
async def executor(
|
|
401
|
+
params: BaseModel,
|
|
402
|
+
_server: str = server_name,
|
|
403
|
+
_tool: str = mcp_tool_name,
|
|
404
|
+
) -> ToolResult[ToolUseCountMetadata]:
|
|
405
|
+
content = await self.call_tool(_server, _tool, params.model_dump())
|
|
406
|
+
xml_content = f"<mcp_result>\n{content}\n</mcp_result>"
|
|
407
|
+
return ToolResult(content=xml_content, metadata=ToolUseCountMetadata())
|
|
408
|
+
|
|
409
|
+
tools.append(
|
|
410
|
+
Tool(
|
|
411
|
+
name=unique_name,
|
|
412
|
+
description=tool_info.get("description") or f"Tool '{mcp_tool_name}' from {server_name}",
|
|
413
|
+
parameters=params_model,
|
|
414
|
+
executor=executor, # ty: ignore[invalid-argument-type]
|
|
415
|
+
)
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
return tools
|
|
419
|
+
|
|
420
|
+
# Tool lifecycle protocol implementation
|
|
421
|
+
async def __aenter__(self) -> list[Tool[Any, ToolUseCountMetadata]]:
|
|
422
|
+
"""Enter async context: connect to MCP servers and return all tools.
|
|
423
|
+
|
|
424
|
+
Returns:
|
|
425
|
+
List of Tool objects, one for each tool available across all connected servers.
|
|
426
|
+
"""
|
|
427
|
+
self._exit_stack = AsyncExitStack()
|
|
428
|
+
await self._exit_stack.__aenter__()
|
|
429
|
+
|
|
430
|
+
config = self._config
|
|
431
|
+
servers_to_connect = self._server_names or list(config.mcp_servers.keys())
|
|
432
|
+
|
|
433
|
+
for name in servers_to_connect:
|
|
434
|
+
if name not in config.mcp_servers:
|
|
435
|
+
raise KeyError(f"Server '{name}' not found in config. Available: {list(config.mcp_servers.keys())}")
|
|
436
|
+
|
|
437
|
+
server_config = config.mcp_servers[name]
|
|
438
|
+
|
|
439
|
+
# Connect to server based on transport type
|
|
440
|
+
match server_config:
|
|
441
|
+
case StdioServerConfig():
|
|
442
|
+
server_params = StdioServerParameters(
|
|
443
|
+
command=server_config.command,
|
|
444
|
+
args=server_config.args,
|
|
445
|
+
env=server_config.env,
|
|
446
|
+
cwd=server_config.cwd,
|
|
447
|
+
encoding=server_config.encoding,
|
|
448
|
+
)
|
|
449
|
+
read, write = await self._exit_stack.enter_async_context(stdio_client(server_params))
|
|
450
|
+
case SseServerConfig():
|
|
451
|
+
read, write = await self._exit_stack.enter_async_context(
|
|
452
|
+
sse_client(
|
|
453
|
+
url=server_config.url,
|
|
454
|
+
headers=server_config.headers,
|
|
455
|
+
timeout=server_config.timeout,
|
|
456
|
+
sse_read_timeout=server_config.sse_read_timeout,
|
|
457
|
+
)
|
|
458
|
+
)
|
|
459
|
+
case StreamableHttpServerConfig():
|
|
460
|
+
read, write, _ = await self._exit_stack.enter_async_context(
|
|
461
|
+
streamablehttp_client(
|
|
462
|
+
url=server_config.url,
|
|
463
|
+
headers=server_config.headers,
|
|
464
|
+
timeout=server_config.timeout,
|
|
465
|
+
sse_read_timeout=server_config.sse_read_timeout,
|
|
466
|
+
terminate_on_close=server_config.terminate_on_close,
|
|
467
|
+
)
|
|
468
|
+
)
|
|
469
|
+
case WebSocketServerConfig():
|
|
470
|
+
if websocket_client is None:
|
|
471
|
+
raise ImportError(
|
|
472
|
+
f"WebSocket transport for server '{name}' requires the 'websockets' package. "
|
|
473
|
+
"Install with: pip install websockets"
|
|
474
|
+
)
|
|
475
|
+
read, write = await self._exit_stack.enter_async_context(websocket_client(url=server_config.url))
|
|
476
|
+
|
|
477
|
+
session = await self._exit_stack.enter_async_context(ClientSession(read, write))
|
|
478
|
+
await session.initialize()
|
|
479
|
+
|
|
480
|
+
# Cache session and available tools
|
|
481
|
+
self._servers[name] = session
|
|
482
|
+
response = await session.list_tools()
|
|
483
|
+
self._tools[name] = [
|
|
484
|
+
{"name": t.name, "description": t.description, "schema": t.inputSchema} for t in response.tools
|
|
485
|
+
]
|
|
486
|
+
|
|
487
|
+
return self.get_all_tools()
|
|
488
|
+
|
|
489
|
+
async def __aexit__(
|
|
490
|
+
self,
|
|
491
|
+
exc_type: type[BaseException] | None,
|
|
492
|
+
exc_val: BaseException | None,
|
|
493
|
+
exc_tb: TracebackType | None,
|
|
494
|
+
) -> None:
|
|
495
|
+
"""Exit async context: disconnect from MCP servers."""
|
|
496
|
+
self._servers.clear()
|
|
497
|
+
self._tools.clear()
|
|
498
|
+
if self._exit_stack:
|
|
499
|
+
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
|
|
500
|
+
self._exit_stack = None
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""View image tool provider for execution environments."""
|
|
2
|
+
|
|
3
|
+
from stirrup.core.models import Tool, ToolProvider, ToolUseCountMetadata
|
|
4
|
+
from stirrup.tools.code_backends.base import CodeExecToolProvider, ViewImageParams
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ViewImageToolProvider(ToolProvider):
|
|
8
|
+
"""Tool provider for viewing images from an execution environment.
|
|
9
|
+
|
|
10
|
+
Can be used with an explicit exec_env or will auto-detect from the
|
|
11
|
+
Agent's session state. Works regardless of tool ordering in the Agent.
|
|
12
|
+
|
|
13
|
+
Examples:
|
|
14
|
+
from stirrup.clients.chat_completions_client import ChatCompletionsClient
|
|
15
|
+
|
|
16
|
+
client = ChatCompletionsClient(model="gpt-5")
|
|
17
|
+
|
|
18
|
+
# Explicit exec_env
|
|
19
|
+
exec_env = LocalCodeExecToolProvider()
|
|
20
|
+
agent = Agent(
|
|
21
|
+
client=client, name="assistant",
|
|
22
|
+
tools=[exec_env, ViewImageToolProvider(exec_env)],
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Auto-detect (any order works)
|
|
26
|
+
agent = Agent(
|
|
27
|
+
client=client, name="assistant",
|
|
28
|
+
tools=[ViewImageToolProvider(), LocalCodeExecToolProvider()],
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
exec_env: CodeExecToolProvider | None = None,
|
|
36
|
+
*,
|
|
37
|
+
name: str = "view_image",
|
|
38
|
+
description: str | None = None,
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Initialize ViewImageToolProvider.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
exec_env: Optional execution environment. If None, will auto-detect
|
|
44
|
+
from the Agent's session state.
|
|
45
|
+
name: Tool name (default: "view_image").
|
|
46
|
+
description: Tool description (default: standard description).
|
|
47
|
+
|
|
48
|
+
"""
|
|
49
|
+
self._exec_env = exec_env
|
|
50
|
+
self._name = name
|
|
51
|
+
self._description = description
|
|
52
|
+
|
|
53
|
+
async def __aenter__(self) -> Tool[ViewImageParams, ToolUseCountMetadata]:
|
|
54
|
+
"""Enter async context: resolve exec_env and return view_image tool."""
|
|
55
|
+
# Import here to avoid circular dependency
|
|
56
|
+
from stirrup.core.agent import _SESSION_STATE
|
|
57
|
+
|
|
58
|
+
state = _SESSION_STATE.get(None)
|
|
59
|
+
agent_exec_env = state.exec_env if state else None
|
|
60
|
+
|
|
61
|
+
if self._exec_env is not None:
|
|
62
|
+
# Explicit exec_env provided - validate it matches agent's exec_env
|
|
63
|
+
if agent_exec_env is not None and self._exec_env is not agent_exec_env:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"ViewImageToolProvider exec_env ({type(self._exec_env).__name__}) "
|
|
66
|
+
f"does not match Agent's exec_env ({type(agent_exec_env).__name__}). "
|
|
67
|
+
"Use the same exec_env instance or omit exec_env to auto-detect."
|
|
68
|
+
)
|
|
69
|
+
exec_env = self._exec_env
|
|
70
|
+
else:
|
|
71
|
+
# Auto-detect from session state
|
|
72
|
+
if agent_exec_env is None:
|
|
73
|
+
raise RuntimeError(
|
|
74
|
+
"ViewImageToolProvider requires a CodeExecToolProvider. "
|
|
75
|
+
"Either pass exec_env explicitly or include a CodeExecToolProvider "
|
|
76
|
+
"in the Agent's tools list."
|
|
77
|
+
)
|
|
78
|
+
exec_env = agent_exec_env
|
|
79
|
+
|
|
80
|
+
return exec_env.get_view_image_tool(
|
|
81
|
+
name=self._name,
|
|
82
|
+
description=self._description,
|
|
83
|
+
)
|