slurmray 6.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of slurmray might be problematic. Click here for more details.

@@ -0,0 +1,1234 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import subprocess
5
+ import paramiko
6
+ import dill
7
+ import re
8
+ from getpass import getpass
9
+ from typing import Any
10
+
11
+ from slurmray.backend.base import ClusterBackend
12
+ from slurmray.backend.remote import RemoteMixin
13
+ from slurmray.utils import SSHTunnel, DependencyManager
14
+
15
+
16
+ class SlurmBackend(RemoteMixin):
17
+ """Backend for Slurm cluster execution (local or remote via SSH)"""
18
+
19
+ def __init__(self, launcher):
20
+ super().__init__(launcher)
21
+ self.ssh_client = None
22
+ self.job_id = None
23
+
24
+ def run(self, cancel_old_jobs: bool = True, wait: bool = True) -> Any:
25
+ """Run the job on Slurm (locally or remotely)"""
26
+
27
+ # Generate the Python script (spython.py) - needed for both modes
28
+ if not self.launcher.server_run:
29
+ self._write_python_script()
30
+ self.script_file, self.job_name = self._write_slurm_script()
31
+ else:
32
+ # For server run, spython.py is generated inside __write_server_script -> slurmray_server.sh execution context?
33
+ # Wait, in original code:
34
+ # if not self.server_run:
35
+ # self.__write_python_script()
36
+ # self.script_file, self.job_name = self.__write_slurm_script()
37
+ #
38
+ # But in __launch_server, it writes server script, pushes files, and runs it.
39
+ # The server script (slurmray_server.py) will instantiate RayLauncher on the cluster!
40
+ # So when running on cluster (via SSH), RayLauncher is instantiated again on the remote machine.
41
+ # On the remote machine, self.cluster will be True (detected /usr/bin/sbatch).
42
+ # So the remote instance will enter the "if self.cluster:" block.
43
+ pass
44
+
45
+ if self.launcher.cluster:
46
+ self.logger.info("Cluster detected, running on cluster...")
47
+ # Cancel the old jobs
48
+ if cancel_old_jobs:
49
+ self.logger.info("Canceling old jobs...")
50
+ subprocess.Popen(
51
+ ["scancel", "-u", os.environ["USER"]],
52
+ stdout=subprocess.PIPE,
53
+ stderr=subprocess.PIPE,
54
+ )
55
+ # Launch the job
56
+ # We need to ensure scripts are written if we are on the cluster
57
+ if not hasattr(self, "script_file"):
58
+ self._write_python_script()
59
+ self.script_file, self.job_name = self._write_slurm_script()
60
+
61
+ self._launch_job(self.script_file, self.job_name)
62
+
63
+ if not wait:
64
+ return self.job_id
65
+
66
+ elif self.launcher.server_run:
67
+ return self._launch_server(cancel_old_jobs, wait=wait)
68
+
69
+ # Load the result
70
+ # Note: In server_run mode, _launch_server downloads result.pkl
71
+ # In cluster mode, we wait for result.pkl
72
+
73
+ if self.launcher.cluster:
74
+ # Wait for result in cluster mode (same filesystem)
75
+ while not os.path.exists(
76
+ os.path.join(self.launcher.project_path, "result.pkl")
77
+ ):
78
+ time.sleep(0.25)
79
+
80
+ # Result should be there now
81
+ with open(os.path.join(self.launcher.project_path, "result.pkl"), "rb") as f:
82
+ result = dill.load(f)
83
+
84
+ return result
85
+
86
+ def cancel(self, job_id: str):
87
+ """Cancel a job"""
88
+ if self.launcher.cluster:
89
+ self.logger.info(f"Canceling local job {job_id}...")
90
+ try:
91
+ subprocess.run(
92
+ ["scancel", job_id],
93
+ stdout=subprocess.PIPE,
94
+ stderr=subprocess.PIPE,
95
+ check=False,
96
+ )
97
+ self.logger.info(f"Job {job_id} canceled.")
98
+ except Exception as e:
99
+ self.logger.error(f"Failed to cancel job {job_id}: {e}")
100
+ elif self.launcher.server_run and self.ssh_client:
101
+ self.logger.info(f"Canceling remote job {job_id} via SSH...")
102
+ try:
103
+ # Check if connection is still active
104
+ if (
105
+ self.ssh_client.get_transport()
106
+ and self.ssh_client.get_transport().is_active()
107
+ ):
108
+ self.ssh_client.exec_command(f"scancel {job_id}")
109
+ self.logger.info(f"Remote job {job_id} canceled.")
110
+ else:
111
+ self.logger.warning(
112
+ "SSH connection lost, cannot cancel remote job."
113
+ )
114
+ except Exception as e:
115
+ self.logger.error(f"Failed to cancel remote job {job_id}: {e}")
116
+ finally:
117
+ try:
118
+ self.ssh_client.close()
119
+ except Exception:
120
+ pass
121
+
122
+ # =========================================================================
123
+ # Private methods extracted from RayLauncher
124
+ # =========================================================================
125
+
126
+ def _write_python_script(self):
127
+ """Write the python script that will be executed by the job"""
128
+ self.logger.info("Writing python script...")
129
+
130
+ # Remove the old python script
131
+ for file in os.listdir(self.launcher.project_path):
132
+ if file.endswith(".py"):
133
+ os.remove(os.path.join(self.launcher.project_path, file))
134
+
135
+ # Write the python script
136
+ with open(
137
+ os.path.join(self.launcher.module_path, "assets", "spython_template.py"),
138
+ "r",
139
+ ) as f:
140
+ text = f.read()
141
+
142
+ text = text.replace("{{PROJECT_PATH}}", f'"{self.launcher.project_path}"')
143
+ local_mode = ""
144
+ if self.launcher.cluster or self.launcher.server_run:
145
+ # Add Ray warning suppression to runtime_env if not already present
146
+ runtime_env = self.launcher.runtime_env.copy()
147
+ if "env_vars" not in runtime_env:
148
+ runtime_env["env_vars"] = {}
149
+ if "RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO" not in runtime_env["env_vars"]:
150
+ runtime_env["env_vars"]["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
151
+
152
+ local_mode = f"\n\taddress='auto',\n\tinclude_dashboard=True,\n\tdashboard_host='0.0.0.0',\n\tdashboard_port=8265,\nruntime_env = {runtime_env},\n"
153
+ text = text.replace(
154
+ "{{LOCAL_MODE}}",
155
+ local_mode,
156
+ )
157
+ with open(os.path.join(self.launcher.project_path, "spython.py"), "w") as f:
158
+ f.write(text)
159
+
160
+ def _write_slurm_script(self):
161
+ """Write the slurm script that will be executed by the job"""
162
+ self.logger.info("Writing slurm script...")
163
+ template_file = os.path.join(
164
+ self.launcher.module_path, "assets", "sbatch_template.sh"
165
+ )
166
+
167
+ JOB_NAME = "{{JOB_NAME}}"
168
+ NUM_NODES = "{{NUM_NODES}}"
169
+ MEMORY = "{{MEMORY}}"
170
+ RUNNING_TIME = "{{RUNNING_TIME}}"
171
+ PARTITION_NAME = "{{PARTITION_NAME}}"
172
+ COMMAND_PLACEHOLDER = "{{COMMAND_PLACEHOLDER}}"
173
+ GIVEN_NODE = "{{GIVEN_NODE}}"
174
+ COMMAND_SUFFIX = "{{COMMAND_SUFFIX}}"
175
+ LOAD_ENV = "{{LOAD_ENV}}"
176
+ PARTITION_SPECIFICS = "{{PARTITION_SPECIFICS}}"
177
+
178
+ job_name = "{}_{}".format(
179
+ self.launcher.project_name, time.strftime("%d%m-%Hh%M", time.localtime())
180
+ )
181
+
182
+ # Convert the time to xx:xx:xx format
183
+ max_time = "{}:{}:{}".format(
184
+ str(self.launcher.max_running_time // 60).zfill(2),
185
+ str(self.launcher.max_running_time % 60).zfill(2),
186
+ str(0).zfill(2),
187
+ )
188
+
189
+ # ===== Modified the template script =====
190
+ with open(template_file, "r") as f:
191
+ text = f.read()
192
+ text = text.replace(
193
+ JOB_NAME, os.path.join(self.launcher.project_path, job_name)
194
+ )
195
+ text = text.replace(NUM_NODES, str(self.launcher.node_nbr))
196
+ text = text.replace(MEMORY, str(self.launcher.memory))
197
+ text = text.replace(RUNNING_TIME, str(max_time))
198
+ text = text.replace(
199
+ PARTITION_NAME, str("gpu" if self.launcher.use_gpu > 0 else "cpu")
200
+ )
201
+ text = text.replace(
202
+ COMMAND_PLACEHOLDER,
203
+ str(f"{sys.executable} {self.launcher.project_path}/spython.py"),
204
+ )
205
+ text = text.replace(
206
+ LOAD_ENV, str(f"module load {' '.join(self.launcher.modules)}")
207
+ )
208
+ text = text.replace(GIVEN_NODE, "")
209
+ text = text.replace(COMMAND_SUFFIX, "")
210
+ text = text.replace(
211
+ "# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "PRODUCTION!",
212
+ "# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
213
+ "RUNNABLE!",
214
+ )
215
+
216
+ # ===== Add partition specifics =====
217
+ if self.launcher.use_gpu > 0:
218
+ text = text.replace(
219
+ PARTITION_SPECIFICS,
220
+ str("#SBATCH --gres gpu:1\n#SBATCH --gres-flags enforce-binding"),
221
+ )
222
+ else:
223
+ text = text.replace(PARTITION_SPECIFICS, "#SBATCH --exclusive")
224
+
225
+ # ===== Save the script =====
226
+ script_file = "sbatch.sh"
227
+ with open(os.path.join(self.launcher.project_path, script_file), "w") as f:
228
+ f.write(text)
229
+
230
+ return script_file, job_name
231
+
232
+ def _launch_job(self, script_file: str = None, job_name: str = None):
233
+ """Launch the job"""
234
+ # ===== Submit the job =====
235
+ self.logger.info("Start to submit job!")
236
+ result = subprocess.run(
237
+ ["sbatch", os.path.join(self.launcher.project_path, script_file)],
238
+ capture_output=True,
239
+ text=True,
240
+ )
241
+ if result.returncode != 0:
242
+ self.logger.error(f"Error submitting job: {result.stderr}")
243
+ return
244
+
245
+ # Extract job ID from output (format: "Submitted batch job 12345")
246
+ job_id = None
247
+ if result.stdout:
248
+ match = re.search(r"Submitted batch job (\d+)", result.stdout)
249
+ if match:
250
+ job_id = match.group(1)
251
+
252
+ if job_id:
253
+ self.job_id = job_id
254
+ self.logger.info(f"Job submitted with ID: {job_id}")
255
+ else:
256
+ self.logger.warning("Could not extract job ID from sbatch output")
257
+
258
+ self.logger.info(
259
+ "Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
260
+ os.path.join(self.launcher.project_path, script_file),
261
+ os.path.join(self.launcher.project_path, "{}.log".format(job_name)),
262
+ )
263
+ )
264
+
265
+ # Wait for log file to be created and job to start running
266
+ self._monitor_queue(job_name, job_id)
267
+
268
+ def _monitor_queue(self, job_name, job_id):
269
+ current_queue = None
270
+ queue_log_file = os.path.join(self.launcher.project_path, "queue.log")
271
+ with open(queue_log_file, "w") as f:
272
+ f.write("")
273
+ self.logger.info(
274
+ "Start to monitor the queue... You can check the queue at: <{}>".format(
275
+ queue_log_file
276
+ )
277
+ )
278
+
279
+ start_time = time.time()
280
+ last_print_time = 0
281
+ job_running = False
282
+
283
+ while True:
284
+ time.sleep(0.25)
285
+ if os.path.exists(
286
+ os.path.join(self.launcher.project_path, "{}.log".format(job_name))
287
+ ):
288
+ break
289
+ else:
290
+ # Get result from squeue -p {{PARTITION_NAME}}
291
+ result = subprocess.run(
292
+ ["squeue", "-p", "gpu" if self.launcher.use_gpu is True else "cpu"],
293
+ capture_output=True,
294
+ )
295
+ df = result.stdout.decode("utf-8").split("\n")
296
+
297
+ try:
298
+ users = list(
299
+ map(
300
+ lambda row: row[: len(df[0].split("ST")[0])][:-1].split(
301
+ " "
302
+ )[-1],
303
+ df,
304
+ )
305
+ )
306
+ status = list(
307
+ map(
308
+ lambda row: row[len(df[0].split("ST")[0]) :]
309
+ .strip()
310
+ .split(" ")[0],
311
+ df,
312
+ )
313
+ )
314
+ nodes = list(
315
+ map(
316
+ lambda row: row[len(df[0].split("NODE")[0]) :]
317
+ .strip()
318
+ .split(" ")[0],
319
+ df,
320
+ )
321
+ )
322
+ node_list = list(
323
+ map(
324
+ lambda row: row[len(df[0].split("NODELIST(REASON)")[0]) :],
325
+ df,
326
+ )
327
+ )
328
+
329
+ to_queue = list(
330
+ zip(
331
+ users,
332
+ status,
333
+ nodes,
334
+ node_list,
335
+ )
336
+ )[1:]
337
+
338
+ # Check if our job is running (status "R")
339
+ if job_id and not job_running:
340
+ job_position = None
341
+ total_jobs = len(to_queue)
342
+
343
+ for i, (user, stat, node_count, node_lst) in enumerate(
344
+ to_queue
345
+ ):
346
+ # Find our job by checking job IDs in squeue output
347
+ if i < len(df) - 1:
348
+ job_line = df[i + 1]
349
+ if job_id in job_line:
350
+ if stat == "R":
351
+ job_running = True
352
+ # Get head node
353
+ head_node = self._get_head_node_from_job_id(
354
+ job_id
355
+ )
356
+ if head_node:
357
+ self.logger.info(
358
+ f"Job is running on node {head_node}."
359
+ )
360
+ self.logger.info(
361
+ f"Dashboard should be accessible at http://{head_node}:8888 (if running on cluster)"
362
+ )
363
+ else:
364
+ job_position = i + 1
365
+ break
366
+
367
+ if job_running:
368
+ break
369
+
370
+ # Print queue status periodically
371
+ if time.time() - last_print_time > 30:
372
+ position_str = (
373
+ f"{job_position}/{total_jobs}"
374
+ if job_position
375
+ else "unknown"
376
+ )
377
+ print(
378
+ f"Waiting for job... (Position in queue : {position_str})"
379
+ )
380
+ last_print_time = time.time()
381
+
382
+ # Update the queue log
383
+ if time.time() - start_time > 60:
384
+ start_time = time.time()
385
+ # Log to file only, no print
386
+ with open(queue_log_file, "a") as f:
387
+ f.write(f"Update time: {time.strftime('%H:%M:%S')}\n")
388
+
389
+ if current_queue is None or current_queue != to_queue:
390
+ current_queue = to_queue
391
+ with open(queue_log_file, "w") as f:
392
+ text = f"Current queue ({time.strftime('%H:%M:%S')}):\n"
393
+ format_row = "{:>30}" * (len(current_queue[0]))
394
+ for user, status, nodes, node_list in current_queue:
395
+ text += (
396
+ format_row.format(user, status, nodes, node_list)
397
+ + "\n"
398
+ )
399
+ text += "\n"
400
+ f.write(text)
401
+ except Exception as e:
402
+ # If squeue format changes or fails parsing, don't crash
403
+ pass
404
+
405
+ # Wait for the job to finish while printing the log
406
+ self.logger.info("Job started! Waiting for the job to finish...")
407
+ log_cursor_position = 0
408
+ job_finished = False
409
+ while not job_finished:
410
+ time.sleep(0.25)
411
+ if os.path.exists(os.path.join(self.launcher.project_path, "result.pkl")):
412
+ job_finished = True
413
+ else:
414
+ with open(
415
+ os.path.join(self.launcher.project_path, "{}.log".format(job_name)),
416
+ "r",
417
+ ) as f:
418
+ f.seek(log_cursor_position)
419
+ text = f.read()
420
+ if text != "":
421
+ print(text, end="")
422
+ self.logger.info(text.strip())
423
+ log_cursor_position = f.tell()
424
+
425
+ self.logger.info("Job finished!")
426
+
427
+ def _get_head_node_from_job_id(
428
+ self, job_id: str, ssh_client: paramiko.SSHClient = None
429
+ ) -> str:
430
+ """Get the head node name from a SLURM job ID"""
431
+ try:
432
+ # Execute scontrol show job
433
+ if ssh_client:
434
+ stdin, stdout, stderr = ssh_client.exec_command(
435
+ f"scontrol show job {job_id}"
436
+ )
437
+ output = stdout.read().decode("utf-8")
438
+ if stderr.read().decode("utf-8"):
439
+ return None
440
+ else:
441
+ result = subprocess.run(
442
+ ["scontrol", "show", "job", job_id],
443
+ capture_output=True,
444
+ text=True,
445
+ )
446
+ if result.returncode != 0:
447
+ return None
448
+ output = result.stdout
449
+
450
+ # Extract NodeList from output
451
+ node_list_match = re.search(r"NodeList=([^\s]+)", output)
452
+ if not node_list_match:
453
+ return None
454
+
455
+ node_list = node_list_match.group(1)
456
+
457
+ # Get hostnames from NodeList using scontrol show hostnames
458
+ if ssh_client:
459
+ stdin, stdout, stderr = ssh_client.exec_command(
460
+ f"scontrol show hostnames {node_list}"
461
+ )
462
+ hostnames_output = stdout.read().decode("utf-8")
463
+ if stderr.read().decode("utf-8"):
464
+ return None
465
+ else:
466
+ result = subprocess.run(
467
+ ["scontrol", "show", "hostnames", node_list],
468
+ capture_output=True,
469
+ text=True,
470
+ )
471
+ if result.returncode != 0:
472
+ return None
473
+ output = result.stdout
474
+
475
+ # Get first hostname (head node)
476
+ hostnames = hostnames_output.strip().split("\n")
477
+ if hostnames and hostnames[0]:
478
+ return hostnames[0].strip()
479
+
480
+ return None
481
+ except Exception as e:
482
+ self.logger.warning(f"Failed to get head node from job ID {job_id}: {e}")
483
+ return None
484
+
485
+ def _launch_server(self, cancel_old_jobs: bool = True, wait: bool = True):
486
+ """Launch the server on the cluster and run the function using the ressources."""
487
+ connected = False
488
+ self.logger.info("Connecting to the cluster...")
489
+ ssh_client = paramiko.SSHClient()
490
+ self.ssh_client = ssh_client
491
+ ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
492
+ while not connected:
493
+ try:
494
+ if self.launcher.server_password is None:
495
+ # Add ssh key
496
+ self.launcher.server_password = getpass(
497
+ "Enter your cluster password: "
498
+ )
499
+
500
+ ssh_client.connect(
501
+ hostname=self.launcher.server_ssh,
502
+ username=self.launcher.server_username,
503
+ password=self.launcher.server_password,
504
+ )
505
+ sftp = ssh_client.open_sftp()
506
+ connected = True
507
+ except paramiko.ssh_exception.AuthenticationException:
508
+ self.launcher.server_password = None
509
+ self.logger.warning("Wrong password, please try again.")
510
+
511
+ # Setup pyenv Python version if available
512
+ self.pyenv_python_cmd = None
513
+ if hasattr(self.launcher, "local_python_version"):
514
+ self.pyenv_python_cmd = self._setup_pyenv_python(
515
+ ssh_client, self.launcher.local_python_version
516
+ )
517
+
518
+ # Check Python version compatibility (with pyenv if available)
519
+ is_compatible = self._check_python_version_compatibility(
520
+ ssh_client, self.pyenv_python_cmd
521
+ )
522
+ self.python_version_compatible = is_compatible
523
+
524
+ # Define project directory on cluster (organized by project name)
525
+ project_dir = f"slurmray-server/{self.launcher.project_name}"
526
+
527
+ # Write server script
528
+ self._write_server_script()
529
+
530
+ self.logger.info("Downloading server...")
531
+
532
+ # Generate requirements first to check venv hash
533
+ self._generate_requirements()
534
+
535
+ # Add slurmray (unpinned for now to match legacy behavior, but could be pinned)
536
+ # Check if slurmray is already in requirements.txt to avoid duplicates
537
+ req_file = f"{self.launcher.project_path}/requirements.txt"
538
+ with open(req_file, "r") as f:
539
+ content = f.read()
540
+ if "slurmray" not in content.lower():
541
+ with open(req_file, "a") as f:
542
+ f.write("slurmray\n")
543
+
544
+ # Check if venv can be reused based on requirements hash
545
+ dep_manager = DependencyManager(self.launcher.project_path, self.logger)
546
+ req_file = os.path.join(self.launcher.project_path, "requirements.txt")
547
+
548
+ should_recreate_venv = True
549
+ if self.launcher.force_reinstall_venv:
550
+ # Force recreation: remove venv if it exists
551
+ self.logger.info("Force reinstall enabled: removing existing virtualenv...")
552
+ ssh_client.exec_command(f"rm -rf {project_dir}/.venv")
553
+ should_recreate_venv = True
554
+ elif os.path.exists(req_file):
555
+ with open(req_file, "r") as f:
556
+ req_lines = f.readlines()
557
+ # Check remote hash (if venv exists on remote)
558
+ remote_hash_file = f"{project_dir}/.slogs/venv_hash.txt"
559
+ stdin, stdout, stderr = ssh_client.exec_command(
560
+ f"test -f {remote_hash_file} && cat {remote_hash_file} || echo ''"
561
+ )
562
+ remote_hash = stdout.read().decode("utf-8").strip()
563
+ current_hash = dep_manager.compute_requirements_hash(req_lines)
564
+
565
+ if remote_hash and remote_hash == current_hash:
566
+ # Hash matches, check if venv exists
567
+ stdin, stdout, stderr = ssh_client.exec_command(
568
+ f"test -d {project_dir}/.venv && echo exists || echo missing"
569
+ )
570
+ venv_exists = stdout.read().decode("utf-8").strip() == "exists"
571
+ if venv_exists:
572
+ should_recreate_venv = False
573
+ self.logger.info(
574
+ "Virtualenv can be reused (requirements hash matches)"
575
+ )
576
+
577
+ # Optimize requirements
578
+ # Assuming standard path structure on Slurm cluster (Curnagl)
579
+ venv_cmd = (
580
+ f"cd {project_dir} && source .venv/bin/activate &&"
581
+ if not should_recreate_venv
582
+ else ""
583
+ )
584
+ req_file_to_push = self._optimize_requirements(ssh_client, venv_cmd)
585
+
586
+ # Copy files from the project to the server
587
+ for file in os.listdir(self.launcher.project_path):
588
+ if file.endswith(".py") or file.endswith(".pkl") or file.endswith(".sh"):
589
+ if file == "requirements.txt":
590
+ continue
591
+ sftp.put(os.path.join(self.launcher.project_path, file), file)
592
+
593
+ # Smart cleanup: preserve venv if hash matches, only clean server logs
594
+ ssh_client.exec_command(f"mkdir -p {project_dir}/.slogs/server")
595
+ if should_recreate_venv:
596
+ # Clean server logs only (venv will be recreated by script if needed)
597
+ ssh_client.exec_command(f"rm -rf {project_dir}/.slogs/server/*")
598
+ # Create flag file to force venv recreation in script
599
+ if self.launcher.force_reinstall_venv:
600
+ ssh_client.exec_command(f"touch {project_dir}/.force_reinstall")
601
+ self.logger.info(
602
+ "Virtualenv will be recreated if needed (requirements changed or missing)"
603
+ )
604
+ else:
605
+ # Clean server logs only, preserve venv
606
+ ssh_client.exec_command(f"rm -rf {project_dir}/.slogs/server/*")
607
+ # Remove flag file if it exists
608
+ ssh_client.exec_command(f"rm -f {project_dir}/.force_reinstall")
609
+ self.logger.info("Preserving virtualenv (requirements unchanged)")
610
+ # Filter valid files
611
+ valid_files = []
612
+ for file in self.launcher.files:
613
+ # Skip invalid paths
614
+ if (
615
+ not file
616
+ or file == "."
617
+ or file == ".."
618
+ or file.startswith("./")
619
+ or file.startswith("../")
620
+ ):
621
+ self.logger.warning(f"Skipping invalid file path: {file}")
622
+ continue
623
+ valid_files.append(file)
624
+
625
+ # Use incremental sync for local files
626
+ if valid_files:
627
+ self._sync_local_files_incremental(
628
+ sftp, project_dir, valid_files, ssh_client
629
+ )
630
+
631
+ # Copy the requirements.txt (optimized) to the server
632
+ sftp.put(req_file_to_push, "requirements.txt")
633
+
634
+ # Store venv hash on remote for future checks
635
+ if os.path.exists(req_file):
636
+ with open(req_file, "r") as f:
637
+ req_lines = f.readlines()
638
+ current_hash = dep_manager.compute_requirements_hash(req_lines)
639
+ # Ensure .slogs directory exists on remote
640
+ ssh_client.exec_command(f"mkdir -p {project_dir}/.slogs")
641
+ stdin, stdout, stderr = ssh_client.exec_command(
642
+ f"echo '{current_hash}' > {project_dir}/.slogs/venv_hash.txt"
643
+ )
644
+ stdout.channel.recv_exit_status()
645
+ # Also store locally
646
+ dep_manager.store_venv_hash(current_hash)
647
+
648
+ # Update retention timestamp
649
+ self._update_retention_timestamp(
650
+ ssh_client, project_dir, self.launcher.retention_days
651
+ )
652
+
653
+ # Write and copy the server script to the server
654
+ self._write_slurmray_server_sh()
655
+ sftp.put(
656
+ os.path.join(self.launcher.project_path, "slurmray_server.sh"),
657
+ "slurmray_server.sh",
658
+ )
659
+ # Chmod script
660
+ sftp.chmod("slurmray_server.sh", 0o755)
661
+
662
+ # Run the server
663
+ self.logger.info("Running server...")
664
+ stdin, stdout, stderr = ssh_client.exec_command("./slurmray_server.sh")
665
+
666
+ # Read the output in real time and capture job ID
667
+ job_id = None
668
+ tunnel = None
669
+ output_lines = []
670
+
671
+ # Read output line by line to capture job ID
672
+ while True:
673
+ line = stdout.readline()
674
+ if not line:
675
+ break
676
+ output_lines.append(line)
677
+
678
+ # Double output: console + log file
679
+ print(line, end="")
680
+ self.logger.info(line.strip())
681
+
682
+ # Try to extract job ID from output (format: "Submitted batch job 12345")
683
+ if not job_id:
684
+ match = re.search(r"Submitted batch job (\d+)", line)
685
+ if match:
686
+ job_id = match.group(1)
687
+ self.job_id = job_id
688
+ self.logger.info(f"Job ID detected: {job_id}")
689
+
690
+ exit_status = stdout.channel.recv_exit_status()
691
+
692
+ # Check if script failed - fail-fast immediately
693
+ if exit_status != 0:
694
+ # Collect error information
695
+ stderr_output = stderr.read().decode("utf-8")
696
+ error_msg = f"Server script exited with non-zero status: {exit_status}"
697
+ if stderr_output.strip():
698
+ error_msg += f"\nScript errors:\n{stderr_output}"
699
+
700
+ # Log the error
701
+ self.logger.error(error_msg)
702
+
703
+ # Close tunnel if open
704
+ if tunnel:
705
+ try:
706
+ tunnel.__exit__(None, None, None)
707
+ except Exception:
708
+ pass
709
+
710
+ # Raise exception immediately (fail-fast)
711
+ raise RuntimeError(error_msg)
712
+
713
+ # If job ID not found in output, try to find it via squeue
714
+ if not job_id:
715
+ self.logger.info("Job ID not found in output, trying to find via squeue...")
716
+ stdin, stdout, stderr = ssh_client.exec_command(
717
+ f"squeue -u {self.launcher.server_username} -o '%i %j' --noheader"
718
+ )
719
+ squeue_output = stdout.read().decode("utf-8")
720
+ # Try to find job matching project name pattern
721
+ if self.launcher.project_name:
722
+ for line in squeue_output.strip().split("\n"):
723
+ parts = line.strip().split()
724
+ if len(parts) >= 2 and self.launcher.project_name in parts[1]:
725
+ job_id = parts[0]
726
+ self.job_id = job_id
727
+ self.logger.info(f"Found job ID via squeue: {job_id}")
728
+ break
729
+
730
+ # If job ID found and we are not waiting, return it and stop
731
+ if job_id and not wait:
732
+ self.logger.info("Async mode: Job submitted with ID {}. Disconnecting...".format(job_id))
733
+ return job_id
734
+
735
+ # If no job ID found and not waiting? We should probably warn or return None
736
+ if not job_id and not wait:
737
+ self.logger.warning("Async mode: Could not detect job ID. Returning None.")
738
+ return None
739
+
740
+ # Loop for monitoring (only if wait=True)
741
+
742
+ # If job ID found, wait for job to be running and set up tunnel
743
+ if job_id:
744
+ self.logger.info(f"Waiting for job {job_id} to start running...")
745
+ max_wait_time = 300 # Wait up to 5 minutes
746
+ wait_start = time.time()
747
+ job_running = False
748
+
749
+ while time.time() - wait_start < max_wait_time:
750
+ time.sleep(2)
751
+ stdin, stdout, stderr = ssh_client.exec_command(
752
+ f"squeue -j {job_id} -o '%T' --noheader"
753
+ )
754
+ status_output = stdout.read().decode("utf-8").strip()
755
+ if status_output == "R":
756
+ job_running = True
757
+ break
758
+
759
+ if job_running:
760
+ # Get head node
761
+ head_node = self._get_head_node_from_job_id(job_id, ssh_client)
762
+ if head_node:
763
+ self.logger.info(
764
+ f"Job is running on node {head_node}. Setting up SSH tunnel for dashboard..."
765
+ )
766
+ try:
767
+ tunnel = SSHTunnel(
768
+ ssh_host=self.launcher.server_ssh,
769
+ ssh_username=self.launcher.server_username,
770
+ ssh_password=self.launcher.server_password,
771
+ remote_host=head_node,
772
+ local_port=8888,
773
+ remote_port=8265,
774
+ logger=self.logger,
775
+ )
776
+ tunnel.__enter__()
777
+ self.logger.info(
778
+ "Dashboard accessible at http://localhost:8888"
779
+ )
780
+
781
+ # Wait for job to complete while maintaining tunnel
782
+ # Check periodically if job is still running
783
+ while True:
784
+ time.sleep(5)
785
+ stdin, stdout, stderr = ssh_client.exec_command(
786
+ f"squeue -j {job_id} -o '%T' --noheader"
787
+ )
788
+ status_output = stdout.read().decode("utf-8").strip()
789
+ if status_output != "R":
790
+ # Job finished or no longer running
791
+ break
792
+ except Exception as e:
793
+ self.logger.warning(f"Failed to create SSH tunnel: {e}")
794
+ self.logger.info(
795
+ "Dashboard will not be accessible via port forwarding"
796
+ )
797
+ tunnel = None
798
+ else:
799
+ self.logger.warning(
800
+ "Job did not start running within timeout, skipping tunnel setup"
801
+ )
802
+
803
+ # Close tunnel if it was created
804
+ if tunnel:
805
+ tunnel.__exit__(None, None, None)
806
+
807
+ # Downloading result
808
+ self.logger.info("Downloading result...")
809
+ project_dir = f"slurmray-server/{self.launcher.project_name}"
810
+ try:
811
+ sftp.get(
812
+ f"{project_dir}/.slogs/server/result.pkl",
813
+ os.path.join(self.launcher.project_path, "result.pkl"),
814
+ )
815
+ self.logger.info("Result downloaded!")
816
+
817
+ # Clean up remote temporary files (preserve venv and cache)
818
+ self.logger.info("Cleaning up remote temporary files...")
819
+ ssh_client.exec_command(
820
+ f"cd {project_dir} && "
821
+ "find . -maxdepth 1 -type f \\( -name '*.py' -o -name '*.pkl' -o -name '*.sh' \\) "
822
+ "! -name 'requirements.txt' -delete 2>/dev/null || true && "
823
+ "rm -rf .slogs/server 2>/dev/null || true"
824
+ )
825
+
826
+ # Clean up local temporary files after successful download
827
+ self._cleanup_local_temp_files()
828
+
829
+ except FileNotFoundError:
830
+ # Check for errors
831
+ stderr_lines = stderr.readlines()
832
+ if stderr_lines:
833
+ self.logger.error("Errors:")
834
+ for line in stderr_lines:
835
+ print(line, end="")
836
+ self.logger.error(line.strip())
837
+ self.logger.error("An error occured, please check the logs.")
838
+
839
+ def _cleanup_local_temp_files(self):
840
+ """Clean up local temporary files after successful execution"""
841
+ temp_files = [
842
+ "func_source.py",
843
+ "func_name.txt",
844
+ "func.pkl",
845
+ "args.pkl",
846
+ "result.pkl",
847
+ "spython.py",
848
+ "sbatch.sh",
849
+ "slurmray_server.py",
850
+ "requirements_to_install.txt",
851
+ ]
852
+
853
+ for temp_file in temp_files:
854
+ file_path = os.path.join(self.launcher.project_path, temp_file)
855
+ if os.path.exists(file_path):
856
+ os.remove(file_path)
857
+ self.logger.debug(f"Removed temporary file: {temp_file}")
858
+
859
+ def _write_server_script(self):
860
+ """This funtion will write a script with the given specifications to run slurmray on the cluster"""
861
+ self.logger.info("Writing slurmray server script...")
862
+ template_file = os.path.join(
863
+ self.launcher.module_path, "assets", "slurmray_server_template.py"
864
+ )
865
+
866
+ MODULES = self.launcher.modules
867
+ NODE_NBR = self.launcher.node_nbr
868
+ USE_GPU = self.launcher.use_gpu
869
+ MEMORY = self.launcher.memory
870
+ MAX_RUNNING_TIME = self.launcher.max_running_time
871
+
872
+ # ===== Modified the template script =====
873
+ with open(template_file, "r") as f:
874
+ text = f.read()
875
+ text = text.replace("{{MODULES}}", str(MODULES))
876
+ text = text.replace("{{NODE_NBR}}", str(NODE_NBR))
877
+ text = text.replace("{{USE_GPU}}", str(USE_GPU))
878
+ text = text.replace("{{MEMORY}}", str(MEMORY))
879
+ text = text.replace("{{MAX_RUNNING_TIME}}", str(MAX_RUNNING_TIME))
880
+
881
+ # ===== Save the script =====
882
+ script_file = "slurmray_server.py"
883
+ with open(os.path.join(self.launcher.project_path, script_file), "w") as f:
884
+ f.write(text)
885
+
886
+ def _write_slurmray_server_sh(self):
887
+ """Write the slurmray_server.sh script with pyenv support if available"""
888
+ # Determine Python command
889
+ if self.pyenv_python_cmd:
890
+ # Use pyenv: the command already includes eval and pyenv shell
891
+ python_cmd = self.pyenv_python_cmd.split(" && ")[
892
+ -1
893
+ ] # Extract just "python" from the command
894
+ python3_cmd = python_cmd.replace("python", "python3")
895
+ pyenv_setup = self.pyenv_python_cmd.rsplit(" && ", 1)[
896
+ 0
897
+ ] # Get "eval ... && pyenv shell X.Y.Z"
898
+ use_pyenv = True
899
+ else:
900
+ # Fallback to system Python
901
+ python_cmd = "python"
902
+ python3_cmd = "python3"
903
+ pyenv_setup = ""
904
+ use_pyenv = False
905
+
906
+ # Filter out python module from modules list if using pyenv
907
+ modules_list = self.launcher.modules.copy()
908
+ if use_pyenv:
909
+ modules_list = [m for m in modules_list if not m.startswith("python")]
910
+
911
+ modules_str = " ".join(modules_list) if modules_list else ""
912
+ project_dir = f"slurmray-server/{self.launcher.project_name}"
913
+
914
+ script_content = f"""#!/bin/sh
915
+
916
+ echo "Installing slurmray server"
917
+
918
+ # Copy files
919
+ mv -t {project_dir} requirements.txt slurmray_server.py
920
+ mv -t {project_dir}/.slogs/server func.pkl args.pkl
921
+ cd {project_dir}
922
+
923
+ # Load modules
924
+ # Using specific versions for Curnagl compatibility (SLURM 24.05.3)
925
+ # gcc/13.2.0: Latest GCC version
926
+ # python module is loaded only if pyenv is not available
927
+ # cuda/12.6.2: Latest CUDA version
928
+ # cudnn/9.2.0.82-12: Compatible with cuda/12.6.2
929
+ """
930
+
931
+ if modules_str:
932
+ script_content += f"module load {modules_str}\n"
933
+
934
+ script_content += f"""
935
+ # Setup pyenv if available
936
+ """
937
+
938
+ if use_pyenv:
939
+ script_content += f"""# Using pyenv for Python version management
940
+ export PATH="$HOME/.pyenv/bin:/usr/local/bin:/opt/pyenv/bin:$PATH"
941
+ {pyenv_setup}
942
+ """
943
+ else:
944
+ script_content += """# pyenv not available, using system Python
945
+ """
946
+
947
+ script_content += f"""
948
+ # Create venv if it doesn't exist (hash check is done in Python before file upload)
949
+ # If venv needs recreation, it has already been removed by Python
950
+ # Check for force reinstall flag
951
+ if [ -f ".force_reinstall" ]; then
952
+ echo "Force reinstall flag detected: removing existing virtualenv..."
953
+ rm -rf .venv
954
+ rm -f .force_reinstall
955
+ fi
956
+
957
+ if [ ! -d ".venv" ]; then
958
+ echo "Creating virtualenv..."
959
+ """
960
+
961
+ if use_pyenv:
962
+ script_content += f""" {pyenv_setup} && {python3_cmd} -m venv .venv
963
+ """
964
+ else:
965
+ script_content += f""" {python3_cmd} -m venv .venv
966
+ """
967
+
968
+ script_content += f"""else
969
+ echo "Using existing virtualenv (requirements unchanged)..."
970
+ VENV_EXISTED=true
971
+ fi
972
+
973
+ source .venv/bin/activate
974
+
975
+ # Install requirements if file exists and is not empty
976
+ if [ -f requirements.txt ]; then
977
+ # Check if requirements.txt is empty (only whitespace)
978
+ if [ -s requirements.txt ]; then
979
+ echo "📥 Installing dependencies from requirements.txt..."
980
+
981
+ # Get installed packages once (fast, single command) - create lookup file
982
+ uv pip list --format=freeze 2>/dev/null | sed 's/==/ /' | awk '{{print $1" "$2}}' > /tmp/installed_packages.txt || touch /tmp/installed_packages.txt
983
+
984
+ # Process requirements: filter duplicates and check what needs installation
985
+ INSTALL_ERRORS=0
986
+ SKIPPED_COUNT=0
987
+ > /tmp/to_install.txt # Clear file
988
+
989
+ while IFS= read -r line || [ -n "$line" ]; do
990
+ # Skip empty lines and comments
991
+ line=$(echo "$line" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
992
+ if [ -z "$line" ] || [ "${{line#"#"}}" != "$line" ]; then
993
+ continue
994
+ fi
995
+
996
+ # Extract package name (remove version specifiers and extras)
997
+ pkg_name=$(echo "$line" | sed 's/[<>=!].*//' | sed 's/\\[.*\\]//' | sed 's/[[:space:]]*//' | tr '[:upper:]' '[:lower:]')
998
+ if [ -z "$pkg_name" ]; then
999
+ continue
1000
+ fi
1001
+
1002
+ # Skip duplicates (check if we've already processed this package)
1003
+ if grep -qi "^$pkg_name$" /tmp/seen_packages.txt 2>/dev/null; then
1004
+ continue
1005
+ fi
1006
+ echo "$pkg_name" >> /tmp/seen_packages.txt
1007
+
1008
+ # Extract required version if present
1009
+ required_version=""
1010
+ if echo "$line" | grep -q "=="; then
1011
+ required_version=$(echo "$line" | sed 's/.*==\\([^;]*\\).*/\\1/' | sed 's/[[:space:]]*//')
1012
+ fi
1013
+
1014
+ # Check if package is already installed with correct version
1015
+ installed_version=$(grep -i "^$pkg_name " /tmp/installed_packages.txt 2>/dev/null | awk '{{print $2}}' | head -1)
1016
+
1017
+ if [ -n "$installed_version" ]; then
1018
+ if [ -z "$required_version" ] || [ "$installed_version" = "$required_version" ]; then
1019
+ echo " ⏭️ $pkg_name==$installed_version (already installed)"
1020
+ SKIPPED_COUNT=$((SKIPPED_COUNT + 1))
1021
+ continue
1022
+ fi
1023
+ fi
1024
+
1025
+ # Package not installed or version mismatch, add to install list
1026
+ echo "$line" >> /tmp/to_install.txt
1027
+ done < requirements.txt
1028
+
1029
+ # Install packages that need installation
1030
+ if [ -s /tmp/to_install.txt ]; then
1031
+ > /tmp/install_errors.txt # Track errors
1032
+ while IFS= read -r line; do
1033
+ pkg_name=$(echo "$line" | sed 's/[<>=!].*//' | sed 's/\\[.*\\]//' | sed 's/[[:space:]]*//')
1034
+ if uv pip install --quiet "$line" >/dev/null 2>&1; then
1035
+ echo " ✅ $pkg_name"
1036
+ else
1037
+ echo " ❌ $pkg_name"
1038
+ echo "1" >> /tmp/install_errors.txt
1039
+ # Show error details
1040
+ uv pip install "$line" 2>&1 | grep -E "(error|Error|ERROR|failed|Failed|FAILED)" | head -3 | sed 's/^/ /' || true
1041
+ fi
1042
+ done < /tmp/to_install.txt
1043
+ INSTALL_ERRORS=$(wc -l < /tmp/install_errors.txt 2>/dev/null | tr -d ' ' || echo "0")
1044
+ rm -f /tmp/install_errors.txt
1045
+ fi
1046
+
1047
+ # Count newly installed packages before cleanup
1048
+ NEWLY_INSTALLED=0
1049
+ if [ -s /tmp/to_install.txt ]; then
1050
+ NEWLY_INSTALLED=$(wc -l < /tmp/to_install.txt 2>/dev/null | tr -d ' ' || echo "0")
1051
+ fi
1052
+
1053
+ # Cleanup temp files
1054
+ rm -f /tmp/installed_packages.txt /tmp/seen_packages.txt /tmp/to_install.txt
1055
+
1056
+ if [ $INSTALL_ERRORS -eq 0 ]; then
1057
+ if [ $SKIPPED_COUNT -gt 0 ]; then
1058
+ echo "✅ All dependencies up to date ($SKIPPED_COUNT already installed, $NEWLY_INSTALLED newly installed)"
1059
+ else
1060
+ echo "✅ All dependencies installed successfully"
1061
+ fi
1062
+ else
1063
+ echo "❌ Failed to install $INSTALL_ERRORS package(s)" >&2
1064
+ exit 1
1065
+ fi
1066
+ else
1067
+ if [ "$VENV_EXISTED" = "true" ]; then
1068
+ echo "✅ All dependencies already installed (requirements.txt is empty)"
1069
+ else
1070
+ echo "⚠️ requirements.txt is empty, skipping dependency installation"
1071
+ fi
1072
+ fi
1073
+ else
1074
+ echo "⚠️ No requirements.txt found, skipping dependency installation"
1075
+ fi
1076
+
1077
+ # Fix torch bug (https://github.com/pytorch/pytorch/issues/111469)
1078
+ PYTHON_VERSION=$({python3_cmd} -c 'import sys; print(f"{{sys.version_info.major}}.{{sys.version_info.minor}}")')
1079
+ export LD_LIBRARY_PATH=$HOME/{project_dir}/.venv/lib/python$PYTHON_VERSION/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
1080
+
1081
+
1082
+ # Run server
1083
+ """
1084
+
1085
+ if use_pyenv:
1086
+ script_content += f"""{pyenv_setup} && {python_cmd} -u slurmray_server.py
1087
+ """
1088
+ else:
1089
+ script_content += f"""{python_cmd} -u slurmray_server.py
1090
+ """
1091
+
1092
+ script_file = "slurmray_server.sh"
1093
+ with open(os.path.join(self.launcher.project_path, script_file), "w") as f:
1094
+ f.write(script_content)
1095
+
1096
+ def _push_file(
1097
+ self, file_path: str, sftp: paramiko.SFTPClient, ssh_client: paramiko.SSHClient
1098
+ ):
1099
+ """Push a file to the cluster"""
1100
+ self.logger.info(
1101
+ f"Pushing file {os.path.basename(file_path)} to the cluster..."
1102
+ )
1103
+
1104
+ # Determine the path to the file
1105
+ local_path = file_path
1106
+ local_path_from_pwd = os.path.relpath(local_path, self.launcher.pwd_path)
1107
+ cluster_path = os.path.join(
1108
+ "/users",
1109
+ self.launcher.server_username,
1110
+ "slurmray-server",
1111
+ self.launcher.project_name,
1112
+ local_path_from_pwd,
1113
+ )
1114
+
1115
+ # Create the directory if not exists
1116
+
1117
+ stdin, stdout, stderr = ssh_client.exec_command(
1118
+ f"mkdir -p '{os.path.dirname(cluster_path)}'"
1119
+ )
1120
+ while True:
1121
+ line = stdout.readline()
1122
+ if not line:
1123
+ break
1124
+ # Keep print for real-time feedback on directory creation (optional, could be logger.debug)
1125
+ self.logger.debug(line.strip())
1126
+ time.sleep(1) # Wait for the directory to be created
1127
+
1128
+ sftp.put(file_path, cluster_path)
1129
+
1130
+ def get_result(self, job_id: str) -> Any:
1131
+ """Get result for a specific job ID"""
1132
+ # Load local result if available (cluster mode or already downloaded)
1133
+ local_path = os.path.join(self.launcher.project_path, "result.pkl")
1134
+ if os.path.exists(local_path):
1135
+ with open(local_path, "rb") as f:
1136
+ return dill.load(f)
1137
+
1138
+ # If clustered/server run, try to fetch it
1139
+ if self.launcher.cluster:
1140
+ # Already checked local path, so it's missing
1141
+ return None
1142
+
1143
+ if self.launcher.server_run:
1144
+ self._connect()
1145
+ try:
1146
+ project_dir = f"slurmray-server/{self.launcher.project_name}"
1147
+ remote_path = f"{project_dir}/.slogs/server/result.pkl"
1148
+ self.ssh_client.open_sftp().get(remote_path, local_path)
1149
+ with open(local_path, "rb") as f:
1150
+ return dill.load(f)
1151
+ except Exception:
1152
+ return None
1153
+
1154
+ return None
1155
+
1156
+ def get_logs(self, job_id: str) -> Any:
1157
+ """Get logs for a specific job ID"""
1158
+ if not job_id:
1159
+ yield "No Job ID provided."
1160
+ return
1161
+
1162
+ # Attempt to get job info to find log file
1163
+ cmd = f"scontrol show job {job_id}"
1164
+
1165
+ output = ""
1166
+ if self.launcher.cluster:
1167
+ try:
1168
+ res = subprocess.run(cmd.split(), capture_output=True, text=True)
1169
+ output = res.stdout
1170
+ except Exception as e:
1171
+ yield f"Error querying local scontrol: {e}"
1172
+ elif self.launcher.server_run:
1173
+ self._connect()
1174
+ try:
1175
+ stdin, stdout, stderr = self.ssh_client.exec_command(cmd)
1176
+ output = stdout.read().decode("utf-8")
1177
+ except Exception as e:
1178
+ yield f"Error querying remote job info: {e}"
1179
+ return
1180
+ else:
1181
+ yield "Not running on cluster."
1182
+ return
1183
+
1184
+ # Extract JobName
1185
+ # JobName=example_1012-14h22
1186
+ import re
1187
+ match = re.search(r"JobName=([^\s]+)", output)
1188
+ if match:
1189
+ job_name = match.group(1)
1190
+ log_filename = f"{job_name}.log"
1191
+ else:
1192
+ yield "Could not determine log filename via scontrol (Job might be finished). Checking standard path..."
1193
+ # If scontrol failed (job finished), we might just look for ANY log file or standard naming
1194
+ # BUT standard naming includes timestamp. We can't guess it easily without listing.
1195
+ # We can list files matching project name and take newest.
1196
+ log_filename = None
1197
+
1198
+ # Attempt to find log file by listing
1199
+ find_cmd = f"ls -t {self.launcher.project_name}*.log | head -n 1"
1200
+ if self.launcher.cluster:
1201
+ if os.path.exists(self.launcher.project_path):
1202
+ import glob
1203
+ files = glob.glob(os.path.join(self.launcher.project_path, f"{self.launcher.project_name}*.log"))
1204
+ if files:
1205
+ log_filename = os.path.basename(max(files, key=os.path.getmtime))
1206
+ elif self.launcher.server_run:
1207
+ project_dir = f"slurmray-server/{self.launcher.project_name}"
1208
+ stdin, stdout, stderr = self.ssh_client.exec_command(f"cd {project_dir} && ls -t {self.launcher.project_name}*.log | head -n 1")
1209
+ possible_log = stdout.read().decode("utf-8").strip()
1210
+ if possible_log:
1211
+ log_filename = possible_log
1212
+
1213
+ if not log_filename:
1214
+ yield "Log file could not be found."
1215
+ return
1216
+
1217
+ # Read log file
1218
+ if self.launcher.cluster:
1219
+ log_path = os.path.join(self.launcher.project_path, log_filename)
1220
+ if os.path.exists(log_path):
1221
+ with open(log_path, "r") as f:
1222
+ yield from f
1223
+ else:
1224
+ yield f"Log file {log_path} not found."
1225
+ elif self.launcher.server_run:
1226
+ project_dir = f"slurmray-server/{self.launcher.project_name}"
1227
+ remote_log = f"{project_dir}/{log_filename}"
1228
+ try:
1229
+ stdin, stdout, stderr = self.ssh_client.exec_command(f"cat {remote_log}")
1230
+ # Stream output? stdout is a channel.
1231
+ for line in stdout:
1232
+ yield line
1233
+ except Exception as e:
1234
+ yield f"Error reading remote log: {e}"