snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. snowflake/cortex/_complete.py +44 -10
  2. snowflake/ml/_internal/platform_capabilities.py +39 -3
  3. snowflake/ml/data/data_connector.py +25 -0
  4. snowflake/ml/dataset/dataset_reader.py +5 -1
  5. snowflake/ml/jobs/_utils/constants.py +2 -4
  6. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +81 -47
  8. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  9. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
  11. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  12. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  13. snowflake/ml/jobs/_utils/spec_utils.py +5 -8
  14. snowflake/ml/jobs/_utils/types.py +6 -0
  15. snowflake/ml/jobs/decorators.py +3 -3
  16. snowflake/ml/jobs/job.py +145 -23
  17. snowflake/ml/jobs/manager.py +62 -10
  18. snowflake/ml/model/_client/ops/service_ops.py +42 -35
  19. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
  20. snowflake/ml/model/_client/sql/service.py +9 -5
  21. snowflake/ml/model/_model_composer/model_composer.py +29 -11
  22. snowflake/ml/model/_packager/model_env/model_env.py +8 -2
  23. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
  24. snowflake/ml/model/_packager/model_packager.py +2 -0
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  26. snowflake/ml/model/type_hints.py +2 -0
  27. snowflake/ml/registry/_manager/model_manager.py +20 -1
  28. snowflake/ml/registry/registry.py +5 -1
  29. snowflake/ml/version.py +1 -1
  30. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +35 -4
  31. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +34 -28
  32. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +0 -0
  33. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  34. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,136 @@
