experimaestro 2.0.0b4__py3-none-any.whl → 2.0.0b8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of experimaestro might be problematic. Click here for more details.
- experimaestro/cli/__init__.py +177 -31
- experimaestro/experiments/cli.py +6 -2
- experimaestro/scheduler/base.py +21 -0
- experimaestro/scheduler/experiment.py +64 -34
- experimaestro/scheduler/interfaces.py +27 -0
- experimaestro/scheduler/remote/__init__.py +31 -0
- experimaestro/scheduler/remote/client.py +874 -0
- experimaestro/scheduler/remote/protocol.py +467 -0
- experimaestro/scheduler/remote/server.py +423 -0
- experimaestro/scheduler/remote/sync.py +144 -0
- experimaestro/scheduler/services.py +158 -32
- experimaestro/scheduler/state_db.py +58 -9
- experimaestro/scheduler/state_provider.py +512 -91
- experimaestro/scheduler/state_sync.py +65 -8
- experimaestro/tests/test_cli_jobs.py +3 -3
- experimaestro/tests/test_remote_state.py +671 -0
- experimaestro/tests/test_state_db.py +8 -8
- experimaestro/tui/app.py +100 -8
- experimaestro/version.py +2 -2
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/METADATA +4 -4
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/RECORD +24 -18
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/WHEEL +0 -0
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/entry_points.txt +0 -0
- {experimaestro-2.0.0b4.dist-info → experimaestro-2.0.0b8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,874 @@
|
|
|
1
|
+
"""SSH State Provider Client
|
|
2
|
+
|
|
3
|
+
Client that connects via SSH to a remote SSHStateProviderServer and implements
|
|
4
|
+
the StateProvider-like interface for local TUI/web UI usage.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
client = SSHStateProviderClient(host="server", remote_workspace="/path/to/workspace")
|
|
8
|
+
client.connect()
|
|
9
|
+
experiments = client.get_experiments()
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import atexit
|
|
13
|
+
import logging
|
|
14
|
+
import shutil
|
|
15
|
+
import subprocess
|
|
16
|
+
import tempfile
|
|
17
|
+
import threading
|
|
18
|
+
from concurrent.futures import Future, TimeoutError as FutureTimeoutError
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from importlib.metadata import version as get_package_version
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
|
23
|
+
|
|
24
|
+
from experimaestro.scheduler.state_provider import (
|
|
25
|
+
StateProvider,
|
|
26
|
+
StateEvent,
|
|
27
|
+
StateEventType,
|
|
28
|
+
StateListener,
|
|
29
|
+
MockJob,
|
|
30
|
+
MockExperiment,
|
|
31
|
+
MockService,
|
|
32
|
+
)
|
|
33
|
+
from experimaestro.scheduler.interfaces import (
|
|
34
|
+
BaseJob,
|
|
35
|
+
BaseExperiment,
|
|
36
|
+
BaseService,
|
|
37
|
+
)
|
|
38
|
+
from experimaestro.scheduler.remote.protocol import (
|
|
39
|
+
RPCMethod,
|
|
40
|
+
NotificationMethod,
|
|
41
|
+
RPCResponse,
|
|
42
|
+
RPCNotification,
|
|
43
|
+
parse_message,
|
|
44
|
+
create_request,
|
|
45
|
+
serialize_datetime,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from experimaestro.scheduler.remote.sync import RemoteFileSynchronizer
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
logger = logging.getLogger("xpm.remote.client")
|
|
53
|
+
|
|
54
|
+
# Default timeout for RPC requests (seconds)
|
|
55
|
+
DEFAULT_TIMEOUT = 30.0
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _strip_dev_version(version: str) -> str:
|
|
59
|
+
"""Strip the .devN suffix from a version string.
|
|
60
|
+
|
|
61
|
+
Examples:
|
|
62
|
+
'2.0.0b3.dev8' -> '2.0.0b3'
|
|
63
|
+
'1.2.3.dev1' -> '1.2.3'
|
|
64
|
+
'1.2.3' -> '1.2.3' (unchanged)
|
|
65
|
+
"""
|
|
66
|
+
import re
|
|
67
|
+
|
|
68
|
+
return re.sub(r"\.dev\d+$", "", version)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class SSHStateProviderClient(StateProvider):
|
|
72
|
+
"""Client that connects to SSHStateProviderServer via SSH
|
|
73
|
+
|
|
74
|
+
This client implements the StateProvider interface for remote experiment
|
|
75
|
+
monitoring via SSH.
|
|
76
|
+
|
|
77
|
+
Features:
|
|
78
|
+
- JSON-RPC over SSH stdin/stdout
|
|
79
|
+
- Async request/response handling with futures
|
|
80
|
+
- Server push notifications converted to StateEvents
|
|
81
|
+
- On-demand rsync for specific paths (used by services like TensorboardService)
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
host: str,
|
|
87
|
+
remote_workspace: str,
|
|
88
|
+
ssh_options: Optional[List[str]] = None,
|
|
89
|
+
remote_xpm_path: Optional[str] = None,
|
|
90
|
+
):
|
|
91
|
+
"""Initialize the client
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
host: SSH host (user@host or host)
|
|
95
|
+
remote_workspace: Path to workspace on the remote host
|
|
96
|
+
ssh_options: Additional SSH options (e.g., ["-p", "2222"])
|
|
97
|
+
remote_xpm_path: Path to experimaestro executable on remote host.
|
|
98
|
+
If None, uses 'uv tool run experimaestro==<version>'.
|
|
99
|
+
"""
|
|
100
|
+
self.host = host
|
|
101
|
+
self.remote_workspace = remote_workspace
|
|
102
|
+
self.ssh_options = ssh_options or []
|
|
103
|
+
self.remote_xpm_path = remote_xpm_path
|
|
104
|
+
|
|
105
|
+
# Session-specific temporary cache directory (created on connect)
|
|
106
|
+
self._temp_dir: Optional[str] = None
|
|
107
|
+
self.local_cache_dir: Optional[Path] = None
|
|
108
|
+
self.workspace_path: Optional[Path] = None # For compatibility
|
|
109
|
+
|
|
110
|
+
self._process: Optional[subprocess.Popen] = None
|
|
111
|
+
self._stdin = None
|
|
112
|
+
self._stdout = None
|
|
113
|
+
self._stderr = None
|
|
114
|
+
|
|
115
|
+
self._listeners: Set[StateListener] = set()
|
|
116
|
+
self._listener_lock = threading.Lock()
|
|
117
|
+
|
|
118
|
+
self._response_handlers: Dict[int, Future] = {}
|
|
119
|
+
self._response_lock = threading.Lock()
|
|
120
|
+
self._request_id = 0
|
|
121
|
+
|
|
122
|
+
self._read_thread: Optional[threading.Thread] = None
|
|
123
|
+
self._notify_thread: Optional[threading.Thread] = None
|
|
124
|
+
self._running = False
|
|
125
|
+
self._connected = False
|
|
126
|
+
|
|
127
|
+
self._synchronizer: Optional["RemoteFileSynchronizer"] = None
|
|
128
|
+
|
|
129
|
+
# Throttled notification delivery to avoid flooding UI
|
|
130
|
+
self._pending_events: List[StateEvent] = []
|
|
131
|
+
self._pending_events_lock = threading.Lock()
|
|
132
|
+
self._notify_interval = 2.0 # Seconds between notification batches
|
|
133
|
+
|
|
134
|
+
# Service cache (from base class)
|
|
135
|
+
self._init_service_cache()
|
|
136
|
+
|
|
137
|
+
def connect(self, timeout: float = 30.0):
|
|
138
|
+
"""Establish SSH connection and start remote server
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
timeout: Connection timeout in seconds
|
|
142
|
+
"""
|
|
143
|
+
if self._connected:
|
|
144
|
+
logger.warning("Already connected")
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
# Create session-specific temporary cache directory
|
|
148
|
+
self._temp_dir = tempfile.mkdtemp(prefix="xpm_remote_")
|
|
149
|
+
self.local_cache_dir = Path(self._temp_dir)
|
|
150
|
+
self.workspace_path = self.local_cache_dir
|
|
151
|
+
logger.debug("Created temporary cache directory: %s", self._temp_dir)
|
|
152
|
+
|
|
153
|
+
# Register cleanup on exit (in case disconnect isn't called)
|
|
154
|
+
atexit.register(self._cleanup_temp_dir)
|
|
155
|
+
|
|
156
|
+
# Build SSH command
|
|
157
|
+
cmd = ["ssh"]
|
|
158
|
+
cmd.extend(self.ssh_options)
|
|
159
|
+
cmd.append(self.host)
|
|
160
|
+
|
|
161
|
+
# Build remote command (workdir is passed to experiments group)
|
|
162
|
+
if self.remote_xpm_path:
|
|
163
|
+
# Use specified path to experimaestro
|
|
164
|
+
remote_cmd = (
|
|
165
|
+
f"{self.remote_xpm_path} experiments "
|
|
166
|
+
f"--workdir {self.remote_workspace} monitor-server"
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
# Use uv tool run with version pinning
|
|
170
|
+
try:
|
|
171
|
+
xpm_version = get_package_version("experimaestro")
|
|
172
|
+
# Strip .devN suffix for release compatibility
|
|
173
|
+
xpm_version = _strip_dev_version(xpm_version)
|
|
174
|
+
except Exception:
|
|
175
|
+
xpm_version = None
|
|
176
|
+
|
|
177
|
+
if xpm_version:
|
|
178
|
+
remote_cmd = (
|
|
179
|
+
f"uv tool run experimaestro=={xpm_version} experiments "
|
|
180
|
+
f"--workdir {self.remote_workspace} monitor-server"
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
remote_cmd = (
|
|
184
|
+
f"uv tool run experimaestro experiments "
|
|
185
|
+
f"--workdir {self.remote_workspace} monitor-server"
|
|
186
|
+
)
|
|
187
|
+
cmd.append(remote_cmd)
|
|
188
|
+
|
|
189
|
+
logger.info("Connecting to %s, workspace: %s", self.host, self.remote_workspace)
|
|
190
|
+
logger.debug("SSH command: %s", " ".join(cmd))
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
self._process = subprocess.Popen(
|
|
194
|
+
cmd,
|
|
195
|
+
stdin=subprocess.PIPE,
|
|
196
|
+
stdout=subprocess.PIPE,
|
|
197
|
+
stderr=subprocess.PIPE,
|
|
198
|
+
bufsize=0, # Unbuffered
|
|
199
|
+
)
|
|
200
|
+
self._stdin = self._process.stdin
|
|
201
|
+
self._stdout = self._process.stdout
|
|
202
|
+
self._stderr = self._process.stderr
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.error("Failed to start SSH process: %s", e)
|
|
205
|
+
raise ConnectionError(f"Failed to connect to {self.host}: {e}")
|
|
206
|
+
|
|
207
|
+
self._running = True
|
|
208
|
+
|
|
209
|
+
# Start read thread for responses and notifications
|
|
210
|
+
self._read_thread = threading.Thread(
|
|
211
|
+
target=self._read_loop, daemon=True, name="SSHClient-Read"
|
|
212
|
+
)
|
|
213
|
+
self._read_thread.start()
|
|
214
|
+
|
|
215
|
+
# Start notification thread for throttled event delivery
|
|
216
|
+
self._notify_thread = threading.Thread(
|
|
217
|
+
target=self._notify_loop, daemon=True, name="SSHClient-Notify"
|
|
218
|
+
)
|
|
219
|
+
self._notify_thread.start()
|
|
220
|
+
|
|
221
|
+
# Wait for connection to be established by sending a test request
|
|
222
|
+
try:
|
|
223
|
+
sync_info = self._call_sync(RPCMethod.GET_SYNC_INFO, {}, timeout=timeout)
|
|
224
|
+
logger.info(
|
|
225
|
+
"Connected to remote workspace: %s", sync_info.get("workspace_path")
|
|
226
|
+
)
|
|
227
|
+
except Exception as e:
|
|
228
|
+
self.disconnect()
|
|
229
|
+
raise ConnectionError(f"Failed to establish connection: {e}")
|
|
230
|
+
|
|
231
|
+
self._connected = True
|
|
232
|
+
|
|
233
|
+
def disconnect(self):
|
|
234
|
+
"""Disconnect from the remote server"""
|
|
235
|
+
self._running = False
|
|
236
|
+
self._connected = False
|
|
237
|
+
|
|
238
|
+
# Close stdin to signal EOF to remote server
|
|
239
|
+
if self._stdin:
|
|
240
|
+
try:
|
|
241
|
+
self._stdin.close()
|
|
242
|
+
except Exception:
|
|
243
|
+
pass
|
|
244
|
+
|
|
245
|
+
# Terminate the SSH process
|
|
246
|
+
if self._process:
|
|
247
|
+
try:
|
|
248
|
+
self._process.terminate()
|
|
249
|
+
self._process.wait(timeout=5.0)
|
|
250
|
+
except Exception:
|
|
251
|
+
try:
|
|
252
|
+
self._process.kill()
|
|
253
|
+
except Exception:
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
# Wait for threads to finish
|
|
257
|
+
if self._read_thread and self._read_thread.is_alive():
|
|
258
|
+
self._read_thread.join(timeout=2.0)
|
|
259
|
+
if self._notify_thread and self._notify_thread.is_alive():
|
|
260
|
+
self._notify_thread.join(timeout=2.0)
|
|
261
|
+
|
|
262
|
+
# Cancel any pending requests
|
|
263
|
+
with self._response_lock:
|
|
264
|
+
for future in self._response_handlers.values():
|
|
265
|
+
if not future.done():
|
|
266
|
+
future.set_exception(ConnectionError("Disconnected"))
|
|
267
|
+
self._response_handlers.clear()
|
|
268
|
+
|
|
269
|
+
# Clear service cache (using base class method)
|
|
270
|
+
self._clear_service_cache()
|
|
271
|
+
|
|
272
|
+
# Clean up temporary cache directory
|
|
273
|
+
self._cleanup_temp_dir()
|
|
274
|
+
|
|
275
|
+
logger.info("Disconnected from %s", self.host)
|
|
276
|
+
|
|
277
|
+
def _cleanup_temp_dir(self):
|
|
278
|
+
"""Clean up the temporary cache directory"""
|
|
279
|
+
if self._temp_dir and Path(self._temp_dir).exists():
|
|
280
|
+
try:
|
|
281
|
+
shutil.rmtree(self._temp_dir)
|
|
282
|
+
logger.debug("Cleaned up temporary cache directory: %s", self._temp_dir)
|
|
283
|
+
except Exception as e:
|
|
284
|
+
logger.warning("Failed to clean up temp dir %s: %s", self._temp_dir, e)
|
|
285
|
+
finally:
|
|
286
|
+
self._temp_dir = None
|
|
287
|
+
self.local_cache_dir = None
|
|
288
|
+
# Unregister atexit handler if we cleaned up successfully
|
|
289
|
+
try:
|
|
290
|
+
atexit.unregister(self._cleanup_temp_dir)
|
|
291
|
+
except Exception:
|
|
292
|
+
pass
|
|
293
|
+
|
|
294
|
+
def close(self):
|
|
295
|
+
"""Alias for disconnect() for compatibility with WorkspaceStateProvider"""
|
|
296
|
+
self.disconnect()
|
|
297
|
+
|
|
298
|
+
def _read_loop(self):
|
|
299
|
+
"""Read responses and notifications from SSH stdout"""
|
|
300
|
+
while self._running:
|
|
301
|
+
try:
|
|
302
|
+
line = self._stdout.readline()
|
|
303
|
+
if not line:
|
|
304
|
+
# EOF - connection closed
|
|
305
|
+
logger.debug("SSH stdout closed")
|
|
306
|
+
break
|
|
307
|
+
|
|
308
|
+
line_str = line.decode("utf-8").strip()
|
|
309
|
+
if not line_str:
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
self._process_message(line_str)
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
if self._running:
|
|
316
|
+
logger.exception("Error in read loop: %s", e)
|
|
317
|
+
break
|
|
318
|
+
|
|
319
|
+
# Connection lost
|
|
320
|
+
if self._running:
|
|
321
|
+
logger.warning("Connection to %s lost", self.host)
|
|
322
|
+
self._connected = False
|
|
323
|
+
|
|
324
|
+
def _process_message(self, line: str):
|
|
325
|
+
"""Process a single message from the server"""
|
|
326
|
+
try:
|
|
327
|
+
msg = parse_message(line)
|
|
328
|
+
except ValueError as e:
|
|
329
|
+
logger.warning("Failed to parse message: %s", e)
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
if isinstance(msg, RPCResponse):
|
|
333
|
+
self._handle_response(msg)
|
|
334
|
+
elif isinstance(msg, RPCNotification):
|
|
335
|
+
self._handle_notification(msg)
|
|
336
|
+
else:
|
|
337
|
+
logger.debug("Unexpected message type: %s", type(msg).__name__)
|
|
338
|
+
|
|
339
|
+
def _handle_response(self, response: RPCResponse):
|
|
340
|
+
"""Handle a response from the server"""
|
|
341
|
+
with self._response_lock:
|
|
342
|
+
future = self._response_handlers.pop(response.id, None)
|
|
343
|
+
|
|
344
|
+
if future is None:
|
|
345
|
+
logger.warning("Received response for unknown request ID: %s", response.id)
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
if response.error:
|
|
349
|
+
future.set_exception(
|
|
350
|
+
RuntimeError(
|
|
351
|
+
f"RPC error {response.error.code}: {response.error.message}"
|
|
352
|
+
)
|
|
353
|
+
)
|
|
354
|
+
else:
|
|
355
|
+
future.set_result(response.result)
|
|
356
|
+
|
|
357
|
+
def _handle_notification(self, notification: RPCNotification):
|
|
358
|
+
"""Handle a notification from the server
|
|
359
|
+
|
|
360
|
+
Queues events for throttled delivery to avoid flooding the UI.
|
|
361
|
+
"""
|
|
362
|
+
method = notification.method
|
|
363
|
+
params = notification.params
|
|
364
|
+
|
|
365
|
+
logger.debug("Received notification: %s", method)
|
|
366
|
+
|
|
367
|
+
# Convert notification to StateEvent and queue for throttled delivery
|
|
368
|
+
event = self._notification_to_event(method, params)
|
|
369
|
+
if event:
|
|
370
|
+
with self._pending_events_lock:
|
|
371
|
+
self._pending_events.append(event)
|
|
372
|
+
|
|
373
|
+
# Handle shutdown notification immediately
|
|
374
|
+
if method == NotificationMethod.SHUTDOWN.value:
|
|
375
|
+
reason = params.get("reason", "unknown")
|
|
376
|
+
logger.info("Server shutdown: %s", reason)
|
|
377
|
+
self._connected = False
|
|
378
|
+
|
|
379
|
+
def _notify_loop(self):
|
|
380
|
+
"""Background thread that delivers pending events to listeners periodically
|
|
381
|
+
|
|
382
|
+
This throttles notification delivery to avoid flooding the UI with
|
|
383
|
+
rapid state changes.
|
|
384
|
+
"""
|
|
385
|
+
import time
|
|
386
|
+
|
|
387
|
+
while self._running:
|
|
388
|
+
time.sleep(self._notify_interval)
|
|
389
|
+
|
|
390
|
+
if not self._running:
|
|
391
|
+
break
|
|
392
|
+
|
|
393
|
+
# Get and clear pending events atomically
|
|
394
|
+
with self._pending_events_lock:
|
|
395
|
+
if not self._pending_events:
|
|
396
|
+
continue
|
|
397
|
+
events = self._pending_events.copy()
|
|
398
|
+
self._pending_events.clear()
|
|
399
|
+
|
|
400
|
+
# Deduplicate events by type (keep latest of each type)
|
|
401
|
+
# This prevents redundant refreshes for rapidly changing state
|
|
402
|
+
seen_types = set()
|
|
403
|
+
unique_events = []
|
|
404
|
+
for event in reversed(events):
|
|
405
|
+
if event.event_type not in seen_types:
|
|
406
|
+
seen_types.add(event.event_type)
|
|
407
|
+
unique_events.append(event)
|
|
408
|
+
unique_events.reverse()
|
|
409
|
+
|
|
410
|
+
# Notify listeners
|
|
411
|
+
for event in unique_events:
|
|
412
|
+
self._notify_listeners(event)
|
|
413
|
+
|
|
414
|
+
def _notification_to_event(self, method: str, params: Dict) -> Optional[StateEvent]:
|
|
415
|
+
"""Convert a notification to a StateEvent"""
|
|
416
|
+
if method == NotificationMethod.EXPERIMENT_UPDATED.value:
|
|
417
|
+
return StateEvent(
|
|
418
|
+
event_type=StateEventType.EXPERIMENT_UPDATED,
|
|
419
|
+
data=params.get("data", params),
|
|
420
|
+
)
|
|
421
|
+
elif method == NotificationMethod.RUN_UPDATED.value:
|
|
422
|
+
return StateEvent(
|
|
423
|
+
event_type=StateEventType.RUN_UPDATED,
|
|
424
|
+
data=params.get("data", params),
|
|
425
|
+
)
|
|
426
|
+
elif method == NotificationMethod.JOB_UPDATED.value:
|
|
427
|
+
return StateEvent(
|
|
428
|
+
event_type=StateEventType.JOB_UPDATED,
|
|
429
|
+
data=params.get("data", params),
|
|
430
|
+
)
|
|
431
|
+
elif method == NotificationMethod.SERVICE_UPDATED.value:
|
|
432
|
+
return StateEvent(
|
|
433
|
+
event_type=StateEventType.SERVICE_UPDATED,
|
|
434
|
+
data=params.get("data", params),
|
|
435
|
+
)
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
def _notify_listeners(self, event: StateEvent):
|
|
439
|
+
"""Notify all registered listeners of a state event"""
|
|
440
|
+
with self._listener_lock:
|
|
441
|
+
listeners = list(self._listeners)
|
|
442
|
+
|
|
443
|
+
for listener in listeners:
|
|
444
|
+
try:
|
|
445
|
+
listener(event)
|
|
446
|
+
except Exception as e:
|
|
447
|
+
logger.exception("Error in listener: %s", e)
|
|
448
|
+
|
|
449
|
+
def _call(self, method: RPCMethod, params: Dict) -> Future:
|
|
450
|
+
"""Send an RPC request and return a Future for the response
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
method: RPC method to call
|
|
454
|
+
params: Method parameters
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Future that resolves to the response result
|
|
458
|
+
"""
|
|
459
|
+
if not self._running:
|
|
460
|
+
future = Future()
|
|
461
|
+
future.set_exception(ConnectionError("Not connected"))
|
|
462
|
+
return future
|
|
463
|
+
|
|
464
|
+
with self._response_lock:
|
|
465
|
+
self._request_id += 1
|
|
466
|
+
request_id = self._request_id
|
|
467
|
+
future = Future()
|
|
468
|
+
self._response_handlers[request_id] = future
|
|
469
|
+
|
|
470
|
+
request_json = create_request(method, params, request_id)
|
|
471
|
+
try:
|
|
472
|
+
self._stdin.write((request_json + "\n").encode("utf-8"))
|
|
473
|
+
self._stdin.flush()
|
|
474
|
+
except Exception as e:
|
|
475
|
+
with self._response_lock:
|
|
476
|
+
self._response_handlers.pop(request_id, None)
|
|
477
|
+
future.set_exception(e)
|
|
478
|
+
|
|
479
|
+
return future
|
|
480
|
+
|
|
481
|
+
def _call_sync(
|
|
482
|
+
self, method: RPCMethod, params: Dict, timeout: float = DEFAULT_TIMEOUT
|
|
483
|
+
):
|
|
484
|
+
"""Send an RPC request and wait for the response
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
method: RPC method to call
|
|
488
|
+
params: Method parameters
|
|
489
|
+
timeout: Request timeout in seconds
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
Response result
|
|
493
|
+
|
|
494
|
+
Raises:
|
|
495
|
+
TimeoutError: If the request times out
|
|
496
|
+
RuntimeError: If the RPC call returns an error
|
|
497
|
+
"""
|
|
498
|
+
future = self._call(method, params)
|
|
499
|
+
try:
|
|
500
|
+
return future.result(timeout=timeout)
|
|
501
|
+
except FutureTimeoutError:
|
|
502
|
+
raise TimeoutError(f"Request {method.value} timed out after {timeout}s")
|
|
503
|
+
|
|
504
|
+
# -------------------------------------------------------------------------
|
|
505
|
+
# StateProvider-like Interface
|
|
506
|
+
# -------------------------------------------------------------------------
|
|
507
|
+
|
|
508
|
+
def add_listener(self, listener: StateListener):
|
|
509
|
+
"""Register a listener for state change events"""
|
|
510
|
+
with self._listener_lock:
|
|
511
|
+
self._listeners.add(listener)
|
|
512
|
+
|
|
513
|
+
def remove_listener(self, listener: StateListener):
|
|
514
|
+
"""Unregister a listener"""
|
|
515
|
+
with self._listener_lock:
|
|
516
|
+
self._listeners.discard(listener)
|
|
517
|
+
|
|
518
|
+
def get_experiments(self, since: Optional[datetime] = None) -> List[BaseExperiment]:
|
|
519
|
+
"""Get list of all experiments"""
|
|
520
|
+
params = {"since": serialize_datetime(since)}
|
|
521
|
+
result = self._call_sync(RPCMethod.GET_EXPERIMENTS, params)
|
|
522
|
+
return [self._dict_to_experiment(d) for d in result]
|
|
523
|
+
|
|
524
|
+
def get_experiment(self, experiment_id: str) -> Optional[BaseExperiment]:
|
|
525
|
+
"""Get a specific experiment by ID"""
|
|
526
|
+
params = {"experiment_id": experiment_id}
|
|
527
|
+
result = self._call_sync(RPCMethod.GET_EXPERIMENT, params)
|
|
528
|
+
if result is None:
|
|
529
|
+
return None
|
|
530
|
+
return self._dict_to_experiment(result)
|
|
531
|
+
|
|
532
|
+
def get_experiment_runs(self, experiment_id: str) -> List[Dict]:
|
|
533
|
+
"""Get all runs for an experiment"""
|
|
534
|
+
params = {"experiment_id": experiment_id}
|
|
535
|
+
return self._call_sync(RPCMethod.GET_EXPERIMENT_RUNS, params)
|
|
536
|
+
|
|
537
|
+
def get_current_run(self, experiment_id: str) -> Optional[str]:
|
|
538
|
+
"""Get the current run ID for an experiment"""
|
|
539
|
+
exp = self.get_experiment(experiment_id)
|
|
540
|
+
if exp is None:
|
|
541
|
+
return None
|
|
542
|
+
return exp.current_run_id
|
|
543
|
+
|
|
544
|
+
def get_jobs(
|
|
545
|
+
self,
|
|
546
|
+
experiment_id: Optional[str] = None,
|
|
547
|
+
run_id: Optional[str] = None,
|
|
548
|
+
task_id: Optional[str] = None,
|
|
549
|
+
state: Optional[str] = None,
|
|
550
|
+
tags: Optional[Dict[str, str]] = None,
|
|
551
|
+
since: Optional[datetime] = None,
|
|
552
|
+
) -> List[BaseJob]:
|
|
553
|
+
"""Query jobs with optional filters"""
|
|
554
|
+
params = {
|
|
555
|
+
"experiment_id": experiment_id,
|
|
556
|
+
"run_id": run_id,
|
|
557
|
+
"task_id": task_id,
|
|
558
|
+
"state": state,
|
|
559
|
+
"tags": tags,
|
|
560
|
+
"since": serialize_datetime(since),
|
|
561
|
+
}
|
|
562
|
+
result = self._call_sync(RPCMethod.GET_JOBS, params)
|
|
563
|
+
return [self._dict_to_job(d) for d in result]
|
|
564
|
+
|
|
565
|
+
def get_job(
|
|
566
|
+
self, job_id: str, experiment_id: str, run_id: Optional[str] = None
|
|
567
|
+
) -> Optional[BaseJob]:
|
|
568
|
+
"""Get a specific job"""
|
|
569
|
+
params = {
|
|
570
|
+
"job_id": job_id,
|
|
571
|
+
"experiment_id": experiment_id,
|
|
572
|
+
"run_id": run_id,
|
|
573
|
+
}
|
|
574
|
+
result = self._call_sync(RPCMethod.GET_JOB, params)
|
|
575
|
+
if result is None:
|
|
576
|
+
return None
|
|
577
|
+
return self._dict_to_job(result)
|
|
578
|
+
|
|
579
|
+
def get_all_jobs(
|
|
580
|
+
self,
|
|
581
|
+
state: Optional[str] = None,
|
|
582
|
+
tags: Optional[Dict[str, str]] = None,
|
|
583
|
+
since: Optional[datetime] = None,
|
|
584
|
+
) -> List[BaseJob]:
|
|
585
|
+
"""Get all jobs across all experiments"""
|
|
586
|
+
params = {
|
|
587
|
+
"state": state,
|
|
588
|
+
"tags": tags,
|
|
589
|
+
"since": serialize_datetime(since),
|
|
590
|
+
}
|
|
591
|
+
result = self._call_sync(RPCMethod.GET_ALL_JOBS, params)
|
|
592
|
+
return [self._dict_to_job(d) for d in result]
|
|
593
|
+
|
|
594
|
+
def _fetch_services_from_storage(
|
|
595
|
+
self, experiment_id: Optional[str], run_id: Optional[str]
|
|
596
|
+
) -> List[BaseService]:
|
|
597
|
+
"""Fetch services from remote server.
|
|
598
|
+
|
|
599
|
+
Called by base class get_services when cache is empty.
|
|
600
|
+
"""
|
|
601
|
+
params = {
|
|
602
|
+
"experiment_id": experiment_id,
|
|
603
|
+
"run_id": run_id,
|
|
604
|
+
}
|
|
605
|
+
result = self._call_sync(RPCMethod.GET_SERVICES, params)
|
|
606
|
+
|
|
607
|
+
services = []
|
|
608
|
+
for d in result:
|
|
609
|
+
service = self._dict_to_service(d)
|
|
610
|
+
services.append(service)
|
|
611
|
+
|
|
612
|
+
return services
|
|
613
|
+
|
|
614
|
+
def get_services_raw(
|
|
615
|
+
self, experiment_id: Optional[str] = None, run_id: Optional[str] = None
|
|
616
|
+
) -> List[Dict]:
|
|
617
|
+
"""Get raw service data as dictionaries"""
|
|
618
|
+
params = {
|
|
619
|
+
"experiment_id": experiment_id,
|
|
620
|
+
"run_id": run_id,
|
|
621
|
+
}
|
|
622
|
+
return self._call_sync(RPCMethod.GET_SERVICES, params)
|
|
623
|
+
|
|
624
|
+
def kill_job(self, job: BaseJob, perform: bool = False) -> bool:
|
|
625
|
+
"""Kill a running job"""
|
|
626
|
+
if not perform:
|
|
627
|
+
# Dry run - just check if job is running
|
|
628
|
+
return job.state.running()
|
|
629
|
+
|
|
630
|
+
params = {
|
|
631
|
+
"job_id": job.identifier,
|
|
632
|
+
"experiment_id": getattr(job, "experiment_id", ""),
|
|
633
|
+
"run_id": getattr(job, "run_id", ""),
|
|
634
|
+
}
|
|
635
|
+
result = self._call_sync(RPCMethod.KILL_JOB, params)
|
|
636
|
+
return result.get("success", False)
|
|
637
|
+
|
|
638
|
+
def clean_job(self, job: BaseJob, perform: bool = False) -> bool:
|
|
639
|
+
"""Clean a finished job"""
|
|
640
|
+
if not perform:
|
|
641
|
+
# Dry run - just check if job is finished
|
|
642
|
+
return job.state.finished()
|
|
643
|
+
|
|
644
|
+
params = {
|
|
645
|
+
"job_id": job.identifier,
|
|
646
|
+
"experiment_id": getattr(job, "experiment_id", ""),
|
|
647
|
+
"run_id": getattr(job, "run_id", ""),
|
|
648
|
+
}
|
|
649
|
+
result = self._call_sync(RPCMethod.CLEAN_JOB, params)
|
|
650
|
+
return result.get("success", False)
|
|
651
|
+
|
|
652
|
+
# -------------------------------------------------------------------------
|
|
653
|
+
# Data Conversion
|
|
654
|
+
# -------------------------------------------------------------------------
|
|
655
|
+
|
|
656
|
+
def _dict_to_job(self, d: Dict) -> MockJob:
|
|
657
|
+
"""Convert a dictionary to a MockJob"""
|
|
658
|
+
state_str = d.get("state", "waiting")
|
|
659
|
+
|
|
660
|
+
# Map local cache path for the job
|
|
661
|
+
path = None
|
|
662
|
+
if d.get("path"):
|
|
663
|
+
# The path from remote is absolute on remote system
|
|
664
|
+
# We map it to local cache
|
|
665
|
+
remote_path = d["path"]
|
|
666
|
+
if remote_path.startswith(self.remote_workspace):
|
|
667
|
+
relative = remote_path[len(self.remote_workspace) :].lstrip("/")
|
|
668
|
+
path = self.local_cache_dir / relative
|
|
669
|
+
else:
|
|
670
|
+
path = Path(remote_path)
|
|
671
|
+
|
|
672
|
+
return MockJob(
|
|
673
|
+
identifier=d["identifier"],
|
|
674
|
+
task_id=d["task_id"],
|
|
675
|
+
locator=d["locator"],
|
|
676
|
+
path=path,
|
|
677
|
+
state=state_str,
|
|
678
|
+
submittime=self._parse_datetime_to_timestamp(d.get("submittime")),
|
|
679
|
+
starttime=self._parse_datetime_to_timestamp(d.get("starttime")),
|
|
680
|
+
endtime=self._parse_datetime_to_timestamp(d.get("endtime")),
|
|
681
|
+
progress=d.get("progress", []),
|
|
682
|
+
tags=d.get("tags", {}),
|
|
683
|
+
experiment_id=d.get("experiment_id", ""),
|
|
684
|
+
run_id=d.get("run_id", ""),
|
|
685
|
+
updated_at="",
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
def _dict_to_experiment(self, d: Dict) -> MockExperiment:
|
|
689
|
+
"""Convert a dictionary to a MockExperiment"""
|
|
690
|
+
# Map local cache path for the experiment
|
|
691
|
+
workdir = None
|
|
692
|
+
if d.get("workdir"):
|
|
693
|
+
remote_path = d["workdir"]
|
|
694
|
+
if remote_path.startswith(self.remote_workspace):
|
|
695
|
+
relative = remote_path[len(self.remote_workspace) :].lstrip("/")
|
|
696
|
+
workdir = self.local_cache_dir / relative
|
|
697
|
+
else:
|
|
698
|
+
workdir = Path(remote_path)
|
|
699
|
+
|
|
700
|
+
# Convert ISO datetime strings to Unix timestamps for TUI compatibility
|
|
701
|
+
started_at = self._parse_datetime_to_timestamp(d.get("started_at"))
|
|
702
|
+
ended_at = self._parse_datetime_to_timestamp(d.get("ended_at"))
|
|
703
|
+
|
|
704
|
+
return MockExperiment(
|
|
705
|
+
workdir=workdir or self.local_cache_dir / "xp" / d["experiment_id"],
|
|
706
|
+
current_run_id=d.get("current_run_id"),
|
|
707
|
+
total_jobs=d.get("total_jobs", 0),
|
|
708
|
+
finished_jobs=d.get("finished_jobs", 0),
|
|
709
|
+
failed_jobs=d.get("failed_jobs", 0),
|
|
710
|
+
updated_at=d.get("updated_at", ""),
|
|
711
|
+
started_at=started_at,
|
|
712
|
+
ended_at=ended_at,
|
|
713
|
+
hostname=d.get("hostname"),
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
def _dict_to_service(self, d: Dict) -> BaseService:
|
|
717
|
+
"""Convert a dictionary to a Service or MockService
|
|
718
|
+
|
|
719
|
+
Tries to recreate the actual Service from state_dict first.
|
|
720
|
+
Falls back to MockService with error message if module is missing.
|
|
721
|
+
"""
|
|
722
|
+
state_dict = d.get("state_dict", {})
|
|
723
|
+
service_id = d.get("service_id", "")
|
|
724
|
+
|
|
725
|
+
# Check for unserializable marker
|
|
726
|
+
if state_dict.get("__unserializable__"):
|
|
727
|
+
reason = state_dict.get("__reason__", "Service cannot be recreated")
|
|
728
|
+
return MockService(
|
|
729
|
+
service_id=service_id,
|
|
730
|
+
description_text=f"[{reason}]",
|
|
731
|
+
state_dict_data=state_dict,
|
|
732
|
+
experiment_id=d.get("experiment_id"),
|
|
733
|
+
run_id=d.get("run_id"),
|
|
734
|
+
url=d.get("url"),
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
# Try to recreate actual Service from state_dict
|
|
738
|
+
if state_dict and "__class__" in state_dict:
|
|
739
|
+
try:
|
|
740
|
+
from experimaestro.scheduler.services import Service
|
|
741
|
+
|
|
742
|
+
# Create path translator that syncs and translates paths
|
|
743
|
+
def path_translator(remote_path: str) -> Path:
|
|
744
|
+
"""Translate remote path to local, syncing if needed"""
|
|
745
|
+
local_path = self.sync_path(remote_path)
|
|
746
|
+
if local_path:
|
|
747
|
+
return local_path
|
|
748
|
+
# Fallback: map to local cache without sync
|
|
749
|
+
if remote_path.startswith(self.remote_workspace):
|
|
750
|
+
relative = remote_path[len(self.remote_workspace) :].lstrip("/")
|
|
751
|
+
return self.local_cache_dir / relative
|
|
752
|
+
return Path(remote_path)
|
|
753
|
+
|
|
754
|
+
service = Service.from_state_dict(state_dict, path_translator)
|
|
755
|
+
service.id = service_id
|
|
756
|
+
# Copy additional attributes
|
|
757
|
+
if d.get("experiment_id"):
|
|
758
|
+
service.experiment_id = d["experiment_id"]
|
|
759
|
+
if d.get("run_id"):
|
|
760
|
+
service.run_id = d["run_id"]
|
|
761
|
+
return service
|
|
762
|
+
except ModuleNotFoundError as e:
|
|
763
|
+
# Module not available locally - show error in description
|
|
764
|
+
missing_module = str(e).replace("No module named ", "").strip("'\"")
|
|
765
|
+
return MockService(
|
|
766
|
+
service_id=service_id,
|
|
767
|
+
description_text=f"[Missing module: {missing_module}]",
|
|
768
|
+
state_dict_data=state_dict,
|
|
769
|
+
experiment_id=d.get("experiment_id"),
|
|
770
|
+
run_id=d.get("run_id"),
|
|
771
|
+
url=d.get("url"),
|
|
772
|
+
)
|
|
773
|
+
except Exception as e:
|
|
774
|
+
# Other error - show in description
|
|
775
|
+
return MockService(
|
|
776
|
+
service_id=service_id,
|
|
777
|
+
description_text=f"[Error: {e}]",
|
|
778
|
+
state_dict_data=state_dict,
|
|
779
|
+
experiment_id=d.get("experiment_id"),
|
|
780
|
+
run_id=d.get("run_id"),
|
|
781
|
+
url=d.get("url"),
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
# No state_dict or no __class__ - use MockService with original description
|
|
785
|
+
return MockService(
|
|
786
|
+
service_id=service_id,
|
|
787
|
+
description_text=d.get("description", ""),
|
|
788
|
+
state_dict_data=state_dict,
|
|
789
|
+
experiment_id=d.get("experiment_id"),
|
|
790
|
+
run_id=d.get("run_id"),
|
|
791
|
+
url=d.get("url"),
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
def _parse_datetime_to_timestamp(self, value) -> Optional[float]:
|
|
795
|
+
"""Convert datetime value to Unix timestamp
|
|
796
|
+
|
|
797
|
+
Handles: None, ISO string, float timestamp, datetime object
|
|
798
|
+
"""
|
|
799
|
+
if value is None:
|
|
800
|
+
return None
|
|
801
|
+
if isinstance(value, (int, float)):
|
|
802
|
+
return float(value)
|
|
803
|
+
if isinstance(value, str):
|
|
804
|
+
try:
|
|
805
|
+
dt = datetime.fromisoformat(value)
|
|
806
|
+
return dt.timestamp()
|
|
807
|
+
except ValueError:
|
|
808
|
+
return None
|
|
809
|
+
if isinstance(value, datetime):
|
|
810
|
+
return value.timestamp()
|
|
811
|
+
return None
|
|
812
|
+
|
|
813
|
+
# -------------------------------------------------------------------------
|
|
814
|
+
# File Synchronization
|
|
815
|
+
# -------------------------------------------------------------------------
|
|
816
|
+
|
|
817
|
+
def sync_path(self, path: str) -> Optional[Path]:
|
|
818
|
+
"""Sync a specific path from remote on-demand
|
|
819
|
+
|
|
820
|
+
Used by services (e.g., TensorboardService) that need access to
|
|
821
|
+
specific remote directories.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
path: Can be:
|
|
825
|
+
- Remote absolute path (e.g., /remote/workspace/jobs/xxx)
|
|
826
|
+
- Local cache path (e.g., /tmp/xpm_remote_xxx/jobs/xxx)
|
|
827
|
+
- Relative path within workspace (e.g., jobs/xxx)
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
Local path where the files were synced to, or None if sync failed
|
|
831
|
+
"""
|
|
832
|
+
if not self._connected or not self.local_cache_dir:
|
|
833
|
+
logger.warning("Cannot sync: not connected")
|
|
834
|
+
return None
|
|
835
|
+
|
|
836
|
+
# Convert local cache path back to remote path if needed
|
|
837
|
+
local_cache_str = str(self.local_cache_dir)
|
|
838
|
+
if path.startswith(local_cache_str):
|
|
839
|
+
# Path is in local cache - extract relative path
|
|
840
|
+
relative = path[len(local_cache_str) :].lstrip("/")
|
|
841
|
+
remote_path = f"{self.remote_workspace}/{relative}"
|
|
842
|
+
elif path.startswith(self.remote_workspace):
|
|
843
|
+
# Already a remote path
|
|
844
|
+
remote_path = path
|
|
845
|
+
else:
|
|
846
|
+
# Assume it's a relative path
|
|
847
|
+
remote_path = f"{self.remote_workspace}/{path.lstrip('/')}"
|
|
848
|
+
|
|
849
|
+
from experimaestro.scheduler.remote.sync import RemoteFileSynchronizer
|
|
850
|
+
|
|
851
|
+
# Create synchronizer lazily
|
|
852
|
+
if self._synchronizer is None:
|
|
853
|
+
self._synchronizer = RemoteFileSynchronizer(
|
|
854
|
+
host=self.host,
|
|
855
|
+
remote_workspace=Path(self.remote_workspace),
|
|
856
|
+
local_cache=self.local_cache_dir,
|
|
857
|
+
ssh_options=self.ssh_options,
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
try:
|
|
861
|
+
return self._synchronizer.sync_path(remote_path)
|
|
862
|
+
except Exception as e:
|
|
863
|
+
logger.warning("Failed to sync path %s: %s", remote_path, e)
|
|
864
|
+
return None
|
|
865
|
+
|
|
866
|
+
@property
|
|
867
|
+
def read_only(self) -> bool:
|
|
868
|
+
"""Client is always read-only"""
|
|
869
|
+
return True
|
|
870
|
+
|
|
871
|
+
@property
|
|
872
|
+
def is_remote(self) -> bool:
|
|
873
|
+
"""This is a remote provider"""
|
|
874
|
+
return True
|