snowflake-ml-python 1.25.0__py3-none-any.whl → 1.25.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,12 +17,20 @@ import cloudpickle as cp
17
17
  from packaging import version
18
18
 
19
19
  from snowflake import snowpark
20
- from snowflake.ml.jobs._utils import constants, query_helper, stage_utils, types
20
+ from snowflake.ml.jobs._utils import (
21
+ constants,
22
+ function_payload_utils,
23
+ query_helper,
24
+ stage_utils,
25
+ types,
26
+ )
21
27
  from snowflake.snowpark import exceptions as sp_exceptions
22
28
  from snowflake.snowpark._internal import code_generation
23
29
  from snowflake.snowpark._internal.utils import zip_file_or_directory_to_stream
24
30
 
25
31
  logger = logging.getLogger(__name__)
32
+
33
+ cp.register_pickle_by_value(function_payload_utils)
26
34
  ImportType = Union[str, Path, ModuleType]
27
35
 
28
36
  _SUPPORTED_ARG_TYPES = {str, int, float}
@@ -553,6 +561,7 @@ class JobPayload:
553
561
  env_vars = {
554
562
  constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,
555
563
  constants.PAYLOAD_DIR_ENV_VAR: constants.APP_STAGE_SUBPATH,
564
+ constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
556
565
  }
557
566
 
