snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/ml/_internal/telemetry.py +42 -16
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/data/data_connector.py +1 -1
  4. snowflake/ml/jobs/__init__.py +2 -0
  5. snowflake/ml/jobs/_utils/constants.py +12 -2
  6. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +95 -39
  9. snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
  11. snowflake/ml/jobs/_utils/spec_utils.py +30 -6
  12. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  13. snowflake/ml/jobs/_utils/types.py +5 -1
  14. snowflake/ml/jobs/decorators.py +10 -7
  15. snowflake/ml/jobs/job.py +176 -28
  16. snowflake/ml/jobs/manager.py +119 -26
  17. snowflake/ml/model/_client/model/model_impl.py +58 -0
  18. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  20. snowflake/ml/model/_client/ops/service_ops.py +24 -7
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
  22. snowflake/ml/model/_client/sql/model_version.py +1 -1
  23. snowflake/ml/model/_client/sql/service.py +73 -28
  24. snowflake/ml/model/_client/sql/stage.py +5 -2
  25. snowflake/ml/model/_model_composer/model_composer.py +3 -1
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
  29. snowflake/ml/model/_signatures/core.py +24 -0
  30. snowflake/ml/monitoring/explain_visualize.py +160 -22
  31. snowflake/ml/monitoring/model_monitor.py +0 -4
  32. snowflake/ml/registry/registry.py +34 -14
  33. snowflake/ml/utils/connection_params.py +9 -3
  34. snowflake/ml/utils/html_utils.py +263 -0
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
  37. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
  38. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
  39. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  40. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,17 @@ import cloudpickle as cp
12
12
  from packaging import version
13
13
 
14
14
  from snowflake import snowpark
15
- from snowflake.ml.jobs._utils import constants, types
15
+ from snowflake.ml.jobs._utils import (
16
+ constants,
17
+ function_payload_utils,
18
+ stage_utils,
19
+ types,
20
+ )
16
21
  from snowflake.snowpark import exceptions as sp_exceptions
17
22
  from snowflake.snowpark._internal import code_generation
18
23
 
24
+ cp.register_pickle_by_value(function_payload_utils)
25
+
19
26
  _SUPPORTED_ARG_TYPES = {str, int, float}
20
27
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
21
28
  _ENTRYPOINT_FUNC_NAME = "func"
@@ -100,6 +107,11 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
100
107
  # Parse the output using read
101
108
  read head_index head_ip head_status<<< "$head_info"
102
109
 
110
+ if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
111
+ NODE_TYPE="worker"
112
+ echo "{constants.LOG_START_MSG}"
113
+ fi
114
+
103
115
  # Use the parsed variables
104
116
  echo "Head Instance Index: $head_index"
105
117
  echo "Head Instance IP: $head_ip"
@@ -117,9 +129,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
117
129
  exit 1
118
130
  fi
119
131
 
120
- if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
121
- NODE_TYPE="worker"
122
- fi
132
+
123
133
  fi
124
134
 
125
135
  # Common parameters for both head and worker nodes
@@ -168,6 +178,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
168
178
  # Start Ray on a worker node - run in background
169
179
  ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
170
180
 
181
+ echo "Worker node started on address $eth0Ip. See more logs in the head node."
182
+
183
+ echo "{constants.LOG_END_MSG}"
184
+
171
185
  # Start the worker shutdown listener in the background
172
186
  echo "Starting worker shutdown listener..."
173
187
  python worker_shutdown_listener.py
@@ -189,15 +203,16 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
189
203
 
190
204
  # Start Ray on the head node
191
205
  ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
206
+
192
207
  ##### End Ray configuration #####
193
208
 
194
209
  # TODO: Monitor MLRS and handle process crashes
195
210
  python -m web.ml_runtime_grpc_server &
196
211
 
197
212
  # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
213
+ echo Running command: python "$@"
198
214
 
199
215
  # Run user's Python entrypoint
200
- echo Running command: python "$@"
201
216
  python "$@"
202
217
 
203
218
  # After the user's job completes, signal workers to shut down
@@ -209,20 +224,23 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
209
224
  ).strip()
210
225
 
211
226
 
212
- def resolve_source(source: Union[Path, Callable[..., Any]]) -> Union[Path, Callable[..., Any]]:
227
+ def resolve_source(
228
+ source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
229
+ ) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
213
230
  if callable(source):
