snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,203 @@
1
+ #!/usr/bin/env python3
2
+ # This file is part of the Ray-based distributed job system for Snowflake ML.
3
+ # Architecture overview:
4
+ # - Head node creates a ShutdownSignal actor and signals workers when job completes
5
+ # - Worker nodes listen for this signal and gracefully shut down
6
+ # - This ensures clean termination of distributed Ray jobs
7
+ import argparse
8
+ import logging
9
+ import socket
10
+ import sys
11
+ import time
12
+ from typing import Any
13
+
14
+ import ray
15
+ from constants import (
16
+ SHUTDOWN_ACTOR_NAME,
17
+ SHUTDOWN_ACTOR_NAMESPACE,
18
+ SHUTDOWN_RPC_TIMEOUT_SECONDS,
19
+ )
20
+ from ray.actor import ActorHandle
21
+
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
+
24
+
25
+ @ray.remote
26
+ class ShutdownSignal:
27
+ """A simple Ray actor that workers can check to determine if they should shutdown"""
28
+
29
+ def __init__(self) -> None:
30
+ self.shutdown_requested = False
31
+ self.timestamp = None
32
+ self.hostname = socket.gethostname()
33
+ self.acknowledged_workers = set()
34
+ logging.info(f"ShutdownSignal actor created on {self.hostname}")
35
+
36
+ def request_shutdown(self) -> dict[str, Any]:
37
+ """Signal workers to shut down"""
38
+ self.shutdown_requested = True
39
+ self.timestamp = time.time()
40
+ logging.info(f"Shutdown requested by head node at {self.timestamp}")
41
+ return {"status": "shutdown_requested", "timestamp": self.timestamp, "host": self.hostname}
42
+
43
+ def should_shutdown(self) -> dict[str, Any]:
44
+ """Check if shutdown has been requested"""
45
+ return {"shutdown": self.shutdown_requested, "timestamp": self.timestamp, "host": self.hostname}
46
+
47
+ def ping(self) -> dict[str, Any]:
48
+ """Simple method to test connectivity"""
49
+ return {"status": "alive", "host": self.hostname}
50
+
51
+ def acknowledge_shutdown(self, worker_id: str) -> dict[str, Any]:
52
+ """Worker acknowledges it has received the shutdown signal and is terminating"""
53
+ self.acknowledged_workers.add(worker_id)
54
+ logging.info(f"Worker {worker_id} acknowledged shutdown. Total acknowledged: {len(self.acknowledged_workers)}")
55
+
56
+ return {"status": "acknowledged", "worker_id": worker_id, "acknowledged_count": len(self.acknowledged_workers)}
57
+
58
+ def get_acknowledgment_workers(self) -> set[str]:
59
+ """Get the set of workers who have acknowledged shutdown"""
60
+ return self.acknowledged_workers
61
+
62
+
63
+ def get_worker_node_ids() -> list[str]:
64
+ """Get the IDs of all active worker nodes.
65
+
66
+ Returns:
67
+ List[str]: List of worker node IDs. Empty list if no workers are present.
68
+ """
69
+ worker_nodes = [
70
+ node for node in ray.nodes() if node.get("Alive") and node.get("Resources", {}).get("node_tag:worker", 0) > 0
71
+ ]
72
+
73
+ worker_node_ids = [node.get("NodeName") for node in worker_nodes]
74
+
75
+ if worker_node_ids:
76
+ logging.info(f"Found {len(worker_node_ids)} worker nodes")
77
+ else:
78
+ logging.info("No active worker nodes found")
79
+
80
+ return worker_node_ids
81
+
82
+
83
+ def get_or_create_shutdown_signal() -> ActorHandle:
84
+ """Get existing shutdown signal actor or create a new one.
85
+
86
+ Returns:
87
+ ActorHandle: Reference to shutdown signal actor
88
+ """
89
+ try:
90
+ # Try to get existing actor
91
+ shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
92
+ logging.info("Found existing shutdown signal actor")
93
+ except (ValueError, ray.exceptions.RayActorError) as e:
94
+ logging.info(f"Creating new shutdown signal actor: {e}")
95
+ # Create new actor if it doesn't exist
96
+ shutdown_signal = ShutdownSignal.options(
97
+ name=SHUTDOWN_ACTOR_NAME,
98
+ namespace=SHUTDOWN_ACTOR_NAMESPACE,
99
+ lifetime="detached", # Ensure actor survives client disconnect
100
+ resources={"node_tag:head": 0.001}, # Resource constraint to ensure it runs on head node
101
+ ).remote()
102
+
103
+ # Verify actor is created and accessible
104
+ ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
105
+ logging.debug(f"New actor ping response: {ping_result}")
106
+
107
+ return shutdown_signal
108
+
109
+
110
+ def request_shutdown(shutdown_signal: ActorHandle) -> None:
111
+ """Request workers to shut down.
112
+
113
+ Args:
114
+ shutdown_signal: Reference to the shutdown signal actor
115
+ """
116
+ response = ray.get(shutdown_signal.request_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
117
+ logging.info(f"Shutdown requested: {response}")
118
+
119
+
120
+ def verify_shutdown(shutdown_signal: ActorHandle) -> None:
121
+ """Verify that shutdown was properly signaled.
122
+
123
+ Args:
124
+ shutdown_signal: Reference to the shutdown signal actor
125
+ """
126
+ check = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
127
+ logging.debug(f"Shutdown status check: {check}")
128
+
129
+
130
+ def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: list[str], wait_time: int) -> None:
131
+ """Wait for workers to acknowledge shutdown.
132
+
133
+ Args:
134
+ shutdown_signal: Reference to the shutdown signal actor
135
+ worker_node_ids: List of worker node IDs
136
+ wait_time: Time in seconds to wait for acknowledgments
137
+
138
+ Raises:
139
+ TimeoutError: When workers don't acknowledge within the wait time or if actor communication times out
140
+ """
141
+ if not worker_node_ids:
142
+ return
143
+
144
+ logging.info(f"Waiting up to {wait_time}s for workers to acknowledge shutdown signal...")
145
+ start_time = time.time()
146
+ check_interval = 1.0
147
+
148
+ while time.time() - start_time < wait_time:
149
+ try:
150
+ ack_workers = ray.get(
151
+ shutdown_signal.get_acknowledgment_workers.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
152
+ )
153
+ if ack_workers and ack_workers == set(worker_node_ids):
154
+ logging.info(
155
+ f"All {len(worker_node_ids)} workers acknowledged shutdown. "
156
+ f"Completed in {time.time() - start_time:.2f}s"
157
+ )
158
+ return
159
+ else:
160
+ logging.debug(f"Waiting for acknowledgments: {len(ack_workers)}/{len(worker_node_ids)} workers")
161
+ except Exception as e:
162
+ logging.warning(f"Error checking acknowledgment status: {e}")
163
+
164
+ time.sleep(check_interval)
165
+
166
+ raise TimeoutError(
167
+ f"Timed out waiting for {len(worker_node_ids)} workers to acknowledge shutdown after {wait_time}s"
168
+ )
169
+
170
+
171
+ def signal_workers(wait_time: int = 10) -> int:
172
+ """
173
+ Signal worker nodes to shut down by creating a shutdown signal actor.
174
+
175
+ Args:
176
+ wait_time: Time in seconds to wait for workers to receive the message
177
+
178
+ Returns:
179
+ 0 for success, 1 for failure
180
+ """
181
+ ray.init(address="auto", ignore_reinit_error=True)
182
+
183
+ worker_node_ids = get_worker_node_ids()
184
+
185
+ if worker_node_ids:
186
+ shutdown_signal = get_or_create_shutdown_signal()
187
+ request_shutdown(shutdown_signal)
188
+ verify_shutdown(shutdown_signal)
189
+ wait_for_acknowledgments(shutdown_signal, worker_node_ids, wait_time)
190
+ else:
191
+ logging.info("No active worker nodes found to signal.")
192
+
193
+ return 0
194
+
195
+
196
+ if __name__ == "__main__":
197
+ parser = argparse.ArgumentParser(description="Signal Ray workers to shutdown")
198
+ parser.add_argument(
199
+ "--wait-time", type=int, default=10, help="Time in seconds to wait for workers to receive the signal"
200
+ )
201
+ args = parser.parse_args()
202
+
203
+ sys.exit(signal_workers(args.wait_time))
@@ -0,0 +1,242 @@
1
+ #!/usr/bin/env python3
2
+ # This file is part of the Ray-based distributed job system for Snowflake ML.
3
+ # Architecture overview:
4
+ # - Head node creates a ShutdownSignal actor and signals workers when job completes
5
+ # - Worker nodes listen for this signal via this script and gracefully shut down
6
+ # - This ensures clean termination of distributed Ray jobs
7
+ import logging
8
+ import signal
9
+ import sys
10
+ import time
11
+ from typing import Optional
12
+
13
+ import get_instance_ip
14
+ import ray
15
+ from constants import (
16
+ SHUTDOWN_ACTOR_NAME,
17
+ SHUTDOWN_ACTOR_NAMESPACE,
18
+ SHUTDOWN_RPC_TIMEOUT_SECONDS,
19
+ )
20
+ from ray.actor import ActorHandle
21
+
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
+
24
+
25
+ def get_shutdown_actor() -> Optional[ActorHandle]:
26
+ """
27
+ Retrieve the shutdown signal actor from Ray.
28
+
29
+ Returns:
30
+ The shutdown signal actor or None if not found
31
+ """
32
+ try:
33
+ shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
34
+ return shutdown_signal
35
+ except Exception:
36
+ return None
37
+
38
+
39
+ def ping_shutdown_actor(shutdown_signal: ActorHandle) -> bool:
40
+ """
41
+ Ping the shutdown actor to ensure connectivity.
42
+
43
+ Args:
44
+ shutdown_signal: The Ray actor handle for the shutdown signal
45
+
46
+ Returns:
47
+ True if ping succeeds, False otherwise
48
+ """
49
+ try:
50
+ ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
51
+ logging.debug(f"Actor ping result: {ping_result}")
52
+ return True
53
+ except (ray.exceptions.GetTimeoutError, Exception) as e:
54
+ logging.debug(f"Actor ping failed: {e}")
55
+ return False
56
+
57
+
58
+ def check_shutdown_status(shutdown_signal: ActorHandle, worker_id: str) -> bool:
59
+ """
60
+ Check if worker should shutdown and acknowledge if needed.
61
+
62
+ Args:
63
+ shutdown_signal: The Ray actor handle for the shutdown signal
64
+ worker_id: Worker identifier (IP address)
65
+
66
+ Returns:
67
+ True if should shutdown, False otherwise
68
+ """
69
+ try:
70
+ status = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
71
+ logging.debug(f"Shutdown status: {status}")
72
+
73
+ if status.get("shutdown", False):
74
+ logging.info(
75
+ f"Received shutdown signal from head node at {status.get('timestamp')}. " f"Exiting worker process."
76
+ )
77
+
78
+ # Acknowledge shutdown before exiting
79
+ try:
80
+ ack_result = ray.get(
81
+ shutdown_signal.acknowledge_shutdown.remote(worker_id), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
82
+ )
83
+ logging.info(f"Acknowledged shutdown: {ack_result}")
84
+ except Exception as e:
85
+ logging.warning(f"Failed to acknowledge shutdown: {e}. Continue to exit worker.")
86
+
87
+ return True
88
+ return False
89
+
90
+ except Exception as e:
91
+ logging.debug(f"Error checking shutdown status: {e}")
92
+ return False
93
+
94
+
95
+ def check_ray_connectivity() -> bool:
96
+ """
97
+ Check if the Ray cluster is accessible.
98
+
99
+ Returns:
100
+ True if Ray is connected, False otherwise
101
+ """
102
+ try:
103
+ # A simple check to verify Ray is working
104
+ nodes = ray.nodes()
105
+ if nodes:
106
+ return True
107
+ return False
108
+ except Exception as e:
109
+ logging.debug(f"Ray connectivity check failed: {e}")
110
+ return False
111
+
112
+
113
+ def initialize_ray_connection(max_retries: int, initial_retry_delay: int, max_retry_delay: int) -> bool:
114
+ """
115
+ Initialize connection to Ray with retries.
116
+
117
+ Args:
118
+ max_retries: Maximum number of connection attempts
119
+ initial_retry_delay: Initial delay between retries in seconds
120
+ max_retry_delay: Maximum delay between retries in seconds
121
+
122
+ Returns:
123
+ bool: True if connection successful, False otherwise
124
+ """
125
+ retry_count = 0
126
+ retry_delay = initial_retry_delay
127
+
128
+ while retry_count < max_retries:
129
+ try:
130
+ ray.init(address="auto", ignore_reinit_error=True)
131
+ return True
132
+ except (ConnectionError, TimeoutError, RuntimeError) as e:
133
+ retry_count += 1
134
+ if retry_count >= max_retries:
135
+ logging.error(f"Failed to connect to Ray head after {max_retries} attempts: {e}")
136
+ return False
137
+
138
+ logging.debug(
139
+ f"Attempt {retry_count}/{max_retries} to connect to Ray failed: {e}. "
140
+ f"Retrying in {retry_delay} seconds..."
141
+ )
142
+ time.sleep(retry_delay)
143
+ # Exponential backoff with cap
144
+ retry_delay = min(retry_delay * 1.5, max_retry_delay)
145
+
146
+ return False # Should not reach here, but added for completeness
147
+
148
+
149
+ def monitor_shutdown_signal(check_interval: int, max_consecutive_failures: int) -> int:
150
+ """
151
+ Main loop to monitor for shutdown signals.
152
+
153
+ Args:
154
+ check_interval: Time in seconds between checks
155
+ max_consecutive_failures: Maximum allowed consecutive connection failures
156
+
157
+ Returns:
158
+ int: Exit code (0 for success, non-zero for failure)
159
+
160
+ Raises:
161
+ ConnectionError: If Ray connection failures exceed threshold
162
+ """
163
+ worker_id = get_instance_ip.get_self_ip()
164
+ actor_check_count = 0
165
+ consecutive_connection_failures = 0
166
+
167
+ logging.debug(
168
+ f"Starting to monitor for shutdown signal using actor {SHUTDOWN_ACTOR_NAME}"
169
+ f" in namespace {SHUTDOWN_ACTOR_NAMESPACE}."
170
+ )
171
+
172
+ while True:
173
+ actor_check_count += 1
174
+
175
+ # Check Ray connectivity before proceeding
176
+ if not check_ray_connectivity():
177
+ consecutive_connection_failures += 1
178
+ logging.debug(
179
+ f"Ray connectivity check failed (attempt {consecutive_connection_failures}/{max_consecutive_failures})"
180
+ )
181
+ if consecutive_connection_failures >= max_consecutive_failures:
182
+ raise ConnectionError("Exceeded max consecutive Ray connection failures")
183
+ time.sleep(check_interval)
184
+ continue
185
+
186
+ # Reset counter on successful connection
187
+ consecutive_connection_failures = 0
188
+
189
+ # Get shutdown actor
190
+ shutdown_signal = get_shutdown_actor()
191
+ if not shutdown_signal:
192
+ logging.debug(f"Shutdown signal actor not found at check #{actor_check_count}, continuing to wait...")
193
+ time.sleep(check_interval)
194
+ continue
195
+
196
+ # Ping the actor to ensure connectivity
197
+ if not ping_shutdown_actor(shutdown_signal):
198
+ time.sleep(check_interval)
199
+ continue
200
+
201
+ # Check shutdown status
202
+ if check_shutdown_status(shutdown_signal, worker_id):
203
+ return 0
204
+
205
+ # Wait before checking again
206
+ time.sleep(check_interval)
207
+
208
+
209
+ def run_listener() -> int:
210
+ """Listen for shutdown signals from the head node"""
211
+ # Configuration
212
+ max_retries = 15
213
+ initial_retry_delay = 2
214
+ max_retry_delay = 30
215
+ check_interval = 5 # How often to check for ray connection or shutdown signal
216
+ max_consecutive_failures = 12 # Exit after about 1 minute of connection failures
217
+
218
+ # Initialize Ray connection
219
+ if not initialize_ray_connection(max_retries, initial_retry_delay, max_retry_delay):
220
+ raise ConnectionError("Failed to connect to Ray cluster. Aborting worker.")
221
+
222
+ # Monitor for shutdown signals
223
+ return monitor_shutdown_signal(check_interval, max_consecutive_failures)
224
+
225
+
226
+ def main():
227
+ """Main entry point with signal handling"""
228
+
229
+ def signal_handler(signum, frame):
230
+ logging.info(f"Received signal {signum}, exiting worker process.")
231
+ sys.exit(0)
232
+
233
+ signal.signal(signal.SIGTERM, signal_handler)
234
+ signal.signal(signal.SIGINT, signal_handler)
235
+
236
+ # Run the listener - this will block until a shutdown signal is received
237
+ result = run_listener()
238
+ sys.exit(result)
239
+
240
+
241
+ if __name__ == "__main__":
242
+ main()
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from math import ceil
3
3
  from pathlib import PurePath
