snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/telemetry.py +6 -9
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/_internal/utils/identifier.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +61 -0
  5. snowflake/ml/jobs/__init__.py +2 -0
  6. snowflake/ml/jobs/_utils/constants.py +3 -2
  7. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  8. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  9. snowflake/ml/jobs/_utils/payload_utils.py +89 -40
  10. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  11. snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
  13. snowflake/ml/jobs/_utils/spec_utils.py +29 -5
  14. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  15. snowflake/ml/jobs/_utils/types.py +5 -1
  16. snowflake/ml/jobs/decorators.py +20 -28
  17. snowflake/ml/jobs/job.py +197 -61
  18. snowflake/ml/jobs/manager.py +253 -121
  19. snowflake/ml/model/_client/model/model_impl.py +58 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  21. snowflake/ml/model/_client/ops/model_ops.py +18 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +23 -6
  23. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  24. snowflake/ml/model/_client/sql/service.py +68 -20
  25. snowflake/ml/model/_client/sql/stage.py +5 -2
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  27. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  28. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  31. snowflake/ml/model/_signatures/core.py +24 -0
  32. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  33. snowflake/ml/model/target_platform.py +11 -0
  34. snowflake/ml/model/task.py +9 -0
  35. snowflake/ml/model/type_hints.py +5 -13
  36. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  37. snowflake/ml/monitoring/explain_visualize.py +2 -2
  38. snowflake/ml/monitoring/model_monitor.py +0 -4
  39. snowflake/ml/registry/_manager/model_manager.py +30 -15
  40. snowflake/ml/registry/registry.py +144 -47
  41. snowflake/ml/utils/connection_params.py +1 -1
  42. snowflake/ml/utils/html_utils.py +263 -0
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
  45. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
  46. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  47. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  48. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,17 @@ import cloudpickle as cp
12
12
  from packaging import version
13
13
 
14
14
  from snowflake import snowpark
15
- from snowflake.ml.jobs._utils import constants, types
16
- from snowflake.snowpark import exceptions as sp_exceptions
15
+ from snowflake.connector import errors
16
+ from snowflake.ml.jobs._utils import (
17
+ constants,
18
+ function_payload_utils,
19
+ stage_utils,
20
+ types,
21
+ )
17
22
  from snowflake.snowpark._internal import code_generation
18
23
 
24
+ cp.register_pickle_by_value(function_payload_utils)
25
+
19
26
  _SUPPORTED_ARG_TYPES = {str, int, float}
20
27
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
21
28
  _ENTRYPOINT_FUNC_NAME = "func"
@@ -217,20 +224,23 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
217
224
  ).strip()
218
225
 
219
226
 
220
- def resolve_source(source: Union[Path, Callable[..., Any]]) -> Union[Path, Callable[..., Any]]:
227
+ def resolve_source(
228
+ source: Union[Path, stage_utils.StagePath, Callable[..., Any]]
229
+ ) -> Union[Path, stage_utils.StagePath, Callable[..., Any]]:
221
230
  if callable(source):
222
231
  return source
223
- elif isinstance(source, Path):
224
- # Validate source
225
- source = source
232
+ elif isinstance(source, (Path, stage_utils.StagePath)):
226
233
  if not source.exists():
227
234
  raise FileNotFoundError(f"{source} does not exist")
228
235
  return source.absolute()
229
236
  else:
230
- raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
237
+ raise ValueError("Unsupported source type. Source must be a stage, file, directory, or callable.")
231
238
 
232
239
 
233
- def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Optional[Path]) -> types.PayloadEntrypoint:
240
+ def resolve_entrypoint(
241
+ source: Union[Path, stage_utils.StagePath, Callable[..., Any]],
242
+ entrypoint: Optional[Union[stage_utils.StagePath, Path]],
243
+ ) -> types.PayloadEntrypoint:
234
244
  if callable(source):
235
245
  # Entrypoint is generated for callable payloads
236
246
  return types.PayloadEntrypoint(
@@ -245,11 +255,11 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
245
255
  # Infer entrypoint from source
246
256
  entrypoint = parent
247
257
  else:
248
- raise ValueError("entrypoint must be provided when source is a directory")
258
+ raise ValueError("Entrypoint must be provided when source is a directory")
249
259
  elif entrypoint.is_absolute():
250
260
  # Absolute path - validate it's a subpath of source dir
251
261
  if not entrypoint.is_relative_to(parent):
252
- raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint})")
262
+ raise ValueError(f"Entrypoint must be a subpath of {parent}, got: {entrypoint}")
253
263
  else:
