snowflake-ml-python 1.8.0__py3-none-any.whl → 1.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. snowflake/cortex/_complete.py +44 -10
  2. snowflake/ml/_internal/platform_capabilities.py +39 -3
  3. snowflake/ml/data/data_connector.py +25 -0
  4. snowflake/ml/dataset/dataset_reader.py +5 -1
  5. snowflake/ml/jobs/_utils/constants.py +3 -5
  6. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +81 -47
  8. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  9. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
  11. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  12. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  13. snowflake/ml/jobs/_utils/spec_utils.py +27 -8
  14. snowflake/ml/jobs/_utils/types.py +6 -0
  15. snowflake/ml/jobs/decorators.py +10 -6
  16. snowflake/ml/jobs/job.py +145 -23
  17. snowflake/ml/jobs/manager.py +79 -12
  18. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  19. snowflake/ml/model/_client/ops/service_ops.py +57 -39
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
  21. snowflake/ml/model/_client/sql/service.py +11 -5
  22. snowflake/ml/model/_model_composer/model_composer.py +29 -11
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -2
  24. snowflake/ml/model/_packager/model_env/model_env.py +8 -2
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
  26. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -1
  27. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
  28. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  29. snowflake/ml/model/_packager/model_packager.py +2 -0
  30. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  31. snowflake/ml/model/type_hints.py +2 -0
  32. snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
  33. snowflake/ml/registry/_manager/model_manager.py +20 -1
  34. snowflake/ml/registry/registry.py +46 -2
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +55 -4
  37. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +40 -34
  38. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +1 -1
  39. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  40. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,242 @@