4
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Optional, Union
5
5
 
6
6
  from snowflake import snowpark
7
7
  from snowflake.ml._internal.utils import snowflake_env
@@ -15,10 +15,7 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
15
15
  if not rows:
16
16
  raise ValueError(f"Compute pool '{compute_pool}' not found")
17
17
  instance_family: str = rows[0]["instance_family"]
18
-
19
- # Get the cloud we're using (AWS, Azure, etc)
20
- region = snowflake_env.get_regions(session)[snowflake_env.get_current_region_id(session)]
21
- cloud = region["cloud"]
18
+ cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
22
19
 
23
20
  return (
24
21
  constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
@@ -26,22 +23,14 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
26
23
  )
27
24
 
28
25
 
29
- def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Optional[str] = None) -> types.ImageSpec:
26
+ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
30
27
  # Retrieve compute pool node resources
31
28
  resources = _get_node_resources(session, compute_pool=compute_pool)
32
29
 
33
30
  # Use MLRuntime image
34
31
  image_repo = constants.DEFAULT_IMAGE_REPO
35
32
  image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
36
-
37
- # Try to pull latest image tag from server side if possible
38
- if not image_tag:
39
- query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
40
- if query_result:
41
- image_tag = query_result[0]["value"]
42
-
43
- if image_tag is None:
44
- image_tag = constants.DEFAULT_IMAGE_TAG
33
+ image_tag = constants.DEFAULT_IMAGE_TAG
45
34
 