254
264
  # Relative path
255
265
  if (abs_entrypoint := entrypoint.absolute()).is_relative_to(parent) and abs_entrypoint.is_file():
@@ -265,6 +275,7 @@ def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Opti
265
275
  "Entrypoint not found. Ensure the entrypoint is a valid file and is under"
266
276
  f" the source directory (source={parent}, entrypoint={entrypoint})"
267
277
  )
278
+
268
279
  if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
269
280
  raise ValueError(
270
281
  "Unsupported entrypoint type:"
@@ -285,8 +296,9 @@ class JobPayload:
285
296
  *,
286
297
  pip_requirements: Optional[list[str]] = None,
287
298
  ) -> None:
288
- self.source = Path(source) if isinstance(source, str) else source
289
- self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
299
+ # for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
300
+ self.source = stage_utils.identify_stage_path(source) if isinstance(source, str) else source
301
+ self.entrypoint = stage_utils.identify_stage_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
290
302
  self.pip_requirements = pip_requirements
291
303
 
292
304
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
@@ -300,17 +312,18 @@ class JobPayload:
300
312
  stage_name = stage_path.parts[0].lstrip("@")
301
313
  # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
302
314
  try:
303
- session.sql("describe stage identifier(?)", params=[stage_name]).collect()
304
- except sp_exceptions.SnowparkSQLException:
305
- session.sql(
315
+ session._conn.run_query("describe stage identifier(?)", params=[stage_name], _force_qmark_paramstyle=True)
316
+ except errors.ProgrammingError:
317
+ session._conn.run_query(
306
318
  "create stage if not exists identifier(?)"
307
319
  " encryption = ( type = 'SNOWFLAKE_SSE' )"
308
320
  " comment = 'Created by snowflake.ml.jobs Python API'",
309
321
  params=[stage_name],
310
- ).collect()
322
+ _force_qmark_paramstyle=True,
323
+ )
311
324
 
312
325
  # Upload payload to stage
313
- if not isinstance(source, Path):
326
+ if not isinstance(source, (Path, stage_utils.StagePath)):
314
327
  source_code = generate_python_code(source, source_code_display=True)
