slurmray 6.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of slurmray might be problematic. Click here for more details.
- slurmray/RayLauncher.py +1019 -0
- slurmray/__init__.py +0 -0
- slurmray/__main__.py +5 -0
- slurmray/assets/cleanup_old_projects.py +171 -0
- slurmray/assets/sbatch_template.sh +67 -0
- slurmray/assets/slurmray_server.sh +145 -0
- slurmray/assets/slurmray_server_template.py +28 -0
- slurmray/assets/spython_template.py +113 -0
- slurmray/backend/__init__.py +0 -0
- slurmray/backend/base.py +1040 -0
- slurmray/backend/desi.py +856 -0
- slurmray/backend/local.py +124 -0
- slurmray/backend/remote.py +191 -0
- slurmray/backend/slurm.py +1234 -0
- slurmray/cli.py +904 -0
- slurmray/detection.py +1 -0
- slurmray/file_sync.py +276 -0
- slurmray/scanner.py +441 -0
- slurmray/utils.py +359 -0
- slurmray-6.0.4.dist-info/LICENSE +201 -0
- slurmray-6.0.4.dist-info/METADATA +85 -0
- slurmray-6.0.4.dist-info/RECORD +24 -0
- slurmray-6.0.4.dist-info/WHEEL +4 -0
- slurmray-6.0.4.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import subprocess
|
|
5
|
+
import dill
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from slurmray.backend.base import ClusterBackend
|
|
9
|
+
|
|
10
|
+
class LocalBackend(ClusterBackend):
|
|
11
|
+
"""Backend for local execution (no scheduler)"""
|
|
12
|
+
|
|
13
|
+
def run(self, cancel_old_jobs: bool = True, wait: bool = True) -> Any:
|
|
14
|
+
"""Run the job locally"""
|
|
15
|
+
self.logger.info("No cluster detected, running locally...")
|
|
16
|
+
|
|
17
|
+
self._write_python_script()
|
|
18
|
+
|
|
19
|
+
# Determine log file path
|
|
20
|
+
# Use job name logic similar to Slurm if possible, or just project name
|
|
21
|
+
job_name = f"{self.launcher.project_name}_local"
|
|
22
|
+
log_path = os.path.join(self.launcher.project_path, f"{job_name}.log")
|
|
23
|
+
|
|
24
|
+
with open(log_path, "w") as log_file:
|
|
25
|
+
process = subprocess.Popen(
|
|
26
|
+
[sys.executable, os.path.join(self.launcher.project_path, "spython.py")],
|
|
27
|
+
stdout=log_file,
|
|
28
|
+
stderr=subprocess.STDOUT
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if not wait:
|
|
32
|
+
self.logger.info(f"Local job started asynchronously. Logs: {log_path} PID: {process.pid}")
|
|
33
|
+
return str(process.pid)
|
|
34
|
+
|
|
35
|
+
# Wait for result
|
|
36
|
+
while not os.path.exists(os.path.join(self.launcher.project_path, "result.pkl")):
|
|
37
|
+
time.sleep(0.25)
|
|
38
|
+
# Check if process died
|
|
39
|
+
if process.poll() is not None:
|
|
40
|
+
# Process finished but no result?
|
|
41
|
+
if not os.path.exists(os.path.join(self.launcher.project_path, "result.pkl")):
|
|
42
|
+
# Check logs
|
|
43
|
+
with open(log_path, "r") as f:
|
|
44
|
+
print(f.read())
|
|
45
|
+
raise RuntimeError("Local process succeeded but no result.pkl found (or failed).")
|
|
46
|
+
|
|
47
|
+
with open(os.path.join(self.launcher.project_path, "result.pkl"), "rb") as f:
|
|
48
|
+
result = dill.load(f)
|
|
49
|
+
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
def get_result(self, job_id: str) -> Any:
|
|
53
|
+
"""Get result from local execution (result.pkl)"""
|
|
54
|
+
# job_id is just project_name descriptor here
|
|
55
|
+
result_path = os.path.join(self.launcher.project_path, "result.pkl")
|
|
56
|
+
if os.path.exists(result_path):
|
|
57
|
+
try:
|
|
58
|
+
with open(result_path, "rb") as f:
|
|
59
|
+
return dill.load(f)
|
|
60
|
+
except Exception:
|
|
61
|
+
return None
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def get_logs(self, job_id: str) -> Any:
|
|
65
|
+
"""Get logs from local execution"""
|
|
66
|
+
log_path = os.path.join(self.launcher.project_path, f"{job_id}.log")
|
|
67
|
+
if os.path.exists(log_path):
|
|
68
|
+
with open(log_path, "r") as f:
|
|
69
|
+
# Yield lines
|
|
70
|
+
for line in f:
|
|
71
|
+
yield line.strip()
|
|
72
|
+
else:
|
|
73
|
+
yield "Log file not found."
|
|
74
|
+
|
|
75
|
+
def cancel(self, job_id: str):
|
|
76
|
+
"""Cancel a job (kill local process)"""
|
|
77
|
+
self.logger.info(f"Canceling local job {job_id}...")
|
|
78
|
+
try:
|
|
79
|
+
# Check if job_id looks like a PID (digits)
|
|
80
|
+
if job_id and job_id.isdigit():
|
|
81
|
+
pid = int(job_id)
|
|
82
|
+
self.logger.info(f"Killing process {pid}...")
|
|
83
|
+
os.kill(pid, signal.SIGTERM)
|
|
84
|
+
else:
|
|
85
|
+
# Fallback to pkill by pattern if not a PID (old behavior or synchronous fallback)
|
|
86
|
+
cmd = ["pkill", "-f", f"{self.launcher.project_path}/spython.py"]
|
|
87
|
+
subprocess.run(cmd, check=False)
|
|
88
|
+
|
|
89
|
+
self.logger.info("Local process killed.")
|
|
90
|
+
except Exception as e:
|
|
91
|
+
self.logger.warning(f"Failed to kill local process: {e}")
|
|
92
|
+
|
|
93
|
+
def _write_python_script(self):
|
|
94
|
+
"""Write the python script that will be executed by the job"""
|
|
95
|
+
self.logger.info("Writing python script...")
|
|
96
|
+
|
|
97
|
+
# Remove the old python script
|
|
98
|
+
for file in os.listdir(self.launcher.project_path):
|
|
99
|
+
if file.endswith(".py"):
|
|
100
|
+
os.remove(os.path.join(self.launcher.project_path, file))
|
|
101
|
+
|
|
102
|
+
# Write the python script
|
|
103
|
+
with open(
|
|
104
|
+
os.path.join(self.launcher.module_path, "assets", "spython_template.py"),
|
|
105
|
+
"r",
|
|
106
|
+
) as f:
|
|
107
|
+
text = f.read()
|
|
108
|
+
|
|
109
|
+
text = text.replace("{{PROJECT_PATH}}", f'"{self.launcher.project_path}"')
|
|
110
|
+
# Local mode doesn't need special address usually, or 'auto' is fine.
|
|
111
|
+
# Original code:
|
|
112
|
+
# if self.cluster or self.server_run:
|
|
113
|
+
# local_mode = ...
|
|
114
|
+
# else:
|
|
115
|
+
# local_mode = "" (empty)
|
|
116
|
+
|
|
117
|
+
local_mode = ""
|
|
118
|
+
text = text.replace(
|
|
119
|
+
"{{LOCAL_MODE}}",
|
|
120
|
+
local_mode,
|
|
121
|
+
)
|
|
122
|
+
with open(os.path.join(self.launcher.project_path, "spython.py"), "w") as f:
|
|
123
|
+
f.write(text)
|
|
124
|
+
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
import paramiko
|
|
5
|
+
import subprocess
|
|
6
|
+
import logging
|
|
7
|
+
from getpass import getpass
|
|
8
|
+
from typing import Any, Optional, List
|
|
9
|
+
|
|
10
|
+
from slurmray.backend.base import ClusterBackend
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RemoteMixin(ClusterBackend):
|
|
14
|
+
"""Mixin for remote execution via SSH"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, launcher):
|
|
17
|
+
super().__init__(launcher)
|
|
18
|
+
self.ssh_client = None
|
|
19
|
+
self.job_id = None
|
|
20
|
+
|
|
21
|
+
def _connect(self):
|
|
22
|
+
"""Establish SSH connection"""
|
|
23
|
+
if (
|
|
24
|
+
self.ssh_client
|
|
25
|
+
and self.ssh_client.get_transport()
|
|
26
|
+
and self.ssh_client.get_transport().is_active()
|
|
27
|
+
):
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
connected = False
|
|
31
|
+
self.logger.info("Connecting to the remote server...")
|
|
32
|
+
self.ssh_client = paramiko.SSHClient()
|
|
33
|
+
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
34
|
+
|
|
35
|
+
while not connected:
|
|
36
|
+
try:
|
|
37
|
+
if self.launcher.server_password is None:
|
|
38
|
+
# Password not provided: prompt interactively
|
|
39
|
+
# (Credentials from .env or explicit parameters are already loaded in RayLauncher.__init__)
|
|
40
|
+
try:
|
|
41
|
+
self.launcher.server_password = getpass(
|
|
42
|
+
"Enter your cluster password: "
|
|
43
|
+
)
|
|
44
|
+
except Exception:
|
|
45
|
+
# Handle case where getpass fails (e.g. non-interactive terminal)
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
if self.launcher.server_password is None:
|
|
49
|
+
raise ValueError(
|
|
50
|
+
"No password provided and cannot prompt (non-interactive)"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self.ssh_client.connect(
|
|
54
|
+
hostname=self.launcher.server_ssh,
|
|
55
|
+
username=self.launcher.server_username,
|
|
56
|
+
password=self.launcher.server_password,
|
|
57
|
+
)
|
|
58
|
+
connected = True
|
|
59
|
+
except paramiko.ssh_exception.AuthenticationException:
|
|
60
|
+
self.launcher.server_password = None
|
|
61
|
+
self.logger.warning("Wrong password, please try again.")
|
|
62
|
+
# Only retry interactively if we failed
|
|
63
|
+
if not sys.stdin.isatty():
|
|
64
|
+
raise # Fail fast if non-interactive
|
|
65
|
+
|
|
66
|
+
def _push_file(
|
|
67
|
+
self, file_path: str, sftp: paramiko.SFTPClient, remote_base_dir: str
|
|
68
|
+
):
|
|
69
|
+
"""Push a file to the remote server"""
|
|
70
|
+
self.logger.info(
|
|
71
|
+
f"Pushing file {os.path.basename(file_path)} to the remote server..."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Determine the path to the file
|
|
75
|
+
local_path = file_path
|
|
76
|
+
if not os.path.isabs(local_path):
|
|
77
|
+
local_path = os.path.join(self.launcher.pwd_path, local_path)
|
|
78
|
+
local_path = os.path.abspath(local_path)
|
|
79
|
+
|
|
80
|
+
local_path_from_pwd = os.path.relpath(local_path, self.launcher.pwd_path)
|
|
81
|
+
remote_path = os.path.join(remote_base_dir, local_path_from_pwd)
|
|
82
|
+
|
|
83
|
+
# Create the directory if not exists
|
|
84
|
+
stdin, stdout, stderr = self.ssh_client.exec_command(
|
|
85
|
+
f"mkdir -p '{os.path.dirname(remote_path)}'"
|
|
86
|
+
)
|
|
87
|
+
# Wait for command to finish
|
|
88
|
+
stdout.channel.recv_exit_status()
|
|
89
|
+
|
|
90
|
+
# Copy the file to the server
|
|
91
|
+
sftp.put(local_path, remote_path)
|
|
92
|
+
|
|
93
|
+
def _push_file_wrapper(
|
|
94
|
+
self,
|
|
95
|
+
rel_path: str,
|
|
96
|
+
sftp: paramiko.SFTPClient,
|
|
97
|
+
remote_base_dir: str,
|
|
98
|
+
ssh_client: paramiko.SSHClient = None,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
Wrapper to call _push_file with the correct signature for each backend.
|
|
102
|
+
"""
|
|
103
|
+
import inspect
|
|
104
|
+
|
|
105
|
+
# Get the signature of _push_file for this backend
|
|
106
|
+
sig = inspect.signature(self._push_file)
|
|
107
|
+
param_count = len(sig.parameters)
|
|
108
|
+
|
|
109
|
+
if param_count == 4: # SlurmBackend: (self, file_path, sftp, ssh_client)
|
|
110
|
+
if ssh_client is None:
|
|
111
|
+
ssh_client = self.ssh_client
|
|
112
|
+
self._push_file(rel_path, sftp, ssh_client)
|
|
113
|
+
else: # RemoteMixin/DesiBackend: (self, file_path, sftp, remote_base_dir)
|
|
114
|
+
self._push_file(rel_path, sftp, remote_base_dir)
|
|
115
|
+
|
|
116
|
+
def _sync_local_files_incremental(
|
|
117
|
+
self,
|
|
118
|
+
sftp: paramiko.SFTPClient,
|
|
119
|
+
remote_base_dir: str,
|
|
120
|
+
local_files: List[str],
|
|
121
|
+
ssh_client: paramiko.SSHClient = None,
|
|
122
|
+
):
|
|
123
|
+
"""
|
|
124
|
+
Synchronize local files incrementally using hash comparison.
|
|
125
|
+
Only uploads files that have changed.
|
|
126
|
+
"""
|
|
127
|
+
from slurmray.file_sync import FileHashManager, LocalFileSyncManager
|
|
128
|
+
|
|
129
|
+
# Use provided ssh_client or self.ssh_client
|
|
130
|
+
if ssh_client is None:
|
|
131
|
+
ssh_client = self.ssh_client
|
|
132
|
+
|
|
133
|
+
# Initialize sync manager
|
|
134
|
+
hash_manager = FileHashManager(self.launcher.pwd_path, self.logger)
|
|
135
|
+
sync_manager = LocalFileSyncManager(
|
|
136
|
+
self.launcher.pwd_path, hash_manager, self.logger
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Remote hash file path
|
|
140
|
+
remote_hash_file = os.path.join(
|
|
141
|
+
remote_base_dir, ".slogs", ".remote_file_hashes.json"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Fetch remote hashes
|
|
145
|
+
remote_hashes = sync_manager.fetch_remote_hashes(ssh_client, remote_hash_file)
|
|
146
|
+
|
|
147
|
+
# Determine which files need uploading
|
|
148
|
+
files_to_upload = sync_manager.get_files_to_upload(local_files, remote_hashes)
|
|
149
|
+
|
|
150
|
+
if not files_to_upload:
|
|
151
|
+
self.logger.info("✅ All local files are up to date, no upload needed.")
|
|
152
|
+
return
|
|
153
|
+
|
|
154
|
+
self.logger.info(
|
|
155
|
+
f"📤 Uploading {len(files_to_upload)} modified/new file(s) out of {len(local_files)} total..."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Upload files
|
|
159
|
+
uploaded_files = []
|
|
160
|
+
for rel_path in files_to_upload:
|
|
161
|
+
# Convert relative path to absolute
|
|
162
|
+
abs_path = os.path.join(self.launcher.pwd_path, rel_path)
|
|
163
|
+
|
|
164
|
+
# Handle directories (packages)
|
|
165
|
+
if os.path.isdir(abs_path):
|
|
166
|
+
# Upload all Python files in the directory recursively
|
|
167
|
+
for root, dirs, files in os.walk(abs_path):
|
|
168
|
+
# Skip __pycache__
|
|
169
|
+
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
|
170
|
+
for file in files:
|
|
171
|
+
if file.endswith(".py"):
|
|
172
|
+
file_path = os.path.join(root, file)
|
|
173
|
+
file_rel = os.path.relpath(
|
|
174
|
+
file_path, self.launcher.pwd_path
|
|
175
|
+
)
|
|
176
|
+
self._push_file_wrapper(
|
|
177
|
+
file_rel, sftp, remote_base_dir, ssh_client
|
|
178
|
+
)
|
|
179
|
+
uploaded_files.append(file_rel)
|
|
180
|
+
else:
|
|
181
|
+
# Single file
|
|
182
|
+
self._push_file_wrapper(rel_path, sftp, remote_base_dir, ssh_client)
|
|
183
|
+
uploaded_files.append(rel_path)
|
|
184
|
+
|
|
185
|
+
# Update remote hashes after successful upload
|
|
186
|
+
sync_manager.update_remote_hashes(uploaded_files, remote_hashes)
|
|
187
|
+
sync_manager.save_remote_hashes_to_server(
|
|
188
|
+
ssh_client, remote_hash_file, remote_hashes
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.logger.info(f"✅ Successfully uploaded {len(uploaded_files)} file(s).")
|