1
+ #!/usr/bin/env python3
2
+ # This file is part of the Ray-based distributed job system for Snowflake ML.
3
+ # Architecture overview:
4
+ # - Head node creates a ShutdownSignal actor and signals workers when job completes
5
+ # - Worker nodes listen for this signal via this script and gracefully shut down
6
+ # - This ensures clean termination of distributed Ray jobs
7
+ import logging
8
+ import signal
9
+ import sys
10
+ import time
11
+ from typing import Optional
12
+
13
+ import get_instance_ip
14
+ import ray
15
+ from constants import (
16
+ SHUTDOWN_ACTOR_NAME,
17
+ SHUTDOWN_ACTOR_NAMESPACE,
18
+ SHUTDOWN_RPC_TIMEOUT_SECONDS,
19
+ )
20
+ from ray.actor import ActorHandle
21
+
22
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
+
24
+
25
+ def get_shutdown_actor() -> Optional[ActorHandle]:
26
+ """
27
+ Retrieve the shutdown signal actor from Ray.
28
+
29
+ Returns:
30
+ The shutdown signal actor or None if not found
31
+ """
32
+ try:
33
+ shutdown_signal = ray.get_actor(SHUTDOWN_ACTOR_NAME, namespace=SHUTDOWN_ACTOR_NAMESPACE)
34
+ return shutdown_signal
35
+ except Exception:
36
+ return None
37
+
38
+
39
+ def ping_shutdown_actor(shutdown_signal: ActorHandle) -> bool:
40
+ """
41
+ Ping the shutdown actor to ensure connectivity.
42
+
43
+ Args:
44
+ shutdown_signal: The Ray actor handle for the shutdown signal
45
+
46
+ Returns:
47
+ True if ping succeeds, False otherwise
48
+ """
49
+ try:
50
+ ping_result = ray.get(shutdown_signal.ping.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
51
+ logging.debug(f"Actor ping result: {ping_result}")
52
+ return True
53
+ except (ray.exceptions.GetTimeoutError, Exception) as e:
54
+ logging.debug(f"Actor ping failed: {e}")
55
+ return False
56
+
57
+
58
+ def check_shutdown_status(shutdown_signal: ActorHandle, worker_id: str) -> bool:
59
+ """
60
+ Check if worker should shutdown and acknowledge if needed.
61
+
62
+ Args:
63
+ shutdown_signal: The Ray actor handle for the shutdown signal
64
+ worker_id: Worker identifier (IP address)
65
+
66
+ Returns:
67
+ True if should shutdown, False otherwise
68
+ """
69
+ try:
70
+ status = ray.get(shutdown_signal.should_shutdown.remote(), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS)
71
+ logging.debug(f"Shutdown status: {status}")
72
+
73
+ if status.get("shutdown", False):
74
+ logging.info(
75
+ f"Received shutdown signal from head node at {status.get('timestamp')}. " f"Exiting worker process."
76
+ )
77
+
78
+ # Acknowledge shutdown before exiting
79
+ try:
80
+ ack_result = ray.get(
81
+ shutdown_signal.acknowledge_shutdown.remote(worker_id), timeout=SHUTDOWN_RPC_TIMEOUT_SECONDS
82
+ )
83
+ logging.info(f"Acknowledged shutdown: {ack_result}")
84
+ except Exception as e:
85
+ logging.warning(f"Failed to acknowledge shutdown: {e}. Continue to exit worker.")
86
+
87
+ return True
88
+ return False
89
+
90
+ except Exception as e:
91
+ logging.debug(f"Error checking shutdown status: {e}")
92
+ return False
93
+
94
+
95
+ def check_ray_connectivity() -> bool:
96
+ """
97
+ Check if the Ray cluster is accessible.
98
+
99
+ Returns:
100
+ True if Ray is connected, False otherwise
101
+ """
102
+ try:
103
+ # A simple check to verify Ray is working
104
+ nodes = ray.nodes()
105
+ if nodes:
106
+ return True
107
+ return False
108
+ except Exception as e:
109
+ logging.debug(f"Ray connectivity check failed: {e}")
110
+ return False
111
+
112
+
113
+ def initialize_ray_connection(max_retries: int, initial_retry_delay: int, max_retry_delay: int) -> bool:
114
+ """
115
+ Initialize connection to Ray with retries.
116
+
117
+ Args:
118
+ max_retries: Maximum number of connection attempts
119
+ initial_retry_delay: Initial delay between retries in seconds
120
+ max_retry_delay: Maximum delay between retries in seconds
121
+
122
+ Returns:
123
+ bool: True if connection successful, False otherwise
124
+ """
125
+ retry_count = 0
126
+ retry_delay = initial_retry_delay
127
+
128
+ while retry_count < max_retries:
129
+ try:
130
+ ray.init(address="auto", ignore_reinit_error=True)
131
+ return True
132
+ except (ConnectionError, TimeoutError, RuntimeError) as e:
133
+ retry_count += 1
134
+ if retry_count >= max_retries:
135
+ logging.error(f"Failed to connect to Ray head after {max_retries} attempts: {e}")
136
+ return False
137
+
138
+ logging.debug(
139
+ f"Attempt {retry_count}/{max_retries} to connect to Ray failed: {e}. "
140
+ f"Retrying in {retry_delay} seconds..."
141
+ )
142
+ time.sleep(retry_delay)
143
+ # Exponential backoff with cap
144
+ retry_delay = min(retry_delay * 1.5, max_retry_delay)
145
+
146
+ return False # Should not reach here, but added for completeness
147
+
148
+
149
+ def monitor_shutdown_signal(check_interval: int, max_consecutive_failures: int) -> int:
150
+ """
151
+ Main loop to monitor for shutdown signals.
152
+
153
+ Args:
154
+ check_interval: Time in seconds between checks
155
+ max_consecutive_failures: Maximum allowed consecutive connection failures
156
+
157
+ Returns:
158
+ int: Exit code (0 for success, non-zero for failure)
159
+
160
+ Raises:
161
+ ConnectionError: If Ray connection failures exceed threshold
162
+ """
163
+ worker_id = get_instance_ip.get_self_ip()
164
+ actor_check_count = 0
165
+ consecutive_connection_failures = 0
166
+
167
+ logging.debug(
168
+ f"Starting to monitor for shutdown signal using actor {SHUTDOWN_ACTOR_NAME}"
169
+ f" in namespace {SHUTDOWN_ACTOR_NAMESPACE}."
170
+ )
171
+
172
+ while True:
173
+ actor_check_count += 1
174
+
175
+ # Check Ray connectivity before proceeding
176
+ if not check_ray_connectivity():
177
+ consecutive_connection_failures += 1
178
+ logging.debug(
179
+ f"Ray connectivity check failed (attempt {consecutive_connection_failures}/{max_consecutive_failures})"
180
+ )
181
+ if consecutive_connection_failures >= max_consecutive_failures:
182
+ raise ConnectionError("Exceeded max consecutive Ray connection failures")
183
+ time.sleep(check_interval)
184
+ continue
185
+
186
+ # Reset counter on successful connection
187
+ consecutive_connection_failures = 0
188
+
189
+ # Get shutdown actor
190
+ shutdown_signal = get_shutdown_actor()
191
+ if not shutdown_signal:
192
+ logging.debug(f"Shutdown signal actor not found at check #{actor_check_count}, continuing to wait...")
193
+ time.sleep(check_interval)
194
+ continue
195
+
196
+ # Ping the actor to ensure connectivity
197
+ if not ping_shutdown_actor(shutdown_signal):
198
+ time.sleep(check_interval)
199
+ continue
200
+
201
+ # Check shutdown status
202
+ if check_shutdown_status(shutdown_signal, worker_id):
203
+ return 0
204
+
205
+ # Wait before checking again
206
+ time.sleep(check_interval)
207
+
208
+
209
+ def run_listener() -> int:
210
+ """Listen for shutdown signals from the head node"""
211
+ # Configuration
212
+ max_retries = 15
213
+ initial_retry_delay = 2
214
+ max_retry_delay = 30
215
+ check_interval = 5 # How often to check for ray connection or shutdown signal
216
+ max_consecutive_failures = 12 # Exit after about 1 minute of connection failures
217
+
218
+ # Initialize Ray connection
219
+ if not initialize_ray_connection(max_retries, initial_retry_delay, max_retry_delay):
220
+ raise ConnectionError("Failed to connect to Ray cluster. Aborting worker.")
221
+
222
+ # Monitor for shutdown signals
223
+ return monitor_shutdown_signal(check_interval, max_consecutive_failures)
224
+
225
+
226
+ def main():
227
+ """Main entry point with signal handling"""
228
+
229
+ def signal_handler(signum, frame):
230
+ logging.info(f"Received signal {signum}, exiting worker process.")
231
+ sys.exit(0)
232
+
233
+ signal.signal(signal.SIGTERM, signal_handler)
234
+ signal.signal(signal.SIGINT, signal_handler)
235
+
236
+ # Run the listener - this will block until a shutdown signal is received
237
+ result = run_listener()
238
+ sys.exit(result)
239
+
240
+
241
+ if __name__ == "__main__":
242
+ main()
@@ -97,6 +97,7 @@ def generate_service_spec(
97
97
  payload: types.UploadedPayload,
98
98
  args: Optional[List[str]] = None,
99
99
  num_instances: Optional[int] = None,
100
+ enable_metrics: bool = False,
100
101
  ) -> Dict[str, Any]:
101
102
  """
102
103
  Generate a service specification for a job.
@@ -107,20 +108,15 @@ def generate_service_spec(
107
108
  payload: Uploaded job payload
108
109
  args: Arguments to pass to entrypoint script
109
110
  num_instances: Number of instances for multi-node job
111
+ enable_metrics: Enable platform metrics for the job
110
112
 
111
113
  Returns:
112
114
  Job service specification
113
115
  """
114
116
  is_multi_node = num_instances is not None and num_instances > 1
117
+ image_spec = _get_image_spec(session, compute_pool)
115
118
 
116
119
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
117
- if is_multi_node:
118
- # If the job is of multi-node, we will need a different image which contains
119
- # module snowflake.runtime.utils.get_instance_ip
120
- # TODO(SNOW-1961849): Remove the hard-coded image name
121
- image_spec = _get_image_spec(session, compute_pool, constants.MULTINODE_HEADLESS_IMAGE_TAG)
122
- else:
123
- image_spec = _get_image_spec(session, compute_pool)
124
120
  resource_requests: Dict[str, Union[str, int]] = {
125
121
  "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
126
122
  "memory": f"{image_spec.resource_limits.memory}Gi",
@@ -189,7 +185,10 @@ def generate_service_spec(
189
185
 
190
186
  # TODO: Add hooks for endpoints for integration with TensorBoard etc
191
187
 
192
- env_vars = {constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix()}
188
+ env_vars = {
189
+ constants.PAYLOAD_DIR_ENV_VAR: stage_mount.as_posix(),
190
+ constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
191
+ }
193
192
  endpoints = []
194
193
 
195
194
  if is_multi_node:
@@ -211,6 +210,16 @@ def generate_service_spec(
211
210
  ]
212
211
  endpoints.extend(ray_endpoints)
213
212
 
213
+ metrics = []
214
+ if enable_metrics:
215
+ # https://docs.snowflake.com/en/developer-guide/snowpark-container-services/monitoring-services#label-spcs-available-platform-metrics
216
+ metrics = [
217
+ "system",
218
+ "status",
219
+ "network",
220
+ "storage",
221
+ ]
222
+
214
223
  spec_dict = {
215
224
  "containers": [
216
225
  {
@@ -233,6 +242,16 @@ def generate_service_spec(
233
242
  }
234
243
  if endpoints:
235
244
  spec_dict["endpoints"] = endpoints
245
+ if metrics:
246
+ spec_dict.update(
247
+ {
248
+ "platformMonitor": {
249
+ "metricConfig": {
250
+ "groups": metrics,
251
+ },
252
+ },
253
+ }
254
+ )
236
255
 
237
256
  # Assemble into service specification dict
238
257
  spec = {"spec": spec_dict}
@@ -11,6 +11,12 @@ JOB_STATUS = Literal[
11
11
  ]
12
12
 
13
13
 
14
+ @dataclass(frozen=True)
15
+ class PayloadEntrypoint:
16
+ file_path: PurePath
17
+ main_func: Optional[str]
18
+
19
+
14
20
  @dataclass(frozen=True)
15
21
  class UploadedPayload:
16
22
  # TODO: Include manifest of payload files for validation
@@ -19,14 +19,16 @@ _ReturnValue = TypeVar("_ReturnValue")
19
19
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
20
20
  def remote(
21
21
  compute_pool: str,
22
+ *,
22
23
  stage_name: str,
23
24
  pip_requirements: Optional[List[str]] = None,
24
25
  external_access_integrations: Optional[List[str]] = None,
25
26
  query_warehouse: Optional[str] = None,
26
27
  env_vars: Optional[Dict[str, str]] = None,
27
- session: Optional[snowpark.Session] = None,
28
28
  num_instances: Optional[int] = None,
29
- ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob]]:
29
+ enable_metrics: bool = False,
30
+ session: Optional[snowpark.Session] = None,
31
+ ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
30
32
  """
31
33
  Submit a job to the compute pool.
32
34
 
@@ -37,14 +39,15 @@ def remote(
37
39
  external_access_integrations: A list of external access integrations.
38
40
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
39
41
  env_vars: Environment variables to set in container
40
- session: The Snowpark session to use. If none specified, uses active session.
41
42
  num_instances: The number of nodes in the job. If none specified, create a single node job.
43
+ enable_metrics: Whether to enable metrics publishing for the job.
44
+ session: The Snowpark session to use. If none specified, uses active session.
42
45
 
43
46
  Returns:
44
47
  Decorator that dispatches invocations of the decorated function as remote jobs.
45
48
  """
46
49
 
47
- def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob]:
50
+ def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
48
51
  # Copy the function to avoid modifying the original
49
52
  # We need to modify the line number of the function to exclude the
50
53
  # decorator from the copied source code
@@ -52,7 +55,7 @@ def remote(
52
55
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
53
56
 
54
57
  @functools.wraps(func)
55
- def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob:
58
+ def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
56
59
  payload = functools.partial(func, *args, **kwargs)
57
60
  setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
58
61
  job = jm._submit_job(
@@ -63,8 +66,9 @@ def remote(
63
66
  external_access_integrations=external_access_integrations,
64
67
  query_warehouse=query_warehouse,
65
68
  env_vars=env_vars,
66
- session=session,
67
69
  num_instances=num_instances,
70
+ enable_metrics=enable_metrics,
71
+ session=session,
68
72
  )
69
73
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
70
74
  return job
snowflake/ml/jobs/job.py CHANGED
@@ -1,20 +1,32 @@
1
1
  import time
2
- from typing import Any, List, Optional, cast
2
+ from typing import Any, Dict, Generic, List, Optional, TypeVar, cast
3
+
4
+ import yaml
3
5
 
4
6
  from snowflake import snowpark
5
7
  from snowflake.ml._internal import telemetry
6
- from snowflake.ml.jobs._utils import constants, types
8
+ from snowflake.ml.jobs._utils import constants, interop_utils, types
7
9
  from snowflake.snowpark import context as sp_context
8
10
 
9
11
  _PROJECT = "MLJob"
10
12
  TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
11
13
 
14
+ T = TypeVar("T")
15
+
12
16
 
13
- class MLJob:
14
- def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
17
+ class MLJob(Generic[T]):
18
+ def __init__(
19
+ self,
20
+ id: str,
21
+ service_spec: Optional[Dict[str, Any]] = None,
22
+ session: Optional[snowpark.Session] = None,
23
+ ) -> None:
15
24
  self._id = id
25
+ self._service_spec_cached: Optional[Dict[str, Any]] = service_spec
16
26
  self._session = session or sp_context.get_active_session()
27
+
17
28
  self._status: types.JOB_STATUS = "PENDING"
29
+ self._result: Optional[interop_utils.ExecutionResult] = None
18
30
 
19
31
  @property
20
32
  def id(self) -> str:
@@ -29,33 +41,66 @@ class MLJob:
29
41
  self._status = _get_status(self._session, self.id)
30
42
  return self._status
31
43
 
44
+ @property
45
+ def _service_spec(self) -> Dict[str, Any]:
46
+ """Get the job's service spec."""
47
+ if not self._service_spec_cached:
48
+ self._service_spec_cached = _get_service_spec(self._session, self.id)
49
+ return self._service_spec_cached
50
+
51
+ @property
52
+ def _container_spec(self) -> Dict[str, Any]:
53
+ """Get the job's main container spec."""
54
+ containers = self._service_spec["spec"]["containers"]
55
+ container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
56
+ return cast(Dict[str, Any], container_spec)
57
+
58
+ @property
59
+ def _stage_path(self) -> str:
60
+ """Get the job's artifact storage stage location."""
61
+ volumes = self._service_spec["spec"]["volumes"]
62
+ stage_path = next(v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME)["source"]
63
+ return cast(str, stage_path)
64
+
65
+ @property
66
+ def _result_path(self) -> str:
67
+ """Get the job's result file location."""
68
+ result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
69
+ if result_path is None:
70
+ raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
71
+ return f"{self._stage_path}/{result_path}"
72
+
32
73
  @snowpark._internal.utils.private_preview(version="1.7.4")
33
- def get_logs(self, limit: int = -1) -> str:
74
+ def get_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> str:
34
75
  """
35
76
  Return the job's execution logs.
36
77
 
37
78
  Args:
38
79
  limit: The maximum number of lines to return. Negative values are treated as no limit.
80
+ instance_id: Optional instance ID to get logs from a specific instance.
81
+ If not provided, returns logs from the head node.
39
82
 
40
83
  Returns:
41
84
  The job's execution logs.
42
85
  """
43
- logs = _get_logs(self._session, self.id, limit)
86
+ logs = _get_logs(self._session, self.id, limit, instance_id)
44
87
  assert isinstance(logs, str) # mypy
45
88
  return logs
46
89
 
47
90
  @snowpark._internal.utils.private_preview(version="1.7.4")
48
- def show_logs(self, limit: int = -1) -> None:
91
+ def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
49
92
  """
50
93
  Display the job's execution logs.
51
94
 
52
95
  Args:
53
96
  limit: The maximum number of lines to display. Negative values are treated as no limit.
97
+ instance_id: Optional instance ID to get logs from a specific instance.
98
+ If not provided, displays logs from the head node.
54
99
  """
55
- print(self.get_logs(limit)) # noqa: T201: we need to print here.
100
+ print(self.get_logs(limit, instance_id)) # noqa: T201: we need to print here.
56
101
 
57
102
  @snowpark._internal.utils.private_preview(version="1.7.4")
58
- @telemetry.send_api_usage_telemetry(project=_PROJECT)
103
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
59
104
  def wait(self, timeout: float = -1) -> types.JOB_STATUS:
60
105
  """
61
106
  Block until completion. Returns completion status.
@@ -78,20 +123,58 @@ class MLJob:
78
123
  delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
79
124
  return self.status
80
125
 
126
+ @snowpark._internal.utils.private_preview(version="1.8.2")
127
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
128
+ def result(self, timeout: float = -1) -> T:
129
+ """
130
+ Block until completion. Returns job execution result.
131
+
132
+ Args:
133
+ timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
134
+
135
+ Returns:
136
+ T: The deserialized job result. # noqa: DAR401
137
+
138
+ Raises:
139
+ RuntimeError: If the job failed or if the job doesn't have a result to retrieve.
140
+ TimeoutError: If the job does not complete within the specified timeout. # noqa: DAR402
141
+ """
142
+ if self._result is None:
143
+ self.wait(timeout)
144
+ try:
145
+ self._result = interop_utils.fetch_result(self._session, self._result_path)
146
+ except Exception as e:
147
+ raise RuntimeError(f"Failed to retrieve result for job (id={self.id})") from e
148
+
149
+ if self._result.success:
150
+ return cast(T, self._result.result)
151
+ raise RuntimeError(f"Job execution failed (id={self.id})") from self._result.exception
152
+
153
+
154
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
155
+ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
156
+ """Retrieve job or job instance execution status."""
157
+ if instance_id is not None:
158
+ # Get specific instance status
159
+ rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
160
+ for row in rows:
161
+ if row["instance_id"] == str(instance_id):
162
+ return cast(types.JOB_STATUS, row["status"])
163
+ raise ValueError(f"Instance {instance_id} not found in job {job_id}")
164
+ else:
165
+ (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
166
+ return cast(types.JOB_STATUS, row["status"])
167
+
81
168
 
82
169
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
83
- def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
84
- """Retrieve job execution status."""
85
- # TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
86
- # `DESCRIBE` queries with bind variables
87
- # Switch to use bind variables instead of client side formatting after
88
- # updating to snowflake-snowpark-python>=1.24.0
89
- (row,) = session.sql(f"DESCRIBE SERVICE {job_id}").collect()
90
- return cast(types.JOB_STATUS, row["status"])
91
-
92
-
93
- @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
94
- def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
170
+ def _get_service_spec(session: snowpark.Session, job_id: str) -> Dict[str, Any]:
171
+ """Retrieve job execution service spec."""
172
+ (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
173
+ return cast(Dict[str, Any], yaml.safe_load(row["spec"]))
174
+
175
+
176
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
177
+ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None) -> str:
95
178
  """
96
179
  Retrieve the job's execution logs.
97
180
 
@@ -99,15 +182,54 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
99
182
  job_id: The job ID.
100
183
  limit: The maximum number of lines to return. Negative values are treated as no limit.
101
184
  session: The Snowpark session to use. If none specified, uses active session.
185
+ instance_id: Optional instance ID to get logs from a specific instance.
102
186
 
103
187
  Returns:
104
188
  The job's execution logs.
105
189
  """
106
- params: List[Any] = [job_id]
190
+ # If instance_id is not specified, try to get the head instance ID
191
+ if instance_id is None:
192
+ instance_id = _get_head_instance_id(session, job_id)
193
+
194
+ # Assemble params: [job_id, instance_id, container_name, (optional) limit]
195
+ params: List[Any] = [
196
+ job_id,
197
+ 0 if instance_id is None else instance_id,
198
+ constants.DEFAULT_CONTAINER_NAME,
199
+ ]
107
200
  if limit > 0:
108
201
  params.append(limit)
202
+
109
203
  (row,) = session.sql(
110
- f"SELECT SYSTEM$GET_SERVICE_LOGS(?, 0, '{constants.DEFAULT_CONTAINER_NAME}'{f', ?' if limit > 0 else ''})",
204
+ f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
111
205
  params=params,
112
206
  ).collect()
113
207
  return str(row[0])
208
+
209
+
210
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
211
+ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[int]:
212
+ """
213
+ Retrieve the head instance ID of a job.
214
+
215
+ Args:
216
+ session: The Snowpark session to use.
217
+ job_id: The job ID.
218
+
219
+ Returns:
220
+ The head instance ID of the job. Returns None if the head instance has not started yet.
221
+ """
222
+ rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
223
+ if not rows:
224
+ return None
225
+
226
+ # Sort by start_time first, then by instance_id
227
+ sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
228
+ head_instance = sorted_instances[0]
229
+ if not head_instance["start_time"]:
230
+ # If head instance hasn't started yet, return None
231
+ return None
232
+ try:
233
+ return int(head_instance["instance_id"])
234
+ except (ValueError, TypeError):
235
+ return 0