openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/benchmarks/__init__.py +8 -0
- openadapt_ml/benchmarks/agent.py +90 -11
- openadapt_ml/benchmarks/azure.py +35 -6
- openadapt_ml/benchmarks/cli.py +4449 -201
- openadapt_ml/benchmarks/live_tracker.py +180 -0
- openadapt_ml/benchmarks/runner.py +41 -4
- openadapt_ml/benchmarks/viewer.py +1219 -0
- openadapt_ml/benchmarks/vm_monitor.py +610 -0
- openadapt_ml/benchmarks/waa.py +61 -4
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/benchmarks/waa_live.py +619 -0
- openadapt_ml/cloud/local.py +1555 -1
- openadapt_ml/cloud/ssh_tunnel.py +553 -0
- openadapt_ml/datasets/next_action.py +87 -68
- openadapt_ml/evals/grounding.py +26 -8
- openadapt_ml/evals/trajectory_matching.py +84 -36
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +717 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +265 -0
- openadapt_ml/ingest/__init__.py +3 -4
- openadapt_ml/ingest/capture.py +89 -81
- openadapt_ml/ingest/loader.py +116 -68
- openadapt_ml/ingest/synthetic.py +221 -159
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +817 -0
- openadapt_ml/retrieval/embeddings.py +629 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +160 -0
- openadapt_ml/runtime/policy.py +10 -10
- openadapt_ml/schema/__init__.py +104 -0
- openadapt_ml/schema/converters.py +541 -0
- openadapt_ml/schema/episode.py +457 -0
- openadapt_ml/scripts/compare.py +26 -16
- openadapt_ml/scripts/eval_policy.py +4 -5
- openadapt_ml/scripts/prepare_synthetic.py +14 -17
- openadapt_ml/scripts/train.py +81 -70
- openadapt_ml/training/benchmark_viewer.py +3225 -0
- openadapt_ml/training/trainer.py +120 -363
- openadapt_ml/training/trl_trainer.py +354 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
- openadapt_ml-0.2.0.dist-info/RECORD +86 -0
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,553 @@
|
|
|
1
|
+
"""SSH Tunnel Manager for Azure VMs.
|
|
2
|
+
|
|
3
|
+
This module provides automatic SSH tunnel management for accessing services
|
|
4
|
+
running inside Azure VMs (VNC, WAA server) that are not exposed via NSG.
|
|
5
|
+
|
|
6
|
+
Architecture:
|
|
7
|
+
Azure VMs have Network Security Groups (NSGs) that act as firewalls.
|
|
8
|
+
By default, only port 22 (SSH) is open. To access other services like
|
|
9
|
+
VNC (8006) and WAA (5000), we create SSH tunnels:
|
|
10
|
+
|
|
11
|
+
Browser → localhost:8006 → SSH Tunnel → Azure VM:8006 → Docker → noVNC
|
|
12
|
+
|
|
13
|
+
This is more secure than opening ports in NSG because:
|
|
14
|
+
1. All traffic is encrypted through SSH
|
|
15
|
+
2. No authentication bypass (VNC has no auth by default)
|
|
16
|
+
3. Access requires SSH key authentication
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
from openadapt_ml.cloud.ssh_tunnel import SSHTunnelManager
|
|
20
|
+
|
|
21
|
+
# Create manager
|
|
22
|
+
manager = SSHTunnelManager()
|
|
23
|
+
|
|
24
|
+
# Start tunnels for a VM
|
|
25
|
+
manager.start_tunnels_for_vm(
|
|
26
|
+
vm_ip="172.171.112.41",
|
|
27
|
+
ssh_user="azureuser",
|
|
28
|
+
ports={"vnc": 8006, "waa": 5000}
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Check tunnel status
|
|
32
|
+
status = manager.get_tunnel_status()
|
|
33
|
+
# {'vnc': {'active': True, 'local_port': 8006, 'remote': '172.171.112.41:8006'}, ...}
|
|
34
|
+
|
|
35
|
+
# Stop all tunnels
|
|
36
|
+
manager.stop_all_tunnels()
|
|
37
|
+
|
|
38
|
+
Integration:
|
|
39
|
+
The SSHTunnelManager is integrated with the dashboard server (local.py):
|
|
40
|
+
- When a VM's WAA probe becomes "ready", tunnels are auto-started
|
|
41
|
+
- When VM goes offline, tunnels are auto-stopped
|
|
42
|
+
- Dashboard shows tunnel status next to VNC button
|
|
43
|
+
- VNC button links to localhost:port (tunnel endpoint)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
from __future__ import annotations
|
|
47
|
+
|
|
48
|
+
import logging
|
|
49
|
+
import os
|
|
50
|
+
import signal
|
|
51
|
+
import socket
|
|
52
|
+
import subprocess
|
|
53
|
+
import time
|
|
54
|
+
from dataclasses import dataclass, field
|
|
55
|
+
from pathlib import Path
|
|
56
|
+
from typing import Any
|
|
57
|
+
|
|
58
|
+
logger = logging.getLogger(__name__)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class TunnelConfig:
|
|
63
|
+
"""Configuration for a single SSH tunnel."""
|
|
64
|
+
|
|
65
|
+
name: str # e.g., "vnc", "waa"
|
|
66
|
+
local_port: int # Local port to listen on
|
|
67
|
+
remote_port: int # Port on the remote VM
|
|
68
|
+
remote_host: str = "localhost" # Host on remote side (usually localhost)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class TunnelStatus:
|
|
73
|
+
"""Status of an SSH tunnel."""
|
|
74
|
+
|
|
75
|
+
name: str
|
|
76
|
+
active: bool
|
|
77
|
+
local_port: int
|
|
78
|
+
remote_endpoint: str # e.g., "172.171.112.41:8006"
|
|
79
|
+
pid: int | None = None
|
|
80
|
+
error: str | None = None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class SSHTunnelManager:
|
|
84
|
+
"""Manages SSH tunnels for Azure VM access.
|
|
85
|
+
|
|
86
|
+
Provides automatic setup and teardown of SSH tunnels for services
|
|
87
|
+
running inside Azure VMs that are not exposed via NSG.
|
|
88
|
+
|
|
89
|
+
Features:
|
|
90
|
+
- Auto-reconnect: Automatically restarts dead tunnels
|
|
91
|
+
- Health monitoring: Periodic checks to verify tunnels are working
|
|
92
|
+
- Graceful handling of network interruptions
|
|
93
|
+
|
|
94
|
+
Attributes:
|
|
95
|
+
tunnels: Dict of tunnel name -> (TunnelConfig, process)
|
|
96
|
+
ssh_key_path: Path to SSH private key
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
# Default tunnel configurations
|
|
100
|
+
DEFAULT_TUNNELS = [
|
|
101
|
+
TunnelConfig(name="vnc", local_port=8006, remote_port=8006),
|
|
102
|
+
TunnelConfig(name="waa", local_port=5000, remote_port=5000),
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
# Auto-reconnect settings
|
|
106
|
+
MAX_RECONNECT_ATTEMPTS = 3
|
|
107
|
+
RECONNECT_DELAY_SECONDS = 2
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
ssh_key_path: str | Path | None = None,
|
|
112
|
+
tunnels: list[TunnelConfig] | None = None,
|
|
113
|
+
auto_reconnect: bool = True,
|
|
114
|
+
):
|
|
115
|
+
"""Initialize tunnel manager.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
ssh_key_path: Path to SSH private key. Defaults to ~/.ssh/id_rsa.
|
|
119
|
+
tunnels: List of tunnel configurations. Defaults to VNC + WAA.
|
|
120
|
+
auto_reconnect: If True, automatically restart dead tunnels.
|
|
121
|
+
"""
|
|
122
|
+
self.ssh_key_path = Path(ssh_key_path or Path.home() / ".ssh" / "id_rsa")
|
|
123
|
+
self.tunnel_configs = tunnels or self.DEFAULT_TUNNELS
|
|
124
|
+
self._active_tunnels: dict[str, tuple[TunnelConfig, subprocess.Popen]] = {}
|
|
125
|
+
self._current_vm_ip: str | None = None
|
|
126
|
+
self._current_ssh_user: str | None = None
|
|
127
|
+
self._auto_reconnect = auto_reconnect
|
|
128
|
+
self._reconnect_attempts: dict[str, int] = {} # Track reconnect attempts per tunnel
|
|
129
|
+
|
|
130
|
+
def start_tunnels_for_vm(
|
|
131
|
+
self,
|
|
132
|
+
vm_ip: str,
|
|
133
|
+
ssh_user: str = "azureuser",
|
|
134
|
+
tunnels: list[TunnelConfig] | None = None,
|
|
135
|
+
) -> dict[str, TunnelStatus]:
|
|
136
|
+
"""Start SSH tunnels for a VM.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
vm_ip: IP address of the Azure VM.
|
|
140
|
+
ssh_user: SSH username (default: azureuser).
|
|
141
|
+
tunnels: Optional list of tunnels to start. Defaults to all configured tunnels.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Dict of tunnel name -> TunnelStatus.
|
|
145
|
+
"""
|
|
146
|
+
self._current_vm_ip = vm_ip
|
|
147
|
+
self._current_ssh_user = ssh_user
|
|
148
|
+
|
|
149
|
+
tunnels_to_start = tunnels or self.tunnel_configs
|
|
150
|
+
results = {}
|
|
151
|
+
|
|
152
|
+
for config in tunnels_to_start:
|
|
153
|
+
status = self._start_tunnel(config, vm_ip, ssh_user)
|
|
154
|
+
results[config.name] = status
|
|
155
|
+
|
|
156
|
+
return results
|
|
157
|
+
|
|
158
|
+
def _start_tunnel(
|
|
159
|
+
self,
|
|
160
|
+
config: TunnelConfig,
|
|
161
|
+
vm_ip: str,
|
|
162
|
+
ssh_user: str,
|
|
163
|
+
) -> TunnelStatus:
|
|
164
|
+
"""Start a single SSH tunnel.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
config: Tunnel configuration.
|
|
168
|
+
vm_ip: IP address of the Azure VM.
|
|
169
|
+
ssh_user: SSH username.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
TunnelStatus indicating success or failure.
|
|
173
|
+
"""
|
|
174
|
+
# Check if tunnel already active
|
|
175
|
+
if config.name in self._active_tunnels:
|
|
176
|
+
proc = self._active_tunnels[config.name][1]
|
|
177
|
+
if proc.poll() is None: # Still running
|
|
178
|
+
logger.debug(f"Tunnel {config.name} already active")
|
|
179
|
+
return TunnelStatus(
|
|
180
|
+
name=config.name,
|
|
181
|
+
active=True,
|
|
182
|
+
local_port=config.local_port,
|
|
183
|
+
remote_endpoint=f"{vm_ip}:{config.remote_port}",
|
|
184
|
+
pid=proc.pid,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Check if local port is already in use
|
|
188
|
+
if self._is_port_in_use(config.local_port):
|
|
189
|
+
# Port in use - check if it's an existing SSH tunnel (likely created manually)
|
|
190
|
+
# If we can reach the service through it, consider it active
|
|
191
|
+
if self._check_tunnel_works(config.local_port, config.remote_port):
|
|
192
|
+
logger.info(f"Port {config.local_port} has existing working tunnel")
|
|
193
|
+
return TunnelStatus(
|
|
194
|
+
name=config.name,
|
|
195
|
+
active=True,
|
|
196
|
+
local_port=config.local_port,
|
|
197
|
+
remote_endpoint=f"{vm_ip}:{config.remote_port}",
|
|
198
|
+
pid=None, # We don't know the PID of the external tunnel
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
logger.warning(f"Port {config.local_port} already in use by unknown process")
|
|
202
|
+
return TunnelStatus(
|
|
203
|
+
name=config.name,
|
|
204
|
+
active=False,
|
|
205
|
+
local_port=config.local_port,
|
|
206
|
+
remote_endpoint=f"{vm_ip}:{config.remote_port}",
|
|
207
|
+
error=f"Port {config.local_port} in use by another process",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Build SSH command with keepalive settings to prevent timeout during long runs
|
|
211
|
+
# ServerAliveInterval=60: Send keepalive every 60 seconds
|
|
212
|
+
# ServerAliveCountMax=10: Disconnect after 10 missed keepalives (10 min tolerance)
|
|
213
|
+
# TCPKeepAlive=yes: Enable TCP-level keepalive as additional safeguard
|
|
214
|
+
ssh_cmd = [
|
|
215
|
+
"ssh",
|
|
216
|
+
"-o", "StrictHostKeyChecking=no",
|
|
217
|
+
"-o", "UserKnownHostsFile=/dev/null",
|
|
218
|
+
"-o", "LogLevel=ERROR",
|
|
219
|
+
"-o", "ServerAliveInterval=60",
|
|
220
|
+
"-o", "ServerAliveCountMax=10",
|
|
221
|
+
"-o", "TCPKeepAlive=yes",
|
|
222
|
+
"-o", "ExitOnForwardFailure=yes",
|
|
223
|
+
"-i", str(self.ssh_key_path),
|
|
224
|
+
"-N", # Don't execute remote command
|
|
225
|
+
"-L", f"{config.local_port}:{config.remote_host}:{config.remote_port}",
|
|
226
|
+
f"{ssh_user}@{vm_ip}",
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
try:
|
|
230
|
+
# Start SSH tunnel in background
|
|
231
|
+
proc = subprocess.Popen(
|
|
232
|
+
ssh_cmd,
|
|
233
|
+
stdout=subprocess.DEVNULL,
|
|
234
|
+
stderr=subprocess.PIPE,
|
|
235
|
+
start_new_session=True, # Detach from terminal
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Wait briefly to check if it started successfully
|
|
239
|
+
time.sleep(0.5)
|
|
240
|
+
|
|
241
|
+
if proc.poll() is not None:
|
|
242
|
+
# Process exited, get error
|
|
243
|
+
_, stderr = proc.communicate(timeout=1)
|
|
244
|
+
error_msg = stderr.decode().strip() if stderr else "Unknown error"
|
|
245
|
+
logger.error(f"Tunnel {config.name} failed: {error_msg}")
|
|
246
|
+
return TunnelStatus(
|
|
247
|
+
name=config.name,
|
|
248
|
+
active=False,
|
|
249
|
+
local_port=config.local_port,
|
|
250
|
+
remote_endpoint=f"{vm_ip}:{config.remote_port}",
|
|
251
|
+
error=error_msg[:200],
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Tunnel started successfully
|
|
255
|
+
self._active_tunnels[config.name] = (config, proc)
|
|
256
|
+
logger.info(f"Started tunnel {config.name}: localhost:{config.local_port} -> {vm_ip}:{config.remote_port}")
|
|
257
|
+
|
|
258
|
+
return TunnelStatus(
|
|
259
|
+
name=config.name,
|
|
260
|
+
active=True,
|
|
261
|
+
local_port=config.local_port,
|
|
262
|
+
remote_endpoint=f"{vm_ip}:{config.remote_port}",
|
|
263
|
+
pid=proc.pid,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.error(f"Failed to start tunnel {config.name}: {e}")
|
|
268
|
+
return TunnelStatus(
|
|
269
|
+
name=config.name,
|
|
270
|
+
active=False,
|
|
271
|
+
local_port=config.local_port,
|
|
272
|
+
remote_endpoint=f"{vm_ip}:{config.remote_port}",
|
|
273
|
+
error=str(e)[:200],
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
def stop_tunnel(self, name: str) -> bool:
|
|
277
|
+
"""Stop a specific tunnel by name.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
name: Tunnel name (e.g., "vnc", "waa").
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
True if tunnel was stopped, False if not found.
|
|
284
|
+
"""
|
|
285
|
+
if name not in self._active_tunnels:
|
|
286
|
+
return False
|
|
287
|
+
|
|
288
|
+
config, proc = self._active_tunnels[name]
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
# Send SIGTERM to gracefully stop
|
|
292
|
+
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
|
293
|
+
proc.wait(timeout=5)
|
|
294
|
+
except ProcessLookupError:
|
|
295
|
+
pass # Already dead
|
|
296
|
+
except subprocess.TimeoutExpired:
|
|
297
|
+
# Force kill
|
|
298
|
+
try:
|
|
299
|
+
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
|
300
|
+
except ProcessLookupError:
|
|
301
|
+
pass
|
|
302
|
+
|
|
303
|
+
del self._active_tunnels[name]
|
|
304
|
+
logger.info(f"Stopped tunnel {name}")
|
|
305
|
+
return True
|
|
306
|
+
|
|
307
|
+
def stop_all_tunnels(self) -> None:
|
|
308
|
+
"""Stop all active tunnels."""
|
|
309
|
+
for name in list(self._active_tunnels.keys()):
|
|
310
|
+
self.stop_tunnel(name)
|
|
311
|
+
self._current_vm_ip = None
|
|
312
|
+
self._current_ssh_user = None
|
|
313
|
+
|
|
314
|
+
def get_tunnel_status(self, auto_restart: bool = True) -> dict[str, TunnelStatus]:
|
|
315
|
+
"""Get status of all configured tunnels.
|
|
316
|
+
|
|
317
|
+
This method checks the actual port status, not just internal state.
|
|
318
|
+
This correctly reports tunnels as active even if they were started
|
|
319
|
+
by a different process or if the tunnel manager was restarted.
|
|
320
|
+
|
|
321
|
+
If auto_reconnect is enabled and a tunnel is found dead, this method
|
|
322
|
+
will attempt to restart it automatically.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
auto_restart: If True and auto_reconnect is enabled, restart dead tunnels.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Dict of tunnel name -> TunnelStatus.
|
|
329
|
+
"""
|
|
330
|
+
results = {}
|
|
331
|
+
tunnels_to_restart = []
|
|
332
|
+
|
|
333
|
+
for config in self.tunnel_configs:
|
|
334
|
+
if config.name in self._active_tunnels:
|
|
335
|
+
_, proc = self._active_tunnels[config.name]
|
|
336
|
+
if proc.poll() is None: # Still running
|
|
337
|
+
# Reset reconnect attempts on successful check
|
|
338
|
+
self._reconnect_attempts[config.name] = 0
|
|
339
|
+
results[config.name] = TunnelStatus(
|
|
340
|
+
name=config.name,
|
|
341
|
+
active=True,
|
|
342
|
+
local_port=config.local_port,
|
|
343
|
+
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}" if self._current_vm_ip else "unknown",
|
|
344
|
+
pid=proc.pid,
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
# Process died - but check if port is still working
|
|
348
|
+
# (could be another tunnel on the same port)
|
|
349
|
+
del self._active_tunnels[config.name]
|
|
350
|
+
if self._is_port_in_use(config.local_port) and self._check_tunnel_works(config.local_port, config.remote_port):
|
|
351
|
+
results[config.name] = TunnelStatus(
|
|
352
|
+
name=config.name,
|
|
353
|
+
active=True,
|
|
354
|
+
local_port=config.local_port,
|
|
355
|
+
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}" if self._current_vm_ip else "external",
|
|
356
|
+
pid=None, # External tunnel, PID unknown
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
# Tunnel is dead - mark for restart if auto_reconnect enabled
|
|
360
|
+
if self._auto_reconnect and auto_restart and self._current_vm_ip:
|
|
361
|
+
tunnels_to_restart.append(config)
|
|
362
|
+
results[config.name] = TunnelStatus(
|
|
363
|
+
name=config.name,
|
|
364
|
+
active=False,
|
|
365
|
+
local_port=config.local_port,
|
|
366
|
+
remote_endpoint="",
|
|
367
|
+
error="Tunnel process exited",
|
|
368
|
+
)
|
|
369
|
+
else:
|
|
370
|
+
# Not tracked internally - but check if an external tunnel exists
|
|
371
|
+
# This handles tunnels started by other processes or after manager restart
|
|
372
|
+
if self._is_port_in_use(config.local_port) and self._check_tunnel_works(config.local_port, config.remote_port):
|
|
373
|
+
logger.debug(f"Found working external tunnel on port {config.local_port}")
|
|
374
|
+
results[config.name] = TunnelStatus(
|
|
375
|
+
name=config.name,
|
|
376
|
+
active=True,
|
|
377
|
+
local_port=config.local_port,
|
|
378
|
+
remote_endpoint=f"{self._current_vm_ip}:{config.remote_port}" if self._current_vm_ip else "external",
|
|
379
|
+
pid=None, # External tunnel, PID unknown
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
results[config.name] = TunnelStatus(
|
|
383
|
+
name=config.name,
|
|
384
|
+
active=False,
|
|
385
|
+
local_port=config.local_port,
|
|
386
|
+
remote_endpoint="",
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
# Auto-restart dead tunnels
|
|
390
|
+
for config in tunnels_to_restart:
|
|
391
|
+
attempts = self._reconnect_attempts.get(config.name, 0)
|
|
392
|
+
if attempts < self.MAX_RECONNECT_ATTEMPTS:
|
|
393
|
+
logger.info(f"Auto-reconnecting tunnel {config.name} (attempt {attempts + 1}/{self.MAX_RECONNECT_ATTEMPTS})")
|
|
394
|
+
time.sleep(self.RECONNECT_DELAY_SECONDS)
|
|
395
|
+
self._reconnect_attempts[config.name] = attempts + 1
|
|
396
|
+
status = self._start_tunnel(config, self._current_vm_ip, self._current_ssh_user or "azureuser")
|
|
397
|
+
results[config.name] = status
|
|
398
|
+
if status.active:
|
|
399
|
+
logger.info(f"Successfully reconnected tunnel {config.name}")
|
|
400
|
+
self._reconnect_attempts[config.name] = 0 # Reset on success
|
|
401
|
+
else:
|
|
402
|
+
logger.warning(f"Tunnel {config.name} exceeded max reconnect attempts ({self.MAX_RECONNECT_ATTEMPTS})")
|
|
403
|
+
results[config.name] = TunnelStatus(
|
|
404
|
+
name=config.name,
|
|
405
|
+
active=False,
|
|
406
|
+
local_port=config.local_port,
|
|
407
|
+
remote_endpoint="",
|
|
408
|
+
error=f"Max reconnect attempts ({self.MAX_RECONNECT_ATTEMPTS}) exceeded",
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
return results
|
|
412
|
+
|
|
413
|
+
def is_tunnel_active(self, name: str) -> bool:
|
|
414
|
+
"""Check if a specific tunnel is active.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
name: Tunnel name.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
True if tunnel is active.
|
|
421
|
+
"""
|
|
422
|
+
status = self.get_tunnel_status()
|
|
423
|
+
return name in status and status[name].active
|
|
424
|
+
|
|
425
|
+
def reset_reconnect_attempts(self, name: str | None = None) -> None:
|
|
426
|
+
"""Reset reconnect attempt counter for tunnels.
|
|
427
|
+
|
|
428
|
+
Call this after manually fixing connectivity issues or when
|
|
429
|
+
VM is known to be healthy again.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
name: Tunnel name to reset, or None to reset all.
|
|
433
|
+
"""
|
|
434
|
+
if name:
|
|
435
|
+
self._reconnect_attempts[name] = 0
|
|
436
|
+
else:
|
|
437
|
+
self._reconnect_attempts.clear()
|
|
438
|
+
logger.info(f"Reset reconnect attempts for {name or 'all tunnels'}")
|
|
439
|
+
|
|
440
|
+
def ensure_tunnels_for_vm(
|
|
441
|
+
self,
|
|
442
|
+
vm_ip: str,
|
|
443
|
+
ssh_user: str = "azureuser",
|
|
444
|
+
) -> dict[str, TunnelStatus]:
|
|
445
|
+
"""Ensure tunnels are running for a VM, starting if needed.
|
|
446
|
+
|
|
447
|
+
This is idempotent - safe to call repeatedly.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
vm_ip: IP address of the Azure VM.
|
|
451
|
+
ssh_user: SSH username.
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
Dict of tunnel name -> TunnelStatus.
|
|
455
|
+
"""
|
|
456
|
+
# If VM changed, stop old tunnels and reset reconnect attempts
|
|
457
|
+
if self._current_vm_ip and self._current_vm_ip != vm_ip:
|
|
458
|
+
logger.info(f"VM IP changed from {self._current_vm_ip} to {vm_ip}, restarting tunnels")
|
|
459
|
+
self.stop_all_tunnels()
|
|
460
|
+
self.reset_reconnect_attempts() # Fresh start for new VM
|
|
461
|
+
|
|
462
|
+
# Check current status and start any missing tunnels
|
|
463
|
+
# get_tunnel_status will auto-restart dead tunnels if enabled
|
|
464
|
+
current_status = self.get_tunnel_status()
|
|
465
|
+
all_active = all(s.active for s in current_status.values())
|
|
466
|
+
|
|
467
|
+
if all_active and self._current_vm_ip == vm_ip:
|
|
468
|
+
return current_status
|
|
469
|
+
|
|
470
|
+
# Start tunnels (also resets reconnect attempts for this VM)
|
|
471
|
+
self.reset_reconnect_attempts()
|
|
472
|
+
return self.start_tunnels_for_vm(vm_ip, ssh_user)
|
|
473
|
+
|
|
474
|
+
def _is_port_in_use(self, port: int) -> bool:
|
|
475
|
+
"""Check if a local port is in use.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
port: Port number.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
True if port is in use.
|
|
482
|
+
"""
|
|
483
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
484
|
+
try:
|
|
485
|
+
s.bind(("localhost", port))
|
|
486
|
+
return False
|
|
487
|
+
except OSError:
|
|
488
|
+
return True
|
|
489
|
+
|
|
490
|
+
def _check_tunnel_works(self, local_port: int, remote_port: int) -> bool:
|
|
491
|
+
"""Check if an existing tunnel on a port is actually working.
|
|
492
|
+
|
|
493
|
+
For VNC (8006), check if we get HTTP response from noVNC.
|
|
494
|
+
For WAA (5000), check if /probe endpoint responds.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
local_port: Local port to check.
|
|
498
|
+
remote_port: Remote port (used to determine service type).
|
|
499
|
+
|
|
500
|
+
Returns:
|
|
501
|
+
True if tunnel appears to be working.
|
|
502
|
+
"""
|
|
503
|
+
import urllib.request
|
|
504
|
+
import urllib.error
|
|
505
|
+
|
|
506
|
+
try:
|
|
507
|
+
if remote_port == 5000:
|
|
508
|
+
# WAA server - check /probe endpoint
|
|
509
|
+
req = urllib.request.Request(
|
|
510
|
+
f"http://localhost:{local_port}/probe",
|
|
511
|
+
method="GET",
|
|
512
|
+
)
|
|
513
|
+
with urllib.request.urlopen(req, timeout=3) as resp:
|
|
514
|
+
return resp.status == 200
|
|
515
|
+
elif remote_port == 8006:
|
|
516
|
+
# VNC - check if noVNC responds
|
|
517
|
+
req = urllib.request.Request(
|
|
518
|
+
f"http://localhost:{local_port}/",
|
|
519
|
+
method="GET",
|
|
520
|
+
)
|
|
521
|
+
with urllib.request.urlopen(req, timeout=3) as resp:
|
|
522
|
+
return resp.status == 200
|
|
523
|
+
else:
|
|
524
|
+
# Unknown service - try to connect
|
|
525
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
526
|
+
s.settimeout(3)
|
|
527
|
+
s.connect(("localhost", local_port))
|
|
528
|
+
return True
|
|
529
|
+
except (urllib.error.URLError, socket.error, OSError):
|
|
530
|
+
return False
|
|
531
|
+
|
|
532
|
+
def __del__(self):
|
|
533
|
+
"""Clean up tunnels on destruction."""
|
|
534
|
+
try:
|
|
535
|
+
self.stop_all_tunnels()
|
|
536
|
+
except Exception:
|
|
537
|
+
pass
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
# Global tunnel manager instance
|
|
541
|
+
_tunnel_manager: SSHTunnelManager | None = None
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def get_tunnel_manager() -> SSHTunnelManager:
|
|
545
|
+
"""Get the global tunnel manager instance.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
SSHTunnelManager instance.
|
|
549
|
+
"""
|
|
550
|
+
global _tunnel_manager
|
|
551
|
+
if _tunnel_manager is None:
|
|
552
|
+
_tunnel_manager = SSHTunnelManager()
|
|
553
|
+
return _tunnel_manager
|