slurmray 6.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of slurmray might be problematic. Click here for more details.
- slurmray/RayLauncher.py +1019 -0
- slurmray/__init__.py +0 -0
- slurmray/__main__.py +5 -0
- slurmray/assets/cleanup_old_projects.py +171 -0
- slurmray/assets/sbatch_template.sh +67 -0
- slurmray/assets/slurmray_server.sh +145 -0
- slurmray/assets/slurmray_server_template.py +28 -0
- slurmray/assets/spython_template.py +113 -0
- slurmray/backend/__init__.py +0 -0
- slurmray/backend/base.py +1040 -0
- slurmray/backend/desi.py +856 -0
- slurmray/backend/local.py +124 -0
- slurmray/backend/remote.py +191 -0
- slurmray/backend/slurm.py +1234 -0
- slurmray/cli.py +904 -0
- slurmray/detection.py +1 -0
- slurmray/file_sync.py +276 -0
- slurmray/scanner.py +441 -0
- slurmray/utils.py +359 -0
- slurmray-6.0.4.dist-info/LICENSE +201 -0
- slurmray-6.0.4.dist-info/METADATA +85 -0
- slurmray-6.0.4.dist-info/RECORD +24 -0
- slurmray-6.0.4.dist-info/WHEEL +4 -0
- slurmray-6.0.4.dist-info/entry_points.txt +3 -0
slurmray/RayLauncher.py
ADDED
|
@@ -0,0 +1,1019 @@
|
|
|
1
|
+
from typing import Any, Callable, List
|
|
2
|
+
import sys
|
|
3
|
+
import os
|
|
4
|
+
import dill
|
|
5
|
+
import logging
|
|
6
|
+
import signal
|
|
7
|
+
import dis
|
|
8
|
+
import builtins
|
|
9
|
+
import inspect
|
|
10
|
+
from typing import Any, Callable, List, Tuple, Set, Generator
|
|
11
|
+
from getpass import getpass
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
|
|
17
|
+
from slurmray.backend.slurm import SlurmBackend
|
|
18
|
+
from slurmray.backend.local import LocalBackend
|
|
19
|
+
from slurmray.backend.desi import DesiBackend
|
|
20
|
+
|
|
21
|
+
dill.settings["recurse"] = True
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RayLauncher:
|
|
25
|
+
"""A class that automatically connects RAY workers and executes the function requested by the user.
|
|
26
|
+
|
|
27
|
+
Official tool from DESI @ HEC UNIL.
|
|
28
|
+
|
|
29
|
+
Supports multiple execution modes:
|
|
30
|
+
- **Curnagl mode** (`cluster='curnagl'`): For Slurm-based clusters like Curnagl. Uses sbatch/squeue for job management.
|
|
31
|
+
- **Desi mode** (`cluster='desi'`): For standalone servers like ISIPOL09. Uses Smart Lock scheduling for resource management.
|
|
32
|
+
- **Local mode** (`cluster='local'`): For local execution without remote server/cluster.
|
|
33
|
+
- **Custom IP** (`cluster='<ip_or_hostname>'`): For custom Slurm clusters. Uses the provided IP/hostname.
|
|
34
|
+
|
|
35
|
+
The launcher automatically selects the appropriate backend based on the `cluster` parameter and environment detection.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
class FunctionReturn:
|
|
39
|
+
"""Object returned when running in asynchronous mode.
|
|
40
|
+
Allows monitoring logs and retrieving the result later.
|
|
41
|
+
"""
|
|
42
|
+
def __init__(self, launcher, job_id=None):
|
|
43
|
+
self.launcher = launcher
|
|
44
|
+
self.job_id = job_id
|
|
45
|
+
self._cached_result = None
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def result(self):
|
|
49
|
+
"""Get the result of the function execution.
|
|
50
|
+
Returns "Compute still in progress" if not finished.
|
|
51
|
+
"""
|
|
52
|
+
if self._cached_result is not None:
|
|
53
|
+
return self._cached_result
|
|
54
|
+
|
|
55
|
+
# Attempt to fetch result from backend
|
|
56
|
+
# We use a new method on backend to check/fetch result without blocking
|
|
57
|
+
if hasattr(self.launcher.backend, "get_result"):
|
|
58
|
+
res = self.launcher.backend.get_result(self.job_id)
|
|
59
|
+
if res is not None:
|
|
60
|
+
self._cached_result = res
|
|
61
|
+
return res
|
|
62
|
+
|
|
63
|
+
return "Compute still in progress"
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def logs(self) -> Generator[str, None, None]:
|
|
67
|
+
"""Get the logs of the function execution as a stream (generator)."""
|
|
68
|
+
if hasattr(self.launcher.backend, "get_logs"):
|
|
69
|
+
yield from self.launcher.backend.get_logs(self.job_id)
|
|
70
|
+
else:
|
|
71
|
+
yield "Logs not available for this backend."
|
|
72
|
+
|
|
73
|
+
def cancel(self):
|
|
74
|
+
"""Cancel the running job."""
|
|
75
|
+
if hasattr(self.launcher.backend, "cancel"):
|
|
76
|
+
self.launcher.backend.cancel(self.job_id)
|
|
77
|
+
|
|
78
|
+
def __getstate__(self):
|
|
79
|
+
"""Custom serialization to ensure picklability"""
|
|
80
|
+
state = self.__dict__.copy()
|
|
81
|
+
# Ensure launcher is picklable. The launcher itself might have non-picklable attributes (like ssh_client).
|
|
82
|
+
# We rely on RayLauncher and Backend handling their own serialization safety.
|
|
83
|
+
return state
|
|
84
|
+
|
|
85
|
+
def __setstate__(self, state):
|
|
86
|
+
self.__dict__.update(state)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
project_name: str = None,
|
|
92
|
+
files: List[str] = [],
|
|
93
|
+
modules: List[str] = [],
|
|
94
|
+
node_nbr: int = 1,
|
|
95
|
+
use_gpu: bool = False,
|
|
96
|
+
memory: int = 64,
|
|
97
|
+
max_running_time: int = 60,
|
|
98
|
+
runtime_env: dict = {"env_vars": {}},
|
|
99
|
+
server_run: bool = True,
|
|
100
|
+
server_ssh: str = None, # Auto-detected from cluster parameter
|
|
101
|
+
server_username: str = None,
|
|
102
|
+
server_password: str = None,
|
|
103
|
+
log_file: str = "logs/RayLauncher.log",
|
|
104
|
+
cluster: str = "curnagl", # 'curnagl', 'desi', 'local', or custom IP/hostname
|
|
105
|
+
force_reinstall_venv: bool = False,
|
|
106
|
+
retention_days: int = 7,
|
|
107
|
+
asynchronous: bool = False,
|
|
108
|
+
):
|
|
109
|
+
"""Initialize the launcher
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
project_name (str, optional): Name of the project. Defaults to None.
|
|
113
|
+
files (List[str], optional): List of files to push to the cluster/server. This path must be **relative** to the project directory. Defaults to [].
|
|
114
|
+
modules (List[str], optional): List of modules to load (Slurm mode only). Use `module spider` to see available modules. Ignored in Desi mode. Defaults to None.
|
|
115
|
+
node_nbr (int, optional): Number of nodes to use. For Desi mode, this is always 1 (single server). Defaults to 1.
|
|
116
|
+
use_gpu (bool, optional): Use GPU or not. Defaults to False.
|
|
117
|
+
memory (int, optional): Amount of RAM to use per node in GigaBytes. For Desi mode, this is not enforced (shared resource). Defaults to 64.
|
|
118
|
+
max_running_time (int, optional): Maximum running time of the job in minutes. For Desi mode, this is not enforced by a scheduler. Defaults to 60.
|
|
119
|
+
runtime_env (dict, optional): Environment variables to share between all the workers. Can be useful for issues like https://github.com/ray-project/ray/issues/418. Default to empty.
|
|
120
|
+
server_run (bool, optional): If you run the launcher from your local machine, you can use this parameter to execute your function using online cluster/server ressources. Defaults to True.
|
|
121
|
+
server_ssh (str, optional): If `server_run` is set to true, the address of the server to use. Auto-detected from `cluster` parameter if not provided. Defaults to None (auto-detected).
|
|
122
|
+
server_username (str, optional): If `server_run` is set to true, the username with which you wish to connect. Credentials are automatically loaded from a `.env` file (CURNAGL_USERNAME for Curnagl/custom IP, DESI_USERNAME for Desi) if available. Priority: environment variables → explicit parameter → default ("hjamet" for Curnagl/custom IP, "henri" for Desi).
|
|
123
|
+
server_password (str, optional): If `server_run` is set to true, the password of the user to connect to the server. Credentials are automatically loaded from a `.env` file (CURNAGL_PASSWORD for Curnagl/custom IP, DESI_PASSWORD for Desi) if available. Priority: explicit parameter → environment variables → interactive prompt. CAUTION: never write your password in the code. Defaults to None.
|
|
124
|
+
log_file (str, optional): Path to the log file. Defaults to "logs/RayLauncher.log".
|
|
125
|
+
cluster (str, optional): Cluster/server to use: 'curnagl' (default, Slurm cluster), 'desi' (ISIPOL09/Desi server), 'local' (local execution), or a custom IP/hostname (for custom Slurm clusters). Defaults to "curnagl".
|
|
126
|
+
force_reinstall_venv (bool, optional): Force complete removal and recreation of virtual environment on remote server/cluster. This will delete the existing venv and reinstall all packages from requirements.txt. Use this if the venv is corrupted or you need a clean installation. Defaults to False.
|
|
127
|
+
retention_days (int, optional): Number of days to retain files and venv on the cluster before automatic cleanup. Must be between 1 and 30 days. Defaults to 7.
|
|
128
|
+
asynchronous (bool, optional): If True, the call to the function returns immediately with a FunctionReturn object. Defaults to False.
|
|
129
|
+
"""
|
|
130
|
+
# Load environment variables from .env file
|
|
131
|
+
load_dotenv()
|
|
132
|
+
|
|
133
|
+
# Normalize cluster parameter
|
|
134
|
+
cluster_lower = cluster.lower()
|
|
135
|
+
|
|
136
|
+
# Detect if cluster is a custom IP/hostname (not a known name)
|
|
137
|
+
is_custom_ip = cluster_lower not in ["curnagl", "desi", "local"]
|
|
138
|
+
|
|
139
|
+
# Determine cluster type and backend type
|
|
140
|
+
if cluster_lower == "local":
|
|
141
|
+
self.cluster_type = "local"
|
|
142
|
+
self.backend_type = "local"
|
|
143
|
+
# Force local execution
|
|
144
|
+
self._force_local = True
|
|
145
|
+
else:
|
|
146
|
+
self._force_local = False
|
|
147
|
+
if cluster_lower == "desi":
|
|
148
|
+
self.cluster_type = "desi"
|
|
149
|
+
self.backend_type = "desi"
|
|
150
|
+
elif cluster_lower == "curnagl" or is_custom_ip:
|
|
151
|
+
self.cluster_type = "curnagl" # Use "curnagl" for credential loading
|
|
152
|
+
self.backend_type = "slurm" # Use SlurmBackend
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Invalid cluster value: '{cluster}'. Use 'curnagl', 'desi', 'local', or a custom IP/hostname."
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Determine environment variable names based on cluster type
|
|
159
|
+
if self.cluster_type == "desi":
|
|
160
|
+
env_username_key = "DESI_USERNAME"
|
|
161
|
+
env_password_key = "DESI_PASSWORD"
|
|
162
|
+
default_username = "henri"
|
|
163
|
+
else: # curnagl or custom IP (both use CURNAGL credentials)
|
|
164
|
+
env_username_key = "CURNAGL_USERNAME"
|
|
165
|
+
env_password_key = "CURNAGL_PASSWORD"
|
|
166
|
+
default_username = "hjamet"
|
|
167
|
+
|
|
168
|
+
# Load credentials with priority: .env → explicit parameter → default/prompt
|
|
169
|
+
# Priority 1: Load from environment variables (from .env or system env)
|
|
170
|
+
env_username = os.getenv(env_username_key)
|
|
171
|
+
env_password = os.getenv(env_password_key)
|
|
172
|
+
|
|
173
|
+
# For username: explicit parameter → env → default
|
|
174
|
+
if server_username is not None:
|
|
175
|
+
# Explicit parameter provided
|
|
176
|
+
self.server_username = server_username
|
|
177
|
+
elif env_username:
|
|
178
|
+
# Load from environment
|
|
179
|
+
self.server_username = env_username
|
|
180
|
+
else:
|
|
181
|
+
# Use default
|
|
182
|
+
self.server_username = default_username
|
|
183
|
+
|
|
184
|
+
# For password: explicit parameter → env → None (will prompt later if needed)
|
|
185
|
+
# Explicit parameter takes precedence over env
|
|
186
|
+
if server_password is not None:
|
|
187
|
+
# Explicit parameter provided
|
|
188
|
+
self.server_password = server_password
|
|
189
|
+
elif env_password:
|
|
190
|
+
# Load from environment
|
|
191
|
+
self.server_password = env_password
|
|
192
|
+
else:
|
|
193
|
+
# None: will be prompted by backend if needed
|
|
194
|
+
self.server_password = None
|
|
195
|
+
|
|
196
|
+
# Save the other parameters
|
|
197
|
+
self.project_name = project_name
|
|
198
|
+
self.files = files
|
|
199
|
+
self.modules = modules
|
|
200
|
+
self.node_nbr = node_nbr
|
|
201
|
+
self.use_gpu = use_gpu
|
|
202
|
+
self.memory = memory
|
|
203
|
+
self.max_running_time = max_running_time
|
|
204
|
+
|
|
205
|
+
# Validate and save retention_days
|
|
206
|
+
if retention_days < 1 or retention_days > 30:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"retention_days must be between 1 and 30, got {retention_days}"
|
|
209
|
+
)
|
|
210
|
+
self.retention_days = retention_days
|
|
211
|
+
|
|
212
|
+
# Set default runtime_env and add Ray warning suppression
|
|
213
|
+
if runtime_env is None:
|
|
214
|
+
runtime_env = {"env_vars": {}}
|
|
215
|
+
elif "env_vars" not in runtime_env:
|
|
216
|
+
runtime_env["env_vars"] = {}
|
|
217
|
+
|
|
218
|
+
# Suppress Ray FutureWarning about accelerator visible devices
|
|
219
|
+
if "RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO" not in runtime_env["env_vars"]:
|
|
220
|
+
runtime_env["env_vars"]["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
|
|
221
|
+
|
|
222
|
+
self.runtime_env = runtime_env
|
|
223
|
+
# Update server_run if cluster is "local"
|
|
224
|
+
if hasattr(self, "_force_local") and self._force_local:
|
|
225
|
+
self.server_run = False
|
|
226
|
+
else:
|
|
227
|
+
self.server_run = server_run
|
|
228
|
+
|
|
229
|
+
# Auto-detect server_ssh from cluster parameter if not provided
|
|
230
|
+
if self.server_run and server_ssh is None:
|
|
231
|
+
if cluster_lower == "desi":
|
|
232
|
+
self.server_ssh = "130.223.73.209"
|
|
233
|
+
elif cluster_lower == "curnagl":
|
|
234
|
+
self.server_ssh = "curnagl.dcsr.unil.ch"
|
|
235
|
+
elif is_custom_ip:
|
|
236
|
+
# Use the provided IP/hostname directly
|
|
237
|
+
self.server_ssh = cluster
|
|
238
|
+
else:
|
|
239
|
+
# Fallback (should not happen)
|
|
240
|
+
self.server_ssh = "curnagl.dcsr.unil.ch"
|
|
241
|
+
else:
|
|
242
|
+
self.server_ssh = server_ssh or "curnagl.dcsr.unil.ch"
|
|
243
|
+
|
|
244
|
+
self.log_file = log_file
|
|
245
|
+
self.force_reinstall_venv = force_reinstall_venv
|
|
246
|
+
self.asynchronous = asynchronous
|
|
247
|
+
|
|
248
|
+
# Track which parameters were explicitly passed (for warnings)
|
|
249
|
+
import inspect
|
|
250
|
+
|
|
251
|
+
frame = inspect.currentframe()
|
|
252
|
+
args, _, _, values = inspect.getargvalues(frame)
|
|
253
|
+
self._explicit_params = {
|
|
254
|
+
arg: values[arg] for arg in args[1:] if arg in values
|
|
255
|
+
} # Skip 'self'
|
|
256
|
+
|
|
257
|
+
self.__setup_logger()
|
|
258
|
+
|
|
259
|
+
# Create the project directory if not exists (needed for pwd_path)
|
|
260
|
+
self.pwd_path = os.getcwd()
|
|
261
|
+
self.module_path = os.path.dirname(os.path.abspath(__file__))
|
|
262
|
+
self.project_path = os.path.join(self.pwd_path, ".slogs", self.project_name)
|
|
263
|
+
if not os.path.exists(self.project_path):
|
|
264
|
+
os.makedirs(self.project_path)
|
|
265
|
+
|
|
266
|
+
# Detect local Python version
|
|
267
|
+
self.local_python_version = self._detect_local_python_version()
|
|
268
|
+
|
|
269
|
+
# Default modules with specific versions for Curnagl compatibility
|
|
270
|
+
# Using latest stable versions available on Curnagl (SLURM 24.05.3)
|
|
271
|
+
# gcc/13.2.0: Latest GCC version
|
|
272
|
+
# python/3.12.1: Latest Python version on Curnagl
|
|
273
|
+
# cuda/12.6.2: Latest CUDA version
|
|
274
|
+
# cudnn/9.2.0.82-12: Compatible with cuda/12.6.2
|
|
275
|
+
default_modules = ["gcc/13.2.0", "python/3.12.1"]
|
|
276
|
+
|
|
277
|
+
# Filter out any gcc or python modules from user list (we use defaults)
|
|
278
|
+
# Allow user to override by providing specific versions
|
|
279
|
+
user_modules = []
|
|
280
|
+
for mod in modules:
|
|
281
|
+
# Skip if it's a gcc or python module (user can override by providing full version)
|
|
282
|
+
if mod.startswith("gcc") or mod.startswith("python"):
|
|
283
|
+
continue
|
|
284
|
+
user_modules.append(mod)
|
|
285
|
+
|
|
286
|
+
self.modules = default_modules + user_modules
|
|
287
|
+
|
|
288
|
+
if self.use_gpu is True:
|
|
289
|
+
# Check if user provided specific cuda/cudnn versions
|
|
290
|
+
has_cuda = any("cuda" in mod for mod in self.modules)
|
|
291
|
+
has_cudnn = any("cudnn" in mod for mod in self.modules)
|
|
292
|
+
if not has_cuda:
|
|
293
|
+
self.modules.append("cuda/12.6.2")
|
|
294
|
+
if not has_cudnn:
|
|
295
|
+
self.modules.append("cudnn/9.2.0.82-12")
|
|
296
|
+
|
|
297
|
+
# --- Validation des Arguments ---
|
|
298
|
+
self._validate_arguments()
|
|
299
|
+
|
|
300
|
+
# Check if this code is running on a cluster (only relevant for Slurm, usually)
|
|
301
|
+
self.cluster = os.path.exists("/usr/bin/sbatch")
|
|
302
|
+
|
|
303
|
+
# Initialize Backend
|
|
304
|
+
if self.backend_type == "local":
|
|
305
|
+
self.backend = LocalBackend(self)
|
|
306
|
+
elif self.server_run:
|
|
307
|
+
if self.backend_type == "desi":
|
|
308
|
+
self.backend = DesiBackend(self)
|
|
309
|
+
elif self.backend_type == "slurm":
|
|
310
|
+
self.backend = SlurmBackend(self)
|
|
311
|
+
else:
|
|
312
|
+
raise ValueError(
|
|
313
|
+
f"Unknown backend type: {self.backend_type}. This should not happen."
|
|
314
|
+
)
|
|
315
|
+
elif self.cluster: # Running ON a cluster (Slurm)
|
|
316
|
+
self.backend = SlurmBackend(self)
|
|
317
|
+
else:
|
|
318
|
+
self.backend = LocalBackend(self)
|
|
319
|
+
|
|
320
|
+
# Auto-detect and add editable package source paths to files list
|
|
321
|
+
# Note: Intelligent dependency detection is now done in __call__
|
|
322
|
+
# when we have the function to analyze. We don't auto-add editable packages
|
|
323
|
+
# blindly anymore to avoid adding unwanted files or breaking with complex setups.
|
|
324
|
+
|
|
325
|
+
def __setup_logger(self):
|
|
326
|
+
"""Setup the logger"""
|
|
327
|
+
# Create the log directory if not exists
|
|
328
|
+
log_dir = os.path.dirname(self.log_file)
|
|
329
|
+
if log_dir and not os.path.exists(log_dir):
|
|
330
|
+
os.makedirs(log_dir)
|
|
331
|
+
|
|
332
|
+
# Configure the logger
|
|
333
|
+
self.logger = logging.getLogger(f"RayLauncher-{self.project_name}")
|
|
334
|
+
self.logger.setLevel(logging.INFO)
|
|
335
|
+
|
|
336
|
+
# Remove existing handlers to avoid duplication if instantiated multiple times
|
|
337
|
+
if self.logger.hasHandlers():
|
|
338
|
+
self.logger.handlers.clear()
|
|
339
|
+
|
|
340
|
+
# File handler (constantly rewritten)
|
|
341
|
+
file_handler = logging.FileHandler(self.log_file, mode="w")
|
|
342
|
+
file_handler.setLevel(logging.INFO)
|
|
343
|
+
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|
344
|
+
file_handler.setFormatter(file_formatter)
|
|
345
|
+
self.logger.addHandler(file_handler)
|
|
346
|
+
|
|
347
|
+
# Console handler (only warnings and errors)
|
|
348
|
+
console_handler = logging.StreamHandler()
|
|
349
|
+
console_handler.setLevel(logging.WARNING)
|
|
350
|
+
console_formatter = logging.Formatter("%(levelname)s: %(message)s")
|
|
351
|
+
console_handler.setFormatter(console_formatter)
|
|
352
|
+
self.logger.addHandler(console_handler)
|
|
353
|
+
|
|
354
|
+
def _detect_local_python_version(self) -> str:
|
|
355
|
+
"""Detect local Python version from .python-version file or sys.version_info
|
|
356
|
+
|
|
357
|
+
Returns:
|
|
358
|
+
str: Python version in format "X.Y.Z" (e.g., "3.12.1")
|
|
359
|
+
"""
|
|
360
|
+
# Try to read from .python-version file first
|
|
361
|
+
python_version_file = os.path.join(self.pwd_path, ".python-version")
|
|
362
|
+
if os.path.exists(python_version_file):
|
|
363
|
+
with open(python_version_file, "r") as f:
|
|
364
|
+
version_str = f.read().strip()
|
|
365
|
+
# Validate format (should be X.Y or X.Y.Z)
|
|
366
|
+
import re
|
|
367
|
+
|
|
368
|
+
if re.match(r"^\d+\.\d+(\.\d+)?$", version_str):
|
|
369
|
+
# If only X.Y, add .0 for micro version
|
|
370
|
+
if version_str.count(".") == 1:
|
|
371
|
+
version_str = f"{version_str}.0"
|
|
372
|
+
self.logger.info(
|
|
373
|
+
f"Detected Python version from .python-version: {version_str}"
|
|
374
|
+
)
|
|
375
|
+
return version_str
|
|
376
|
+
|
|
377
|
+
# Fallback to sys.version_info
|
|
378
|
+
version_str = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
|
|
379
|
+
self.logger.info(
|
|
380
|
+
f"Detected Python version from sys.version_info: {version_str}"
|
|
381
|
+
)
|
|
382
|
+
return version_str
|
|
383
|
+
|
|
384
|
+
def _validate_arguments(self):
|
|
385
|
+
"""Validate arguments and warn about inconsistencies"""
|
|
386
|
+
# Validate project_name is not None (required for project-based organization on cluster)
|
|
387
|
+
if self.project_name is None:
|
|
388
|
+
raise ValueError(
|
|
389
|
+
"project_name cannot be None. A project name is required for cluster execution."
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if self.cluster_type == "desi":
|
|
393
|
+
# server_ssh is already set correctly in __init__
|
|
394
|
+
pass
|
|
395
|
+
|
|
396
|
+
if self.node_nbr > 1:
|
|
397
|
+
self.logger.warning(
|
|
398
|
+
f"Warning: Desi cluster only supports single node execution. node_nbr={self.node_nbr} will be ignored (effectively 1)."
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Only warn if modules were explicitly passed by user (not just defaults)
|
|
402
|
+
# Check if user provided modules beyond the default ones (gcc/python) or GPU modules (cuda/cudnn)
|
|
403
|
+
# GPU modules are added automatically if use_gpu=True, so they don't count as user-provided
|
|
404
|
+
user_provided_modules = [
|
|
405
|
+
m
|
|
406
|
+
for m in self.modules
|
|
407
|
+
if not (
|
|
408
|
+
m.startswith("gcc")
|
|
409
|
+
or m.startswith("python")
|
|
410
|
+
or m.startswith("cuda")
|
|
411
|
+
or m.startswith("cudnn")
|
|
412
|
+
)
|
|
413
|
+
]
|
|
414
|
+
if "modules" in self._explicit_params and user_provided_modules:
|
|
415
|
+
self.logger.warning(
|
|
416
|
+
"Warning: Modules loading is not supported on Desi (no module system). Modules list will be ignored."
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
if "memory" in self._explicit_params and self.memory != 64: # 64 is default
|
|
420
|
+
self.logger.warning(
|
|
421
|
+
"Warning: Memory allocation is not enforced on Desi (shared resource)."
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
def _handle_signal(self, signum, frame):
|
|
425
|
+
"""Handle interruption signals (SIGINT, SIGTERM) to cleanup resources"""
|
|
426
|
+
sig_name = signal.Signals(signum).name
|
|
427
|
+
self.logger.warning(f"Signal {sig_name} received. Cleaning up resources...")
|
|
428
|
+
print(f"\nInterruption received ({sig_name}). Canceling job and cleaning up...")
|
|
429
|
+
|
|
430
|
+
self.cancel()
|
|
431
|
+
sys.exit(1)
|
|
432
|
+
|
|
433
|
+
def cancel(self, target: Any = None):
|
|
434
|
+
"""
|
|
435
|
+
Cancel a running job.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
target (Any, optional): The job to cancel. Can be:
|
|
439
|
+
- None: Cancels the last job run by this launcher instance.
|
|
440
|
+
- str: A specific job ID.
|
|
441
|
+
- FunctionReturn: A specific FunctionReturn object.
|
|
442
|
+
"""
|
|
443
|
+
if hasattr(self, "backend"):
|
|
444
|
+
job_id = None
|
|
445
|
+
|
|
446
|
+
# Determine job_id based on target
|
|
447
|
+
if target is None:
|
|
448
|
+
# Fallback to last job
|
|
449
|
+
if hasattr(self.backend, "job_id") and self.backend.job_id:
|
|
450
|
+
job_id = self.backend.job_id
|
|
451
|
+
elif hasattr(self, "job_id") and self.job_id:
|
|
452
|
+
job_id = self.job_id
|
|
453
|
+
elif isinstance(target, str):
|
|
454
|
+
job_id = target
|
|
455
|
+
elif isinstance(target, self.FunctionReturn):
|
|
456
|
+
job_id = target.job_id
|
|
457
|
+
|
|
458
|
+
if job_id:
|
|
459
|
+
self.backend.cancel(job_id)
|
|
460
|
+
else:
|
|
461
|
+
self.logger.warning("No job ID found to cancel.")
|
|
462
|
+
else:
|
|
463
|
+
self.logger.warning("No backend initialized, cannot cancel.")
|
|
464
|
+
|
|
465
|
+
def __call__(
|
|
466
|
+
self,
|
|
467
|
+
func: Callable,
|
|
468
|
+
args: dict = None,
|
|
469
|
+
cancel_old_jobs: bool = True,
|
|
470
|
+
serialize: bool = True,
|
|
471
|
+
) -> Any:
|
|
472
|
+
"""Launch the job and return the result
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
func (Callable): Function to execute. This function should not be remote but can use ray ressources.
|
|
476
|
+
args (dict, optional): Arguments of the function. Defaults to None (empty dict).
|
|
477
|
+
cancel_old_jobs (bool, optional): Cancel the old jobs. Defaults to True.
|
|
478
|
+
serialize (bool, optional): Serialize the function and the arguments. This should be set to False if the function is automatically called by the server. Defaults to True.
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
Any: Result of the function
|
|
482
|
+
"""
|
|
483
|
+
if args is None:
|
|
484
|
+
args = {}
|
|
485
|
+
|
|
486
|
+
# Intelligent dependency detection from function source file
|
|
487
|
+
if self.server_run:
|
|
488
|
+
try:
|
|
489
|
+
from slurmray.scanner import ProjectScanner
|
|
490
|
+
|
|
491
|
+
scanner = ProjectScanner(self.pwd_path, self.logger)
|
|
492
|
+
detected_dependencies = scanner.detect_dependencies_from_function(func)
|
|
493
|
+
|
|
494
|
+
added_count = 0
|
|
495
|
+
for dep in detected_dependencies:
|
|
496
|
+
# Skip invalid paths (empty, current directory, etc.)
|
|
497
|
+
if (
|
|
498
|
+
not dep
|
|
499
|
+
or dep == "."
|
|
500
|
+
or dep == ".."
|
|
501
|
+
or dep.startswith("./")
|
|
502
|
+
or dep.startswith("../")
|
|
503
|
+
):
|
|
504
|
+
continue
|
|
505
|
+
|
|
506
|
+
# Skip paths that are outside project or in ignored directories
|
|
507
|
+
dep_abs = os.path.abspath(os.path.join(self.pwd_path, dep))
|
|
508
|
+
if not dep_abs.startswith(os.path.abspath(self.pwd_path)):
|
|
509
|
+
continue
|
|
510
|
+
|
|
511
|
+
# Check if it's a valid file or directory
|
|
512
|
+
if not os.path.exists(dep_abs):
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
# Check if dependency is already covered by existing files/dirs
|
|
516
|
+
# E.g. if 'src' is in files, 'src/module.py' is covered
|
|
517
|
+
is_covered = False
|
|
518
|
+
for existing in self.files:
|
|
519
|
+
if dep == existing or (dep.startswith(existing + os.sep)):
|
|
520
|
+
is_covered = True
|
|
521
|
+
break
|
|
522
|
+
|
|
523
|
+
if not is_covered:
|
|
524
|
+
self.files.append(dep)
|
|
525
|
+
added_count += 1
|
|
526
|
+
|
|
527
|
+
if added_count > 0:
|
|
528
|
+
self.logger.info(
|
|
529
|
+
f"Auto-added {added_count} local dependencies to upload list (from function imports)."
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
# Display warnings for dynamic imports
|
|
533
|
+
if scanner.dynamic_imports_warnings:
|
|
534
|
+
print("\n" + "=" * 60)
|
|
535
|
+
print("⚠️ WARNING: Dynamic imports or file operations detected ⚠️")
|
|
536
|
+
print("=" * 60)
|
|
537
|
+
print(
|
|
538
|
+
"The following lines might require files that cannot be auto-detected."
|
|
539
|
+
)
|
|
540
|
+
print(
|
|
541
|
+
"Please verify if you need to add them manually to 'files=[...]':"
|
|
542
|
+
)
|
|
543
|
+
for warning in scanner.dynamic_imports_warnings:
|
|
544
|
+
print(f" - {warning}")
|
|
545
|
+
print("=" * 60 + "\n")
|
|
546
|
+
|
|
547
|
+
# Also log them
|
|
548
|
+
for warning in scanner.dynamic_imports_warnings:
|
|
549
|
+
self.logger.warning(f"Dynamic import warning: {warning}")
|
|
550
|
+
|
|
551
|
+
except Exception as e:
|
|
552
|
+
self.logger.warning(f"Dependency detection from function failed: {e}")
|
|
553
|
+
|
|
554
|
+
# Register signal handlers
|
|
555
|
+
original_sigint = signal.getsignal(signal.SIGINT)
|
|
556
|
+
original_sigterm = signal.getsignal(signal.SIGTERM)
|
|
557
|
+
signal.signal(signal.SIGINT, self._handle_signal)
|
|
558
|
+
signal.signal(signal.SIGTERM, self._handle_signal)
|
|
559
|
+
|
|
560
|
+
try:
|
|
561
|
+
# Serialize function and arguments
|
|
562
|
+
if serialize:
|
|
563
|
+
self.__serialize_func_and_args(func, args)
|
|
564
|
+
|
|
565
|
+
# Execute
|
|
566
|
+
if self.asynchronous:
|
|
567
|
+
job_id = self.backend.run(cancel_old_jobs=cancel_old_jobs, wait=False)
|
|
568
|
+
return self.FunctionReturn(self, job_id)
|
|
569
|
+
else:
|
|
570
|
+
return self.backend.run(cancel_old_jobs=cancel_old_jobs, wait=True)
|
|
571
|
+
finally:
|
|
572
|
+
# Restore original signal handlers
|
|
573
|
+
# In asynchronous mode, we might not want to restore if we return immediately
|
|
574
|
+
# but usually we do because the launcher __call__ returns.
|
|
575
|
+
signal.signal(signal.SIGINT, original_sigint)
|
|
576
|
+
signal.signal(signal.SIGTERM, original_sigterm)
|
|
577
|
+
|
|
578
|
+
def _dedent_source(self, source: str) -> str:
|
|
579
|
+
"""Dedent source code"""
|
|
580
|
+
lines = source.split("\n")
|
|
581
|
+
if not lines:
|
|
582
|
+
return source
|
|
583
|
+
|
|
584
|
+
first_line = lines[0]
|
|
585
|
+
# Skip empty lines at the start
|
|
586
|
+
first_non_empty = next((i for i, line in enumerate(lines) if line.strip()), 0)
|
|
587
|
+
|
|
588
|
+
if first_non_empty < len(lines):
|
|
589
|
+
first_line = lines[first_non_empty]
|
|
590
|
+
indent = len(first_line) - len(first_line.lstrip())
|
|
591
|
+
|
|
592
|
+
# Deduplicate indentation, but preserve empty lines
|
|
593
|
+
deduplicated_lines = []
|
|
594
|
+
for line in lines:
|
|
595
|
+
if line.strip(): # Non-empty line
|
|
596
|
+
if len(line) >= indent:
|
|
597
|
+
deduplicated_lines.append(line[indent:])
|
|
598
|
+
else:
|
|
599
|
+
deduplicated_lines.append(line)
|
|
600
|
+
else: # Empty line
|
|
601
|
+
deduplicated_lines.append("")
|
|
602
|
+
return "\n".join(deduplicated_lines)
|
|
603
|
+
|
|
604
|
+
return source
|
|
605
|
+
|
|
606
|
+
def _resolve_dependencies(
|
|
607
|
+
self, func: Callable
|
|
608
|
+
) -> Tuple[List[str], List[str], bool]:
|
|
609
|
+
"""
|
|
610
|
+
Analyze function dependencies and resolve them recursively.
|
|
611
|
+
Returns: (imports_to_add, source_code_to_add, is_safe)
|
|
612
|
+
"""
|
|
613
|
+
imports = set()
|
|
614
|
+
sources = [] # List of (name, source) tuples to sort or deduplicate?
|
|
615
|
+
# Actually simple list is fine, but order matters?
|
|
616
|
+
# Dependencies should come before usage?
|
|
617
|
+
# Python functions are late-binding, so order of definition doesn't matter strictly
|
|
618
|
+
# as long as they are defined before CALL.
|
|
619
|
+
# But for variables/classes it might matter.
|
|
620
|
+
# We'll append sources.
|
|
621
|
+
|
|
622
|
+
sources_map = {} # name -> source
|
|
623
|
+
|
|
624
|
+
queue = [func]
|
|
625
|
+
processed_funcs = set() # code objects or funcs
|
|
626
|
+
|
|
627
|
+
import inspect
|
|
628
|
+
|
|
629
|
+
while queue:
|
|
630
|
+
current_func = queue.pop(0)
|
|
631
|
+
|
|
632
|
+
# Use code object for identity if possible, else func object
|
|
633
|
+
func_id = current_func
|
|
634
|
+
if hasattr(current_func, "__code__"):
|
|
635
|
+
func_id = current_func.__code__
|
|
636
|
+
|
|
637
|
+
if func_id in processed_funcs:
|
|
638
|
+
continue
|
|
639
|
+
processed_funcs.add(func_id)
|
|
640
|
+
|
|
641
|
+
# Closures are still hard. Reject them.
|
|
642
|
+
if hasattr(current_func, "__code__") and current_func.__code__.co_freevars:
|
|
643
|
+
self.logger.debug(
|
|
644
|
+
f"Function {current_func.__name__} uses closures. Unsafe."
|
|
645
|
+
)
|
|
646
|
+
return [], [], False
|
|
647
|
+
|
|
648
|
+
builtin_names = set(dir(builtins))
|
|
649
|
+
global_names = set()
|
|
650
|
+
|
|
651
|
+
# Find global names used
|
|
652
|
+
try:
|
|
653
|
+
for instruction in dis.get_instructions(current_func):
|
|
654
|
+
if instruction.opname == "LOAD_GLOBAL":
|
|
655
|
+
if instruction.argval not in builtin_names:
|
|
656
|
+
global_names.add(instruction.argval)
|
|
657
|
+
except Exception as e:
|
|
658
|
+
self.logger.debug(f"Bytecode analysis failed for {current_func}: {e}")
|
|
659
|
+
# If it's the main func, we must fail. If it's a dependency, maybe we can skip?
|
|
660
|
+
# Better fail safe.
|
|
661
|
+
return [], [], False
|
|
662
|
+
|
|
663
|
+
# Resolve each name
|
|
664
|
+
for name in global_names:
|
|
665
|
+
# If name is not in globals, it might be a problem
|
|
666
|
+
if (
|
|
667
|
+
not hasattr(current_func, "__globals__")
|
|
668
|
+
or name not in current_func.__globals__
|
|
669
|
+
):
|
|
670
|
+
# Maybe it's a recursive self-reference?
|
|
671
|
+
if (
|
|
672
|
+
hasattr(current_func, "__name__")
|
|
673
|
+
and name == current_func.__name__
|
|
674
|
+
):
|
|
675
|
+
continue
|
|
676
|
+
self.logger.debug(f"Global '{name}' not found in function globals.")
|
|
677
|
+
return [], [], False
|
|
678
|
+
|
|
679
|
+
obj = current_func.__globals__[name]
|
|
680
|
+
|
|
681
|
+
# Case 1: Module
|
|
682
|
+
if inspect.ismodule(obj):
|
|
683
|
+
if obj.__name__ == name:
|
|
684
|
+
imports.add(f"import {name}")
|
|
685
|
+
else:
|
|
686
|
+
imports.add(f"import {obj.__name__} as {name}")
|
|
687
|
+
|
|
688
|
+
# Case 2: Function (User defined)
|
|
689
|
+
elif inspect.isfunction(obj):
|
|
690
|
+
if obj not in queue and obj.__code__ not in processed_funcs:
|
|
691
|
+
try:
|
|
692
|
+
src = inspect.getsource(obj)
|
|
693
|
+
sources_map[name] = self._dedent_source(src)
|
|
694
|
+
queue.append(obj)
|
|
695
|
+
except:
|
|
696
|
+
self.logger.debug(
|
|
697
|
+
f"Could not get source for function '{name}'"
|
|
698
|
+
)
|
|
699
|
+
return [], [], False
|
|
700
|
+
|
|
701
|
+
# Case 3: Class
|
|
702
|
+
elif inspect.isclass(obj):
|
|
703
|
+
# We don't recurse into classes yet, just add source
|
|
704
|
+
if name not in sources_map:
|
|
705
|
+
try:
|
|
706
|
+
src = inspect.getsource(obj)
|
|
707
|
+
sources_map[name] = self._dedent_source(src)
|
|
708
|
+
except:
|
|
709
|
+
self.logger.debug(
|
|
710
|
+
f"Could not get source for class '{name}'"
|
|
711
|
+
)
|
|
712
|
+
return [], [], False
|
|
713
|
+
|
|
714
|
+
# Case 4: Builtin function/method
|
|
715
|
+
elif inspect.isbuiltin(obj):
|
|
716
|
+
mod = inspect.getmodule(obj)
|
|
717
|
+
if mod:
|
|
718
|
+
if obj.__name__ == name:
|
|
719
|
+
imports.add(f"from {mod.__name__} import {name}")
|
|
720
|
+
else:
|
|
721
|
+
imports.add(
|
|
722
|
+
f"from {mod.__name__} import {obj.__name__} as {name}"
|
|
723
|
+
)
|
|
724
|
+
else:
|
|
725
|
+
return [], [], False
|
|
726
|
+
|
|
727
|
+
else:
|
|
728
|
+
self.logger.debug(
|
|
729
|
+
f"Unsupported global object type: {type(obj)} for '{name}'"
|
|
730
|
+
)
|
|
731
|
+
return [], [], False
|
|
732
|
+
|
|
733
|
+
# Sort imports for consistency
|
|
734
|
+
sorted_imports = sorted(list(imports))
|
|
735
|
+
# Sources
|
|
736
|
+
sorted_sources = list(sources_map.values())
|
|
737
|
+
|
|
738
|
+
return sorted_imports, sorted_sources, True
|
|
739
|
+
|
|
740
|
+
def __serialize_func_and_args(self, func: Callable = None, args: list = None):
|
|
741
|
+
"""Serialize the function and the arguments
|
|
742
|
+
|
|
743
|
+
This method uses a simplified serialization strategy:
|
|
744
|
+
- **Always tries dill pickle first** (better performance, handles closures, complex objects)
|
|
745
|
+
- **Falls back to source extraction** only if dill pickle fails
|
|
746
|
+
- With pyenv, Python versions are identical, so dill pickle should always work
|
|
747
|
+
|
|
748
|
+
**Fallback to source extraction happens when:**
|
|
749
|
+
- Python versions are incompatible (rare with pyenv)
|
|
750
|
+
- Function is not serializable by dill (built-ins, C functions, etc.)
|
|
751
|
+
- Other serialization errors occur
|
|
752
|
+
|
|
753
|
+
**Limitations of source-based serialization:**
|
|
754
|
+
- Functions with closures: Only the function body is captured, not the captured
|
|
755
|
+
variables. The function may fail at runtime if it depends on closure variables.
|
|
756
|
+
- Functions defined in interactive shells or dynamically compiled code may not
|
|
757
|
+
have accessible source.
|
|
758
|
+
- Lambda functions defined inline may have limited source information.
|
|
759
|
+
|
|
760
|
+
Args:
|
|
761
|
+
func (Callable, optional): Function to serialize. Defaults to None.
|
|
762
|
+
args (list, optional): Arguments of the function. Defaults to None.
|
|
763
|
+
"""
|
|
764
|
+
self.logger.info("Serializing function and arguments...")
|
|
765
|
+
|
|
766
|
+
source_extracted = False
|
|
767
|
+
source_method = None
|
|
768
|
+
dill_pickle_used = False
|
|
769
|
+
serialization_method = "dill_pickle" # Default method
|
|
770
|
+
|
|
771
|
+
# Step 1: Always try dill pickle first
|
|
772
|
+
self.logger.info("Attempting dill pickle serialization...")
|
|
773
|
+
try:
|
|
774
|
+
# Try to pickle the function directly with dill
|
|
775
|
+
func_pickle_path = os.path.join(self.project_path, "func.pkl")
|
|
776
|
+
with open(func_pickle_path, "wb") as f:
|
|
777
|
+
dill.dump(func, f)
|
|
778
|
+
|
|
779
|
+
# If successful, use pickle
|
|
780
|
+
dill_pickle_used = True
|
|
781
|
+
serialization_method = "dill_pickle"
|
|
782
|
+
self.logger.info("✅ Successfully serialized function with dill pickle.")
|
|
783
|
+
|
|
784
|
+
# Clean up any stale source files since we're using dill pickle
|
|
785
|
+
source_path = os.path.join(self.project_path, "func_source.py")
|
|
786
|
+
name_path = os.path.join(self.project_path, "func_name.txt")
|
|
787
|
+
if os.path.exists(source_path):
|
|
788
|
+
os.remove(source_path)
|
|
789
|
+
if os.path.exists(name_path):
|
|
790
|
+
os.remove(name_path)
|
|
791
|
+
|
|
792
|
+
except Exception as e:
|
|
793
|
+
# Dill pickle failed - analyze why and fallback to source extraction
|
|
794
|
+
error_type = type(e).__name__
|
|
795
|
+
error_message = str(e)
|
|
796
|
+
|
|
797
|
+
# Determine likely reason for failure
|
|
798
|
+
reason_explanation = "Unknown error"
|
|
799
|
+
if "opcode" in error_message.lower() or "bytecode" in error_message.lower():
|
|
800
|
+
reason_explanation = (
|
|
801
|
+
"Python version incompatibility (bytecode mismatch)"
|
|
802
|
+
)
|
|
803
|
+
elif "cannot pickle" in error_message.lower():
|
|
804
|
+
reason_explanation = (
|
|
805
|
+
"Function not serializable by dill (built-in, C function, etc.)"
|
|
806
|
+
)
|
|
807
|
+
elif "recursion" in error_message.lower():
|
|
808
|
+
reason_explanation = "Recursion limit reached during serialization"
|
|
809
|
+
else:
|
|
810
|
+
reason_explanation = f"Serialization error: {error_type}"
|
|
811
|
+
|
|
812
|
+
self.logger.error(
|
|
813
|
+
f"❌ dill pickle serialization failed: {error_type}: {error_message}"
|
|
814
|
+
)
|
|
815
|
+
self.logger.warning(
|
|
816
|
+
f"⚠️ Falling back to source extraction. Reason: {reason_explanation}"
|
|
817
|
+
)
|
|
818
|
+
|
|
819
|
+
dill_pickle_used = False
|
|
820
|
+
serialization_method = "source_extraction"
|
|
821
|
+
|
|
822
|
+
# Continue with source extraction below
|
|
823
|
+
|
|
824
|
+
# Step 2: Try source extraction if dill pickle failed
|
|
825
|
+
if not dill_pickle_used:
|
|
826
|
+
self.logger.info("📝 Using source extraction fallback (dill pickle failed)")
|
|
827
|
+
|
|
828
|
+
# Only analyze dependencies if we need source extraction
|
|
829
|
+
extra_imports, extra_sources, is_safe = self._resolve_dependencies(func)
|
|
830
|
+
|
|
831
|
+
if is_safe:
|
|
832
|
+
# Method 1: Try inspect.getsource() (standard library, most common)
|
|
833
|
+
try:
|
|
834
|
+
source = inspect.getsource(func)
|
|
835
|
+
source_method = "inspect.getsource"
|
|
836
|
+
|
|
837
|
+
# Combine parts
|
|
838
|
+
# 1. Imports
|
|
839
|
+
# 2. Dependency sources
|
|
840
|
+
# 3. Main function source
|
|
841
|
+
|
|
842
|
+
parts = []
|
|
843
|
+
if extra_imports:
|
|
844
|
+
parts.extend(extra_imports)
|
|
845
|
+
parts.append("") # newline
|
|
846
|
+
|
|
847
|
+
if extra_sources:
|
|
848
|
+
parts.extend(extra_sources)
|
|
849
|
+
parts.append("") # newline
|
|
850
|
+
|
|
851
|
+
# Dedent main source
|
|
852
|
+
source = self._dedent_source(source)
|
|
853
|
+
parts.append(source)
|
|
854
|
+
|
|
855
|
+
final_source = "\n".join(parts)
|
|
856
|
+
|
|
857
|
+
source = final_source
|
|
858
|
+
source_extracted = True
|
|
859
|
+
|
|
860
|
+
except (OSError, TypeError) as e:
|
|
861
|
+
self.logger.debug(f"inspect.getsource() failed: {e}")
|
|
862
|
+
except Exception as e:
|
|
863
|
+
self.logger.debug(f"inspect.getsource() unexpected error: {e}")
|
|
864
|
+
|
|
865
|
+
# Method 2: Try dill.source.getsource()
|
|
866
|
+
# Note: dill doesn't support our dependency injection easily,
|
|
867
|
+
# so if inspect fails, we might just fallback to pickle.
|
|
868
|
+
# But let's keep it as backup for simple functions.
|
|
869
|
+
if not source_extracted:
|
|
870
|
+
# BUT we need to be careful. If we use dill source, we miss our injections.
|
|
871
|
+
# So if imports/sources are needed, we probably shouldn't use raw dill source.
|
|
872
|
+
# Since is_safe=True implies we resolved dependencies, we EXPECT them to be injected.
|
|
873
|
+
# If inspect fails, we can't easily combine dill source with our injections reliably
|
|
874
|
+
# (dill source might have different indentation/structure).
|
|
875
|
+
# So let's skip dill fallback if we have dependencies.
|
|
876
|
+
if not extra_imports and not extra_sources:
|
|
877
|
+
try:
|
|
878
|
+
if hasattr(dill, "source") and hasattr(
|
|
879
|
+
dill.source, "getsource"
|
|
880
|
+
):
|
|
881
|
+
source = dill.source.getsource(func)
|
|
882
|
+
source_method = "dill.source.getsource"
|
|
883
|
+
source_extracted = True
|
|
884
|
+
except Exception as e:
|
|
885
|
+
self.logger.debug(f"dill.source.getsource() failed: {e}")
|
|
886
|
+
else:
|
|
887
|
+
self.logger.warning(
|
|
888
|
+
"⚠️ Function unsafe for source extraction (unresolvable globals/closures). Will use dill pickle anyway."
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
# Process and save source if extracted
|
|
892
|
+
if source_extracted:
|
|
893
|
+
try:
|
|
894
|
+
# Source is already prepared and dedented above if using inspect.
|
|
895
|
+
# If using dill (fallback path), we might need to dedent.
|
|
896
|
+
if source_method == "dill.source.getsource":
|
|
897
|
+
source = self._dedent_source(source)
|
|
898
|
+
|
|
899
|
+
# Save source code
|
|
900
|
+
with open(os.path.join(self.project_path, "func_source.py"), "w") as f:
|
|
901
|
+
f.write(source)
|
|
902
|
+
|
|
903
|
+
# Save function name for loading
|
|
904
|
+
with open(os.path.join(self.project_path, "func_name.txt"), "w") as f:
|
|
905
|
+
f.write(func.__name__)
|
|
906
|
+
|
|
907
|
+
self.logger.info(
|
|
908
|
+
f"✅ Function source extracted successfully using {source_method}."
|
|
909
|
+
)
|
|
910
|
+
serialization_method = "source_extraction"
|
|
911
|
+
|
|
912
|
+
except Exception as e:
|
|
913
|
+
self.logger.warning(f"Failed to process/save function source: {e}")
|
|
914
|
+
source_extracted = False
|
|
915
|
+
# Fallback to dill pickle if source extraction save failed
|
|
916
|
+
serialization_method = "dill_pickle"
|
|
917
|
+
|
|
918
|
+
# If source extraction failed or was skipped, ensure no stale source files exist
|
|
919
|
+
if not source_extracted:
|
|
920
|
+
source_path = os.path.join(self.project_path, "func_source.py")
|
|
921
|
+
if os.path.exists(source_path):
|
|
922
|
+
os.remove(source_path)
|
|
923
|
+
|
|
924
|
+
# Always create func_name.txt even if source extraction failed
|
|
925
|
+
# This is needed for Desi backend queue management
|
|
926
|
+
func_name_path = os.path.join(self.project_path, "func_name.txt")
|
|
927
|
+
if not os.path.exists(func_name_path):
|
|
928
|
+
try:
|
|
929
|
+
with open(func_name_path, "w") as f:
|
|
930
|
+
f.write(func.__name__)
|
|
931
|
+
self.logger.debug(f"Created func_name.txt with function name: {func.__name__}")
|
|
932
|
+
except Exception as e:
|
|
933
|
+
self.logger.warning(f"Failed to create func_name.txt: {e}")
|
|
934
|
+
|
|
935
|
+
# If source extraction was attempted but failed, log it
|
|
936
|
+
if not dill_pickle_used:
|
|
937
|
+
self.logger.warning(
|
|
938
|
+
"⚠️ Source extraction failed. Using dill pickle as final fallback."
|
|
939
|
+
)
|
|
940
|
+
# Ensure we still pickle the function even if source extraction failed
|
|
941
|
+
serialization_method = "dill_pickle"
|
|
942
|
+
|
|
943
|
+
# Always pickle the function (used by dill pickle strategy or as fallback)
|
|
944
|
+
# Create func.pkl if not already created (we created it earlier if dill_pickle_used was True)
|
|
945
|
+
if not dill_pickle_used:
|
|
946
|
+
try:
|
|
947
|
+
func_pickle_path = os.path.join(self.project_path, "func.pkl")
|
|
948
|
+
with open(func_pickle_path, "wb") as f:
|
|
949
|
+
dill.dump(func, f)
|
|
950
|
+
self.logger.debug(
|
|
951
|
+
"Created func.pkl as fallback (dill pickle or source extraction fallback)"
|
|
952
|
+
)
|
|
953
|
+
except Exception as e:
|
|
954
|
+
# Only raise if we have no other option (both dill pickle and source extraction failed)
|
|
955
|
+
if not source_extracted:
|
|
956
|
+
self.logger.error(
|
|
957
|
+
f"❌ Critical: Failed to pickle function even as fallback: {e}"
|
|
958
|
+
)
|
|
959
|
+
raise
|
|
960
|
+
else:
|
|
961
|
+
# If source extraction succeeded, warn but don't fail
|
|
962
|
+
self.logger.warning(
|
|
963
|
+
f"⚠️ Failed to create func.pkl fallback (source extraction will be used): {e}"
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
# Save serialization method indicator
|
|
967
|
+
method_file = os.path.join(self.project_path, "serialization_method.txt")
|
|
968
|
+
with open(method_file, "w") as f:
|
|
969
|
+
f.write(f"{serialization_method}\n")
|
|
970
|
+
|
|
971
|
+
# Pickle the arguments
|
|
972
|
+
if args is None:
|
|
973
|
+
args = {}
|
|
974
|
+
with open(os.path.join(self.project_path, "args.pkl"), "wb") as f:
|
|
975
|
+
dill.dump(args, f)
|
|
976
|
+
|
|
977
|
+
|
|
978
|
+
# ---------------------------------------------------------------------------- #
|
|
979
|
+
# EXAMPLE OF EXECUTION #
|
|
980
|
+
# ---------------------------------------------------------------------------- #
|
|
981
|
+
if __name__ == "__main__":
|
|
982
|
+
import ray
|
|
983
|
+
import torch
|
|
984
|
+
|
|
985
|
+
def function_inside_function():
|
|
986
|
+
# Check if file exists before trying to read it, as paths might differ
|
|
987
|
+
if os.path.exists("documentation/RayLauncher.html"):
|
|
988
|
+
with open("documentation/RayLauncher.html", "r") as f:
|
|
989
|
+
return f.read()[0:10]
|
|
990
|
+
return "DocNotFound"
|
|
991
|
+
|
|
992
|
+
def example_func(x):
|
|
993
|
+
import time # Encapsulated imports works too !
|
|
994
|
+
print("Waiting for 60 seconds so that you can check the dashboard...")
|
|
995
|
+
time.sleep(60)
|
|
996
|
+
print("Done waiting !")
|
|
997
|
+
result = (
|
|
998
|
+
ray.cluster_resources(),
|
|
999
|
+
f"GPU is available : {torch.cuda.is_available()}",
|
|
1000
|
+
x + 1,
|
|
1001
|
+
function_inside_function(),
|
|
1002
|
+
)
|
|
1003
|
+
return result
|
|
1004
|
+
|
|
1005
|
+
cluster = RayLauncher(
|
|
1006
|
+
project_name="example", # Name of the project (will create a directory with this name in the current directory)
|
|
1007
|
+
files=["documentation/RayLauncher.html"], # List of files to push to the server
|
|
1008
|
+
use_gpu=True, # If you need GPU, you can set it to True
|
|
1009
|
+
runtime_env={
|
|
1010
|
+
"env_vars": {"NCCL_SOCKET_IFNAME": "eno1"}
|
|
1011
|
+
}, # Example of environment variable
|
|
1012
|
+
server_run=True, # To run the code on the server and not locally
|
|
1013
|
+
cluster="desi", # Use Desi backend (credentials loaded from .env: DESI_USERNAME and DESI_PASSWORD)
|
|
1014
|
+
force_reinstall_venv=False, # Force reinstall venv to test with Python 3.12.1
|
|
1015
|
+
retention_days=1, # Retain files and venv for 1 day before cleanup
|
|
1016
|
+
)
|
|
1017
|
+
|
|
1018
|
+
result = cluster(example_func, args={"x": 5}) # Execute function with arguments
|
|
1019
|
+
print(result)
|