315
328
  _ = session.file.put_stream(
316
329
  io.BytesIO(source_code.encode()),
@@ -321,27 +334,38 @@ class JobPayload:
321
334
  source = Path(entrypoint.file_path.parent)
322
335
  if not any(r.startswith("cloudpickle") for r in pip_requirements):
323
336
  pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
324
- elif source.is_dir():
325
- # Manually traverse the directory and upload each file, since Snowflake PUT
326
- # can't handle directories. Reduce the number of PUT operations by using
327
- # wildcard patterns to batch upload files with the same extension.
328
- for path in {
329
- p.parent.joinpath(f"*{p.suffix}") if p.suffix else p for p in source.resolve().rglob("*") if p.is_file()
330
- }:
337
+
338
+ elif isinstance(source, stage_utils.StagePath):
339
+ # copy payload to stage
340
+ if source == entrypoint.file_path:
341
+ source = source.parent
342
+ source_path = source.as_posix() + "/"
343
+ session.sql(f"copy files into {stage_path}/ from {source_path}").collect()
344
+
345
+ elif isinstance(source, Path):
346
+ if source.is_dir():
347
+ # Manually traverse the directory and upload each file, since Snowflake PUT
348
+ # can't handle directories. Reduce the number of PUT operations by using
349
+ # wildcard patterns to batch upload files with the same extension.
350
+ for path in {
351
+ p.parent.joinpath(f"*{p.suffix}") if p.suffix else p
352
+ for p in source.resolve().rglob("*")
353
+ if p.is_file()
354
+ }:
355
+ session.file.put(
356
+ str(path),
357
+ stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
358
+ overwrite=True,
359
+ auto_compress=False,
360
+ )
361
+ else:
331
362
  session.file.put(
332
- str(path),
333
- stage_path.joinpath(path.parent.relative_to(source)).as_posix(),
363
+ str(source.resolve()),
364
+ stage_path.as_posix(),
334
365
  overwrite=True,
335
366
  auto_compress=False,
336
367
  )
337
- else:
338
- session.file.put(
339
- str(source.resolve()),
340
- stage_path.as_posix(),
341
- overwrite=True,
342
- auto_compress=False,
343
- )
344
- source = source.parent
368
+ source = source.parent
345
369
 
346
370
  # Upload requirements
347
371
  # TODO: Check if payload includes both a requirements.txt file and pip_requirements
@@ -502,9 +526,15 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
502
526
  return param_code
503
527
 
504
528
 
505
- def generate_python_code(func: Callable[..., Any], source_code_display: bool = False) -> str:
529
+ def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
506
530
  """Generate an entrypoint script from a Python function."""
507
- signature = inspect.signature(func)
531
+
532
+ if isinstance(payload, function_payload_utils.FunctionPayload):
533
+ function = payload.function
534
+ else:
535
+ function = payload
536
+
537
+ signature = inspect.signature(function)
508
538
  if any(
509
539
  p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
510
540
  for p in signature.parameters.values()
@@ -513,21 +543,20 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
513
543
 
514
544
  # Mirrored from Snowpark generate_python_code() function
515
545
  # https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
516
- source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
546
+ source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
517
547
 
518
548
  arg_dict_name = "kwargs"
519
- if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
549
+ if isinstance(payload, function_payload_utils.FunctionPayload):
520
550
  param_code = f"{arg_dict_name} = {{}}"
521
551
  else:
522
552
  param_code = _generate_param_handler_code(signature, arg_dict_name)
523
-
524
553
  return f"""
525
554
  import sys
526
555
  import pickle
527
556
 
528
557
  try:
529
558
  {textwrap.indent(source_code_comment, ' ')}
530
- {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
559
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
531
560
  except (TypeError, pickle.PickleError):
532
561
  if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
533
562
  raise RuntimeError(
@@ -551,3 +580,23 @@ if __name__ == '__main__':
551
580
 
552
581
  __return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
553
582
  """
583
+
584
+
585
+ def create_function_payload(
586
+ func: Callable[..., Any], *args: Any, **kwargs: Any
587
+ ) -> function_payload_utils.FunctionPayload:
588
+ signature = inspect.signature(func)
589
+ bound = signature.bind(*args, **kwargs)
590
+ bound.apply_defaults()
591
+ session_argument = ""
592
+ session = None
593
+ for name, val in list(bound.arguments.items()):
594
+ if isinstance(val, snowpark.Session):
595
+ if session:
596
+ raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
597
+ session = val
598
+ session_argument = name
599
+ del bound.arguments[name]
600
+ payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
601
+
602
+ return payload
@@ -0,0 +1,9 @@
1
+ from snowflake import snowpark
2
+
3
+
4
+ def get_attribute_map(session: snowpark.Session, requested_attributes: dict[str, int]) -> dict[str, int]:
5
+ metadata = session._conn._cursor.description
6
+ for index in range(len(metadata)):
7
+ if metadata[index].name in requested_attributes.keys():
8
+ requested_attributes[metadata[index].name] = index
9
+ return requested_attributes
@@ -1,10 +1,26 @@
1
+ from snowflake.ml.jobs._utils import constants as mljob_constants
2
+
1
3
  # Constants defining the shutdown signal actor configuration.
2
4
  SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
3
5
  SHUTDOWN_ACTOR_NAMESPACE = "default"
4
6
  SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
5
7
 
6
8
 
9
+ # The followings are Inherited from snowflake.ml.jobs._utils.constants
10
+ # We need to copy them here since snowml package on the server side does
11
+ # not have the latest version of the code
12
+
7
13
  # Log start and end messages
8
- # Inherited from snowflake.ml.jobs._utils.constants
9
- LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
10
- LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
14
+ LOG_START_MSG = getattr(
15
+ mljob_constants,
16
+ "LOG_START_MSG",
17
+ "--------------------------------\nML job started\n--------------------------------",
18
+ )
19
+ LOG_END_MSG = getattr(
20
+ mljob_constants,
21
+ "LOG_END_MSG",
22
+ "--------------------------------\nML job finished\n--------------------------------",
23
+ )
24
+
25
+ # min_instances environment variable name
26
+ MIN_INSTANCES_ENV_VAR = getattr(mljob_constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
@@ -13,7 +13,7 @@ from pathlib import Path
13
13
  from typing import Any, Optional
14
14
 
15
15
  import cloudpickle
16
- from constants import LOG_END_MSG, LOG_START_MSG
16
+ from constants import LOG_END_MSG, LOG_START_MSG, MIN_INSTANCES_ENV_VAR
17
17
 
18
18
  from snowflake.ml.jobs._utils import constants
19
19
  from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
@@ -72,28 +72,6 @@ class SimpleJSONEncoder(json.JSONEncoder):
72
72
  return f"Unserializable object: {repr(obj)}"
73
73
 
74
74
 
75
- def get_active_node_count() -> int:
76
- """
77
- Count the number of active nodes in the Ray cluster.
78
-
79
- Returns:
80
- int: Total count of active nodes
81
- """
82
- import ray
83
-
84
- if not ray.is_initialized():
85
- ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)
86
- try:
87
- nodes = [node for node in ray.nodes() if node.get("Alive")]
88
- total_active = len(nodes)
89
-
90
- logger.info(f"Active nodes: {total_active}")
91
- return total_active
92
- except Exception as e:
93
- logger.warning(f"Error getting active node count: {e}")
94
- return 0
95
-
96
-
97
75
  def wait_for_min_instances(min_instances: int) -> None:
98
76
  """
99
77
  Wait until the specified minimum number of instances are available in the Ray cluster.
@@ -108,13 +86,16 @@ def wait_for_min_instances(min_instances: int) -> None:
108
86
  logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
109
87
  return
110
88
 
89
+ # mljob_launcher runs inside the CR where mlruntime libraries are available, so we can import common_util directly
90
+ from common_utils import common_util as mlrs_util
91
+
111
92
  start_time = time.time()
112
93
  timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
113
94
  check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
114
95
  logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
115
96
 
116
97
  while time.time() - start_time < timeout:
117
- total_nodes = get_active_node_count()
98
+ total_nodes = mlrs_util.get_num_ray_nodes()
118
99
 
119
100
  if total_nodes >= min_instances:
120
101
  elapsed = time.time() - start_time
@@ -128,7 +109,8 @@ def wait_for_min_instances(min_instances: int) -> None:
128
109
  time.sleep(check_interval)
129
110
 
130
111
  raise TimeoutError(
131
- f"Timed out after {timeout}s waiting for {min_instances} instances, only {get_active_node_count()} available"
112
+ f"Timed out after {timeout}s waiting for {min_instances} instances, only "
113
+ f"{mlrs_util.get_num_ray_nodes()} available"
132
114
  )
133
115
 
134
116
 
@@ -199,7 +181,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
199
181
  """
200
182
  try:
201
183
  # Wait for minimum required instances if specified
202
- min_instances_str = os.environ.get("JOB_MIN_INSTANCES", 1)
184
+ min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
203
185
  if min_instances_str and int(min_instances_str) > 1:
204
186
  wait_for_min_instances(int(min_instances_str))
205
187
 
@@ -1,20 +1,23 @@
1
1
  import logging
2
+ import os
2
3
  from math import ceil
3
4
  from pathlib import PurePath
4
5
  from typing import Any, Optional, Union
5
6
 
6
7
  from snowflake import snowpark
7
8
  from snowflake.ml._internal.utils import snowflake_env
8
- from snowflake.ml.jobs._utils import constants, types
9
+ from snowflake.ml.jobs._utils import constants, query_helper, types
9
10
 
10
11
 
11
12
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
12
13
  """Extract resource information for the specified compute pool"""
13
14
  # Get the instance family
14
- rows = session.sql("show compute pools like ?", params=[compute_pool]).collect()
15
- if not rows:
15
+ rows = session._conn.run_query("show compute pools like ?", params=[compute_pool], _force_qmark_paramstyle=True)
16
+ if not rows or not isinstance(rows, dict) or not rows.get("data"):
16
17
  raise ValueError(f"Compute pool '{compute_pool}' not found")
17
- instance_family: str = rows[0]["instance_family"]
18
+ requested_attributes = query_helper.get_attribute_map(session, {"instance_family": 4})
19
+ compute_pool_info = rows["data"]
20
+ instance_family: str = compute_pool_info[0][requested_attributes["instance_family"]]
18
21
  cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
19
22
 
20
23
  return (
@@ -30,7 +33,7 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.Image
30
33
  # Use MLRuntime image
31
34
  image_repo = constants.DEFAULT_IMAGE_REPO
32
35
  image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
33
- image_tag = constants.DEFAULT_IMAGE_TAG
36
+ image_tag = _get_runtime_image_tag()
34
37
 
35
38
  # TODO: Should each instance consume the entire pod?
36
39
  return types.ImageSpec(
@@ -346,3 +349,24 @@ def _merge_lists_of_dicts(
346
349
  result[key] = d
347
350
 
348
351
  return list(result.values())
352
+
353
+
354
+ def _get_runtime_image_tag() -> str:
355
+ """
356
+ Detect runtime image tag from container environment.
357
+
358
+ Checks in order:
359
+ 1. Environment variable MLRS_CONTAINER_IMAGE_TAG
360
+ 2. Falls back to hardcoded default
361
+
362
+ Returns:
363
+ str: The runtime image tag to use for job containers
364
+ """
365
+ env_tag = os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR)
366
+ if env_tag:
367
+ logging.debug(f"Using runtime image tag from environment: {env_tag}")
368
+ return env_tag
369
+
370
+ # Fall back to default
371
+ logging.debug(f"Using default runtime image tag: {constants.DEFAULT_IMAGE_TAG}")
372
+ return constants.DEFAULT_IMAGE_TAG
@@ -0,0 +1,119 @@
1
+ import os
2
+ import re
3
+ from os import PathLike
4
+ from pathlib import Path, PurePath
5
+ from typing import Union
6
+
7
+ from snowflake.ml._internal.utils import identifier
8
+
9
+ PROTOCOL_NAME = "snow"
10
+ _SNOWURL_PATH_RE = re.compile(
11
+ rf"^(?:(?:{PROTOCOL_NAME}://)?"
12
+ r"(?<!@)(?P<domain>\w+)/"
13
+ rf"(?P<name>(?:{identifier._SF_IDENTIFIER}\.){{,2}}{identifier._SF_IDENTIFIER})/)?"
14
+ r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
15
+ )
16
+
17
+ _STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
18
+
19
+
20
+ class StagePath:
21
+ def __init__(self, path: str) -> None:
22
+ stage_match = _SNOWURL_PATH_RE.fullmatch(path) or _STAGEF_PATH_RE.fullmatch(path)
23
+ if not stage_match:
24
+ raise ValueError(f"{path} is not a valid stage path")
25
+ path = path.strip()
26
+ self._raw_path = path
27
+ relpath = stage_match.group("relpath")
28
+ start, _ = stage_match.span("relpath")
29
+ self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
30
+ self._path = Path(relpath or "")
31
+
32
+ @property
33
+ def parent(self) -> "StagePath":
34
+ if self._path.parent == Path(""):
35
+ return StagePath(self._root)
36
+ else:
37
+ return StagePath(f"{self._root}/{self._path.parent}")
38
+
39
+ @property
40
+ def root(self) -> str:
41
+ return self._root
42
+
43
+ @property
44
+ def suffix(self) -> str:
45
+ return self._path.suffix
46
+
47
+ def _compose_path(self, path: Path) -> str:
48
+ # in pathlib, Path("") = "."
49
+ if path == Path(""):
50
+ return self.root
51
+ else:
52
+ return f"{self.root}/{path}"
53
+
54
+ def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
55
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
56
+ if stage_path.root == self.root:
57
+ return self._path.is_relative_to(stage_path._path)
58
+ else:
59
+ return False
60
+
61
+ def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
62
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
63
+ if self.root == stage_path.root:
64
+ return self._path.relative_to(stage_path._path)
65
+ raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
66
+
67
+ def absolute(self) -> "StagePath":
68
+ return self
69
+
70
+ def as_posix(self) -> str:
71
+ return self._compose_path(self._path)
72
+
73
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
74
+ def exists(self) -> bool:
75
+ return True
76
+
77
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
78
+ def is_file(self) -> bool:
79
+ return True
80
+
81
+ # TODO Add actual implementation https://snowflakecomputing.atlassian.net/browse/SNOW-2112795
82
+ def is_dir(self) -> bool:
83
+ return True
84
+
85
+ def is_absolute(self) -> bool:
86
+ return True
87
+
88
+ def __str__(self) -> str:
89
+ return self.as_posix()
90
+
91
+ def __eq__(self, other: object) -> bool:
92
+ if not isinstance(other, StagePath):
93
+ raise NotImplementedError
94
+ return bool(self.root == other.root and self._path == other._path)
95
+
96
+ def __fspath__(self) -> str:
97
+ return self._compose_path(self._path)
98
+
99
+ def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
100
+ path = self
101
+ for arg in args:
102
+ path = path._make_child(arg)
103
+ return path
104
+
105
+ def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
106
+ stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
107
+ if self.root == stage_path.root:
108
+ child_path = self._path.joinpath(stage_path._path)
109
+ return StagePath(self._compose_path(child_path))
110
+ else:
111
+ return stage_path
112
+
113
+
114
+ def identify_stage_path(path: str) -> Union[StagePath, Path]:
115
+ try:
116
+ stage_path = StagePath(path)
117
+ except ValueError:
118
+ return Path(path)
119
+ return stage_path
@@ -2,18 +2,22 @@ from dataclasses import dataclass
2
2
  from pathlib import PurePath
3
3
  from typing import Literal, Optional, Union
4
4
 
5
+ from snowflake.ml.jobs._utils import stage_utils
6
+
5
7
  JOB_STATUS = Literal[
6
8
  "PENDING",
7
9
  "RUNNING",
8
10
  "FAILED",
9
11
  "DONE",
12
+ "CANCELLING",
13
+ "CANCELLED",
10
14
  "INTERNAL_ERROR",
11
15
  ]
12
16
 
13
17
 
14
18
  @dataclass(frozen=True)
15
19
  class PayloadEntrypoint:
16
- file_path: PurePath
20
+ file_path: Union[PurePath, stage_utils.StagePath]
17
21
  main_func: Optional[str]
18
22
 
19
23
 
@@ -1,13 +1,13 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Callable, Optional, TypeVar
3
+ from typing import Any, Callable, Optional, TypeVar
4
4
 
5
5
  from typing_extensions import ParamSpec
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal import telemetry
9
9
  from snowflake.ml.jobs import job as jb, manager as jm
10
- from snowflake.ml.jobs._utils import constants
10
+ from snowflake.ml.jobs._utils import payload_utils
11
11
 
12
12
  _PROJECT = "MLJob"
13
13
 
@@ -20,16 +20,11 @@ def remote(
20
20
  compute_pool: str,
21
21
  *,
22
22
  stage_name: str,
23
+ target_instances: int = 1,
23
24
  pip_requirements: Optional[list[str]] = None,
24
25
  external_access_integrations: Optional[list[str]] = None,
25
- query_warehouse: Optional[str] = None,
26
- env_vars: Optional[dict[str, str]] = None,
27
- target_instances: int = 1,
28
- min_instances: int = 1,
29
- enable_metrics: bool = False,
30
- database: Optional[str] = None,
31
- schema: Optional[str] = None,
32
26
  session: Optional[snowpark.Session] = None,
27
+ **kwargs: Any,
33
28
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
34
29
  """
35
30
  Submit a job to the compute pool.
@@ -37,17 +32,20 @@ def remote(
37
32
  Args:
38
33
  compute_pool: The compute pool to use for the job.
39
34
  stage_name: The name of the stage where the job payload will be uploaded.
35
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
40
36
  pip_requirements: A list of pip requirements for the job.
41
37
  external_access_integrations: A list of external access integrations.
42
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
43
- env_vars: Environment variables to set in container
44
- target_instances: The number of nodes in the job. If none specified, create a single node job.
45
- min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
46
- If set, the job will not start until the minimum number of nodes is available.
47
- enable_metrics: Whether to enable metrics publishing for the job.
48
- database: The database to use for the job.
49
- schema: The schema to use for the job.
50
38
  session: The Snowpark session to use. If none specified, uses active session.
39
+ kwargs: Additional keyword arguments. Supported arguments:
40
+ database (str): The database to use for the job.
41
+ schema (str): The schema to use for the job.
42
+ min_instances (int): The minimum number of nodes required to start the job.
43
+ If none specified, defaults to target_instances. If set, the job
44
+ will not start until the minimum number of nodes is available.
45
+ env_vars (dict): Environment variables to set in container.
46
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
47
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
48
+ spec_overrides (dict): A dictionary of overrides for the service spec.
51
49
 
52
50
  Returns:
53
51
  Decorator that dispatches invocations of the decorated function as remote jobs.
@@ -61,23 +59,17 @@ def remote(
61
59
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
62
60
 
63
61
  @functools.wraps(func)
64
- def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
65
- payload = functools.partial(func, *args, **kwargs)
66
- setattr(payload, constants.IS_MLJOB_REMOTE_ATTR, True)
62
+ def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
63
+ payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
67
64
  job = jm._submit_job(
68
65
  source=payload,
69
66
  stage_name=stage_name,
70
67
  compute_pool=compute_pool,
68
+ target_instances=target_instances,
71
69
  pip_requirements=pip_requirements,
72
70
  external_access_integrations=external_access_integrations,
73
- query_warehouse=query_warehouse,
74
- env_vars=env_vars,
75
- target_instances=target_instances,
76
- min_instances=min_instances,
77
- enable_metrics=enable_metrics,
78
- database=database,
79
- schema=schema,
80
- session=session,
71
+ session=payload.session or session,
72
+ **kwargs,
81
73
  )
82
74
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
83
75
  return job