iotsploit-mcp 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,67 @@
1
+ Metadata-Version: 2.3
2
+ Name: iotsploit-mcp
3
+ Version: 0.0.6
4
+ Summary: IoTSploit MCP runtime (FastMCP stdio server + WebSocket bridge), framework-free outer ring.
5
+ License: GPL-3.0-or-later
6
+ Keywords: iot,security,mcp,model-context-protocol,pentest
7
+ Author: IoTSploit Team
8
+ Author-email: support@iotsploit.org
9
+ Requires-Python: >=3.10,<4.0
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Information Technology
13
+ Classifier: License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.10
16
+ Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
18
+ Classifier: Programming Language :: Python :: 3.13
19
+ Classifier: Topic :: Security
20
+ Classifier: Topic :: System :: Hardware
21
+ Requires-Dist: iotsploit-core (>=0.0.6,<0.0.7)
22
+ Requires-Dist: mcp (>=1.9.4,<2.0.0)
23
+ Requires-Dist: nest-asyncio (>=1.6.0,<2.0.0)
24
+ Requires-Dist: requests (>=2.31.0,<3.0.0)
25
+ Requires-Dist: websockets (>=12.0,<13.0)
26
+ Project-URL: Documentation, https://www.iotsploit.org/
27
+ Project-URL: Homepage, https://www.iotsploit.org/
28
+ Project-URL: Repository, https://github.com/TKXB/iotsploit
29
+ Description-Content-Type: text/markdown
30
+
31
+ # iotsploit-mcp
32
+
33
+ `iotsploit-mcp` 是 IoTSploit 的 MCP 运行时组件(execution plane / outer ring),包含:
34
+
35
+ - FastMCP **stdio server**(提供 MCP tools)
36
+ - WebSocket **bridge**(给上层 UI/Django consumer 通过 `ws://host:9998` 访问)
37
+
38
+ ## 命令
39
+
40
+ - 默认启动 WebSocket bridge(等价于 `ws`):
41
+
42
+ ```bash
43
+ iotsploit-mcp
44
+ ```
45
+
46
+ - 显式启动 WebSocket bridge:
47
+
48
+ ```bash
49
+ iotsploit-mcp ws --host 0.0.0.0 --port 9998
50
+ ```
51
+
52
+ - 只启动 stdio FastMCP server(一般由 bridge 拉起):
53
+
54
+ ```bash
55
+ iotsploit-mcp stdio
56
+ ```
57
+
58
+ ## 环境变量
59
+
60
+ - `IOTSPLOIT_DJANGO_API_BASE_URL`:Django HTTP API base URL(默认 `http://127.0.0.1:8888`)
61
+ - `IOTSPLOIT_DJANGO_API_TOKEN`:可选 Bearer token
62
+ - `IOTSPLOIT_DJANGO_API_TIMEOUT_S`:可选超时(秒)
63
+ - `IOTSPLOIT_DEVICE_PLUGINS_DIR`:device driver 插件目录
64
+ - `IOTSPLOIT_EXPLOIT_PLUGINS_DIR`:exploit 插件目录
65
+
66
+
67
+
@@ -0,0 +1,36 @@
1
+ # iotsploit-mcp
2
+
3
+ `iotsploit-mcp` 是 IoTSploit 的 MCP 运行时组件(execution plane / outer ring),包含:
4
+
5
+ - FastMCP **stdio server**(提供 MCP tools)
6
+ - WebSocket **bridge**(给上层 UI/Django consumer 通过 `ws://host:9998` 访问)
7
+
8
+ ## 命令
9
+
10
+ - 默认启动 WebSocket bridge(等价于 `ws`):
11
+
12
+ ```bash
13
+ iotsploit-mcp
14
+ ```
15
+
16
+ - 显式启动 WebSocket bridge:
17
+
18
+ ```bash
19
+ iotsploit-mcp ws --host 0.0.0.0 --port 9998
20
+ ```
21
+
22
+ - 只启动 stdio FastMCP server(一般由 bridge 拉起):
23
+
24
+ ```bash
25
+ iotsploit-mcp stdio
26
+ ```
27
+
28
+ ## 环境变量
29
+
30
+ - `IOTSPLOIT_DJANGO_API_BASE_URL`:Django HTTP API base URL(默认 `http://127.0.0.1:8888`)
31
+ - `IOTSPLOIT_DJANGO_API_TOKEN`:可选 Bearer token
32
+ - `IOTSPLOIT_DJANGO_API_TIMEOUT_S`:可选超时(秒)
33
+ - `IOTSPLOIT_DEVICE_PLUGINS_DIR`:device driver 插件目录
34
+ - `IOTSPLOIT_EXPLOIT_PLUGINS_DIR`:exploit 插件目录
35
+
36
+
@@ -0,0 +1,39 @@
1
+ [tool.poetry]
2
+ name = "iotsploit-mcp"
3
+ version = "0.0.6"
4
+ description = "IoTSploit MCP runtime (FastMCP stdio server + WebSocket bridge), framework-free outer ring."
5
+ authors = ["IoTSploit Team <support@iotsploit.org>"]
6
+ readme = "README.md"
7
+ license = "GPL-3.0-or-later"
8
+ homepage = "https://www.iotsploit.org/"
9
+ repository = "https://github.com/TKXB/iotsploit"
10
+ documentation = "https://www.iotsploit.org/"
11
+ keywords = ["iot", "security", "mcp", "model-context-protocol", "pentest"]
12
+ classifiers = [
13
+ "Development Status :: 3 - Alpha",
14
+ "Intended Audience :: Developers",
15
+ "Intended Audience :: Information Technology",
16
+ "License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)",
17
+ "Topic :: Security",
18
+ "Topic :: System :: Hardware",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.10",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ ]
24
+ packages = [{ include = "iotsploit_mcp", from = "src" }]
25
+
26
+ [tool.poetry.dependencies]
27
+ python = ">=3.10,<4.0"
28
+ mcp = "^1.9.4"
29
+ websockets = "^12.0"
30
+ requests = "^2.31.0"
31
+ nest-asyncio = "^1.6.0"
32
+ iotsploit-core = "^0.0.6"
33
+
34
+ [tool.poetry.scripts]
35
+ iotsploit-mcp = "iotsploit_mcp.cli:main"
36
+
37
+ [build-system]
38
+ requires = ["poetry-core>=1.8.0"]
39
+ build-backend = "poetry.core.masonry.api"
@@ -0,0 +1,7 @@
1
+ """MCP entrypoint package (framework-free outer ring).
2
+
3
+ This package wires `iotsploit_core` with adapters that talk to external services
4
+ (e.g. iotsploit_django HTTP API) without importing Django at runtime.
5
+ """
6
+
7
+
@@ -0,0 +1,8 @@
1
+ from __future__ import annotations
2
+
3
+ from iotsploit_mcp.cli import main
4
+
5
+ if __name__ == "__main__":
6
+ main()
7
+
8
+
@@ -0,0 +1,3 @@
1
+ """Adapters for `iotsploit_mcp`."""
2
+
3
+
@@ -0,0 +1,107 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import requests
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class DjangoHttpApiConfig:
12
+ base_url: str
13
+ timeout_s: float = 5.0
14
+ bearer_token: str | None = None
15
+
16
+
17
+ class HttpDriverStateRepository:
18
+ """DriverStateRepository over iotsploit_django HTTP API.
19
+
20
+ Django endpoints (mounted under `/api/`):
21
+ - GET /api/get_driver_states/
22
+ - POST /api/enable_driver/
23
+ - POST /api/disable_driver/
24
+ """
25
+
26
+ def __init__(self, *, base_url: str, timeout_s: float = 5.0, bearer_token: str | None = None) -> None:
27
+ base_url = (base_url or "").strip().rstrip("/")
28
+ if not base_url:
29
+ raise ValueError("base_url is required for HttpDriverStateRepository")
30
+ self._cfg = DjangoHttpApiConfig(base_url=base_url, timeout_s=float(timeout_s), bearer_token=bearer_token)
31
+ self._session = requests.Session()
32
+
33
+ @staticmethod
34
+ def from_env() -> "HttpDriverStateRepository":
35
+ base_url = os.getenv("IOTSPLOIT_DJANGO_API_BASE_URL", "http://127.0.0.1:8888")
36
+ timeout_s = float(os.getenv("IOTSPLOIT_DJANGO_API_TIMEOUT_S", "5.0"))
37
+ bearer_token = os.getenv("IOTSPLOIT_DJANGO_API_TOKEN") or None
38
+ return HttpDriverStateRepository(base_url=base_url, timeout_s=timeout_s, bearer_token=bearer_token)
39
+
40
+ def _headers(self) -> dict[str, str]:
41
+ headers: dict[str, str] = {"Accept": "application/json"}
42
+ if self._cfg.bearer_token:
43
+ headers["Authorization"] = f"Bearer {self._cfg.bearer_token}"
44
+ return headers
45
+
46
+ def _url(self, path: str) -> str:
47
+ path = path if path.startswith("/") else f"/{path}"
48
+ return f"{self._cfg.base_url}{path}"
49
+
50
+ @staticmethod
51
+ def _raise_for_bad_response(resp: requests.Response, *, context: str) -> None:
52
+ if 200 <= resp.status_code < 300:
53
+ return
54
+ raise RuntimeError(f"{context} failed: HTTP {resp.status_code}: {resp.text[:500]}")
55
+
56
+ def list_enabled(self) -> dict[str, bool]:
57
+ resp = self._session.get(
58
+ self._url("/api/get_driver_states/"),
59
+ headers=self._headers(),
60
+ timeout=self._cfg.timeout_s,
61
+ )
62
+ self._raise_for_bad_response(resp, context="GET /api/get_driver_states/")
63
+
64
+ payload: dict[str, Any] = resp.json()
65
+ if payload.get("status") not in (None, "success"):
66
+ raise RuntimeError(f"GET /api/get_driver_states/ returned error: {payload!r}")
67
+
68
+ drivers = payload.get("drivers") or {}
69
+ if not isinstance(drivers, dict):
70
+ raise RuntimeError(f"Unexpected drivers payload: {drivers!r}")
71
+
72
+ out: dict[str, bool] = {}
73
+ for name, state in drivers.items():
74
+ if isinstance(state, dict):
75
+ out[str(name)] = bool(state.get("enabled", True))
76
+ else:
77
+ out[str(name)] = bool(state)
78
+ return out
79
+
80
+ def get_enabled(self, driver_name: str) -> bool | None:
81
+ if not driver_name:
82
+ return None
83
+ states = self.list_enabled()
84
+ return states.get(driver_name)
85
+
86
+ def set_enabled(self, driver_name: str, enabled: bool, description: str | None = None) -> None:
87
+ if not driver_name:
88
+ raise ValueError("driver_name is required")
89
+
90
+ endpoint = "/api/enable_driver/" if enabled else "/api/disable_driver/"
91
+ body: dict[str, Any] = {"driver_name": driver_name}
92
+ if description is not None:
93
+ body["description"] = description
94
+
95
+ resp = self._session.post(
96
+ self._url(endpoint),
97
+ headers={**self._headers(), "Content-Type": "application/json"},
98
+ json=body,
99
+ timeout=self._cfg.timeout_s,
100
+ )
101
+ self._raise_for_bad_response(resp, context=f"POST {endpoint}")
102
+
103
+ payload: dict[str, Any] = resp.json()
104
+ if payload.get("status") not in (None, "success"):
105
+ raise RuntimeError(f"POST {endpoint} returned error: {payload!r}")
106
+
107
+
@@ -0,0 +1,135 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import requests
8
+
9
+ from iotsploit_core.domain.execution_plan import GroupStepSpec, PluginGroupSpec, PluginStepSpec
10
+ from iotsploit_core.ports.plugin_repo import PluginGroupRepository
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class DjangoHttpApiConfig:
15
+ base_url: str
16
+ timeout_s: float = 5.0
17
+ bearer_token: str | None = None
18
+
19
+
20
+ class HttpPluginGroupRepository(PluginGroupRepository):
21
+ """PluginGroupRepository over iotsploit_django HTTP API (SSOT).
22
+
23
+ Endpoints:
24
+ - GET /api/plugin_groups/enabled/
25
+ - GET /api/plugin_groups/<name>/
26
+ """
27
+
28
+ def __init__(self, *, base_url: str, timeout_s: float = 5.0, bearer_token: str | None = None) -> None:
29
+ base_url = (base_url or "").strip().rstrip("/")
30
+ if not base_url:
31
+ raise ValueError("base_url is required for HttpPluginGroupRepository")
32
+ self._cfg = DjangoHttpApiConfig(base_url=base_url, timeout_s=float(timeout_s), bearer_token=bearer_token)
33
+ self._session = requests.Session()
34
+
35
+ @staticmethod
36
+ def from_env() -> "HttpPluginGroupRepository":
37
+ base_url = os.getenv("IOTSPLOIT_DJANGO_API_BASE_URL", "http://127.0.0.1:8888")
38
+ timeout_s = float(os.getenv("IOTSPLOIT_DJANGO_API_TIMEOUT_S", "5.0"))
39
+ bearer_token = os.getenv("IOTSPLOIT_DJANGO_API_TOKEN") or None
40
+ return HttpPluginGroupRepository(base_url=base_url, timeout_s=timeout_s, bearer_token=bearer_token)
41
+
42
+ def _headers(self) -> dict[str, str]:
43
+ headers: dict[str, str] = {"Accept": "application/json"}
44
+ if self._cfg.bearer_token:
45
+ headers["Authorization"] = f"Bearer {self._cfg.bearer_token}"
46
+ return headers
47
+
48
+ def _url(self, path: str) -> str:
49
+ path = path if path.startswith("/") else f"/{path}"
50
+ return f"{self._cfg.base_url}{path}"
51
+
52
+ @staticmethod
53
+ def _raise_for_bad_response(resp: requests.Response, *, context: str) -> None:
54
+ if 200 <= resp.status_code < 300:
55
+ return
56
+ raise RuntimeError(f"{context} failed: HTTP {resp.status_code}: {resp.text[:500]}")
57
+
58
+ @staticmethod
59
+ def _to_spec(payload: dict[str, Any]) -> PluginGroupSpec:
60
+ plugin_steps: list[PluginStepSpec] = []
61
+ for s in payload.get("plugin_steps") or []:
62
+ if not isinstance(s, dict):
63
+ continue
64
+ plugin_steps.append(
65
+ PluginStepSpec(
66
+ sequence=int(s.get("sequence", 100)),
67
+ plugin_name=str(s.get("plugin_name") or ""),
68
+ ignore_fail=bool(s.get("ignore_fail", False)),
69
+ )
70
+ )
71
+
72
+ group_steps: list[GroupStepSpec] = []
73
+ for s in payload.get("group_steps") or []:
74
+ if not isinstance(s, dict):
75
+ continue
76
+ group_steps.append(
77
+ GroupStepSpec(
78
+ sequence=int(s.get("sequence", 100)),
79
+ group_name=str(s.get("group_name") or ""),
80
+ ignore_fail=bool(s.get("ignore_fail", False)),
81
+ force_exec=bool(s.get("force_exec", False)),
82
+ )
83
+ )
84
+
85
+ return PluginGroupSpec(
86
+ name=str(payload.get("name") or ""),
87
+ enabled=bool(payload.get("enabled", True)),
88
+ plugin_steps=plugin_steps,
89
+ group_steps=group_steps,
90
+ )
91
+
92
+ def list_enabled_groups(self) -> list[PluginGroupSpec]:
93
+ resp = self._session.get(
94
+ self._url("/api/plugin_groups/enabled/"),
95
+ headers=self._headers(),
96
+ timeout=self._cfg.timeout_s,
97
+ )
98
+ self._raise_for_bad_response(resp, context="GET /api/plugin_groups/enabled/")
99
+ payload: dict[str, Any] = resp.json()
100
+ if payload.get("status") not in (None, "success"):
101
+ raise RuntimeError(f"GET /api/plugin_groups/enabled/ returned error: {payload!r}")
102
+
103
+ items = payload.get("groups") or []
104
+ if not isinstance(items, list):
105
+ raise RuntimeError(f"Unexpected groups payload: {items!r}")
106
+ out: list[PluginGroupSpec] = []
107
+ for it in items:
108
+ if not isinstance(it, dict):
109
+ continue
110
+ spec = self._to_spec(it)
111
+ if spec.name:
112
+ out.append(spec)
113
+ return out
114
+
115
+ def get_group(self, name: str) -> PluginGroupSpec | None:
116
+ if not name:
117
+ return None
118
+ resp = self._session.get(
119
+ self._url(f"/api/plugin_groups/{name}/"),
120
+ headers=self._headers(),
121
+ timeout=self._cfg.timeout_s,
122
+ )
123
+ if resp.status_code == 404:
124
+ return None
125
+ self._raise_for_bad_response(resp, context=f"GET /api/plugin_groups/{name}/")
126
+ payload: dict[str, Any] = resp.json()
127
+ if payload.get("status") not in (None, "success"):
128
+ raise RuntimeError(f"GET /api/plugin_groups/{name}/ returned error: {payload!r}")
129
+ g = payload.get("group")
130
+ if not isinstance(g, dict):
131
+ return None
132
+ spec = self._to_spec(g)
133
+ return spec if spec.name else None
134
+
135
+
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import requests
8
+
9
+ from iotsploit_core.domain.plugin import PluginMeta
10
+ from iotsploit_core.ports.plugin_repo import PluginMetaRepository
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class DjangoHttpApiConfig:
15
+ base_url: str
16
+ timeout_s: float = 5.0
17
+ bearer_token: str | None = None
18
+
19
+
20
+ class HttpPluginMetaRepository(PluginMetaRepository):
21
+ """PluginMetaRepository over iotsploit_django HTTP API (SSOT).
22
+
23
+ Endpoints:
24
+ - GET /api/plugins/exploits/enabled/
25
+ - POST /api/plugins/exploits/discovered/ (optional discovery report; must NOT override enabled/disabled)
26
+ """
27
+
28
+ def __init__(self, *, base_url: str, timeout_s: float = 5.0, bearer_token: str | None = None) -> None:
29
+ base_url = (base_url or "").strip().rstrip("/")
30
+ if not base_url:
31
+ raise ValueError("base_url is required for HttpPluginMetaRepository")
32
+ self._cfg = DjangoHttpApiConfig(base_url=base_url, timeout_s=float(timeout_s), bearer_token=bearer_token)
33
+ self._session = requests.Session()
34
+
35
+ @staticmethod
36
+ def from_env() -> "HttpPluginMetaRepository":
37
+ base_url = os.getenv("IOTSPLOIT_DJANGO_API_BASE_URL", "http://127.0.0.1:8888")
38
+ timeout_s = float(os.getenv("IOTSPLOIT_DJANGO_API_TIMEOUT_S", "5.0"))
39
+ bearer_token = os.getenv("IOTSPLOIT_DJANGO_API_TOKEN") or None
40
+ return HttpPluginMetaRepository(base_url=base_url, timeout_s=timeout_s, bearer_token=bearer_token)
41
+
42
+ def _headers(self) -> dict[str, str]:
43
+ headers: dict[str, str] = {"Accept": "application/json"}
44
+ if self._cfg.bearer_token:
45
+ headers["Authorization"] = f"Bearer {self._cfg.bearer_token}"
46
+ return headers
47
+
48
+ def _url(self, path: str) -> str:
49
+ path = path if path.startswith("/") else f"/{path}"
50
+ return f"{self._cfg.base_url}{path}"
51
+
52
+ @staticmethod
53
+ def _raise_for_bad_response(resp: requests.Response, *, context: str) -> None:
54
+ if 200 <= resp.status_code < 300:
55
+ return
56
+ raise RuntimeError(f"{context} failed: HTTP {resp.status_code}: {resp.text[:500]}")
57
+
58
+ def upsert(self, meta: PluginMeta) -> None:
59
+ # Optional discovery report: do not treat as SSOT write for enabled/disabled.
60
+ body = {
61
+ "plugins": [
62
+ {
63
+ "name": meta.name,
64
+ "module_path": meta.module_path,
65
+ "description": meta.description,
66
+ "author": meta.author,
67
+ "license": meta.license,
68
+ "parameters": meta.parameters or {},
69
+ }
70
+ ]
71
+ }
72
+ resp = self._session.post(
73
+ self._url("/api/plugins/exploits/discovered/"),
74
+ headers={**self._headers(), "Content-Type": "application/json"},
75
+ json=body,
76
+ timeout=self._cfg.timeout_s,
77
+ )
78
+ self._raise_for_bad_response(resp, context="POST /api/plugins/exploits/discovered/")
79
+
80
+ def list_enabled(self) -> list[PluginMeta]:
81
+ resp = self._session.get(
82
+ self._url("/api/plugins/exploits/enabled/"),
83
+ headers=self._headers(),
84
+ timeout=self._cfg.timeout_s,
85
+ )
86
+ self._raise_for_bad_response(resp, context="GET /api/plugins/exploits/enabled/")
87
+ payload: dict[str, Any] = resp.json()
88
+ if payload.get("status") not in (None, "success"):
89
+ raise RuntimeError(f"GET /api/plugins/exploits/enabled/ returned error: {payload!r}")
90
+
91
+ items = payload.get("plugins") or []
92
+ if not isinstance(items, list):
93
+ raise RuntimeError(f"Unexpected plugins payload: {items!r}")
94
+
95
+ metas: list[PluginMeta] = []
96
+ for it in items:
97
+ if not isinstance(it, dict):
98
+ continue
99
+ metas.append(
100
+ PluginMeta(
101
+ name=str(it.get("name") or ""),
102
+ module_path=str(it.get("module_path") or ""),
103
+ enabled=bool(it.get("enabled", True)),
104
+ description=str(it.get("description") or ""),
105
+ author=str(it.get("author") or ""),
106
+ license=str(it.get("license") or ""),
107
+ parameters=it.get("parameters") if isinstance(it.get("parameters"), dict) else None,
108
+ )
109
+ )
110
+ return [m for m in metas if m.name and m.module_path]
111
+
112
+ def disable_missing(self, names: set[str]) -> int:
113
+ # SSOT note: do NOT disable globally based on a single MCP node's filesystem.
114
+ # In multi-node deployments, a plugin may exist on another node.
115
+ return 0
116
+
117
+
@@ -0,0 +1,31 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from iotsploit_core.ports.task_runner import TaskRunner
6
+
7
+
8
+ class LocalTaskRunner(TaskRunner):
9
+ """最小可用 TaskRunner:用于 MCP 进程内满足依赖。
10
+
11
+ 当前 MCP 主要走 `execute_plugin_async`(直接 await),不会依赖 Celery。
12
+ 这里实现 submit 仅用于兼容 `ExploitPluginManager` 的构造函数签名与返回结构。
13
+ """
14
+
15
+ def submit(
16
+ self,
17
+ plugin_name: str,
18
+ target: dict | None,
19
+ parameters: dict,
20
+ *,
21
+ context: dict | None = None,
22
+ ) -> dict[str, Any]:
23
+ return {
24
+ "execution_type": "in_process",
25
+ "plugin_name": plugin_name,
26
+ "target": target,
27
+ "parameters": parameters,
28
+ "context": context,
29
+ }
30
+
31
+
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import asyncio
5
+ import sys
6
+
7
+
8
+ def _build_parser() -> argparse.ArgumentParser:
9
+ p = argparse.ArgumentParser(prog="iotsploit-mcp")
10
+ sub = p.add_subparsers(dest="cmd")
11
+
12
+ ws = sub.add_parser("ws", help="Start WebSocket bridge (default)")
13
+ ws.add_argument("--host", default="0.0.0.0")
14
+ ws.add_argument("--port", type=int, default=9998)
15
+
16
+ sub.add_parser("stdio", help="Start FastMCP stdio server (usually started by bridge)")
17
+ return p
18
+
19
+
20
+ async def _run_ws(host: str, port: int) -> None:
21
+ from iotsploit_mcp.websocket_bridge_simple import SATMCPWebSocketBridge
22
+
23
+ bridge = SATMCPWebSocketBridge(host=host, port=port)
24
+ await bridge.start_server()
25
+
26
+
27
+ async def _run_stdio() -> None:
28
+ import nest_asyncio
29
+
30
+ nest_asyncio.apply()
31
+
32
+ from iotsploit_mcp.sat_fastmcp import run_stdio_async
33
+
34
+ await run_stdio_async()
35
+
36
+
37
+ def main(argv: list[str] | None = None) -> None:
38
+ argv = list(sys.argv[1:] if argv is None else argv)
39
+ p = _build_parser()
40
+ args = p.parse_args(argv)
41
+
42
+ cmd = args.cmd or "ws" # default: ws
43
+ if cmd == "ws":
44
+ asyncio.run(_run_ws(host=args.host, port=args.port))
45
+ return
46
+ if cmd == "stdio":
47
+ asyncio.run(_run_stdio())
48
+ return
49
+
50
+ p.error(f"Unknown cmd: {cmd}")
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
55
+
@@ -0,0 +1,78 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ from iotsploit_core.core.device_manager import DeviceDriverManager
8
+ from iotsploit_core.core.exploit_manager import ExploitPluginManager
9
+
10
+ from iotsploit_mcp.adapters.http_driver_state_repo import HttpDriverStateRepository
11
+ from iotsploit_mcp.adapters.http_plugin_group_repo import HttpPluginGroupRepository
12
+ from iotsploit_mcp.adapters.http_plugin_meta_repo import HttpPluginMetaRepository
13
+ from iotsploit_mcp.adapters.task_runner_local import LocalTaskRunner
14
+
15
+
16
+ def _default_exploit_plugins_dir() -> Path:
17
+ env = os.getenv("IOTSPLOIT_EXPLOIT_PLUGINS_DIR") or os.getenv("SAT_EXPLOIT_PLUGINS_DIR")
18
+ if env:
19
+ return Path(env)
20
+ return Path.cwd() / "plugins" / "exploits"
21
+
22
+
23
+ def build_device_manager(
24
+ *,
25
+ plugins_dir: str | Path | None = None,
26
+ usb_config_file: str | Path | None = None,
27
+ django_api_base_url: Optional[str] = None,
28
+ ) -> DeviceDriverManager:
29
+ """Build a framework-free DeviceDriverManager for MCP runtime.
30
+
31
+ Driver states are sourced from iotsploit_django HTTP API (single source of truth).
32
+ """
33
+
34
+ repo = (
35
+ HttpDriverStateRepository.from_env()
36
+ if django_api_base_url is None
37
+ else HttpDriverStateRepository(base_url=django_api_base_url)
38
+ )
39
+
40
+ # Fail fast so we don't start with wrong/unknown driver states.
41
+ _ = repo.list_enabled()
42
+
43
+ return DeviceDriverManager(
44
+ driver_state_repo=repo,
45
+ plugins_dir=plugins_dir,
46
+ usb_config_file=usb_config_file,
47
+ )
48
+
49
+
50
+ def build_exploit_manager(*, django_api_base_url: Optional[str] = None, plugins_dir: str | Path | None = None) -> ExploitPluginManager:
51
+ """Build ExploitPluginManager for MCP runtime.
52
+
53
+ SSOT mode: enabled plugins and group specs come from Django over HTTP.
54
+ Discovery reporting (upsert) is allowed but MUST NOT override enabled/disabled decisions.
55
+ """
56
+ plugin_repo = (
57
+ HttpPluginMetaRepository.from_env()
58
+ if django_api_base_url is None
59
+ else HttpPluginMetaRepository(base_url=django_api_base_url)
60
+ )
61
+ group_repo = (
62
+ HttpPluginGroupRepository.from_env()
63
+ if django_api_base_url is None
64
+ else HttpPluginGroupRepository(base_url=django_api_base_url)
65
+ )
66
+
67
+ # Fail fast: ensure Django API reachable & returns a well-formed payload.
68
+ _ = plugin_repo.list_enabled()
69
+ _ = group_repo.list_enabled_groups()
70
+
71
+ return ExploitPluginManager(
72
+ plugin_repo=plugin_repo,
73
+ group_repo=group_repo,
74
+ task_runner=LocalTaskRunner(),
75
+ plugins_dir=Path(plugins_dir) if plugins_dir is not None else _default_exploit_plugins_dir(),
76
+ )
77
+
78
+
@@ -0,0 +1,258 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ SAT FastMCP Server (stdio)
4
+
5
+ Execution plane: MCP tools call iotsploit-core usecases. All enabled/disabled states
6
+ are sourced from Django SSOT over HTTP APIs.
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import os
12
+
13
+ from mcp.server.fastmcp import FastMCP
14
+
15
+ from iotsploit_mcp.tools.xlogger_mcp import xlog_mcp
16
+
17
+ # sat_fastmcp runs as a subprocess; allow optional file logging.
18
+ logger = xlog_mcp.get_logger(
19
+ "sat_fastmcp",
20
+ # enable via env IOTSPLOIT_MCP_LOG_TO_FILE=1 (and optionally IOTSPLOIT_MCP_LOG_FILE / IOTSPLOIT_MCP_LOG_DIR)
21
+ )
22
+
23
+
24
+ mcp = FastMCP("sat-toolkit")
25
+
26
+
27
+ # Lazy-loaded managers (initialized on first use to avoid startup order issues)
28
+ _device_manager = None
29
+ _exploit_manager = None
30
+ _init_attempted = False
31
+
32
+
33
+ def _ensure_managers_initialized() -> tuple:
34
+ """Lazily initialize device/exploit managers on first use.
35
+
36
+ Returns (device_manager, exploit_manager). Either may be None if Django API is unavailable.
37
+ """
38
+ global _device_manager, _exploit_manager, _init_attempted
39
+
40
+ if _init_attempted:
41
+ return _device_manager, _exploit_manager
42
+
43
+ _init_attempted = True
44
+ try:
45
+ from iotsploit_mcp.composition_root import build_device_manager, build_exploit_manager
46
+
47
+ _device_manager = build_device_manager()
48
+ _exploit_manager = build_exploit_manager()
49
+ logger.info("SAT MCP components initialized (SSOT via Django HTTP API)")
50
+ except Exception as e:
51
+ logger.error("SAT MCP components failed to initialize: %s", e)
52
+ logger.info("Will retry on next tool call...")
53
+ _init_attempted = False # Allow retry on next call
54
+
55
+ return _device_manager, _exploit_manager
56
+
57
+
58
+ def get_device_manager():
59
+ """Get device manager, initializing if needed."""
60
+ dm, _ = _ensure_managers_initialized()
61
+ return dm
62
+
63
+
64
+ def get_exploit_manager():
65
+ """Get exploit manager, initializing if needed."""
66
+ _, em = _ensure_managers_initialized()
67
+ return em
68
+
69
+
70
+ @mcp.tool()
71
+ async def scan_devices(driver_name: str = "all") -> str:
72
+ """Scan for available devices."""
73
+ try:
74
+ device_manager = get_device_manager()
75
+ if not device_manager:
76
+ return "Device manager not available. Ensure Django API is running at IOTSPLOIT_DJANGO_API_BASE_URL"
77
+
78
+ logger.info("Scanning devices (driver: %s)", driver_name)
79
+
80
+ if driver_name == "all":
81
+ enabled_drivers = [
82
+ driver for driver in device_manager.list_drivers() if device_manager.is_driver_enabled(driver)
83
+ ]
84
+
85
+ if not enabled_drivers:
86
+ return "No enabled device drivers found"
87
+
88
+ all_results = []
89
+ for driver in enabled_drivers:
90
+ try:
91
+ result = device_manager.scan_devices(driver)
92
+ if result.get("status") != "success":
93
+ all_results.append(f"Driver '{driver}': {result.get('message', 'Scan failed')}")
94
+ continue
95
+ devices = result.get("devices", [])
96
+ device_list = [
97
+ f" - {getattr(device, 'name', 'Unknown')}: {getattr(device, 'device_type', 'Unknown')}"
98
+ for device in devices
99
+ ]
100
+ if device_list:
101
+ all_results.append(f"Driver '{driver}':\n" + "\n".join(device_list))
102
+ else:
103
+ all_results.append(f"Driver '{driver}': No devices found")
104
+ except Exception as e:
105
+ all_results.append(f"Driver '{driver}': Error - {str(e)}")
106
+
107
+ return "\n\n".join(all_results) if all_results else "No devices found"
108
+
109
+ result = device_manager.scan_devices(driver_name)
110
+ if result.get("status") != "success":
111
+ return f"Scan failed: {result.get('message', 'Unknown error')}"
112
+ devices = result.get("devices", [])
113
+ if devices:
114
+ device_list = [f"- {getattr(device, 'name', 'Unknown')}: {getattr(device, 'device_type', 'Unknown')}" for device in devices]
115
+ return f"Found {len(devices)} devices:\n" + "\n".join(device_list)
116
+ return f"No devices found using driver '{driver_name}'"
117
+
118
+ except Exception as e:
119
+ logger.error("Error scanning devices: %s", e)
120
+ return f"Error scanning devices: {str(e)}"
121
+
122
+
123
+ @mcp.tool()
124
+ async def get_system_status() -> str:
125
+ """Get overall system status."""
126
+ try:
127
+ logger.info("Getting system status")
128
+ device_manager = get_device_manager()
129
+ exploit_manager = get_exploit_manager()
130
+ status = {
131
+ "timestamp": asyncio.get_event_loop().time(),
132
+ "device_manager_available": device_manager is not None,
133
+ "exploit_manager_available": exploit_manager is not None,
134
+ }
135
+ if device_manager:
136
+ enabled_drivers = [
137
+ driver for driver in device_manager.list_drivers() if device_manager.is_driver_enabled(driver)
138
+ ]
139
+ status["enabled_drivers"] = len(enabled_drivers)
140
+ status["total_drivers"] = len(device_manager.list_drivers())
141
+ if exploit_manager:
142
+ status["enabled_exploit_plugins"] = len(exploit_manager.list_plugins())
143
+ return json.dumps(status, indent=2)
144
+ except Exception as e:
145
+ logger.error("Error getting system status: %s", e)
146
+ return f"Error getting system status: {str(e)}"
147
+
148
+
149
+ @mcp.tool()
150
+ async def read_serial_port(
151
+ port: str = "/dev/ttyUSB0",
152
+ baudrate: int = 115200,
153
+ timeout: int = 300,
154
+ auto_interact: bool = True,
155
+ ) -> str:
156
+ """Read and analyze serial port output using exploit plugin system (SSOT mode)."""
157
+ try:
158
+ exploit_manager = get_exploit_manager()
159
+ if not exploit_manager:
160
+ return json.dumps({"success": False, "error": "Exploit manager not available. Ensure Django API is running."}, indent=2)
161
+
162
+ parameters = {
163
+ "port": port,
164
+ "baudrate": baudrate,
165
+ "timeout": timeout,
166
+ "auto_interact": auto_interact,
167
+ "analyze_output": True,
168
+ }
169
+
170
+ plugin_name = os.getenv("IOTSPLOIT_SERIAL_READER_PLUGIN_NAME", "Picocom Serial Reader")
171
+ logger.info("Executing exploit plugin '%s' for serial port %s", plugin_name, port)
172
+ result = await exploit_manager.execute_plugin_async(plugin_name, target=None, parameters=parameters)
173
+
174
+ if result is None:
175
+ return json.dumps(
176
+ {
177
+ "success": False,
178
+ "error": f"Plugin '{plugin_name}' not found or disabled in Django SSOT",
179
+ "available_enabled_plugins": exploit_manager.list_plugins(),
180
+ },
181
+ indent=2,
182
+ )
183
+
184
+ if getattr(result, "status", False):
185
+ analysis_data = (result.data or {}).get("analysis", {}) if hasattr(result, "data") else {}
186
+ report = (result.data or {}).get("report", "No report generated") if hasattr(result, "data") else "No report generated"
187
+ return json.dumps(
188
+ {
189
+ "success": True,
190
+ "message": getattr(result, "message", ""),
191
+ "analysis": {
192
+ "device_type": analysis_data.get("device_type", "unknown"),
193
+ "confidence": analysis_data.get("confidence", 0.0),
194
+ "login_detected": analysis_data.get("login_detected", False),
195
+ "shell_type": analysis_data.get("shell_type", "unknown"),
196
+ "output_lines": analysis_data.get("output_lines_count", 0),
197
+ "detected_patterns": analysis_data.get("detected_patterns", []),
198
+ },
199
+ "report": report,
200
+ "sample_output": analysis_data.get("raw_output_sample", []),
201
+ },
202
+ indent=2,
203
+ )
204
+
205
+ return json.dumps(
206
+ {
207
+ "success": False,
208
+ "error": getattr(result, "message", "Serial reading failed"),
209
+ "data": getattr(result, "data", None),
210
+ },
211
+ indent=2,
212
+ )
213
+
214
+ except Exception as e:
215
+ logger.error("Error reading serial port: %s", e)
216
+ return json.dumps({"success": False, "error": f"Error reading serial port: {str(e)}"}, indent=2)
217
+
218
+
219
+ @mcp.tool()
220
+ async def list_serial_ports() -> str:
221
+ """List available serial ports on the system."""
222
+ try:
223
+ logger.info("Listing serial ports")
224
+ import serial.tools.list_ports
225
+
226
+ ports = serial.tools.list_ports.comports()
227
+ port_list = []
228
+ for port in ports:
229
+ port_list.append(
230
+ {
231
+ "device": port.device,
232
+ "description": port.description,
233
+ "manufacturer": getattr(port, "manufacturer", "Unknown"),
234
+ "product": getattr(port, "product", "Unknown"),
235
+ "vid": getattr(port, "vid", None),
236
+ "pid": getattr(port, "pid", None),
237
+ }
238
+ )
239
+
240
+ if not port_list:
241
+ return "No serial ports found on the system"
242
+ return json.dumps({"success": True, "ports": port_list}, indent=2)
243
+ except Exception as e:
244
+ logger.error("Error listing serial ports: %s", e)
245
+ return f"Error listing serial ports: {str(e)}"
246
+
247
+
248
+ async def run_stdio_async() -> None:
249
+ """Run FastMCP stdio server (used by CLI and bridge)."""
250
+ logger.info("Starting SAT FastMCP Server (stdio)")
251
+ await mcp.run_stdio_async()
252
+
253
+
254
+ if __name__ == "__main__":
255
+ import nest_asyncio
256
+
257
+ nest_asyncio.apply()
258
+ asyncio.run(run_stdio_async())
@@ -0,0 +1,5 @@
1
+ """
2
+ iotsploit_mcp.tools
3
+ """
4
+
5
+
@@ -0,0 +1,129 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Dict, Optional
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class _MCPLogConfig:
13
+ fmt: str = "%(asctime)s | %(levelname)s | %(name)s | %(message)s"
14
+ datefmt: str = "%Y-%m-%d %H:%M:%S"
15
+ default_level: int = logging.INFO
16
+ default_log_dir: str = "/tmp/sat_logs"
17
+
18
+
19
+ def _env_bool(name: str, default: bool = False) -> bool:
20
+ v = os.getenv(name)
21
+ if v is None:
22
+ return default
23
+ return v.strip().lower() in {"1", "true", "yes", "y", "on"}
24
+
25
+
26
+ def _env_level(default: int) -> int:
27
+ v = (os.getenv("IOTSPLOIT_MCP_LOG_LEVEL") or "").strip().upper()
28
+ if not v:
29
+ return default
30
+ return getattr(logging, v, default)
31
+
32
+
33
+ def _safe_filename(name: str) -> str:
34
+ # Keep it deterministic and filesystem-safe.
35
+ return "".join(c if c.isalnum() or c in {"-", "_", "."} else "_" for c in name).replace(".", "_")
36
+
37
+
38
+ class XLoggerMCP:
39
+ """
40
+ MCP 组件专用 logger:
41
+ - 不依赖 Django/Channels
42
+ - 不调用 logging.basicConfig(避免污染宿主进程)
43
+ - 默认格式:YYYY-mm-dd HH:MM:SS | LEVEL | name | message
44
+ - 支持按需写文件(可通过参数或 env 开关)
45
+ """
46
+
47
+ _instance: Optional["XLoggerMCP"] = None
48
+
49
+ def __new__(cls) -> "XLoggerMCP":
50
+ if cls._instance is None:
51
+ cls._instance = super().__new__(cls)
52
+ return cls._instance
53
+
54
+ def __init__(self) -> None:
55
+ if hasattr(self, "_initialized") and self._initialized:
56
+ return
57
+ self._initialized = True
58
+ self._cfg = _MCPLogConfig(default_level=_env_level(_MCPLogConfig.default_level))
59
+ self._loggers: Dict[str, logging.Logger] = {}
60
+
61
+ def get_logger(
62
+ self,
63
+ name: str = "iotsploit_mcp",
64
+ *,
65
+ level: Optional[int] = None,
66
+ to_file: Optional[bool] = None,
67
+ file_path: Optional[str] = None,
68
+ ) -> logging.Logger:
69
+ lvl = self._cfg.default_level if level is None else level
70
+ want_file = _env_bool("IOTSPLOIT_MCP_LOG_TO_FILE", False) if to_file is None else bool(to_file)
71
+
72
+ # Global override: send all MCP logs to a single file if provided.
73
+ env_log_file = (os.getenv("IOTSPLOIT_MCP_LOG_FILE") or "").strip() or None
74
+ env_log_dir = (os.getenv("IOTSPLOIT_MCP_LOG_DIR") or "").strip() or None
75
+
76
+ logger = self._loggers.get(name) or logging.getLogger(name)
77
+ logger.setLevel(lvl)
78
+
79
+ formatter = logging.Formatter(self._cfg.fmt, datefmt=self._cfg.datefmt)
80
+
81
+ # Ensure stream handler exists (stderr).
82
+ if not any(getattr(h, "_mcp_stream", False) for h in logger.handlers):
83
+ sh = logging.StreamHandler(sys.stderr)
84
+ sh.setLevel(lvl)
85
+ sh.setFormatter(formatter)
86
+ setattr(sh, "_mcp_stream", True)
87
+ logger.addHandler(sh)
88
+
89
+ # File handler is optional and can be enabled later.
90
+ if want_file:
91
+ target_file = (
92
+ file_path
93
+ or env_log_file
94
+ or str(Path(env_log_dir or self._cfg.default_log_dir) / f"{_safe_filename(name)}.log")
95
+ )
96
+
97
+ if not any(getattr(h, "_mcp_file", None) == target_file for h in logger.handlers):
98
+ Path(target_file).parent.mkdir(parents=True, exist_ok=True)
99
+ fh = logging.FileHandler(target_file)
100
+ fh.setLevel(logging.DEBUG)
101
+ fh.setFormatter(formatter)
102
+ setattr(fh, "_mcp_file", target_file)
103
+ logger.addHandler(fh)
104
+
105
+ logger.propagate = False
106
+ self._loggers[name] = logger
107
+ return logger
108
+
109
+ # Convenience wrappers (optional, mirrors iotsploit_django.tools.xlogger)
110
+ def debug(self, msg: str, name: str = "iotsploit_mcp", **kwargs) -> None:
111
+ self.get_logger(name).debug(msg, **kwargs)
112
+
113
+ def info(self, msg: str, name: str = "iotsploit_mcp", **kwargs) -> None:
114
+ self.get_logger(name).info(msg, **kwargs)
115
+
116
+ def warning(self, msg: str, name: str = "iotsploit_mcp", **kwargs) -> None:
117
+ self.get_logger(name).warning(msg, **kwargs)
118
+
119
+ def error(self, msg: str, name: str = "iotsploit_mcp", **kwargs) -> None:
120
+ self.get_logger(name).error(msg, **kwargs)
121
+
122
+ def critical(self, msg: str, name: str = "iotsploit_mcp", **kwargs) -> None:
123
+ self.get_logger(name).critical(msg, **kwargs)
124
+
125
+
126
+ # Global instance
127
+ xlog_mcp = XLoggerMCP()
128
+
129
+
@@ -0,0 +1,246 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simplified WebSocket Bridge for SAT FastMCP Server
4
+
5
+ Bridge listens on ws://host:port (default 0.0.0.0:9998) and proxies messages to a
6
+ FastMCP stdio subprocess.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import json
13
+ import subprocess
14
+ import sys
15
+ from typing import Optional, Set
16
+
17
+ import websockets
18
+
19
+
20
+ from iotsploit_mcp.tools.xlogger_mcp import xlog_mcp
21
+
22
+ logger = xlog_mcp.get_logger("iotsploit_mcp.websocket_bridge_simple")
23
+
24
+
25
+ class SATMCPWebSocketBridge:
26
+ """Bridge between WebSocket clients and FastMCP stdio server."""
27
+
28
+ def __init__(self, *, host: str = "0.0.0.0", port: int = 9998):
29
+ self.host = host
30
+ self.port = port
31
+ self.mcp_process: Optional[subprocess.Popen] = None
32
+ self.clients: Set[websockets.WebSocketServerProtocol] = set()
33
+ self.mcp_initialized = False
34
+ self.request_id = 0
35
+
36
+ async def start_mcp_server(self) -> bool:
37
+ """Start the FastMCP stdio server process."""
38
+ try:
39
+ logger.info("Starting FastMCP server (stdio) via python -m iotsploit_mcp.cli stdio")
40
+ self.mcp_process = subprocess.Popen(
41
+ [sys.executable, "-m", "iotsploit_mcp.cli", "stdio"],
42
+ stdin=subprocess.PIPE,
43
+ stdout=subprocess.PIPE,
44
+ stderr=subprocess.PIPE,
45
+ text=True,
46
+ bufsize=0,
47
+ )
48
+ logger.info("FastMCP server started with PID: %s", self.mcp_process.pid)
49
+ # Important: initialize immediately. Some environments may behave as if stdin is EOF
50
+ # unless a first line is promptly written.
51
+ await self.initialize_mcp_session()
52
+
53
+ # Give the process a brief moment; if it still exits, surface stderr for debugging.
54
+ await asyncio.sleep(0.1)
55
+ rc = self.mcp_process.poll()
56
+ if rc is not None:
57
+ try:
58
+ stderr = self.mcp_process.stderr.read() if self.mcp_process.stderr else ""
59
+ except Exception:
60
+ stderr = ""
61
+ logger.error("FastMCP server process terminated immediately (rc=%s). stderr=%r", rc, stderr[:2000])
62
+ return False
63
+
64
+ return True
65
+ except Exception as e:
66
+ logger.error("Failed to start MCP server: %s", e)
67
+ return False
68
+
69
+ async def initialize_mcp_session(self) -> None:
70
+ try:
71
+ logger.info("Initializing MCP session...")
72
+ initialize_request = {
73
+ "jsonrpc": "2.0",
74
+ "method": "initialize",
75
+ "params": {
76
+ "protocolVersion": "2024-11-05",
77
+ "capabilities": {"tools": {}},
78
+ "clientInfo": {"name": "SAT-WebSocket-Bridge", "version": "1.0.0"},
79
+ },
80
+ "id": "init",
81
+ }
82
+ await self.send_to_mcp(initialize_request)
83
+
84
+ response_line = await asyncio.wait_for(
85
+ asyncio.get_event_loop().run_in_executor(None, self.mcp_process.stdout.readline),
86
+ timeout=5.0,
87
+ )
88
+
89
+ if response_line:
90
+ response = json.loads(response_line.strip())
91
+ if response.get("id") == "init" and "result" in response:
92
+ initialized_notification = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}
93
+ await self.send_to_mcp(initialized_notification)
94
+ self.mcp_initialized = True
95
+ logger.info("MCP session initialized successfully")
96
+ else:
97
+ logger.error("MCP initialization failed: %s", response)
98
+ else:
99
+ logger.error("No response to MCP initialize request")
100
+ except Exception as e:
101
+ logger.error("Failed to initialize MCP session: %s", e)
102
+
103
+ async def send_to_mcp(self, data: dict) -> None:
104
+ try:
105
+ message = json.dumps(data) + "\n"
106
+ assert self.mcp_process is not None and self.mcp_process.stdin is not None
107
+ self.mcp_process.stdin.write(message)
108
+ self.mcp_process.stdin.flush()
109
+ except Exception as e:
110
+ logger.error("Error sending to MCP: %s", e)
111
+
112
+ async def handle_client(self, websocket):
113
+ client_addr = f"{websocket.remote_address[0]}:{websocket.remote_address[1]}"
114
+ try:
115
+ logger.info("New client connected: %s", client_addr)
116
+ self.clients.add(websocket)
117
+ async for message in websocket:
118
+ try:
119
+ data = json.loads(message)
120
+ msg_type = data.get("type")
121
+ if msg_type == "mcp_call_tool":
122
+ await self.handle_tool_call(data, websocket)
123
+ elif msg_type == "mcp_list_tools":
124
+ await self.handle_list_tools(websocket)
125
+ else:
126
+ await websocket.send(json.dumps({"type": "error", "error": f"Unknown message type: {msg_type}"}))
127
+ except json.JSONDecodeError:
128
+ await websocket.send(json.dumps({"type": "error", "error": "Invalid JSON format"}))
129
+ except websockets.exceptions.ConnectionClosed:
130
+ logger.info("Client %s disconnected", client_addr)
131
+ except Exception as e:
132
+ logger.error("Error handling client %s: %s", client_addr, e)
133
+ finally:
134
+ self.clients.discard(websocket)
135
+
136
+ async def handle_tool_call(self, data: dict, websocket) -> None:
137
+ try:
138
+ if not self.mcp_initialized:
139
+ await websocket.send(json.dumps({"type": "error", "error": "MCP not initialized"}))
140
+ return
141
+
142
+ tool_name = data.get("tool_name")
143
+ arguments = data.get("arguments", {})
144
+ if not tool_name:
145
+ await websocket.send(json.dumps({"type": "error", "error": "Missing tool_name"}))
146
+ return
147
+
148
+ self.request_id += 1
149
+ request_id = f"req_{self.request_id}"
150
+
151
+ params = {"name": tool_name}
152
+ if arguments:
153
+ params["arguments"] = arguments
154
+
155
+ mcp_request = {"jsonrpc": "2.0", "method": "tools/call", "params": params, "id": request_id}
156
+ await self.send_to_mcp(mcp_request)
157
+
158
+ response_line = await asyncio.wait_for(
159
+ asyncio.get_event_loop().run_in_executor(None, self.mcp_process.stdout.readline),
160
+ timeout=30.0,
161
+ )
162
+
163
+ if response_line:
164
+ mcp_response = json.loads(response_line.strip())
165
+ if "result" in mcp_response:
166
+ await websocket.send(
167
+ json.dumps({"type": "tool_result", "tool_name": tool_name, "result": mcp_response["result"]})
168
+ )
169
+ else:
170
+ await websocket.send(
171
+ json.dumps(
172
+ {"type": "tool_error", "tool_name": tool_name, "error": mcp_response.get("error", "Unknown error")}
173
+ )
174
+ )
175
+ else:
176
+ await websocket.send(json.dumps({"type": "tool_error", "tool_name": tool_name, "error": "No response from MCP server"}))
177
+ except asyncio.TimeoutError:
178
+ await websocket.send(json.dumps({"type": "tool_error", "tool_name": tool_name, "error": "Tool execution timeout"}))
179
+ except Exception as e:
180
+ await websocket.send(json.dumps({"type": "tool_error", "tool_name": tool_name, "error": str(e)}))
181
+
182
+ async def handle_list_tools(self, websocket) -> None:
183
+ try:
184
+ if not self.mcp_initialized:
185
+ await websocket.send(json.dumps({"type": "error", "error": "MCP not initialized"}))
186
+ return
187
+
188
+ self.request_id += 1
189
+ request_id = f"req_{self.request_id}"
190
+ mcp_request = {"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": request_id}
191
+ await self.send_to_mcp(mcp_request)
192
+
193
+ response_line = await asyncio.wait_for(
194
+ asyncio.get_event_loop().run_in_executor(None, self.mcp_process.stdout.readline),
195
+ timeout=10.0,
196
+ )
197
+ if response_line:
198
+ mcp_response = json.loads(response_line.strip())
199
+ if "result" in mcp_response:
200
+ await websocket.send(json.dumps({"type": "tools_list", "tools": mcp_response["result"].get("tools", [])}))
201
+ else:
202
+ await websocket.send(json.dumps({"type": "error", "error": mcp_response.get("error", "Unknown error")}))
203
+ else:
204
+ await websocket.send(json.dumps({"type": "error", "error": "No response from MCP server"}))
205
+ except asyncio.TimeoutError:
206
+ await websocket.send(json.dumps({"type": "error", "error": "List tools timeout"}))
207
+ except Exception as e:
208
+ await websocket.send(json.dumps({"type": "error", "error": str(e)}))
209
+
210
+ async def start_server(self) -> None:
211
+ logger.info("Starting WebSocket bridge on %s:%s", self.host, self.port)
212
+ if not await self.start_mcp_server():
213
+ logger.error("Failed to start MCP server")
214
+ return
215
+ try:
216
+ async with websockets.serve(self.handle_client, self.host, self.port):
217
+ logger.info("WebSocket bridge running on ws://%s:%s", self.host, self.port)
218
+ await asyncio.Future()
219
+ finally:
220
+ await self.cleanup()
221
+
222
+ async def cleanup(self) -> None:
223
+ logger.info("Cleaning up...")
224
+ if self.mcp_process:
225
+ try:
226
+ self.mcp_process.terminate()
227
+ await asyncio.sleep(1)
228
+ if self.mcp_process.poll() is None:
229
+ self.mcp_process.kill()
230
+ logger.info("FastMCP server stopped")
231
+ except Exception as e:
232
+ logger.error("Error stopping MCP server: %s", e)
233
+
234
+
235
+ async def main() -> None:
236
+ bridge = SATMCPWebSocketBridge()
237
+ try:
238
+ await bridge.start_server()
239
+ finally:
240
+ await bridge.cleanup()
241
+
242
+
243
+ if __name__ == "__main__":
244
+ asyncio.run(main())
245
+
246
+