1
+ #!/usr/bin/env python3
2
+ # This file is modified from mlruntime/service/snowflake/runtime/utils
3
+ import argparse
4
+ import logging
5
+ import socket
6
+ import sys
7
+ import time
8
+ from typing import Optional
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_self_ip() -> Optional[str]:
16
+ """Get the IP address of the current service instance.
17
+ References:
18
+ - https://docs.snowflake.com/en/developer-guide/snowpark-container-services/working-with-services#general-guidelines-related-to-service-to-service-communications # noqa: E501
19
+
20
+ Returns:
21
+ Optional[str]: The IP address of the current service instance, or None if unable to retrieve.
22
+ """
23
+ try:
24
+ hostname = socket.gethostname()
25
+ instance_ip = socket.gethostbyname(hostname)
26
+ return instance_ip
27
+ except OSError as e:
28
+ logger.error(f"Error: Unable to get IP address via socket. {e}")
29
+ return None
30
+
31
+
32
+ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
33
+ """Get the first instance of a batch job based on start time and instance ID.
34
+
35
+ Args:
36
+ service_name (str): The name of the service to query.
37
+
38
+ Returns:
39
+ tuple[str, str]: A tuple containing (instance_id, ip_address) of the head instance.
40
+ """
41
+ from snowflake.runtime.utils import session_utils
42
+
43
+ session = session_utils.get_session()
44
+ df = session.sql(f"show service instances in service {service_name}")
45
+ result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
46
+
47
+ if not result:
48
+ return None
49
+
50
+ # Sort by start_time first, then by instance_id
51
+ sorted_instances = sorted(result, key=lambda x: (x["start_time"], int(x["instance_id"])))
52
+ head_instance = sorted_instances[0]
53
+ if not head_instance["instance_id"] or not head_instance["ip_address"]:
54
+ return None
55
+
56
+ # Validate head instance IP
57
+ ip_address = head_instance["ip_address"]
58
+ try:
59
+ socket.inet_aton(ip_address) # Validate IPv4 address
60
+ return (head_instance["instance_id"], ip_address)
61
+ except OSError:
62
+ logger.error(f"Error: Invalid IP address format: {ip_address}")
63
+ return None
64
+
65
+
66
+ def main():
67
+ """Retrieves the IP address of a specified service instance or the current service.
68
+ Args:
69
+ service_name (str,required) Name of the service to query
70
+ --instance-index (int, optional) Index of the service instance to query. Default: -1
71
+ Currently only supports -1 to get the IP address of the current service instance.
72
+ --head (bool, optional) Get the head instance information using show services.
73
+ If set, instance-index will be ignored, and the script will return the index and IP address of
74
+ the head instance, split by a space. Default: False.
75
+ --timeout (int, optional) Maximum time to wait for IP address retrieval in seconds. Default: 720 seconds
76
+ --retry-interval (int, optional) Time to wait between retry attempts in seconds. Default: 10 seconds
77
+ Usage Examples:
78
+ python get_instance_ip.py myservice --instance-index=1 --retry-interval=5
79
+ Returns:
80
+ Prints the IP address to stdout if successful. Exits with status code 0 on success, 1 on failure
81
+ """
82
+
83
+ parser = argparse.ArgumentParser(description="Get IP address of a service instance")
84
+ group = parser.add_mutually_exclusive_group()
85
+ parser.add_argument("service_name", help="Name of the service")
86
+ group.add_argument(
87
+ "--instance-index",
88
+ type=int,
89
+ default=-1,
90
+ help="Index of service instance (default: -1 for self instance)",
91
+ )
92
+ group.add_argument(
93
+ "--head",
94
+ action="store_true",
95
+ help="Get head instance information using show services",
96
+ )
97
+ parser.add_argument("--timeout", type=int, default=720, help="Timeout in seconds (default: 720)")
98
+ parser.add_argument(
99
+ "--retry-interval",
100
+ type=int,
101
+ default=10,
102
+ help="Retry interval in seconds (default: 10)",
103
+ )
104
+
105
+ args = parser.parse_args()
106
+ start_time = time.time()
107
+
108
+ if args.head:
109
+ while time.time() - start_time < args.timeout:
110
+ head_info = get_first_instance(args.service_name)
111
+ if head_info:
112
+ # Print to stdout to allow capture but don't use logger
113
+ sys.stdout.write(f"{head_info[0]} {head_info[1]}\n")
114
+ sys.exit(0)
115
+ time.sleep(args.retry_interval)
116
+ # If we get here, we've timed out
117
+ logger.error("Error: Unable to retrieve head IP address")
118
+ sys.exit(1)
119
+
120
+ # If the index is -1, use get_self_ip to get the IP address of the current service
121
+ if args.instance_index == -1:
122
+ ip_address = get_self_ip()
123
+ if ip_address:
124
+ sys.stdout.write(f"{ip_address}\n")
125
+ sys.exit(0)
126
+ else:
127
+ logger.error("Error: Unable to retrieve self IP address")
128
+ sys.exit(1)
129
+ else:
130
+ # We don't support querying a specific instance index other than -1
131
+ logger.error("Error: Invalid arguments. Only --instance-index=-1 is supported for now.")
132
+ sys.exit(1)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
@@ -0,0 +1,178 @@
1
+ import argparse
2
+ import importlib.util
3
+ import json
4
+ import os
5
+ import runpy
6
+ import sys
7
+ import traceback
8
+ import warnings
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional
11
+
12
+ import cloudpickle
13
+
14
+ from snowflake.ml.jobs._utils import constants
15
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
16
+ from snowflake.snowpark import Session
17
+
18
+ # Fallbacks in case of SnowML version mismatch
19
+ RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
20
+
21
+ JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
22
+
23
+
24
+ try:
25
+ from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
26
+ except ImportError:
27
+ from dataclasses import dataclass
28
+
29
+ @dataclass(frozen=True)
30
+ class ExecutionResult:
31
+ result: Optional[Any] = None
32
+ exception: Optional[BaseException] = None
33
+
34
+ @property
35
+ def success(self) -> bool:
36
+ return self.exception is None
37
+
38
+ def to_dict(self) -> Dict[str, Any]:
39
+ """Return the serializable dictionary."""
40
+ if isinstance(self.exception, BaseException):
41
+ exc_type = type(self.exception)
42
+ return {
43
+ "success": False,
44
+ "exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
45
+ "exc_value": self.exception,
46
+ "exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
47
+ }
48
+ return {
49
+ "success": True,
50
+ "result_type": type(self.result).__qualname__,
51
+ "result": self.result,
52
+ }
53
+
54
+
55
+ # Create a custom JSON encoder that converts non-serializable types to strings
56
+ class SimpleJSONEncoder(json.JSONEncoder):
57
+ def default(self, obj: Any) -> Any:
58
+ try:
59
+ return super().default(obj)
60
+ except TypeError:
61
+ return str(obj)
62
+
63
+
64
+ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
65
+ """
66
+ Execute a Python script and return its result.
67
+
68
+ Args:
69
+ script_path: Path to the Python script
70
+ script_args: Arguments to pass to the script
71
+ main_func: The name of the function to call in the script (if any)
72
+
73
+ Returns:
74
+ Result from script execution, either from the main function or the script's __return__ value
75
+
76
+ Raises:
77
+ RuntimeError: If the specified main_func is not found or not callable
78
+ """
79
+ # Save original sys.argv and modify it for the script (applies to runpy execution only)
80
+ original_argv = sys.argv
81
+ sys.argv = [script_path, *script_args]
82
+
83
+ # Create a Snowpark session before running the script
84
+ # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
85
+ session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
86
+
87
+ try:
88
+ if main_func:
89
+ # Use importlib for scripts with a main function defined
90
+ module_name = Path(script_path).stem
91
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
92
+ assert spec is not None
93
+ assert spec.loader is not None
94
+ module = importlib.util.module_from_spec(spec)
95
+ spec.loader.exec_module(module)
96
+
97
+ # Validate main function
98
+ if not (func := getattr(module, main_func, None)) or not callable(func):
99
+ raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
100
+
101
+ # Call main function
102
+ result = func(*script_args)
103
+ return result
104
+ else:
105
+ # Use runpy for other scripts
106
+ globals_dict = runpy.run_path(script_path, run_name="__main__")
107
+ result = globals_dict.get("__return__", None)
108
+ return result
109
+ finally:
110
+ # Restore original sys.argv
111
+ sys.argv = original_argv
112
+
113
+
114
+ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
115
+ """Executes a Python script and serializes the result to JOB_RESULT_PATH.
116
+
117
+ Args:
118
+ script_path (str): Path to the Python script to execute.
119
+ script_args (Any): Arguments to pass to the script.
120
+ script_main_func (str, optional): The name of the function to call in the script (if any).
121
+
122
+ Returns:
123
+ ExecutionResult: Object containing execution results.
124
+
125
+ Raises:
126
+ Exception: Re-raises any exception caught during script execution.
127
+ """
128
+ # Run the script with the specified arguments
129
+ try:
130
+ result = run_script(script_path, *script_args, main_func=script_main_func)
131
+ result_obj = ExecutionResult(result=result)
132
+ return result_obj
133
+ except Exception as e:
134
+ tb = e.__traceback__
135
+ skip_files = {__file__, runpy.__file__}
136
+ while tb and tb.tb_frame.f_code.co_filename in skip_files:
137
+ # Skip any frames preceding user script execution
138
+ tb = tb.tb_next
139
+ result_obj = ExecutionResult(exception=e.with_traceback(tb))
140
+ raise
141
+ finally:
142
+ result_dict = result_obj.to_dict()
143
+ try:
144
+ # Serialize result using cloudpickle
145
+ result_pickle_path = JOB_RESULT_PATH
146
+ with open(result_pickle_path, "wb") as f:
147
+ cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
148
+ except Exception as pkl_exc:
149
+ warnings.warn(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}", RuntimeWarning, stacklevel=1)
150
+
151
+ try:
152
+ # Serialize result to JSON as fallback path in case of cross version incompatibility
153
+ # TODO: Manually convert non-serializable types to strings
154
+ result_json_path = os.path.splitext(JOB_RESULT_PATH)[0] + ".json"
155
+ with open(result_json_path, "w") as f:
156
+ json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
157
+ except Exception as json_exc:
158
+ warnings.warn(
159
+ f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
160
+ )
161
+
162
+
163
+ if __name__ == "__main__":
164
+ # Parse command line arguments
165
+ parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
166
+ parser.add_argument("script_path", help="Path to the Python script to execute")
167
+ parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
168
+ parser.add_argument(
169
+ "--script_main_func", required=False, help="The name of the main function to call in the script"
170
+ )
171
+ args, unknown_args = parser.parse_known_args()
172
+
173
+ main(
174
+ args.script_path,
175
+ *args.script_args,
176
+ *unknown_args,
177
+ script_main_func=args.script_main_func,
178
+ )
@@ -0,0 +1,203 @@
1
+ #!/usr/bin/env python3
2
+ # This file is part of the Ray-based distributed job system for Snowflake ML.
3
+ # Architecture overview:
4
+ # - Head node creates a ShutdownSignal actor and signals workers when job completes
5
+ # - Worker nodes listen for this signal and gracefully shut down
6
+ # - This ensures clean termination of distributed Ray jobs
7
+ import argparse
8
+ import logging
9
+ import socket
10
+ import sys
11
+ import time
12
+ from typing import Any, Dict, List, Set
13
+
14
+ import ray
15
+ from constants import (
16
+ SHUTDOWN_ACTOR_NAME,
17
+ SHUTDOWN_ACTOR_NAMESPACE,
18
+ SHUTDOWN_RPC_TIMEOUT_SECONDS,
19
+ )
20
+ from ray.actor import ActorHandle
21
+
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
+
24
+
25
+ @ray.remote
26
+ class ShutdownSignal:
27
+ """A simple Ray actor that workers can check to determine if they should shutdown"""
28
+
29
+ def __init__(self) -> None:
30
+ self.shutdown_requested = False
31
+ self.timestamp = None
32
+ self.hostname = socket.gethostname()
33
+ self.acknowledged_workers = set()
34
+ logging.info(f"ShutdownSignal actor created on {self.hostname}")
35
+
36
+ def request_shutdown(self) -> Dict[str, Any]:
37
+ """Signal workers to shut down"""
38
+ self.shutdown_requested = True
39
+ self.timestamp = time.time()
40
+ logging.info(f"Shutdown requested by head node at {self.timestamp}")
41
+ return {"status": "shutdown_requested", "timestamp": self.timestamp, "host": self.hostname}
42
+
43
+ def should_shutdown(self) -> Dict[str, Any]:
44
+ """Check if shutdown has been requested"""
45
+ return {"shutdown": self.shutdown_requested, "timestamp": self.timestamp, "host": self.hostname}
46
+
47
+ def ping(self) -> Dict[str, Any]:
48
+ """Simple method to test connectivity"""
49
+ return {"status": "alive", "host": self.hostname}
50
+
51
+ def acknowledge_shutdown(self, worker_id: str) -> Dict[str, Any]:
52
+ """Worker acknowledges it has received the shutdown signal and is terminating"""
53
+ self.acknowledged_workers.add(worker_id)
54
+ logging.info(f"Worker {worker_id} acknowledged shutdown. Total acknowledged: {len(self.acknowledged_workers)}")
55
+
56
+ return {"status": "acknowledged", "worker_id": worker_id, "acknowledged_count": len(self.acknowledged_workers)}
57
+
58
+ def get_acknowledgment_workers(self) -> Set[str]:
59
+ """Get the set of workers who have acknowledged shutdown"""
60
+ return self.acknowledged_workers
61
+
62
+
63
+ def get_worker_node_ids() -> List[str]:
64
+ """Get the IDs of all active worker nodes.
65
+
66
+ Returns:
67
+ List[str]: List of worker node IDs. Empty list if no workers are present.
68
+ """
69
+ worker_nodes = [
70
+ node for node in ray.nodes() if node.get("Alive") and node.get("Resources", {}).get("node_tag:worker", 0) > 0
71
+ ]
72
+
73
+ worker_node_ids = [node.get("NodeName") for node in worker_nodes]
74
+
75
+ if worker_node_ids:
76
+ logging.info(f"Found {len(worker_node_ids)} worker nodes")
77
+ else:
78
+ logging.info("No active worker nodes found")
79
+
80
+ return worker_node_ids
81
+
82
+
83
+ def get_or_create_shutdown_signal() -> ActorHandle:
84
+ """Get existing shutdown signal actor or create a new one.
85
+
86
+ Returns:
87
+ ActorHandle: Reference to shutdown signal actor
88
+ """
89
+ try:
90
+ # Try to get existing actor
91
+ shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
92
+ logging.info("Found existing shutdown signal actor")
93
+ except (ValueError, ray.exceptions.RayActorError) as e:
94
+ logging.info(f"Creating new shutdown signal actor: {e}")
95
+ # Create new actor if it doesn't exist
96
+ shutdown_signal = ShutdownSignal.options(
97
+ name=SHUTDOWN_ACTOR_NAME,
98
+ namespace=SHUTDOWN_ACTOR_NAMESPACE,
99
+ lifetime="detached", # Ensure actor survives client disconnect
100
+ resources={"node_tag:head": 0.001}, # Resource constraint to ensure it runs on head node
101
+ ).remote()
102
+
103
+ # Verify actor is created and accessible
104
+ ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
105
+ logging.debug(f"New actor ping response: {ping_result}")
106
+
107
+ return shutdown_signal
108
+
109
+
110
+ def request_shutdown(shutdown_signal: ActorHandle) -> None:
111
+ """Request workers to shut down.
112
+
113
+ Args:
114
+ shutdown_signal: Reference to the shutdown signal actor
115
+ """
116
+ response = ray.get(shutdown_signal.request_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
117
+ logging.info(f"Shutdown requested: {response}")
118
+
119
+
120
+ def verify_shutdown(shutdown_signal: ActorHandle) -> None:
121
+ """Verify that shutdown was properly signaled.
122
+
123
+ Args:
124
+ shutdown_signal: Reference to the shutdown signal actor
125
+ """
126
+ check = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
127
+ logging.debug(f"Shutdown status check: {check}")
128
+
129
+
130
+ def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: List[str], wait_time: int) -> None:
131
+ """Wait for workers to acknowledge shutdown.
132
+
133
+ Args:
134
+ shutdown_signal: Reference to the shutdown signal actor
135
+ worker_node_ids: List of worker node IDs
136
+ wait_time: Time in seconds to wait for acknowledgments
137
+
138
+ Raises:
139
+ TimeoutError: When workers don't acknowledge within the wait time or if actor communication times out
140
+ """
141
+ if not worker_node_ids:
142
+ return
143
+
144
+ logging.info(f"Waiting up to {wait_time}s for workers to acknowledge shutdown signal...")
145
+ start_time = time.time()
146
+ check_interval = 1.0
147
+
148
+ while time.time() - start_time < wait_time:
149
+ try:
150
+ ack_workers = ray.get(
151
+ shutdown_signal.get_acknowledgment_workers.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
152
+ )
153
+ if ack_workers and ack_workers == set(worker_node_ids):
154
+ logging.info(
155
+ f"All {len(worker_node_ids)} workers acknowledged shutdown. "
156
+ f"Completed in {time.time() - start_time:.2f}s"
157
+ )
158
+ return
159
+ else:
160
+ logging.debug(f"Waiting for acknowledgments: {len(ack_workers)}/{len(worker_node_ids)} workers")
161
+ except Exception as e:
162
+ logging.warning(f"Error checking acknowledgment status: {e}")
163
+
164
+ time.sleep(check_interval)
165
+
166
+ raise TimeoutError(
167
+ f"Timed out waiting for {len(worker_node_ids)} workers to acknowledge shutdown after {wait_time}s"
168
+ )
169
+
170
+
171
+ def signal_workers(wait_time: int = 10) -> int:
172
+ """
173
+ Signal worker nodes to shut down by creating a shutdown signal actor.
174
+
175
+ Args:
176
+ wait_time: Time in seconds to wait for workers to receive the message
177
+
178
+ Returns:
179
+ 0 for success, 1 for failure
180
+ """
181
+ ray.init(address="auto", ignore_reinit_error=True)
182
+
183
+ worker_node_ids = get_worker_node_ids()
184
+
185
+ if worker_node_ids:
186
+ shutdown_signal = get_or_create_shutdown_signal()
187
+ request_shutdown(shutdown_signal)
188
+ verify_shutdown(shutdown_signal)
189
+ wait_for_acknowledgments(shutdown_signal, worker_node_ids, wait_time)
190
+ else:
191
+ logging.info("No active worker nodes found to signal.")
192
+
193
+ return 0
194
+
195
+
196
+ if __name__ == "__main__":
197
+ parser = argparse.ArgumentParser(description="Signal Ray workers to shutdown")
198
+ parser.add_argument(
199
+ "--wait-time", type=int, default=10, help="Time in seconds to wait for workers to receive the signal"
200
+ )
201
+ args = parser.parse_args()
202
+
203
+ sys.exit(signal_workers(args.wait_time))