snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -6,9 +6,13 @@ import logging
|
|
|
6
6
|
import math
|
|
7
7
|
import os
|
|
8
8
|
import runpy
|
|
9
|
+
import shutil
|
|
10
|
+
import subprocess
|
|
9
11
|
import sys
|
|
10
12
|
import time
|
|
11
13
|
import traceback
|
|
14
|
+
import zipfile
|
|
15
|
+
from pathlib import Path
|
|
12
16
|
from typing import Any, Optional
|
|
13
17
|
|
|
14
18
|
# Ensure payload directory is in sys.path for module imports before importing other modules
|
|
@@ -18,11 +22,17 @@ from typing import Any, Optional
|
|
|
18
22
|
STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
|
|
19
23
|
JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
|
|
20
24
|
PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
|
|
25
|
+
|
|
21
26
|
if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
|
|
22
27
|
PAYLOAD_PATH = os.path.join(STAGE_MOUNT_PATH, PAYLOAD_PATH)
|
|
23
|
-
if PAYLOAD_PATH and PAYLOAD_PATH not in sys.path:
|
|
24
|
-
sys.path.insert(0, PAYLOAD_PATH)
|
|
25
28
|
|
|
29
|
+
if PAYLOAD_PATH:
|
|
30
|
+
if PAYLOAD_PATH not in sys.path:
|
|
31
|
+
sys.path.insert(0, PAYLOAD_PATH)
|
|
32
|
+
for zip_file in Path(PAYLOAD_PATH).rglob("*.zip"):
|
|
33
|
+
fpath = str(zip_file)
|
|
34
|
+
if fpath not in sys.path and zipfile.is_zipfile(fpath):
|
|
35
|
+
sys.path.insert(0, fpath)
|
|
26
36
|
# Imports below must come after sys.path modification to support module overrides
|
|
27
37
|
import snowflake.ml.jobs._utils.constants # noqa: E402
|
|
28
38
|
import snowflake.snowpark # noqa: E402
|
|
@@ -81,6 +91,76 @@ TIMEOUT = float(os.getenv(INSTANCES_TIMEOUT_ENV_VAR) or 720) # seconds
|
|
|
81
91
|
CHECK_INTERVAL = float(os.getenv(INSTANCES_CHECK_INTERVAL_ENV_VAR) or 10) # seconds
|
|
82
92
|
|
|
83
93
|
|
|
94
|
+
def is_python_script(file_path: str) -> bool:
|
|
95
|
+
"""Check if a file is a Python script by examining its shebang.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
file_path: Path to the file to check.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if the file has a shebang line containing 'python', False otherwise.
|
|
102
|
+
"""
|
|
103
|
+
try:
|
|
104
|
+
with open(file_path, "rb") as f:
|
|
105
|
+
first_line = f.readline()
|
|
106
|
+
if first_line.startswith(b"#!"):
|
|
107
|
+
shebang = first_line.decode("utf-8", errors="ignore").lower()
|
|
108
|
+
return "python" in shebang
|
|
109
|
+
except OSError:
|
|
110
|
+
pass
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def resolve_entrypoint(entrypoint: str) -> tuple[str, bool]:
|
|
115
|
+
"""Resolve the entrypoint to determine how to execute it.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
entrypoint: The entrypoint string (file path or command name).
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
A tuple of (resolved_path, is_python):
|
|
122
|
+
- resolved_path: The path to the executable/script.
|
|
123
|
+
- is_python: True if this should be run as a Python script.
|
|
124
|
+
"""
|
|
125
|
+
# Check if entrypoint is an existing file
|
|
126
|
+
if os.path.isfile(entrypoint):
|
|
127
|
+
# Always run as Python script for backward compatibility
|
|
128
|
+
return entrypoint, True
|
|
129
|
+
|
|
130
|
+
# Try to resolve as a command using shutil.which
|
|
131
|
+
resolved_path = shutil.which(entrypoint)
|
|
132
|
+
if resolved_path:
|
|
133
|
+
if is_python_script(resolved_path):
|
|
134
|
+
return resolved_path, True
|
|
135
|
+
else:
|
|
136
|
+
# Assume it's meant to be used as a command and not a Python script
|
|
137
|
+
return entrypoint, False
|
|
138
|
+
|
|
139
|
+
# If we can't resolve it, assume it's meant to be a Python script path
|
|
140
|
+
# (this preserves backwards compatibility and will fail with a clear error)
|
|
141
|
+
return entrypoint, True
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def run_command(command: str, *args: Any) -> None:
|
|
145
|
+
"""Execute a command as a subprocess, streaming output and raising an exception if it fails.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
command: Path to the executable.
|
|
149
|
+
args: Arguments to pass to the command.
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
CalledProcessError: If the subprocess exits with a non-zero return code.
|
|
153
|
+
"""
|
|
154
|
+
cmd = [command, *[str(arg) for arg in args]]
|
|
155
|
+
logger.debug(f"Running subprocess: {' '.join(cmd)}")
|
|
156
|
+
|
|
157
|
+
# Run subprocess without capturing output - let stdout/stderr flow directly to console
|
|
158
|
+
result = subprocess.run(cmd)
|
|
159
|
+
|
|
160
|
+
if result.returncode != 0:
|
|
161
|
+
raise subprocess.CalledProcessError(result.returncode, cmd)
|
|
162
|
+
|
|
163
|
+
|
|
84
164
|
def save_mljob_result_v2(value: Any, is_error: bool, path: str) -> None:
|
|
85
165
|
from snowflake.ml.jobs._interop import (
|
|
86
166
|
results as interop_result,
|
|
@@ -313,11 +393,11 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
313
393
|
sys.argv = original_argv
|
|
314
394
|
|
|
315
395
|
|
|
316
|
-
def main(
|
|
396
|
+
def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
|
|
317
397
|
"""Executes a Python script and serializes the result to JOB_RESULT_PATH.
|
|
318
398
|
|
|
319
399
|
Args:
|
|
320
|
-
|
|
400
|
+
entrypoint (str): The job payload entrypoint to execute.
|
|
321
401
|
script_args (Any): Arguments to pass to the script.
|
|
322
402
|
script_main_func (str, optional): The name of the function to call in the script (if any).
|
|
323
403
|
|
|
@@ -361,8 +441,15 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
361
441
|
# Log start marker before starting user script execution
|
|
362
442
|
print(LOG_START_MSG) # noqa: T201
|
|
363
443
|
|
|
364
|
-
#
|
|
365
|
-
|
|
444
|
+
# Resolve entrypoint to determine execution method
|
|
445
|
+
resolved_entrypoint, is_python = resolve_entrypoint(entrypoint)
|
|
446
|
+
|
|
447
|
+
if is_python:
|
|
448
|
+
# Run as Python script
|
|
449
|
+
execution_result_value = run_script(resolved_entrypoint, *script_args, main_func=script_main_func)
|
|
450
|
+
else:
|
|
451
|
+
# Run as subprocess
|
|
452
|
+
run_command(resolved_entrypoint, *script_args)
|
|
366
453
|
|
|
367
454
|
# Log end marker for user script execution
|
|
368
455
|
print(LOG_END_MSG) # noqa: T201
|
|
@@ -395,7 +482,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
395
482
|
|
|
396
483
|
if __name__ == "__main__":
|
|
397
484
|
parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
|
|
398
|
-
parser.add_argument("
|
|
485
|
+
parser.add_argument("entrypoint", help="The job payload entrypoint to execute")
|
|
399
486
|
parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
|
|
400
487
|
parser.add_argument(
|
|
401
488
|
"--script_main_func", required=False, help="The name of the main function to call in the script"
|
|
@@ -403,7 +490,7 @@ if __name__ == "__main__":
|
|
|
403
490
|
args, unknown_args = parser.parse_known_args()
|
|
404
491
|
|
|
405
492
|
main(
|
|
406
|
-
args.
|
|
493
|
+
args.entrypoint,
|
|
407
494
|
*args.script_args,
|
|
408
495
|
*unknown_args,
|
|
409
496
|
script_main_func=args.script_main_func,
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
set -e # exit if a command fails
|
|
4
|
+
|
|
5
|
+
echo "Creating log directories..."
|
|
6
|
+
mkdir -p /var/log/managedservices/user/mlrs
|
|
7
|
+
mkdir -p /var/log/managedservices/system/mlrs
|
|
8
|
+
mkdir -p /var/log/managedservices/system/ray
|
|
9
|
+
|
|
10
|
+
echo "*/1 * * * * root /etc/ray_copy_cron.sh" >> /etc/cron.d/ray_copy_cron
|
|
11
|
+
echo "" >> /etc/cron.d/ray_copy_cron
|
|
12
|
+
chmod 744 /etc/cron.d/ray_copy_cron
|
|
13
|
+
|
|
14
|
+
service cron start
|
|
15
|
+
|
|
16
|
+
mkdir -p /tmp/prometheus-multi-dir
|
|
17
|
+
|
|
18
|
+
# Configure IP address and logging directory
|
|
19
|
+
eth0Ip=$(ifconfig eth0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
|
20
|
+
log_dir="/tmp/ray"
|
|
21
|
+
|
|
22
|
+
# Check if eth0Ip is empty and set default if necessary
|
|
23
|
+
if [ -z "$eth0Ip" ]; then
|
|
24
|
+
# This should never happen, but just in case ethOIp is not set, we should default to localhost
|
|
25
|
+
eth0Ip="127.0.0.1"
|
|
26
|
+
fi
|
|
27
|
+
|
|
28
|
+
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
|
29
|
+
total_memory_size=$(awk '/MemTotal/ {print int($2/1024)}' /proc/meminfo)
|
|
30
|
+
|
|
31
|
+
# Determine if dashboard should be enabled based on total memory size
|
|
32
|
+
# Enable dashboard only if total memory size >= 8GB (i.e. not on XS compute pool)
|
|
33
|
+
# TODO (SNOW-2860029): use a environment variable to determine the node type
|
|
34
|
+
total_memory_threshold=8192
|
|
35
|
+
if [ "$total_memory_size" -ge "$total_memory_threshold" ]; then
|
|
36
|
+
enable_dashboard="true"
|
|
37
|
+
else
|
|
38
|
+
enable_dashboard="false"
|
|
39
|
+
fi
|
|
40
|
+
|
|
41
|
+
echo "Shared memory size: $shm_size bytes"
|
|
42
|
+
echo "Dashboard enabled: $enable_dashboard"
|
|
43
|
+
|
|
44
|
+
# Common parameters for both head and worker nodes
|
|
45
|
+
common_params=(
|
|
46
|
+
"--node-ip-address=$eth0Ip"
|
|
47
|
+
"--object-manager-port=${RAY_OBJECT_MANAGER_PORT:-12011}"
|
|
48
|
+
"--node-manager-port=${RAY_NODE_MANAGER_PORT:-12012}"
|
|
49
|
+
"--runtime-env-agent-port=${RAY_RUNTIME_ENV_AGENT_PORT:-12013}"
|
|
50
|
+
"--dashboard-agent-grpc-port=${RAY_DASHBOARD_AGENT_GRPC_PORT:-12014}"
|
|
51
|
+
"--dashboard-agent-listen-port=${RAY_DASHBOARD_AGENT_LISTEN_PORT:-12015}"
|
|
52
|
+
"--min-worker-port=${RAY_MIN_WORKER_PORT:-12031}"
|
|
53
|
+
"--max-worker-port=${RAY_MAX_WORKER_PORT:-13000}"
|
|
54
|
+
"--metrics-export-port=11502"
|
|
55
|
+
"--temp-dir=$log_dir"
|
|
56
|
+
"--disable-usage-stats"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Specific parameters for head and worker nodes
|
|
60
|
+
if [ "$NODE_TYPE" = "worker" ]; then
|
|
61
|
+
# Check mandatory environment variables for worker
|
|
62
|
+
if [ -z "$RAY_HEAD_ADDRESS" ] || [ -z "$SERVICE_NAME" ]; then
|
|
63
|
+
echo "Error: RAY_HEAD_ADDRESS and SERVICE_NAME must be set."
|
|
64
|
+
exit 1
|
|
65
|
+
fi
|
|
66
|
+
|
|
67
|
+
# Additional worker-specific parameters
|
|
68
|
+
worker_params=(
|
|
69
|
+
"--address=${RAY_HEAD_ADDRESS}:${RAY_HEAD_GCS_PORT:-12001}" # Connect to head node
|
|
70
|
+
"--resources={\"${SERVICE_NAME}\":1, \"node_tag:worker\":1}" # Custom resource for node identification
|
|
71
|
+
"--object-store-memory=${shm_size}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
# Start Ray on a worker node
|
|
75
|
+
ray start "${common_params[@]}" "${worker_params[@]}" "$@" -v
|
|
76
|
+
else
|
|
77
|
+
# Additional head-specific parameters
|
|
78
|
+
head_params=(
|
|
79
|
+
"--head"
|
|
80
|
+
"--include-dashboard=$enable_dashboard"
|
|
81
|
+
"--disable-usage-stats"
|
|
82
|
+
"--port=${RAY_HEAD_GCS_PORT:-12001}" # Port of Ray (GCS server)
|
|
83
|
+
"--ray-client-server-port=${RAY_HEAD_CLIENT_SERVER_PORT:-10001}" # Listening port for Ray Client Server
|
|
84
|
+
"--dashboard-host=${NODE_IP_ADDRESS}" # Host to bind the dashboard server
|
|
85
|
+
"--dashboard-grpc-port=${RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}" # Dashboard head to listen for grpc on
|
|
86
|
+
"--dashboard-port=${DASHBOARD_PORT}" # Port to bind the dashboard server for local debugging
|
|
87
|
+
"--resources={\"node_tag:head\":1}" # Resource tag for selecting head as coordinator
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Start Ray
|
|
91
|
+
ray start "${common_params[@]}" "${head_params[@]}" "$@"
|
|
92
|
+
fi
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
set -e # exit if a command fails
|
|
4
|
+
|
|
5
|
+
# Get and change to system scripts directory
|
|
6
|
+
SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
|
|
7
|
+
|
|
8
|
+
# Change directory to user payload directory
|
|
9
|
+
if [ -n "${MLRS_PAYLOAD_DIR}" ]; then
|
|
10
|
+
cd ${MLRS_STAGE_MOUNT_PATH}/${MLRS_PAYLOAD_DIR}
|
|
11
|
+
fi
|
|
12
|
+
|
|
13
|
+
##### Set up Python environment #####
|
|
14
|
+
export PYTHONPATH=/opt/env/site-packages/
|
|
15
|
+
MLRS_SYSTEM_REQUIREMENTS_FILE=${MLRS_SYSTEM_REQUIREMENTS_FILE:-"${SYSTEM_DIR}/requirements.txt"}
|
|
16
|
+
if [ -f "${MLRS_SYSTEM_REQUIREMENTS_FILE}" ]; then
|
|
17
|
+
echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
|
|
18
|
+
if ! pip install --no-index -r $MLRS_SYSTEM_REQUIREMENTS_FILE; then
|
|
19
|
+
echo "Offline install failed, falling back to regular pip install"
|
|
20
|
+
pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
|
|
21
|
+
fi
|
|
22
|
+
fi
|
|
23
|
+
|
|
24
|
+
MLRS_REQUIREMENTS_FILE=${MLRS_REQUIREMENTS_FILE:-"requirements.txt"}
|
|
25
|
+
if [ -f "${MLRS_REQUIREMENTS_FILE}" ]; then
|
|
26
|
+
# TODO: Prevent collisions with MLRS packages using virtualenvs
|
|
27
|
+
echo "Installing packages from $MLRS_REQUIREMENTS_FILE"
|
|
28
|
+
pip install -r $MLRS_REQUIREMENTS_FILE
|
|
29
|
+
fi
|
|
30
|
+
|
|
31
|
+
MLRS_CONDA_ENV_FILE=${MLRS_CONDA_ENV_FILE:-"environment.yml"}
|
|
32
|
+
if [ -f "${MLRS_CONDA_ENV_FILE}" ]; then
|
|
33
|
+
# TODO: Handle conda environment
|
|
34
|
+
echo "Custom conda environments not currently supported"
|
|
35
|
+
exit 1
|
|
36
|
+
fi
|
|
37
|
+
##### End Python environment setup #####
|
|
38
|
+
|
|
39
|
+
##### Set up multi-node configuration #####
|
|
40
|
+
# Configure IP address
|
|
41
|
+
if [ -f "${SYSTEM_DIR}/get_instance_ip.py" ]; then
|
|
42
|
+
eth0Ip=$(python3 "${SYSTEM_DIR}/get_instance_ip.py" \
|
|
43
|
+
"$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
|
44
|
+
else
|
|
45
|
+
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
|
46
|
+
fi
|
|
47
|
+
|
|
48
|
+
# Check if eth0Ip is a valid IP address and fall back to default if necessary
|
|
49
|
+
if [[ ! $eth0Ip =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
|
50
|
+
eth0Ip="127.0.0.1"
|
|
51
|
+
fi
|
|
52
|
+
|
|
53
|
+
# Set default values for job environment variables if they don't exist
|
|
54
|
+
# (e.g. some are only populate by SPCS for batch jobs, others just may not be set at all)
|
|
55
|
+
export SNOWFLAKE_JOBS_COUNT=${SNOWFLAKE_JOBS_COUNT:-1}
|
|
56
|
+
export SNOWFLAKE_JOB_INDEX=${SNOWFLAKE_JOB_INDEX:-0}
|
|
57
|
+
export SERVICE_NAME="${SERVICE_NAME:-$SNOWFLAKE_SERVICE_NAME}"
|
|
58
|
+
|
|
59
|
+
##### Ray configuration #####
|
|
60
|
+
|
|
61
|
+
# Determine if it should be a worker or a head node for batch jobs
|
|
62
|
+
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 ]]; then
|
|
63
|
+
head_info=$(python3 "${SYSTEM_DIR}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
|
|
64
|
+
if [ $? -eq 0 ]; then
|
|
65
|
+
# Parse the output using read
|
|
66
|
+
read head_index head_ip head_status<<< "$head_info"
|
|
67
|
+
|
|
68
|
+
if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
|
|
69
|
+
NODE_TYPE="worker"
|
|
70
|
+
fi
|
|
71
|
+
|
|
72
|
+
# Use the parsed variables
|
|
73
|
+
echo "Head Instance Index: $head_index"
|
|
74
|
+
echo "Head Instance IP: $head_ip"
|
|
75
|
+
echo "Head Instance Status: $head_status"
|
|
76
|
+
|
|
77
|
+
# If the head status is not "READY" or "PENDING", exit early
|
|
78
|
+
if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
|
|
79
|
+
echo "Head instance status is not READY or PENDING. Exiting."
|
|
80
|
+
exit 0
|
|
81
|
+
fi
|
|
82
|
+
|
|
83
|
+
else
|
|
84
|
+
echo "Error: Failed to get head instance information."
|
|
85
|
+
echo "$head_info" # Print the error message
|
|
86
|
+
exit 1
|
|
87
|
+
fi
|
|
88
|
+
fi
|
|
89
|
+
|
|
90
|
+
# Start ML Runtime (non-blocking call)
|
|
91
|
+
NODE_TYPE=$NODE_TYPE RAY_HEAD_ADDRESS="$head_ip" bash ${SYSTEM_DIR}/start_mlruntime.sh
|
|
92
|
+
|
|
93
|
+
if [ "$NODE_TYPE" = "worker" ]; then
|
|
94
|
+
echo "Worker node started on address $eth0Ip. See more logs in the head node."
|
|
95
|
+
|
|
96
|
+
# Start the worker shutdown listener in the background
|
|
97
|
+
echo "Starting worker shutdown listener..."
|
|
98
|
+
python "${SYSTEM_DIR}/worker_shutdown_listener.py"
|
|
99
|
+
WORKER_EXIT_CODE=$?
|
|
100
|
+
|
|
101
|
+
echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
|
|
102
|
+
exit $WORKER_EXIT_CODE
|
|
103
|
+
else
|
|
104
|
+
# Run user's Python entrypoint via mljob_launcher
|
|
105
|
+
echo Running command: python "${SYSTEM_DIR}/mljob_launcher.py" "$@"
|
|
106
|
+
python "${SYSTEM_DIR}/mljob_launcher.py" "$@"
|
|
107
|
+
|
|
108
|
+
# After the user's job completes, signal workers to shut down
|
|
109
|
+
echo "User job completed. Signaling workers to shut down..."
|
|
110
|
+
python "${SYSTEM_DIR}/signal_workers.py" --wait-time 15
|
|
111
|
+
echo "Head node job completed. Exiting."
|
|
112
|
+
fi
|