sshmd 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sshm/daemon.py ADDED
@@ -0,0 +1,289 @@
1
+ """sshmd — SSH session manager daemon."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ import signal
8
+ import sys
9
+ import threading
10
+ from collections.abc import Callable
11
+ from typing import Any
12
+
13
+ from . import protocol
14
+ from .config import (
15
+ PortForward,
16
+ add_port_forward,
17
+ find_entry,
18
+ load_entries,
19
+ remove_host,
20
+ remove_port_forward,
21
+ rename_host,
22
+ set_enabled,
23
+ )
24
+ from .ipc import IPC_HOST, IPC_PORT, IpcServer, StreamingResponse
25
+ from .process import ProcessManager
26
+ from .state import log_file, new_token, pid_file, write_token
27
+
28
+ log = logging.getLogger("sshm")
29
+
30
+ WATCHDOG_INTERVAL = 5.0
31
+
32
+
33
+ def _required(req: dict[str, Any], *keys: str) -> list[Any]:
34
+ values = [req.get(k) for k in keys]
35
+ if not all(values):
36
+ raise ValueError(f"Missing {' or '.join(keys)}")
37
+ return values
38
+
39
+
40
+ class Daemon:
41
+ def __init__(self) -> None:
42
+ self.pm = ProcessManager()
43
+ self.token = new_token()
44
+ self.server = IpcServer(handler=self.handle_request, token=self.token)
45
+ self._shutdown_event = threading.Event()
46
+ self._handlers: dict[str, Callable[[dict[str, Any]], Any]] = {
47
+ protocol.CMD_STATUS: self._cmd_status,
48
+ protocol.CMD_LIST: self._cmd_list,
49
+ protocol.CMD_CONNECT: self._cmd_connect,
50
+ protocol.CMD_ATTACH: self._cmd_attach,
51
+ protocol.CMD_DETACH: self._cmd_detach,
52
+ protocol.CMD_RESIZE: self._cmd_resize,
53
+ protocol.CMD_DISCONNECT: self._cmd_disconnect,
54
+ protocol.CMD_PORT_ADD: self._cmd_port_add,
55
+ protocol.CMD_PORT_REMOVE: self._cmd_port_remove,
56
+ protocol.CMD_ENABLE: self._cmd_enable,
57
+ protocol.CMD_DISABLE: self._cmd_disable,
58
+ protocol.CMD_REMOVE: self._cmd_remove,
59
+ protocol.CMD_RENAME: self._cmd_rename,
60
+ protocol.CMD_SHUTDOWN: self._cmd_shutdown,
61
+ }
62
+
63
+ def handle_request(self, req: dict[str, Any]) -> dict[str, Any] | StreamingResponse:
64
+ cmd = req.get("cmd", "")
65
+ handler = self._handlers.get(cmd)
66
+ if handler is None:
67
+ return protocol.err(f"Unknown command: {cmd}")
68
+ try:
69
+ return handler(req)
70
+ except ValueError as e:
71
+ # Request validation errors — report to the client without traceback noise
72
+ return protocol.err(str(e))
73
+ except Exception as e:
74
+ log.exception("Error handling %s", cmd)
75
+ return protocol.err(str(e))
76
+
77
+ # --- Command handlers ---
78
+
79
+ def _cmd_status(self, req: dict[str, Any]):
80
+ return protocol.ok({"status": "running", "sessions": len(self.pm.get_sessions())})
81
+
82
+ def _cmd_list(self, req: dict[str, Any]):
83
+ alias = req.get("alias")
84
+ if alias:
85
+ return protocol.ok([s.to_dict() for s in self.pm.get_sessions(alias)])
86
+
87
+ result = []
88
+ for e in load_entries():
89
+ if e.alias in ("*", ""):
90
+ continue
91
+ active = self.pm.get_sessions(e.alias)
92
+ result.append({
93
+ "alias": e.alias,
94
+ "hostname": e.hostname,
95
+ "user": e.user,
96
+ "port": e.port,
97
+ "enabled": e.enabled,
98
+ "connections": len(active),
99
+ "attached": sum(1 for s in active if s.attached),
100
+ "port_forwards": [pf.to_str() for pf in e.port_forwards],
101
+ })
102
+ return protocol.ok(result)
103
+
104
+ def _cmd_connect(self, req: dict[str, Any]):
105
+ [alias] = _required(req, "alias")
106
+ if not find_entry(alias):
107
+ return protocol.err(f"Unknown alias: {alias}")
108
+ session = self.pm.connect(alias, req.get("name"))
109
+ return protocol.ok(session.to_dict())
110
+
111
+ def _cmd_attach(self, req: dict[str, Any]):
112
+ [alias] = _required(req, "alias")
113
+ if not find_entry(alias):
114
+ return protocol.err(f"Unknown alias: {alias}")
115
+ session = self.pm.attach(alias, req.get("name"), cli_pid=req.get("cli_pid"))
116
+ if not session:
117
+ return protocol.err(f"No available session for '{alias}'")
118
+
119
+ # attach() has now RESERVED the session (attached=True). If anything below
120
+ # raises (e.g. a bogus cols/rows), release the reservation so the session
121
+ # isn't leaked out of the attachable pool.
122
+ try:
123
+ # Size the remote terminal to the attaching client before streaming starts
124
+ cols, rows = req.get("cols"), req.get("rows")
125
+ if cols and rows:
126
+ session.set_winsize(int(cols), int(rows))
127
+
128
+ # The IPC server bridges this connection after sending the response
129
+ return StreamingResponse(
130
+ protocol.ok(session.to_dict()), session, cli_pid=req.get("cli_pid")
131
+ )
132
+ except Exception:
133
+ session.detach()
134
+ raise
135
+
136
+ def _cmd_resize(self, req: dict[str, Any]):
137
+ alias, name, cols, rows = _required(req, "alias", "name", "cols", "rows")
138
+ for s in self.pm.get_sessions(alias):
139
+ if s.name == name:
140
+ s.set_winsize(int(cols), int(rows))
141
+ return protocol.ok({"resized": [int(cols), int(rows)]})
142
+ return protocol.err(f"No session {alias}/{name}")
143
+
144
+ def _cmd_detach(self, req: dict[str, Any]):
145
+ alias, name = _required(req, "alias", "name")
146
+ # Detach is mostly handled by the bridge ending; explicit detach works too
147
+ for s in self.pm.get_sessions(alias):
148
+ if s.name == name and s.attached:
149
+ s.detach()
150
+ # For enabled aliases, keep an unattached session ready
151
+ entry = find_entry(alias)
152
+ if entry and entry.enabled:
153
+ self.pm.ensure_unattached(alias)
154
+ return protocol.ok({"detached": True})
155
+
156
+ def _cmd_disconnect(self, req: dict[str, Any]):
157
+ alias, name = _required(req, "alias", "name")
158
+ return protocol.ok({"disconnected": self.pm.disconnect(alias, name)})
159
+
160
+ def _cmd_port_add(self, req: dict[str, Any]):
161
+ alias, direction, rule = _required(req, "alias", "direction", "rule")
162
+ if direction == "D":
163
+ pf = PortForward.socks(int(rule)) # SOCKS proxy: rule is just the port
164
+ else:
165
+ pf = PortForward.parse_rule(rule, direction)
166
+ add_port_forward(alias, pf)
167
+ self._rebuild_alias_sessions(alias)
168
+ return protocol.ok({"added": pf.to_str()})
169
+
170
+ def _cmd_port_remove(self, req: dict[str, Any]):
171
+ alias, rule = _required(req, "alias", "rule") # rule is already serialized
172
+ remove_port_forward(alias, rule)
173
+ self._rebuild_alias_sessions(alias)
174
+ return protocol.ok({"removed": rule})
175
+
176
+ def _cmd_enable(self, req: dict[str, Any]):
177
+ [alias] = _required(req, "alias")
178
+ set_enabled(alias, True)
179
+ self.pm.ensure_unattached(alias)
180
+ return protocol.ok({"enabled": alias})
181
+
182
+ def _cmd_disable(self, req: dict[str, Any]):
183
+ [alias] = _required(req, "alias")
184
+ set_enabled(alias, False)
185
+ return protocol.ok({"disabled": alias})
186
+
187
+ def _cmd_remove(self, req: dict[str, Any]):
188
+ [alias] = _required(req, "alias")
189
+ self.pm.disconnect_alias(alias)
190
+ remove_host(alias)
191
+ return protocol.ok({"removed": alias})
192
+
193
+ def _cmd_rename(self, req: dict[str, Any]):
194
+ old_alias, new_alias = _required(req, "alias", "new_alias")
195
+ if not find_entry(old_alias):
196
+ return protocol.err(f"Unknown alias: {old_alias}")
197
+ if find_entry(new_alias):
198
+ return protocol.err(f"Alias '{new_alias}' already exists")
199
+ # Sessions are keyed by alias; drop the old ones (an enabled host gets a
200
+ # fresh session under the new alias on the next watchdog tick).
201
+ self.pm.disconnect_alias(old_alias)
202
+ rename_host(old_alias, new_alias)
203
+ return protocol.ok({"renamed": new_alias})
204
+
205
+ def _cmd_shutdown(self, req: dict[str, Any]):
206
+ self._shutdown_event.set()
207
+ return protocol.ok({"shutting_down": True})
208
+
209
+ # --- Background maintenance ---
210
+
211
+ def _rebuild_alias_sessions(self, alias: str) -> None:
212
+ """Restart active sessions so they pick up config changes (e.g. forwards)."""
213
+ for s in self.pm.get_sessions(alias):
214
+ self.pm.rebuild_session(alias, s.name)
215
+
216
+ def _watchdog(self) -> None:
217
+ while not self._shutdown_event.is_set():
218
+ try:
219
+ self.pm.check_orphaned_attaches()
220
+ self.pm.check_health()
221
+ self._ensure_enabled_sessions()
222
+ except Exception:
223
+ log.exception("Watchdog error")
224
+ self._shutdown_event.wait(WATCHDOG_INTERVAL)
225
+
226
+ def _ensure_enabled_sessions(self) -> None:
227
+ for entry in load_entries():
228
+ if entry.enabled and entry.alias not in ("*", ""):
229
+ try:
230
+ self.pm.ensure_unattached(entry.alias)
231
+ except Exception:
232
+ # One failing alias must not block the others
233
+ log.exception("Failed to ensure session for %s", entry.alias)
234
+
235
+ def run(self) -> None:
236
+ try:
237
+ self.server.start()
238
+ except OSError as e:
239
+ # Don't touch the live daemon's pid/token files if the port is taken
240
+ log.error(
241
+ "Cannot bind %s:%s (%s) — is another sshmd already running?",
242
+ IPC_HOST, IPC_PORT, e,
243
+ )
244
+ return
245
+
246
+ # Publish credentials only after the port is ours
247
+ write_token(self.token)
248
+ pf = pid_file()
249
+ pf.write_text(str(os.getpid()), encoding="utf-8")
250
+
251
+ log.info("sshmd started (PID %s)", os.getpid())
252
+ log.info("IPC server listening on %s:%s", IPC_HOST, IPC_PORT)
253
+
254
+ # First watchdog tick auto-connects enabled aliases
255
+ watchdog_thread = threading.Thread(target=self._watchdog, daemon=True)
256
+ watchdog_thread.start()
257
+
258
+ def handle_signal(sig, frame):
259
+ log.info("Received signal %s, shutting down...", sig)
260
+ self._shutdown_event.set()
261
+
262
+ signal.signal(signal.SIGTERM, handle_signal)
263
+ signal.signal(signal.SIGINT, handle_signal)
264
+
265
+ self._shutdown_event.wait()
266
+
267
+ log.info("Shutting down...")
268
+ self.server.stop()
269
+ self.pm.disconnect_all()
270
+ pf.unlink(missing_ok=True)
271
+ log.info("sshmd stopped")
272
+
273
+
274
+ def main() -> None:
275
+ handlers: list[logging.Handler] = [logging.FileHandler(log_file(), encoding="utf-8")]
276
+ if sys.stderr is not None: # absent under pythonw / detached start
277
+ handlers.append(logging.StreamHandler())
278
+
279
+ logging.basicConfig(
280
+ level=logging.INFO,
281
+ format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
282
+ handlers=handlers,
283
+ )
284
+
285
+ Daemon().run()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ main()
sshm/ipc.py ADDED
@@ -0,0 +1,249 @@
1
+ """IPC client/server for sshm CLI <-> daemon communication over localhost TCP."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hmac
6
+ import socket
7
+ import subprocess
8
+ import sys
9
+ import threading
10
+ import time
11
+ from collections.abc import Callable
12
+ from typing import Any
13
+
14
+ from . import protocol
15
+ from .procutil import daemon_interpreter, detached_popen_flags, pid_alive
16
+ from .state import pid_file, read_token, resolve_port, token_file
17
+
18
+ IPC_HOST = "127.0.0.1"
19
+ IPC_PORT = resolve_port()
20
+ BUFFER_SIZE = 65536
21
+ REQUEST_TIMEOUT = 10.0
22
+
23
+
24
+ def _recv_line(sock: socket.socket) -> tuple[bytes, bytes] | None:
25
+ """Read until a newline.
26
+
27
+ Returns ``(line, leftover)`` where ``line`` is the bytes up to (excluding)
28
+ the first newline and ``leftover`` is anything already received past it.
29
+ Returns None if the peer closed before sending a newline. The leftover
30
+ matters for the streaming handshake: the daemon writes the JSON response and
31
+ then immediately starts streaming session output, and on localhost both can
32
+ land in a single recv() — without preserving the tail those first bytes of
33
+ terminal output would be lost (or corrupt the JSON decode).
34
+ """
35
+ data = b""
36
+ while b"\n" not in data:
37
+ chunk = sock.recv(BUFFER_SIZE)
38
+ if not chunk:
39
+ return None
40
+ data += chunk
41
+ line, _, leftover = data.partition(b"\n")
42
+ return line, leftover
43
+
44
+
45
+ # --- Server ---
46
+
47
+
48
+ class StreamingResponse:
49
+ """Marks a response whose connection should become a raw I/O bridge to a session."""
50
+
51
+ def __init__(self, response: dict[str, Any], session, cli_pid: int | None = None):
52
+ self.response = response
53
+ self.session = session # SshSession to bridge to
54
+ self.cli_pid = cli_pid
55
+
56
+
57
+ Handler = Callable[[dict[str, Any]], "dict[str, Any] | StreamingResponse"]
58
+
59
+
60
+ class IpcServer:
61
+ def __init__(self, handler: Handler, token: str):
62
+ self.handler = handler
63
+ self.token = token
64
+ self._sock: socket.socket | None = None
65
+ self._running = False
66
+ self._thread: threading.Thread | None = None
67
+
68
+ def start(self) -> None:
69
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
70
+ if sys.platform == "win32":
71
+ # On Windows SO_REUSEADDR lets another process bind an actively
72
+ # listening port and hijack traffic; exclusive use makes it fail.
73
+ self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
74
+ else:
75
+ self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
76
+ self._sock.bind((IPC_HOST, IPC_PORT))
77
+ self._sock.listen(5)
78
+ self._sock.settimeout(1.0)
79
+ self._running = True
80
+ self._thread = threading.Thread(target=self._accept_loop, daemon=True)
81
+ self._thread.start()
82
+
83
+ def _accept_loop(self) -> None:
84
+ while self._running:
85
+ try:
86
+ conn, _ = self._sock.accept()
87
+ except TimeoutError:
88
+ continue
89
+ except OSError:
90
+ break
91
+ threading.Thread(target=self._handle_client, args=(conn,), daemon=True).start()
92
+
93
+ def _handle_client(self, conn: socket.socket) -> None:
94
+ try:
95
+ conn.settimeout(REQUEST_TIMEOUT)
96
+ received = _recv_line(conn)
97
+ if received is None:
98
+ return
99
+ data, _ = received # a request is exactly one line; nothing trails it
100
+
101
+ request = protocol.decode(data)
102
+
103
+ token = request.get("token")
104
+ if not isinstance(token, str) or not hmac.compare_digest(token, self.token):
105
+ conn.sendall(protocol.encode(protocol.err("Invalid token")))
106
+ return
107
+
108
+ response = self.handler(request)
109
+
110
+ if isinstance(response, StreamingResponse):
111
+ # Send the JSON response, then turn this connection into a raw bridge.
112
+ # attach() already reserved the session (attached=True); if anything
113
+ # fails before bridge() starts its own finally-detach, release that
114
+ # reservation so the session isn't leaked out of the attachable pool.
115
+ try:
116
+ conn.sendall(protocol.encode(response.response))
117
+ conn.settimeout(None)
118
+ response.session.bridge(conn, cli_pid=response.cli_pid) # blocks until detach
119
+ except Exception:
120
+ response.session.detach()
121
+ raise
122
+ else:
123
+ conn.sendall(protocol.encode(response))
124
+
125
+ except Exception as e:
126
+ try:
127
+ conn.sendall(protocol.encode(protocol.err(str(e))))
128
+ except Exception:
129
+ pass
130
+ finally:
131
+ try:
132
+ conn.close()
133
+ except Exception:
134
+ pass
135
+
136
+ def stop(self) -> None:
137
+ self._running = False
138
+ if self._sock:
139
+ self._sock.close()
140
+ if self._thread:
141
+ self._thread.join(timeout=5)
142
+
143
+
144
+ # --- Client ---
145
+
146
+
147
+ class DaemonNotRunning(Exception):
148
+ pass
149
+
150
+
151
+ def _open_request(cmd: str, **kwargs) -> socket.socket:
152
+ """Connect to the daemon and send a request; the response is not read yet."""
153
+ token = read_token()
154
+ if not token:
155
+ raise DaemonNotRunning("Daemon not running (no token file)")
156
+
157
+ request = protocol.make_request(cmd, token, **kwargs)
158
+
159
+ try:
160
+ sock = socket.create_connection((IPC_HOST, IPC_PORT), timeout=REQUEST_TIMEOUT)
161
+ except ConnectionRefusedError as e:
162
+ raise DaemonNotRunning("Cannot connect to sshmd (connection refused)") from e
163
+
164
+ sock.sendall(protocol.encode(request))
165
+ return sock
166
+
167
+
168
+ def send_request(cmd: str, **kwargs) -> dict[str, Any]:
169
+ sock = _open_request(cmd, **kwargs)
170
+ try:
171
+ received = _recv_line(sock)
172
+ finally:
173
+ sock.close()
174
+ if received is None:
175
+ raise DaemonNotRunning("Connection closed")
176
+ data, _ = received
177
+ return protocol.decode(data)
178
+
179
+
180
+ def connect_streaming(cmd: str, **kwargs) -> tuple[socket.socket, dict[str, Any], bytes]:
181
+ """Send a request; return (socket, response, leftover).
182
+
183
+ ``leftover`` is any streamed output that arrived bundled with the response
184
+ line and must be replayed to the terminal before reading further from the
185
+ socket. The socket stays open for streaming.
186
+ """
187
+ sock = _open_request(cmd, **kwargs)
188
+
189
+ received = _recv_line(sock)
190
+ if received is None:
191
+ sock.close()
192
+ raise DaemonNotRunning("Connection closed")
193
+
194
+ data, leftover = received
195
+ resp = protocol.decode(data)
196
+ if not resp.get("ok"):
197
+ sock.close()
198
+ return sock, resp, b""
199
+
200
+ sock.settimeout(None) # no timeout while streaming
201
+ return sock, resp, leftover
202
+
203
+
204
+ # --- Daemon lifecycle ---
205
+
206
+
207
+ def is_daemon_running() -> bool:
208
+ pf = pid_file()
209
+ if not pf.exists():
210
+ return False
211
+ try:
212
+ pid = int(pf.read_text().strip())
213
+ except (ValueError, OSError):
214
+ # The daemon unlinks the pid file on shutdown; a read race means "not running"
215
+ return False
216
+ return pid_alive(pid)
217
+
218
+
219
+ def ensure_daemon() -> None:
220
+ if is_daemon_running():
221
+ # The pid is alive, but confirm it's actually our daemon answering on the
222
+ # port — guards against PID reuse (a stale pid file pointing at an
223
+ # unrelated process) and a wedged daemon. If it doesn't respond, fall
224
+ # through and (re)spawn.
225
+ try:
226
+ if send_request(protocol.CMD_STATUS).get("ok"):
227
+ return
228
+ except Exception:
229
+ pass
230
+
231
+ subprocess.Popen(
232
+ [daemon_interpreter(), "-m", "sshm.daemon"],
233
+ stdin=subprocess.DEVNULL,
234
+ stdout=subprocess.DEVNULL,
235
+ stderr=subprocess.DEVNULL,
236
+ **detached_popen_flags(),
237
+ )
238
+
239
+ for _ in range(30):
240
+ time.sleep(0.2)
241
+ if token_file().exists():
242
+ try:
243
+ if send_request(protocol.CMD_STATUS).get("ok"):
244
+ return
245
+ except Exception:
246
+ pass
247
+ raise RuntimeError(
248
+ f"Failed to start sshmd daemon (is port {IPC_PORT} taken by another process?)"
249
+ )