slurmray 6.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of slurmray might be problematic. Click here for more details.
- slurmray/RayLauncher.py +1019 -0
- slurmray/__init__.py +0 -0
- slurmray/__main__.py +5 -0
- slurmray/assets/cleanup_old_projects.py +171 -0
- slurmray/assets/sbatch_template.sh +67 -0
- slurmray/assets/slurmray_server.sh +145 -0
- slurmray/assets/slurmray_server_template.py +28 -0
- slurmray/assets/spython_template.py +113 -0
- slurmray/backend/__init__.py +0 -0
- slurmray/backend/base.py +1040 -0
- slurmray/backend/desi.py +856 -0
- slurmray/backend/local.py +124 -0
- slurmray/backend/remote.py +191 -0
- slurmray/backend/slurm.py +1234 -0
- slurmray/cli.py +904 -0
- slurmray/detection.py +1 -0
- slurmray/file_sync.py +276 -0
- slurmray/scanner.py +441 -0
- slurmray/utils.py +359 -0
- slurmray-6.0.4.dist-info/LICENSE +201 -0
- slurmray-6.0.4.dist-info/METADATA +85 -0
- slurmray-6.0.4.dist-info/RECORD +24 -0
- slurmray-6.0.4.dist-info/WHEEL +4 -0
- slurmray-6.0.4.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1234 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import subprocess
|
|
5
|
+
import paramiko
|
|
6
|
+
import dill
|
|
7
|
+
import re
|
|
8
|
+
from getpass import getpass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from slurmray.backend.base import ClusterBackend
|
|
12
|
+
from slurmray.backend.remote import RemoteMixin
|
|
13
|
+
from slurmray.utils import SSHTunnel, DependencyManager
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SlurmBackend(RemoteMixin):
|
|
17
|
+
"""Backend for Slurm cluster execution (local or remote via SSH)"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, launcher):
|
|
20
|
+
super().__init__(launcher)
|
|
21
|
+
self.ssh_client = None
|
|
22
|
+
self.job_id = None
|
|
23
|
+
|
|
24
|
+
def run(self, cancel_old_jobs: bool = True, wait: bool = True) -> Any:
|
|
25
|
+
"""Run the job on Slurm (locally or remotely)"""
|
|
26
|
+
|
|
27
|
+
# Generate the Python script (spython.py) - needed for both modes
|
|
28
|
+
if not self.launcher.server_run:
|
|
29
|
+
self._write_python_script()
|
|
30
|
+
self.script_file, self.job_name = self._write_slurm_script()
|
|
31
|
+
else:
|
|
32
|
+
# For server run, spython.py is generated inside __write_server_script -> slurmray_server.sh execution context?
|
|
33
|
+
# Wait, in original code:
|
|
34
|
+
# if not self.server_run:
|
|
35
|
+
# self.__write_python_script()
|
|
36
|
+
# self.script_file, self.job_name = self.__write_slurm_script()
|
|
37
|
+
#
|
|
38
|
+
# But in __launch_server, it writes server script, pushes files, and runs it.
|
|
39
|
+
# The server script (slurmray_server.py) will instantiate RayLauncher on the cluster!
|
|
40
|
+
# So when running on cluster (via SSH), RayLauncher is instantiated again on the remote machine.
|
|
41
|
+
# On the remote machine, self.cluster will be True (detected /usr/bin/sbatch).
|
|
42
|
+
# So the remote instance will enter the "if self.cluster:" block.
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
if self.launcher.cluster:
|
|
46
|
+
self.logger.info("Cluster detected, running on cluster...")
|
|
47
|
+
# Cancel the old jobs
|
|
48
|
+
if cancel_old_jobs:
|
|
49
|
+
self.logger.info("Canceling old jobs...")
|
|
50
|
+
subprocess.Popen(
|
|
51
|
+
["scancel", "-u", os.environ["USER"]],
|
|
52
|
+
stdout=subprocess.PIPE,
|
|
53
|
+
stderr=subprocess.PIPE,
|
|
54
|
+
)
|
|
55
|
+
# Launch the job
|
|
56
|
+
# We need to ensure scripts are written if we are on the cluster
|
|
57
|
+
if not hasattr(self, "script_file"):
|
|
58
|
+
self._write_python_script()
|
|
59
|
+
self.script_file, self.job_name = self._write_slurm_script()
|
|
60
|
+
|
|
61
|
+
self._launch_job(self.script_file, self.job_name)
|
|
62
|
+
|
|
63
|
+
if not wait:
|
|
64
|
+
return self.job_id
|
|
65
|
+
|
|
66
|
+
elif self.launcher.server_run:
|
|
67
|
+
return self._launch_server(cancel_old_jobs, wait=wait)
|
|
68
|
+
|
|
69
|
+
# Load the result
|
|
70
|
+
# Note: In server_run mode, _launch_server downloads result.pkl
|
|
71
|
+
# In cluster mode, we wait for result.pkl
|
|
72
|
+
|
|
73
|
+
if self.launcher.cluster:
|
|
74
|
+
# Wait for result in cluster mode (same filesystem)
|
|
75
|
+
while not os.path.exists(
|
|
76
|
+
os.path.join(self.launcher.project_path, "result.pkl")
|
|
77
|
+
):
|
|
78
|
+
time.sleep(0.25)
|
|
79
|
+
|
|
80
|
+
# Result should be there now
|
|
81
|
+
with open(os.path.join(self.launcher.project_path, "result.pkl"), "rb") as f:
|
|
82
|
+
result = dill.load(f)
|
|
83
|
+
|
|
84
|
+
return result
|
|
85
|
+
|
|
86
|
+
def cancel(self, job_id: str):
|
|
87
|
+
"""Cancel a job"""
|
|
88
|
+
if self.launcher.cluster:
|
|
89
|
+
self.logger.info(f"Canceling local job {job_id}...")
|
|
90
|
+
try:
|
|
91
|
+
subprocess.run(
|
|
92
|
+
["scancel", job_id],
|
|
93
|
+
stdout=subprocess.PIPE,
|
|
94
|
+
stderr=subprocess.PIPE,
|
|
95
|
+
check=False,
|
|
96
|
+
)
|
|
97
|
+
self.logger.info(f"Job {job_id} canceled.")
|
|
98
|
+
except Exception as e:
|
|
99
|
+
self.logger.error(f"Failed to cancel job {job_id}: {e}")
|
|
100
|
+
elif self.launcher.server_run and self.ssh_client:
|
|
101
|
+
self.logger.info(f"Canceling remote job {job_id} via SSH...")
|
|
102
|
+
try:
|
|
103
|
+
# Check if connection is still active
|
|
104
|
+
if (
|
|
105
|
+
self.ssh_client.get_transport()
|
|
106
|
+
and self.ssh_client.get_transport().is_active()
|
|
107
|
+
):
|
|
108
|
+
self.ssh_client.exec_command(f"scancel {job_id}")
|
|
109
|
+
self.logger.info(f"Remote job {job_id} canceled.")
|
|
110
|
+
else:
|
|
111
|
+
self.logger.warning(
|
|
112
|
+
"SSH connection lost, cannot cancel remote job."
|
|
113
|
+
)
|
|
114
|
+
except Exception as e:
|
|
115
|
+
self.logger.error(f"Failed to cancel remote job {job_id}: {e}")
|
|
116
|
+
finally:
|
|
117
|
+
try:
|
|
118
|
+
self.ssh_client.close()
|
|
119
|
+
except Exception:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
# =========================================================================
|
|
123
|
+
# Private methods extracted from RayLauncher
|
|
124
|
+
# =========================================================================
|
|
125
|
+
|
|
126
|
+
def _write_python_script(self):
|
|
127
|
+
"""Write the python script that will be executed by the job"""
|
|
128
|
+
self.logger.info("Writing python script...")
|
|
129
|
+
|
|
130
|
+
# Remove the old python script
|
|
131
|
+
for file in os.listdir(self.launcher.project_path):
|
|
132
|
+
if file.endswith(".py"):
|
|
133
|
+
os.remove(os.path.join(self.launcher.project_path, file))
|
|
134
|
+
|
|
135
|
+
# Write the python script
|
|
136
|
+
with open(
|
|
137
|
+
os.path.join(self.launcher.module_path, "assets", "spython_template.py"),
|
|
138
|
+
"r",
|
|
139
|
+
) as f:
|
|
140
|
+
text = f.read()
|
|
141
|
+
|
|
142
|
+
text = text.replace("{{PROJECT_PATH}}", f'"{self.launcher.project_path}"')
|
|
143
|
+
local_mode = ""
|
|
144
|
+
if self.launcher.cluster or self.launcher.server_run:
|
|
145
|
+
# Add Ray warning suppression to runtime_env if not already present
|
|
146
|
+
runtime_env = self.launcher.runtime_env.copy()
|
|
147
|
+
if "env_vars" not in runtime_env:
|
|
148
|
+
runtime_env["env_vars"] = {}
|
|
149
|
+
if "RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO" not in runtime_env["env_vars"]:
|
|
150
|
+
runtime_env["env_vars"]["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
|
151
|
+
|
|
152
|
+
local_mode = f"\n\taddress='auto',\n\tinclude_dashboard=True,\n\tdashboard_host='0.0.0.0',\n\tdashboard_port=8265,\nruntime_env = {runtime_env},\n"
|
|
153
|
+
text = text.replace(
|
|
154
|
+
"{{LOCAL_MODE}}",
|
|
155
|
+
local_mode,
|
|
156
|
+
)
|
|
157
|
+
with open(os.path.join(self.launcher.project_path, "spython.py"), "w") as f:
|
|
158
|
+
f.write(text)
|
|
159
|
+
|
|
160
|
+
def _write_slurm_script(self):
|
|
161
|
+
"""Write the slurm script that will be executed by the job"""
|
|
162
|
+
self.logger.info("Writing slurm script...")
|
|
163
|
+
template_file = os.path.join(
|
|
164
|
+
self.launcher.module_path, "assets", "sbatch_template.sh"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
JOB_NAME = "{{JOB_NAME}}"
|
|
168
|
+
NUM_NODES = "{{NUM_NODES}}"
|
|
169
|
+
MEMORY = "{{MEMORY}}"
|
|
170
|
+
RUNNING_TIME = "{{RUNNING_TIME}}"
|
|
171
|
+
PARTITION_NAME = "{{PARTITION_NAME}}"
|
|
172
|
+
COMMAND_PLACEHOLDER = "{{COMMAND_PLACEHOLDER}}"
|
|
173
|
+
GIVEN_NODE = "{{GIVEN_NODE}}"
|
|
174
|
+
COMMAND_SUFFIX = "{{COMMAND_SUFFIX}}"
|
|
175
|
+
LOAD_ENV = "{{LOAD_ENV}}"
|
|
176
|
+
PARTITION_SPECIFICS = "{{PARTITION_SPECIFICS}}"
|
|
177
|
+
|
|
178
|
+
job_name = "{}_{}".format(
|
|
179
|
+
self.launcher.project_name, time.strftime("%d%m-%Hh%M", time.localtime())
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Convert the time to xx:xx:xx format
|
|
183
|
+
max_time = "{}:{}:{}".format(
|
|
184
|
+
str(self.launcher.max_running_time // 60).zfill(2),
|
|
185
|
+
str(self.launcher.max_running_time % 60).zfill(2),
|
|
186
|
+
str(0).zfill(2),
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# ===== Modified the template script =====
|
|
190
|
+
with open(template_file, "r") as f:
|
|
191
|
+
text = f.read()
|
|
192
|
+
text = text.replace(
|
|
193
|
+
JOB_NAME, os.path.join(self.launcher.project_path, job_name)
|
|
194
|
+
)
|
|
195
|
+
text = text.replace(NUM_NODES, str(self.launcher.node_nbr))
|
|
196
|
+
text = text.replace(MEMORY, str(self.launcher.memory))
|
|
197
|
+
text = text.replace(RUNNING_TIME, str(max_time))
|
|
198
|
+
text = text.replace(
|
|
199
|
+
PARTITION_NAME, str("gpu" if self.launcher.use_gpu > 0 else "cpu")
|
|
200
|
+
)
|
|
201
|
+
text = text.replace(
|
|
202
|
+
COMMAND_PLACEHOLDER,
|
|
203
|
+
str(f"{sys.executable} {self.launcher.project_path}/spython.py"),
|
|
204
|
+
)
|
|
205
|
+
text = text.replace(
|
|
206
|
+
LOAD_ENV, str(f"module load {' '.join(self.launcher.modules)}")
|
|
207
|
+
)
|
|
208
|
+
text = text.replace(GIVEN_NODE, "")
|
|
209
|
+
text = text.replace(COMMAND_SUFFIX, "")
|
|
210
|
+
text = text.replace(
|
|
211
|
+
"# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO " "PRODUCTION!",
|
|
212
|
+
"# THIS FILE IS MODIFIED AUTOMATICALLY FROM TEMPLATE AND SHOULD BE "
|
|
213
|
+
"RUNNABLE!",
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# ===== Add partition specifics =====
|
|
217
|
+
if self.launcher.use_gpu > 0:
|
|
218
|
+
text = text.replace(
|
|
219
|
+
PARTITION_SPECIFICS,
|
|
220
|
+
str("#SBATCH --gres gpu:1\n#SBATCH --gres-flags enforce-binding"),
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
text = text.replace(PARTITION_SPECIFICS, "#SBATCH --exclusive")
|
|
224
|
+
|
|
225
|
+
# ===== Save the script =====
|
|
226
|
+
script_file = "sbatch.sh"
|
|
227
|
+
with open(os.path.join(self.launcher.project_path, script_file), "w") as f:
|
|
228
|
+
f.write(text)
|
|
229
|
+
|
|
230
|
+
return script_file, job_name
|
|
231
|
+
|
|
232
|
+
def _launch_job(self, script_file: str = None, job_name: str = None):
|
|
233
|
+
"""Launch the job"""
|
|
234
|
+
# ===== Submit the job =====
|
|
235
|
+
self.logger.info("Start to submit job!")
|
|
236
|
+
result = subprocess.run(
|
|
237
|
+
["sbatch", os.path.join(self.launcher.project_path, script_file)],
|
|
238
|
+
capture_output=True,
|
|
239
|
+
text=True,
|
|
240
|
+
)
|
|
241
|
+
if result.returncode != 0:
|
|
242
|
+
self.logger.error(f"Error submitting job: {result.stderr}")
|
|
243
|
+
return
|
|
244
|
+
|
|
245
|
+
# Extract job ID from output (format: "Submitted batch job 12345")
|
|
246
|
+
job_id = None
|
|
247
|
+
if result.stdout:
|
|
248
|
+
match = re.search(r"Submitted batch job (\d+)", result.stdout)
|
|
249
|
+
if match:
|
|
250
|
+
job_id = match.group(1)
|
|
251
|
+
|
|
252
|
+
if job_id:
|
|
253
|
+
self.job_id = job_id
|
|
254
|
+
self.logger.info(f"Job submitted with ID: {job_id}")
|
|
255
|
+
else:
|
|
256
|
+
self.logger.warning("Could not extract job ID from sbatch output")
|
|
257
|
+
|
|
258
|
+
self.logger.info(
|
|
259
|
+
"Job submitted! Script file is at: <{}>. Log file is at: <{}>".format(
|
|
260
|
+
os.path.join(self.launcher.project_path, script_file),
|
|
261
|
+
os.path.join(self.launcher.project_path, "{}.log".format(job_name)),
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Wait for log file to be created and job to start running
|
|
266
|
+
self._monitor_queue(job_name, job_id)
|
|
267
|
+
|
|
268
|
+
def _monitor_queue(self, job_name, job_id):
|
|
269
|
+
current_queue = None
|
|
270
|
+
queue_log_file = os.path.join(self.launcher.project_path, "queue.log")
|
|
271
|
+
with open(queue_log_file, "w") as f:
|
|
272
|
+
f.write("")
|
|
273
|
+
self.logger.info(
|
|
274
|
+
"Start to monitor the queue... You can check the queue at: <{}>".format(
|
|
275
|
+
queue_log_file
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
start_time = time.time()
|
|
280
|
+
last_print_time = 0
|
|
281
|
+
job_running = False
|
|
282
|
+
|
|
283
|
+
while True:
|
|
284
|
+
time.sleep(0.25)
|
|
285
|
+
if os.path.exists(
|
|
286
|
+
os.path.join(self.launcher.project_path, "{}.log".format(job_name))
|
|
287
|
+
):
|
|
288
|
+
break
|
|
289
|
+
else:
|
|
290
|
+
# Get result from squeue -p {{PARTITION_NAME}}
|
|
291
|
+
result = subprocess.run(
|
|
292
|
+
["squeue", "-p", "gpu" if self.launcher.use_gpu is True else "cpu"],
|
|
293
|
+
capture_output=True,
|
|
294
|
+
)
|
|
295
|
+
df = result.stdout.decode("utf-8").split("\n")
|
|
296
|
+
|
|
297
|
+
try:
|
|
298
|
+
users = list(
|
|
299
|
+
map(
|
|
300
|
+
lambda row: row[: len(df[0].split("ST")[0])][:-1].split(
|
|
301
|
+
" "
|
|
302
|
+
)[-1],
|
|
303
|
+
df,
|
|
304
|
+
)
|
|
305
|
+
)
|
|
306
|
+
status = list(
|
|
307
|
+
map(
|
|
308
|
+
lambda row: row[len(df[0].split("ST")[0]) :]
|
|
309
|
+
.strip()
|
|
310
|
+
.split(" ")[0],
|
|
311
|
+
df,
|
|
312
|
+
)
|
|
313
|
+
)
|
|
314
|
+
nodes = list(
|
|
315
|
+
map(
|
|
316
|
+
lambda row: row[len(df[0].split("NODE")[0]) :]
|
|
317
|
+
.strip()
|
|
318
|
+
.split(" ")[0],
|
|
319
|
+
df,
|
|
320
|
+
)
|
|
321
|
+
)
|
|
322
|
+
node_list = list(
|
|
323
|
+
map(
|
|
324
|
+
lambda row: row[len(df[0].split("NODELIST(REASON)")[0]) :],
|
|
325
|
+
df,
|
|
326
|
+
)
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
to_queue = list(
|
|
330
|
+
zip(
|
|
331
|
+
users,
|
|
332
|
+
status,
|
|
333
|
+
nodes,
|
|
334
|
+
node_list,
|
|
335
|
+
)
|
|
336
|
+
)[1:]
|
|
337
|
+
|
|
338
|
+
# Check if our job is running (status "R")
|
|
339
|
+
if job_id and not job_running:
|
|
340
|
+
job_position = None
|
|
341
|
+
total_jobs = len(to_queue)
|
|
342
|
+
|
|
343
|
+
for i, (user, stat, node_count, node_lst) in enumerate(
|
|
344
|
+
to_queue
|
|
345
|
+
):
|
|
346
|
+
# Find our job by checking job IDs in squeue output
|
|
347
|
+
if i < len(df) - 1:
|
|
348
|
+
job_line = df[i + 1]
|
|
349
|
+
if job_id in job_line:
|
|
350
|
+
if stat == "R":
|
|
351
|
+
job_running = True
|
|
352
|
+
# Get head node
|
|
353
|
+
head_node = self._get_head_node_from_job_id(
|
|
354
|
+
job_id
|
|
355
|
+
)
|
|
356
|
+
if head_node:
|
|
357
|
+
self.logger.info(
|
|
358
|
+
f"Job is running on node {head_node}."
|
|
359
|
+
)
|
|
360
|
+
self.logger.info(
|
|
361
|
+
f"Dashboard should be accessible at http://{head_node}:8888 (if running on cluster)"
|
|
362
|
+
)
|
|
363
|
+
else:
|
|
364
|
+
job_position = i + 1
|
|
365
|
+
break
|
|
366
|
+
|
|
367
|
+
if job_running:
|
|
368
|
+
break
|
|
369
|
+
|
|
370
|
+
# Print queue status periodically
|
|
371
|
+
if time.time() - last_print_time > 30:
|
|
372
|
+
position_str = (
|
|
373
|
+
f"{job_position}/{total_jobs}"
|
|
374
|
+
if job_position
|
|
375
|
+
else "unknown"
|
|
376
|
+
)
|
|
377
|
+
print(
|
|
378
|
+
f"Waiting for job... (Position in queue : {position_str})"
|
|
379
|
+
)
|
|
380
|
+
last_print_time = time.time()
|
|
381
|
+
|
|
382
|
+
# Update the queue log
|
|
383
|
+
if time.time() - start_time > 60:
|
|
384
|
+
start_time = time.time()
|
|
385
|
+
# Log to file only, no print
|
|
386
|
+
with open(queue_log_file, "a") as f:
|
|
387
|
+
f.write(f"Update time: {time.strftime('%H:%M:%S')}\n")
|
|
388
|
+
|
|
389
|
+
if current_queue is None or current_queue != to_queue:
|
|
390
|
+
current_queue = to_queue
|
|
391
|
+
with open(queue_log_file, "w") as f:
|
|
392
|
+
text = f"Current queue ({time.strftime('%H:%M:%S')}):\n"
|
|
393
|
+
format_row = "{:>30}" * (len(current_queue[0]))
|
|
394
|
+
for user, status, nodes, node_list in current_queue:
|
|
395
|
+
text += (
|
|
396
|
+
format_row.format(user, status, nodes, node_list)
|
|
397
|
+
+ "\n"
|
|
398
|
+
)
|
|
399
|
+
text += "\n"
|
|
400
|
+
f.write(text)
|
|
401
|
+
except Exception as e:
|
|
402
|
+
# If squeue format changes or fails parsing, don't crash
|
|
403
|
+
pass
|
|
404
|
+
|
|
405
|
+
# Wait for the job to finish while printing the log
|
|
406
|
+
self.logger.info("Job started! Waiting for the job to finish...")
|
|
407
|
+
log_cursor_position = 0
|
|
408
|
+
job_finished = False
|
|
409
|
+
while not job_finished:
|
|
410
|
+
time.sleep(0.25)
|
|
411
|
+
if os.path.exists(os.path.join(self.launcher.project_path, "result.pkl")):
|
|
412
|
+
job_finished = True
|
|
413
|
+
else:
|
|
414
|
+
with open(
|
|
415
|
+
os.path.join(self.launcher.project_path, "{}.log".format(job_name)),
|
|
416
|
+
"r",
|
|
417
|
+
) as f:
|
|
418
|
+
f.seek(log_cursor_position)
|
|
419
|
+
text = f.read()
|
|
420
|
+
if text != "":
|
|
421
|
+
print(text, end="")
|
|
422
|
+
self.logger.info(text.strip())
|
|
423
|
+
log_cursor_position = f.tell()
|
|
424
|
+
|
|
425
|
+
self.logger.info("Job finished!")
|
|
426
|
+
|
|
427
|
+
def _get_head_node_from_job_id(
|
|
428
|
+
self, job_id: str, ssh_client: paramiko.SSHClient = None
|
|
429
|
+
) -> str:
|
|
430
|
+
"""Get the head node name from a SLURM job ID"""
|
|
431
|
+
try:
|
|
432
|
+
# Execute scontrol show job
|
|
433
|
+
if ssh_client:
|
|
434
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
435
|
+
f"scontrol show job {job_id}"
|
|
436
|
+
)
|
|
437
|
+
output = stdout.read().decode("utf-8")
|
|
438
|
+
if stderr.read().decode("utf-8"):
|
|
439
|
+
return None
|
|
440
|
+
else:
|
|
441
|
+
result = subprocess.run(
|
|
442
|
+
["scontrol", "show", "job", job_id],
|
|
443
|
+
capture_output=True,
|
|
444
|
+
text=True,
|
|
445
|
+
)
|
|
446
|
+
if result.returncode != 0:
|
|
447
|
+
return None
|
|
448
|
+
output = result.stdout
|
|
449
|
+
|
|
450
|
+
# Extract NodeList from output
|
|
451
|
+
node_list_match = re.search(r"NodeList=([^\s]+)", output)
|
|
452
|
+
if not node_list_match:
|
|
453
|
+
return None
|
|
454
|
+
|
|
455
|
+
node_list = node_list_match.group(1)
|
|
456
|
+
|
|
457
|
+
# Get hostnames from NodeList using scontrol show hostnames
|
|
458
|
+
if ssh_client:
|
|
459
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
460
|
+
f"scontrol show hostnames {node_list}"
|
|
461
|
+
)
|
|
462
|
+
hostnames_output = stdout.read().decode("utf-8")
|
|
463
|
+
if stderr.read().decode("utf-8"):
|
|
464
|
+
return None
|
|
465
|
+
else:
|
|
466
|
+
result = subprocess.run(
|
|
467
|
+
["scontrol", "show", "hostnames", node_list],
|
|
468
|
+
capture_output=True,
|
|
469
|
+
text=True,
|
|
470
|
+
)
|
|
471
|
+
if result.returncode != 0:
|
|
472
|
+
return None
|
|
473
|
+
output = result.stdout
|
|
474
|
+
|
|
475
|
+
# Get first hostname (head node)
|
|
476
|
+
hostnames = hostnames_output.strip().split("\n")
|
|
477
|
+
if hostnames and hostnames[0]:
|
|
478
|
+
return hostnames[0].strip()
|
|
479
|
+
|
|
480
|
+
return None
|
|
481
|
+
except Exception as e:
|
|
482
|
+
self.logger.warning(f"Failed to get head node from job ID {job_id}: {e}")
|
|
483
|
+
return None
|
|
484
|
+
|
|
485
|
+
def _launch_server(self, cancel_old_jobs: bool = True, wait: bool = True):
|
|
486
|
+
"""Launch the server on the cluster and run the function using the ressources."""
|
|
487
|
+
connected = False
|
|
488
|
+
self.logger.info("Connecting to the cluster...")
|
|
489
|
+
ssh_client = paramiko.SSHClient()
|
|
490
|
+
self.ssh_client = ssh_client
|
|
491
|
+
ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
492
|
+
while not connected:
|
|
493
|
+
try:
|
|
494
|
+
if self.launcher.server_password is None:
|
|
495
|
+
# Add ssh key
|
|
496
|
+
self.launcher.server_password = getpass(
|
|
497
|
+
"Enter your cluster password: "
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
ssh_client.connect(
|
|
501
|
+
hostname=self.launcher.server_ssh,
|
|
502
|
+
username=self.launcher.server_username,
|
|
503
|
+
password=self.launcher.server_password,
|
|
504
|
+
)
|
|
505
|
+
sftp = ssh_client.open_sftp()
|
|
506
|
+
connected = True
|
|
507
|
+
except paramiko.ssh_exception.AuthenticationException:
|
|
508
|
+
self.launcher.server_password = None
|
|
509
|
+
self.logger.warning("Wrong password, please try again.")
|
|
510
|
+
|
|
511
|
+
# Setup pyenv Python version if available
|
|
512
|
+
self.pyenv_python_cmd = None
|
|
513
|
+
if hasattr(self.launcher, "local_python_version"):
|
|
514
|
+
self.pyenv_python_cmd = self._setup_pyenv_python(
|
|
515
|
+
ssh_client, self.launcher.local_python_version
|
|
516
|
+
)
|
|
517
|
+
|
|
518
|
+
# Check Python version compatibility (with pyenv if available)
|
|
519
|
+
is_compatible = self._check_python_version_compatibility(
|
|
520
|
+
ssh_client, self.pyenv_python_cmd
|
|
521
|
+
)
|
|
522
|
+
self.python_version_compatible = is_compatible
|
|
523
|
+
|
|
524
|
+
# Define project directory on cluster (organized by project name)
|
|
525
|
+
project_dir = f"slurmray-server/{self.launcher.project_name}"
|
|
526
|
+
|
|
527
|
+
# Write server script
|
|
528
|
+
self._write_server_script()
|
|
529
|
+
|
|
530
|
+
self.logger.info("Downloading server...")
|
|
531
|
+
|
|
532
|
+
# Generate requirements first to check venv hash
|
|
533
|
+
self._generate_requirements()
|
|
534
|
+
|
|
535
|
+
# Add slurmray (unpinned for now to match legacy behavior, but could be pinned)
|
|
536
|
+
# Check if slurmray is already in requirements.txt to avoid duplicates
|
|
537
|
+
req_file = f"{self.launcher.project_path}/requirements.txt"
|
|
538
|
+
with open(req_file, "r") as f:
|
|
539
|
+
content = f.read()
|
|
540
|
+
if "slurmray" not in content.lower():
|
|
541
|
+
with open(req_file, "a") as f:
|
|
542
|
+
f.write("slurmray\n")
|
|
543
|
+
|
|
544
|
+
# Check if venv can be reused based on requirements hash
|
|
545
|
+
dep_manager = DependencyManager(self.launcher.project_path, self.logger)
|
|
546
|
+
req_file = os.path.join(self.launcher.project_path, "requirements.txt")
|
|
547
|
+
|
|
548
|
+
should_recreate_venv = True
|
|
549
|
+
if self.launcher.force_reinstall_venv:
|
|
550
|
+
# Force recreation: remove venv if it exists
|
|
551
|
+
self.logger.info("Force reinstall enabled: removing existing virtualenv...")
|
|
552
|
+
ssh_client.exec_command(f"rm -rf {project_dir}/.venv")
|
|
553
|
+
should_recreate_venv = True
|
|
554
|
+
elif os.path.exists(req_file):
|
|
555
|
+
with open(req_file, "r") as f:
|
|
556
|
+
req_lines = f.readlines()
|
|
557
|
+
# Check remote hash (if venv exists on remote)
|
|
558
|
+
remote_hash_file = f"{project_dir}/.slogs/venv_hash.txt"
|
|
559
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
560
|
+
f"test -f {remote_hash_file} && cat {remote_hash_file} || echo ''"
|
|
561
|
+
)
|
|
562
|
+
remote_hash = stdout.read().decode("utf-8").strip()
|
|
563
|
+
current_hash = dep_manager.compute_requirements_hash(req_lines)
|
|
564
|
+
|
|
565
|
+
if remote_hash and remote_hash == current_hash:
|
|
566
|
+
# Hash matches, check if venv exists
|
|
567
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
568
|
+
f"test -d {project_dir}/.venv && echo exists || echo missing"
|
|
569
|
+
)
|
|
570
|
+
venv_exists = stdout.read().decode("utf-8").strip() == "exists"
|
|
571
|
+
if venv_exists:
|
|
572
|
+
should_recreate_venv = False
|
|
573
|
+
self.logger.info(
|
|
574
|
+
"Virtualenv can be reused (requirements hash matches)"
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
# Optimize requirements
|
|
578
|
+
# Assuming standard path structure on Slurm cluster (Curnagl)
|
|
579
|
+
venv_cmd = (
|
|
580
|
+
f"cd {project_dir} && source .venv/bin/activate &&"
|
|
581
|
+
if not should_recreate_venv
|
|
582
|
+
else ""
|
|
583
|
+
)
|
|
584
|
+
req_file_to_push = self._optimize_requirements(ssh_client, venv_cmd)
|
|
585
|
+
|
|
586
|
+
# Copy files from the project to the server
|
|
587
|
+
for file in os.listdir(self.launcher.project_path):
|
|
588
|
+
if file.endswith(".py") or file.endswith(".pkl") or file.endswith(".sh"):
|
|
589
|
+
if file == "requirements.txt":
|
|
590
|
+
continue
|
|
591
|
+
sftp.put(os.path.join(self.launcher.project_path, file), file)
|
|
592
|
+
|
|
593
|
+
# Smart cleanup: preserve venv if hash matches, only clean server logs
|
|
594
|
+
ssh_client.exec_command(f"mkdir -p {project_dir}/.slogs/server")
|
|
595
|
+
if should_recreate_venv:
|
|
596
|
+
# Clean server logs only (venv will be recreated by script if needed)
|
|
597
|
+
ssh_client.exec_command(f"rm -rf {project_dir}/.slogs/server/*")
|
|
598
|
+
# Create flag file to force venv recreation in script
|
|
599
|
+
if self.launcher.force_reinstall_venv:
|
|
600
|
+
ssh_client.exec_command(f"touch {project_dir}/.force_reinstall")
|
|
601
|
+
self.logger.info(
|
|
602
|
+
"Virtualenv will be recreated if needed (requirements changed or missing)"
|
|
603
|
+
)
|
|
604
|
+
else:
|
|
605
|
+
# Clean server logs only, preserve venv
|
|
606
|
+
ssh_client.exec_command(f"rm -rf {project_dir}/.slogs/server/*")
|
|
607
|
+
# Remove flag file if it exists
|
|
608
|
+
ssh_client.exec_command(f"rm -f {project_dir}/.force_reinstall")
|
|
609
|
+
self.logger.info("Preserving virtualenv (requirements unchanged)")
|
|
610
|
+
# Filter valid files
|
|
611
|
+
valid_files = []
|
|
612
|
+
for file in self.launcher.files:
|
|
613
|
+
# Skip invalid paths
|
|
614
|
+
if (
|
|
615
|
+
not file
|
|
616
|
+
or file == "."
|
|
617
|
+
or file == ".."
|
|
618
|
+
or file.startswith("./")
|
|
619
|
+
or file.startswith("../")
|
|
620
|
+
):
|
|
621
|
+
self.logger.warning(f"Skipping invalid file path: {file}")
|
|
622
|
+
continue
|
|
623
|
+
valid_files.append(file)
|
|
624
|
+
|
|
625
|
+
# Use incremental sync for local files
|
|
626
|
+
if valid_files:
|
|
627
|
+
self._sync_local_files_incremental(
|
|
628
|
+
sftp, project_dir, valid_files, ssh_client
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# Copy the requirements.txt (optimized) to the server
|
|
632
|
+
sftp.put(req_file_to_push, "requirements.txt")
|
|
633
|
+
|
|
634
|
+
# Store venv hash on remote for future checks
|
|
635
|
+
if os.path.exists(req_file):
|
|
636
|
+
with open(req_file, "r") as f:
|
|
637
|
+
req_lines = f.readlines()
|
|
638
|
+
current_hash = dep_manager.compute_requirements_hash(req_lines)
|
|
639
|
+
# Ensure .slogs directory exists on remote
|
|
640
|
+
ssh_client.exec_command(f"mkdir -p {project_dir}/.slogs")
|
|
641
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
642
|
+
f"echo '{current_hash}' > {project_dir}/.slogs/venv_hash.txt"
|
|
643
|
+
)
|
|
644
|
+
stdout.channel.recv_exit_status()
|
|
645
|
+
# Also store locally
|
|
646
|
+
dep_manager.store_venv_hash(current_hash)
|
|
647
|
+
|
|
648
|
+
# Update retention timestamp
|
|
649
|
+
self._update_retention_timestamp(
|
|
650
|
+
ssh_client, project_dir, self.launcher.retention_days
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Write and copy the server script to the server
|
|
654
|
+
self._write_slurmray_server_sh()
|
|
655
|
+
sftp.put(
|
|
656
|
+
os.path.join(self.launcher.project_path, "slurmray_server.sh"),
|
|
657
|
+
"slurmray_server.sh",
|
|
658
|
+
)
|
|
659
|
+
# Chmod script
|
|
660
|
+
sftp.chmod("slurmray_server.sh", 0o755)
|
|
661
|
+
|
|
662
|
+
# Run the server
|
|
663
|
+
self.logger.info("Running server...")
|
|
664
|
+
stdin, stdout, stderr = ssh_client.exec_command("./slurmray_server.sh")
|
|
665
|
+
|
|
666
|
+
# Read the output in real time and capture job ID
|
|
667
|
+
job_id = None
|
|
668
|
+
tunnel = None
|
|
669
|
+
output_lines = []
|
|
670
|
+
|
|
671
|
+
# Read output line by line to capture job ID
|
|
672
|
+
while True:
|
|
673
|
+
line = stdout.readline()
|
|
674
|
+
if not line:
|
|
675
|
+
break
|
|
676
|
+
output_lines.append(line)
|
|
677
|
+
|
|
678
|
+
# Double output: console + log file
|
|
679
|
+
print(line, end="")
|
|
680
|
+
self.logger.info(line.strip())
|
|
681
|
+
|
|
682
|
+
# Try to extract job ID from output (format: "Submitted batch job 12345")
|
|
683
|
+
if not job_id:
|
|
684
|
+
match = re.search(r"Submitted batch job (\d+)", line)
|
|
685
|
+
if match:
|
|
686
|
+
job_id = match.group(1)
|
|
687
|
+
self.job_id = job_id
|
|
688
|
+
self.logger.info(f"Job ID detected: {job_id}")
|
|
689
|
+
|
|
690
|
+
exit_status = stdout.channel.recv_exit_status()
|
|
691
|
+
|
|
692
|
+
# Check if script failed - fail-fast immediately
|
|
693
|
+
if exit_status != 0:
|
|
694
|
+
# Collect error information
|
|
695
|
+
stderr_output = stderr.read().decode("utf-8")
|
|
696
|
+
error_msg = f"Server script exited with non-zero status: {exit_status}"
|
|
697
|
+
if stderr_output.strip():
|
|
698
|
+
error_msg += f"\nScript errors:\n{stderr_output}"
|
|
699
|
+
|
|
700
|
+
# Log the error
|
|
701
|
+
self.logger.error(error_msg)
|
|
702
|
+
|
|
703
|
+
# Close tunnel if open
|
|
704
|
+
if tunnel:
|
|
705
|
+
try:
|
|
706
|
+
tunnel.__exit__(None, None, None)
|
|
707
|
+
except Exception:
|
|
708
|
+
pass
|
|
709
|
+
|
|
710
|
+
# Raise exception immediately (fail-fast)
|
|
711
|
+
raise RuntimeError(error_msg)
|
|
712
|
+
|
|
713
|
+
# If job ID not found in output, try to find it via squeue
|
|
714
|
+
if not job_id:
|
|
715
|
+
self.logger.info("Job ID not found in output, trying to find via squeue...")
|
|
716
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
717
|
+
f"squeue -u {self.launcher.server_username} -o '%i %j' --noheader"
|
|
718
|
+
)
|
|
719
|
+
squeue_output = stdout.read().decode("utf-8")
|
|
720
|
+
# Try to find job matching project name pattern
|
|
721
|
+
if self.launcher.project_name:
|
|
722
|
+
for line in squeue_output.strip().split("\n"):
|
|
723
|
+
parts = line.strip().split()
|
|
724
|
+
if len(parts) >= 2 and self.launcher.project_name in parts[1]:
|
|
725
|
+
job_id = parts[0]
|
|
726
|
+
self.job_id = job_id
|
|
727
|
+
self.logger.info(f"Found job ID via squeue: {job_id}")
|
|
728
|
+
break
|
|
729
|
+
|
|
730
|
+
# If job ID found and we are not waiting, return it and stop
|
|
731
|
+
if job_id and not wait:
|
|
732
|
+
self.logger.info("Async mode: Job submitted with ID {}. Disconnecting...".format(job_id))
|
|
733
|
+
return job_id
|
|
734
|
+
|
|
735
|
+
# If no job ID found and not waiting? We should probably warn or return None
|
|
736
|
+
if not job_id and not wait:
|
|
737
|
+
self.logger.warning("Async mode: Could not detect job ID. Returning None.")
|
|
738
|
+
return None
|
|
739
|
+
|
|
740
|
+
# Loop for monitoring (only if wait=True)
|
|
741
|
+
|
|
742
|
+
# If job ID found, wait for job to be running and set up tunnel
|
|
743
|
+
if job_id:
|
|
744
|
+
self.logger.info(f"Waiting for job {job_id} to start running...")
|
|
745
|
+
max_wait_time = 300 # Wait up to 5 minutes
|
|
746
|
+
wait_start = time.time()
|
|
747
|
+
job_running = False
|
|
748
|
+
|
|
749
|
+
while time.time() - wait_start < max_wait_time:
|
|
750
|
+
time.sleep(2)
|
|
751
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
752
|
+
f"squeue -j {job_id} -o '%T' --noheader"
|
|
753
|
+
)
|
|
754
|
+
status_output = stdout.read().decode("utf-8").strip()
|
|
755
|
+
if status_output == "R":
|
|
756
|
+
job_running = True
|
|
757
|
+
break
|
|
758
|
+
|
|
759
|
+
if job_running:
|
|
760
|
+
# Get head node
|
|
761
|
+
head_node = self._get_head_node_from_job_id(job_id, ssh_client)
|
|
762
|
+
if head_node:
|
|
763
|
+
self.logger.info(
|
|
764
|
+
f"Job is running on node {head_node}. Setting up SSH tunnel for dashboard..."
|
|
765
|
+
)
|
|
766
|
+
try:
|
|
767
|
+
tunnel = SSHTunnel(
|
|
768
|
+
ssh_host=self.launcher.server_ssh,
|
|
769
|
+
ssh_username=self.launcher.server_username,
|
|
770
|
+
ssh_password=self.launcher.server_password,
|
|
771
|
+
remote_host=head_node,
|
|
772
|
+
local_port=8888,
|
|
773
|
+
remote_port=8265,
|
|
774
|
+
logger=self.logger,
|
|
775
|
+
)
|
|
776
|
+
tunnel.__enter__()
|
|
777
|
+
self.logger.info(
|
|
778
|
+
"Dashboard accessible at http://localhost:8888"
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
# Wait for job to complete while maintaining tunnel
|
|
782
|
+
# Check periodically if job is still running
|
|
783
|
+
while True:
|
|
784
|
+
time.sleep(5)
|
|
785
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
786
|
+
f"squeue -j {job_id} -o '%T' --noheader"
|
|
787
|
+
)
|
|
788
|
+
status_output = stdout.read().decode("utf-8").strip()
|
|
789
|
+
if status_output != "R":
|
|
790
|
+
# Job finished or no longer running
|
|
791
|
+
break
|
|
792
|
+
except Exception as e:
|
|
793
|
+
self.logger.warning(f"Failed to create SSH tunnel: {e}")
|
|
794
|
+
self.logger.info(
|
|
795
|
+
"Dashboard will not be accessible via port forwarding"
|
|
796
|
+
)
|
|
797
|
+
tunnel = None
|
|
798
|
+
else:
|
|
799
|
+
self.logger.warning(
|
|
800
|
+
"Job did not start running within timeout, skipping tunnel setup"
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
# Close tunnel if it was created
|
|
804
|
+
if tunnel:
|
|
805
|
+
tunnel.__exit__(None, None, None)
|
|
806
|
+
|
|
807
|
+
# Downloading result
|
|
808
|
+
self.logger.info("Downloading result...")
|
|
809
|
+
project_dir = f"slurmray-server/{self.launcher.project_name}"
|
|
810
|
+
try:
|
|
811
|
+
sftp.get(
|
|
812
|
+
f"{project_dir}/.slogs/server/result.pkl",
|
|
813
|
+
os.path.join(self.launcher.project_path, "result.pkl"),
|
|
814
|
+
)
|
|
815
|
+
self.logger.info("Result downloaded!")
|
|
816
|
+
|
|
817
|
+
# Clean up remote temporary files (preserve venv and cache)
|
|
818
|
+
self.logger.info("Cleaning up remote temporary files...")
|
|
819
|
+
ssh_client.exec_command(
|
|
820
|
+
f"cd {project_dir} && "
|
|
821
|
+
"find . -maxdepth 1 -type f \\( -name '*.py' -o -name '*.pkl' -o -name '*.sh' \\) "
|
|
822
|
+
"! -name 'requirements.txt' -delete 2>/dev/null || true && "
|
|
823
|
+
"rm -rf .slogs/server 2>/dev/null || true"
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Clean up local temporary files after successful download
|
|
827
|
+
self._cleanup_local_temp_files()
|
|
828
|
+
|
|
829
|
+
except FileNotFoundError:
|
|
830
|
+
# Check for errors
|
|
831
|
+
stderr_lines = stderr.readlines()
|
|
832
|
+
if stderr_lines:
|
|
833
|
+
self.logger.error("Errors:")
|
|
834
|
+
for line in stderr_lines:
|
|
835
|
+
print(line, end="")
|
|
836
|
+
self.logger.error(line.strip())
|
|
837
|
+
self.logger.error("An error occured, please check the logs.")
|
|
838
|
+
|
|
839
|
+
def _cleanup_local_temp_files(self):
|
|
840
|
+
"""Clean up local temporary files after successful execution"""
|
|
841
|
+
temp_files = [
|
|
842
|
+
"func_source.py",
|
|
843
|
+
"func_name.txt",
|
|
844
|
+
"func.pkl",
|
|
845
|
+
"args.pkl",
|
|
846
|
+
"result.pkl",
|
|
847
|
+
"spython.py",
|
|
848
|
+
"sbatch.sh",
|
|
849
|
+
"slurmray_server.py",
|
|
850
|
+
"requirements_to_install.txt",
|
|
851
|
+
]
|
|
852
|
+
|
|
853
|
+
for temp_file in temp_files:
|
|
854
|
+
file_path = os.path.join(self.launcher.project_path, temp_file)
|
|
855
|
+
if os.path.exists(file_path):
|
|
856
|
+
os.remove(file_path)
|
|
857
|
+
self.logger.debug(f"Removed temporary file: {temp_file}")
|
|
858
|
+
|
|
859
|
+
def _write_server_script(self):
|
|
860
|
+
"""This funtion will write a script with the given specifications to run slurmray on the cluster"""
|
|
861
|
+
self.logger.info("Writing slurmray server script...")
|
|
862
|
+
template_file = os.path.join(
|
|
863
|
+
self.launcher.module_path, "assets", "slurmray_server_template.py"
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
MODULES = self.launcher.modules
|
|
867
|
+
NODE_NBR = self.launcher.node_nbr
|
|
868
|
+
USE_GPU = self.launcher.use_gpu
|
|
869
|
+
MEMORY = self.launcher.memory
|
|
870
|
+
MAX_RUNNING_TIME = self.launcher.max_running_time
|
|
871
|
+
|
|
872
|
+
# ===== Modified the template script =====
|
|
873
|
+
with open(template_file, "r") as f:
|
|
874
|
+
text = f.read()
|
|
875
|
+
text = text.replace("{{MODULES}}", str(MODULES))
|
|
876
|
+
text = text.replace("{{NODE_NBR}}", str(NODE_NBR))
|
|
877
|
+
text = text.replace("{{USE_GPU}}", str(USE_GPU))
|
|
878
|
+
text = text.replace("{{MEMORY}}", str(MEMORY))
|
|
879
|
+
text = text.replace("{{MAX_RUNNING_TIME}}", str(MAX_RUNNING_TIME))
|
|
880
|
+
|
|
881
|
+
# ===== Save the script =====
|
|
882
|
+
script_file = "slurmray_server.py"
|
|
883
|
+
with open(os.path.join(self.launcher.project_path, script_file), "w") as f:
|
|
884
|
+
f.write(text)
|
|
885
|
+
|
|
886
|
+
def _write_slurmray_server_sh(self):
|
|
887
|
+
"""Write the slurmray_server.sh script with pyenv support if available"""
|
|
888
|
+
# Determine Python command
|
|
889
|
+
if self.pyenv_python_cmd:
|
|
890
|
+
# Use pyenv: the command already includes eval and pyenv shell
|
|
891
|
+
python_cmd = self.pyenv_python_cmd.split(" && ")[
|
|
892
|
+
-1
|
|
893
|
+
] # Extract just "python" from the command
|
|
894
|
+
python3_cmd = python_cmd.replace("python", "python3")
|
|
895
|
+
pyenv_setup = self.pyenv_python_cmd.rsplit(" && ", 1)[
|
|
896
|
+
0
|
|
897
|
+
] # Get "eval ... && pyenv shell X.Y.Z"
|
|
898
|
+
use_pyenv = True
|
|
899
|
+
else:
|
|
900
|
+
# Fallback to system Python
|
|
901
|
+
python_cmd = "python"
|
|
902
|
+
python3_cmd = "python3"
|
|
903
|
+
pyenv_setup = ""
|
|
904
|
+
use_pyenv = False
|
|
905
|
+
|
|
906
|
+
# Filter out python module from modules list if using pyenv
|
|
907
|
+
modules_list = self.launcher.modules.copy()
|
|
908
|
+
if use_pyenv:
|
|
909
|
+
modules_list = [m for m in modules_list if not m.startswith("python")]
|
|
910
|
+
|
|
911
|
+
modules_str = " ".join(modules_list) if modules_list else ""
|
|
912
|
+
project_dir = f"slurmray-server/{self.launcher.project_name}"
|
|
913
|
+
|
|
914
|
+
script_content = f"""#!/bin/sh
|
|
915
|
+
|
|
916
|
+
echo "Installing slurmray server"
|
|
917
|
+
|
|
918
|
+
# Copy files
|
|
919
|
+
mv -t {project_dir} requirements.txt slurmray_server.py
|
|
920
|
+
mv -t {project_dir}/.slogs/server func.pkl args.pkl
|
|
921
|
+
cd {project_dir}
|
|
922
|
+
|
|
923
|
+
# Load modules
|
|
924
|
+
# Using specific versions for Curnagl compatibility (SLURM 24.05.3)
|
|
925
|
+
# gcc/13.2.0: Latest GCC version
|
|
926
|
+
# python module is loaded only if pyenv is not available
|
|
927
|
+
# cuda/12.6.2: Latest CUDA version
|
|
928
|
+
# cudnn/9.2.0.82-12: Compatible with cuda/12.6.2
|
|
929
|
+
"""
|
|
930
|
+
|
|
931
|
+
if modules_str:
|
|
932
|
+
script_content += f"module load {modules_str}\n"
|
|
933
|
+
|
|
934
|
+
script_content += f"""
|
|
935
|
+
# Setup pyenv if available
|
|
936
|
+
"""
|
|
937
|
+
|
|
938
|
+
if use_pyenv:
|
|
939
|
+
script_content += f"""# Using pyenv for Python version management
|
|
940
|
+
export PATH="$HOME/.pyenv/bin:/usr/local/bin:/opt/pyenv/bin:$PATH"
|
|
941
|
+
{pyenv_setup}
|
|
942
|
+
"""
|
|
943
|
+
else:
|
|
944
|
+
script_content += """# pyenv not available, using system Python
|
|
945
|
+
"""
|
|
946
|
+
|
|
947
|
+
script_content += f"""
|
|
948
|
+
# Create venv if it doesn't exist (hash check is done in Python before file upload)
|
|
949
|
+
# If venv needs recreation, it has already been removed by Python
|
|
950
|
+
# Check for force reinstall flag
|
|
951
|
+
if [ -f ".force_reinstall" ]; then
|
|
952
|
+
echo "Force reinstall flag detected: removing existing virtualenv..."
|
|
953
|
+
rm -rf .venv
|
|
954
|
+
rm -f .force_reinstall
|
|
955
|
+
fi
|
|
956
|
+
|
|
957
|
+
if [ ! -d ".venv" ]; then
|
|
958
|
+
echo "Creating virtualenv..."
|
|
959
|
+
"""
|
|
960
|
+
|
|
961
|
+
if use_pyenv:
|
|
962
|
+
script_content += f""" {pyenv_setup} && {python3_cmd} -m venv .venv
|
|
963
|
+
"""
|
|
964
|
+
else:
|
|
965
|
+
script_content += f""" {python3_cmd} -m venv .venv
|
|
966
|
+
"""
|
|
967
|
+
|
|
968
|
+
script_content += f"""else
|
|
969
|
+
echo "Using existing virtualenv (requirements unchanged)..."
|
|
970
|
+
VENV_EXISTED=true
|
|
971
|
+
fi
|
|
972
|
+
|
|
973
|
+
source .venv/bin/activate
|
|
974
|
+
|
|
975
|
+
# Install requirements if file exists and is not empty
|
|
976
|
+
if [ -f requirements.txt ]; then
|
|
977
|
+
# Check if requirements.txt is empty (only whitespace)
|
|
978
|
+
if [ -s requirements.txt ]; then
|
|
979
|
+
echo "📥 Installing dependencies from requirements.txt..."
|
|
980
|
+
|
|
981
|
+
# Get installed packages once (fast, single command) - create lookup file
|
|
982
|
+
uv pip list --format=freeze 2>/dev/null | sed 's/==/ /' | awk '{{print $1" "$2}}' > /tmp/installed_packages.txt || touch /tmp/installed_packages.txt
|
|
983
|
+
|
|
984
|
+
# Process requirements: filter duplicates and check what needs installation
|
|
985
|
+
INSTALL_ERRORS=0
|
|
986
|
+
SKIPPED_COUNT=0
|
|
987
|
+
> /tmp/to_install.txt # Clear file
|
|
988
|
+
|
|
989
|
+
while IFS= read -r line || [ -n "$line" ]; do
|
|
990
|
+
# Skip empty lines and comments
|
|
991
|
+
line=$(echo "$line" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
|
|
992
|
+
if [ -z "$line" ] || [ "${{line#"#"}}" != "$line" ]; then
|
|
993
|
+
continue
|
|
994
|
+
fi
|
|
995
|
+
|
|
996
|
+
# Extract package name (remove version specifiers and extras)
|
|
997
|
+
pkg_name=$(echo "$line" | sed 's/[<>=!].*//' | sed 's/\\[.*\\]//' | sed 's/[[:space:]]*//' | tr '[:upper:]' '[:lower:]')
|
|
998
|
+
if [ -z "$pkg_name" ]; then
|
|
999
|
+
continue
|
|
1000
|
+
fi
|
|
1001
|
+
|
|
1002
|
+
# Skip duplicates (check if we've already processed this package)
|
|
1003
|
+
if grep -qi "^$pkg_name$" /tmp/seen_packages.txt 2>/dev/null; then
|
|
1004
|
+
continue
|
|
1005
|
+
fi
|
|
1006
|
+
echo "$pkg_name" >> /tmp/seen_packages.txt
|
|
1007
|
+
|
|
1008
|
+
# Extract required version if present
|
|
1009
|
+
required_version=""
|
|
1010
|
+
if echo "$line" | grep -q "=="; then
|
|
1011
|
+
required_version=$(echo "$line" | sed 's/.*==\\([^;]*\\).*/\\1/' | sed 's/[[:space:]]*//')
|
|
1012
|
+
fi
|
|
1013
|
+
|
|
1014
|
+
# Check if package is already installed with correct version
|
|
1015
|
+
installed_version=$(grep -i "^$pkg_name " /tmp/installed_packages.txt 2>/dev/null | awk '{{print $2}}' | head -1)
|
|
1016
|
+
|
|
1017
|
+
if [ -n "$installed_version" ]; then
|
|
1018
|
+
if [ -z "$required_version" ] || [ "$installed_version" = "$required_version" ]; then
|
|
1019
|
+
echo " ⏭️ $pkg_name==$installed_version (already installed)"
|
|
1020
|
+
SKIPPED_COUNT=$((SKIPPED_COUNT + 1))
|
|
1021
|
+
continue
|
|
1022
|
+
fi
|
|
1023
|
+
fi
|
|
1024
|
+
|
|
1025
|
+
# Package not installed or version mismatch, add to install list
|
|
1026
|
+
echo "$line" >> /tmp/to_install.txt
|
|
1027
|
+
done < requirements.txt
|
|
1028
|
+
|
|
1029
|
+
# Install packages that need installation
|
|
1030
|
+
if [ -s /tmp/to_install.txt ]; then
|
|
1031
|
+
> /tmp/install_errors.txt # Track errors
|
|
1032
|
+
while IFS= read -r line; do
|
|
1033
|
+
pkg_name=$(echo "$line" | sed 's/[<>=!].*//' | sed 's/\\[.*\\]//' | sed 's/[[:space:]]*//')
|
|
1034
|
+
if uv pip install --quiet "$line" >/dev/null 2>&1; then
|
|
1035
|
+
echo " ✅ $pkg_name"
|
|
1036
|
+
else
|
|
1037
|
+
echo " ❌ $pkg_name"
|
|
1038
|
+
echo "1" >> /tmp/install_errors.txt
|
|
1039
|
+
# Show error details
|
|
1040
|
+
uv pip install "$line" 2>&1 | grep -E "(error|Error|ERROR|failed|Failed|FAILED)" | head -3 | sed 's/^/ /' || true
|
|
1041
|
+
fi
|
|
1042
|
+
done < /tmp/to_install.txt
|
|
1043
|
+
INSTALL_ERRORS=$(wc -l < /tmp/install_errors.txt 2>/dev/null | tr -d ' ' || echo "0")
|
|
1044
|
+
rm -f /tmp/install_errors.txt
|
|
1045
|
+
fi
|
|
1046
|
+
|
|
1047
|
+
# Count newly installed packages before cleanup
|
|
1048
|
+
NEWLY_INSTALLED=0
|
|
1049
|
+
if [ -s /tmp/to_install.txt ]; then
|
|
1050
|
+
NEWLY_INSTALLED=$(wc -l < /tmp/to_install.txt 2>/dev/null | tr -d ' ' || echo "0")
|
|
1051
|
+
fi
|
|
1052
|
+
|
|
1053
|
+
# Cleanup temp files
|
|
1054
|
+
rm -f /tmp/installed_packages.txt /tmp/seen_packages.txt /tmp/to_install.txt
|
|
1055
|
+
|
|
1056
|
+
if [ $INSTALL_ERRORS -eq 0 ]; then
|
|
1057
|
+
if [ $SKIPPED_COUNT -gt 0 ]; then
|
|
1058
|
+
echo "✅ All dependencies up to date ($SKIPPED_COUNT already installed, $NEWLY_INSTALLED newly installed)"
|
|
1059
|
+
else
|
|
1060
|
+
echo "✅ All dependencies installed successfully"
|
|
1061
|
+
fi
|
|
1062
|
+
else
|
|
1063
|
+
echo "❌ Failed to install $INSTALL_ERRORS package(s)" >&2
|
|
1064
|
+
exit 1
|
|
1065
|
+
fi
|
|
1066
|
+
else
|
|
1067
|
+
if [ "$VENV_EXISTED" = "true" ]; then
|
|
1068
|
+
echo "✅ All dependencies already installed (requirements.txt is empty)"
|
|
1069
|
+
else
|
|
1070
|
+
echo "⚠️ requirements.txt is empty, skipping dependency installation"
|
|
1071
|
+
fi
|
|
1072
|
+
fi
|
|
1073
|
+
else
|
|
1074
|
+
echo "⚠️ No requirements.txt found, skipping dependency installation"
|
|
1075
|
+
fi
|
|
1076
|
+
|
|
1077
|
+
# Fix torch bug (https://github.com/pytorch/pytorch/issues/111469)
|
|
1078
|
+
PYTHON_VERSION=$({python3_cmd} -c 'import sys; print(f"{{sys.version_info.major}}.{{sys.version_info.minor}}")')
|
|
1079
|
+
export LD_LIBRARY_PATH=$HOME/{project_dir}/.venv/lib/python$PYTHON_VERSION/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
|
|
1080
|
+
|
|
1081
|
+
|
|
1082
|
+
# Run server
|
|
1083
|
+
"""
|
|
1084
|
+
|
|
1085
|
+
if use_pyenv:
|
|
1086
|
+
script_content += f"""{pyenv_setup} && {python_cmd} -u slurmray_server.py
|
|
1087
|
+
"""
|
|
1088
|
+
else:
|
|
1089
|
+
script_content += f"""{python_cmd} -u slurmray_server.py
|
|
1090
|
+
"""
|
|
1091
|
+
|
|
1092
|
+
script_file = "slurmray_server.sh"
|
|
1093
|
+
with open(os.path.join(self.launcher.project_path, script_file), "w") as f:
|
|
1094
|
+
f.write(script_content)
|
|
1095
|
+
|
|
1096
|
+
def _push_file(
|
|
1097
|
+
self, file_path: str, sftp: paramiko.SFTPClient, ssh_client: paramiko.SSHClient
|
|
1098
|
+
):
|
|
1099
|
+
"""Push a file to the cluster"""
|
|
1100
|
+
self.logger.info(
|
|
1101
|
+
f"Pushing file {os.path.basename(file_path)} to the cluster..."
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
# Determine the path to the file
|
|
1105
|
+
local_path = file_path
|
|
1106
|
+
local_path_from_pwd = os.path.relpath(local_path, self.launcher.pwd_path)
|
|
1107
|
+
cluster_path = os.path.join(
|
|
1108
|
+
"/users",
|
|
1109
|
+
self.launcher.server_username,
|
|
1110
|
+
"slurmray-server",
|
|
1111
|
+
self.launcher.project_name,
|
|
1112
|
+
local_path_from_pwd,
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
# Create the directory if not exists
|
|
1116
|
+
|
|
1117
|
+
stdin, stdout, stderr = ssh_client.exec_command(
|
|
1118
|
+
f"mkdir -p '{os.path.dirname(cluster_path)}'"
|
|
1119
|
+
)
|
|
1120
|
+
while True:
|
|
1121
|
+
line = stdout.readline()
|
|
1122
|
+
if not line:
|
|
1123
|
+
break
|
|
1124
|
+
# Keep print for real-time feedback on directory creation (optional, could be logger.debug)
|
|
1125
|
+
self.logger.debug(line.strip())
|
|
1126
|
+
time.sleep(1) # Wait for the directory to be created
|
|
1127
|
+
|
|
1128
|
+
sftp.put(file_path, cluster_path)
|
|
1129
|
+
|
|
1130
|
+
def get_result(self, job_id: str) -> Any:
|
|
1131
|
+
"""Get result for a specific job ID"""
|
|
1132
|
+
# Load local result if available (cluster mode or already downloaded)
|
|
1133
|
+
local_path = os.path.join(self.launcher.project_path, "result.pkl")
|
|
1134
|
+
if os.path.exists(local_path):
|
|
1135
|
+
with open(local_path, "rb") as f:
|
|
1136
|
+
return dill.load(f)
|
|
1137
|
+
|
|
1138
|
+
# If clustered/server run, try to fetch it
|
|
1139
|
+
if self.launcher.cluster:
|
|
1140
|
+
# Already checked local path, so it's missing
|
|
1141
|
+
return None
|
|
1142
|
+
|
|
1143
|
+
if self.launcher.server_run:
|
|
1144
|
+
self._connect()
|
|
1145
|
+
try:
|
|
1146
|
+
project_dir = f"slurmray-server/{self.launcher.project_name}"
|
|
1147
|
+
remote_path = f"{project_dir}/.slogs/server/result.pkl"
|
|
1148
|
+
self.ssh_client.open_sftp().get(remote_path, local_path)
|
|
1149
|
+
with open(local_path, "rb") as f:
|
|
1150
|
+
return dill.load(f)
|
|
1151
|
+
except Exception:
|
|
1152
|
+
return None
|
|
1153
|
+
|
|
1154
|
+
return None
|
|
1155
|
+
|
|
1156
|
+
def get_logs(self, job_id: str) -> Any:
|
|
1157
|
+
"""Get logs for a specific job ID"""
|
|
1158
|
+
if not job_id:
|
|
1159
|
+
yield "No Job ID provided."
|
|
1160
|
+
return
|
|
1161
|
+
|
|
1162
|
+
# Attempt to get job info to find log file
|
|
1163
|
+
cmd = f"scontrol show job {job_id}"
|
|
1164
|
+
|
|
1165
|
+
output = ""
|
|
1166
|
+
if self.launcher.cluster:
|
|
1167
|
+
try:
|
|
1168
|
+
res = subprocess.run(cmd.split(), capture_output=True, text=True)
|
|
1169
|
+
output = res.stdout
|
|
1170
|
+
except Exception as e:
|
|
1171
|
+
yield f"Error querying local scontrol: {e}"
|
|
1172
|
+
elif self.launcher.server_run:
|
|
1173
|
+
self._connect()
|
|
1174
|
+
try:
|
|
1175
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(cmd)
|
|
1176
|
+
output = stdout.read().decode("utf-8")
|
|
1177
|
+
except Exception as e:
|
|
1178
|
+
yield f"Error querying remote job info: {e}"
|
|
1179
|
+
return
|
|
1180
|
+
else:
|
|
1181
|
+
yield "Not running on cluster."
|
|
1182
|
+
return
|
|
1183
|
+
|
|
1184
|
+
# Extract JobName
|
|
1185
|
+
# JobName=example_1012-14h22
|
|
1186
|
+
import re
|
|
1187
|
+
match = re.search(r"JobName=([^\s]+)", output)
|
|
1188
|
+
if match:
|
|
1189
|
+
job_name = match.group(1)
|
|
1190
|
+
log_filename = f"{job_name}.log"
|
|
1191
|
+
else:
|
|
1192
|
+
yield "Could not determine log filename via scontrol (Job might be finished). Checking standard path..."
|
|
1193
|
+
# If scontrol failed (job finished), we might just look for ANY log file or standard naming
|
|
1194
|
+
# BUT standard naming includes timestamp. We can't guess it easily without listing.
|
|
1195
|
+
# We can list files matching project name and take newest.
|
|
1196
|
+
log_filename = None
|
|
1197
|
+
|
|
1198
|
+
# Attempt to find log file by listing
|
|
1199
|
+
find_cmd = f"ls -t {self.launcher.project_name}*.log | head -n 1"
|
|
1200
|
+
if self.launcher.cluster:
|
|
1201
|
+
if os.path.exists(self.launcher.project_path):
|
|
1202
|
+
import glob
|
|
1203
|
+
files = glob.glob(os.path.join(self.launcher.project_path, f"{self.launcher.project_name}*.log"))
|
|
1204
|
+
if files:
|
|
1205
|
+
log_filename = os.path.basename(max(files, key=os.path.getmtime))
|
|
1206
|
+
elif self.launcher.server_run:
|
|
1207
|
+
project_dir = f"slurmray-server/{self.launcher.project_name}"
|
|
1208
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(f"cd {project_dir} && ls -t {self.launcher.project_name}*.log | head -n 1")
|
|
1209
|
+
possible_log = stdout.read().decode("utf-8").strip()
|
|
1210
|
+
if possible_log:
|
|
1211
|
+
log_filename = possible_log
|
|
1212
|
+
|
|
1213
|
+
if not log_filename:
|
|
1214
|
+
yield "Log file could not be found."
|
|
1215
|
+
return
|
|
1216
|
+
|
|
1217
|
+
# Read log file
|
|
1218
|
+
if self.launcher.cluster:
|
|
1219
|
+
log_path = os.path.join(self.launcher.project_path, log_filename)
|
|
1220
|
+
if os.path.exists(log_path):
|
|
1221
|
+
with open(log_path, "r") as f:
|
|
1222
|
+
yield from f
|
|
1223
|
+
else:
|
|
1224
|
+
yield f"Log file {log_path} not found."
|
|
1225
|
+
elif self.launcher.server_run:
|
|
1226
|
+
project_dir = f"slurmray-server/{self.launcher.project_name}"
|
|
1227
|
+
remote_log = f"{project_dir}/{log_filename}"
|
|
1228
|
+
try:
|
|
1229
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(f"cat {remote_log}")
|
|
1230
|
+
# Stream output? stdout is a channel.
|
|
1231
|
+
for line in stdout:
|
|
1232
|
+
yield line
|
|
1233
|
+
except Exception as e:
|
|
1234
|
+
yield f"Error reading remote log: {e}"
|