snowflake-ml-python 1.8.0__py3-none-any.whl → 1.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +44 -10
- snowflake/ml/_internal/platform_capabilities.py +39 -3
- snowflake/ml/data/data_connector.py +25 -0
- snowflake/ml/dataset/dataset_reader.py +5 -1
- snowflake/ml/jobs/_utils/constants.py +3 -5
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +81 -47
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +27 -8
- snowflake/ml/jobs/_utils/types.py +6 -0
- snowflake/ml/jobs/decorators.py +10 -6
- snowflake/ml/jobs/job.py +145 -23
- snowflake/ml/jobs/manager.py +79 -12
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +57 -39
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
- snowflake/ml/model/_client/sql/service.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +29 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -2
- snowflake/ml/model/_packager/model_env/model_env.py +8 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
- snowflake/ml/registry/_manager/model_manager.py +20 -1
- snowflake/ml/registry/registry.py +46 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +55 -4
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +40 -34
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,136 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
# This file is modified from mlruntime/service/snowflake/runtime/utils
|
3
|
+
import argparse
|
4
|
+
import logging
|
5
|
+
import socket
|
6
|
+
import sys
|
7
|
+
import time
|
8
|
+
from typing import Optional
|
9
|
+
|
10
|
+
# Configure logging
|
11
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
def get_self_ip() -> Optional[str]:
|
16
|
+
"""Get the IP address of the current service instance.
|
17
|
+
References:
|
18
|
+
- https://docs.snowflake.com/en/developer-guide/snowpark-container-services/working-with-services#general-guidelines-related-to-service-to-service-communications # noqa: E501
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
Optional[str]: The IP address of the current service instance, or None if unable to retrieve.
|
22
|
+
"""
|
23
|
+
try:
|
24
|
+
hostname = socket.gethostname()
|
25
|
+
instance_ip = socket.gethostbyname(hostname)
|
26
|
+
return instance_ip
|
27
|
+
except OSError as e:
|
28
|
+
logger.error(f"Error: Unable to get IP address via socket. {e}")
|
29
|
+
return None
|
30
|
+
|
31
|
+
|
32
|
+
def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
33
|
+
"""Get the first instance of a batch job based on start time and instance ID.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
service_name (str): The name of the service to query.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
tuple[str, str]: A tuple containing (instance_id, ip_address) of the head instance.
|
40
|
+
"""
|
41
|
+
from snowflake.runtime.utils import session_utils
|
42
|
+
|
43
|
+
session = session_utils.get_session()
|
44
|
+
df = session.sql(f"show service instances in service {service_name}")
|
45
|
+
result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
|
46
|
+
|
47
|
+
if not result:
|
48
|
+
return None
|
49
|
+
|
50
|
+
# Sort by start_time first, then by instance_id
|
51
|
+
sorted_instances = sorted(result, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
52
|
+
head_instance = sorted_instances[0]
|
53
|
+
if not head_instance["instance_id"] or not head_instance["ip_address"]:
|
54
|
+
return None
|
55
|
+
|
56
|
+
# Validate head instance IP
|
57
|
+
ip_address = head_instance["ip_address"]
|
58
|
+
try:
|
59
|
+
socket.inet_aton(ip_address) # Validate IPv4 address
|
60
|
+
return (head_instance["instance_id"], ip_address)
|
61
|
+
except OSError:
|
62
|
+
logger.error(f"Error: Invalid IP address format: {ip_address}")
|
63
|
+
return None
|
64
|
+
|
65
|
+
|
66
|
+
def main():
|
67
|
+
"""Retrieves the IP address of a specified service instance or the current service.
|
68
|
+
Args:
|
69
|
+
service_name (str,required) Name of the service to query
|
70
|
+
--instance-index (int, optional) Index of the service instance to query. Default: -1
|
71
|
+
Currently only supports -1 to get the IP address of the current service instance.
|
72
|
+
--head (bool, optional) Get the head instance information using show services.
|
73
|
+
If set, instance-index will be ignored, and the script will return the index and IP address of
|
74
|
+
the head instance, split by a space. Default: False.
|
75
|
+
--timeout (int, optional) Maximum time to wait for IP address retrieval in seconds. Default: 720 seconds
|
76
|
+
--retry-interval (int, optional) Time to wait between retry attempts in seconds. Default: 10 seconds
|
77
|
+
Usage Examples:
|
78
|
+
python get_instance_ip.py myservice --instance-index=1 --retry-interval=5
|
79
|
+
Returns:
|
80
|
+
Prints the IP address to stdout if successful. Exits with status code 0 on success, 1 on failure
|
81
|
+
"""
|
82
|
+
|
83
|
+
parser = argparse.ArgumentParser(description="Get IP address of a service instance")
|
84
|
+
group = parser.add_mutually_exclusive_group()
|
85
|
+
parser.add_argument("service_name", help="Name of the service")
|
86
|
+
group.add_argument(
|
87
|
+
"--instance-index",
|
88
|
+
type=int,
|
89
|
+
default=-1,
|
90
|
+
help="Index of service instance (default: -1 for self instance)",
|
91
|
+
)
|
92
|
+
group.add_argument(
|
93
|
+
"--head",
|
94
|
+
action="store_true",
|
95
|
+
help="Get head instance information using show services",
|
96
|
+
)
|
97
|
+
parser.add_argument("--timeout", type=int, default=720, help="Timeout in seconds (default: 720)")
|
98
|
+
parser.add_argument(
|
99
|
+
"--retry-interval",
|
100
|
+
type=int,
|
101
|
+
default=10,
|
102
|
+
help="Retry interval in seconds (default: 10)",
|
103
|
+
)
|
104
|
+
|
105
|
+
args = parser.parse_args()
|
106
|
+
start_time = time.time()
|
107
|
+
|
108
|
+
if args.head:
|
109
|
+
while time.time() - start_time < args.timeout:
|
110
|
+
head_info = get_first_instance(args.service_name)
|
111
|
+
if head_info:
|
112
|
+
# Print to stdout to allow capture but don't use logger
|
113
|
+
sys.stdout.write(f"{head_info[0]} {head_info[1]}\n")
|
114
|
+
sys.exit(0)
|
115
|
+
time.sleep(args.retry_interval)
|
116
|
+
# If we get here, we've timed out
|
117
|
+
logger.error("Error: Unable to retrieve head IP address")
|
118
|
+
sys.exit(1)
|
119
|
+
|
120
|
+
# If the index is -1, use get_self_ip to get the IP address of the current service
|
121
|
+
if args.instance_index == -1:
|
122
|
+
ip_address = get_self_ip()
|
123
|
+
if ip_address:
|
124
|
+
sys.stdout.write(f"{ip_address}\n")
|
125
|
+
sys.exit(0)
|
126
|
+
else:
|
127
|
+
logger.error("Error: Unable to retrieve self IP address")
|
128
|
+
sys.exit(1)
|
129
|
+
else:
|
130
|
+
# We don't support querying a specific instance index other than -1
|
131
|
+
logger.error("Error: Invalid arguments. Only --instance-index=-1 is supported for now.")
|
132
|
+
sys.exit(1)
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == "__main__":
|
136
|
+
main()
|
@@ -0,0 +1,178 @@
|
|
1
|
+
import argparse
|
2
|
+
import importlib.util
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import runpy
|
6
|
+
import sys
|
7
|
+
import traceback
|
8
|
+
import warnings
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Any, Dict, Optional
|
11
|
+
|
12
|
+
import cloudpickle
|
13
|
+
|
14
|
+
from snowflake.ml.jobs._utils import constants
|
15
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
16
|
+
from snowflake.snowpark import Session
|
17
|
+
|
18
|
+
# Fallbacks in case of SnowML version mismatch
|
19
|
+
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
20
|
+
|
21
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
|
22
|
+
|
23
|
+
|
24
|
+
try:
|
25
|
+
from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
|
26
|
+
except ImportError:
|
27
|
+
from dataclasses import dataclass
|
28
|
+
|
29
|
+
@dataclass(frozen=True)
|
30
|
+
class ExecutionResult:
|
31
|
+
result: Optional[Any] = None
|
32
|
+
exception: Optional[BaseException] = None
|
33
|
+
|
34
|
+
@property
|
35
|
+
def success(self) -> bool:
|
36
|
+
return self.exception is None
|
37
|
+
|
38
|
+
def to_dict(self) -> Dict[str, Any]:
|
39
|
+
"""Return the serializable dictionary."""
|
40
|
+
if isinstance(self.exception, BaseException):
|
41
|
+
exc_type = type(self.exception)
|
42
|
+
return {
|
43
|
+
"success": False,
|
44
|
+
"exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
|
45
|
+
"exc_value": self.exception,
|
46
|
+
"exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
|
47
|
+
}
|
48
|
+
return {
|
49
|
+
"success": True,
|
50
|
+
"result_type": type(self.result).__qualname__,
|
51
|
+
"result": self.result,
|
52
|
+
}
|
53
|
+
|
54
|
+
|
55
|
+
# Create a custom JSON encoder that converts non-serializable types to strings
|
56
|
+
class SimpleJSONEncoder(json.JSONEncoder):
|
57
|
+
def default(self, obj: Any) -> Any:
|
58
|
+
try:
|
59
|
+
return super().default(obj)
|
60
|
+
except TypeError:
|
61
|
+
return str(obj)
|
62
|
+
|
63
|
+
|
64
|
+
def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
|
65
|
+
"""
|
66
|
+
Execute a Python script and return its result.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
script_path: Path to the Python script
|
70
|
+
script_args: Arguments to pass to the script
|
71
|
+
main_func: The name of the function to call in the script (if any)
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
Result from script execution, either from the main function or the script's __return__ value
|
75
|
+
|
76
|
+
Raises:
|
77
|
+
RuntimeError: If the specified main_func is not found or not callable
|
78
|
+
"""
|
79
|
+
# Save original sys.argv and modify it for the script (applies to runpy execution only)
|
80
|
+
original_argv = sys.argv
|
81
|
+
sys.argv = [script_path, *script_args]
|
82
|
+
|
83
|
+
# Create a Snowpark session before running the script
|
84
|
+
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
85
|
+
session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
|
86
|
+
|
87
|
+
try:
|
88
|
+
if main_func:
|
89
|
+
# Use importlib for scripts with a main function defined
|
90
|
+
module_name = Path(script_path).stem
|
91
|
+
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
92
|
+
assert spec is not None
|
93
|
+
assert spec.loader is not None
|
94
|
+
module = importlib.util.module_from_spec(spec)
|
95
|
+
spec.loader.exec_module(module)
|
96
|
+
|
97
|
+
# Validate main function
|
98
|
+
if not (func := getattr(module, main_func, None)) or not callable(func):
|
99
|
+
raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
|
100
|
+
|
101
|
+
# Call main function
|
102
|
+
result = func(*script_args)
|
103
|
+
return result
|
104
|
+
else:
|
105
|
+
# Use runpy for other scripts
|
106
|
+
globals_dict = runpy.run_path(script_path, run_name="__main__")
|
107
|
+
result = globals_dict.get("__return__", None)
|
108
|
+
return result
|
109
|
+
finally:
|
110
|
+
# Restore original sys.argv
|
111
|
+
sys.argv = original_argv
|
112
|
+
|
113
|
+
|
114
|
+
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
|
115
|
+
"""Executes a Python script and serializes the result to JOB_RESULT_PATH.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
script_path (str): Path to the Python script to execute.
|
119
|
+
script_args (Any): Arguments to pass to the script.
|
120
|
+
script_main_func (str, optional): The name of the function to call in the script (if any).
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
ExecutionResult: Object containing execution results.
|
124
|
+
|
125
|
+
Raises:
|
126
|
+
Exception: Re-raises any exception caught during script execution.
|
127
|
+
"""
|
128
|
+
# Run the script with the specified arguments
|
129
|
+
try:
|
130
|
+
result = run_script(script_path, *script_args, main_func=script_main_func)
|
131
|
+
result_obj = ExecutionResult(result=result)
|
132
|
+
return result_obj
|
133
|
+
except Exception as e:
|
134
|
+
tb = e.__traceback__
|
135
|
+
skip_files = {__file__, runpy.__file__}
|
136
|
+
while tb and tb.tb_frame.f_code.co_filename in skip_files:
|
137
|
+
# Skip any frames preceding user script execution
|
138
|
+
tb = tb.tb_next
|
139
|
+
result_obj = ExecutionResult(exception=e.with_traceback(tb))
|
140
|
+
raise
|
141
|
+
finally:
|
142
|
+
result_dict = result_obj.to_dict()
|
143
|
+
try:
|
144
|
+
# Serialize result using cloudpickle
|
145
|
+
result_pickle_path = JOB_RESULT_PATH
|
146
|
+
with open(result_pickle_path, "wb") as f:
|
147
|
+
cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
|
148
|
+
except Exception as pkl_exc:
|
149
|
+
warnings.warn(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}", RuntimeWarning, stacklevel=1)
|
150
|
+
|
151
|
+
try:
|
152
|
+
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
153
|
+
# TODO: Manually convert non-serializable types to strings
|
154
|
+
result_json_path = os.path.splitext(JOB_RESULT_PATH)[0] + ".json"
|
155
|
+
with open(result_json_path, "w") as f:
|
156
|
+
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
157
|
+
except Exception as json_exc:
|
158
|
+
warnings.warn(
|
159
|
+
f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
|
160
|
+
)
|
161
|
+
|
162
|
+
|
163
|
+
if __name__ == "__main__":
|
164
|
+
# Parse command line arguments
|
165
|
+
parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
|
166
|
+
parser.add_argument("script_path", help="Path to the Python script to execute")
|
167
|
+
parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
|
168
|
+
parser.add_argument(
|
169
|
+
"--script_main_func", required=False, help="The name of the main function to call in the script"
|
170
|
+
)
|
171
|
+
args, unknown_args = parser.parse_known_args()
|
172
|
+
|
173
|
+
main(
|
174
|
+
args.script_path,
|
175
|
+
*args.script_args,
|
176
|
+
*unknown_args,
|
177
|
+
script_main_func=args.script_main_func,
|
178
|
+
)
|
@@ -0,0 +1,203 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
# This file is part of the Ray-based distributed job system for Snowflake ML.
|
3
|
+
# Architecture overview:
|
4
|
+
# - Head node creates a ShutdownSignal actor and signals workers when job completes
|
5
|
+
# - Worker nodes listen for this signal and gracefully shut down
|
6
|
+
# - This ensures clean termination of distributed Ray jobs
|
7
|
+
import argparse
|
8
|
+
import logging
|
9
|
+
import socket
|
10
|
+
import sys
|
11
|
+
import time
|
12
|
+
from typing import Any, Dict, List, Set
|
13
|
+
|
14
|
+
import ray
|
15
|
+
from constants import (
|
16
|
+
SHUTDOWN_ACTOR_NAME,
|
17
|
+
SHUTDOWN_ACTOR_NAMESPACE,
|
18
|
+
SHUTDOWN_RPC_TIMEOUT_SECONDS,
|
19
|
+
)
|
20
|
+
from ray.actor import ActorHandle
|
21
|
+
|
22
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
23
|
+
|
24
|
+
|
25
|
+
@ray.remote
|
26
|
+
class ShutdownSignal:
|
27
|
+
"""A simple Ray actor that workers can check to determine if they should shutdown"""
|
28
|
+
|
29
|
+
def __init__(self) -> None:
|
30
|
+
self.shutdown_requested = False
|
31
|
+
self.timestamp = None
|
32
|
+
self.hostname = socket.gethostname()
|
33
|
+
self.acknowledged_workers = set()
|
34
|
+
logging.info(f"ShutdownSignal actor created on {self.hostname}")
|
35
|
+
|
36
|
+
def request_shutdown(self) -> Dict[str, Any]:
|
37
|
+
"""Signal workers to shut down"""
|
38
|
+
self.shutdown_requested = True
|
39
|
+
self.timestamp = time.time()
|
40
|
+
logging.info(f"Shutdown requested by head node at {self.timestamp}")
|
41
|
+
return {"status": "shutdown_requested", "timestamp": self.timestamp, "host": self.hostname}
|
42
|
+
|
43
|
+
def should_shutdown(self) -> Dict[str, Any]:
|
44
|
+
"""Check if shutdown has been requested"""
|
45
|
+
return {"shutdown": self.shutdown_requested, "timestamp": self.timestamp, "host": self.hostname}
|
46
|
+
|
47
|
+
def ping(self) -> Dict[str, Any]:
|
48
|
+
"""Simple method to test connectivity"""
|
49
|
+
return {"status": "alive", "host": self.hostname}
|
50
|
+
|
51
|
+
def acknowledge_shutdown(self, worker_id: str) -> Dict[str, Any]:
|
52
|
+
"""Worker acknowledges it has received the shutdown signal and is terminating"""
|
53
|
+
self.acknowledged_workers.add(worker_id)
|
54
|
+
logging.info(f"Worker {worker_id} acknowledged shutdown. Total acknowledged: {len(self.acknowledged_workers)}")
|
55
|
+
|
56
|
+
return {"status": "acknowledged", "worker_id": worker_id, "acknowledged_count": len(self.acknowledged_workers)}
|
57
|
+
|
58
|
+
def get_acknowledgment_workers(self) -> Set[str]:
|
59
|
+
"""Get the set of workers who have acknowledged shutdown"""
|
60
|
+
return self.acknowledged_workers
|
61
|
+
|
62
|
+
|
63
|
+
def get_worker_node_ids() -> List[str]:
|
64
|
+
"""Get the IDs of all active worker nodes.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
List[str]: List of worker node IDs. Empty list if no workers are present.
|
68
|
+
"""
|
69
|
+
worker_nodes = [
|
70
|
+
node for node in ray.nodes() if node.get("Alive") and node.get("Resources", {}).get("node_tag:worker", 0) > 0
|
71
|
+
]
|
72
|
+
|
73
|
+
worker_node_ids = [node.get("NodeName") for node in worker_nodes]
|
74
|
+
|
75
|
+
if worker_node_ids:
|
76
|
+
logging.info(f"Found {len(worker_node_ids)} worker nodes")
|
77
|
+
else:
|
78
|
+
logging.info("No active worker nodes found")
|
79
|
+
|
80
|
+
return worker_node_ids
|
81
|
+
|
82
|
+
|
83
|
+
def get_or_create_shutdown_signal() -> ActorHandle:
|
84
|
+
"""Get existing shutdown signal actor or create a new one.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
ActorHandle: Reference to shutdown signal actor
|
88
|
+
"""
|
89
|
+
try:
|
90
|
+
# Try to get existing actor
|
91
|
+
shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
|
92
|
+
logging.info("Found existing shutdown signal actor")
|
93
|
+
except (ValueError, ray.exceptions.RayActorError) as e:
|
94
|
+
logging.info(f"Creating new shutdown signal actor: {e}")
|
95
|
+
# Create new actor if it doesn't exist
|
96
|
+
shutdown_signal = ShutdownSignal.options(
|
97
|
+
name=SHUTDOWN_ACTOR_NAME,
|
98
|
+
namespace=SHUTDOWN_ACTOR_NAMESPACE,
|
99
|
+
lifetime="detached", # Ensure actor survives client disconnect
|
100
|
+
resources={"node_tag:head": 0.001}, # Resource constraint to ensure it runs on head node
|
101
|
+
).remote()
|
102
|
+
|
103
|
+
# Verify actor is created and accessible
|
104
|
+
ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
105
|
+
logging.debug(f"New actor ping response: {ping_result}")
|
106
|
+
|
107
|
+
return shutdown_signal
|
108
|
+
|
109
|
+
|
110
|
+
def request_shutdown(shutdown_signal: ActorHandle) -> None:
|
111
|
+
"""Request workers to shut down.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
shutdown_signal: Reference to the shutdown signal actor
|
115
|
+
"""
|
116
|
+
response = ray.get(shutdown_signal.request_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
117
|
+
logging.info(f"Shutdown requested: {response}")
|
118
|
+
|
119
|
+
|
120
|
+
def verify_shutdown(shutdown_signal: ActorHandle) -> None:
|
121
|
+
"""Verify that shutdown was properly signaled.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
shutdown_signal: Reference to the shutdown signal actor
|
125
|
+
"""
|
126
|
+
check = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
127
|
+
logging.debug(f"Shutdown status check: {check}")
|
128
|
+
|
129
|
+
|
130
|
+
def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: List[str], wait_time: int) -> None:
|
131
|
+
"""Wait for workers to acknowledge shutdown.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
shutdown_signal: Reference to the shutdown signal actor
|
135
|
+
worker_node_ids: List of worker node IDs
|
136
|
+
wait_time: Time in seconds to wait for acknowledgments
|
137
|
+
|
138
|
+
Raises:
|
139
|
+
TimeoutError: When workers don't acknowledge within the wait time or if actor communication times out
|
140
|
+
"""
|
141
|
+
if not worker_node_ids:
|
142
|
+
return
|
143
|
+
|
144
|
+
logging.info(f"Waiting up to {wait_time}s for workers to acknowledge shutdown signal...")
|
145
|
+
start_time = time.time()
|
146
|
+
check_interval = 1.0
|
147
|
+
|
148
|
+
while time.time() - start_time < wait_time:
|
149
|
+
try:
|
150
|
+
ack_workers = ray.get(
|
151
|
+
shutdown_signal.get_acknowledgment_workers.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
|
152
|
+
)
|
153
|
+
if ack_workers and ack_workers == set(worker_node_ids):
|
154
|
+
logging.info(
|
155
|
+
f"All {len(worker_node_ids)} workers acknowledged shutdown. "
|
156
|
+
f"Completed in {time.time() - start_time:.2f}s"
|
157
|
+
)
|
158
|
+
return
|
159
|
+
else:
|
160
|
+
logging.debug(f"Waiting for acknowledgments: {len(ack_workers)}/{len(worker_node_ids)} workers")
|
161
|
+
except Exception as e:
|
162
|
+
logging.warning(f"Error checking acknowledgment status: {e}")
|
163
|
+
|
164
|
+
time.sleep(check_interval)
|
165
|
+
|
166
|
+
raise TimeoutError(
|
167
|
+
f"Timed out waiting for {len(worker_node_ids)} workers to acknowledge shutdown after {wait_time}s"
|
168
|
+
)
|
169
|
+
|
170
|
+
|
171
|
+
def signal_workers(wait_time: int = 10) -> int:
|
172
|
+
"""
|
173
|
+
Signal worker nodes to shut down by creating a shutdown signal actor.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
wait_time: Time in seconds to wait for workers to receive the message
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
0 for success, 1 for failure
|
180
|
+
"""
|
181
|
+
ray.init(address="auto", ignore_reinit_error=True)
|
182
|
+
|
183
|
+
worker_node_ids = get_worker_node_ids()
|
184
|
+
|
185
|
+
if worker_node_ids:
|
186
|
+
shutdown_signal = get_or_create_shutdown_signal()
|
187
|
+
request_shutdown(shutdown_signal)
|
188
|
+
verify_shutdown(shutdown_signal)
|
189
|
+
wait_for_acknowledgments(shutdown_signal, worker_node_ids, wait_time)
|
190
|
+
else:
|
191
|
+
logging.info("No active worker nodes found to signal.")
|
192
|
+
|
193
|
+
return 0
|
194
|
+
|
195
|
+
|
196
|
+
if __name__ == "__main__":
|
197
|
+
parser = argparse.ArgumentParser(description="Signal Ray workers to shutdown")
|
198
|
+
parser.add_argument(
|
199
|
+
"--wait-time", type=int, default=10, help="Time in seconds to wait for workers to receive the signal"
|
200
|
+
)
|
201
|
+
args = parser.parse_args()
|
202
|
+
|
203
|
+
sys.exit(signal_workers(args.wait_time))
|