snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,203 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
# This file is part of the Ray-based distributed job system for Snowflake ML.
|
3
|
+
# Architecture overview:
|
4
|
+
# - Head node creates a ShutdownSignal actor and signals workers when job completes
|
5
|
+
# - Worker nodes listen for this signal and gracefully shut down
|
6
|
+
# - This ensures clean termination of distributed Ray jobs
|
7
|
+
import argparse
|
8
|
+
import logging
|
9
|
+
import socket
|
10
|
+
import sys
|
11
|
+
import time
|
12
|
+
from typing import Any
|
13
|
+
|
14
|
+
import ray
|
15
|
+
from constants import (
|
16
|
+
SHUTDOWN_ACTOR_NAME,
|
17
|
+
SHUTDOWN_ACTOR_NAMESPACE,
|
18
|
+
SHUTDOWN_RPC_TIMEOUT_SECONDS,
|
19
|
+
)
|
20
|
+
from ray.actor import ActorHandle
|
21
|
+
|
22
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
23
|
+
|
24
|
+
|
25
|
+
@ray.remote
|
26
|
+
class ShutdownSignal:
|
27
|
+
"""A simple Ray actor that workers can check to determine if they should shutdown"""
|
28
|
+
|
29
|
+
def __init__(self) -> None:
|
30
|
+
self.shutdown_requested = False
|
31
|
+
self.timestamp = None
|
32
|
+
self.hostname = socket.gethostname()
|
33
|
+
self.acknowledged_workers = set()
|
34
|
+
logging.info(f"ShutdownSignal actor created on {self.hostname}")
|
35
|
+
|
36
|
+
def request_shutdown(self) -> dict[str, Any]:
|
37
|
+
"""Signal workers to shut down"""
|
38
|
+
self.shutdown_requested = True
|
39
|
+
self.timestamp = time.time()
|
40
|
+
logging.info(f"Shutdown requested by head node at {self.timestamp}")
|
41
|
+
return {"status": "shutdown_requested", "timestamp": self.timestamp, "host": self.hostname}
|
42
|
+
|
43
|
+
def should_shutdown(self) -> dict[str, Any]:
|
44
|
+
"""Check if shutdown has been requested"""
|
45
|
+
return {"shutdown": self.shutdown_requested, "timestamp": self.timestamp, "host": self.hostname}
|
46
|
+
|
47
|
+
def ping(self) -> dict[str, Any]:
|
48
|
+
"""Simple method to test connectivity"""
|
49
|
+
return {"status": "alive", "host": self.hostname}
|
50
|
+
|
51
|
+
def acknowledge_shutdown(self, worker_id: str) -> dict[str, Any]:
|
52
|
+
"""Worker acknowledges it has received the shutdown signal and is terminating"""
|
53
|
+
self.acknowledged_workers.add(worker_id)
|
54
|
+
logging.info(f"Worker {worker_id} acknowledged shutdown. Total acknowledged: {len(self.acknowledged_workers)}")
|
55
|
+
|
56
|
+
return {"status": "acknowledged", "worker_id": worker_id, "acknowledged_count": len(self.acknowledged_workers)}
|
57
|
+
|
58
|
+
def get_acknowledgment_workers(self) -> set[str]:
|
59
|
+
"""Get the set of workers who have acknowledged shutdown"""
|
60
|
+
return self.acknowledged_workers
|
61
|
+
|
62
|
+
|
63
|
+
def get_worker_node_ids() -> list[str]:
|
64
|
+
"""Get the IDs of all active worker nodes.
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
List[str]: List of worker node IDs. Empty list if no workers are present.
|
68
|
+
"""
|
69
|
+
worker_nodes = [
|
70
|
+
node for node in ray.nodes() if node.get("Alive") and node.get("Resources", {}).get("node_tag:worker", 0) > 0
|
71
|
+
]
|
72
|
+
|
73
|
+
worker_node_ids = [node.get("NodeName") for node in worker_nodes]
|
74
|
+
|
75
|
+
if worker_node_ids:
|
76
|
+
logging.info(f"Found {len(worker_node_ids)} worker nodes")
|
77
|
+
else:
|
78
|
+
logging.info("No active worker nodes found")
|
79
|
+
|
80
|
+
return worker_node_ids
|
81
|
+
|
82
|
+
|
83
|
+
def get_or_create_shutdown_signal() -> ActorHandle:
|
84
|
+
"""Get existing shutdown signal actor or create a new one.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
ActorHandle: Reference to shutdown signal actor
|
88
|
+
"""
|
89
|
+
try:
|
90
|
+
# Try to get existing actor
|
91
|
+
shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
|
92
|
+
logging.info("Found existing shutdown signal actor")
|
93
|
+
except (ValueError, ray.exceptions.RayActorError) as e:
|
94
|
+
logging.info(f"Creating new shutdown signal actor: {e}")
|
95
|
+
# Create new actor if it doesn't exist
|
96
|
+
shutdown_signal = ShutdownSignal.options(
|
97
|
+
name=SHUTDOWN_ACTOR_NAME,
|
98
|
+
namespace=SHUTDOWN_ACTOR_NAMESPACE,
|
99
|
+
lifetime="detached", # Ensure actor survives client disconnect
|
100
|
+
resources={"node_tag:head": 0.001}, # Resource constraint to ensure it runs on head node
|
101
|
+
).remote()
|
102
|
+
|
103
|
+
# Verify actor is created and accessible
|
104
|
+
ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
105
|
+
logging.debug(f"New actor ping response: {ping_result}")
|
106
|
+
|
107
|
+
return shutdown_signal
|
108
|
+
|
109
|
+
|
110
|
+
def request_shutdown(shutdown_signal: ActorHandle) -> None:
|
111
|
+
"""Request workers to shut down.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
shutdown_signal: Reference to the shutdown signal actor
|
115
|
+
"""
|
116
|
+
response = ray.get(shutdown_signal.request_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
117
|
+
logging.info(f"Shutdown requested: {response}")
|
118
|
+
|
119
|
+
|
120
|
+
def verify_shutdown(shutdown_signal: ActorHandle) -> None:
|
121
|
+
"""Verify that shutdown was properly signaled.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
shutdown_signal: Reference to the shutdown signal actor
|
125
|
+
"""
|
126
|
+
check = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
127
|
+
logging.debug(f"Shutdown status check: {check}")
|
128
|
+
|
129
|
+
|
130
|
+
def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: list[str], wait_time: int) -> None:
|
131
|
+
"""Wait for workers to acknowledge shutdown.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
shutdown_signal: Reference to the shutdown signal actor
|
135
|
+
worker_node_ids: List of worker node IDs
|
136
|
+
wait_time: Time in seconds to wait for acknowledgments
|
137
|
+
|
138
|
+
Raises:
|
139
|
+
TimeoutError: When workers don't acknowledge within the wait time or if actor communication times out
|
140
|
+
"""
|
141
|
+
if not worker_node_ids:
|
142
|
+
return
|
143
|
+
|
144
|
+
logging.info(f"Waiting up to {wait_time}s for workers to acknowledge shutdown signal...")
|
145
|
+
start_time = time.time()
|
146
|
+
check_interval = 1.0
|
147
|
+
|
148
|
+
while time.time() - start_time < wait_time:
|
149
|
+
try:
|
150
|
+
ack_workers = ray.get(
|
151
|
+
shutdown_signal.get_acknowledgment_workers.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
|
152
|
+
)
|
153
|
+
if ack_workers and ack_workers == set(worker_node_ids):
|
154
|
+
logging.info(
|
155
|
+
f"All {len(worker_node_ids)} workers acknowledged shutdown. "
|
156
|
+
f"Completed in {time.time() - start_time:.2f}s"
|
157
|
+
)
|
158
|
+
return
|
159
|
+
else:
|
160
|
+
logging.debug(f"Waiting for acknowledgments: {len(ack_workers)}/{len(worker_node_ids)} workers")
|
161
|
+
except Exception as e:
|
162
|
+
logging.warning(f"Error checking acknowledgment status: {e}")
|
163
|
+
|
164
|
+
time.sleep(check_interval)
|
165
|
+
|
166
|
+
raise TimeoutError(
|
167
|
+
f"Timed out waiting for {len(worker_node_ids)} workers to acknowledge shutdown after {wait_time}s"
|
168
|
+
)
|
169
|
+
|
170
|
+
|
171
|
+
def signal_workers(wait_time: int = 10) -> int:
|
172
|
+
"""
|
173
|
+
Signal worker nodes to shut down by creating a shutdown signal actor.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
wait_time: Time in seconds to wait for workers to receive the message
|
177
|
+
|
178
|
+
Returns:
|
179
|
+
0 for success, 1 for failure
|
180
|
+
"""
|
181
|
+
ray.init(address="auto", ignore_reinit_error=True)
|
182
|
+
|
183
|
+
worker_node_ids = get_worker_node_ids()
|
184
|
+
|
185
|
+
if worker_node_ids:
|
186
|
+
shutdown_signal = get_or_create_shutdown_signal()
|
187
|
+
request_shutdown(shutdown_signal)
|
188
|
+
verify_shutdown(shutdown_signal)
|
189
|
+
wait_for_acknowledgments(shutdown_signal, worker_node_ids, wait_time)
|
190
|
+
else:
|
191
|
+
logging.info("No active worker nodes found to signal.")
|
192
|
+
|
193
|
+
return 0
|
194
|
+
|
195
|
+
|
196
|
+
if __name__ == "__main__":
|
197
|
+
parser = argparse.ArgumentParser(description="Signal Ray workers to shutdown")
|
198
|
+
parser.add_argument(
|
199
|
+
"--wait-time", type=int, default=10, help="Time in seconds to wait for workers to receive the signal"
|
200
|
+
)
|
201
|
+
args = parser.parse_args()
|
202
|
+
|
203
|
+
sys.exit(signal_workers(args.wait_time))
|
@@ -0,0 +1,242 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
# This file is part of the Ray-based distributed job system for Snowflake ML.
|
3
|
+
# Architecture overview:
|
4
|
+
# - Head node creates a ShutdownSignal actor and signals workers when job completes
|
5
|
+
# - Worker nodes listen for this signal via this script and gracefully shut down
|
6
|
+
# - This ensures clean termination of distributed Ray jobs
|
7
|
+
import logging
|
8
|
+
import signal
|
9
|
+
import sys
|
10
|
+
import time
|
11
|
+
from typing import Optional
|
12
|
+
|
13
|
+
import get_instance_ip
|
14
|
+
import ray
|
15
|
+
from constants import (
|
16
|
+
SHUTDOWN_ACTOR_NAME,
|
17
|
+
SHUTDOWN_ACTOR_NAMESPACE,
|
18
|
+
SHUTDOWN_RPC_TIMEOUT_SECONDS,
|
19
|
+
)
|
20
|
+
from ray.actor import ActorHandle
|
21
|
+
|
22
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
23
|
+
|
24
|
+
|
25
|
+
def get_shutdown_actor() -> Optional[ActorHandle]:
|
26
|
+
"""
|
27
|
+
Retrieve the shutdown signal actor from Ray.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
The shutdown signal actor or None if not found
|
31
|
+
"""
|
32
|
+
try:
|
33
|
+
shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
|
34
|
+
return shutdown_signal
|
35
|
+
except Exception:
|
36
|
+
return None
|
37
|
+
|
38
|
+
|
39
|
+
def ping_shutdown_actor(shutdown_signal: ActorHandle) -> bool:
|
40
|
+
"""
|
41
|
+
Ping the shutdown actor to ensure connectivity.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
shutdown_signal: The Ray actor handle for the shutdown signal
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
True if ping succeeds, False otherwise
|
48
|
+
"""
|
49
|
+
try:
|
50
|
+
ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
51
|
+
logging.debug(f"Actor ping result: {ping_result}")
|
52
|
+
return True
|
53
|
+
except (ray.exceptions.GetTimeoutError, Exception) as e:
|
54
|
+
logging.debug(f"Actor ping failed: {e}")
|
55
|
+
return False
|
56
|
+
|
57
|
+
|
58
|
+
def check_shutdown_status(shutdown_signal: ActorHandle, worker_id: str) -> bool:
|
59
|
+
"""
|
60
|
+
Check if worker should shutdown and acknowledge if needed.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
shutdown_signal: The Ray actor handle for the shutdown signal
|
64
|
+
worker_id: Worker identifier (IP address)
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
True if should shutdown, False otherwise
|
68
|
+
"""
|
69
|
+
try:
|
70
|
+
status = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
71
|
+
logging.debug(f"Shutdown status: {status}")
|
72
|
+
|
73
|
+
if status.get("shutdown", False):
|
74
|
+
logging.info(
|
75
|
+
f"Received shutdown signal from head node at {status.get('timestamp')}. " f"Exiting worker process."
|
76
|
+
)
|
77
|
+
|
78
|
+
# Acknowledge shutdown before exiting
|
79
|
+
try:
|
80
|
+
ack_result = ray.get(
|
81
|
+
shutdown_signal.acknowledge_shutdown.remote(worker_id), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
|
82
|
+
)
|
83
|
+
logging.info(f"Acknowledged shutdown: {ack_result}")
|
84
|
+
except Exception as e:
|
85
|
+
logging.warning(f"Failed to acknowledge shutdown: {e}. Continue to exit worker.")
|
86
|
+
|
87
|
+
return True
|
88
|
+
return False
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
logging.debug(f"Error checking shutdown status: {e}")
|
92
|
+
return False
|
93
|
+
|
94
|
+
|
95
|
+
def check_ray_connectivity() -> bool:
|
96
|
+
"""
|
97
|
+
Check if the Ray cluster is accessible.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
True if Ray is connected, False otherwise
|
101
|
+
"""
|
102
|
+
try:
|
103
|
+
# A simple check to verify Ray is working
|
104
|
+
nodes = ray.nodes()
|
105
|
+
if nodes:
|
106
|
+
return True
|
107
|
+
return False
|
108
|
+
except Exception as e:
|
109
|
+
logging.debug(f"Ray connectivity check failed: {e}")
|
110
|
+
return False
|
111
|
+
|
112
|
+
|
113
|
+
def initialize_ray_connection(max_retries: int, initial_retry_delay: int, max_retry_delay: int) -> bool:
|
114
|
+
"""
|
115
|
+
Initialize connection to Ray with retries.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
max_retries: Maximum number of connection attempts
|
119
|
+
initial_retry_delay: Initial delay between retries in seconds
|
120
|
+
max_retry_delay: Maximum delay between retries in seconds
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
bool: True if connection successful, False otherwise
|
124
|
+
"""
|
125
|
+
retry_count = 0
|
126
|
+
retry_delay = initial_retry_delay
|
127
|
+
|
128
|
+
while retry_count < max_retries:
|
129
|
+
try:
|
130
|
+
ray.init(address="auto", ignore_reinit_error=True)
|
131
|
+
return True
|
132
|
+
except (ConnectionError, TimeoutError, RuntimeError) as e:
|
133
|
+
retry_count += 1
|
134
|
+
if retry_count >= max_retries:
|
135
|
+
logging.error(f"Failed to connect to Ray head after {max_retries} attempts: {e}")
|
136
|
+
return False
|
137
|
+
|
138
|
+
logging.debug(
|
139
|
+
f"Attempt {retry_count}/{max_retries} to connect to Ray failed: {e}. "
|
140
|
+
f"Retrying in {retry_delay} seconds..."
|
141
|
+
)
|
142
|
+
time.sleep(retry_delay)
|
143
|
+
# Exponential backoff with cap
|
144
|
+
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
145
|
+
|
146
|
+
return False # Should not reach here, but added for completeness
|
147
|
+
|
148
|
+
|
149
|
+
def monitor_shutdown_signal(check_interval: int, max_consecutive_failures: int) -> int:
|
150
|
+
"""
|
151
|
+
Main loop to monitor for shutdown signals.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
check_interval: Time in seconds between checks
|
155
|
+
max_consecutive_failures: Maximum allowed consecutive connection failures
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
int: Exit code (0 for success, non-zero for failure)
|
159
|
+
|
160
|
+
Raises:
|
161
|
+
ConnectionError: If Ray connection failures exceed threshold
|
162
|
+
"""
|
163
|
+
worker_id = get_instance_ip.get_self_ip()
|
164
|
+
actor_check_count = 0
|
165
|
+
consecutive_connection_failures = 0
|
166
|
+
|
167
|
+
logging.debug(
|
168
|
+
f"Starting to monitor for shutdown signal using actor {SHUTDOWN_ACTOR_NAME}"
|
169
|
+
f" in namespace {SHUTDOWN_ACTOR_NAMESPACE}."
|
170
|
+
)
|
171
|
+
|
172
|
+
while True:
|
173
|
+
actor_check_count += 1
|
174
|
+
|
175
|
+
# Check Ray connectivity before proceeding
|
176
|
+
if not check_ray_connectivity():
|
177
|
+
consecutive_connection_failures += 1
|
178
|
+
logging.debug(
|
179
|
+
f"Ray connectivity check failed (attempt {consecutive_connection_failures}/{max_consecutive_failures})"
|
180
|
+
)
|
181
|
+
if consecutive_connection_failures >= max_consecutive_failures:
|
182
|
+
raise ConnectionError("Exceeded max consecutive Ray connection failures")
|
183
|
+
time.sleep(check_interval)
|
184
|
+
continue
|
185
|
+
|
186
|
+
# Reset counter on successful connection
|
187
|
+
consecutive_connection_failures = 0
|
188
|
+
|
189
|
+
# Get shutdown actor
|
190
|
+
shutdown_signal = get_shutdown_actor()
|
191
|
+
if not shutdown_signal:
|
192
|
+
logging.debug(f"Shutdown signal actor not found at check #{actor_check_count}, continuing to wait...")
|
193
|
+
time.sleep(check_interval)
|
194
|
+
continue
|
195
|
+
|
196
|
+
# Ping the actor to ensure connectivity
|
197
|
+
if not ping_shutdown_actor(shutdown_signal):
|
198
|
+
time.sleep(check_interval)
|
199
|
+
continue
|
200
|
+
|
201
|
+
# Check shutdown status
|
202
|
+
if check_shutdown_status(shutdown_signal, worker_id):
|
203
|
+
return 0
|
204
|
+
|
205
|
+
# Wait before checking again
|
206
|
+
time.sleep(check_interval)
|
207
|
+
|
208
|
+
|
209
|
+
def run_listener() -> int:
|
210
|
+
"""Listen for shutdown signals from the head node"""
|
211
|
+
# Configuration
|
212
|
+
max_retries = 15
|
213
|
+
initial_retry_delay = 2
|
214
|
+
max_retry_delay = 30
|
215
|
+
check_interval = 5 # How often to check for ray connection or shutdown signal
|
216
|
+
max_consecutive_failures = 12 # Exit after about 1 minute of connection failures
|
217
|
+
|
218
|
+
# Initialize Ray connection
|
219
|
+
if not initialize_ray_connection(max_retries, initial_retry_delay, max_retry_delay):
|
220
|
+
raise ConnectionError("Failed to connect to Ray cluster. Aborting worker.")
|
221
|
+
|
222
|
+
# Monitor for shutdown signals
|
223
|
+
return monitor_shutdown_signal(check_interval, max_consecutive_failures)
|
224
|
+
|
225
|
+
|
226
|
+
def main():
|
227
|
+
"""Main entry point with signal handling"""
|
228
|
+
|
229
|
+
def signal_handler(signum, frame):
|
230
|
+
logging.info(f"Received signal {signum}, exiting worker process.")
|
231
|
+
sys.exit(0)
|
232
|
+
|
233
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
234
|
+
signal.signal(signal.SIGINT, signal_handler)
|
235
|
+
|
236
|
+
# Run the listener - this will block until a shutdown signal is received
|
237
|
+
result = run_listener()
|
238
|
+
sys.exit(result)
|
239
|
+
|
240
|
+
|
241
|
+
if __name__ == "__main__":
|
242
|
+
main()
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
from math import ceil
|
3
3
|
from pathlib import PurePath
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
7
|
from snowflake.ml._internal.utils import snowflake_env
|
@@ -15,10 +15,7 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
15
15
|
if not rows:
|
16
16
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
17
17
|
instance_family: str = rows[0]["instance_family"]
|
18
|
-
|
19
|
-
# Get the cloud we're using (AWS, Azure, etc)
|
20
|
-
region = snowflake_env.get_regions(session)[snowflake_env.get_current_region_id(session)]
|
21
|
-
cloud = region["cloud"]
|
18
|
+
cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
|
22
19
|
|
23
20
|
return (
|
24
21
|
constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
|
@@ -26,22 +23,14 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
26
23
|
)
|
27
24
|
|
28
25
|
|
29
|
-
def _get_image_spec(session: snowpark.Session, compute_pool: str
|
26
|
+
def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
|
30
27
|
# Retrieve compute pool node resources
|
31
28
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
32
29
|
|
33
30
|
# Use MLRuntime image
|
34
31
|
image_repo = constants.DEFAULT_IMAGE_REPO
|
35
32
|
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
36
|
-
|
37
|
-
# Try to pull latest image tag from server side if possible
|
38
|
-
if not image_tag:
|
39
|
-
query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
|
40
|
-
if query_result:
|
41
|
-
image_tag = query_result[0]["value"]
|
42
|
-
|
43
|
-
if image_tag is None:
|
44
|
-
image_tag = constants.DEFAULT_IMAGE_TAG
|
33
|
+
image_tag = constants.DEFAULT_IMAGE_TAG
|
45
34
|
|
46
35
|
# TODO: Should each instance consume the entire pod?
|
47
36
|
return types.ImageSpec(
|
@@ -54,9 +43,9 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Opt
|
|
54
43
|
|
55
44
|
|
56
45
|
def generate_spec_overrides(
|
57
|
-
environment_vars: Optional[
|
58
|
-
custom_overrides: Optional[
|
59
|
-
) ->
|
46
|
+
environment_vars: Optional[dict[str, str]] = None,
|
47
|
+
custom_overrides: Optional[dict[str, Any]] = None,
|
48
|
+
) -> dict[str, Any]:
|
60
49
|
"""
|
61
50
|
Generate a dictionary of service specification overrides.
|
62
51
|
|
@@ -68,7 +57,7 @@ def generate_spec_overrides(
|
|
68
57
|
Resulting service specifiation patch dict. Empty if no overrides were supplied.
|
69
58
|
"""
|
70
59
|
# Generate container level overrides
|
71
|
-
container_spec:
|
60
|
+
container_spec: dict[str, Any] = {
|
72
61
|
"name": constants.DEFAULT_CONTAINER_NAME,
|
73
62
|
}
|
74
63
|
if environment_vars:
|
@@ -95,10 +84,10 @@ def generate_service_spec(
|
|
95
84
|
session: snowpark.Session,
|
96
85
|
compute_pool: str,
|
97
86
|
payload: types.UploadedPayload,
|
98
|
-
args: Optional[
|
87
|
+
args: Optional[list[str]] = None,
|
99
88
|
num_instances: Optional[int] = None,
|
100
89
|
enable_metrics: bool = False,
|
101
|
-
) ->
|
90
|
+
) -> dict[str, Any]:
|
102
91
|
"""
|
103
92
|
Generate a service specification for a job.
|
104
93
|
|
@@ -114,20 +103,14 @@ def generate_service_spec(
|
|
114
103
|
Job service specification
|
115
104
|
"""
|
116
105
|
is_multi_node = num_instances is not None and num_instances > 1
|
106
|
+
image_spec = _get_image_spec(session, compute_pool)
|
117
107
|
|
118
108
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
119
|
-
|
120
|
-
# If the job is of multi-node, we will need a different image which contains
|
121
|
-
# module snowflake.runtime.utils.get_instance_ip
|
122
|
-
# TODO(SNOW-1961849): Remove the hard-coded image name
|
123
|
-
image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
|
124
|
-
else:
|
125
|
-
image_spec = _get_image_spec(session, compute_pool)
|
126
|
-
resource_requests: Dict[str, Union[str, int]] = {
|
109
|
+
resource_requests: dict[str, Union[str, int]] = {
|
127
110
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
128
111
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
129
112
|
}
|
130
|
-
resource_limits:
|
113
|
+
resource_limits: dict[str, Union[str, int]] = {
|
131
114
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
132
115
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
133
116
|
}
|
@@ -136,8 +119,8 @@ def generate_service_spec(
|
|
136
119
|
resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
|
137
120
|
|
138
121
|
# Add local volumes for ephemeral logs and artifacts
|
139
|
-
volumes:
|
140
|
-
volume_mounts:
|
122
|
+
volumes: list[dict[str, str]] = []
|
123
|
+
volume_mounts: list[dict[str, str]] = []
|
141
124
|
for volume_name, mount_path in [
|
142
125
|
("system-logs", "/var/log/managedservices/system/mlrs"),
|
143
126
|
("user-logs", "/var/log/managedservices/user/mlrs"),
|
@@ -191,7 +174,10 @@ def generate_service_spec(
|
|
191
174
|
|
192
175
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
193
176
|
|
194
|
-
env_vars = {
|
177
|
+
env_vars = {
|
178
|
+
constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
|
179
|
+
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
180
|
+
}
|
195
181
|
endpoints = []
|
196
182
|
|
197
183
|
if is_multi_node:
|
@@ -305,11 +291,11 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
|
305
291
|
|
306
292
|
|
307
293
|
def _merge_lists_of_dicts(
|
308
|
-
base:
|
309
|
-
patch:
|
294
|
+
base: list[dict[str, Any]],
|
295
|
+
patch: list[dict[str, Any]],
|
310
296
|
merge_key: str = "name",
|
311
297
|
display_name: str = "",
|
312
|
-
) ->
|
298
|
+
) -> list[dict[str, Any]]:
|
313
299
|
"""
|
314
300
|
Attempts to merge lists of dicts by matching on a merge key (default "name").
|
315
301
|
- If the merge key is missing, the behavior falls back to overwriting the list.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
from pathlib import PurePath
|
3
|
-
from typing import
|
3
|
+
from typing import Literal, Optional, Union
|
4
4
|
|
5
5
|
JOB_STATUS = Literal[
|
6
6
|
"PENDING",
|
@@ -11,11 +11,17 @@ JOB_STATUS = Literal[
|
|
11
11
|
]
|
12
12
|
|
13
13
|
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class PayloadEntrypoint:
|
16
|
+
file_path: PurePath
|
17
|
+
main_func: Optional[str]
|
18
|
+
|
19
|
+
|
14
20
|
@dataclass(frozen=True)
|
15
21
|
class UploadedPayload:
|
16
22
|
# TODO: Include manifest of payload files for validation
|
17
23
|
stage_path: PurePath
|
18
|
-
entrypoint:
|
24
|
+
entrypoint: list[Union[str, PurePath]]
|
19
25
|
|
20
26
|
|
21
27
|
@dataclass(frozen=True)
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Callable,
|
3
|
+
from typing import Callable, Optional, TypeVar
|
4
4
|
|
5
5
|
from typing_extensions import ParamSpec
|
6
6
|
|
@@ -15,20 +15,19 @@ _Args = ParamSpec("_Args")
|
|
15
15
|
_ReturnValue = TypeVar("_ReturnValue")
|
16
16
|
|
17
17
|
|
18
|
-
@snowpark._internal.utils.private_preview(version="1.7.4")
|
19
18
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
20
19
|
def remote(
|
21
20
|
compute_pool: str,
|
22
21
|
*,
|
23
22
|
stage_name: str,
|
24
|
-
pip_requirements: Optional[
|
25
|
-
external_access_integrations: Optional[
|
23
|
+
pip_requirements: Optional[list[str]] = None,
|
24
|
+
external_access_integrations: Optional[list[str]] = None,
|
26
25
|
query_warehouse: Optional[str] = None,
|
27
|
-
env_vars: Optional[
|
26
|
+
env_vars: Optional[dict[str, str]] = None,
|
28
27
|
num_instances: Optional[int] = None,
|
29
28
|
enable_metrics: bool = False,
|
30
29
|
session: Optional[snowpark.Session] = None,
|
31
|
-
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
|
30
|
+
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
32
31
|
"""
|
33
32
|
Submit a job to the compute pool.
|
34
33
|
|
@@ -47,7 +46,7 @@ def remote(
|
|
47
46
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
48
47
|
"""
|
49
48
|
|
50
|
-
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
|
49
|
+
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
|
51
50
|
# Copy the function to avoid modifying the original
|
52
51
|
# We need to modify the line number of the function to exclude the
|
53
52
|
# decorator from the copied source code
|
@@ -55,7 +54,7 @@ def remote(
|
|
55
54
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
56
55
|
|
57
56
|
@functools.wraps(func)
|
58
|
-
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
|
57
|
+
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
59
58
|
payload = functools.partial(func, *args, **kwargs)
|
60
59
|
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
61
60
|
job = jm._submit_job(
|