gnosys-strata 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,104 @@
1
+ """Abstract base class for MCP transport implementations."""
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from contextlib import AsyncExitStack
6
+ from typing import Optional, Tuple
7
+
8
+ from mcp.client.session import ClientSession
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Transport(ABC):
14
+ """Abstract base class for MCP transport implementations."""
15
+
16
+ def __init__(self):
17
+ """Initialize the transport."""
18
+ self._session: Optional[ClientSession] = None
19
+ self._exit_stack: Optional[AsyncExitStack] = None
20
+ self._connected: bool = False
21
+
22
+ @abstractmethod
23
+ async def _get_streams(self, exit_stack: AsyncExitStack) -> Tuple:
24
+ """Get the transport-specific streams.
25
+
26
+ Args:
27
+ exit_stack: AsyncExitStack to manage the streams
28
+
29
+ Returns:
30
+ Tuple of (read_stream, write_stream)
31
+ """
32
+ async def initialize(self) -> None:
33
+ if not self._exit_stack:
34
+ self._exit_stack = AsyncExitStack()
35
+
36
+ try:
37
+ # Get transport-specific streams
38
+ streams = await self._get_streams(self._exit_stack)
39
+
40
+ # Create client session (common for all transports)
41
+ self._session = await self._exit_stack.enter_async_context(
42
+ ClientSession(streams[0], streams[1])
43
+ )
44
+ logger.info("Client session created successfully")
45
+ # Initialize the session
46
+ await self._session.initialize()
47
+
48
+ self._connected = True
49
+ logger.info(f"Successfully connected via {self.__class__.__name__}")
50
+
51
+ except Exception as e:
52
+ logger.error(f"Failed to connect via {self.__class__.__name__}: {e}")
53
+ if self._exit_stack:
54
+ await self._exit_stack.aclose()
55
+ self._exit_stack = None
56
+ raise
57
+
58
+ async def connect(self) -> None:
59
+ """Connect to the MCP server using the specific transport."""
60
+ if self._connected:
61
+ return
62
+
63
+ await self.initialize()
64
+
65
+ async def disconnect(self) -> None:
66
+ """Disconnect from the MCP server."""
67
+ if not self._connected:
68
+ return
69
+
70
+ if self._exit_stack:
71
+ try:
72
+ await self._exit_stack.aclose()
73
+ except RuntimeError as e:
74
+ # Handle cross-task cleanup errors from anyio's CancelScope
75
+ if "cancel scope" in str(e).lower():
76
+ logger.warning(
77
+ "Cross-task cleanup detected and handled. "
78
+ "This typically happens with pytest fixtures."
79
+ )
80
+ else:
81
+ raise
82
+
83
+ self._session = None
84
+ self._exit_stack = None
85
+ self._connected = False
86
+
87
+ def is_connected(self) -> bool:
88
+ """Check if connected to an MCP server."""
89
+ return self._connected
90
+
91
+ def get_session(self) -> ClientSession:
92
+ """Get the current client session."""
93
+ if not self._connected or not self._session:
94
+ raise RuntimeError("Not connected to an MCP server.")
95
+ return self._session
96
+
97
+ async def __aenter__(self):
98
+ """Enter async context manager."""
99
+ await self.connect()
100
+ return self
101
+
102
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
103
+ """Exit async context manager."""
104
+ await self.disconnect()
@@ -0,0 +1,80 @@
1
+ """HTTP and SSE transport implementations for MCP."""
2
+
3
+ import logging
4
+ from contextlib import AsyncExitStack
5
+ from typing import Dict, Literal, Optional, Tuple
6
+ from urllib.parse import urlparse
7
+
8
+ from mcp.client.sse import sse_client
9
+ from mcp.client.streamable_http import streamablehttp_client
10
+
11
+ from strata.mcp_proxy.auth_provider import create_oauth_provider
12
+ from .base import Transport
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class HTTPTransport(Transport):
17
+ """HTTP/SSE transport for MCP communication."""
18
+
19
+ def __init__(
20
+ self,
21
+ server_name: str,
22
+ url: str,
23
+ mode: Literal["http", "sse"] = "http",
24
+ headers: Optional[Dict[str, str]] = None,
25
+ auth: str = "",
26
+ ):
27
+ """Initialize HTTP transport.
28
+
29
+ Args:
30
+ url: HTTP/HTTPS URL of the MCP server
31
+ mode: Transport mode - "http" for request/response, "sse" for server-sent events
32
+ headers: Optional headers to send with requests
33
+ """
34
+ super().__init__()
35
+ self.server_name = server_name
36
+ self.url = url
37
+ self.mode = mode
38
+ self.headers = headers or {}
39
+ self.auth = auth
40
+
41
+ # Validate URL
42
+ parsed = urlparse(url)
43
+ if parsed.scheme not in ("http", "https"):
44
+ raise ValueError(f"Invalid URL scheme: {parsed.scheme}")
45
+
46
+ async def _get_streams(self, exit_stack: AsyncExitStack) -> Tuple:
47
+ """Get HTTP/SSE transport streams.
48
+
49
+ Args:
50
+ exit_stack: AsyncExitStack to manage the streams
51
+
52
+ Returns:
53
+ Tuple of (read_stream, write_stream)
54
+ """
55
+ if self.mode == "sse":
56
+ # Connect via SSE for server-sent events
57
+ logger.info(f"Connecting to MCP server via SSE: {self.url}")
58
+ if self.auth == "oauth":
59
+ return await exit_stack.enter_async_context(
60
+ sse_client(
61
+ self.url, headers=self.headers, auth=create_oauth_provider(self.server_name, self.url))
62
+ )
63
+ return await exit_stack.enter_async_context(
64
+ sse_client(self.url, headers=self.headers)
65
+ )
66
+ elif self.mode == "http":
67
+ # Connect via standard HTTP (request/response)
68
+ logger.info(f"Connecting to MCP server via HTTP: {self.url}")
69
+ if self.auth == "oauth":
70
+ return await exit_stack.enter_async_context(
71
+ streamablehttp_client(
72
+ self.url, headers=self.headers, auth=create_oauth_provider(self.server_name, self.url))
73
+ )
74
+ return await exit_stack.enter_async_context(
75
+ streamablehttp_client(self.url, headers=self.headers)
76
+ )
77
+ else:
78
+ raise ValueError(
79
+ f"Invalid transport mode: {self.mode}. Use 'http' or 'sse'."
80
+ )
@@ -0,0 +1,69 @@
1
+ """Stdio transport implementation for MCP."""
2
+
3
+ import logging
4
+ from contextlib import AsyncExitStack
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ from mcp.client.stdio import StdioServerParameters, stdio_client
8
+
9
+ from .base import Transport
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class StdioTransport(Transport):
15
+ """Stdio transport for MCP communication."""
16
+
17
+ def __init__(
18
+ self,
19
+ command: str,
20
+ args: Optional[List[str]] = None,
21
+ env: Optional[Dict[str, str]] = None,
22
+ ):
23
+ """Initialize stdio transport.
24
+
25
+ Args:
26
+ command: Command to execute (e.g., "docker")
27
+ args: Command arguments (e.g., ["run", "-i", "--rm", "-e", "GITHUB_PERSONAL_ACCESS_TOKEN", "ghcr.io/github/github-mcp-server"])
28
+ env: Environment variables to pass to the command (e.g., {"GITHUB_PERSONAL_ACCESS_TOKEN": "${input:github_token}"})
29
+
30
+ Example server configuration:
31
+ "servers": {
32
+ "github": {
33
+ "command": "docker",
34
+ "args": [
35
+ "run",
36
+ "-i",
37
+ "--rm",
38
+ "-e",
39
+ "GITHUB_PERSONAL_ACCESS_TOKEN",
40
+ "ghcr.io/github/github-mcp-server"
41
+ ],
42
+ "env": {
43
+ "GITHUB_PERSONAL_ACCESS_TOKEN": "${input:github_token}"
44
+ }
45
+ }
46
+ }
47
+ """
48
+ super().__init__()
49
+ self.command = command
50
+ self.args = args or []
51
+ self.env = env or {}
52
+
53
+ async def _get_streams(self, exit_stack: AsyncExitStack) -> Tuple:
54
+ """Get stdio transport streams.
55
+
56
+ Args:
57
+ exit_stack: AsyncExitStack to manage the streams
58
+
59
+ Returns:
60
+ Tuple of (read_stream, write_stream)
61
+ """
62
+ # Create stdio server parameters
63
+ server_params = StdioServerParameters(
64
+ command=self.command, args=self.args, env=self.env
65
+ )
66
+
67
+ # Connect via stdio
68
+ logger.info(f"Connecting to MCP server via stdio: {self.command} {self.args}")
69
+ return await exit_stack.enter_async_context(stdio_client(server_params))
strata/server.py ADDED
@@ -0,0 +1,216 @@
1
+ """Main server module for Strata MCP Router."""
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import logging
6
+ import os
7
+ from collections.abc import AsyncIterator
8
+
9
+ import mcp.types as types
10
+ from mcp.server.lowlevel import Server
11
+ from mcp.server.sse import SseServerTransport
12
+ from mcp.server.stdio import stdio_server
13
+ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
14
+ from starlette.applications import Starlette
15
+ from starlette.responses import Response
16
+ from starlette.routing import Mount, Route
17
+ from starlette.types import Receive, Scope, Send
18
+
19
+ from .mcp_client_manager import MCPClientManager
20
+ from .tools import execute_tool, get_tool_definitions
21
+
22
+ # Configure logging
23
+ logger = logging.getLogger(__name__)
24
+
25
+ MCP_ROUTER_PORT = int(os.getenv("MCP_ROUTER_PORT", "8080"))
26
+
27
+ # Global client manager
28
+ client_manager = MCPClientManager()
29
+
30
+
31
+ @contextlib.asynccontextmanager
32
+ async def config_watching_context():
33
+ """Shared context manager for config watching in both stdio and HTTP modes."""
34
+ # JIT mode: Don't connect any servers on startup
35
+ # Servers connect on-demand via manage_servers tool
36
+ # Catalog is loaded from disk cache (populated via populate_catalog tool)
37
+ logger.info("Strata starting in JIT mode - no servers connected on startup")
38
+ logger.info(f"Catalog has {len(client_manager.catalog.get_all_tools())} cached servers")
39
+
40
+ # Create default starsystem set if it doesn't exist
41
+ if not client_manager.server_list.get_set("starsystem"):
42
+ default_servers = ["STARSYSTEM", "starship", "starlog", "waypoint", "metastack"]
43
+ # Only include servers that are actually configured
44
+ available = [s for s in default_servers if client_manager.server_list.get_server(s)]
45
+ if available:
46
+ client_manager.server_list.add_set(
47
+ "starsystem",
48
+ available,
49
+ description="Core compound intelligence: STARSYSTEM wrapper + STARSHIP navigation + STARLOG tracking + WAYPOINT flights + metastack templates"
50
+ )
51
+ logger.info(f"Created default 'starsystem' set with: {', '.join(available)}")
52
+
53
+ # Start config watching in background
54
+ def on_config_changed(new_servers):
55
+ """Handle config changes by syncing the client manager."""
56
+ logger.info("Config file changed, syncing client manager...")
57
+
58
+ async def safe_sync():
59
+ """Safely sync with config, catching any errors."""
60
+ try:
61
+ await client_manager.sync_with_config(new_servers)
62
+ except Exception as e:
63
+ logger.error(f"Error during config sync: {e}")
64
+ # Don't let sync errors crash the server
65
+
66
+ # Schedule sync on the event loop
67
+ asyncio.create_task(safe_sync())
68
+
69
+ # Start watching config file for changes
70
+ watch_task = asyncio.create_task(
71
+ client_manager.server_list.watch_config(on_config_changed)
72
+ )
73
+ logger.info("Config file watching enabled - changes will be auto-synced")
74
+
75
+ try:
76
+ yield
77
+ finally:
78
+ logger.info("Shutting down...")
79
+ # Stop config watching
80
+ watch_task.cancel()
81
+ try:
82
+ await watch_task
83
+ except asyncio.CancelledError:
84
+ logger.info("Config watching stopped") # Expected cancellation, no traceback needed
85
+ # Clean up client managers
86
+ await client_manager.disconnect_all()
87
+
88
+
89
+ def setup_server_handlers(server: Server) -> None:
90
+ """Set up shared MCP server handlers for both stdio and HTTP modes."""
91
+
92
+ @server.list_tools()
93
+ async def list_tools() -> list[types.Tool]:
94
+ """List all available Strata tools."""
95
+ try:
96
+ # Get available servers from client manager
97
+ user_available_servers = list(client_manager.active_clients.keys())
98
+ return get_tool_definitions(user_available_servers)
99
+ except Exception:
100
+ logger.exception("Error listing strata tools")
101
+ return []
102
+
103
+ @server.call_tool(validate_input=False)
104
+ async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]:
105
+ """Call one of the strata tools."""
106
+ return await execute_tool(name, arguments, client_manager)
107
+
108
+
109
+ async def run_stdio_server_async() -> None:
110
+ """Run the Strata MCP router in stdio mode."""
111
+
112
+ # Create server instance
113
+ server = Server("strata-mcp-stdio")
114
+
115
+ # Set up shared handlers
116
+ setup_server_handlers(server)
117
+
118
+ # Use shared config watching context manager
119
+ logger.info("Strata MCP Router running in stdio mode")
120
+ async with config_watching_context():
121
+ async with stdio_server() as (read_stream, write_stream):
122
+ await server.run(
123
+ read_stream, write_stream, server.create_initialization_options()
124
+ )
125
+
126
+
127
+ def run_stdio_server() -> int:
128
+ """Run the stdio server synchronously."""
129
+ try:
130
+ asyncio.run(run_stdio_server_async())
131
+ return 0
132
+ except KeyboardInterrupt:
133
+ logger.info("Stdio server stopped by user")
134
+ return 0
135
+ except Exception:
136
+ logger.exception("Error running stdio server")
137
+ return 1
138
+
139
+
140
+ def run_server(port: int, json_response: bool) -> int:
141
+ """Run the MCP router server with the given configuration."""
142
+
143
+ # Create the MCP router server instance
144
+ app = Server("strata-mcp-server")
145
+
146
+ # Set up shared handlers
147
+ setup_server_handlers(app)
148
+
149
+ # Set up SSE transport
150
+ sse = SseServerTransport("/messages/")
151
+
152
+ async def handle_sse(request):
153
+ logger.info("Handling SSE connection for router")
154
+
155
+ try:
156
+ async with sse.connect_sse(
157
+ request.scope, request.receive, request._send
158
+ ) as streams:
159
+ await app.run(
160
+ streams[0], streams[1], app.create_initialization_options()
161
+ )
162
+ except Exception as e:
163
+ logger.error(f"SSE connection error: {e}")
164
+ return Response()
165
+
166
+ # Set up StreamableHTTP transport
167
+ session_manager = StreamableHTTPSessionManager(
168
+ app=app,
169
+ event_store=None, # Stateless mode
170
+ json_response=json_response,
171
+ stateless=True,
172
+ )
173
+
174
+ async def handle_streamable_http(
175
+ scope: Scope, receive: Receive, send: Send
176
+ ) -> None:
177
+ logger.info("Handling StreamableHTTP request for router")
178
+ try:
179
+ await session_manager.handle_request(scope, receive, send)
180
+ finally:
181
+ logger.info("StreamableHTTP request completed")
182
+
183
+ @contextlib.asynccontextmanager
184
+ async def lifespan(app: Starlette) -> AsyncIterator[None]:
185
+ """Context manager for session manager and client initialization."""
186
+ async with config_watching_context():
187
+ async with session_manager.run():
188
+ logger.info("Strata MCP Router started with dual transports!")
189
+ logger.info("Available tools:")
190
+ logger.info("- discover_server_actions: Discover available actions")
191
+ logger.info("- get_action_details: Get detailed action parameters")
192
+ logger.info("- execute_action: Execute server actions")
193
+ logger.info("- search_documentation: Search server documentation")
194
+ logger.info("- handle_auth_failure: Handle authentication issues")
195
+ yield
196
+
197
+ # Create an ASGI application with routes for both transports
198
+ starlette_app = Starlette(
199
+ debug=True,
200
+ routes=[
201
+ # SSE routes
202
+ Route("/sse", endpoint=handle_sse, methods=["GET"]),
203
+ Mount("/messages/", app=sse.handle_post_message),
204
+ # StreamableHTTP route
205
+ Mount("/mcp", app=handle_streamable_http),
206
+ ],
207
+ lifespan=lifespan,
208
+ )
209
+
210
+ logger.info(f"Strata MCP Router starting on port {port} with dual transports")
211
+
212
+ import uvicorn
213
+
214
+ uvicorn.run(starlette_app, host="0.0.0.0", port=port)
215
+
216
+ return 0