snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +44 -10
- snowflake/ml/_internal/platform_capabilities.py +39 -3
- snowflake/ml/data/data_connector.py +25 -0
- snowflake/ml/dataset/dataset_reader.py +5 -1
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +81 -47
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +5 -8
- snowflake/ml/jobs/_utils/types.py +6 -0
- snowflake/ml/jobs/decorators.py +3 -3
- snowflake/ml/jobs/job.py +145 -23
- snowflake/ml/jobs/manager.py +62 -10
- snowflake/ml/model/_client/ops/service_ops.py +42 -35
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
- snowflake/ml/model/_client/sql/service.py +9 -5
- snowflake/ml/model/_model_composer/model_composer.py +29 -11
- snowflake/ml/model/_packager/model_env/model_env.py +8 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +20 -1
- snowflake/ml/registry/registry.py +5 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +35 -4
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +34 -28
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,242 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
# This file is part of the Ray-based distributed job system for Snowflake ML.
|
3
|
+
# Architecture overview:
|
4
|
+
# - Head node creates a ShutdownSignal actor and signals workers when job completes
|
5
|
+
# - Worker nodes listen for this signal via this script and gracefully shut down
|
6
|
+
# - This ensures clean termination of distributed Ray jobs
|
7
|
+
import logging
|
8
|
+
import signal
|
9
|
+
import sys
|
10
|
+
import time
|
11
|
+
from typing import Optional
|
12
|
+
|
13
|
+
import get_instance_ip
|
14
|
+
import ray
|
15
|
+
from constants import (
|
16
|
+
SHUTDOWN_ACTOR_NAME,
|
17
|
+
SHUTDOWN_ACTOR_NAMESPACE,
|
18
|
+
SHUTDOWN_RPC_TIMEOUT_SECONDS,
|
19
|
+
)
|
20
|
+
from ray.actor import ActorHandle
|
21
|
+
|
22
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
23
|
+
|
24
|
+
|
25
|
+
def get_shutdown_actor() -> Optional[ActorHandle]:
|
26
|
+
"""
|
27
|
+
Retrieve the shutdown signal actor from Ray.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
The shutdown signal actor or None if not found
|
31
|
+
"""
|
32
|
+
try:
|
33
|
+
shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
|
34
|
+
return shutdown_signal
|
35
|
+
except Exception:
|
36
|
+
return None
|
37
|
+
|
38
|
+
|
39
|
+
def ping_shutdown_actor(shutdown_signal: ActorHandle) -> bool:
|
40
|
+
"""
|
41
|
+
Ping the shutdown actor to ensure connectivity.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
shutdown_signal: The Ray actor handle for the shutdown signal
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
True if ping succeeds, False otherwise
|
48
|
+
"""
|
49
|
+
try:
|
50
|
+
ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
51
|
+
logging.debug(f"Actor ping result: {ping_result}")
|
52
|
+
return True
|
53
|
+
except (ray.exceptions.GetTimeoutError, Exception) as e:
|
54
|
+
logging.debug(f"Actor ping failed: {e}")
|
55
|
+
return False
|
56
|
+
|
57
|
+
|
58
|
+
def check_shutdown_status(shutdown_signal: ActorHandle, worker_id: str) -> bool:
|
59
|
+
"""
|
60
|
+
Check if worker should shutdown and acknowledge if needed.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
shutdown_signal: The Ray actor handle for the shutdown signal
|
64
|
+
worker_id: Worker identifier (IP address)
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
True if should shutdown, False otherwise
|
68
|
+
"""
|
69
|
+
try:
|
70
|
+
status = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
71
|
+
logging.debug(f"Shutdown status: {status}")
|
72
|
+
|
73
|
+
if status.get("shutdown", False):
|
74
|
+
logging.info(
|
75
|
+
f"Received shutdown signal from head node at {status.get('timestamp')}. " f"Exiting worker process."
|
76
|
+
)
|
77
|
+
|
78
|
+
# Acknowledge shutdown before exiting
|
79
|
+
try:
|
80
|
+
ack_result = ray.get(
|
81
|
+
shutdown_signal.acknowledge_shutdown.remote(worker_id), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
|
82
|
+
)
|
83
|
+
logging.info(f"Acknowledged shutdown: {ack_result}")
|
84
|
+
except Exception as e:
|
85
|
+
logging.warning(f"Failed to acknowledge shutdown: {e}. Continue to exit worker.")
|
86
|
+
|
87
|
+
return True
|
88
|
+
return False
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
logging.debug(f"Error checking shutdown status: {e}")
|
92
|
+
return False
|
93
|
+
|
94
|
+
|
95
|
+
def check_ray_connectivity() -> bool:
|
96
|
+
"""
|
97
|
+
Check if the Ray cluster is accessible.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
True if Ray is connected, False otherwise
|
101
|
+
"""
|
102
|
+
try:
|
103
|
+
# A simple check to verify Ray is working
|
104
|
+
nodes = ray.nodes()
|
105
|
+
if nodes:
|
106
|
+
return True
|
107
|
+
return False
|
108
|
+
except Exception as e:
|
109
|
+
logging.debug(f"Ray connectivity check failed: {e}")
|
110
|
+
return False
|
111
|
+
|
112
|
+
|
113
|
+
def initialize_ray_connection(max_retries: int, initial_retry_delay: int, max_retry_delay: int) -> bool:
|
114
|
+
"""
|
115
|
+
Initialize connection to Ray with retries.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
max_retries: Maximum number of connection attempts
|
119
|
+
initial_retry_delay: Initial delay between retries in seconds
|
120
|
+
max_retry_delay: Maximum delay between retries in seconds
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
bool: True if connection successful, False otherwise
|
124
|
+
"""
|
125
|
+
retry_count = 0
|
126
|
+
retry_delay = initial_retry_delay
|
127
|
+
|
128
|
+
while retry_count < max_retries:
|
129
|
+
try:
|
130
|
+
ray.init(address="auto", ignore_reinit_error=True)
|
131
|
+
return True
|
132
|
+
except (ConnectionError, TimeoutError, RuntimeError) as e:
|
133
|
+
retry_count += 1
|
134
|
+
if retry_count >= max_retries:
|
135
|
+
logging.error(f"Failed to connect to Ray head after {max_retries} attempts: {e}")
|
136
|
+
return False
|
137
|
+
|
138
|
+
logging.debug(
|
139
|
+
f"Attempt {retry_count}/{max_retries} to connect to Ray failed: {e}. "
|
140
|
+
f"Retrying in {retry_delay} seconds..."
|
141
|
+
)
|
142
|
+
time.sleep(retry_delay)
|
143
|
+
# Exponential backoff with cap
|
144
|
+
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
145
|
+
|
146
|
+
return False # Should not reach here, but added for completeness
|
147
|
+
|
148
|
+
|
149
|
+
def monitor_shutdown_signal(check_interval: int, max_consecutive_failures: int) -> int:
|
150
|
+
"""
|
151
|
+
Main loop to monitor for shutdown signals.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
check_interval: Time in seconds between checks
|
155
|
+
max_consecutive_failures: Maximum allowed consecutive connection failures
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
int: Exit code (0 for success, non-zero for failure)
|
159
|
+
|
160
|
+
Raises:
|
161
|
+
ConnectionError: If Ray connection failures exceed threshold
|
162
|
+
"""
|
163
|
+
worker_id = get_instance_ip.get_self_ip()
|
164
|
+
actor_check_count = 0
|
165
|
+
consecutive_connection_failures = 0
|
166
|
+
|
167
|
+
logging.debug(
|
168
|
+
f"Starting to monitor for shutdown signal using actor {SHUTDOWN_ACTOR_NAME}"
|
169
|
+
f" in namespace {SHUTDOWN_ACTOR_NAMESPACE}."
|
170
|
+
)
|
171
|
+
|
172
|
+
while True:
|
173
|
+
actor_check_count += 1
|
174
|
+
|
175
|
+
# Check Ray connectivity before proceeding
|
176
|
+
if not check_ray_connectivity():
|
177
|
+
consecutive_connection_failures += 1
|
178
|
+
logging.debug(
|
179
|
+
f"Ray connectivity check failed (attempt {consecutive_connection_failures}/{max_consecutive_failures})"
|
180
|
+
)
|
181
|
+
if consecutive_connection_failures >= max_consecutive_failures:
|
182
|
+
raise ConnectionError("Exceeded max consecutive Ray connection failures")
|
183
|
+
time.sleep(check_interval)
|
184
|
+
continue
|
185
|
+
|
186
|
+
# Reset counter on successful connection
|
187
|
+
consecutive_connection_failures = 0
|
188
|
+
|
189
|
+
# Get shutdown actor
|
190
|
+
shutdown_signal = get_shutdown_actor()
|
191
|
+
if not shutdown_signal:
|
192
|
+
logging.debug(f"Shutdown signal actor not found at check #{actor_check_count}, continuing to wait...")
|
193
|
+
time.sleep(check_interval)
|
194
|
+
continue
|
195
|
+
|
196
|
+
# Ping the actor to ensure connectivity
|
197
|
+
if not ping_shutdown_actor(shutdown_signal):
|
198
|
+
time.sleep(check_interval)
|
199
|
+
continue
|
200
|
+
|
201
|
+
# Check shutdown status
|
202
|
+
if check_shutdown_status(shutdown_signal, worker_id):
|
203
|
+
return 0
|
204
|
+
|
205
|
+
# Wait before checking again
|
206
|
+
time.sleep(check_interval)
|
207
|
+
|
208
|
+
|
209
|
+
def run_listener() -> int:
|
210
|
+
"""Listen for shutdown signals from the head node"""
|
211
|
+
# Configuration
|
212
|
+
max_retries = 15
|
213
|
+
initial_retry_delay = 2
|
214
|
+
max_retry_delay = 30
|
215
|
+
check_interval = 5 # How often to check for ray connection or shutdown signal
|
216
|
+
max_consecutive_failures = 12 # Exit after about 1 minute of connection failures
|
217
|
+
|
218
|
+
# Initialize Ray connection
|
219
|
+
if not initialize_ray_connection(max_retries, initial_retry_delay, max_retry_delay):
|
220
|
+
raise ConnectionError("Failed to connect to Ray cluster. Aborting worker.")
|
221
|
+
|
222
|
+
# Monitor for shutdown signals
|
223
|
+
return monitor_shutdown_signal(check_interval, max_consecutive_failures)
|
224
|
+
|
225
|
+
|
226
|
+
def main():
|
227
|
+
"""Main entry point with signal handling"""
|
228
|
+
|
229
|
+
def signal_handler(signum, frame):
|
230
|
+
logging.info(f"Received signal {signum}, exiting worker process.")
|
231
|
+
sys.exit(0)
|
232
|
+
|
233
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
234
|
+
signal.signal(signal.SIGINT, signal_handler)
|
235
|
+
|
236
|
+
# Run the listener - this will block until a shutdown signal is received
|
237
|
+
result = run_listener()
|
238
|
+
sys.exit(result)
|
239
|
+
|
240
|
+
|
241
|
+
if __name__ == "__main__":
|
242
|
+
main()
|
@@ -114,15 +114,9 @@ def generate_service_spec(
|
|
114
114
|
Job service specification
|
115
115
|
"""
|
116
116
|
is_multi_node = num_instances is not None and num_instances > 1
|
117
|
+
image_spec = _get_image_spec(session, compute_pool)
|
117
118
|
|
118
119
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
119
|
-
if is_multi_node:
|
120
|
-
# If the job is of multi-node, we will need a different image which contains
|
121
|
-
# module snowflake.runtime.utils.get_instance_ip
|
122
|
-
# TODO(SNOW-1961849): Remove the hard-coded image name
|
123
|
-
image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
|
124
|
-
else:
|
125
|
-
image_spec = _get_image_spec(session, compute_pool)
|
126
120
|
resource_requests: Dict[str, Union[str, int]] = {
|
127
121
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
128
122
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
@@ -191,7 +185,10 @@ def generate_service_spec(
|
|
191
185
|
|
192
186
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
193
187
|
|
194
|
-
env_vars = {
|
188
|
+
env_vars = {
|
189
|
+
constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
|
190
|
+
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
191
|
+
}
|
195
192
|
endpoints = []
|
196
193
|
|
197
194
|
if is_multi_node:
|
@@ -11,6 +11,12 @@ JOB_STATUS = Literal[
|
|
11
11
|
]
|
12
12
|
|
13
13
|
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class PayloadEntrypoint:
|
16
|
+
file_path: PurePath
|
17
|
+
main_func: Optional[str]
|
18
|
+
|
19
|
+
|
14
20
|
@dataclass(frozen=True)
|
15
21
|
class UploadedPayload:
|
16
22
|
# TODO: Include manifest of payload files for validation
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -28,7 +28,7 @@ def remote(
|
|
28
28
|
num_instances: Optional[int] = None,
|
29
29
|
enable_metrics: bool = False,
|
30
30
|
session: Optional[snowpark.Session] = None,
|
31
|
-
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
|
31
|
+
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
32
32
|
"""
|
33
33
|
Submit a job to the compute pool.
|
34
34
|
|
@@ -47,7 +47,7 @@ def remote(
|
|
47
47
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
48
48
|
"""
|
49
49
|
|
50
|
-
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
|
50
|
+
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
|
51
51
|
# Copy the function to avoid modifying the original
|
52
52
|
# We need to modify the line number of the function to exclude the
|
53
53
|
# decorator from the copied source code
|
@@ -55,7 +55,7 @@ def remote(
|
|
55
55
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
56
56
|
|
57
57
|
@functools.wraps(func)
|
58
|
-
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
|
58
|
+
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
59
59
|
payload = functools.partial(func, *args, **kwargs)
|
60
60
|
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
61
61
|
job = jm._submit_job(
|
snowflake/ml/jobs/job.py
CHANGED
@@ -1,20 +1,32 @@
|
|
1
1
|
import time
|
2
|
-
from typing import Any, List, Optional, cast
|
2
|
+
from typing import Any, Dict, Generic, List, Optional, TypeVar, cast
|
3
|
+
|
4
|
+
import yaml
|
3
5
|
|
4
6
|
from snowflake import snowpark
|
5
7
|
from snowflake.ml._internal import telemetry
|
6
|
-
from snowflake.ml.jobs._utils import constants, types
|
8
|
+
from snowflake.ml.jobs._utils import constants, interop_utils, types
|
7
9
|
from snowflake.snowpark import context as sp_context
|
8
10
|
|
9
11
|
_PROJECT = "MLJob"
|
10
12
|
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
11
13
|
|
14
|
+
T = TypeVar("T")
|
15
|
+
|
12
16
|
|
13
|
-
class MLJob:
|
14
|
-
def __init__(
|
17
|
+
class MLJob(Generic[T]):
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
id: str,
|
21
|
+
service_spec: Optional[Dict[str, Any]] = None,
|
22
|
+
session: Optional[snowpark.Session] = None,
|
23
|
+
) -> None:
|
15
24
|
self._id = id
|
25
|
+
self._service_spec_cached: Optional[Dict[str, Any]] = service_spec
|
16
26
|
self._session = session or sp_context.get_active_session()
|
27
|
+
|
17
28
|
self._status: types.JOB_STATUS = "PENDING"
|
29
|
+
self._result: Optional[interop_utils.ExecutionResult] = None
|
18
30
|
|
19
31
|
@property
|
20
32
|
def id(self) -> str:
|
@@ -29,33 +41,66 @@ class MLJob:
|
|
29
41
|
self._status = _get_status(self._session, self.id)
|
30
42
|
return self._status
|
31
43
|
|
44
|
+
@property
|
45
|
+
def _service_spec(self) -> Dict[str, Any]:
|
46
|
+
"""Get the job's service spec."""
|
47
|
+
if not self._service_spec_cached:
|
48
|
+
self._service_spec_cached = _get_service_spec(self._session, self.id)
|
49
|
+
return self._service_spec_cached
|
50
|
+
|
51
|
+
@property
|
52
|
+
def _container_spec(self) -> Dict[str, Any]:
|
53
|
+
"""Get the job's main container spec."""
|
54
|
+
containers = self._service_spec["spec"]["containers"]
|
55
|
+
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
|
56
|
+
return cast(Dict[str, Any], container_spec)
|
57
|
+
|
58
|
+
@property
|
59
|
+
def _stage_path(self) -> str:
|
60
|
+
"""Get the job's artifact storage stage location."""
|
61
|
+
volumes = self._service_spec["spec"]["volumes"]
|
62
|
+
stage_path = next(v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME)["source"]
|
63
|
+
return cast(str, stage_path)
|
64
|
+
|
65
|
+
@property
|
66
|
+
def _result_path(self) -> str:
|
67
|
+
"""Get the job's result file location."""
|
68
|
+
result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
69
|
+
if result_path is None:
|
70
|
+
raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
|
71
|
+
return f"{self._stage_path}/{result_path}"
|
72
|
+
|
32
73
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
33
|
-
def get_logs(self, limit: int = -1) -> str:
|
74
|
+
def get_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> str:
|
34
75
|
"""
|
35
76
|
Return the job's execution logs.
|
36
77
|
|
37
78
|
Args:
|
38
79
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
80
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
81
|
+
If not provided, returns logs from the head node.
|
39
82
|
|
40
83
|
Returns:
|
41
84
|
The job's execution logs.
|
42
85
|
"""
|
43
|
-
logs = _get_logs(self._session, self.id, limit)
|
86
|
+
logs = _get_logs(self._session, self.id, limit, instance_id)
|
44
87
|
assert isinstance(logs, str) # mypy
|
45
88
|
return logs
|
46
89
|
|
47
90
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
48
|
-
def show_logs(self, limit: int = -1) -> None:
|
91
|
+
def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
|
49
92
|
"""
|
50
93
|
Display the job's execution logs.
|
51
94
|
|
52
95
|
Args:
|
53
96
|
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
97
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
98
|
+
If not provided, displays logs from the head node.
|
54
99
|
"""
|
55
|
-
print(self.get_logs(limit)) # noqa: T201: we need to print here.
|
100
|
+
print(self.get_logs(limit, instance_id)) # noqa: T201: we need to print here.
|
56
101
|
|
57
102
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
58
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
103
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
59
104
|
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
60
105
|
"""
|
61
106
|
Block until completion. Returns completion status.
|
@@ -78,20 +123,58 @@ class MLJob:
|
|
78
123
|
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
79
124
|
return self.status
|
80
125
|
|
126
|
+
@snowpark._internal.utils.private_preview(version="1.8.2")
|
127
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
128
|
+
def result(self, timeout: float = -1) -> T:
|
129
|
+
"""
|
130
|
+
Block until completion. Returns job execution result.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
T: The deserialized job result. # noqa: DAR401
|
137
|
+
|
138
|
+
Raises:
|
139
|
+
RuntimeError: If the job failed or if the job doesn't have a result to retrieve.
|
140
|
+
TimeoutError: If the job does not complete within the specified timeout. # noqa: DAR402
|
141
|
+
"""
|
142
|
+
if self._result is None:
|
143
|
+
self.wait(timeout)
|
144
|
+
try:
|
145
|
+
self._result = interop_utils.fetch_result(self._session, self._result_path)
|
146
|
+
except Exception as e:
|
147
|
+
raise RuntimeError(f"Failed to retrieve result for job (id={self.id})") from e
|
148
|
+
|
149
|
+
if self._result.success:
|
150
|
+
return cast(T, self._result.result)
|
151
|
+
raise RuntimeError(f"Job execution failed (id={self.id})") from self._result.exception
|
152
|
+
|
153
|
+
|
154
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
155
|
+
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
156
|
+
"""Retrieve job or job instance execution status."""
|
157
|
+
if instance_id is not None:
|
158
|
+
# Get specific instance status
|
159
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
160
|
+
for row in rows:
|
161
|
+
if row["instance_id"] == str(instance_id):
|
162
|
+
return cast(types.JOB_STATUS, row["status"])
|
163
|
+
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
164
|
+
else:
|
165
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
166
|
+
return cast(types.JOB_STATUS, row["status"])
|
167
|
+
|
81
168
|
|
82
169
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
83
|
-
def
|
84
|
-
"""Retrieve job execution
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
|
94
|
-
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
170
|
+
def _get_service_spec(session: snowpark.Session, job_id: str) -> Dict[str, Any]:
|
171
|
+
"""Retrieve job execution service spec."""
|
172
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
|
173
|
+
return cast(Dict[str, Any], yaml.safe_load(row["spec"]))
|
174
|
+
|
175
|
+
|
176
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
177
|
+
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None) -> str:
|
95
178
|
"""
|
96
179
|
Retrieve the job's execution logs.
|
97
180
|
|
@@ -99,15 +182,54 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
|
99
182
|
job_id: The job ID.
|
100
183
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
101
184
|
session: The Snowpark session to use. If none specified, uses active session.
|
185
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
102
186
|
|
103
187
|
Returns:
|
104
188
|
The job's execution logs.
|
105
189
|
"""
|
106
|
-
|
190
|
+
# If instance_id is not specified, try to get the head instance ID
|
191
|
+
if instance_id is None:
|
192
|
+
instance_id = _get_head_instance_id(session, job_id)
|
193
|
+
|
194
|
+
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
195
|
+
params: List[Any] = [
|
196
|
+
job_id,
|
197
|
+
0 if instance_id is None else instance_id,
|
198
|
+
constants.DEFAULT_CONTAINER_NAME,
|
199
|
+
]
|
107
200
|
if limit > 0:
|
108
201
|
params.append(limit)
|
202
|
+
|
109
203
|
(row,) = session.sql(
|
110
|
-
f"SELECT SYSTEM$GET_SERVICE_LOGS(?,
|
204
|
+
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
111
205
|
params=params,
|
112
206
|
).collect()
|
113
207
|
return str(row[0])
|
208
|
+
|
209
|
+
|
210
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
211
|
+
def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[int]:
|
212
|
+
"""
|
213
|
+
Retrieve the head instance ID of a job.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
session: The Snowpark session to use.
|
217
|
+
job_id: The job ID.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
The head instance ID of the job. Returns None if the head instance has not started yet.
|
221
|
+
"""
|
222
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
223
|
+
if not rows:
|
224
|
+
return None
|
225
|
+
|
226
|
+
# Sort by start_time first, then by instance_id
|
227
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
228
|
+
head_instance = sorted_instances[0]
|
229
|
+
if not head_instance["start_time"]:
|
230
|
+
# If head instance hasn't started yet, return None
|
231
|
+
return None
|
232
|
+
try:
|
233
|
+
return int(head_instance["instance_id"])
|
234
|
+
except (ValueError, TypeError):
|
235
|
+
return 0
|
snowflake/ml/jobs/manager.py
CHANGED
@@ -1,6 +1,16 @@
|
|
1
1
|
import pathlib
|
2
2
|
import textwrap
|
3
|
-
from typing import
|
3
|
+
from typing import (
|
4
|
+
Any,
|
5
|
+
Callable,
|
6
|
+
Dict,
|
7
|
+
List,
|
8
|
+
Literal,
|
9
|
+
Optional,
|
10
|
+
TypeVar,
|
11
|
+
Union,
|
12
|
+
overload,
|
13
|
+
)
|
4
14
|
from uuid import uuid4
|
5
15
|
|
6
16
|
import yaml
|
@@ -16,6 +26,8 @@ from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
16
26
|
_PROJECT = "MLJob"
|
17
27
|
JOB_ID_PREFIX = "MLJOB_"
|
18
28
|
|
29
|
+
T = TypeVar("T")
|
30
|
+
|
19
31
|
|
20
32
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
21
33
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
|
@@ -59,7 +71,7 @@ def list_jobs(
|
|
59
71
|
|
60
72
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
61
73
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
62
|
-
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
|
74
|
+
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
|
63
75
|
"""Retrieve a job service from the backend."""
|
64
76
|
session = session or get_active_session()
|
65
77
|
|
@@ -71,7 +83,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
71
83
|
|
72
84
|
try:
|
73
85
|
# Validate that job exists by doing a status check
|
74
|
-
|
86
|
+
# FIXME: Retrieve return path
|
87
|
+
job = jb.MLJob[Any](job_id, session=session)
|
75
88
|
_ = job.status
|
76
89
|
return job
|
77
90
|
except SnowparkSQLException as e:
|
@@ -82,7 +95,7 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
|
|
82
95
|
|
83
96
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
84
97
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
85
|
-
def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
|
98
|
+
def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
|
86
99
|
"""Delete a job service from the backend. Status and logs will be lost."""
|
87
100
|
if isinstance(job, jb.MLJob):
|
88
101
|
job_id = job.id
|
@@ -109,7 +122,7 @@ def submit_file(
|
|
109
122
|
num_instances: Optional[int] = None,
|
110
123
|
enable_metrics: bool = False,
|
111
124
|
session: Optional[snowpark.Session] = None,
|
112
|
-
) -> jb.MLJob:
|
125
|
+
) -> jb.MLJob[None]:
|
113
126
|
"""
|
114
127
|
Submit a Python file as a job to the compute pool.
|
115
128
|
|
@@ -163,7 +176,7 @@ def submit_directory(
|
|
163
176
|
num_instances: Optional[int] = None,
|
164
177
|
enable_metrics: bool = False,
|
165
178
|
session: Optional[snowpark.Session] = None,
|
166
|
-
) -> jb.MLJob:
|
179
|
+
) -> jb.MLJob[None]:
|
167
180
|
"""
|
168
181
|
Submit a directory containing Python script(s) as a job to the compute pool.
|
169
182
|
|
@@ -202,6 +215,46 @@ def submit_directory(
|
|
202
215
|
)
|
203
216
|
|
204
217
|
|
218
|
+
@overload
|
219
|
+
def _submit_job(
|
220
|
+
source: str,
|
221
|
+
compute_pool: str,
|
222
|
+
*,
|
223
|
+
stage_name: str,
|
224
|
+
entrypoint: Optional[str] = None,
|
225
|
+
args: Optional[List[str]] = None,
|
226
|
+
env_vars: Optional[Dict[str, str]] = None,
|
227
|
+
pip_requirements: Optional[List[str]] = None,
|
228
|
+
external_access_integrations: Optional[List[str]] = None,
|
229
|
+
query_warehouse: Optional[str] = None,
|
230
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
231
|
+
num_instances: Optional[int] = None,
|
232
|
+
enable_metrics: bool = False,
|
233
|
+
session: Optional[snowpark.Session] = None,
|
234
|
+
) -> jb.MLJob[None]:
|
235
|
+
...
|
236
|
+
|
237
|
+
|
238
|
+
@overload
|
239
|
+
def _submit_job(
|
240
|
+
source: Callable[..., T],
|
241
|
+
compute_pool: str,
|
242
|
+
*,
|
243
|
+
stage_name: str,
|
244
|
+
entrypoint: Optional[str] = None,
|
245
|
+
args: Optional[List[str]] = None,
|
246
|
+
env_vars: Optional[Dict[str, str]] = None,
|
247
|
+
pip_requirements: Optional[List[str]] = None,
|
248
|
+
external_access_integrations: Optional[List[str]] = None,
|
249
|
+
query_warehouse: Optional[str] = None,
|
250
|
+
spec_overrides: Optional[Dict[str, Any]] = None,
|
251
|
+
num_instances: Optional[int] = None,
|
252
|
+
enable_metrics: bool = False,
|
253
|
+
session: Optional[snowpark.Session] = None,
|
254
|
+
) -> jb.MLJob[T]:
|
255
|
+
...
|
256
|
+
|
257
|
+
|
205
258
|
@telemetry.send_api_usage_telemetry(
|
206
259
|
project=_PROJECT,
|
207
260
|
func_params_to_log=[
|
@@ -213,7 +266,7 @@ def submit_directory(
|
|
213
266
|
],
|
214
267
|
)
|
215
268
|
def _submit_job(
|
216
|
-
source: Union[str, Callable[...,
|
269
|
+
source: Union[str, Callable[..., T]],
|
217
270
|
compute_pool: str,
|
218
271
|
*,
|
219
272
|
stage_name: str,
|
@@ -227,7 +280,7 @@ def _submit_job(
|
|
227
280
|
num_instances: Optional[int] = None,
|
228
281
|
enable_metrics: bool = False,
|
229
282
|
session: Optional[snowpark.Session] = None,
|
230
|
-
) -> jb.MLJob:
|
283
|
+
) -> jb.MLJob[T]:
|
231
284
|
"""
|
232
285
|
Submit a job to the compute pool.
|
233
286
|
|
@@ -314,5 +367,4 @@ def _submit_job(
|
|
314
367
|
) from e
|
315
368
|
raise
|
316
369
|
|
317
|
-
|
318
|
-
return jb.MLJob(job_id, session=session)
|
370
|
+
return jb.MLJob(job_id, service_spec=spec, session=session)
|