558
567
  return types.UploadedPayload(
@@ -682,9 +691,14 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
682
691
  return param_code
683
692
 
684
693
 
685
- def generate_python_code(function: Callable[..., Any], source_code_display: bool = False) -> str:
694
+ def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
686
695
  """Generate an entrypoint script from a Python function."""
687
696
 
697
+ if isinstance(payload, function_payload_utils.FunctionPayload):
698
+ function = payload.function
699
+ else:
700
+ function = payload
701
+
688
702
  signature = inspect.signature(function)
689
703
  if any(
690
704
  p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
@@ -697,7 +711,7 @@ def generate_python_code(function: Callable[..., Any], source_code_display: bool
697
711
  source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
698
712
 
699
713
  arg_dict_name = "kwargs"
700
- if getattr(function, constants.IS_MLJOB_REMOTE_ATTR, None):
714
+ if isinstance(payload, function_payload_utils.FunctionPayload):
701
715
  param_code = f"{arg_dict_name} = {{}}"
702
716
  else:
703
717
  param_code = _generate_param_handler_code(signature, arg_dict_name)
@@ -707,7 +721,7 @@ import pickle
707
721
 
708
722
  try:
709
723
  {textwrap.indent(source_code_comment, ' ')}
710
- {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(function).hex()}'))
724
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
711
725
  except (TypeError, pickle.PickleError):
712
726
  if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
713
727
  raise RuntimeError(
@@ -733,6 +747,26 @@ if __name__ == '__main__':
733
747
  """
734
748
 
735
749
 
750
+ def create_function_payload(
751
+ func: Callable[..., Any], *args: Any, **kwargs: Any
752
+ ) -> function_payload_utils.FunctionPayload:
753
+ signature = inspect.signature(func)
754
+ bound = signature.bind(*args, **kwargs)
755
+ bound.apply_defaults()
756
+ session_argument = ""
757
+ session = None
758
+ for name, val in list(bound.arguments.items()):
759
+ if isinstance(val, snowpark.Session):
760
+ if session:
761
+ raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
762
+ session = val
763
+ session_argument = name
764
+ del bound.arguments[name]
765
+ payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
766
+
767
+ return payload
768
+
769
+
736
770
  def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optional[Union[str, list[str]]] = None) -> str:
737
771
 
738
772
  if entrypoint and isinstance(entrypoint, (list, tuple)):
@@ -741,7 +775,7 @@ def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optiona
741
775
  return f"{PurePath(entrypoint).stem}"
742
776
  elif source and not callable(source):
743
777
  return f"{PurePath(source).stem}"
744
- elif callable(source):
745
- return f"{source.__name__}"
778
+ elif isinstance(source, function_payload_utils.FunctionPayload):
779
+ return f"{source.function.__name__}"
746
780
  else:
747
781
  return f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
@@ -1,18 +1,117 @@
1
- from typing import Optional, cast
1
+ import datetime
2
+ import logging
3
+ from typing import Any, Literal, Optional, Union
4
+
5
+ from packaging.version import Version
6
+ from pydantic import BaseModel, Field, RootModel, field_validator
2
7
 
3
8
  from snowflake import snowpark
4
- from snowflake.ml.jobs._utils import query_helper
9
+ from snowflake.ml.jobs._utils import constants, query_helper
10
+
11
+
12
+ class SpcsContainerRuntime(BaseModel):
13
+ python_version: Version = Field(alias="pythonVersion")
14
+ hardware_type: str = Field(alias="hardwareType")
15
+ runtime_container_image: str = Field(alias="runtimeContainerImage")
16
+
17
+ @field_validator("python_version", mode="before")
18
+ @classmethod
19
+ def validate_python_version(cls, v: Union[str, Version]) -> Version:
20
+ if isinstance(v, Version):
21
+ return v
22
+ try:
23
+ return Version(v)
24
+ except Exception:
25
+ raise ValueError(f"Invalid Python version format: {v}")
26
+
27
+ class Config:
28
+ frozen = True
29
+ extra = "allow"
30
+ arbitrary_types_allowed = True
31
+
32
+
33
+ class RuntimeEnvironmentEntry(BaseModel):
34
+ spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
35
+ created_on: datetime.datetime = Field(alias="createdOn")
36
+ id: Optional[str] = Field(alias="id")
37
+
38
+ class Config:
39
+ extra = "allow"
40
+ frozen = True
41
+
42
+
43
+ class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
44
+ @field_validator("root", mode="before")
45
+ @classmethod
46
+ def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]:
47
+ """
48
+ Pre-validation hook: keep only those items at the root level
49
+ whose values are dicts. Non-dict values will be dropped.
5
50
 
51
+ Args:
52
+ data: The input data to filter, expected to be a dictionary.
6
53
 
7
- def get_runtime_image(
8
- session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
54
+ Returns:
55
+ A dictionary containing only the key-value pairs where values are dictionaries.
56
+
57
+ Raises:
58
+ ValueError: If input data is not a dictionary.
59
+ """
60
+ # If the entire root is not a dict, raise error immediately
61
+ if not isinstance(data, dict):
62
+ raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}")
63
+
64
+ # Filter out any key whose value is not a dict
65
+ return {key: value for key, value in data.items() if isinstance(value, dict)}
66
+
67
+ def get_spcs_container_runtimes(
68
+ self,
69
+ *,
70
+ hardware_type: Optional[str] = None,
71
+ python_version: Optional[Version] = None,
72
+ ) -> list[SpcsContainerRuntime]:
73
+ # TODO(SNOW-2682000): parse version from NRE in a safer way, like relying on the label,id or image tag.
74
+ entries: list[RuntimeEnvironmentEntry] = [
75
+ entry
76
+ for entry in self.root.values()
77
+ if entry.spcs_container_runtime is not None
78
+ and (hardware_type is None or entry.spcs_container_runtime.hardware_type.lower() == hardware_type.lower())
79
+ and (
80
+ python_version is None
81
+ or (
82
+ entry.spcs_container_runtime.python_version.major == python_version.major
83
+ and entry.spcs_container_runtime.python_version.minor == python_version.minor
84
+ )
85
+ )
86
+ ]
87
+ entries.sort(key=lambda e: e.created_on, reverse=True)
88
+
89
+ return [entry.spcs_container_runtime for entry in entries if entry.spcs_container_runtime is not None]
90
+
91
+
92
+ def _extract_image_tag(image_url: str) -> Optional[str]:
93
+ image_tag = image_url.rsplit(":", 1)[-1]
94
+ return image_tag
95
+
96
+
97
+ def find_runtime_image(
98
+ session: snowpark.Session, target_hardware: Literal["CPU", "GPU"], target_python_version: Optional[str] = None
9
99
  ) -> Optional[str]:
10
- runtime_environment = runtime_environment if runtime_environment else ""
11
- rows = query_helper.run_query(session, f"CALL SYSTEM$GET_ML_JOB_RUNTIME('{compute_pool}', '{runtime_environment}')")
100
+ python_version = (
101
+ Version(target_python_version) if target_python_version else Version(constants.DEFAULT_PYTHON_VERSION)
102
+ )
103
+ rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
12
104
  if not rows:
13
- raise ValueError("Failed to get any available runtime image")
14
- image = rows[0][0]
15
- url, tag = image.rsplit(":", 1)
16
- if url is None or tag is None:
17
- raise ValueError(f"image {image} is not a valid runtime image")
18
- return cast(str, image) if image else None
105
+ return None
106
+ try:
107
+ runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
108
+ spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes(
109
+ hardware_type=target_hardware,
110
+ python_version=python_version,
111
+ )
112
+ except Exception as e:
113
+ logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
114
+ return None
115
+
116
+ selected_runtime = spcs_container_runtimes[0] if spcs_container_runtimes else None
117
+ return selected_runtime.runtime_container_image if selected_runtime else None
@@ -1,7 +1,6 @@
1
1
  import argparse
2
2
  import copy
3
3
  import importlib.util
4
- import io
5
4
  import json
6
5
  import logging
7
6
  import math
@@ -13,22 +12,15 @@ import sys
13
12
  import time
14
13
  import traceback
15
14
  import zipfile
16
- from pathlib import Path, PurePosixPath
17
- from typing import Any, Callable, Optional
15
+ from pathlib import Path
16
+ from typing import Any, Optional
18
17
 
19
18
  # Ensure payload directory is in sys.path for module imports before importing other modules
20
19
  # This is needed to support relative imports in user scripts and to allow overriding
21
20
  # modules using modules in the payload directory
22
21
  # TODO: Inject the environment variable names at job submission time
23
22
  STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
24
- STAGE_RESULT_PATH = os.environ.get("MLRS_STAGE_RESULT_PATH")
25
- # Updated MLRS_RESULT_PATH to use unique stage mounts for each ML Job.
26
- # To prevent output collisions between jobs sharing the same definition,
27
- # the server-side mount now dynamically includes the job_name.
28
- # Format: @payload_stage/{job_definition_name}/{job_name}/mljob_result
29
- JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "mljob_result")
30
- if STAGE_RESULT_PATH:
31
- JOB_RESULT_PATH = os.path.join(STAGE_RESULT_PATH, JOB_RESULT_PATH)
23
+ JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
32
24
  PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
33
25
 
34
26
  if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
@@ -355,156 +347,24 @@ def wait_for_instances(
355
347
  )
356
348
 
357
349
 
358
- def _load_dto_fallback(function_args: str, path_transform: Callable[[str], str]) -> Any:
359
- from snowflake.ml.jobs._interop import data_utils
360
- from snowflake.ml.jobs._interop.utils import DEFAULT_CODEC, DEFAULT_PROTOCOL
361
- from snowflake.snowpark import exceptions as sp_exceptions
362
-
363
- try:
364
- with data_utils.open_stream(function_args, "r") as stream:
365
- # Load the DTO as a dict for easy fallback to legacy loading if necessary
366
- data = DEFAULT_CODEC.decode(stream, as_dict=True)
367
- # the exception could be OSError or BlockingIOError(the file name is too long)
368
- except OSError as e:
369
- # path_or_data might be inline data
370
- try:
371
- data = DEFAULT_CODEC.decode(io.StringIO(function_args), as_dict=True)
372
- except Exception:
373
- raise e
374
-
375
- if data["protocol"] is not None:
376
- try:
377
- from snowflake.ml.jobs._interop.dto_schema import ProtocolInfo
378
-
379
- protocol_info = ProtocolInfo.model_validate(data["protocol"])
380
- logger.debug(f"Loading result value with protocol {protocol_info}")
381
- result_value = DEFAULT_PROTOCOL.load(protocol_info, session=None, path_transform=path_transform)
382
- except sp_exceptions.SnowparkSQLException:
383
- raise
384
- else:
385
- result_value = None
386
-
387
- return data["value"] or result_value
388
-
389
-
390
- def _unpack_obj_fallback(obj: Any, session: Optional[snowflake.snowpark.Session]) -> Any:
391
- SESSION_KEY_PREFIX = "session@"
392
-
393
- if not isinstance(obj, dict):
394
- return obj
395
- elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
396
- return session
397
- else:
398
- type = obj.get("type@", None)
399
- # If type is None, we are unpacking a dict
400
- if type is None:
401
- result_dict = {}
402
- for k, v in obj.items():
403
- if k.startswith(SESSION_KEY_PREFIX):
404
- result_key = k[len(SESSION_KEY_PREFIX) :]
405
- result_dict[result_key] = session
406
- else:
407
- result_dict[k] = _unpack_obj_fallback(v, session)
408
- return result_dict
409
- # If type is not None, we are unpacking a tuple or list
410
- else:
411
- indexes = []
412
- for k, _ in obj.items():
413
- if "#" in k:
414
- indexes.append(int(k.split("#")[-1]))
415
-
416
- if not indexes:
417
- return tuple() if type is tuple else []
418
- result_list: list[Any] = [None] * (max(indexes) + 1)
419
-
420
- for k, v in obj.items():
421
- if k == "type@":
422
- continue
423
- idx = int(k.split("#")[-1])
424
- if k.startswith(SESSION_KEY_PREFIX):
425
- result_list[idx] = session
426
- else:
427
- result_list[idx] = _unpack_obj_fallback(v, session)
428
- return tuple(result_list) if type is tuple else result_list
429
-
430
-
431
- def _load_function_args(
432
- session: snowflake.snowpark.Session,
433
- function_args: Optional[str] = None,
434
- ) -> tuple[tuple[Any, ...], dict[str, Any]]:
435
- """Load and deserialize function arguments.
436
-
437
- Args:
438
- function_args: Inline serialized function arguments or path to serialized file.
439
- session: Optional Snowpark session for stage access if needed.
440
-
441
- Returns:
442
- A tuple of (positional_args, keyword_args)
443
-
444
- """
445
- if not function_args:
446
- return (), {}
447
-
448
- def path_transform(stage_path: str) -> str:
449
- if not PAYLOAD_PATH:
450
- return stage_path
451
-
452
- payload_path = PurePosixPath(PAYLOAD_PATH)
453
- payload_dir_name = payload_path.name # e.g., "app"
454
-
455
- # Parse stage path and find the payload directory
456
- stage_parts = PurePosixPath(stage_path.lstrip("@")).parts
457
-
458
- try:
459
- # Find index of payload directory (e.g., "app") in stage path
460
- idx = stage_parts.index(payload_dir_name)
461
- # Get relative path after the payload directory
462
- relative_parts = stage_parts[idx + 1 :]
463
- return str(payload_path.joinpath(*relative_parts))
464
- except (ValueError, IndexError):
465
- # Fallback to just the filename
466
- return str(payload_path / PurePosixPath(stage_path).name)
467
-
468
- try:
469
- from snowflake.ml.jobs._interop import utils as interop_utils
470
-
471
- args, kwargs = interop_utils.load(
472
- function_args,
473
- session=session,
474
- path_transform=path_transform,
475
- )
476
- return args, kwargs
477
- except (AttributeError, ImportError):
478
- # Backwards compatibility: load may not exist in older SnowML versions
479
- packed = _load_dto_fallback(function_args, path_transform)
480
- args, kwargs = _unpack_obj_fallback(packed, session)
481
- return args, kwargs
482
-
483
-
484
- def run_script(
485
- script_path: str,
486
- payload_args: Optional[tuple[Any, ...]] = None,
487
- payload_kwargs: Optional[dict[str, Any]] = None,
488
- main_func: Optional[str] = None,
489
- ) -> Any:
350
+ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
490
351
  """
491
352
  Execute a Python script and return its result.
492
353
 
493
354
  Args:
494
- script_path: Path to the Python script.
495
- payload_args: Positional arguments to pass to the script or entrypoint.
496
- payload_kwargs: Keyword arguments to pass to the script or entrypoint.
497
- main_func: The name of the function to call in the script (if any).
355
+ script_path: Path to the Python script
356
+ script_args: Arguments to pass to the script
357
+ main_func: The name of the function to call in the script (if any)
498
358
 
499
359
  Returns:
500
360
  Result from script execution, either from the main function or the script's __return__ value
501
361
 
502
362
  Raises:
503
363
  RuntimeError: If the specified main_func is not found or not callable
504
- ValueError: If payload_kwargs is provided for runpy execution.
505
364
  """
506
365
  # Save original sys.argv and modify it for the script (applies to runpy execution only)
507
366
  original_argv = sys.argv
367
+ sys.argv = [script_path, *script_args]
508
368
 
509
369
  try:
510
370
  if main_func:
@@ -521,13 +381,10 @@ def run_script(
521
381
  raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
522
382
 
523
383
  # Call main function
524
- result = func(*(payload_args or ()), **(payload_kwargs or {}))
384
+ result = func(*script_args)
525
385
  return result
526
386
  else:
527
- if payload_kwargs:
528
- raise ValueError("payload_kwargs is not supported for runpy execution; use payload_args instead")
529
- # Save original sys.argv and modify it for the script.
530
- sys.argv = [script_path, *(payload_args or ())]
387
+ # Use runpy for other scripts
531
388
  globals_dict = runpy.run_path(script_path, run_name="__main__")
532
389
  result = globals_dict.get("__return__", None)
533
390
  return result
@@ -536,28 +393,24 @@ def run_script(
536
393
  sys.argv = original_argv
537
394
 
538
395
 
539
- def main(
540
- entrypoint: str,
541
- session: snowflake.snowpark.Session,
542
- payload_args: Optional[tuple[Any, ...]] = None,
543
- payload_kwargs: Optional[dict[str, Any]] = None,
544
- script_main_func: Optional[str] = None,
545
- ) -> Any:
396
+ def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
546
397
  """Executes a Python script and serializes the result to JOB_RESULT_PATH.
547
398
 
548
399
  Args:
549
400
  entrypoint (str): The job payload entrypoint to execute.
550
- payload_args (tuple[Any, ...], optional): Positional args to pass to the script or entrypoint.
551
- payload_kwargs (dict[str, Any], optional): Keyword args to pass to the script or entrypoint.
401
+ script_args (Any): Arguments to pass to the script.
552
402
  script_main_func (str, optional): The name of the function to call in the script (if any).
553
- session (snowflake.snowpark.Session, optional): Snowpark session for stage access if needed.
554
403
 
555
404
  Returns:
556
405
  Any: The result of the script execution.
557
406
 
558
407
  Raises:
559
- ValueError: If payload_kwargs is provided for runpy execution.
408
+ Exception: Re-raises any exception caught during script execution.
560
409
  """
410
+ try:
411
+ from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
412
+ except ImportError:
413
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
561
414
 
562
415
  # Initialize Ray if available
563
416
  try:
@@ -567,6 +420,12 @@ def main(
567
420
  except ModuleNotFoundError:
568
421
  logger.debug("Ray is not installed, skipping Ray initialization")
569
422
 
423
+ # Create a Snowpark session before starting
424
+ # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
425
+ config = SnowflakeLoginOptions()
426
+ config["client_session_keep_alive"] = "True"
427
+ session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
428
+
570
429
  execution_result_is_error = False
571
430
  execution_result_value = None
572
431
  try:
@@ -587,21 +446,10 @@ def main(
587
446
 
588
447
  if is_python:
589
448
  # Run as Python script
590
- execution_result_value = run_script(
591
- resolved_entrypoint,
592
- payload_args=payload_args,
593
- payload_kwargs=payload_kwargs,
594
- main_func=script_main_func,
595
- )
449
+ execution_result_value = run_script(resolved_entrypoint, *script_args, main_func=script_main_func)
596
450
  else:
597
451
  # Run as subprocess
598
- if payload_kwargs:
599
- raise ValueError("payload_kwargs is not supported for subprocesses")
600
-
601
- run_command(
602
- resolved_entrypoint,
603
- *(payload_args or ()),
604
- )
452
+ run_command(resolved_entrypoint, *script_args)
605
453
 
606
454
  # Log end marker for user script execution
607
455
  print(LOG_END_MSG) # noqa: T201
@@ -639,36 +487,11 @@ if __name__ == "__main__":
639
487
  parser.add_argument(
640
488
  "--script_main_func", required=False, help="The name of the main function to call in the script"
641
489
  )
642
- parser.add_argument(
643
- "--function_args",
644
- required=False,
645
- help="Serialized function arguments or path to serialized function arguments file",
646
- )
647
490
  args, unknown_args = parser.parse_known_args()
648
491
 
649
- try:
650
- from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
651
- except ImportError:
652
- from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
653
-
654
- # Create a Snowpark session before starting
655
- # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
656
- # _load_function_args will use the session to load the function arguments
657
- config = SnowflakeLoginOptions()
658
- config["client_session_keep_alive"] = "True"
659
- session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
660
-
661
- if args.function_args:
662
- if args.script_args or unknown_args:
663
- raise ValueError("Only one of function_args and script_args can be provided")
664
- payload_args, payload_kwargs = _load_function_args(session, args.function_args)
665
- else:
666
- payload_args, payload_kwargs = (args.script_args + unknown_args), {}
667
-
668
492
  main(
669
493
  args.entrypoint,
670
- session=session,
671
- payload_args=payload_args,
672
- payload_kwargs=payload_kwargs,
494
+ *args.script_args,
495
+ *unknown_args,
673
496
  script_main_func=args.script_main_func,
674
497
  )
@@ -0,0 +1,22 @@
1
+ from snowflake import snowpark
2
+ from snowflake.ml._internal.utils import snowflake_env
3
+ from snowflake.ml.jobs._utils import constants, query_helper, types
4
+
5
+
6
+ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
7
+ """Extract resource information for the specified compute pool"""
8
+ # Get the instance family
9
+ rows = query_helper.run_query(
10
+ session,
11
+ "show compute pools like ?",
12
+ params=[compute_pool],
13
+ )
14
+ if not rows:
15
+ raise ValueError(f"Compute pool '{compute_pool}' not found")
16
+ instance_family: str = rows[0]["instance_family"]
17
+ cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
18
+
19
+ return (
20
+ constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
21
+ or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
22
+ )
@@ -1,12 +1,13 @@
1
1
  import copy
2
+ import functools
2
3
  from typing import Any, Callable, Optional, TypeVar
3
4
 
4
5
  from typing_extensions import ParamSpec
5
6
 
6
7
  from snowflake import snowpark
7
8
  from snowflake.ml._internal import telemetry
8
- from snowflake.ml.jobs import job_definition as jd
9
- from snowflake.ml.jobs._utils import arg_protocol, constants
9
+ from snowflake.ml.jobs import job as jb, manager as jm
10
+ from snowflake.ml.jobs._utils import payload_utils
10
11
 
11
12
  _PROJECT = "MLJob"
12
13
 
@@ -24,7 +25,7 @@ def remote(
24
25
  external_access_integrations: Optional[list[str]] = None,
25
26
  session: Optional[snowpark.Session] = None,
26
27
  **kwargs: Any,
27
- ) -> Callable[[Callable[_Args, _ReturnValue]], jd.MLJobDefinition[_Args, _ReturnValue]]:
28
+ ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
28
29
  """
29
30
  Submit a job to the compute pool.
30
31
 
@@ -50,25 +51,29 @@ def remote(
50
51
  Decorator that dispatches invocations of the decorated function as remote jobs.
51
52
  """
52
53
 
53
- def decorator(func: Callable[_Args, _ReturnValue]) -> jd.MLJobDefinition[_Args, _ReturnValue]:
54
+ def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
54
55
  # Copy the function to avoid modifying the original
55
56
  # We need to modify the line number of the function to exclude the
56
57
  # decorator from the copied source code
57
58
  wrapped_func = copy.copy(func)
58
59
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
59
60
 
60
- setattr(wrapped_func, constants.IS_MLJOB_REMOTE_ATTR, True)
61
- return jd.MLJobDefinition.register(
62
- source=wrapped_func,
63
- compute_pool=compute_pool,
64
- stage_name=stage_name,
65
- target_instances=target_instances,
66
- pip_requirements=pip_requirements,
67
- external_access_integrations=external_access_integrations,
68
- session=session or snowpark.context.get_active_session(),
69
- arg_protocol=arg_protocol.ArgProtocol.PICKLE,
70
- generate_suffix=True,
71
- **kwargs,
72
- )
61
+ @functools.wraps(func)
62
+ def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
63
+ payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
64
+ job = jm._submit_job(
65
+ source=payload,
66
+ stage_name=stage_name,
67
+ compute_pool=compute_pool,
68
+ target_instances=target_instances,
69
+ pip_requirements=pip_requirements,
70
+ external_access_integrations=external_access_integrations,
71
+ session=payload.session or session,
72
+ **kwargs,
73
+ )
74
+ assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
75
+ return job
76
+
77
+ return wrapper
73
78
 
74
79
  return decorator