46
35
  # TODO: Should each instance consume the entire pod?
47
36
  return types.ImageSpec(
@@ -54,9 +43,9 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Opt
54
43
 
55
44
 
56
45
  def generate_spec_overrides(
57
- environment_vars: Optional[Dict[str, str]] = None,
58
- custom_overrides: Optional[Dict[str, Any]] = None,
59
- ) -> Dict[str, Any]:
46
+ environment_vars: Optional[dict[str, str]] = None,
47
+ custom_overrides: Optional[dict[str, Any]] = None,
48
+ ) -> dict[str, Any]:
60
49
  """
61
50
  Generate a dictionary of service specification overrides.
62
51
 
@@ -68,7 +57,7 @@ def generate_spec_overrides(
68
57
  Resulting service specifiation patch dict. Empty if no overrides were supplied.
69
58
  """
70
59
  # Generate container level overrides
71
- container_spec: Dict[str, Any] = {
60
+ container_spec: dict[str, Any] = {
72
61
  "name": constants.DEFAULT_CONTAINER_NAME,
73
62
  }
74
63
  if environment_vars:
@@ -95,10 +84,10 @@ def generate_service_spec(
95
84
  session: snowpark.Session,
96
85
  compute_pool: str,
97
86
  payload: types.UploadedPayload,
98
- args: Optional[List[str]] = None,
87
+ args: Optional[list[str]] = None,
99
88
  num_instances: Optional[int] = None,
100
89
  enable_metrics: bool = False,
101
- ) -> Dict[str, Any]:
90
+ ) -> dict[str, Any]:
102
91
  """
103
92
  Generate a service specification for a job.
104
93
 
@@ -114,20 +103,14 @@ def generate_service_spec(
114
103
  Job service specification
115
104
  """
