slurmray 6.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of slurmray might be problematic. Click here for more details.
- slurmray/RayLauncher.py +1019 -0
- slurmray/__init__.py +0 -0
- slurmray/__main__.py +5 -0
- slurmray/assets/cleanup_old_projects.py +171 -0
- slurmray/assets/sbatch_template.sh +67 -0
- slurmray/assets/slurmray_server.sh +145 -0
- slurmray/assets/slurmray_server_template.py +28 -0
- slurmray/assets/spython_template.py +113 -0
- slurmray/backend/__init__.py +0 -0
- slurmray/backend/base.py +1040 -0
- slurmray/backend/desi.py +856 -0
- slurmray/backend/local.py +124 -0
- slurmray/backend/remote.py +191 -0
- slurmray/backend/slurm.py +1234 -0
- slurmray/cli.py +904 -0
- slurmray/detection.py +1 -0
- slurmray/file_sync.py +276 -0
- slurmray/scanner.py +441 -0
- slurmray/utils.py +359 -0
- slurmray-6.0.4.dist-info/LICENSE +201 -0
- slurmray-6.0.4.dist-info/METADATA +85 -0
- slurmray-6.0.4.dist-info/RECORD +24 -0
- slurmray-6.0.4.dist-info/WHEEL +4 -0
- slurmray-6.0.4.dist-info/entry_points.txt +3 -0
slurmray/__init__.py
ADDED
|
File without changes
|
slurmray/__main__.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Cleanup script for old SlurmRay projects on cluster.
|
|
4
|
+
|
|
5
|
+
This script is designed to be run as a cron job to automatically remove
|
|
6
|
+
old project files and venv that have exceeded their retention period.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python cleanup_old_projects.py [--cluster-type slurm|desi] [--base-path /path/to/users]
|
|
10
|
+
|
|
11
|
+
The script will:
|
|
12
|
+
- Scan user directories for slurmray-server/{project_name}/ directories
|
|
13
|
+
- Read .retention_timestamp and .retention_days files
|
|
14
|
+
- Delete projects that have exceeded their retention period
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
import argparse
|
|
21
|
+
import logging
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def read_retention_files(project_dir):
|
|
26
|
+
"""
|
|
27
|
+
Read retention timestamp and days from project directory.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
project_dir: Path to project directory
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
tuple: (timestamp, retention_days) or (None, None) if files don't exist
|
|
34
|
+
"""
|
|
35
|
+
timestamp_file = os.path.join(project_dir, ".retention_timestamp")
|
|
36
|
+
retention_days_file = os.path.join(project_dir, ".retention_days")
|
|
37
|
+
|
|
38
|
+
if not os.path.exists(timestamp_file) or not os.path.exists(retention_days_file):
|
|
39
|
+
return None, None
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
with open(timestamp_file, 'r') as f:
|
|
43
|
+
timestamp_str = f.read().strip()
|
|
44
|
+
timestamp = int(timestamp_str)
|
|
45
|
+
|
|
46
|
+
with open(retention_days_file, 'r') as f:
|
|
47
|
+
retention_days_str = f.read().strip()
|
|
48
|
+
retention_days = int(retention_days_str)
|
|
49
|
+
|
|
50
|
+
return timestamp, retention_days
|
|
51
|
+
except (ValueError, IOError):
|
|
52
|
+
return None, None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def cleanup_user_projects(username, base_path, cluster_type, logger):
|
|
56
|
+
"""
|
|
57
|
+
Cleanup old projects for a specific user.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
username: Username
|
|
61
|
+
base_path: Base path to user directories (e.g., /users or /home)
|
|
62
|
+
cluster_type: Type of cluster ('slurm' or 'desi')
|
|
63
|
+
logger: Logger instance
|
|
64
|
+
"""
|
|
65
|
+
user_dir = os.path.join(base_path, username)
|
|
66
|
+
if not os.path.exists(user_dir):
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
slurmray_base = os.path.join(user_dir, "slurmray-server")
|
|
70
|
+
if not os.path.exists(slurmray_base):
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Scan for project directories
|
|
74
|
+
if not os.path.isdir(slurmray_base):
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
for item in os.listdir(slurmray_base):
|
|
78
|
+
project_dir = os.path.join(slurmray_base, item)
|
|
79
|
+
if not os.path.isdir(project_dir):
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
# Read retention files
|
|
83
|
+
timestamp, retention_days = read_retention_files(project_dir)
|
|
84
|
+
|
|
85
|
+
if timestamp is None or retention_days is None:
|
|
86
|
+
# No retention files, skip this project
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
# Calculate age in days
|
|
90
|
+
current_time = int(time.time())
|
|
91
|
+
age_seconds = current_time - timestamp
|
|
92
|
+
age_days = age_seconds / (24 * 3600)
|
|
93
|
+
|
|
94
|
+
# Check if project should be deleted
|
|
95
|
+
if age_days > retention_days:
|
|
96
|
+
logger.info(f"Deleting project {username}/{item}: age={age_days:.1f} days, retention={retention_days} days")
|
|
97
|
+
try:
|
|
98
|
+
import shutil
|
|
99
|
+
shutil.rmtree(project_dir)
|
|
100
|
+
logger.info(f"Successfully deleted project {username}/{item}")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.error(f"Failed to delete project {username}/{item}: {e}")
|
|
103
|
+
raise # Fail-fast on errors
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def main():
|
|
107
|
+
"""Main entry point for cleanup script"""
|
|
108
|
+
parser = argparse.ArgumentParser(description="Cleanup old SlurmRay projects on cluster")
|
|
109
|
+
parser.add_argument(
|
|
110
|
+
"--cluster-type",
|
|
111
|
+
choices=["slurm", "desi"],
|
|
112
|
+
default="slurm",
|
|
113
|
+
help="Type of cluster (default: slurm)"
|
|
114
|
+
)
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--base-path",
|
|
117
|
+
help="Base path to user directories (default: /users for slurm, /home for desi)"
|
|
118
|
+
)
|
|
119
|
+
parser.add_argument(
|
|
120
|
+
"--log-file",
|
|
121
|
+
help="Log file path (default: /tmp/slurmray_cleanup.log)"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
args = parser.parse_args()
|
|
125
|
+
|
|
126
|
+
# Setup logging
|
|
127
|
+
log_file = args.log_file or "/tmp/slurmray_cleanup.log"
|
|
128
|
+
logging.basicConfig(
|
|
129
|
+
level=logging.INFO,
|
|
130
|
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
131
|
+
handlers=[
|
|
132
|
+
logging.FileHandler(log_file),
|
|
133
|
+
logging.StreamHandler(sys.stdout)
|
|
134
|
+
]
|
|
135
|
+
)
|
|
136
|
+
logger = logging.getLogger(__name__)
|
|
137
|
+
|
|
138
|
+
# Determine base path
|
|
139
|
+
if args.base_path:
|
|
140
|
+
base_path = args.base_path
|
|
141
|
+
else:
|
|
142
|
+
if args.cluster_type == "slurm":
|
|
143
|
+
base_path = "/users"
|
|
144
|
+
else: # desi
|
|
145
|
+
base_path = "/home"
|
|
146
|
+
|
|
147
|
+
logger.info(f"Starting cleanup for cluster type: {args.cluster_type}, base path: {base_path}")
|
|
148
|
+
|
|
149
|
+
if not os.path.exists(base_path):
|
|
150
|
+
logger.error(f"Base path does not exist: {base_path}")
|
|
151
|
+
sys.exit(1)
|
|
152
|
+
|
|
153
|
+
# Scan all user directories
|
|
154
|
+
cleaned_count = 0
|
|
155
|
+
for username in os.listdir(base_path):
|
|
156
|
+
user_path = os.path.join(base_path, username)
|
|
157
|
+
if not os.path.isdir(user_path):
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
cleanup_user_projects(username, base_path, args.cluster_type, logger)
|
|
162
|
+
except Exception as e:
|
|
163
|
+
logger.error(f"Error processing user {username}: {e}")
|
|
164
|
+
raise # Fail-fast on errors
|
|
165
|
+
|
|
166
|
+
logger.info("Cleanup completed")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
if __name__ == "__main__":
|
|
170
|
+
main()
|
|
171
|
+
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
# THIS FILE IS GENERATED BY AUTOMATION SCRIPT! PLEASE REFER TO ORIGINAL SCRIPT!
|
|
4
|
+
# THIS FILE IS A TEMPLATE AND IT SHOULD NOT BE DEPLOYED TO PRODUCTION!
|
|
5
|
+
|
|
6
|
+
#SBATCH --partition={{PARTITION_NAME}}
|
|
7
|
+
#SBATCH --mem {{MEMORY}}G
|
|
8
|
+
#SBATCH --job-name={{JOB_NAME}}
|
|
9
|
+
#SBATCH --output={{JOB_NAME}}.log
|
|
10
|
+
#SBATCH --time {{RUNNING_TIME}}
|
|
11
|
+
{{GIVEN_NODE}}
|
|
12
|
+
|
|
13
|
+
### This script works for any number of nodes, Ray will find and manage all resources
|
|
14
|
+
#SBATCH --nodes={{NUM_NODES}}
|
|
15
|
+
|
|
16
|
+
### Give all resources to a single Ray task, ray can manage the resources internally
|
|
17
|
+
#SBATCH --ntasks-per-node=1
|
|
18
|
+
{{PARTITION_SPECIFICS}}
|
|
19
|
+
|
|
20
|
+
# Load modules or your own conda environment here
|
|
21
|
+
{{LOAD_ENV}}
|
|
22
|
+
|
|
23
|
+
################# DON NOT CHANGE THINGS HERE UNLESS YOU KNOW WHAT YOU ARE DOING ###############
|
|
24
|
+
# This script is a modification to the implementation suggest by gregSchwartz18 here:
|
|
25
|
+
# https://github.com/ray-project/ray/issues/826#issuecomment-522116599
|
|
26
|
+
redis_password=$(uuidgen)
|
|
27
|
+
export redis_password
|
|
28
|
+
|
|
29
|
+
nodes=$(scontrol show hostnames $SLURM_JOB_NODELIST) # Getting the node names
|
|
30
|
+
nodes_array=($nodes)
|
|
31
|
+
|
|
32
|
+
node_1=${nodes_array[0]}
|
|
33
|
+
ip=$(srun --nodes=1 --ntasks=1 -w $node_1 hostname --ip-address) # making redis-address
|
|
34
|
+
|
|
35
|
+
if [[ $ip == *" "* ]]; then
|
|
36
|
+
IFS=' ' read -ra ADDR <<<"$ip"
|
|
37
|
+
if [[ ${#ADDR[0]} > 16 ]]; then
|
|
38
|
+
ip=${ADDR[1]}
|
|
39
|
+
else
|
|
40
|
+
ip=${ADDR[0]}
|
|
41
|
+
fi
|
|
42
|
+
echo "We detect space in ip! You are using IPV6 address. We split the IPV4 address as $ip"
|
|
43
|
+
fi
|
|
44
|
+
|
|
45
|
+
port=6379
|
|
46
|
+
ip_head=$ip:$port
|
|
47
|
+
export ip_head
|
|
48
|
+
echo "IP Head: $ip_head"
|
|
49
|
+
|
|
50
|
+
echo "STARTING HEAD at $node_1"
|
|
51
|
+
# srun --nodes=1 --ntasks=1 -w $node_1 start-head.sh $ip $redis_password &
|
|
52
|
+
srun --nodes=1 --ntasks=1 -w $node_1 \
|
|
53
|
+
ray start --head --include-dashboard=true --dashboard-host 0.0.0.0 --dashboard-port=8265 --node-ip-address=$ip --port=6379 --redis-password=$redis_password --block &
|
|
54
|
+
sleep 30
|
|
55
|
+
|
|
56
|
+
worker_num=$(($SLURM_JOB_NUM_NODES - 1)) #number of nodes other than the head node
|
|
57
|
+
for ((i = 1; i <= $worker_num; i++)); do
|
|
58
|
+
node_i=${nodes_array[$i]}
|
|
59
|
+
echo "STARTING WORKER $i at $node_i"
|
|
60
|
+
srun --nodes=1 --ntasks=1 -w $node_i ray start --address $ip_head --redis-password=$redis_password --block &
|
|
61
|
+
sleep 5
|
|
62
|
+
done
|
|
63
|
+
|
|
64
|
+
##############################################################################################
|
|
65
|
+
|
|
66
|
+
#### call your code below
|
|
67
|
+
{{COMMAND_PLACEHOLDER}} {{COMMAND_SUFFIX}}
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
#!/bin/sh
|
|
2
|
+
|
|
3
|
+
echo "Installing slurmray server"
|
|
4
|
+
|
|
5
|
+
# Copy files
|
|
6
|
+
mv -t slurmray-server requirements.txt slurmray_server.py
|
|
7
|
+
mv -t slurmray-server/.slogs/server func.pkl args.pkl
|
|
8
|
+
cd slurmray-server
|
|
9
|
+
|
|
10
|
+
# Load modules
|
|
11
|
+
# Using specific versions for Curnagl compatibility (SLURM 24.05.3)
|
|
12
|
+
# gcc/13.2.0: Latest GCC version
|
|
13
|
+
# python/3.12.1: Latest Python version on Curnagl
|
|
14
|
+
# cuda/12.6.2: Latest CUDA version
|
|
15
|
+
# cudnn/9.2.0.82-12: Compatible with cuda/12.6.2
|
|
16
|
+
module load gcc/13.2.0 rust python/3.12.1 cuda/12.6.2 cudnn/9.2.0.82-12
|
|
17
|
+
|
|
18
|
+
# Create venv if it doesn't exist (hash check is done in Python before file upload)
|
|
19
|
+
# If venv needs recreation, it has already been removed by Python
|
|
20
|
+
# Check for force reinstall flag
|
|
21
|
+
if [ -f ".force_reinstall" ]; then
|
|
22
|
+
echo "Force reinstall flag detected: removing existing virtualenv..."
|
|
23
|
+
rm -rf .venv
|
|
24
|
+
rm -f .force_reinstall
|
|
25
|
+
fi
|
|
26
|
+
|
|
27
|
+
if [ ! -d ".venv" ]; then
|
|
28
|
+
echo "Creating virtualenv..."
|
|
29
|
+
python3 -m venv .venv
|
|
30
|
+
else
|
|
31
|
+
echo "Using existing virtualenv (requirements unchanged)..."
|
|
32
|
+
VENV_EXISTED=true
|
|
33
|
+
fi
|
|
34
|
+
|
|
35
|
+
source .venv/bin/activate
|
|
36
|
+
|
|
37
|
+
# Install requirements if file exists and is not empty
|
|
38
|
+
if [ -f requirements.txt ]; then
|
|
39
|
+
# Check if requirements.txt is empty (only whitespace)
|
|
40
|
+
if [ -s requirements.txt ]; then
|
|
41
|
+
echo "📥 Installing dependencies from requirements.txt..."
|
|
42
|
+
|
|
43
|
+
# Get installed packages once (fast, single command) - create lookup file
|
|
44
|
+
uv pip list --format=freeze 2>/dev/null | sed 's/==/ /' | awk '{print $1" "$2}' > /tmp/installed_packages.txt || touch /tmp/installed_packages.txt
|
|
45
|
+
|
|
46
|
+
# Process requirements: filter duplicates and check what needs installation
|
|
47
|
+
INSTALL_ERRORS=0
|
|
48
|
+
SKIPPED_COUNT=0
|
|
49
|
+
> /tmp/to_install.txt # Clear file
|
|
50
|
+
|
|
51
|
+
while IFS= read -r line || [ -n "$line" ]; do
|
|
52
|
+
# Skip empty lines and comments
|
|
53
|
+
line=$(echo "$line" | sed 's/^[[:space:]]*//;s/[[:space:]]*$//')
|
|
54
|
+
if [ -z "$line" ] || [ "${line#\#}" != "$line" ]; then
|
|
55
|
+
continue
|
|
56
|
+
fi
|
|
57
|
+
|
|
58
|
+
# Extract package name (remove version specifiers and extras)
|
|
59
|
+
pkg_name=$(echo "$line" | sed 's/[<>=!].*//' | sed 's/\[.*\]//' | sed 's/[[:space:]]*//' | tr '[:upper:]' '[:lower:]')
|
|
60
|
+
if [ -z "$pkg_name" ]; then
|
|
61
|
+
continue
|
|
62
|
+
fi
|
|
63
|
+
|
|
64
|
+
# Skip duplicates (check if we've already processed this package)
|
|
65
|
+
if grep -qi "^$pkg_name$" /tmp/seen_packages.txt 2>/dev/null; then
|
|
66
|
+
continue
|
|
67
|
+
fi
|
|
68
|
+
echo "$pkg_name" >> /tmp/seen_packages.txt
|
|
69
|
+
|
|
70
|
+
# Extract required version if present
|
|
71
|
+
required_version=""
|
|
72
|
+
if echo "$line" | grep -q "=="; then
|
|
73
|
+
required_version=$(echo "$line" | sed 's/.*==\([^;]*\).*/\1/' | sed 's/[[:space:]]*//')
|
|
74
|
+
fi
|
|
75
|
+
|
|
76
|
+
# Check if package is already installed with correct version
|
|
77
|
+
installed_version=$(grep -i "^$pkg_name " /tmp/installed_packages.txt 2>/dev/null | awk '{print $2}' | head -1)
|
|
78
|
+
|
|
79
|
+
if [ -n "$installed_version" ]; then
|
|
80
|
+
if [ -z "$required_version" ] || [ "$installed_version" = "$required_version" ]; then
|
|
81
|
+
echo " ⏭️ $pkg_name==$installed_version (already installed)"
|
|
82
|
+
SKIPPED_COUNT=$((SKIPPED_COUNT + 1))
|
|
83
|
+
continue
|
|
84
|
+
fi
|
|
85
|
+
fi
|
|
86
|
+
|
|
87
|
+
# Package not installed or version mismatch, add to install list
|
|
88
|
+
echo "$line" >> /tmp/to_install.txt
|
|
89
|
+
done < requirements.txt
|
|
90
|
+
|
|
91
|
+
# Install packages that need installation
|
|
92
|
+
if [ -s /tmp/to_install.txt ]; then
|
|
93
|
+
> /tmp/install_errors.txt # Track errors
|
|
94
|
+
while IFS= read -r line; do
|
|
95
|
+
pkg_name=$(echo "$line" | sed 's/[<>=!].*//' | sed 's/\[.*\]//' | sed 's/[[:space:]]*//')
|
|
96
|
+
if uv pip install --quiet "$line" >/dev/null 2>&1; then
|
|
97
|
+
echo " ✅ $pkg_name"
|
|
98
|
+
else
|
|
99
|
+
echo " ❌ $pkg_name"
|
|
100
|
+
echo "1" >> /tmp/install_errors.txt
|
|
101
|
+
# Show error details
|
|
102
|
+
uv pip install "$line" 2>&1 | grep -E "(error|Error|ERROR|failed|Failed|FAILED)" | head -3 | sed 's/^/ /' || true
|
|
103
|
+
fi
|
|
104
|
+
done < /tmp/to_install.txt
|
|
105
|
+
INSTALL_ERRORS=$(wc -l < /tmp/install_errors.txt 2>/dev/null | tr -d ' ' || echo "0")
|
|
106
|
+
rm -f /tmp/install_errors.txt
|
|
107
|
+
fi
|
|
108
|
+
|
|
109
|
+
# Count newly installed packages before cleanup
|
|
110
|
+
NEWLY_INSTALLED=0
|
|
111
|
+
if [ -s /tmp/to_install.txt ]; then
|
|
112
|
+
NEWLY_INSTALLED=$(wc -l < /tmp/to_install.txt 2>/dev/null | tr -d ' ' || echo "0")
|
|
113
|
+
fi
|
|
114
|
+
|
|
115
|
+
# Cleanup temp files
|
|
116
|
+
rm -f /tmp/installed_packages.txt /tmp/seen_packages.txt /tmp/to_install.txt
|
|
117
|
+
|
|
118
|
+
if [ $INSTALL_ERRORS -eq 0 ]; then
|
|
119
|
+
if [ $SKIPPED_COUNT -gt 0 ]; then
|
|
120
|
+
echo "✅ All dependencies up to date ($SKIPPED_COUNT already installed, $NEWLY_INSTALLED newly installed)"
|
|
121
|
+
else
|
|
122
|
+
echo "✅ All dependencies installed successfully"
|
|
123
|
+
fi
|
|
124
|
+
else
|
|
125
|
+
echo "❌ Failed to install $INSTALL_ERRORS package(s)" >&2
|
|
126
|
+
exit 1
|
|
127
|
+
fi
|
|
128
|
+
else
|
|
129
|
+
if [ "$VENV_EXISTED" = "true" ]; then
|
|
130
|
+
echo "✅ All dependencies already installed (requirements.txt is empty)"
|
|
131
|
+
else
|
|
132
|
+
echo "⚠️ requirements.txt is empty, skipping dependency installation"
|
|
133
|
+
fi
|
|
134
|
+
fi
|
|
135
|
+
else
|
|
136
|
+
echo "⚠️ No requirements.txt found, skipping dependency installation"
|
|
137
|
+
fi
|
|
138
|
+
|
|
139
|
+
# Fix torch bug (https://github.com/pytorch/pytorch/issues/111469)
|
|
140
|
+
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
|
141
|
+
export LD_LIBRARY_PATH=$HOME/slurmray-server/.venv/lib/python$PYTHON_VERSION/site-packages/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# Run server
|
|
145
|
+
python -u slurmray_server.py
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from slurmray.RayLauncher import RayLauncher
|
|
2
|
+
|
|
3
|
+
if __name__ == "__main__":
|
|
4
|
+
# Note: This template creates a RayLauncher instance that runs on the cluster.
|
|
5
|
+
# The function execution is handled by spython.py which loads the serialized function.
|
|
6
|
+
# With the refactored API, we don't need to call cluster() here because:
|
|
7
|
+
# 1. The function is already serialized (func_source.py, func.pkl, args.pkl)
|
|
8
|
+
# 2. The actual execution is done via sbatch -> spython.py
|
|
9
|
+
# The RayLauncher instance here is just used to detect cluster mode and submit the job.
|
|
10
|
+
|
|
11
|
+
cluster = RayLauncher(
|
|
12
|
+
project_name="server",
|
|
13
|
+
modules={{MODULES}},
|
|
14
|
+
node_nbr={{NODE_NBR}},
|
|
15
|
+
use_gpu={{USE_GPU}},
|
|
16
|
+
memory={{MEMORY}},
|
|
17
|
+
max_running_time={{MAX_RUNNING_TIME}},
|
|
18
|
+
server_run=False,
|
|
19
|
+
server_ssh=None,
|
|
20
|
+
server_username=None,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# The job execution is handled by the backend when cluster.run() is called.
|
|
24
|
+
# Since we're on the cluster, the backend will detect cluster=True and submit via sbatch.
|
|
25
|
+
# The actual function execution happens in spython.py.
|
|
26
|
+
# Note: With the refactored API, this would need a dummy function to work.
|
|
27
|
+
# However, this template should be refactored to not need this call at all.
|
|
28
|
+
# For now, we skip execution here as it's handled by the backend infrastructure.
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import ray
|
|
2
|
+
import dill
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
PROJECT_PATH = {{PROJECT_PATH}}
|
|
7
|
+
|
|
8
|
+
# Add the project path to the python path
|
|
9
|
+
sys.path.append(PROJECT_PATH)
|
|
10
|
+
|
|
11
|
+
# Add editable package source directories to sys.path
|
|
12
|
+
# This handles packages uploaded from editable installs (e.g., Poetry projects)
|
|
13
|
+
# Check for src/ directory (common Poetry src/ layout)
|
|
14
|
+
# For flat layout, PROJECT_PATH is already in sys.path, so packages at root are importable
|
|
15
|
+
src_path = os.path.join(PROJECT_PATH, "src")
|
|
16
|
+
if os.path.exists(src_path) and os.path.isdir(src_path):
|
|
17
|
+
if src_path not in sys.path:
|
|
18
|
+
sys.path.insert(0, src_path) # Insert at beginning for priority
|
|
19
|
+
|
|
20
|
+
# Suppress Ray FutureWarning about accelerator visible devices
|
|
21
|
+
os.environ.setdefault("RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO", "0")
|
|
22
|
+
|
|
23
|
+
# Start the ray cluster
|
|
24
|
+
ray.init({{LOCAL_MODE}})
|
|
25
|
+
|
|
26
|
+
# Load the function
|
|
27
|
+
# Read the serialization method used
|
|
28
|
+
serialization_method = "dill_pickle" # Default
|
|
29
|
+
method_file = os.path.join(PROJECT_PATH, "serialization_method.txt")
|
|
30
|
+
if os.path.exists(method_file):
|
|
31
|
+
with open(method_file, "r") as f:
|
|
32
|
+
serialization_method = f.read().strip()
|
|
33
|
+
|
|
34
|
+
if serialization_method == "source_extraction":
|
|
35
|
+
# Load from source extraction
|
|
36
|
+
try:
|
|
37
|
+
func_source_path = os.path.join(PROJECT_PATH, "func_source.py")
|
|
38
|
+
func_name_path = os.path.join(PROJECT_PATH, "func_name.txt")
|
|
39
|
+
|
|
40
|
+
if not os.path.exists(func_source_path) or not os.path.exists(func_name_path):
|
|
41
|
+
raise FileNotFoundError(
|
|
42
|
+
"Source files missing: func_source.py or func_name.txt not found"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
import importlib.util
|
|
46
|
+
|
|
47
|
+
spec = importlib.util.spec_from_file_location("func_module", func_source_path)
|
|
48
|
+
module = importlib.util.module_from_spec(spec)
|
|
49
|
+
spec.loader.exec_module(module)
|
|
50
|
+
|
|
51
|
+
with open(func_name_path, "r") as f:
|
|
52
|
+
func_name = f.read().strip()
|
|
53
|
+
|
|
54
|
+
func = getattr(module, func_name)
|
|
55
|
+
print(f"Loaded function '{func_name}' from source extraction.")
|
|
56
|
+
except Exception as e:
|
|
57
|
+
print(f"❌ Error loading function from source extraction: {e}")
|
|
58
|
+
print("Falling back to dill pickle...")
|
|
59
|
+
import traceback
|
|
60
|
+
|
|
61
|
+
traceback.print_exc()
|
|
62
|
+
with open(os.path.join(PROJECT_PATH, "func.pkl"), "rb") as f:
|
|
63
|
+
func = dill.load(f)
|
|
64
|
+
print("✅ Loaded function from dill pickle (fallback).")
|
|
65
|
+
else:
|
|
66
|
+
# Load from dill pickle (default method)
|
|
67
|
+
try:
|
|
68
|
+
func_pickle_path = os.path.join(PROJECT_PATH, "func.pkl")
|
|
69
|
+
if not os.path.exists(func_pickle_path):
|
|
70
|
+
raise FileNotFoundError(f"Pickle file not found: {func_pickle_path}")
|
|
71
|
+
|
|
72
|
+
with open(func_pickle_path, "rb") as f:
|
|
73
|
+
func = dill.load(f)
|
|
74
|
+
print(f"✅ Loaded function from dill pickle.")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
print(f"❌ Error loading function from dill pickle: {e}")
|
|
77
|
+
import traceback
|
|
78
|
+
|
|
79
|
+
traceback.print_exc()
|
|
80
|
+
sys.exit(1)
|
|
81
|
+
|
|
82
|
+
# Load the arguments
|
|
83
|
+
with open(os.path.join(PROJECT_PATH, "args.pkl"), "rb") as f:
|
|
84
|
+
args = dill.load(f)
|
|
85
|
+
|
|
86
|
+
# Run the function
|
|
87
|
+
try:
|
|
88
|
+
result = func(**args)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
print(f"Error executing function: {e}")
|
|
91
|
+
import traceback
|
|
92
|
+
|
|
93
|
+
traceback.print_exc()
|
|
94
|
+
sys.exit(1)
|
|
95
|
+
|
|
96
|
+
# Write the result
|
|
97
|
+
result_path = os.path.join(PROJECT_PATH, "result.pkl")
|
|
98
|
+
try:
|
|
99
|
+
with open(result_path, "wb") as f:
|
|
100
|
+
dill.dump(result, f)
|
|
101
|
+
print(f"Result written to {result_path}")
|
|
102
|
+
except Exception as e:
|
|
103
|
+
print(f"Error writing result: {e}")
|
|
104
|
+
import traceback
|
|
105
|
+
|
|
106
|
+
traceback.print_exc()
|
|
107
|
+
sys.exit(1)
|
|
108
|
+
|
|
109
|
+
# Stop ray
|
|
110
|
+
try:
|
|
111
|
+
ray.shutdown()
|
|
112
|
+
except Exception as e:
|
|
113
|
+
print(f"Warning: Error shutting down Ray: {e}")
|
|
File without changes
|