214
231
  return source
215
- elif isinstance(source, Path):
216
- # Validate source
217
- source = source
232
+ elif isinstance(source, (Path, stage_utils.StagePath)):
218
233
  if not source.exists():
219
234
  raise FileNotFoundError(f"{source} does not exist")
220
235
  return source.absolute()
221
236
  else:
222
- raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
237
+ raise ValueError("Unsupported source type. Source must be a stage, file, directory, or callable.")
223
238
 
224
239
 
225
- def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Optional[Path]) -> types.PayloadEntrypoint:
240
+ def resolve_entrypoint(
241
+ source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
242
+ entrypoint: Optional[Union[stage_utils.StagePath, Path]],
243
+ ) -> types.PayloadEntrypoint:
226
244
  if callable(source):
227
245
  # Entrypoint is generated for callable payloads
228
246
  return types.PayloadEntrypoint(
@@ -237,11 +255,11 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
237
255
  # Infer entrypoint from source
238
256
  entrypoint = parent
239
257
  else:
240
- raise ValueError("entrypoint must be provided when source is a directory")
258
+ raise ValueError("Entrypoint must be provided when source is a directory")
241
259
  elif entrypoint.is_absolute():
242
260
  # Absolute path - validate it's a subpath of source dir
243
261
  if not entrypoint.is_relative_to(parent):
244
- raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
262
+ raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}")
245
263
  else:
246
264
  # Relative path
247
265
  if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
@@ -257,6 +275,7 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
257
275
  "Entrypoint not found. Ensure the entrypoint is a valid file and is under"
258
276
  f" the source directory (source={parent}, entrypoint={entrypoint})"
259
277
  )
278
+
260
279
  if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
261
280
  raise ValueError(
262
281
  "Unsupported entrypoint type:"
@@ -277,8 +296,9 @@ class JobPayload:
277
296
  *,
278
297
  pip_requirements: Optional[list[str]] = None,
279
298
  ) -> None:
280
- self.source = Path(source) if isinstance(source, str) else source
281
- self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
299
+ # for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
300
+ self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
301
+ self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
282
302
  self.pip_requirements = pip_requirements
283
303
 
284
304
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
@@ -302,7 +322,7 @@ class JobPayload:
302
322
  ).collect()
303
323
 
304
324
  # Upload payload to stage
305
- if not isinstance(source, Path):
325
+ if not isinstance(source, (Path, stage_utils.StagePath)):
306
326
  source_code = generate_python_code(source, source_code_display=True)