116
105
  is_multi_node = num_instances is not None and num_instances > 1
106
+ image_spec = _get_image_spec(session, compute_pool)
117
107
 
118
108
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
119
- if is_multi_node:
120
- # If the job is of multi-node, we will need a different image which contains
121
- # module snowflake.runtime.utils.get_instance_ip
122
- # TODO(SNOW-1961849): Remove the hard-coded image name
123
- image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
124
- else:
125
- image_spec = _get_image_spec(session, compute_pool)
126
- resource_requests: Dict[str, Union[str, int]] = {
109
+ resource_requests: dict[str, Union[str, int]] = {
127
110
  "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
128
111
  "memory": f"{image_spec.resource_limits.memory}Gi",
129
112
  }
130
- resource_limits: Dict[str, Union[str, int]] = {
113
+ resource_limits: dict[str, Union[str, int]] = {
131
114
  "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
132
115
  "memory": f"{image_spec.resource_limits.memory}Gi",
133
116
  }
@@ -136,8 +119,8 @@ def generate_service_spec(
136
119
  resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
137
120
 
138
121
  # Add local volumes for ephemeral logs and artifacts
139
- volumes: List[Dict[str, str]] = []
140
- volume_mounts: List[Dict[str, str]] = []
122
+ volumes: list[dict[str, str]] = []
123
+ volume_mounts: list[dict[str, str]] = []
141
124
  for volume_name, mount_path in [
142
125
  ("system-logs", "/var/log/managedservices/system/mlrs"),
143
126
  ("user-logs", "/var/log/managedservices/user/mlrs"),
@@ -191,7 +174,10 @@ def generate_service_spec(
191
174
 
192
175
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
193
176
 
194
- env_vars = {constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix()}
177
+ env_vars = {
178
+ constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
179
+ constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
180
+ }
195
181
  endpoints = []
196
182
 
197
183
  if is_multi_node:
@@ -305,11 +291,11 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
305
291
 
306
292
 
307
293
  def _merge_lists_of_dicts(
308
- base: List[Dict[str, Any]],
309
- patch: List[Dict[str, Any]],
294
+ base: list[dict[str, Any]],
295
+ patch: list[dict[str, Any]],
310
296
  merge_key: str = "name",
311
297
  display_name: str = "",
312
- ) -> List[Dict[str, Any]]:
298
+ ) -> list[dict[str, Any]]:
313
299
  """
314
300
  Attempts to merge lists of dicts by matching on a merge key (default "name").
315
301
  - If the merge key is missing, the behavior falls back to overwriting the list.
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from pathlib import PurePath
3
- from typing import List, Literal, Optional, Union
3
+ from typing import Literal, Optional, Union
4
4
 
5
5
  JOB_STATUS = Literal[
6
6
  "PENDING",
@@ -11,11 +11,17 @@ JOB_STATUS = Literal[
11
11
  ]
12
12
 
13
13
 
14
+ @dataclass(frozen=True)
15
+ class PayloadEntrypoint:
16
+ file_path: PurePath
17
+ main_func: Optional[str]
18
+
19
+
14
20
  @dataclass(frozen=True)
15
21
  class UploadedPayload:
16
22
  # TODO: Include manifest of payload files for validation
17
23
  stage_path: PurePath
18
- entrypoint: List[Union[str, PurePath]]
24
+ entrypoint: list[Union[str, PurePath]]
19
25
 
20
26
 
21
27
  @dataclass(frozen=True)
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Callable, Dict, List, Optional, TypeVar
3
+ from typing import Callable, Optional, TypeVar
4
4
 
5
5
  from typing_extensions import ParamSpec
6
6
 
@@ -15,20 +15,19 @@ _Args = ParamSpec("_Args")
15
15
  _ReturnValue = TypeVar("_ReturnValue")
16
16
 
17
17
 
18
- @snowpark._internal.utils.private_preview(version="1.7.4")
19
18
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
20
19
  def remote(
21
20
  compute_pool: str,
22
21
  *,
23
22
  stage_name: str,
24
- pip_requirements: Optional[List[str]] = None,
25
- external_access_integrations: Optional[List[str]] = None,
23
+ pip_requirements: Optional[list[str]] = None,
24
+ external_access_integrations: Optional[list[str]] = None,
26
25
  query_warehouse: Optional[str] = None,
27
- env_vars: Optional[Dict[str, str]] = None,
26
+ env_vars: Optional[dict[str, str]] = None,
28
27
  num_instances: Optional[int] = None,
29
28
  enable_metrics: bool = False,
30
29
  session: Optional[snowpark.Session] = None,
31
- ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
30
+ ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
32
31
  """
33
32
  Submit a job to the compute pool.
34
33
 
@@ -47,7 +46,7 @@ def remote(
47
46
  Decorator that dispatches invocations of the decorated function as remote jobs.
48
47
  """
49
48
 
50
- def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
49
+ def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
51
50
  # Copy the function to avoid modifying the original
52
51
  # We need to modify the line number of the function to exclude the
53
52
  # decorator from the copied source code
@@ -55,7 +54,7 @@ def remote(
55
54
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
56
55
 
57
56
  @functools.wraps(func)
58
- def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
57
+ def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
59
58
  payload = functools.partial(func, *args, **kwargs)
60
59
  setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
61
60
  job = jm._submit_job(