slurmray 6.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of slurmray might be problematic. Click here for more details.

slurmray/cli.py ADDED
@@ -0,0 +1,904 @@
1
+ import os
2
+ import time
3
+ import subprocess
4
+ import signal
5
+ import sys
6
+ import webbrowser
7
+ import argparse
8
+ import threading
9
+ import json
10
+ from abc import ABC, abstractmethod
11
+ from rich.console import Console
12
+ from rich.table import Table
13
+ from rich.panel import Panel
14
+ from rich.prompt import Prompt, Confirm
15
+ from rich.live import Live
16
+ from rich.layout import Layout
17
+ from dotenv import load_dotenv
18
+ import paramiko
19
+ from getpass import getpass
20
+ import inquirer
21
+
22
+ # Try to import RayLauncher components
23
+ try:
24
+ from slurmray.utils import SSHTunnel
25
+ except ImportError:
26
+ # Fallback if running from root without package installed
27
+ sys.path.append(os.getcwd())
28
+ from slurmray.utils import SSHTunnel
29
+
30
+ load_dotenv()
31
+
32
+ console = Console()
33
+
34
+ class ClusterManager(ABC):
35
+ """Abstract base class for cluster managers"""
36
+
37
+ def __init__(self, username: str = None, password: str = None, ssh_host: str = None):
38
+ self.username = username
39
+ self.password = password
40
+ self.ssh_host = ssh_host
41
+ self.active_tunnel = None
42
+ self.ssh_client = None
43
+
44
+ @abstractmethod
45
+ def _connect(self):
46
+ """Connect to the cluster if not already connected"""
47
+ pass
48
+
49
+ @abstractmethod
50
+ def get_jobs(self):
51
+ """Retrieve jobs from the cluster"""
52
+ pass
53
+
54
+ @abstractmethod
55
+ def cancel_job(self, job_id):
56
+ """Cancel a job"""
57
+ pass
58
+
59
+ @abstractmethod
60
+ def get_head_node(self, job_id):
61
+ """Get head node for a job"""
62
+ pass
63
+
64
+ def open_dashboard(self, job_id):
65
+ """Open Ray dashboard for a job"""
66
+ head_node = self.get_head_node(job_id)
67
+ if not head_node:
68
+ console.print(f"[red]Could not determine head node for job {job_id}. Is it running?[/red]")
69
+ return
70
+
71
+ console.print(f"[blue]Head node identified: {head_node}[/blue]")
72
+
73
+ if not self.password:
74
+ self.password = getpass("Enter cluster password: ")
75
+
76
+ try:
77
+ console.print("[yellow]Setting up SSH tunnel... (Press Ctrl+C to stop)[/yellow]")
78
+ self.active_tunnel = SSHTunnel(
79
+ ssh_host=self.ssh_host,
80
+ ssh_username=self.username,
81
+ ssh_password=self.password,
82
+ remote_host=head_node,
83
+ local_port=0,
84
+ remote_port=8265
85
+ )
86
+
87
+ with self.active_tunnel:
88
+ url = f"http://localhost:{self.active_tunnel.local_port}"
89
+ console.print(f"[green]Dashboard available at: {url}[/green]")
90
+
91
+ console.print("Opening browser...")
92
+ try:
93
+ webbrowser.open(url)
94
+ except Exception as e:
95
+ console.print(f"[yellow]Could not open browser automatically: {e}[/yellow]")
96
+
97
+ console.print("Tunnel active. Keeping connection alive...")
98
+ try:
99
+ while True:
100
+ time.sleep(1)
101
+ except KeyboardInterrupt:
102
+ console.print("\n[yellow]Closing tunnel...[/yellow]")
103
+ except Exception as e:
104
+ console.print(f"[red]Failed to establish tunnel: {e}[/red]")
105
+
106
+ class SlurmManager(ClusterManager):
107
+ def __init__(self, username: str = None, password: str = None, ssh_host: str = None):
108
+ super().__init__(username, password, ssh_host)
109
+ self.username = username or os.getenv("CURNAGL_USERNAME") or os.environ.get("USER")
110
+ self.password = password or os.getenv("CURNAGL_PASSWORD")
111
+ self.ssh_host = ssh_host or "curnagl.dcsr.unil.ch"
112
+
113
+ def _connect(self):
114
+ """Connect to the cluster if not already connected"""
115
+ if self.ssh_client and self.ssh_client.get_transport() and self.ssh_client.get_transport().is_active():
116
+ return
117
+
118
+ if not self.password:
119
+ self.password = getpass("Enter cluster password: ")
120
+
121
+ try:
122
+ self.ssh_client = paramiko.SSHClient()
123
+ self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
124
+ self.ssh_client.connect(
125
+ hostname=self.ssh_host,
126
+ username=self.username,
127
+ password=self.password
128
+ )
129
+ except Exception as e:
130
+ console.print(f"[red]Failed to connect to cluster: {e}[/red]")
131
+ self.ssh_client = None
132
+ raise
133
+
134
+ def run_command(self, command):
135
+ """Run a command on the cluster via SSH"""
136
+ self._connect()
137
+ stdin, stdout, stderr = self.ssh_client.exec_command(command)
138
+ return stdout.read().decode("utf-8"), stderr.read().decode("utf-8")
139
+
140
+ def get_jobs(self):
141
+ """Retrieve jobs from squeue"""
142
+ try:
143
+ # Run squeue command remotely
144
+ stdout, stderr = self.run_command(f"squeue -u {self.username} -o '%.18i %.9P %.30j %.8u %.8T %.10M %.6D %R' --noheader")
145
+
146
+ lines = stdout.strip().split("\n")
147
+ jobs = []
148
+ for line in lines:
149
+ if not line.strip():
150
+ continue
151
+ parts = line.split()
152
+ if len(parts) >= 8:
153
+ jobs.append({
154
+ "id": parts[0],
155
+ "partition": parts[1],
156
+ "name": parts[2],
157
+ "user": parts[3],
158
+ "state": parts[4],
159
+ "time": parts[5],
160
+ "nodes": parts[6],
161
+ "nodelist": parts[7]
162
+ })
163
+ return jobs
164
+ except Exception as e:
165
+ console.print(f"[red]Error retrieving jobs: {e}[/red]")
166
+ return []
167
+
168
+ def cancel_job(self, job_id):
169
+ """Cancel a SLURM job"""
170
+ try:
171
+ stdout, stderr = self.run_command(f"scancel {job_id}")
172
+ if stderr:
173
+ console.print(f"[red]Failed to cancel job {job_id}: {stderr}[/red]")
174
+ else:
175
+ console.print(f"[green]Job {job_id} cancelled successfully.[/green]")
176
+ except Exception as e:
177
+ console.print(f"[red]Error cancelling job: {e}[/red]")
178
+
179
+ def get_head_node(self, job_id):
180
+ """Get head node for a job"""
181
+ try:
182
+ # Get job info remotely
183
+ stdout, stderr = self.run_command(f"scontrol show job {job_id}")
184
+ output = stdout
185
+
186
+ # Simple parsing for NodeList
187
+ import re
188
+ match = re.search(r"NodeList=([^\s]+)", output)
189
+ if match:
190
+ nodelist = match.group(1)
191
+ # Convert nodelist to hostname
192
+ stdout, stderr = self.run_command(f"scontrol show hostnames {nodelist}")
193
+ hosts = stdout.strip().split("\n")
194
+ if hosts:
195
+ return hosts[0]
196
+ return None
197
+ except Exception:
198
+ return None
199
+
200
+
201
+ class DesiManager(ClusterManager):
202
+ """Manager for Desi server (ISIPOL09) using Smart Lock"""
203
+
204
+ LOCK_FILE = "/tmp/slurmray_desi.lock"
205
+ QUEUE_FILE = "/tmp/slurmray_desi.queue"
206
+
207
+ def __init__(self, username: str = None, password: str = None, ssh_host: str = None):
208
+ super().__init__(username, password, ssh_host)
209
+ self.username = username or os.getenv("DESI_USERNAME") or os.environ.get("USER")
210
+ self.password = password or os.getenv("DESI_PASSWORD")
211
+ self.ssh_host = ssh_host or "130.223.73.209"
212
+ self.base_dir = f"/home/{self.username}/slurmray-server"
213
+
214
+ def _connect(self):
215
+ """Connect to the Desi server if not already connected"""
216
+ if self.ssh_client and self.ssh_client.get_transport() and self.ssh_client.get_transport().is_active():
217
+ return
218
+
219
+ if not self.password:
220
+ self.password = getpass("Enter Desi server password: ")
221
+
222
+ try:
223
+ self.ssh_client = paramiko.SSHClient()
224
+ self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
225
+ self.ssh_client.connect(
226
+ hostname=self.ssh_host,
227
+ username=self.username,
228
+ password=self.password
229
+ )
230
+ except Exception as e:
231
+ console.print(f"[red]Failed to connect to Desi server: {e}[/red]")
232
+ self.ssh_client = None
233
+ raise
234
+
235
+ def run_command(self, command):
236
+ """Run a command on the Desi server via SSH"""
237
+ self._connect()
238
+ stdin, stdout, stderr = self.ssh_client.exec_command(command)
239
+ exit_status = stdout.channel.recv_exit_status()
240
+ return stdout.read().decode("utf-8"), stderr.read().decode("utf-8"), exit_status
241
+
242
+ def _read_queue(self):
243
+ """Read queue file from remote server (read-only, no lock needed)"""
244
+ try:
245
+ stdout, stderr, exit_status = self.run_command(
246
+ f"test -f {self.QUEUE_FILE} && cat {self.QUEUE_FILE} || echo '[]'"
247
+ )
248
+ if exit_status != 0:
249
+ return []
250
+ import json
251
+ queue_data = json.loads(stdout.strip() or "[]")
252
+ return queue_data if isinstance(queue_data, list) else []
253
+ except (json.JSONDecodeError, ValueError, Exception):
254
+ return []
255
+
256
+ def _write_queue(self, queue_data):
257
+ """Write queue file to remote server with lock management"""
258
+ import json
259
+ import tempfile
260
+ import base64
261
+
262
+ # Create temporary file with queue data
263
+ queue_json = json.dumps(queue_data, indent=2)
264
+ queue_b64 = base64.b64encode(queue_json.encode('utf-8')).decode('utf-8')
265
+
266
+ # Write using Python on remote with lock
267
+ python_script = f"""
268
+ import json
269
+ import fcntl
270
+ import base64
271
+ import sys
272
+
273
+ QUEUE_FILE = "{self.QUEUE_FILE}"
274
+ queue_b64 = "{queue_b64}"
275
+
276
+ max_retries = 10
277
+ retry_delay = 0.1
278
+ success = False
279
+
280
+ for attempt in range(max_retries):
281
+ try:
282
+ queue_fd = open(QUEUE_FILE, 'w')
283
+ try:
284
+ fcntl.lockf(queue_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
285
+ queue_json = base64.b64decode(queue_b64).decode('utf-8')
286
+ queue_data = json.loads(queue_json)
287
+ json.dump(queue_data, queue_fd, indent=2)
288
+ queue_fd.flush()
289
+ import os
290
+ os.fsync(queue_fd.fileno())
291
+ fcntl.lockf(queue_fd, fcntl.LOCK_UN)
292
+ queue_fd.close()
293
+ success = True
294
+ break
295
+ except IOError:
296
+ queue_fd.close()
297
+ if attempt < max_retries - 1:
298
+ import time
299
+ time.sleep(retry_delay)
300
+ continue
301
+ except IOError:
302
+ if attempt < max_retries - 1:
303
+ import time
304
+ time.sleep(retry_delay)
305
+ continue
306
+
307
+ sys.exit(0 if success else 1)
308
+ """
309
+ # Write script to temp file and execute
310
+ stdout, stderr, exit_status = self.run_command(
311
+ f"python3 -c {repr(python_script)}"
312
+ )
313
+ return exit_status == 0
314
+
315
+ def _clean_dead_processes_from_queue(self):
316
+ """Remove entries for processes that no longer exist"""
317
+ queue = self._read_queue()
318
+ if not queue:
319
+ return
320
+
321
+ cleaned_queue = []
322
+ pids_to_check = [entry.get("pid") for entry in queue if entry.get("pid")]
323
+
324
+ if not pids_to_check:
325
+ return
326
+
327
+ # Check all PIDs at once
328
+ pids_str = " ".join(pids_to_check)
329
+ stdout, stderr, exit_status = self.run_command(
330
+ f"for pid in {pids_str}; do ps -p $pid >/dev/null 2>&1 && echo $pid; done"
331
+ )
332
+ alive_pids = set(stdout.strip().split())
333
+
334
+ # Keep only entries with alive processes
335
+ cleaned_queue = [entry for entry in queue if entry.get("pid") in alive_pids]
336
+
337
+ # If queue changed, write it back
338
+ if len(cleaned_queue) != len(queue):
339
+ self._write_queue(cleaned_queue)
340
+
341
+ def _get_function_name_from_project(self, project_dir):
342
+ """Try to read function name from func_name.txt in project directory"""
343
+ try:
344
+ # Try direct path first
345
+ func_name_path = f"{project_dir}/func_name.txt"
346
+ stdout, stderr, exit_status = self.run_command(
347
+ f"test -f {func_name_path} && cat {func_name_path} || echo ''"
348
+ )
349
+ func_name = stdout.strip()
350
+ if func_name:
351
+ return func_name
352
+
353
+ # If not found, try to find func_name.txt in parent directories (up to 2 levels)
354
+ # Sometimes the process might be running from a subdirectory
355
+ for parent_level in [1, 2]:
356
+ parent_dir = "/".join(project_dir.split("/")[:-parent_level]) if parent_level > 0 else project_dir
357
+ if not parent_dir or parent_dir == "/":
358
+ break
359
+ func_name_path = f"{parent_dir}/func_name.txt"
360
+ stdout, stderr, exit_status = self.run_command(
361
+ f"test -f {func_name_path} && cat {func_name_path} || echo ''"
362
+ )
363
+ func_name = stdout.strip()
364
+ if func_name:
365
+ return func_name
366
+
367
+ return None
368
+ except Exception:
369
+ return None
370
+
371
+ def _get_project_dir_from_pid(self, pid):
372
+ """Try to get project directory from process working directory"""
373
+ try:
374
+ # Try multiple methods to get working directory
375
+ # Method 1: readlink /proc/{pid}/cwd (Linux)
376
+ stdout, stderr, exit_status = self.run_command(
377
+ f"readlink /proc/{pid}/cwd 2>/dev/null || echo ''"
378
+ )
379
+ cwd = stdout.strip()
380
+
381
+ # Method 2: pwdx (if available)
382
+ if not cwd:
383
+ stdout, stderr, exit_status = self.run_command(
384
+ f"pwdx {pid} 2>/dev/null | awk '{{print $2}}' || echo ''"
385
+ )
386
+ cwd = stdout.strip()
387
+
388
+ # Method 3: ps -o cwd
389
+ if not cwd:
390
+ stdout, stderr, exit_status = self.run_command(
391
+ f"ps -p {pid} -o cwd --no-headers 2>/dev/null || echo ''"
392
+ )
393
+ cwd = stdout.strip()
394
+
395
+ if cwd and "slurmray-server" in cwd:
396
+ # Extract project directory (should be in /home/{user}/slurmray-server/{project_name})
397
+ return cwd
398
+ return None
399
+ except Exception:
400
+ return None
401
+
402
+ def get_jobs(self):
403
+ """Retrieve jobs from Desi using queue file"""
404
+ try:
405
+ self._connect()
406
+ jobs = []
407
+
408
+ # Clean dead processes from queue first
409
+ self._clean_dead_processes_from_queue()
410
+
411
+ # Read queue file
412
+ queue = self._read_queue()
413
+ if not queue:
414
+ return []
415
+
416
+ # Separate running and waiting jobs, sort waiting by timestamp
417
+ running_jobs = []
418
+ waiting_jobs = []
419
+
420
+ for entry in queue:
421
+ pid = entry.get("pid")
422
+ if not pid:
423
+ continue
424
+
425
+ # Verify process exists
426
+ stdout, stderr, exit_status = self.run_command(
427
+ f"ps -p {pid} >/dev/null 2>&1 && echo 'alive' || echo 'dead'"
428
+ )
429
+ if "dead" in stdout.strip():
430
+ # Process is dead, skip (will be cleaned next time)
431
+ continue
432
+
433
+ # Get elapsed time from ps (more accurate than calculating from timestamp)
434
+ stdout_etime, stderr_etime, exit_status_etime = self.run_command(
435
+ f"ps -p {pid} -o etime --no-headers 2>/dev/null || echo 'N/A'"
436
+ )
437
+ elapsed_time = stdout_etime.strip() or "N/A"
438
+
439
+ # Extract job information
440
+ func_name = entry.get("func_name")
441
+ user = entry.get("user", "unknown")
442
+ status = entry.get("status", "unknown")
443
+
444
+ # If func_name is missing or "Unknown function", try to read it from project_dir
445
+ if not func_name or func_name == "Unknown function":
446
+ project_dir = entry.get("project_dir")
447
+ if project_dir:
448
+ func_name = self._get_function_name_from_project(project_dir)
449
+ if not func_name or func_name == "Unknown function":
450
+ # Fail-fast: if we can't get the function name, it's an error
451
+ console.print(f"[yellow]Warning: Job {pid} has no function name in queue and func_name.txt not found[/yellow]")
452
+ func_name = "ERROR: func_name missing"
453
+
454
+ job = {
455
+ "id": f"desi-{pid}",
456
+ "name": func_name,
457
+ "user": user,
458
+ "state": "RUNNING" if status == "running" else "WAITING",
459
+ "time": elapsed_time,
460
+ "nodes": "1",
461
+ "nodelist": "localhost",
462
+ "pid": pid,
463
+ "timestamp": entry.get("timestamp", 0),
464
+ "queue_position": None
465
+ }
466
+
467
+ if status == "running":
468
+ running_jobs.append(job)
469
+ elif status == "waiting":
470
+ waiting_jobs.append(job)
471
+
472
+ # Sort waiting jobs by timestamp (oldest first)
473
+ waiting_jobs.sort(key=lambda x: x.get("timestamp", 0))
474
+
475
+ # Calculate queue positions for waiting jobs
476
+ for idx, job in enumerate(waiting_jobs):
477
+ job["queue_position"] = idx + 1
478
+
479
+ # Combine: running jobs first, then waiting jobs
480
+ jobs = running_jobs + waiting_jobs
481
+
482
+ return jobs
483
+ except Exception as e:
484
+ console.print(f"[red]Error retrieving jobs: {e}[/red]")
485
+ return []
486
+
487
+ def cancel_job(self, job_id):
488
+ """Cancel a Desi job by killing the process"""
489
+ try:
490
+ # Extract PID from job_id (format: desi-<pid>)
491
+ if job_id.startswith("desi-"):
492
+ pid = job_id.split("-", 1)[1]
493
+ else:
494
+ pid = job_id
495
+
496
+ # Kill the process
497
+ stdout, stderr, exit_status = self.run_command(f"kill -TERM {pid} 2>&1")
498
+ if exit_status == 0:
499
+ console.print(f"[green]Job {job_id} (PID {pid}) cancelled successfully.[/green]")
500
+ # Wait a bit and force kill if still running
501
+ time.sleep(2)
502
+ stdout, stderr, exit_status = self.run_command(f"kill -9 {pid} 2>&1")
503
+ else:
504
+ console.print(f"[red]Failed to cancel job {job_id}: {stderr or 'Process not found'}[/red]")
505
+ except Exception as e:
506
+ console.print(f"[red]Error cancelling job: {e}[/red]")
507
+
508
+ def get_head_node(self, job_id):
509
+ """Get head node for a Desi job (always localhost)"""
510
+ # For Desi, the head node is always localhost since it's a single machine
511
+ return "127.0.0.1"
512
+
513
+ def display_jobs_table(jobs, cluster_type="slurm"):
514
+ """Display jobs in a table, adapting to cluster type"""
515
+ if cluster_type == "slurm":
516
+ table = Table(title=f"Slurm Jobs ({len(jobs)})")
517
+ table.add_column("ID", style="cyan", no_wrap=True)
518
+ table.add_column("Name", style="magenta")
519
+ table.add_column("State", style="green")
520
+ table.add_column("Time", style="yellow")
521
+ table.add_column("Nodes", justify="right")
522
+ table.add_column("NodeList", style="blue")
523
+
524
+ for i, job in enumerate(jobs):
525
+ # Add position in queue for pending jobs
526
+ state_display = job["state"]
527
+ if job["state"] == "PENDING":
528
+ state_display += f" (#{i+1})"
529
+
530
+ table.add_row(
531
+ job["id"],
532
+ job["name"],
533
+ state_display,
534
+ job["time"],
535
+ job["nodes"],
536
+ job["nodelist"]
537
+ )
538
+ else: # desi
539
+ table = Table(title=f"Desi Jobs ({len(jobs)})")
540
+ table.add_column("ID", style="cyan", no_wrap=True)
541
+ table.add_column("Name", style="magenta")
542
+ table.add_column("User", style="blue", no_wrap=True)
543
+ table.add_column("State", style="green")
544
+ table.add_column("Queue", style="yellow", justify="center")
545
+ table.add_column("Time", style="yellow")
546
+ table.add_column("PID", style="blue", no_wrap=True)
547
+
548
+ for job in jobs:
549
+ # Format state display
550
+ state_display = job["state"]
551
+ if job["state"] == "WAITING":
552
+ queue_pos = job.get("queue_position")
553
+ if queue_pos:
554
+ state_display = f"WAITING (#{queue_pos})"
555
+
556
+ # Format queue column
557
+ queue_display = "-"
558
+ if job["state"] == "WAITING":
559
+ queue_pos = job.get("queue_position")
560
+ if queue_pos:
561
+ queue_display = f"#{queue_pos}"
562
+ elif job["state"] == "RUNNING":
563
+ queue_display = "RUNNING"
564
+
565
+ table.add_row(
566
+ job["id"],
567
+ job.get("name", "N/A"),
568
+ job.get("user", "N/A"),
569
+ state_display,
570
+ queue_display,
571
+ job.get("time", "N/A"),
572
+ job.get("pid", "N/A")
573
+ )
574
+ return table
575
+
576
+ def main():
577
+ parser = argparse.ArgumentParser(
578
+ description="SlurmRay CLI - Interactive job manager for Slurm clusters and Desi server",
579
+ formatter_class=argparse.RawDescriptionHelpFormatter,
580
+ epilog="""
581
+ Examples:
582
+ slurmray # Show help
583
+ slurmray curnagl # Connect to Curnagl (Slurm cluster)
584
+ slurmray desi # Connect to Desi server (ISIPOL09)
585
+ """
586
+ )
587
+ parser.add_argument(
588
+ "cluster",
589
+ nargs="?",
590
+ choices=["curnagl", "desi"],
591
+ help="Cluster to connect to (curnagl for Slurm, desi for Desi server)"
592
+ )
593
+ parser.add_argument(
594
+ "--username",
595
+ help="Username for SSH connection (overrides environment variables)"
596
+ )
597
+ parser.add_argument(
598
+ "--password",
599
+ help="Password for SSH connection (overrides environment variables, not recommended)"
600
+ )
601
+ parser.add_argument(
602
+ "--host",
603
+ help="SSH hostname (overrides default: curnagl.dcsr.unil.ch for Curnagl, 130.223.73.209 for Desi)"
604
+ )
605
+
606
+ args = parser.parse_args()
607
+
608
+ # If no cluster specified, show help
609
+ if not args.cluster:
610
+ parser.print_help()
611
+ return
612
+
613
+ # Create appropriate manager
614
+ if args.cluster == "curnagl":
615
+ manager = SlurmManager(
616
+ username=args.username,
617
+ password=args.password,
618
+ ssh_host=args.host
619
+ )
620
+ cluster_name = "Curnagl (Slurm)"
621
+ cluster_type = "slurm"
622
+ elif args.cluster == "desi":
623
+ manager = DesiManager(
624
+ username=args.username,
625
+ password=args.password,
626
+ ssh_host=args.host
627
+ )
628
+ cluster_name = "Desi (ISIPOL09)"
629
+ cluster_type = "desi"
630
+ else:
631
+ console.print("[red]Invalid cluster specified[/red]")
632
+ return
633
+
634
+ # Main interactive loop with auto-refresh
635
+ refresh_interval = 10 # seconds
636
+ jobs_data = {"jobs": [], "last_refresh": 0}
637
+ should_exit = threading.Event()
638
+ user_wants_menu = threading.Event()
639
+
640
+ def auto_refresh_worker():
641
+ """Background thread to auto-refresh jobs every refresh_interval seconds"""
642
+ while not should_exit.is_set():
643
+ time.sleep(refresh_interval)
644
+ if not should_exit.is_set():
645
+ jobs_data["jobs"] = manager.get_jobs()
646
+ jobs_data["last_refresh"] = time.time()
647
+
648
+ # Start auto-refresh thread
649
+ refresh_thread = threading.Thread(target=auto_refresh_worker, daemon=True)
650
+ refresh_thread.start()
651
+
652
+ # Initial load
653
+ jobs_data["jobs"] = manager.get_jobs()
654
+ jobs_data["last_refresh"] = time.time()
655
+
656
+ def render_display():
657
+ """Render the current display"""
658
+ jobs = jobs_data["jobs"]
659
+ last_refresh_time = time.strftime('%H:%M:%S', time.localtime(jobs_data["last_refresh"]))
660
+
661
+ # Create layout
662
+ layout = Layout()
663
+ layout.split_column(
664
+ Layout(name="header", size=3),
665
+ Layout(name="table"),
666
+ Layout(name="footer", size=4)
667
+ )
668
+
669
+ layout["header"].update(Panel(f"[bold blue]SlurmRay Manager - {cluster_name}[/bold blue]", style="bold"))
670
+ layout["table"].update(display_jobs_table(jobs, cluster_type))
671
+
672
+ footer_text = f"[yellow]Auto-refreshing every {refresh_interval}s... Last refresh: {last_refresh_time}[/yellow]\n"
673
+ footer_text += "[dim]Press Enter to open menu (↑↓ to navigate, ←→ to select)[/dim]"
674
+ layout["footer"].update(Panel(footer_text))
675
+
676
+ return layout
677
+
678
+ # Main loop: alternate between Live display and menu
679
+ while not should_exit.is_set():
680
+ try:
681
+ # Phase 1: Live auto-updating display
682
+ with Live(render_display(), refresh_per_second=2, screen=True) as live:
683
+ import sys
684
+ import select
685
+ import tty
686
+ import termios
687
+
688
+ # Set terminal to cbreak mode (less disruptive than raw mode)
689
+ old_settings = termios.tcgetattr(sys.stdin)
690
+ try:
691
+ tty.setcbreak(sys.stdin.fileno())
692
+ except Exception:
693
+ # If terminal manipulation fails, use simple mode
694
+ pass
695
+
696
+ try:
697
+ while not should_exit.is_set():
698
+ # Update display with latest data (reads from jobs_data which is updated by thread)
699
+ live.update(render_display())
700
+
701
+ # Check if user pressed Enter (non-blocking)
702
+ try:
703
+ if select.select([sys.stdin], [], [], 0.1)[0]:
704
+ char = sys.stdin.read(1)
705
+ if char == '\n' or char == '\r':
706
+ # Exit Live to show menu
707
+ break
708
+ elif char == '\x03': # Ctrl+C
709
+ should_exit.set()
710
+ break
711
+ except (OSError, ValueError):
712
+ # Terminal might not support select, fall back to simple wait
713
+ time.sleep(0.5)
714
+ continue
715
+
716
+ # Small sleep to avoid busy waiting
717
+ time.sleep(0.1)
718
+ finally:
719
+ # Restore terminal settings
720
+ try:
721
+ termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
722
+ except Exception:
723
+ pass
724
+
725
+ if should_exit.is_set():
726
+ break
727
+
728
+ # Phase 2: Show menu with inquirer (outside Live context)
729
+ # Refresh jobs before showing menu
730
+ jobs_data["jobs"] = manager.get_jobs()
731
+ jobs_data["last_refresh"] = time.time()
732
+
733
+ console.clear()
734
+ console.rule(f"[bold blue]SlurmRay Manager - {cluster_name}[/bold blue]")
735
+
736
+ # Use latest jobs data
737
+ jobs = jobs_data["jobs"]
738
+ console.print(display_jobs_table(jobs, cluster_type))
739
+
740
+ last_refresh_time = time.strftime('%H:%M:%S', time.localtime(jobs_data["last_refresh"]))
741
+ console.print(f"\n[yellow]Auto-refreshing every {refresh_interval}s... Last refresh: {last_refresh_time}[/yellow]")
742
+ console.print("[dim]Navigation: Use arrow keys (↑↓) to navigate, Enter to select[/dim]")
743
+
744
+ # Show menu with inquirer for arrow key navigation
745
+ questions = [
746
+ inquirer.List(
747
+ 'action',
748
+ message="Select an option",
749
+ choices=[
750
+ ('Refresh Now', 'refresh'),
751
+ ('Cancel Job', 'cancel'),
752
+ ('Open Dashboard', 'dashboard'),
753
+ ('Quit', 'quit')
754
+ ],
755
+ default='refresh'
756
+ )
757
+ ]
758
+
759
+ answers = inquirer.prompt(questions)
760
+ if not answers:
761
+ # User cancelled (Ctrl+C)
762
+ break
763
+
764
+ user_input = answers['action']
765
+
766
+ # Handle user input
767
+ if user_input == 'quit':
768
+ should_exit.set()
769
+ console.print("Bye!")
770
+ break
771
+ elif user_input == 'refresh':
772
+ jobs_data["jobs"] = manager.get_jobs()
773
+ jobs_data["last_refresh"] = time.time()
774
+ # Return to Live display (continue loop)
775
+ continue
776
+ elif user_input == 'cancel':
777
+ # Refresh jobs before showing cancel menu
778
+ jobs_data["jobs"] = manager.get_jobs()
779
+ jobs_data["last_refresh"] = time.time()
780
+ jobs = jobs_data["jobs"]
781
+ if jobs:
782
+ job_choices = [f"{job['id']} - {job.get('name', 'N/A')}" for job in jobs]
783
+ cancel_questions = [
784
+ inquirer.List(
785
+ 'job_id',
786
+ message="Select job to cancel",
787
+ choices=job_choices
788
+ )
789
+ ]
790
+ cancel_answers = inquirer.prompt(cancel_questions)
791
+ if cancel_answers:
792
+ selected_job = cancel_answers['job_id']
793
+ job_id = selected_job.split(' - ')[0]
794
+ if Confirm.ask(f"Are you sure you want to cancel job {job_id}?"):
795
+ manager.cancel_job(job_id)
796
+ console.print("[green]Job cancelled. Refreshing...[/green]")
797
+ time.sleep(1.5)
798
+ jobs_data["jobs"] = manager.get_jobs()
799
+ jobs_data["last_refresh"] = time.time()
800
+ else:
801
+ console.print("[yellow]No jobs to cancel[/yellow]")
802
+ time.sleep(1)
803
+ # Return to Live display (continue loop)
804
+ continue
805
+ elif user_input == 'dashboard':
806
+ # Refresh jobs before showing dashboard menu
807
+ jobs_data["jobs"] = manager.get_jobs()
808
+ jobs_data["last_refresh"] = time.time()
809
+ jobs = jobs_data["jobs"]
810
+ if jobs:
811
+ job_choices = [f"{job['id']} - {job.get('name', 'N/A')}" for job in jobs]
812
+ dashboard_questions = [
813
+ inquirer.List(
814
+ 'job_id',
815
+ message="Select job to connect to",
816
+ choices=job_choices
817
+ )
818
+ ]
819
+ dashboard_answers = inquirer.prompt(dashboard_questions)
820
+ if dashboard_answers:
821
+ selected_job = dashboard_answers['job_id']
822
+ job_id = selected_job.split(' - ')[0]
823
+ manager.open_dashboard(job_id)
824
+ # Refresh after tunnel close
825
+ jobs_data["jobs"] = manager.get_jobs()
826
+ jobs_data["last_refresh"] = time.time()
827
+ else:
828
+ console.print("[yellow]No jobs available[/yellow]")
829
+ time.sleep(1)
830
+ # Return to Live display (continue loop)
831
+ continue
832
+
833
+ except (KeyboardInterrupt, EOFError):
834
+ should_exit.set()
835
+ console.print("\nBye!")
836
+ break
837
+ except Exception as e:
838
+ # Fallback if terminal manipulation fails
839
+ console.print(f"[yellow]Warning: Auto-refresh unavailable ({e}). Using simple mode...[/yellow]")
840
+ should_exit.set()
841
+ # Simple fallback without auto-refresh
842
+ while True:
843
+ console.clear()
844
+ console.rule(f"[bold blue]SlurmRay Manager - {cluster_name}[/bold blue]")
845
+ jobs = manager.get_jobs()
846
+ console.print(display_jobs_table(jobs, cluster_type))
847
+
848
+ questions = [
849
+ inquirer.List(
850
+ 'action',
851
+ message="Select an option",
852
+ choices=[
853
+ ('Refresh', 'refresh'),
854
+ ('Cancel Job', 'cancel'),
855
+ ('Open Dashboard', 'dashboard'),
856
+ ('Quit', 'quit')
857
+ ],
858
+ default='refresh'
859
+ )
860
+ ]
861
+
862
+ answers = inquirer.prompt(questions)
863
+ if not answers or answers['action'] == 'quit':
864
+ break
865
+ elif answers['action'] == 'refresh':
866
+ continue
867
+ elif answers['action'] == 'cancel':
868
+ jobs = manager.get_jobs()
869
+ if jobs:
870
+ job_choices = [f"{job['id']} - {job.get('name', 'N/A')}" for job in jobs]
871
+ cancel_questions = [
872
+ inquirer.List(
873
+ 'job_id',
874
+ message="Select job to cancel",
875
+ choices=job_choices
876
+ )
877
+ ]
878
+ cancel_answers = inquirer.prompt(cancel_questions)
879
+ if cancel_answers:
880
+ selected_job = cancel_answers['job_id']
881
+ job_id = selected_job.split(' - ')[0]
882
+ if Confirm.ask(f"Are you sure you want to cancel job {job_id}?"):
883
+ manager.cancel_job(job_id)
884
+ time.sleep(1.5)
885
+ elif answers['action'] == 'dashboard':
886
+ jobs = manager.get_jobs()
887
+ if jobs:
888
+ job_choices = [f"{job['id']} - {job.get('name', 'N/A')}" for job in jobs]
889
+ dashboard_questions = [
890
+ inquirer.List(
891
+ 'job_id',
892
+ message="Select job to connect to",
893
+ choices=job_choices
894
+ )
895
+ ]
896
+ dashboard_answers = inquirer.prompt(dashboard_questions)
897
+ if dashboard_answers:
898
+ selected_job = dashboard_answers['job_id']
899
+ job_id = selected_job.split(' - ')[0]
900
+ manager.open_dashboard(job_id)
901
+ break
902
+
903
+ if __name__ == "__main__":
904
+ main()