mcp-ssh-vps 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_ssh_vps-0.4.1.dist-info/METADATA +482 -0
- mcp_ssh_vps-0.4.1.dist-info/RECORD +47 -0
- mcp_ssh_vps-0.4.1.dist-info/WHEEL +5 -0
- mcp_ssh_vps-0.4.1.dist-info/entry_points.txt +4 -0
- mcp_ssh_vps-0.4.1.dist-info/licenses/LICENSE +21 -0
- mcp_ssh_vps-0.4.1.dist-info/top_level.txt +1 -0
- sshmcp/__init__.py +3 -0
- sshmcp/cli.py +473 -0
- sshmcp/config.py +155 -0
- sshmcp/core/__init__.py +5 -0
- sshmcp/core/container.py +291 -0
- sshmcp/models/__init__.py +15 -0
- sshmcp/models/command.py +69 -0
- sshmcp/models/file.py +102 -0
- sshmcp/models/machine.py +139 -0
- sshmcp/monitoring/__init__.py +0 -0
- sshmcp/monitoring/alerts.py +464 -0
- sshmcp/prompts/__init__.py +7 -0
- sshmcp/prompts/backup.py +151 -0
- sshmcp/prompts/deploy.py +115 -0
- sshmcp/prompts/monitor.py +146 -0
- sshmcp/resources/__init__.py +7 -0
- sshmcp/resources/logs.py +99 -0
- sshmcp/resources/metrics.py +204 -0
- sshmcp/resources/status.py +160 -0
- sshmcp/security/__init__.py +7 -0
- sshmcp/security/audit.py +314 -0
- sshmcp/security/rate_limiter.py +221 -0
- sshmcp/security/totp.py +392 -0
- sshmcp/security/validator.py +234 -0
- sshmcp/security/whitelist.py +169 -0
- sshmcp/server.py +632 -0
- sshmcp/ssh/__init__.py +6 -0
- sshmcp/ssh/async_client.py +247 -0
- sshmcp/ssh/client.py +464 -0
- sshmcp/ssh/executor.py +79 -0
- sshmcp/ssh/forwarding.py +368 -0
- sshmcp/ssh/pool.py +343 -0
- sshmcp/ssh/shell.py +518 -0
- sshmcp/ssh/transfer.py +461 -0
- sshmcp/tools/__init__.py +13 -0
- sshmcp/tools/commands.py +226 -0
- sshmcp/tools/files.py +220 -0
- sshmcp/tools/helpers.py +321 -0
- sshmcp/tools/history.py +372 -0
- sshmcp/tools/processes.py +214 -0
- sshmcp/tools/servers.py +484 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
"""MCP Resource for server status."""
|
|
2
|
+
|
|
3
|
+
import structlog
|
|
4
|
+
|
|
5
|
+
from sshmcp.config import get_machine
|
|
6
|
+
from sshmcp.security.audit import get_audit_logger
|
|
7
|
+
from sshmcp.ssh.pool import get_pool
|
|
8
|
+
|
|
9
|
+
logger = structlog.get_logger()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_status(host: str) -> dict:
|
|
13
|
+
"""
|
|
14
|
+
Get status of VPS server.
|
|
15
|
+
|
|
16
|
+
Returns server availability, hostname, and service statuses.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
host: Name of the host from machines.json configuration.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Dictionary with:
|
|
23
|
+
- hostname: Server hostname
|
|
24
|
+
- status: Server status (online/offline)
|
|
25
|
+
- services: List of service statuses
|
|
26
|
+
- system_info: Basic system information
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
ValueError: If host not found.
|
|
30
|
+
RuntimeError: If status cannot be retrieved.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
Resource URI: vps://production-server/status
|
|
34
|
+
"""
|
|
35
|
+
audit = get_audit_logger()
|
|
36
|
+
|
|
37
|
+
# Get machine configuration
|
|
38
|
+
try:
|
|
39
|
+
machine = get_machine(host)
|
|
40
|
+
except Exception as e:
|
|
41
|
+
raise ValueError(f"Host not found: {host}") from e
|
|
42
|
+
|
|
43
|
+
pool = get_pool()
|
|
44
|
+
pool.register_machine(machine)
|
|
45
|
+
|
|
46
|
+
status_info = {
|
|
47
|
+
"hostname": "",
|
|
48
|
+
"status": "unknown",
|
|
49
|
+
"host": host,
|
|
50
|
+
"services": [],
|
|
51
|
+
"system_info": {},
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
client = pool.get_client(host)
|
|
56
|
+
try:
|
|
57
|
+
# Server is online if we can connect
|
|
58
|
+
status_info["status"] = "online"
|
|
59
|
+
|
|
60
|
+
# Get hostname
|
|
61
|
+
hostname_result = client.execute("hostname")
|
|
62
|
+
status_info["hostname"] = hostname_result.stdout.strip()
|
|
63
|
+
|
|
64
|
+
# Get basic system info
|
|
65
|
+
uname_result = client.execute("uname -a")
|
|
66
|
+
status_info["system_info"]["uname"] = uname_result.stdout.strip()
|
|
67
|
+
|
|
68
|
+
# Get OS info
|
|
69
|
+
os_result = client.execute(
|
|
70
|
+
"cat /etc/os-release 2>/dev/null | grep PRETTY_NAME | cut -d= -f2 | tr -d '\"'"
|
|
71
|
+
)
|
|
72
|
+
if os_result.exit_code == 0 and os_result.stdout.strip():
|
|
73
|
+
status_info["system_info"]["os"] = os_result.stdout.strip()
|
|
74
|
+
|
|
75
|
+
# Get running services (try systemd first)
|
|
76
|
+
services_result = client.execute(
|
|
77
|
+
"systemctl list-units --type=service --state=running --no-pager --no-legend 2>/dev/null | head -20"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
if services_result.exit_code == 0 and services_result.stdout.strip():
|
|
81
|
+
status_info["services"] = _parse_systemd_services(
|
|
82
|
+
services_result.stdout
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
# Try pm2
|
|
86
|
+
pm2_result = client.execute("pm2 jlist 2>/dev/null")
|
|
87
|
+
if pm2_result.exit_code == 0 and pm2_result.stdout.strip():
|
|
88
|
+
status_info["services"] = _parse_pm2_services(pm2_result.stdout)
|
|
89
|
+
|
|
90
|
+
audit.log(
|
|
91
|
+
event="status_read",
|
|
92
|
+
host=host,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return status_info
|
|
96
|
+
|
|
97
|
+
finally:
|
|
98
|
+
pool.release_client(client)
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
# If we can't connect, server is offline
|
|
102
|
+
status_info["status"] = "offline"
|
|
103
|
+
status_info["error"] = str(e)
|
|
104
|
+
|
|
105
|
+
audit.log(
|
|
106
|
+
event="status_read_failed",
|
|
107
|
+
host=host,
|
|
108
|
+
error=str(e),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return status_info
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _parse_systemd_services(output: str) -> list[dict]:
|
|
115
|
+
"""Parse systemd services from systemctl output."""
|
|
116
|
+
services = []
|
|
117
|
+
|
|
118
|
+
for line in output.strip().split("\n"):
|
|
119
|
+
if not line.strip():
|
|
120
|
+
continue
|
|
121
|
+
|
|
122
|
+
parts = line.split()
|
|
123
|
+
if len(parts) >= 1:
|
|
124
|
+
service_name = parts[0]
|
|
125
|
+
# Remove .service suffix
|
|
126
|
+
if service_name.endswith(".service"):
|
|
127
|
+
service_name = service_name[:-8]
|
|
128
|
+
|
|
129
|
+
services.append(
|
|
130
|
+
{
|
|
131
|
+
"name": service_name,
|
|
132
|
+
"status": "running",
|
|
133
|
+
"manager": "systemd",
|
|
134
|
+
}
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return services
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _parse_pm2_services(output: str) -> list[dict]:
|
|
141
|
+
"""Parse pm2 services from pm2 jlist output."""
|
|
142
|
+
import json
|
|
143
|
+
|
|
144
|
+
services = []
|
|
145
|
+
|
|
146
|
+
try:
|
|
147
|
+
data = json.loads(output)
|
|
148
|
+
for app in data:
|
|
149
|
+
services.append(
|
|
150
|
+
{
|
|
151
|
+
"name": app.get("name", "unknown"),
|
|
152
|
+
"status": app.get("pm2_env", {}).get("status", "unknown"),
|
|
153
|
+
"manager": "pm2",
|
|
154
|
+
"pid": app.get("pid"),
|
|
155
|
+
}
|
|
156
|
+
)
|
|
157
|
+
except json.JSONDecodeError:
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
return services
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""Security module for command validation and auditing."""
|
|
2
|
+
|
|
3
|
+
from sshmcp.security.audit import audit_log
|
|
4
|
+
from sshmcp.security.validator import validate_command, validate_path
|
|
5
|
+
from sshmcp.security.whitelist import CommandWhitelist
|
|
6
|
+
|
|
7
|
+
__all__ = ["validate_command", "validate_path", "CommandWhitelist", "audit_log"]
|
sshmcp/security/audit.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
"""Audit logging for SSH MCP operations."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import socket
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
import structlog
|
|
11
|
+
|
|
12
|
+
from sshmcp.security.validator import sanitize_command_for_log
|
|
13
|
+
|
|
14
|
+
logger = structlog.get_logger()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_client_info() -> dict[str, str | None]:
|
|
18
|
+
"""
|
|
19
|
+
Get information about the client making the request.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Dictionary with client IP, hostname, and user.
|
|
23
|
+
"""
|
|
24
|
+
info: dict[str, str | None] = {
|
|
25
|
+
"client_ip": None,
|
|
26
|
+
"client_hostname": None,
|
|
27
|
+
"client_user": None,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
# Get current user
|
|
31
|
+
try:
|
|
32
|
+
info["client_user"] = os.getenv("USER") or os.getenv("USERNAME")
|
|
33
|
+
except Exception:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
# Get local hostname
|
|
37
|
+
try:
|
|
38
|
+
info["client_hostname"] = socket.gethostname()
|
|
39
|
+
except Exception:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
# Get SSH client IP if available (from SSH_CLIENT or SSH_CONNECTION env vars)
|
|
43
|
+
try:
|
|
44
|
+
ssh_client = os.getenv("SSH_CLIENT")
|
|
45
|
+
if ssh_client:
|
|
46
|
+
info["client_ip"] = ssh_client.split()[0]
|
|
47
|
+
else:
|
|
48
|
+
ssh_connection = os.getenv("SSH_CONNECTION")
|
|
49
|
+
if ssh_connection:
|
|
50
|
+
info["client_ip"] = ssh_connection.split()[0]
|
|
51
|
+
except Exception:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
return info
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class AuditLogger:
|
|
58
|
+
"""
|
|
59
|
+
Audit logger for tracking all SSH MCP operations.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(
|
|
63
|
+
self,
|
|
64
|
+
log_file: Optional[str] = None,
|
|
65
|
+
log_to_stdout: bool = True,
|
|
66
|
+
) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Initialize audit logger.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
log_file: Optional path to audit log file.
|
|
72
|
+
log_to_stdout: Whether to also log to stdout via structlog.
|
|
73
|
+
"""
|
|
74
|
+
self.log_file = log_file
|
|
75
|
+
self.log_to_stdout = log_to_stdout
|
|
76
|
+
self._file_handle = None
|
|
77
|
+
|
|
78
|
+
if log_file:
|
|
79
|
+
path = Path(log_file)
|
|
80
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
81
|
+
self._file_handle = open(path, "a", encoding="utf-8")
|
|
82
|
+
|
|
83
|
+
def log(
|
|
84
|
+
self,
|
|
85
|
+
event: str,
|
|
86
|
+
host: Optional[str] = None,
|
|
87
|
+
command: Optional[str] = None,
|
|
88
|
+
result: Optional[dict[str, Any]] = None,
|
|
89
|
+
error: Optional[str] = None,
|
|
90
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
91
|
+
include_client_info: bool = True,
|
|
92
|
+
) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Log an audit event.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
event: Event type (e.g., 'command_executed', 'file_read').
|
|
98
|
+
host: Target host name.
|
|
99
|
+
command: Command that was executed (will be sanitized).
|
|
100
|
+
result: Result of the operation.
|
|
101
|
+
error: Error message if operation failed.
|
|
102
|
+
metadata: Additional metadata to log.
|
|
103
|
+
include_client_info: Whether to include client IP/user info.
|
|
104
|
+
"""
|
|
105
|
+
timestamp = datetime.now(timezone.utc).isoformat()
|
|
106
|
+
|
|
107
|
+
audit_record: dict[str, Any] = {
|
|
108
|
+
"timestamp": timestamp,
|
|
109
|
+
"event": event,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Include client information
|
|
113
|
+
if include_client_info:
|
|
114
|
+
client_info = get_client_info()
|
|
115
|
+
audit_record["client"] = {
|
|
116
|
+
k: v for k, v in client_info.items() if v is not None
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
# Build log kwargs (without event, as it's passed separately to structlog)
|
|
120
|
+
log_kwargs: dict[str, Any] = {"timestamp": timestamp}
|
|
121
|
+
|
|
122
|
+
if host:
|
|
123
|
+
audit_record["host"] = host
|
|
124
|
+
log_kwargs["host"] = host
|
|
125
|
+
|
|
126
|
+
if command:
|
|
127
|
+
# Sanitize command to hide potential secrets
|
|
128
|
+
sanitized = sanitize_command_for_log(command)
|
|
129
|
+
audit_record["command"] = sanitized
|
|
130
|
+
log_kwargs["command"] = sanitized
|
|
131
|
+
|
|
132
|
+
if result:
|
|
133
|
+
# Only include safe result fields
|
|
134
|
+
safe_result = {
|
|
135
|
+
k: v
|
|
136
|
+
for k, v in result.items()
|
|
137
|
+
if k in ("exit_code", "success", "duration_ms", "size", "path")
|
|
138
|
+
}
|
|
139
|
+
audit_record["result"] = safe_result
|
|
140
|
+
log_kwargs["result"] = safe_result
|
|
141
|
+
|
|
142
|
+
if error:
|
|
143
|
+
audit_record["error"] = error
|
|
144
|
+
log_kwargs["error"] = error
|
|
145
|
+
|
|
146
|
+
if metadata:
|
|
147
|
+
audit_record["metadata"] = metadata
|
|
148
|
+
log_kwargs["metadata"] = metadata
|
|
149
|
+
|
|
150
|
+
# Log to structlog
|
|
151
|
+
if self.log_to_stdout:
|
|
152
|
+
if error:
|
|
153
|
+
logger.error(event, **log_kwargs)
|
|
154
|
+
else:
|
|
155
|
+
logger.info(event, **log_kwargs)
|
|
156
|
+
|
|
157
|
+
# Write to file
|
|
158
|
+
if self._file_handle:
|
|
159
|
+
self._file_handle.write(json.dumps(audit_record) + "\n")
|
|
160
|
+
self._file_handle.flush()
|
|
161
|
+
|
|
162
|
+
def log_command_executed(
|
|
163
|
+
self,
|
|
164
|
+
host: str,
|
|
165
|
+
command: str,
|
|
166
|
+
exit_code: int,
|
|
167
|
+
duration_ms: int,
|
|
168
|
+
) -> None:
|
|
169
|
+
"""Log command execution."""
|
|
170
|
+
self.log(
|
|
171
|
+
event="command_executed",
|
|
172
|
+
host=host,
|
|
173
|
+
command=command,
|
|
174
|
+
result={
|
|
175
|
+
"exit_code": exit_code,
|
|
176
|
+
"duration_ms": duration_ms,
|
|
177
|
+
"success": exit_code == 0,
|
|
178
|
+
},
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
def log_command_failed(
|
|
182
|
+
self,
|
|
183
|
+
host: str,
|
|
184
|
+
command: str,
|
|
185
|
+
error: str,
|
|
186
|
+
) -> None:
|
|
187
|
+
"""Log failed command execution."""
|
|
188
|
+
self.log(
|
|
189
|
+
event="command_failed",
|
|
190
|
+
host=host,
|
|
191
|
+
command=command,
|
|
192
|
+
error=error,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
def log_command_rejected(
|
|
196
|
+
self,
|
|
197
|
+
host: str,
|
|
198
|
+
command: str,
|
|
199
|
+
reason: str,
|
|
200
|
+
) -> None:
|
|
201
|
+
"""Log rejected command (security validation failed)."""
|
|
202
|
+
self.log(
|
|
203
|
+
event="command_rejected",
|
|
204
|
+
host=host,
|
|
205
|
+
command=command,
|
|
206
|
+
error=reason,
|
|
207
|
+
metadata={"security_violation": True},
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def log_file_read(
|
|
211
|
+
self,
|
|
212
|
+
host: str,
|
|
213
|
+
path: str,
|
|
214
|
+
size: int,
|
|
215
|
+
) -> None:
|
|
216
|
+
"""Log file read operation."""
|
|
217
|
+
self.log(
|
|
218
|
+
event="file_read",
|
|
219
|
+
host=host,
|
|
220
|
+
result={"path": path, "size": size},
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
def log_file_write(
|
|
224
|
+
self,
|
|
225
|
+
host: str,
|
|
226
|
+
path: str,
|
|
227
|
+
size: int,
|
|
228
|
+
) -> None:
|
|
229
|
+
"""Log file write operation."""
|
|
230
|
+
self.log(
|
|
231
|
+
event="file_write",
|
|
232
|
+
host=host,
|
|
233
|
+
result={"path": path, "size": size},
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
def log_path_rejected(
|
|
237
|
+
self,
|
|
238
|
+
host: str,
|
|
239
|
+
path: str,
|
|
240
|
+
reason: str,
|
|
241
|
+
) -> None:
|
|
242
|
+
"""Log rejected path access."""
|
|
243
|
+
self.log(
|
|
244
|
+
event="path_rejected",
|
|
245
|
+
host=host,
|
|
246
|
+
error=reason,
|
|
247
|
+
metadata={"path": path, "security_violation": True},
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def close(self) -> None:
|
|
251
|
+
"""Close audit log file."""
|
|
252
|
+
if self._file_handle:
|
|
253
|
+
self._file_handle.close()
|
|
254
|
+
self._file_handle = None
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# Global audit logger instance
|
|
258
|
+
_audit_logger: Optional[AuditLogger] = None
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def get_audit_logger() -> AuditLogger:
|
|
262
|
+
"""Get or create the global audit logger."""
|
|
263
|
+
global _audit_logger
|
|
264
|
+
if _audit_logger is None:
|
|
265
|
+
_audit_logger = AuditLogger()
|
|
266
|
+
return _audit_logger
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def init_audit_logger(
|
|
270
|
+
log_file: Optional[str] = None,
|
|
271
|
+
log_to_stdout: bool = True,
|
|
272
|
+
) -> AuditLogger:
|
|
273
|
+
"""
|
|
274
|
+
Initialize the global audit logger.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
log_file: Optional path to audit log file.
|
|
278
|
+
log_to_stdout: Whether to also log to stdout.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Initialized AuditLogger.
|
|
282
|
+
"""
|
|
283
|
+
global _audit_logger
|
|
284
|
+
_audit_logger = AuditLogger(log_file=log_file, log_to_stdout=log_to_stdout)
|
|
285
|
+
return _audit_logger
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def audit_log(
|
|
289
|
+
event: str,
|
|
290
|
+
host: Optional[str] = None,
|
|
291
|
+
command: Optional[str] = None,
|
|
292
|
+
result: Optional[dict[str, Any]] = None,
|
|
293
|
+
error: Optional[str] = None,
|
|
294
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
295
|
+
) -> None:
|
|
296
|
+
"""
|
|
297
|
+
Convenience function to log an audit event.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
event: Event type.
|
|
301
|
+
host: Target host name.
|
|
302
|
+
command: Command that was executed.
|
|
303
|
+
result: Result of the operation.
|
|
304
|
+
error: Error message if operation failed.
|
|
305
|
+
metadata: Additional metadata.
|
|
306
|
+
"""
|
|
307
|
+
get_audit_logger().log(
|
|
308
|
+
event=event,
|
|
309
|
+
host=host,
|
|
310
|
+
command=command,
|
|
311
|
+
result=result,
|
|
312
|
+
error=error,
|
|
313
|
+
metadata=metadata,
|
|
314
|
+
)
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""Rate limiting for command execution."""
|
|
2
|
+
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import NamedTuple
|
|
7
|
+
|
|
8
|
+
import structlog
|
|
9
|
+
|
|
10
|
+
logger = structlog.get_logger()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RateLimitConfig(NamedTuple):
|
|
14
|
+
"""Rate limit configuration."""
|
|
15
|
+
|
|
16
|
+
requests_per_minute: int = 60
|
|
17
|
+
requests_per_hour: int = 1000
|
|
18
|
+
burst_size: int = 10
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RateLimitExceeded(Exception):
|
|
22
|
+
"""Exception raised when rate limit is exceeded."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, message: str, retry_after: float = 0):
|
|
25
|
+
super().__init__(message)
|
|
26
|
+
self.retry_after = retry_after
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TokenBucket:
|
|
30
|
+
"""Token bucket rate limiter implementation."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, rate: float, capacity: int):
|
|
33
|
+
"""
|
|
34
|
+
Initialize token bucket.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
rate: Tokens per second to add.
|
|
38
|
+
capacity: Maximum tokens in bucket.
|
|
39
|
+
"""
|
|
40
|
+
self.rate = rate
|
|
41
|
+
self.capacity = capacity
|
|
42
|
+
self.tokens = capacity
|
|
43
|
+
self.last_update = time.monotonic()
|
|
44
|
+
self._lock = threading.Lock()
|
|
45
|
+
|
|
46
|
+
def consume(self, tokens: int = 1) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Try to consume tokens from the bucket.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
tokens: Number of tokens to consume.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
True if tokens were consumed, False if not enough tokens.
|
|
55
|
+
"""
|
|
56
|
+
with self._lock:
|
|
57
|
+
now = time.monotonic()
|
|
58
|
+
elapsed = now - self.last_update
|
|
59
|
+
self.last_update = now
|
|
60
|
+
|
|
61
|
+
# Add tokens based on elapsed time
|
|
62
|
+
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
|
63
|
+
|
|
64
|
+
if self.tokens >= tokens:
|
|
65
|
+
self.tokens -= tokens
|
|
66
|
+
return True
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
def time_until_available(self, tokens: int = 1) -> float:
|
|
70
|
+
"""
|
|
71
|
+
Calculate time until tokens are available.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
tokens: Number of tokens needed.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Seconds until tokens available (0 if available now).
|
|
78
|
+
"""
|
|
79
|
+
with self._lock:
|
|
80
|
+
if self.tokens >= tokens:
|
|
81
|
+
return 0
|
|
82
|
+
needed = tokens - self.tokens
|
|
83
|
+
return needed / self.rate
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class RateLimiter:
|
|
87
|
+
"""Rate limiter for SSH command execution."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, config: RateLimitConfig | None = None):
|
|
90
|
+
"""
|
|
91
|
+
Initialize rate limiter.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
config: Rate limit configuration.
|
|
95
|
+
"""
|
|
96
|
+
self.config = config or RateLimitConfig()
|
|
97
|
+
self._buckets: dict[str, dict[str, TokenBucket]] = defaultdict(dict)
|
|
98
|
+
self._lock = threading.Lock()
|
|
99
|
+
|
|
100
|
+
def _get_bucket(self, host: str, limit_type: str) -> TokenBucket:
|
|
101
|
+
"""Get or create a token bucket for host and limit type."""
|
|
102
|
+
with self._lock:
|
|
103
|
+
if limit_type not in self._buckets[host]:
|
|
104
|
+
if limit_type == "minute":
|
|
105
|
+
rate = self.config.requests_per_minute / 60.0
|
|
106
|
+
capacity = self.config.burst_size
|
|
107
|
+
elif limit_type == "hour":
|
|
108
|
+
rate = self.config.requests_per_hour / 3600.0
|
|
109
|
+
capacity = self.config.burst_size * 2
|
|
110
|
+
else:
|
|
111
|
+
rate = 1.0
|
|
112
|
+
capacity = 10
|
|
113
|
+
|
|
114
|
+
self._buckets[host][limit_type] = TokenBucket(rate, capacity)
|
|
115
|
+
|
|
116
|
+
return self._buckets[host][limit_type]
|
|
117
|
+
|
|
118
|
+
def check_rate_limit(self, host: str) -> None:
|
|
119
|
+
"""
|
|
120
|
+
Check if request is within rate limits.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
host: Host name to check rate limit for.
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
RateLimitExceeded: If rate limit is exceeded.
|
|
127
|
+
"""
|
|
128
|
+
# Check minute limit
|
|
129
|
+
minute_bucket = self._get_bucket(host, "minute")
|
|
130
|
+
if not minute_bucket.consume():
|
|
131
|
+
retry_after = minute_bucket.time_until_available()
|
|
132
|
+
logger.warning(
|
|
133
|
+
"rate_limit_exceeded",
|
|
134
|
+
host=host,
|
|
135
|
+
limit_type="minute",
|
|
136
|
+
retry_after=retry_after,
|
|
137
|
+
)
|
|
138
|
+
raise RateLimitExceeded(
|
|
139
|
+
f"Rate limit exceeded for {host}. Try again in {retry_after:.1f}s",
|
|
140
|
+
retry_after=retry_after,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Check hour limit
|
|
144
|
+
hour_bucket = self._get_bucket(host, "hour")
|
|
145
|
+
if not hour_bucket.consume():
|
|
146
|
+
retry_after = hour_bucket.time_until_available()
|
|
147
|
+
logger.warning(
|
|
148
|
+
"rate_limit_exceeded",
|
|
149
|
+
host=host,
|
|
150
|
+
limit_type="hour",
|
|
151
|
+
retry_after=retry_after,
|
|
152
|
+
)
|
|
153
|
+
raise RateLimitExceeded(
|
|
154
|
+
f"Hourly rate limit exceeded for {host}. Try again in {retry_after:.1f}s",
|
|
155
|
+
retry_after=retry_after,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def get_remaining(self, host: str) -> dict:
|
|
159
|
+
"""
|
|
160
|
+
Get remaining rate limit for host.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
host: Host name.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Dictionary with remaining limits.
|
|
167
|
+
"""
|
|
168
|
+
minute_bucket = self._get_bucket(host, "minute")
|
|
169
|
+
hour_bucket = self._get_bucket(host, "hour")
|
|
170
|
+
|
|
171
|
+
return {
|
|
172
|
+
"host": host,
|
|
173
|
+
"minute": {
|
|
174
|
+
"remaining": int(minute_bucket.tokens),
|
|
175
|
+
"limit": self.config.requests_per_minute,
|
|
176
|
+
},
|
|
177
|
+
"hour": {
|
|
178
|
+
"remaining": int(hour_bucket.tokens),
|
|
179
|
+
"limit": self.config.requests_per_hour,
|
|
180
|
+
},
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
def reset(self, host: str | None = None) -> None:
|
|
184
|
+
"""
|
|
185
|
+
Reset rate limits.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
host: Host to reset, or None for all hosts.
|
|
189
|
+
"""
|
|
190
|
+
with self._lock:
|
|
191
|
+
if host is None:
|
|
192
|
+
self._buckets.clear()
|
|
193
|
+
elif host in self._buckets:
|
|
194
|
+
del self._buckets[host]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# Global rate limiter instance
|
|
198
|
+
_rate_limiter: RateLimiter | None = None
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def get_rate_limiter() -> RateLimiter:
|
|
202
|
+
"""Get or create the global rate limiter."""
|
|
203
|
+
global _rate_limiter
|
|
204
|
+
if _rate_limiter is None:
|
|
205
|
+
_rate_limiter = RateLimiter()
|
|
206
|
+
return _rate_limiter
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def init_rate_limiter(config: RateLimitConfig | None = None) -> RateLimiter:
|
|
210
|
+
"""
|
|
211
|
+
Initialize the global rate limiter with config.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
config: Rate limit configuration.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Initialized RateLimiter.
|
|
218
|
+
"""
|
|
219
|
+
global _rate_limiter
|
|
220
|
+
_rate_limiter = RateLimiter(config)
|
|
221
|
+
return _rate_limiter
|