307
327
  _ = session.file.put_stream(
308
328
  io.BytesIO(source_code.encode()),
@@ -313,27 +333,38 @@ class JobPayload:
313
333
  source = Path(entrypoint.file_path.parent)
314
334
  if not any(r.startswith("cloudpickle") for r in pip_requirements):
315
335
  pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
316
- elif source.is_dir():
317
- # Manually traverse the directory and upload each file, since Snowflake PUT
318
- # can't handle directories. Reduce the number of PUT operations by using
319
- # wildcard patterns to batch upload files with the same extension.
320
- for path in {
321
- p.parent.joinpath(f"*{p.suffix}") if p.suffix else p for p in source.resolve().rglob("*") if p.is_file()
322
- }:
336
+
337
+ elif isinstance(source, stage_utils.StagePath):
338
+ # copy payload to stage
339
+ if source == entrypoint.file_path:
340
+ source = source.parent
341
+ source_path = source.as_posix() + "/"
342
+ session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
343
+
344
+ elif isinstance(source, Path):
345
+ if source.is_dir():
346
+ # Manually traverse the directory and upload each file, since Snowflake PUT
347
+ # can't handle directories. Reduce the number of PUT operations by using
348
+ # wildcard patterns to batch upload files with the same extension.
349
+ for path in {
350
+ p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
351
+ for p in source.resolve().rglob("*")
352
+ if p.is_file()
353
+ }:
354
+ session.file.put(
355
+ str(path),
356
+ stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
357
+ overwrite=True,
358
+ auto_compress=False,
359
+ )
360
+ else:
323
361
  session.file.put(
324
- str(path),
325
- stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
362
+ str(source.resolve()),
363
+ stage_path.as_posix(),
326
364
  overwrite=True,
327
365
  auto_compress=False,
328
366
  )
329
- else:
330
- session.file.put(
331
- str(source.resolve()),
332
- stage_path.as_posix(),
333
- overwrite=True,
334
- auto_compress=False,
335
- )
336
- source = source.parent
367
+ source = source.parent
337
368
 
338
369
  # Upload requirements
339
370
  # TODO: Check if payload includes both a requirements.txt file and pip_requirements
@@ -494,9 +525,15 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
494
525
  return param_code
495
526
 
496
527
 
497
- def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
528
+ def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
498
529
  """Generate an entrypoint script from a Python function."""
499
- signature = inspect.signature(func)
530
+
531
+ if isinstance(payload, function_payload_utils.FunctionPayload):
532
+ function = payload.function
533
+ else:
534
+ function = payload
535
+
536
+ signature = inspect.signature(function)
500
537
  if any(
501
538
  p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
502
539
  for p in signature.parameters.values()
@@ -505,21 +542,20 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
505
542
 
506
543
  # Mirrored from Snowpark generate_python_code() function
507
544
  # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
508
- source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
545
+ source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
509
546
 
510
547
  arg_dict_name = "kwargs"
511
- if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
548
+ if isinstance(payload, function_payload_utils.FunctionPayload):
512
549
  param_code = f"{arg_dict_name} = {{}}"
513
550
  else:
514
551
  param_code = _generate_param_handler_code(signature, arg_dict_name)
515
-
516
552
  return f"""
517
553
  import sys
518
554
  import pickle
519
555
 
520
556
  try:
521
557
  {textwrap.indent(source_code_comment, ' ')}
522
- {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
558
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
523
559
  except (TypeError, pickle.PickleError):
524
560
  if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
525
561
  raise RuntimeError(
@@ -543,3 +579,23 @@ if __name__ == '__main__':
543
579
 
544
580
  __return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
545
581
  """
582
+
583
+
584
+ def create_function_payload(
585
+ func: Callable[..., Any], *args: Any, **kwargs: Any
586
+ ) -> function_payload_utils.FunctionPayload:
587
+ signature = inspect.signature(func)
588
+ bound = signature.bind(*args, **kwargs)
589
+ bound.apply_defaults()
590
+ session_argument = ""
591
+ session = None
592
+ for name, val in list(bound.arguments.items()):
593
+ if isinstance(val, snowpark.Session):
594
+ if session:
595
+ raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
596
+ session = val
597
+ session_argument = name
598
+ del bound.arguments[name]
599
+ payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
600
+
601
+ return payload
@@ -1,4 +1,26 @@
1
+ from snowflake.ml.jobs._utils import constants as mljob_constants
2
+
1
3
  # Constants defining the shutdown signal actor configuration.
2
4
  SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
3
5
  SHUTDOWN_ACTOR_NAMESPACE = "default"
4
6
  SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
7
+
8
+
9
+ # The followings are Inherited from snowflake.ml.jobs._utils.constants
10
+ # We need to copy them here since snowml package on the server side does
11
+ # not have the latest version of the code
12
+
13
+ # Log start and end messages
14
+ LOG_START_MSG = getattr(
15
+ mljob_constants,
16
+ "LOG_START_MSG",
17
+ "--------------------------------\nML job started\n--------------------------------",
18
+ )
19
+ LOG_END_MSG = getattr(
20
+ mljob_constants,
21
+ "LOG_END_MSG",
22
+ "--------------------------------\nML job finished\n--------------------------------",
23
+ )
24
+
25
+ # min_instances environment variable name
26
+ MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
@@ -2,25 +2,35 @@ import argparse
2
2
  import copy
3
3
  import importlib.util
4
4
  import json
5
+ import logging
5
6
  import os
6
7
  import runpy
7
8
  import sys
9
+ import time
8
10
  import traceback
9
11
  import warnings
10
12
  from pathlib import Path
11
13
  from typing import Any, Optional
12
14
 
13
15
  import cloudpickle
16
+ from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
14
17
 
15
18
  from snowflake.ml.jobs._utils import constants
16
19
  from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
17
20
  from snowflake.snowpark import Session
18
21
 
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
+ logger = logging.getLogger(__name__)
25
+
19
26
  # Fallbacks in case of SnowML version mismatch
20
27
  RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
21
-
22
28
  JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
23
29
 
30
+ # Constants for the wait_for_min_instances function
31
+ CHECK_INTERVAL = 10 # seconds
32
+ TIMEOUT = 720 # seconds
33
+
24
34
 
25
35
  try:
26
36
  from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
@@ -62,6 +72,48 @@ class SimpleJSONEncoder(json.JSONEncoder):
62
72
  return f"Unserializable object: {repr(obj)}"
63
73
 
64
74
 
75
+ def wait_for_min_instances(min_instances: int) -> None:
76
+ """
77
+ Wait until the specified minimum number of instances are available in the Ray cluster.
78
+
79
+ Args:
80
+ min_instances: Minimum number of instances required
81
+
82
+ Raises:
83
+ TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
84
+ """
85
+ if min_instances <= 1:
86
+ logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
87
+ return
88
+
89
+ # mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
90
+ from common_utils import common_util as mlrs_util
91
+
92
+ start_time = time.time()
93
+ timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
94
+ check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
95
+ logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
96
+
97
+ while time.time() - start_time < timeout:
98
+ total_nodes = mlrs_util.get_num_ray_nodes()
99
+
100
+ if total_nodes >= min_instances:
101
+ elapsed = time.time() - start_time
102
+ logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
103
+ return
104
+
105
+ logger.debug(
106
+ f"Waiting for instances: {total_nodes}/{min_instances} available "
107
+ f"(elapsed: {time.time() - start_time:.1f}s)"
108
+ )
109
+ time.sleep(check_interval)
110
+
111
+ raise TimeoutError(
112
+ f"Timed out after {timeout}s waiting for {min_instances} instances, only "
113
+ f"{mlrs_util.get_num_ray_nodes()} available"
114
+ )
115
+
116
+
65
117
  def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
66
118
  """
67
119
  Execute a Python script and return its result.
@@ -86,6 +138,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
86
138
  session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
87
139
 
88
140
  try:
141
+
89
142
  if main_func:
90
143
  # Use importlib for scripts with a main function defined
91
144
  module_name = Path(script_path).stem
@@ -126,9 +179,21 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
126
179
  Raises:
127
180
  Exception: Re-raises any exception caught during script execution.
128
181
  """
129
- # Run the script with the specified arguments
130
182
  try:
183
+ # Wait for minimum required instances if specified
184
+ min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
185
+ if min_instances_str and int(min_instances_str) > 1:
186
+ wait_for_min_instances(int(min_instances_str))
187
+
188
+ # Log start marker for user script execution
189
+ print(LOG_START_MSG) # noqa: T201
190
+
191
+ # Run the script with the specified arguments
131
192
  result = run_script(script_path, *script_args, main_func=script_main_func)
193
+
194
+ # Log end marker for user script execution
195
+ print(LOG_END_MSG) # noqa: T201
196
+
132
197
  result_obj = ExecutionResult(result=result)
133
198
  return result_obj
134
199
  except Exception as e:
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import os
2
3
  from math import ceil
3
4
  from pathlib import PurePath
4
5
  from typing import Any, Optional, Union
@@ -30,7 +31,7 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.Image
30
31
  # Use MLRuntime image
31
32
  image_repo = constants.DEFAULT_IMAGE_REPO
32
33
  image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
33
- image_tag = constants.DEFAULT_IMAGE_TAG
34
+ image_tag = _get_runtime_image_tag()
34
35
 
35
36
  # TODO: Should each instance consume the entire pod?
36
37
  return types.ImageSpec(
@@ -85,7 +86,8 @@ def generate_service_spec(
85
86
  compute_pool: str,
86
87
  payload: types.UploadedPayload,
87
88
  args: Optional[list[str]] = None,
88
- num_instances: Optional[int] = None,
89
+ target_instances: int = 1,
90
+ min_instances: int = 1,
89
91
  enable_metrics: bool = False,
90
92
  ) -> dict[str, Any]:
91
93
  """
@@ -96,13 +98,13 @@ def generate_service_spec(
96
98
  compute_pool: Compute pool for job execution
97
99
  payload: Uploaded job payload
98
100
  args: Arguments to pass to entrypoint script
99
- num_instances: Number of instances for multi-node job
101
+ target_instances: Number of instances for multi-node job
100
102
  enable_metrics: Enable platform metrics for the job
103
+ min_instances: Minimum number of instances required to start the job
101
104
 
102
105
  Returns:
103
106
  Job service specification
104
107
  """
105
- is_multi_node = num_instances is not None and num_instances > 1
106
108
  image_spec = _get_image_spec(session, compute_pool)
107
109
 
108
110
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
@@ -180,10 +182,11 @@ def generate_service_spec(
180
182
  }
181
183
  endpoints = []
182
184
 
183
- if is_multi_node:
185
+ if target_instances > 1:
184
186
  # Update environment variables for multi-node job
185
187
  env_vars.update(constants.RAY_PORTS)
186
- env_vars["ENABLE_HEALTH_CHECKS"] = constants.ENABLE_HEALTH_CHECKS
188
+ env_vars[constants.ENABLE_HEALTH_CHECKS_ENV_VAR] = constants.ENABLE_HEALTH_CHECKS
189
+ env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
187
190
 
188
191
  # Define Ray endpoints for intra-service instance communication
189
192
  ray_endpoints = [
@@ -344,3 +347,24 @@ def _merge_lists_of_dicts(
344
347
  result[key] = d
345
348
 
346
349
  return list(result.values())
350
+
351
+
352
+ def _get_runtime_image_tag() -> str:
353
+ """
354
+ Detect runtime image tag from container environment.
355
+
356
+ Checks in order:
357
+ 1. Environment variable MLRS_CONTAINER_IMAGE_TAG
358
+ 2. Falls back to hardcoded default
359
+
360
+ Returns:
361
+ str: The runtime image tag to use for job containers
362
+ """
363
+ env_tag = os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR)
364
+ if env_tag:
365
+ logging.debug(f"Using runtime image tag from environment: {env_tag}")
366
+ return env_tag
367
+
368
+ # Fall back to default
369
+ logging.debug(f"Using default runtime image tag: {constants.DEFAULT_IMAGE_TAG}")
370
+ return constants.DEFAULT_IMAGE_TAG
@@ -0,0 +1,119 @@
1
+ import os
2
+ import re
3
+ from os import PathLike
4
+ from pathlib import Path, PurePath
5
+ from typing import Union
6
+
7
+ from snowflake.ml._internal.utils import identifier
8
+
9
+ PROTOCOL_NAME = "snow"
10
+ _SNOWURL_PATH_RE = re.compile(
11
+ rf"^(?:(?:{PROTOCOL_NAME}://)?"
12
+ r"(?<!@)(?P<domain>\w+)/"
13
+ rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/)?"
14
+ r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
15
+ )
16
+
17
+ _STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
18
+
19
+
20
+ class StagePath:
21
+ def __init__(self, path: str) -> None:
22
+ stage_match = _SNOWURL_PATH_RE.fullmatch(path) or _STAGEF_PATH_RE.fullmatch(path)
23
+ if not stage_match:
24
+ raise ValueError(f"{path} is not a valid stage path")
25
+ path = path.strip()
26
+ self._raw_path = path
27
+ relpath = stage_match.group("relpath")
28
+ start, _ = stage_match.span("relpath")
29
+ self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
30
+ self._path = Path(relpath or "")
31
+
32
+ @property
33
+ def parent(self) -> "StagePath":
34
+ if self._path.parent == Path(""):
35
+ return StagePath(self._root)
36
+ else:
37
+ return StagePath(f"{self._root}/{self._path.parent}")
38
+
39
+ @property
40
+ def root(self) -> str:
41
+ return self._root
42
+
43
+ @property
44
+ def suffix(self) -> str:
45
+ return self._path.suffix
46
+
47
+ def _compose_path(self, path: Path) -> str:
48
+ # in pathlib, Path("") = "."
49
+ if path == Path(""):
50
+ return self.root
51
+ else:
52
+ return f"{self.root}/{path}"
53
+
54
+ def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
55
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
56
+ if stage_path.root == self.root:
57
+ return self._path.is_relative_to(stage_path._path)
58
+ else:
59
+ return False
60
+
61
+ def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
62
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
63
+ if self.root == stage_path.root:
64
+ return self._path.relative_to(stage_path._path)
65
+ raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
66
+
67
+ def absolute(self) -> "StagePath":
68
+ return self
69
+
70
+ def as_posix(self) -> str:
71
+ return self._compose_path(self._path)
72
+
73
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
74
+ def exists(self) -> bool:
75
+ return True
76
+
77
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
78
+ def is_file(self) -> bool:
79
+ return True
80
+
81
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
82
+ def is_dir(self) -> bool:
83
+ return True
84
+
85
+ def is_absolute(self) -> bool:
86
+ return True
87
+
88
+ def __str__(self) -> str:
89
+ return self.as_posix()
90
+
91
+ def __eq__(self, other: object) -> bool:
92
+ if not isinstance(other, StagePath):
93
+ raise NotImplementedError
94
+ return bool(self.root == other.root and self._path == other._path)
95
+
96
+ def __fspath__(self) -> str:
97
+ return self._compose_path(self._path)
98
+
99
+ def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
100
+ path = self
101
+ for arg in args:
102
+ path = path._make_child(arg)
103
+ return path
104
+
105
+ def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
106
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
107
+ if self.root == stage_path.root:
108
+ child_path = self._path.joinpath(stage_path._path)
109
+ return StagePath(self._compose_path(child_path))
110
+ else:
111
+ return stage_path
112
+
113
+
114
+ def identify_stage_path(path: str) -> Union[StagePath, Path]:
115
+ try:
116
+ stage_path = StagePath(path)
117
+ except ValueError:
118
+ return Path(path)
119
+ return stage_path
@@ -2,18 +2,22 @@ from dataclasses import dataclass
2
2
  from pathlib import PurePath
3
3
  from typing import Literal, Optional, Union
4
4
 
5
+ from snowflake.ml.jobs._utils import stage_utils
6
+
5
7
  JOB_STATUS = Literal[
6
8
  "PENDING",
7
9
  "RUNNING",
8
10
  "FAILED",
9
11
  "DONE",
12
+ "CANCELLING",
13
+ "CANCELLED",
10
14
  "INTERNAL_ERROR",
11
15
  ]
12
16
 
13
17
 
14
18
  @dataclass(frozen=True)
15
19
  class PayloadEntrypoint:
16
- file_path: PurePath
20
+ file_path: Union[PurePath, stage_utils.StagePath]
17
21
  main_func: Optional[str]
18
22
 
19
23
 
@@ -7,7 +7,7 @@ from typing_extensions import ParamSpec
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal import telemetry
9
9
  from snowflake.ml.jobs import job as jb, manager as jm
10
- from snowflake.ml.jobs._utils import constants
10
+ from snowflake.ml.jobs._utils import payload_utils
11
11
 
12
12
  _PROJECT = "MLJob"
13
13
 
@@ -24,7 +24,8 @@ def remote(
24
24
  external_access_integrations: Optional[list[str]] = None,
25
25
  query_warehouse: Optional[str] = None,
26
26
  env_vars: Optional[dict[str, str]] = None,
27
- num_instances: Optional[int] = None,
27
+ target_instances: int = 1,
28
+ min_instances: Optional[int] = None,
28
29
  enable_metrics: bool = False,
29
30
  database: Optional[str] = None,
30
31
  schema: Optional[str] = None,
@@ -40,7 +41,9 @@ def remote(
40
41
  external_access_integrations: A list of external access integrations.
41
42
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
42
43
  env_vars: Environment variables to set in container
43
- num_instances: The number of nodes in the job. If none specified, create a single node job.
44
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
45
+ min_instances: The minimum number of nodes required to start the job. If none specified,
46
+ defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
44
47
  enable_metrics: Whether to enable metrics publishing for the job.
45
48
  database: The database to use for the job.
46
49
  schema: The schema to use for the job.
@@ -59,8 +62,7 @@ def remote(
59
62
 
60
63
  @functools.wraps(func)
61
64
  def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
62
- payload = functools.partial(func, *args, **kwargs)
63
- setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
65
+ payload = payload_utils.create_function_payload(func, *args, **kwargs)
64
66
  job = jm._submit_job(
65
67
  source=payload,
66
68
  stage_name=stage_name,
@@ -69,11 +71,12 @@ def remote(
69
71
  external_access_integrations=external_access_integrations,
70
72
  query_warehouse=query_warehouse,
71
73
  env_vars=env_vars,
72
- num_instances=num_instances,
74
+ target_instances=target_instances,
75
+ min_instances=min_instances,
73
76
  enable_metrics=enable_metrics,
74
77
  database=database,
75
78
  schema=schema,
76
- session=session,
79
+ session=payload.session or session,
77
80
  )
78
81
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
79
82
  return job