snowflake-ml-python 1.8.0__py3-none-any.whl → 1.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +44 -10
- snowflake/ml/_internal/platform_capabilities.py +39 -3
- snowflake/ml/data/data_connector.py +25 -0
- snowflake/ml/dataset/dataset_reader.py +5 -1
- snowflake/ml/jobs/_utils/constants.py +3 -5
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +81 -47
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +27 -8
- snowflake/ml/jobs/_utils/types.py +6 -0
- snowflake/ml/jobs/decorators.py +10 -6
- snowflake/ml/jobs/job.py +145 -23
- snowflake/ml/jobs/manager.py +79 -12
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +57 -39
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
- snowflake/ml/model/_client/sql/service.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +29 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -2
- snowflake/ml/model/_packager/model_env/model_env.py +8 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
- snowflake/ml/registry/_manager/model_manager.py +20 -1
- snowflake/ml/registry/registry.py +46 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +55 -4
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +40 -34
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,242 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
# This file is part of the Ray-based distributed job system for Snowflake ML.
|
3
|
+
# Architecture overview:
|
4
|
+
# - Head node creates a ShutdownSignal actor and signals workers when job completes
|
5
|
+
# - Worker nodes listen for this signal via this script and gracefully shut down
|
6
|
+
# - This ensures clean termination of distributed Ray jobs
|
7
|
+
import logging
|
8
|
+
import signal
|
9
|
+
import sys
|
10
|
+
import time
|
11
|
+
from typing import Optional
|
12
|
+
|
13
|
+
import get_instance_ip
|
14
|
+
import ray
|
15
|
+
from constants import (
|
16
|
+
SHUTDOWN_ACTOR_NAME,
|
17
|
+
SHUTDOWN_ACTOR_NAMESPACE,
|
18
|
+
SHUTDOWN_RPC_TIMEOUT_SECONDS,
|
19
|
+
)
|
20
|
+
from ray.actor import ActorHandle
|
21
|
+
|
22
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
23
|
+
|
24
|
+
|
25
|
+
def get_shutdown_actor() -> Optional[ActorHandle]:
|
26
|
+
"""
|
27
|
+
Retrieve the shutdown signal actor from Ray.
|
28
|
+
|
29
|
+
Returns:
|
30
|
+
The shutdown signal actor or None if not found
|
31
|
+
"""
|
32
|
+
try:
|
33
|
+
shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
|
34
|
+
return shutdown_signal
|
35
|
+
except Exception:
|
36
|
+
return None
|
37
|
+
|
38
|
+
|
39
|
+
def ping_shutdown_actor(shutdown_signal: ActorHandle) -> bool:
|
40
|
+
"""
|
41
|
+
Ping the shutdown actor to ensure connectivity.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
shutdown_signal: The Ray actor handle for the shutdown signal
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
True if ping succeeds, False otherwise
|
48
|
+
"""
|
49
|
+
try:
|
50
|
+
ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
51
|
+
logging.debug(f"Actor ping result: {ping_result}")
|
52
|
+
return True
|
53
|
+
except (ray.exceptions.GetTimeoutError, Exception) as e:
|
54
|
+
logging.debug(f"Actor ping failed: {e}")
|
55
|
+
return False
|
56
|
+
|
57
|
+
|
58
|
+
def check_shutdown_status(shutdown_signal: ActorHandle, worker_id: str) -> bool:
|
59
|
+
"""
|
60
|
+
Check if worker should shutdown and acknowledge if needed.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
shutdown_signal: The Ray actor handle for the shutdown signal
|
64
|
+
worker_id: Worker identifier (IP address)
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
True if should shutdown, False otherwise
|
68
|
+
"""
|
69
|
+
try:
|
70
|
+
status = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
|
71
|
+
logging.debug(f"Shutdown status: {status}")
|
72
|
+
|
73
|
+
if status.get("shutdown", False):
|
74
|
+
logging.info(
|
75
|
+
f"Received shutdown signal from head node at {status.get('timestamp')}. " f"Exiting worker process."
|
76
|
+
)
|
77
|
+
|
78
|
+
# Acknowledge shutdown before exiting
|
79
|
+
try:
|
80
|
+
ack_result = ray.get(
|
81
|
+
shutdown_signal.acknowledge_shutdown.remote(worker_id), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
|
82
|
+
)
|
83
|
+
logging.info(f"Acknowledged shutdown: {ack_result}")
|
84
|
+
except Exception as e:
|
85
|
+
logging.warning(f"Failed to acknowledge shutdown: {e}. Continue to exit worker.")
|
86
|
+
|
87
|
+
return True
|
88
|
+
return False
|
89
|
+
|
90
|
+
except Exception as e:
|
91
|
+
logging.debug(f"Error checking shutdown status: {e}")
|
92
|
+
return False
|
93
|
+
|
94
|
+
|
95
|
+
def check_ray_connectivity() -> bool:
|
96
|
+
"""
|
97
|
+
Check if the Ray cluster is accessible.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
True if Ray is connected, False otherwise
|
101
|
+
"""
|
102
|
+
try:
|
103
|
+
# A simple check to verify Ray is working
|
104
|
+
nodes = ray.nodes()
|
105
|
+
if nodes:
|
106
|
+
return True
|
107
|
+
return False
|
108
|
+
except Exception as e:
|
109
|
+
logging.debug(f"Ray connectivity check failed: {e}")
|
110
|
+
return False
|
111
|
+
|
112
|
+
|
113
|
+
def initialize_ray_connection(max_retries: int, initial_retry_delay: int, max_retry_delay: int) -> bool:
|
114
|
+
"""
|
115
|
+
Initialize connection to Ray with retries.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
max_retries: Maximum number of connection attempts
|
119
|
+
initial_retry_delay: Initial delay between retries in seconds
|
120
|
+
max_retry_delay: Maximum delay between retries in seconds
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
bool: True if connection successful, False otherwise
|
124
|
+
"""
|
125
|
+
retry_count = 0
|
126
|
+
retry_delay = initial_retry_delay
|
127
|
+
|
128
|
+
while retry_count < max_retries:
|
129
|
+
try:
|
130
|
+
ray.init(address="auto", ignore_reinit_error=True)
|
131
|
+
return True
|
132
|
+
except (ConnectionError, TimeoutError, RuntimeError) as e:
|
133
|
+
retry_count += 1
|
134
|
+
if retry_count >= max_retries:
|
135
|
+
logging.error(f"Failed to connect to Ray head after {max_retries} attempts: {e}")
|
136
|
+
return False
|
137
|
+
|
138
|
+
logging.debug(
|
139
|
+
f"Attempt {retry_count}/{max_retries} to connect to Ray failed: {e}. "
|
140
|
+
f"Retrying in {retry_delay} seconds..."
|
141
|
+
)
|
142
|
+
time.sleep(retry_delay)
|
143
|
+
# Exponential backoff with cap
|
144
|
+
retry_delay = min(retry_delay * 1.5, max_retry_delay)
|
145
|
+
|
146
|
+
return False # Should not reach here, but added for completeness
|
147
|
+
|
148
|
+
|
149
|
+
def monitor_shutdown_signal(check_interval: int, max_consecutive_failures: int) -> int:
|
150
|
+
"""
|
151
|
+
Main loop to monitor for shutdown signals.
|
152
|
+
|
153
|
+
Args:
|
154
|
+
check_interval: Time in seconds between checks
|
155
|
+
max_consecutive_failures: Maximum allowed consecutive connection failures
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
int: Exit code (0 for success, non-zero for failure)
|
159
|
+
|
160
|
+
Raises:
|
161
|
+
ConnectionError: If Ray connection failures exceed threshold
|
162
|
+
"""
|
163
|
+
worker_id = get_instance_ip.get_self_ip()
|
164
|
+
actor_check_count = 0
|
165
|
+
consecutive_connection_failures = 0
|
166
|
+
|
167
|
+
logging.debug(
|
168
|
+
f"Starting to monitor for shutdown signal using actor {SHUTDOWN_ACTOR_NAME}"
|
169
|
+
f" in namespace {SHUTDOWN_ACTOR_NAMESPACE}."
|
170
|
+
)
|
171
|
+
|
172
|
+
while True:
|
173
|
+
actor_check_count += 1
|
174
|
+
|
175
|
+
# Check Ray connectivity before proceeding
|
176
|
+
if not check_ray_connectivity():
|
177
|
+
consecutive_connection_failures += 1
|
178
|
+
logging.debug(
|
179
|
+
f"Ray connectivity check failed (attempt {consecutive_connection_failures}/{max_consecutive_failures})"
|
180
|
+
)
|
181
|
+
if consecutive_connection_failures >= max_consecutive_failures:
|
182
|
+
raise ConnectionError("Exceeded max consecutive Ray connection failures")
|
183
|
+
time.sleep(check_interval)
|
184
|
+
continue
|
185
|
+
|
186
|
+
# Reset counter on successful connection
|
187
|
+
consecutive_connection_failures = 0
|
188
|
+
|
189
|
+
# Get shutdown actor
|
190
|
+
shutdown_signal = get_shutdown_actor()
|
191
|
+
if not shutdown_signal:
|
192
|
+
logging.debug(f"Shutdown signal actor not found at check #{actor_check_count}, continuing to wait...")
|
193
|
+
time.sleep(check_interval)
|
194
|
+
continue
|
195
|
+
|
196
|
+
# Ping the actor to ensure connectivity
|
197
|
+
if not ping_shutdown_actor(shutdown_signal):
|
198
|
+
time.sleep(check_interval)
|
199
|
+
continue
|
200
|
+
|
201
|
+
# Check shutdown status
|
202
|
+
if check_shutdown_status(shutdown_signal, worker_id):
|
203
|
+
return 0
|
204
|
+
|
205
|
+
# Wait before checking again
|
206
|
+
time.sleep(check_interval)
|
207
|
+
|
208
|
+
|
209
|
+
def run_listener() -> int:
|
210
|
+
"""Listen for shutdown signals from the head node"""
|
211
|
+
# Configuration
|
212
|
+
max_retries = 15
|
213
|
+
initial_retry_delay = 2
|
214
|
+
max_retry_delay = 30
|
215
|
+
check_interval = 5 # How often to check for ray connection or shutdown signal
|
216
|
+
max_consecutive_failures = 12 # Exit after about 1 minute of connection failures
|
217
|
+
|
218
|
+
# Initialize Ray connection
|
219
|
+
if not initialize_ray_connection(max_retries, initial_retry_delay, max_retry_delay):
|
220
|
+
raise ConnectionError("Failed to connect to Ray cluster. Aborting worker.")
|
221
|
+
|
222
|
+
# Monitor for shutdown signals
|
223
|
+
return monitor_shutdown_signal(check_interval, max_consecutive_failures)
|
224
|
+
|
225
|
+
|
226
|
+
def main():
|
227
|
+
"""Main entry point with signal handling"""
|
228
|
+
|
229
|
+
def signal_handler(signum, frame):
|
230
|
+
logging.info(f"Received signal {signum}, exiting worker process.")
|
231
|
+
sys.exit(0)
|
232
|
+
|
233
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
234
|
+
signal.signal(signal.SIGINT, signal_handler)
|
235
|
+
|
236
|
+
# Run the listener - this will block until a shutdown signal is received
|
237
|
+
result = run_listener()
|
238
|
+
sys.exit(result)
|
239
|
+
|
240
|
+
|
241
|
+
if __name__ == "__main__":
|
242
|
+
main()
|
@@ -97,6 +97,7 @@ def generate_service_spec(
|
|
97
97
|
payload: types.UploadedPayload,
|
98
98
|
args: Optional[List[str]] = None,
|
99
99
|
num_instances: Optional[int] = None,
|
100
|
+
enable_metrics: bool = False,
|
100
101
|
) -> Dict[str, Any]:
|
101
102
|
"""
|
102
103
|
Generate a service specification for a job.
|
@@ -107,20 +108,15 @@ def generate_service_spec(
|
|
107
108
|
payload: Uploaded job payload
|
108
109
|
args: Arguments to pass to entrypoint script
|
109
110
|
num_instances: Number of instances for multi-node job
|
111
|
+
enable_metrics: Enable platform metrics for the job
|
110
112
|
|
111
113
|
Returns:
|
112
114
|
Job service specification
|
113
115
|
"""
|
114
116
|
is_multi_node = num_instances is not None and num_instances > 1
|
117
|
+
image_spec = _get_image_spec(session, compute_pool)
|
115
118
|
|
116
119
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
117
|
-
if is_multi_node:
|
118
|
-
# If the job is of multi-node, we will need a different image which contains
|
119
|
-
# module snowflake.runtime.utils.get_instance_ip
|
120
|
-
# TODO(SNOW-1961849): Remove the hard-coded image name
|
121
|
-
image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
|
122
|
-
else:
|
123
|
-
image_spec = _get_image_spec(session, compute_pool)
|
124
120
|
resource_requests: Dict[str, Union[str, int]] = {
|
125
121
|
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
126
122
|
"memory": f"{image_spec.resource_limits.memory}Gi",
|
@@ -189,7 +185,10 @@ def generate_service_spec(
|
|
189
185
|
|
190
186
|
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
191
187
|
|
192
|
-
env_vars = {
|
188
|
+
env_vars = {
|
189
|
+
constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
|
190
|
+
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
191
|
+
}
|
193
192
|
endpoints = []
|
194
193
|
|
195
194
|
if is_multi_node:
|
@@ -211,6 +210,16 @@ def generate_service_spec(
|
|
211
210
|
]
|
212
211
|
endpoints.extend(ray_endpoints)
|
213
212
|
|
213
|
+
metrics = []
|
214
|
+
if enable_metrics:
|
215
|
+
# https://docs.snowflake.com/en/developer-guide/snowpark-container-services/monitoring-services#label-spcs-available-platform-metrics
|
216
|
+
metrics = [
|
217
|
+
"system",
|
218
|
+
"status",
|
219
|
+
"network",
|
220
|
+
"storage",
|
221
|
+
]
|
222
|
+
|
214
223
|
spec_dict = {
|
215
224
|
"containers": [
|
216
225
|
{
|
@@ -233,6 +242,16 @@ def generate_service_spec(
|
|
233
242
|
}
|
234
243
|
if endpoints:
|
235
244
|
spec_dict["endpoints"] = endpoints
|
245
|
+
if metrics:
|
246
|
+
spec_dict.update(
|
247
|
+
{
|
248
|
+
"platformMonitor": {
|
249
|
+
"metricConfig": {
|
250
|
+
"groups": metrics,
|
251
|
+
},
|
252
|
+
},
|
253
|
+
}
|
254
|
+
)
|
236
255
|
|
237
256
|
# Assemble into service specification dict
|
238
257
|
spec = {"spec": spec_dict}
|
@@ -11,6 +11,12 @@ JOB_STATUS = Literal[
|
|
11
11
|
]
|
12
12
|
|
13
13
|
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class PayloadEntrypoint:
|
16
|
+
file_path: PurePath
|
17
|
+
main_func: Optional[str]
|
18
|
+
|
19
|
+
|
14
20
|
@dataclass(frozen=True)
|
15
21
|
class UploadedPayload:
|
16
22
|
# TODO: Include manifest of payload files for validation
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -19,14 +19,16 @@ _ReturnValue = TypeVar("_ReturnValue")
|
|
19
19
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
20
20
|
def remote(
|
21
21
|
compute_pool: str,
|
22
|
+
*,
|
22
23
|
stage_name: str,
|
23
24
|
pip_requirements: Optional[List[str]] = None,
|
24
25
|
external_access_integrations: Optional[List[str]] = None,
|
25
26
|
query_warehouse: Optional[str] = None,
|
26
27
|
env_vars: Optional[Dict[str, str]] = None,
|
27
|
-
session: Optional[snowpark.Session] = None,
|
28
28
|
num_instances: Optional[int] = None,
|
29
|
-
|
29
|
+
enable_metrics: bool = False,
|
30
|
+
session: Optional[snowpark.Session] = None,
|
31
|
+
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
30
32
|
"""
|
31
33
|
Submit a job to the compute pool.
|
32
34
|
|
@@ -37,14 +39,15 @@ def remote(
|
|
37
39
|
external_access_integrations: A list of external access integrations.
|
38
40
|
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
39
41
|
env_vars: Environment variables to set in container
|
40
|
-
session: The Snowpark session to use. If none specified, uses active session.
|
41
42
|
num_instances: The number of nodes in the job. If none specified, create a single node job.
|
43
|
+
enable_metrics: Whether to enable metrics publishing for the job.
|
44
|
+
session: The Snowpark session to use. If none specified, uses active session.
|
42
45
|
|
43
46
|
Returns:
|
44
47
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
45
48
|
"""
|
46
49
|
|
47
|
-
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
|
50
|
+
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
|
48
51
|
# Copy the function to avoid modifying the original
|
49
52
|
# We need to modify the line number of the function to exclude the
|
50
53
|
# decorator from the copied source code
|
@@ -52,7 +55,7 @@ def remote(
|
|
52
55
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
53
56
|
|
54
57
|
@functools.wraps(func)
|
55
|
-
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
|
58
|
+
def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
56
59
|
payload = functools.partial(func, *args, **kwargs)
|
57
60
|
setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
|
58
61
|
job = jm._submit_job(
|
@@ -63,8 +66,9 @@ def remote(
|
|
63
66
|
external_access_integrations=external_access_integrations,
|
64
67
|
query_warehouse=query_warehouse,
|
65
68
|
env_vars=env_vars,
|
66
|
-
session=session,
|
67
69
|
num_instances=num_instances,
|
70
|
+
enable_metrics=enable_metrics,
|
71
|
+
session=session,
|
68
72
|
)
|
69
73
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
70
74
|
return job
|
snowflake/ml/jobs/job.py
CHANGED
@@ -1,20 +1,32 @@
|
|
1
1
|
import time
|
2
|
-
from typing import Any, List, Optional, cast
|
2
|
+
from typing import Any, Dict, Generic, List, Optional, TypeVar, cast
|
3
|
+
|
4
|
+
import yaml
|
3
5
|
|
4
6
|
from snowflake import snowpark
|
5
7
|
from snowflake.ml._internal import telemetry
|
6
|
-
from snowflake.ml.jobs._utils import constants, types
|
8
|
+
from snowflake.ml.jobs._utils import constants, interop_utils, types
|
7
9
|
from snowflake.snowpark import context as sp_context
|
8
10
|
|
9
11
|
_PROJECT = "MLJob"
|
10
12
|
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
11
13
|
|
14
|
+
T = TypeVar("T")
|
15
|
+
|
12
16
|
|
13
|
-
class MLJob:
|
14
|
-
def __init__(
|
17
|
+
class MLJob(Generic[T]):
|
18
|
+
def __init__(
|
19
|
+
self,
|
20
|
+
id: str,
|
21
|
+
service_spec: Optional[Dict[str, Any]] = None,
|
22
|
+
session: Optional[snowpark.Session] = None,
|
23
|
+
) -> None:
|
15
24
|
self._id = id
|
25
|
+
self._service_spec_cached: Optional[Dict[str, Any]] = service_spec
|
16
26
|
self._session = session or sp_context.get_active_session()
|
27
|
+
|
17
28
|
self._status: types.JOB_STATUS = "PENDING"
|
29
|
+
self._result: Optional[interop_utils.ExecutionResult] = None
|
18
30
|
|
19
31
|
@property
|
20
32
|
def id(self) -> str:
|
@@ -29,33 +41,66 @@ class MLJob:
|
|
29
41
|
self._status = _get_status(self._session, self.id)
|
30
42
|
return self._status
|
31
43
|
|
44
|
+
@property
|
45
|
+
def _service_spec(self) -> Dict[str, Any]:
|
46
|
+
"""Get the job's service spec."""
|
47
|
+
if not self._service_spec_cached:
|
48
|
+
self._service_spec_cached = _get_service_spec(self._session, self.id)
|
49
|
+
return self._service_spec_cached
|
50
|
+
|
51
|
+
@property
|
52
|
+
def _container_spec(self) -> Dict[str, Any]:
|
53
|
+
"""Get the job's main container spec."""
|
54
|
+
containers = self._service_spec["spec"]["containers"]
|
55
|
+
container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
|
56
|
+
return cast(Dict[str, Any], container_spec)
|
57
|
+
|
58
|
+
@property
|
59
|
+
def _stage_path(self) -> str:
|
60
|
+
"""Get the job's artifact storage stage location."""
|
61
|
+
volumes = self._service_spec["spec"]["volumes"]
|
62
|
+
stage_path = next(v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME)["source"]
|
63
|
+
return cast(str, stage_path)
|
64
|
+
|
65
|
+
@property
|
66
|
+
def _result_path(self) -> str:
|
67
|
+
"""Get the job's result file location."""
|
68
|
+
result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
69
|
+
if result_path is None:
|
70
|
+
raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
|
71
|
+
return f"{self._stage_path}/{result_path}"
|
72
|
+
|
32
73
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
33
|
-
def get_logs(self, limit: int = -1) -> str:
|
74
|
+
def get_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> str:
|
34
75
|
"""
|
35
76
|
Return the job's execution logs.
|
36
77
|
|
37
78
|
Args:
|
38
79
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
80
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
81
|
+
If not provided, returns logs from the head node.
|
39
82
|
|
40
83
|
Returns:
|
41
84
|
The job's execution logs.
|
42
85
|
"""
|
43
|
-
logs = _get_logs(self._session, self.id, limit)
|
86
|
+
logs = _get_logs(self._session, self.id, limit, instance_id)
|
44
87
|
assert isinstance(logs, str) # mypy
|
45
88
|
return logs
|
46
89
|
|
47
90
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
48
|
-
def show_logs(self, limit: int = -1) -> None:
|
91
|
+
def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
|
49
92
|
"""
|
50
93
|
Display the job's execution logs.
|
51
94
|
|
52
95
|
Args:
|
53
96
|
limit: The maximum number of lines to display. Negative values are treated as no limit.
|
97
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
98
|
+
If not provided, displays logs from the head node.
|
54
99
|
"""
|
55
|
-
print(self.get_logs(limit)) # noqa: T201: we need to print here.
|
100
|
+
print(self.get_logs(limit, instance_id)) # noqa: T201: we need to print here.
|
56
101
|
|
57
102
|
@snowpark._internal.utils.private_preview(version="1.7.4")
|
58
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
103
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
59
104
|
def wait(self, timeout: float = -1) -> types.JOB_STATUS:
|
60
105
|
"""
|
61
106
|
Block until completion. Returns completion status.
|
@@ -78,20 +123,58 @@ class MLJob:
|
|
78
123
|
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
79
124
|
return self.status
|
80
125
|
|
126
|
+
@snowpark._internal.utils.private_preview(version="1.8.2")
|
127
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
|
128
|
+
def result(self, timeout: float = -1) -> T:
|
129
|
+
"""
|
130
|
+
Block until completion. Returns job execution result.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
T: The deserialized job result. # noqa: DAR401
|
137
|
+
|
138
|
+
Raises:
|
139
|
+
RuntimeError: If the job failed or if the job doesn't have a result to retrieve.
|
140
|
+
TimeoutError: If the job does not complete within the specified timeout. # noqa: DAR402
|
141
|
+
"""
|
142
|
+
if self._result is None:
|
143
|
+
self.wait(timeout)
|
144
|
+
try:
|
145
|
+
self._result = interop_utils.fetch_result(self._session, self._result_path)
|
146
|
+
except Exception as e:
|
147
|
+
raise RuntimeError(f"Failed to retrieve result for job (id={self.id})") from e
|
148
|
+
|
149
|
+
if self._result.success:
|
150
|
+
return cast(T, self._result.result)
|
151
|
+
raise RuntimeError(f"Job execution failed (id={self.id})") from self._result.exception
|
152
|
+
|
153
|
+
|
154
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
155
|
+
def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
|
156
|
+
"""Retrieve job or job instance execution status."""
|
157
|
+
if instance_id is not None:
|
158
|
+
# Get specific instance status
|
159
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
160
|
+
for row in rows:
|
161
|
+
if row["instance_id"] == str(instance_id):
|
162
|
+
return cast(types.JOB_STATUS, row["status"])
|
163
|
+
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
164
|
+
else:
|
165
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
166
|
+
return cast(types.JOB_STATUS, row["status"])
|
167
|
+
|
81
168
|
|
82
169
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
83
|
-
def
|
84
|
-
"""Retrieve job execution
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
|
94
|
-
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
170
|
+
def _get_service_spec(session: snowpark.Session, job_id: str) -> Dict[str, Any]:
|
171
|
+
"""Retrieve job execution service spec."""
|
172
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
|
173
|
+
return cast(Dict[str, Any], yaml.safe_load(row["spec"]))
|
174
|
+
|
175
|
+
|
176
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
|
177
|
+
def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None) -> str:
|
95
178
|
"""
|
96
179
|
Retrieve the job's execution logs.
|
97
180
|
|
@@ -99,15 +182,54 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
|
|
99
182
|
job_id: The job ID.
|
100
183
|
limit: The maximum number of lines to return. Negative values are treated as no limit.
|
101
184
|
session: The Snowpark session to use. If none specified, uses active session.
|
185
|
+
instance_id: Optional instance ID to get logs from a specific instance.
|
102
186
|
|
103
187
|
Returns:
|
104
188
|
The job's execution logs.
|
105
189
|
"""
|
106
|
-
|
190
|
+
# If instance_id is not specified, try to get the head instance ID
|
191
|
+
if instance_id is None:
|
192
|
+
instance_id = _get_head_instance_id(session, job_id)
|
193
|
+
|
194
|
+
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
195
|
+
params: List[Any] = [
|
196
|
+
job_id,
|
197
|
+
0 if instance_id is None else instance_id,
|
198
|
+
constants.DEFAULT_CONTAINER_NAME,
|
199
|
+
]
|
107
200
|
if limit > 0:
|
108
201
|
params.append(limit)
|
202
|
+
|
109
203
|
(row,) = session.sql(
|
110
|
-
f"SELECT SYSTEM$GET_SERVICE_LOGS(?,
|
204
|
+
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
111
205
|
params=params,
|
112
206
|
).collect()
|
113
207
|
return str(row[0])
|
208
|
+
|
209
|
+
|
210
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
211
|
+
def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[int]:
|
212
|
+
"""
|
213
|
+
Retrieve the head instance ID of a job.
|
214
|
+
|
215
|
+
Args:
|
216
|
+
session: The Snowpark session to use.
|
217
|
+
job_id: The job ID.
|
218
|
+
|
219
|
+
Returns:
|
220
|
+
The head instance ID of the job. Returns None if the head instance has not started yet.
|
221
|
+
"""
|
222
|
+
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
223
|
+
if not rows:
|
224
|
+
return None
|
225
|
+
|
226
|
+
# Sort by start_time first, then by instance_id
|
227
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
228
|
+
head_instance = sorted_instances[0]
|
229
|
+
if not head_instance["start_time"]:
|
230
|
+
# If head instance hasn't started yet, return None
|
231
|
+
return None
|
232
|
+
try:
|
233
|
+
return int(head_instance["instance_id"])
|
234
|
+
except (ValueError, TypeError):
|
235
|
+
return 0
|