slurmray 6.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of slurmray might be problematic. Click here for more details.

@@ -0,0 +1,124 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import subprocess
5
+ import dill
6
+ from typing import Any
7
+
8
+ from slurmray.backend.base import ClusterBackend
9
+
10
+ class LocalBackend(ClusterBackend):
11
+ """Backend for local execution (no scheduler)"""
12
+
13
+ def run(self, cancel_old_jobs: bool = True, wait: bool = True) -> Any:
14
+ """Run the job locally"""
15
+ self.logger.info("No cluster detected, running locally...")
16
+
17
+ self._write_python_script()
18
+
19
+ # Determine log file path
20
+ # Use job name logic similar to Slurm if possible, or just project name
21
+ job_name = f"{self.launcher.project_name}_local"
22
+ log_path = os.path.join(self.launcher.project_path, f"{job_name}.log")
23
+
24
+ with open(log_path, "w") as log_file:
25
+ process = subprocess.Popen(
26
+ [sys.executable, os.path.join(self.launcher.project_path, "spython.py")],
27
+ stdout=log_file,
28
+ stderr=subprocess.STDOUT
29
+ )
30
+
31
+ if not wait:
32
+ self.logger.info(f"Local job started asynchronously. Logs: {log_path} PID: {process.pid}")
33
+ return str(process.pid)
34
+
35
+ # Wait for result
36
+ while not os.path.exists(os.path.join(self.launcher.project_path, "result.pkl")):
37
+ time.sleep(0.25)
38
+ # Check if process died
39
+ if process.poll() is not None:
40
+ # Process finished but no result?
41
+ if not os.path.exists(os.path.join(self.launcher.project_path, "result.pkl")):
42
+ # Check logs
43
+ with open(log_path, "r") as f:
44
+ print(f.read())
45
+ raise RuntimeError("Local process succeeded but no result.pkl found (or failed).")
46
+
47
+ with open(os.path.join(self.launcher.project_path, "result.pkl"), "rb") as f:
48
+ result = dill.load(f)
49
+
50
+ return result
51
+
52
+ def get_result(self, job_id: str) -> Any:
53
+ """Get result from local execution (result.pkl)"""
54
+ # job_id is just project_name descriptor here
55
+ result_path = os.path.join(self.launcher.project_path, "result.pkl")
56
+ if os.path.exists(result_path):
57
+ try:
58
+ with open(result_path, "rb") as f:
59
+ return dill.load(f)
60
+ except Exception:
61
+ return None
62
+ return None
63
+
64
+ def get_logs(self, job_id: str) -> Any:
65
+ """Get logs from local execution"""
66
+ log_path = os.path.join(self.launcher.project_path, f"{job_id}.log")
67
+ if os.path.exists(log_path):
68
+ with open(log_path, "r") as f:
69
+ # Yield lines
70
+ for line in f:
71
+ yield line.strip()
72
+ else:
73
+ yield "Log file not found."
74
+
75
+ def cancel(self, job_id: str):
76
+ """Cancel a job (kill local process)"""
77
+ self.logger.info(f"Canceling local job {job_id}...")
78
+ try:
79
+ # Check if job_id looks like a PID (digits)
80
+ if job_id and job_id.isdigit():
81
+ pid = int(job_id)
82
+ self.logger.info(f"Killing process {pid}...")
83
+ os.kill(pid, signal.SIGTERM)
84
+ else:
85
+ # Fallback to pkill by pattern if not a PID (old behavior or synchronous fallback)
86
+ cmd = ["pkill", "-f", f"{self.launcher.project_path}/spython.py"]
87
+ subprocess.run(cmd, check=False)
88
+
89
+ self.logger.info("Local process killed.")
90
+ except Exception as e:
91
+ self.logger.warning(f"Failed to kill local process: {e}")
92
+
93
+ def _write_python_script(self):
94
+ """Write the python script that will be executed by the job"""
95
+ self.logger.info("Writing python script...")
96
+
97
+ # Remove the old python script
98
+ for file in os.listdir(self.launcher.project_path):
99
+ if file.endswith(".py"):
100
+ os.remove(os.path.join(self.launcher.project_path, file))
101
+
102
+ # Write the python script
103
+ with open(
104
+ os.path.join(self.launcher.module_path, "assets", "spython_template.py"),
105
+ "r",
106
+ ) as f:
107
+ text = f.read()
108
+
109
+ text = text.replace("{{PROJECT_PATH}}", f'"{self.launcher.project_path}"')
110
+ # Local mode doesn't need special address usually, or 'auto' is fine.
111
+ # Original code:
112
+ # if self.cluster or self.server_run:
113
+ # local_mode = ...
114
+ # else:
115
+ # local_mode = "" (empty)
116
+
117
+ local_mode = ""
118
+ text = text.replace(
119
+ "{{LOCAL_MODE}}",
120
+ local_mode,
121
+ )
122
+ with open(os.path.join(self.launcher.project_path, "spython.py"), "w") as f:
123
+ f.write(text)
124
+
@@ -0,0 +1,191 @@
1
+ import os
2
+ import sys
3
+ import time
4
+ import paramiko
5
+ import subprocess
6
+ import logging
7
+ from getpass import getpass
8
+ from typing import Any, Optional, List
9
+
10
+ from slurmray.backend.base import ClusterBackend
11
+
12
+
13
+ class RemoteMixin(ClusterBackend):
14
+ """Mixin for remote execution via SSH"""
15
+
16
+ def __init__(self, launcher):
17
+ super().__init__(launcher)
18
+ self.ssh_client = None
19
+ self.job_id = None
20
+
21
+ def _connect(self):
22
+ """Establish SSH connection"""
23
+ if (
24
+ self.ssh_client
25
+ and self.ssh_client.get_transport()
26
+ and self.ssh_client.get_transport().is_active()
27
+ ):
28
+ return
29
+
30
+ connected = False
31
+ self.logger.info("Connecting to the remote server...")
32
+ self.ssh_client = paramiko.SSHClient()
33
+ self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
34
+
35
+ while not connected:
36
+ try:
37
+ if self.launcher.server_password is None:
38
+ # Password not provided: prompt interactively
39
+ # (Credentials from .env or explicit parameters are already loaded in RayLauncher.__init__)
40
+ try:
41
+ self.launcher.server_password = getpass(
42
+ "Enter your cluster password: "
43
+ )
44
+ except Exception:
45
+ # Handle case where getpass fails (e.g. non-interactive terminal)
46
+ pass
47
+
48
+ if self.launcher.server_password is None:
49
+ raise ValueError(
50
+ "No password provided and cannot prompt (non-interactive)"
51
+ )
52
+
53
+ self.ssh_client.connect(
54
+ hostname=self.launcher.server_ssh,
55
+ username=self.launcher.server_username,
56
+ password=self.launcher.server_password,
57
+ )
58
+ connected = True
59
+ except paramiko.ssh_exception.AuthenticationException:
60
+ self.launcher.server_password = None
61
+ self.logger.warning("Wrong password, please try again.")
62
+ # Only retry interactively if we failed
63
+ if not sys.stdin.isatty():
64
+ raise # Fail fast if non-interactive
65
+
66
+ def _push_file(
67
+ self, file_path: str, sftp: paramiko.SFTPClient, remote_base_dir: str
68
+ ):
69
+ """Push a file to the remote server"""
70
+ self.logger.info(
71
+ f"Pushing file {os.path.basename(file_path)} to the remote server..."
72
+ )
73
+
74
+ # Determine the path to the file
75
+ local_path = file_path
76
+ if not os.path.isabs(local_path):
77
+ local_path = os.path.join(self.launcher.pwd_path, local_path)
78
+ local_path = os.path.abspath(local_path)
79
+
80
+ local_path_from_pwd = os.path.relpath(local_path, self.launcher.pwd_path)
81
+ remote_path = os.path.join(remote_base_dir, local_path_from_pwd)
82
+
83
+ # Create the directory if not exists
84
+ stdin, stdout, stderr = self.ssh_client.exec_command(
85
+ f"mkdir -p '{os.path.dirname(remote_path)}'"
86
+ )
87
+ # Wait for command to finish
88
+ stdout.channel.recv_exit_status()
89
+
90
+ # Copy the file to the server
91
+ sftp.put(local_path, remote_path)
92
+
93
+ def _push_file_wrapper(
94
+ self,
95
+ rel_path: str,
96
+ sftp: paramiko.SFTPClient,
97
+ remote_base_dir: str,
98
+ ssh_client: paramiko.SSHClient = None,
99
+ ):
100
+ """
101
+ Wrapper to call _push_file with the correct signature for each backend.
102
+ """
103
+ import inspect
104
+
105
+ # Get the signature of _push_file for this backend
106
+ sig = inspect.signature(self._push_file)
107
+ param_count = len(sig.parameters)
108
+
109
+ if param_count == 4: # SlurmBackend: (self, file_path, sftp, ssh_client)
110
+ if ssh_client is None:
111
+ ssh_client = self.ssh_client
112
+ self._push_file(rel_path, sftp, ssh_client)
113
+ else: # RemoteMixin/DesiBackend: (self, file_path, sftp, remote_base_dir)
114
+ self._push_file(rel_path, sftp, remote_base_dir)
115
+
116
+ def _sync_local_files_incremental(
117
+ self,
118
+ sftp: paramiko.SFTPClient,
119
+ remote_base_dir: str,
120
+ local_files: List[str],
121
+ ssh_client: paramiko.SSHClient = None,
122
+ ):
123
+ """
124
+ Synchronize local files incrementally using hash comparison.
125
+ Only uploads files that have changed.
126
+ """
127
+ from slurmray.file_sync import FileHashManager, LocalFileSyncManager
128
+
129
+ # Use provided ssh_client or self.ssh_client
130
+ if ssh_client is None:
131
+ ssh_client = self.ssh_client
132
+
133
+ # Initialize sync manager
134
+ hash_manager = FileHashManager(self.launcher.pwd_path, self.logger)
135
+ sync_manager = LocalFileSyncManager(
136
+ self.launcher.pwd_path, hash_manager, self.logger
137
+ )
138
+
139
+ # Remote hash file path
140
+ remote_hash_file = os.path.join(
141
+ remote_base_dir, ".slogs", ".remote_file_hashes.json"
142
+ )
143
+
144
+ # Fetch remote hashes
145
+ remote_hashes = sync_manager.fetch_remote_hashes(ssh_client, remote_hash_file)
146
+
147
+ # Determine which files need uploading
148
+ files_to_upload = sync_manager.get_files_to_upload(local_files, remote_hashes)
149
+
150
+ if not files_to_upload:
151
+ self.logger.info("✅ All local files are up to date, no upload needed.")
152
+ return
153
+
154
+ self.logger.info(
155
+ f"📤 Uploading {len(files_to_upload)} modified/new file(s) out of {len(local_files)} total..."
156
+ )
157
+
158
+ # Upload files
159
+ uploaded_files = []
160
+ for rel_path in files_to_upload:
161
+ # Convert relative path to absolute
162
+ abs_path = os.path.join(self.launcher.pwd_path, rel_path)
163
+
164
+ # Handle directories (packages)
165
+ if os.path.isdir(abs_path):
166
+ # Upload all Python files in the directory recursively
167
+ for root, dirs, files in os.walk(abs_path):
168
+ # Skip __pycache__
169
+ dirs[:] = [d for d in dirs if d != "__pycache__"]
170
+ for file in files:
171
+ if file.endswith(".py"):
172
+ file_path = os.path.join(root, file)
173
+ file_rel = os.path.relpath(
174
+ file_path, self.launcher.pwd_path
175
+ )
176
+ self._push_file_wrapper(
177
+ file_rel, sftp, remote_base_dir, ssh_client
178
+ )
179
+ uploaded_files.append(file_rel)
180
+ else:
181
+ # Single file
182
+ self._push_file_wrapper(rel_path, sftp, remote_base_dir, ssh_client)
183
+ uploaded_files.append(rel_path)
184
+
185
+ # Update remote hashes after successful upload
186
+ sync_manager.update_remote_hashes(uploaded_files, remote_hashes)
187
+ sync_manager.save_remote_hashes_to_server(
188
+ ssh_client, remote_hash_file, remote_hashes
189
+ )
190
+
191
+ self.logger.info(f"✅ Successfully uploaded {len(uploaded_files)} file(s).")