slurmray 6.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of slurmray might be problematic. Click here for more details.

@@ -0,0 +1,1019 @@
1
+ from typing import Any, Callable, List
2
+ import sys
3
+ import os
4
+ import dill
5
+ import logging
6
+ import signal
7
+ import dis
8
+ import builtins
9
+ import inspect
10
+ from typing import Any, Callable, List, Tuple, Set, Generator
11
+ from getpass import getpass
12
+ import time
13
+
14
+
15
+ from dotenv import load_dotenv
16
+
17
+ from slurmray.backend.slurm import SlurmBackend
18
+ from slurmray.backend.local import LocalBackend
19
+ from slurmray.backend.desi import DesiBackend
20
+
21
+ dill.settings["recurse"] = True
22
+
23
+
24
+ class RayLauncher:
25
+ """A class that automatically connects RAY workers and executes the function requested by the user.
26
+
27
+ Official tool from DESI @ HEC UNIL.
28
+
29
+ Supports multiple execution modes:
30
+ - **Curnagl mode** (`cluster='curnagl'`): For Slurm-based clusters like Curnagl. Uses sbatch/squeue for job management.
31
+ - **Desi mode** (`cluster='desi'`): For standalone servers like ISIPOL09. Uses Smart Lock scheduling for resource management.
32
+ - **Local mode** (`cluster='local'`): For local execution without remote server/cluster.
33
+ - **Custom IP** (`cluster='<ip_or_hostname>'`): For custom Slurm clusters. Uses the provided IP/hostname.
34
+
35
+ The launcher automatically selects the appropriate backend based on the `cluster` parameter and environment detection.
36
+ """
37
+
38
+ class FunctionReturn:
39
+ """Object returned when running in asynchronous mode.
40
+ Allows monitoring logs and retrieving the result later.
41
+ """
42
+ def __init__(self, launcher, job_id=None):
43
+ self.launcher = launcher
44
+ self.job_id = job_id
45
+ self._cached_result = None
46
+
47
+ @property
48
+ def result(self):
49
+ """Get the result of the function execution.
50
+ Returns "Compute still in progress" if not finished.
51
+ """
52
+ if self._cached_result is not None:
53
+ return self._cached_result
54
+
55
+ # Attempt to fetch result from backend
56
+ # We use a new method on backend to check/fetch result without blocking
57
+ if hasattr(self.launcher.backend, "get_result"):
58
+ res = self.launcher.backend.get_result(self.job_id)
59
+ if res is not None:
60
+ self._cached_result = res
61
+ return res
62
+
63
+ return "Compute still in progress"
64
+
65
+ @property
66
+ def logs(self) -> Generator[str, None, None]:
67
+ """Get the logs of the function execution as a stream (generator)."""
68
+ if hasattr(self.launcher.backend, "get_logs"):
69
+ yield from self.launcher.backend.get_logs(self.job_id)
70
+ else:
71
+ yield "Logs not available for this backend."
72
+
73
+ def cancel(self):
74
+ """Cancel the running job."""
75
+ if hasattr(self.launcher.backend, "cancel"):
76
+ self.launcher.backend.cancel(self.job_id)
77
+
78
+ def __getstate__(self):
79
+ """Custom serialization to ensure picklability"""
80
+ state = self.__dict__.copy()
81
+ # Ensure launcher is picklable. The launcher itself might have non-picklable attributes (like ssh_client).
82
+ # We rely on RayLauncher and Backend handling their own serialization safety.
83
+ return state
84
+
85
+ def __setstate__(self, state):
86
+ self.__dict__.update(state)
87
+
88
+
89
+ def __init__(
90
+ self,
91
+ project_name: str = None,
92
+ files: List[str] = [],
93
+ modules: List[str] = [],
94
+ node_nbr: int = 1,
95
+ use_gpu: bool = False,
96
+ memory: int = 64,
97
+ max_running_time: int = 60,
98
+ runtime_env: dict = {"env_vars": {}},
99
+ server_run: bool = True,
100
+ server_ssh: str = None, # Auto-detected from cluster parameter
101
+ server_username: str = None,
102
+ server_password: str = None,
103
+ log_file: str = "logs/RayLauncher.log",
104
+ cluster: str = "curnagl", # 'curnagl', 'desi', 'local', or custom IP/hostname
105
+ force_reinstall_venv: bool = False,
106
+ retention_days: int = 7,
107
+ asynchronous: bool = False,
108
+ ):
109
+ """Initialize the launcher
110
+
111
+ Args:
112
+ project_name (str, optional): Name of the project. Defaults to None.
113
+ files (List[str], optional): List of files to push to the cluster/server. This path must be **relative** to the project directory. Defaults to [].
114
+ modules (List[str], optional): List of modules to load (Slurm mode only). Use `module spider` to see available modules. Ignored in Desi mode. Defaults to None.
115
+ node_nbr (int, optional): Number of nodes to use. For Desi mode, this is always 1 (single server). Defaults to 1.
116
+ use_gpu (bool, optional): Use GPU or not. Defaults to False.
117
+ memory (int, optional): Amount of RAM to use per node in GigaBytes. For Desi mode, this is not enforced (shared resource). Defaults to 64.
118
+ max_running_time (int, optional): Maximum running time of the job in minutes. For Desi mode, this is not enforced by a scheduler. Defaults to 60.
119
+ runtime_env (dict, optional): Environment variables to share between all the workers. Can be useful for issues like https://github.com/ray-project/ray/issues/418. Default to empty.
120
+ server_run (bool, optional): If you run the launcher from your local machine, you can use this parameter to execute your function using online cluster/server ressources. Defaults to True.
121
+ server_ssh (str, optional): If `server_run` is set to true, the address of the server to use. Auto-detected from `cluster` parameter if not provided. Defaults to None (auto-detected).
122
+ server_username (str, optional): If `server_run` is set to true, the username with which you wish to connect. Credentials are automatically loaded from a `.env` file (CURNAGL_USERNAME for Curnagl/custom IP, DESI_USERNAME for Desi) if available. Priority: environment variables → explicit parameter → default ("hjamet" for Curnagl/custom IP, "henri" for Desi).
123
+ server_password (str, optional): If `server_run` is set to true, the password of the user to connect to the server. Credentials are automatically loaded from a `.env` file (CURNAGL_PASSWORD for Curnagl/custom IP, DESI_PASSWORD for Desi) if available. Priority: explicit parameter → environment variables → interactive prompt. CAUTION: never write your password in the code. Defaults to None.
124
+ log_file (str, optional): Path to the log file. Defaults to "logs/RayLauncher.log".
125
+ cluster (str, optional): Cluster/server to use: 'curnagl' (default, Slurm cluster), 'desi' (ISIPOL09/Desi server), 'local' (local execution), or a custom IP/hostname (for custom Slurm clusters). Defaults to "curnagl".
126
+ force_reinstall_venv (bool, optional): Force complete removal and recreation of virtual environment on remote server/cluster. This will delete the existing venv and reinstall all packages from requirements.txt. Use this if the venv is corrupted or you need a clean installation. Defaults to False.
127
+ retention_days (int, optional): Number of days to retain files and venv on the cluster before automatic cleanup. Must be between 1 and 30 days. Defaults to 7.
128
+ asynchronous (bool, optional): If True, the call to the function returns immediately with a FunctionReturn object. Defaults to False.
129
+ """
130
+ # Load environment variables from .env file
131
+ load_dotenv()
132
+
133
+ # Normalize cluster parameter
134
+ cluster_lower = cluster.lower()
135
+
136
+ # Detect if cluster is a custom IP/hostname (not a known name)
137
+ is_custom_ip = cluster_lower not in ["curnagl", "desi", "local"]
138
+
139
+ # Determine cluster type and backend type
140
+ if cluster_lower == "local":
141
+ self.cluster_type = "local"
142
+ self.backend_type = "local"
143
+ # Force local execution
144
+ self._force_local = True
145
+ else:
146
+ self._force_local = False
147
+ if cluster_lower == "desi":
148
+ self.cluster_type = "desi"
149
+ self.backend_type = "desi"
150
+ elif cluster_lower == "curnagl" or is_custom_ip:
151
+ self.cluster_type = "curnagl" # Use "curnagl" for credential loading
152
+ self.backend_type = "slurm" # Use SlurmBackend
153
+ else:
154
+ raise ValueError(
155
+ f"Invalid cluster value: '{cluster}'. Use 'curnagl', 'desi', 'local', or a custom IP/hostname."
156
+ )
157
+
158
+ # Determine environment variable names based on cluster type
159
+ if self.cluster_type == "desi":
160
+ env_username_key = "DESI_USERNAME"
161
+ env_password_key = "DESI_PASSWORD"
162
+ default_username = "henri"
163
+ else: # curnagl or custom IP (both use CURNAGL credentials)
164
+ env_username_key = "CURNAGL_USERNAME"
165
+ env_password_key = "CURNAGL_PASSWORD"
166
+ default_username = "hjamet"
167
+
168
+ # Load credentials with priority: .env → explicit parameter → default/prompt
169
+ # Priority 1: Load from environment variables (from .env or system env)
170
+ env_username = os.getenv(env_username_key)
171
+ env_password = os.getenv(env_password_key)
172
+
173
+ # For username: explicit parameter → env → default
174
+ if server_username is not None:
175
+ # Explicit parameter provided
176
+ self.server_username = server_username
177
+ elif env_username:
178
+ # Load from environment
179
+ self.server_username = env_username
180
+ else:
181
+ # Use default
182
+ self.server_username = default_username
183
+
184
+ # For password: explicit parameter → env → None (will prompt later if needed)
185
+ # Explicit parameter takes precedence over env
186
+ if server_password is not None:
187
+ # Explicit parameter provided
188
+ self.server_password = server_password
189
+ elif env_password:
190
+ # Load from environment
191
+ self.server_password = env_password
192
+ else:
193
+ # None: will be prompted by backend if needed
194
+ self.server_password = None
195
+
196
+ # Save the other parameters
197
+ self.project_name = project_name
198
+ self.files = files
199
+ self.modules = modules
200
+ self.node_nbr = node_nbr
201
+ self.use_gpu = use_gpu
202
+ self.memory = memory
203
+ self.max_running_time = max_running_time
204
+
205
+ # Validate and save retention_days
206
+ if retention_days < 1 or retention_days > 30:
207
+ raise ValueError(
208
+ f"retention_days must be between 1 and 30, got {retention_days}"
209
+ )
210
+ self.retention_days = retention_days
211
+
212
+ # Set default runtime_env and add Ray warning suppression
213
+ if runtime_env is None:
214
+ runtime_env = {"env_vars": {}}
215
+ elif "env_vars" not in runtime_env:
216
+ runtime_env["env_vars"] = {}
217
+
218
+ # Suppress Ray FutureWarning about accelerator visible devices
219
+ if "RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO" not in runtime_env["env_vars"]:
220
+ runtime_env["env_vars"]["RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO"] = "0"
221
+
222
+ self.runtime_env = runtime_env
223
+ # Update server_run if cluster is "local"
224
+ if hasattr(self, "_force_local") and self._force_local:
225
+ self.server_run = False
226
+ else:
227
+ self.server_run = server_run
228
+
229
+ # Auto-detect server_ssh from cluster parameter if not provided
230
+ if self.server_run and server_ssh is None:
231
+ if cluster_lower == "desi":
232
+ self.server_ssh = "130.223.73.209"
233
+ elif cluster_lower == "curnagl":
234
+ self.server_ssh = "curnagl.dcsr.unil.ch"
235
+ elif is_custom_ip:
236
+ # Use the provided IP/hostname directly
237
+ self.server_ssh = cluster
238
+ else:
239
+ # Fallback (should not happen)
240
+ self.server_ssh = "curnagl.dcsr.unil.ch"
241
+ else:
242
+ self.server_ssh = server_ssh or "curnagl.dcsr.unil.ch"
243
+
244
+ self.log_file = log_file
245
+ self.force_reinstall_venv = force_reinstall_venv
246
+ self.asynchronous = asynchronous
247
+
248
+ # Track which parameters were explicitly passed (for warnings)
249
+ import inspect
250
+
251
+ frame = inspect.currentframe()
252
+ args, _, _, values = inspect.getargvalues(frame)
253
+ self._explicit_params = {
254
+ arg: values[arg] for arg in args[1:] if arg in values
255
+ } # Skip 'self'
256
+
257
+ self.__setup_logger()
258
+
259
+ # Create the project directory if not exists (needed for pwd_path)
260
+ self.pwd_path = os.getcwd()
261
+ self.module_path = os.path.dirname(os.path.abspath(__file__))
262
+ self.project_path = os.path.join(self.pwd_path, ".slogs", self.project_name)
263
+ if not os.path.exists(self.project_path):
264
+ os.makedirs(self.project_path)
265
+
266
+ # Detect local Python version
267
+ self.local_python_version = self._detect_local_python_version()
268
+
269
+ # Default modules with specific versions for Curnagl compatibility
270
+ # Using latest stable versions available on Curnagl (SLURM 24.05.3)
271
+ # gcc/13.2.0: Latest GCC version
272
+ # python/3.12.1: Latest Python version on Curnagl
273
+ # cuda/12.6.2: Latest CUDA version
274
+ # cudnn/9.2.0.82-12: Compatible with cuda/12.6.2
275
+ default_modules = ["gcc/13.2.0", "python/3.12.1"]
276
+
277
+ # Filter out any gcc or python modules from user list (we use defaults)
278
+ # Allow user to override by providing specific versions
279
+ user_modules = []
280
+ for mod in modules:
281
+ # Skip if it's a gcc or python module (user can override by providing full version)
282
+ if mod.startswith("gcc") or mod.startswith("python"):
283
+ continue
284
+ user_modules.append(mod)
285
+
286
+ self.modules = default_modules + user_modules
287
+
288
+ if self.use_gpu is True:
289
+ # Check if user provided specific cuda/cudnn versions
290
+ has_cuda = any("cuda" in mod for mod in self.modules)
291
+ has_cudnn = any("cudnn" in mod for mod in self.modules)
292
+ if not has_cuda:
293
+ self.modules.append("cuda/12.6.2")
294
+ if not has_cudnn:
295
+ self.modules.append("cudnn/9.2.0.82-12")
296
+
297
+ # --- Validation des Arguments ---
298
+ self._validate_arguments()
299
+
300
+ # Check if this code is running on a cluster (only relevant for Slurm, usually)
301
+ self.cluster = os.path.exists("/usr/bin/sbatch")
302
+
303
+ # Initialize Backend
304
+ if self.backend_type == "local":
305
+ self.backend = LocalBackend(self)
306
+ elif self.server_run:
307
+ if self.backend_type == "desi":
308
+ self.backend = DesiBackend(self)
309
+ elif self.backend_type == "slurm":
310
+ self.backend = SlurmBackend(self)
311
+ else:
312
+ raise ValueError(
313
+ f"Unknown backend type: {self.backend_type}. This should not happen."
314
+ )
315
+ elif self.cluster: # Running ON a cluster (Slurm)
316
+ self.backend = SlurmBackend(self)
317
+ else:
318
+ self.backend = LocalBackend(self)
319
+
320
+ # Auto-detect and add editable package source paths to files list
321
+ # Note: Intelligent dependency detection is now done in __call__
322
+ # when we have the function to analyze. We don't auto-add editable packages
323
+ # blindly anymore to avoid adding unwanted files or breaking with complex setups.
324
+
325
+ def __setup_logger(self):
326
+ """Setup the logger"""
327
+ # Create the log directory if not exists
328
+ log_dir = os.path.dirname(self.log_file)
329
+ if log_dir and not os.path.exists(log_dir):
330
+ os.makedirs(log_dir)
331
+
332
+ # Configure the logger
333
+ self.logger = logging.getLogger(f"RayLauncher-{self.project_name}")
334
+ self.logger.setLevel(logging.INFO)
335
+
336
+ # Remove existing handlers to avoid duplication if instantiated multiple times
337
+ if self.logger.hasHandlers():
338
+ self.logger.handlers.clear()
339
+
340
+ # File handler (constantly rewritten)
341
+ file_handler = logging.FileHandler(self.log_file, mode="w")
342
+ file_handler.setLevel(logging.INFO)
343
+ file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
344
+ file_handler.setFormatter(file_formatter)
345
+ self.logger.addHandler(file_handler)
346
+
347
+ # Console handler (only warnings and errors)
348
+ console_handler = logging.StreamHandler()
349
+ console_handler.setLevel(logging.WARNING)
350
+ console_formatter = logging.Formatter("%(levelname)s: %(message)s")
351
+ console_handler.setFormatter(console_formatter)
352
+ self.logger.addHandler(console_handler)
353
+
354
+ def _detect_local_python_version(self) -> str:
355
+ """Detect local Python version from .python-version file or sys.version_info
356
+
357
+ Returns:
358
+ str: Python version in format "X.Y.Z" (e.g., "3.12.1")
359
+ """
360
+ # Try to read from .python-version file first
361
+ python_version_file = os.path.join(self.pwd_path, ".python-version")
362
+ if os.path.exists(python_version_file):
363
+ with open(python_version_file, "r") as f:
364
+ version_str = f.read().strip()
365
+ # Validate format (should be X.Y or X.Y.Z)
366
+ import re
367
+
368
+ if re.match(r"^\d+\.\d+(\.\d+)?$", version_str):
369
+ # If only X.Y, add .0 for micro version
370
+ if version_str.count(".") == 1:
371
+ version_str = f"{version_str}.0"
372
+ self.logger.info(
373
+ f"Detected Python version from .python-version: {version_str}"
374
+ )
375
+ return version_str
376
+
377
+ # Fallback to sys.version_info
378
+ version_str = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
379
+ self.logger.info(
380
+ f"Detected Python version from sys.version_info: {version_str}"
381
+ )
382
+ return version_str
383
+
384
+ def _validate_arguments(self):
385
+ """Validate arguments and warn about inconsistencies"""
386
+ # Validate project_name is not None (required for project-based organization on cluster)
387
+ if self.project_name is None:
388
+ raise ValueError(
389
+ "project_name cannot be None. A project name is required for cluster execution."
390
+ )
391
+
392
+ if self.cluster_type == "desi":
393
+ # server_ssh is already set correctly in __init__
394
+ pass
395
+
396
+ if self.node_nbr > 1:
397
+ self.logger.warning(
398
+ f"Warning: Desi cluster only supports single node execution. node_nbr={self.node_nbr} will be ignored (effectively 1)."
399
+ )
400
+
401
+ # Only warn if modules were explicitly passed by user (not just defaults)
402
+ # Check if user provided modules beyond the default ones (gcc/python) or GPU modules (cuda/cudnn)
403
+ # GPU modules are added automatically if use_gpu=True, so they don't count as user-provided
404
+ user_provided_modules = [
405
+ m
406
+ for m in self.modules
407
+ if not (
408
+ m.startswith("gcc")
409
+ or m.startswith("python")
410
+ or m.startswith("cuda")
411
+ or m.startswith("cudnn")
412
+ )
413
+ ]
414
+ if "modules" in self._explicit_params and user_provided_modules:
415
+ self.logger.warning(
416
+ "Warning: Modules loading is not supported on Desi (no module system). Modules list will be ignored."
417
+ )
418
+
419
+ if "memory" in self._explicit_params and self.memory != 64: # 64 is default
420
+ self.logger.warning(
421
+ "Warning: Memory allocation is not enforced on Desi (shared resource)."
422
+ )
423
+
424
+ def _handle_signal(self, signum, frame):
425
+ """Handle interruption signals (SIGINT, SIGTERM) to cleanup resources"""
426
+ sig_name = signal.Signals(signum).name
427
+ self.logger.warning(f"Signal {sig_name} received. Cleaning up resources...")
428
+ print(f"\nInterruption received ({sig_name}). Canceling job and cleaning up...")
429
+
430
+ self.cancel()
431
+ sys.exit(1)
432
+
433
+ def cancel(self, target: Any = None):
434
+ """
435
+ Cancel a running job.
436
+
437
+ Args:
438
+ target (Any, optional): The job to cancel. Can be:
439
+ - None: Cancels the last job run by this launcher instance.
440
+ - str: A specific job ID.
441
+ - FunctionReturn: A specific FunctionReturn object.
442
+ """
443
+ if hasattr(self, "backend"):
444
+ job_id = None
445
+
446
+ # Determine job_id based on target
447
+ if target is None:
448
+ # Fallback to last job
449
+ if hasattr(self.backend, "job_id") and self.backend.job_id:
450
+ job_id = self.backend.job_id
451
+ elif hasattr(self, "job_id") and self.job_id:
452
+ job_id = self.job_id
453
+ elif isinstance(target, str):
454
+ job_id = target
455
+ elif isinstance(target, self.FunctionReturn):
456
+ job_id = target.job_id
457
+
458
+ if job_id:
459
+ self.backend.cancel(job_id)
460
+ else:
461
+ self.logger.warning("No job ID found to cancel.")
462
+ else:
463
+ self.logger.warning("No backend initialized, cannot cancel.")
464
+
465
+ def __call__(
466
+ self,
467
+ func: Callable,
468
+ args: dict = None,
469
+ cancel_old_jobs: bool = True,
470
+ serialize: bool = True,
471
+ ) -> Any:
472
+ """Launch the job and return the result
473
+
474
+ Args:
475
+ func (Callable): Function to execute. This function should not be remote but can use ray ressources.
476
+ args (dict, optional): Arguments of the function. Defaults to None (empty dict).
477
+ cancel_old_jobs (bool, optional): Cancel the old jobs. Defaults to True.
478
+ serialize (bool, optional): Serialize the function and the arguments. This should be set to False if the function is automatically called by the server. Defaults to True.
479
+
480
+ Returns:
481
+ Any: Result of the function
482
+ """
483
+ if args is None:
484
+ args = {}
485
+
486
+ # Intelligent dependency detection from function source file
487
+ if self.server_run:
488
+ try:
489
+ from slurmray.scanner import ProjectScanner
490
+
491
+ scanner = ProjectScanner(self.pwd_path, self.logger)
492
+ detected_dependencies = scanner.detect_dependencies_from_function(func)
493
+
494
+ added_count = 0
495
+ for dep in detected_dependencies:
496
+ # Skip invalid paths (empty, current directory, etc.)
497
+ if (
498
+ not dep
499
+ or dep == "."
500
+ or dep == ".."
501
+ or dep.startswith("./")
502
+ or dep.startswith("../")
503
+ ):
504
+ continue
505
+
506
+ # Skip paths that are outside project or in ignored directories
507
+ dep_abs = os.path.abspath(os.path.join(self.pwd_path, dep))
508
+ if not dep_abs.startswith(os.path.abspath(self.pwd_path)):
509
+ continue
510
+
511
+ # Check if it's a valid file or directory
512
+ if not os.path.exists(dep_abs):
513
+ continue
514
+
515
+ # Check if dependency is already covered by existing files/dirs
516
+ # E.g. if 'src' is in files, 'src/module.py' is covered
517
+ is_covered = False
518
+ for existing in self.files:
519
+ if dep == existing or (dep.startswith(existing + os.sep)):
520
+ is_covered = True
521
+ break
522
+
523
+ if not is_covered:
524
+ self.files.append(dep)
525
+ added_count += 1
526
+
527
+ if added_count > 0:
528
+ self.logger.info(
529
+ f"Auto-added {added_count} local dependencies to upload list (from function imports)."
530
+ )
531
+
532
+ # Display warnings for dynamic imports
533
+ if scanner.dynamic_imports_warnings:
534
+ print("\n" + "=" * 60)
535
+ print("⚠️ WARNING: Dynamic imports or file operations detected ⚠️")
536
+ print("=" * 60)
537
+ print(
538
+ "The following lines might require files that cannot be auto-detected."
539
+ )
540
+ print(
541
+ "Please verify if you need to add them manually to 'files=[...]':"
542
+ )
543
+ for warning in scanner.dynamic_imports_warnings:
544
+ print(f" - {warning}")
545
+ print("=" * 60 + "\n")
546
+
547
+ # Also log them
548
+ for warning in scanner.dynamic_imports_warnings:
549
+ self.logger.warning(f"Dynamic import warning: {warning}")
550
+
551
+ except Exception as e:
552
+ self.logger.warning(f"Dependency detection from function failed: {e}")
553
+
554
+ # Register signal handlers
555
+ original_sigint = signal.getsignal(signal.SIGINT)
556
+ original_sigterm = signal.getsignal(signal.SIGTERM)
557
+ signal.signal(signal.SIGINT, self._handle_signal)
558
+ signal.signal(signal.SIGTERM, self._handle_signal)
559
+
560
+ try:
561
+ # Serialize function and arguments
562
+ if serialize:
563
+ self.__serialize_func_and_args(func, args)
564
+
565
+ # Execute
566
+ if self.asynchronous:
567
+ job_id = self.backend.run(cancel_old_jobs=cancel_old_jobs, wait=False)
568
+ return self.FunctionReturn(self, job_id)
569
+ else:
570
+ return self.backend.run(cancel_old_jobs=cancel_old_jobs, wait=True)
571
+ finally:
572
+ # Restore original signal handlers
573
+ # In asynchronous mode, we might not want to restore if we return immediately
574
+ # but usually we do because the launcher __call__ returns.
575
+ signal.signal(signal.SIGINT, original_sigint)
576
+ signal.signal(signal.SIGTERM, original_sigterm)
577
+
578
+ def _dedent_source(self, source: str) -> str:
579
+ """Dedent source code"""
580
+ lines = source.split("\n")
581
+ if not lines:
582
+ return source
583
+
584
+ first_line = lines[0]
585
+ # Skip empty lines at the start
586
+ first_non_empty = next((i for i, line in enumerate(lines) if line.strip()), 0)
587
+
588
+ if first_non_empty < len(lines):
589
+ first_line = lines[first_non_empty]
590
+ indent = len(first_line) - len(first_line.lstrip())
591
+
592
+ # Deduplicate indentation, but preserve empty lines
593
+ deduplicated_lines = []
594
+ for line in lines:
595
+ if line.strip(): # Non-empty line
596
+ if len(line) >= indent:
597
+ deduplicated_lines.append(line[indent:])
598
+ else:
599
+ deduplicated_lines.append(line)
600
+ else: # Empty line
601
+ deduplicated_lines.append("")
602
+ return "\n".join(deduplicated_lines)
603
+
604
+ return source
605
+
606
+ def _resolve_dependencies(
607
+ self, func: Callable
608
+ ) -> Tuple[List[str], List[str], bool]:
609
+ """
610
+ Analyze function dependencies and resolve them recursively.
611
+ Returns: (imports_to_add, source_code_to_add, is_safe)
612
+ """
613
+ imports = set()
614
+ sources = [] # List of (name, source) tuples to sort or deduplicate?
615
+ # Actually simple list is fine, but order matters?
616
+ # Dependencies should come before usage?
617
+ # Python functions are late-binding, so order of definition doesn't matter strictly
618
+ # as long as they are defined before CALL.
619
+ # But for variables/classes it might matter.
620
+ # We'll append sources.
621
+
622
+ sources_map = {} # name -> source
623
+
624
+ queue = [func]
625
+ processed_funcs = set() # code objects or funcs
626
+
627
+ import inspect
628
+
629
+ while queue:
630
+ current_func = queue.pop(0)
631
+
632
+ # Use code object for identity if possible, else func object
633
+ func_id = current_func
634
+ if hasattr(current_func, "__code__"):
635
+ func_id = current_func.__code__
636
+
637
+ if func_id in processed_funcs:
638
+ continue
639
+ processed_funcs.add(func_id)
640
+
641
+ # Closures are still hard. Reject them.
642
+ if hasattr(current_func, "__code__") and current_func.__code__.co_freevars:
643
+ self.logger.debug(
644
+ f"Function {current_func.__name__} uses closures. Unsafe."
645
+ )
646
+ return [], [], False
647
+
648
+ builtin_names = set(dir(builtins))
649
+ global_names = set()
650
+
651
+ # Find global names used
652
+ try:
653
+ for instruction in dis.get_instructions(current_func):
654
+ if instruction.opname == "LOAD_GLOBAL":
655
+ if instruction.argval not in builtin_names:
656
+ global_names.add(instruction.argval)
657
+ except Exception as e:
658
+ self.logger.debug(f"Bytecode analysis failed for {current_func}: {e}")
659
+ # If it's the main func, we must fail. If it's a dependency, maybe we can skip?
660
+ # Better fail safe.
661
+ return [], [], False
662
+
663
+ # Resolve each name
664
+ for name in global_names:
665
+ # If name is not in globals, it might be a problem
666
+ if (
667
+ not hasattr(current_func, "__globals__")
668
+ or name not in current_func.__globals__
669
+ ):
670
+ # Maybe it's a recursive self-reference?
671
+ if (
672
+ hasattr(current_func, "__name__")
673
+ and name == current_func.__name__
674
+ ):
675
+ continue
676
+ self.logger.debug(f"Global '{name}' not found in function globals.")
677
+ return [], [], False
678
+
679
+ obj = current_func.__globals__[name]
680
+
681
+ # Case 1: Module
682
+ if inspect.ismodule(obj):
683
+ if obj.__name__ == name:
684
+ imports.add(f"import {name}")
685
+ else:
686
+ imports.add(f"import {obj.__name__} as {name}")
687
+
688
+ # Case 2: Function (User defined)
689
+ elif inspect.isfunction(obj):
690
+ if obj not in queue and obj.__code__ not in processed_funcs:
691
+ try:
692
+ src = inspect.getsource(obj)
693
+ sources_map[name] = self._dedent_source(src)
694
+ queue.append(obj)
695
+ except:
696
+ self.logger.debug(
697
+ f"Could not get source for function '{name}'"
698
+ )
699
+ return [], [], False
700
+
701
+ # Case 3: Class
702
+ elif inspect.isclass(obj):
703
+ # We don't recurse into classes yet, just add source
704
+ if name not in sources_map:
705
+ try:
706
+ src = inspect.getsource(obj)
707
+ sources_map[name] = self._dedent_source(src)
708
+ except:
709
+ self.logger.debug(
710
+ f"Could not get source for class '{name}'"
711
+ )
712
+ return [], [], False
713
+
714
+ # Case 4: Builtin function/method
715
+ elif inspect.isbuiltin(obj):
716
+ mod = inspect.getmodule(obj)
717
+ if mod:
718
+ if obj.__name__ == name:
719
+ imports.add(f"from {mod.__name__} import {name}")
720
+ else:
721
+ imports.add(
722
+ f"from {mod.__name__} import {obj.__name__} as {name}"
723
+ )
724
+ else:
725
+ return [], [], False
726
+
727
+ else:
728
+ self.logger.debug(
729
+ f"Unsupported global object type: {type(obj)} for '{name}'"
730
+ )
731
+ return [], [], False
732
+
733
+ # Sort imports for consistency
734
+ sorted_imports = sorted(list(imports))
735
+ # Sources
736
+ sorted_sources = list(sources_map.values())
737
+
738
+ return sorted_imports, sorted_sources, True
739
+
740
+ def __serialize_func_and_args(self, func: Callable = None, args: list = None):
741
+ """Serialize the function and the arguments
742
+
743
+ This method uses a simplified serialization strategy:
744
+ - **Always tries dill pickle first** (better performance, handles closures, complex objects)
745
+ - **Falls back to source extraction** only if dill pickle fails
746
+ - With pyenv, Python versions are identical, so dill pickle should always work
747
+
748
+ **Fallback to source extraction happens when:**
749
+ - Python versions are incompatible (rare with pyenv)
750
+ - Function is not serializable by dill (built-ins, C functions, etc.)
751
+ - Other serialization errors occur
752
+
753
+ **Limitations of source-based serialization:**
754
+ - Functions with closures: Only the function body is captured, not the captured
755
+ variables. The function may fail at runtime if it depends on closure variables.
756
+ - Functions defined in interactive shells or dynamically compiled code may not
757
+ have accessible source.
758
+ - Lambda functions defined inline may have limited source information.
759
+
760
+ Args:
761
+ func (Callable, optional): Function to serialize. Defaults to None.
762
+ args (list, optional): Arguments of the function. Defaults to None.
763
+ """
764
+ self.logger.info("Serializing function and arguments...")
765
+
766
+ source_extracted = False
767
+ source_method = None
768
+ dill_pickle_used = False
769
+ serialization_method = "dill_pickle" # Default method
770
+
771
+ # Step 1: Always try dill pickle first
772
+ self.logger.info("Attempting dill pickle serialization...")
773
+ try:
774
+ # Try to pickle the function directly with dill
775
+ func_pickle_path = os.path.join(self.project_path, "func.pkl")
776
+ with open(func_pickle_path, "wb") as f:
777
+ dill.dump(func, f)
778
+
779
+ # If successful, use pickle
780
+ dill_pickle_used = True
781
+ serialization_method = "dill_pickle"
782
+ self.logger.info("✅ Successfully serialized function with dill pickle.")
783
+
784
+ # Clean up any stale source files since we're using dill pickle
785
+ source_path = os.path.join(self.project_path, "func_source.py")
786
+ name_path = os.path.join(self.project_path, "func_name.txt")
787
+ if os.path.exists(source_path):
788
+ os.remove(source_path)
789
+ if os.path.exists(name_path):
790
+ os.remove(name_path)
791
+
792
+ except Exception as e:
793
+ # Dill pickle failed - analyze why and fallback to source extraction
794
+ error_type = type(e).__name__
795
+ error_message = str(e)
796
+
797
+ # Determine likely reason for failure
798
+ reason_explanation = "Unknown error"
799
+ if "opcode" in error_message.lower() or "bytecode" in error_message.lower():
800
+ reason_explanation = (
801
+ "Python version incompatibility (bytecode mismatch)"
802
+ )
803
+ elif "cannot pickle" in error_message.lower():
804
+ reason_explanation = (
805
+ "Function not serializable by dill (built-in, C function, etc.)"
806
+ )
807
+ elif "recursion" in error_message.lower():
808
+ reason_explanation = "Recursion limit reached during serialization"
809
+ else:
810
+ reason_explanation = f"Serialization error: {error_type}"
811
+
812
+ self.logger.error(
813
+ f"❌ dill pickle serialization failed: {error_type}: {error_message}"
814
+ )
815
+ self.logger.warning(
816
+ f"⚠️ Falling back to source extraction. Reason: {reason_explanation}"
817
+ )
818
+
819
+ dill_pickle_used = False
820
+ serialization_method = "source_extraction"
821
+
822
+ # Continue with source extraction below
823
+
824
+ # Step 2: Try source extraction if dill pickle failed
825
+ if not dill_pickle_used:
826
+ self.logger.info("📝 Using source extraction fallback (dill pickle failed)")
827
+
828
+ # Only analyze dependencies if we need source extraction
829
+ extra_imports, extra_sources, is_safe = self._resolve_dependencies(func)
830
+
831
+ if is_safe:
832
+ # Method 1: Try inspect.getsource() (standard library, most common)
833
+ try:
834
+ source = inspect.getsource(func)
835
+ source_method = "inspect.getsource"
836
+
837
+ # Combine parts
838
+ # 1. Imports
839
+ # 2. Dependency sources
840
+ # 3. Main function source
841
+
842
+ parts = []
843
+ if extra_imports:
844
+ parts.extend(extra_imports)
845
+ parts.append("") # newline
846
+
847
+ if extra_sources:
848
+ parts.extend(extra_sources)
849
+ parts.append("") # newline
850
+
851
+ # Dedent main source
852
+ source = self._dedent_source(source)
853
+ parts.append(source)
854
+
855
+ final_source = "\n".join(parts)
856
+
857
+ source = final_source
858
+ source_extracted = True
859
+
860
+ except (OSError, TypeError) as e:
861
+ self.logger.debug(f"inspect.getsource() failed: {e}")
862
+ except Exception as e:
863
+ self.logger.debug(f"inspect.getsource() unexpected error: {e}")
864
+
865
+ # Method 2: Try dill.source.getsource()
866
+ # Note: dill doesn't support our dependency injection easily,
867
+ # so if inspect fails, we might just fallback to pickle.
868
+ # But let's keep it as backup for simple functions.
869
+ if not source_extracted:
870
+ # BUT we need to be careful. If we use dill source, we miss our injections.
871
+ # So if imports/sources are needed, we probably shouldn't use raw dill source.
872
+ # Since is_safe=True implies we resolved dependencies, we EXPECT them to be injected.
873
+ # If inspect fails, we can't easily combine dill source with our injections reliably
874
+ # (dill source might have different indentation/structure).
875
+ # So let's skip dill fallback if we have dependencies.
876
+ if not extra_imports and not extra_sources:
877
+ try:
878
+ if hasattr(dill, "source") and hasattr(
879
+ dill.source, "getsource"
880
+ ):
881
+ source = dill.source.getsource(func)
882
+ source_method = "dill.source.getsource"
883
+ source_extracted = True
884
+ except Exception as e:
885
+ self.logger.debug(f"dill.source.getsource() failed: {e}")
886
+ else:
887
+ self.logger.warning(
888
+ "⚠️ Function unsafe for source extraction (unresolvable globals/closures). Will use dill pickle anyway."
889
+ )
890
+
891
+ # Process and save source if extracted
892
+ if source_extracted:
893
+ try:
894
+ # Source is already prepared and dedented above if using inspect.
895
+ # If using dill (fallback path), we might need to dedent.
896
+ if source_method == "dill.source.getsource":
897
+ source = self._dedent_source(source)
898
+
899
+ # Save source code
900
+ with open(os.path.join(self.project_path, "func_source.py"), "w") as f:
901
+ f.write(source)
902
+
903
+ # Save function name for loading
904
+ with open(os.path.join(self.project_path, "func_name.txt"), "w") as f:
905
+ f.write(func.__name__)
906
+
907
+ self.logger.info(
908
+ f"✅ Function source extracted successfully using {source_method}."
909
+ )
910
+ serialization_method = "source_extraction"
911
+
912
+ except Exception as e:
913
+ self.logger.warning(f"Failed to process/save function source: {e}")
914
+ source_extracted = False
915
+ # Fallback to dill pickle if source extraction save failed
916
+ serialization_method = "dill_pickle"
917
+
918
+ # If source extraction failed or was skipped, ensure no stale source files exist
919
+ if not source_extracted:
920
+ source_path = os.path.join(self.project_path, "func_source.py")
921
+ if os.path.exists(source_path):
922
+ os.remove(source_path)
923
+
924
+ # Always create func_name.txt even if source extraction failed
925
+ # This is needed for Desi backend queue management
926
+ func_name_path = os.path.join(self.project_path, "func_name.txt")
927
+ if not os.path.exists(func_name_path):
928
+ try:
929
+ with open(func_name_path, "w") as f:
930
+ f.write(func.__name__)
931
+ self.logger.debug(f"Created func_name.txt with function name: {func.__name__}")
932
+ except Exception as e:
933
+ self.logger.warning(f"Failed to create func_name.txt: {e}")
934
+
935
+ # If source extraction was attempted but failed, log it
936
+ if not dill_pickle_used:
937
+ self.logger.warning(
938
+ "⚠️ Source extraction failed. Using dill pickle as final fallback."
939
+ )
940
+ # Ensure we still pickle the function even if source extraction failed
941
+ serialization_method = "dill_pickle"
942
+
943
+ # Always pickle the function (used by dill pickle strategy or as fallback)
944
+ # Create func.pkl if not already created (we created it earlier if dill_pickle_used was True)
945
+ if not dill_pickle_used:
946
+ try:
947
+ func_pickle_path = os.path.join(self.project_path, "func.pkl")
948
+ with open(func_pickle_path, "wb") as f:
949
+ dill.dump(func, f)
950
+ self.logger.debug(
951
+ "Created func.pkl as fallback (dill pickle or source extraction fallback)"
952
+ )
953
+ except Exception as e:
954
+ # Only raise if we have no other option (both dill pickle and source extraction failed)
955
+ if not source_extracted:
956
+ self.logger.error(
957
+ f"❌ Critical: Failed to pickle function even as fallback: {e}"
958
+ )
959
+ raise
960
+ else:
961
+ # If source extraction succeeded, warn but don't fail
962
+ self.logger.warning(
963
+ f"⚠️ Failed to create func.pkl fallback (source extraction will be used): {e}"
964
+ )
965
+
966
+ # Save serialization method indicator
967
+ method_file = os.path.join(self.project_path, "serialization_method.txt")
968
+ with open(method_file, "w") as f:
969
+ f.write(f"{serialization_method}\n")
970
+
971
+ # Pickle the arguments
972
+ if args is None:
973
+ args = {}
974
+ with open(os.path.join(self.project_path, "args.pkl"), "wb") as f:
975
+ dill.dump(args, f)
976
+
977
+
978
+ # ---------------------------------------------------------------------------- #
979
+ # EXAMPLE OF EXECUTION #
980
+ # ---------------------------------------------------------------------------- #
981
+ if __name__ == "__main__":
982
+ import ray
983
+ import torch
984
+
985
+ def function_inside_function():
986
+ # Check if file exists before trying to read it, as paths might differ
987
+ if os.path.exists("documentation/RayLauncher.html"):
988
+ with open("documentation/RayLauncher.html", "r") as f:
989
+ return f.read()[0:10]
990
+ return "DocNotFound"
991
+
992
+ def example_func(x):
993
+ import time # Encapsulated imports works too !
994
+ print("Waiting for 60 seconds so that you can check the dashboard...")
995
+ time.sleep(60)
996
+ print("Done waiting !")
997
+ result = (
998
+ ray.cluster_resources(),
999
+ f"GPU is available : {torch.cuda.is_available()}",
1000
+ x + 1,
1001
+ function_inside_function(),
1002
+ )
1003
+ return result
1004
+
1005
+ cluster = RayLauncher(
1006
+ project_name="example", # Name of the project (will create a directory with this name in the current directory)
1007
+ files=["documentation/RayLauncher.html"], # List of files to push to the server
1008
+ use_gpu=True, # If you need GPU, you can set it to True
1009
+ runtime_env={
1010
+ "env_vars": {"NCCL_SOCKET_IFNAME": "eno1"}
1011
+ }, # Example of environment variable
1012
+ server_run=True, # To run the code on the server and not locally
1013
+ cluster="desi", # Use Desi backend (credentials loaded from .env: DESI_USERNAME and DESI_PASSWORD)
1014
+ force_reinstall_venv=False, # Force reinstall venv to test with Python 3.12.1
1015
+ retention_days=1, # Retain files and venv for 1 day before cleanup
1016
+ )
1017
+
1018
+ result = cluster(example_func, args={"x": 5}) # Execute function with arguments
1019
+ print(result)