snowflake-ml-python 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. snowflake/ml/_internal/utils/mixins.py +26 -1
  2. snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
  3. snowflake/ml/data/data_connector.py +2 -2
  4. snowflake/ml/data/data_ingestor.py +2 -1
  5. snowflake/ml/experiment/_experiment_info.py +3 -3
  6. snowflake/ml/jobs/_interop/data_utils.py +8 -8
  7. snowflake/ml/jobs/_interop/dto_schema.py +52 -7
  8. snowflake/ml/jobs/_interop/protocols.py +124 -7
  9. snowflake/ml/jobs/_interop/utils.py +92 -33
  10. snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
  11. snowflake/ml/jobs/_utils/constants.py +4 -0
  12. snowflake/ml/jobs/_utils/feature_flags.py +97 -13
  13. snowflake/ml/jobs/_utils/payload_utils.py +6 -40
  14. snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
  15. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
  16. snowflake/ml/jobs/decorators.py +17 -22
  17. snowflake/ml/jobs/job.py +25 -10
  18. snowflake/ml/jobs/job_definition.py +100 -8
  19. snowflake/ml/model/_client/model/model_version_impl.py +25 -14
  20. snowflake/ml/model/_client/ops/service_ops.py +6 -6
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  23. snowflake/ml/model/models/huggingface_pipeline.py +3 -0
  24. snowflake/ml/model/openai_signatures.py +154 -0
  25. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
  26. snowflake/ml/version.py +1 -1
  27. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +41 -2
  28. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +31 -32
  29. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
  30. snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
  31. snowflake/ml/jobs/_utils/spec_utils.py +0 -22
  32. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
  33. {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,11 @@
1
1
  import os
2
- from enum import Enum
3
- from typing import Optional
2
+ from typing import Callable, Optional, Union
3
+
4
+ from snowflake.ml._internal.utils.snowflake_env import SnowflakeCloudType
5
+ from snowflake.snowpark import context as sp_context
6
+
7
+ # Default value type: can be a bool or a callable that returns a bool
8
+ DefaultValue = Union[bool, Callable[[], bool]]
4
9
 
5
10
 
6
11
  def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
@@ -28,22 +33,101 @@ def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
28
33
  return default
29
34
 
30
35
 
31
- class FeatureFlags(Enum):
32
- USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
33
- ENABLE_RUNTIME_VERSIONS = "MLRS_ENABLE_RUNTIME_VERSIONS"
34
- ENABLE_STAGE_MOUNT_V2 = "MLRS_ENABLE_STAGE_MOUNT_V2"
36
+ def _enabled_in_clouds(*clouds: SnowflakeCloudType) -> Callable[[], bool]:
37
+ """Create a callable that checks if the current environment is in any of the specified clouds.
38
+
39
+ This factory function returns a callable that can be used as a dynamic default
40
+ for feature flags. The returned callable will check if the current Snowflake
41
+ session is connected to a region in any of the specified cloud providers.
42
+
43
+ Args:
44
+ *clouds: One or more SnowflakeCloudType values to check against.
45
+
46
+ Returns:
47
+ A callable that returns True if running in any of the specified clouds,
48
+ False otherwise (including when no session is available).
49
+
50
+ Example:
51
+ >>> # Enable feature only in GCP
52
+ >>> default=_enabled_in_clouds(SnowflakeCloudType.GCP)
53
+ >>>
54
+ >>> # Enable feature in both GCP and Azure
55
+ >>> default=_enabled_in_clouds(SnowflakeCloudType.GCP, SnowflakeCloudType.AZURE)
56
+ """
57
+ cloud_set = frozenset(clouds)
58
+
59
+ def check() -> bool:
60
+ try:
61
+ from snowflake.ml._internal.utils.snowflake_env import get_current_cloud
62
+
63
+ session = sp_context.get_active_session()
64
+ current_cloud = get_current_cloud(session, default=SnowflakeCloudType.AWS)
65
+ return current_cloud in cloud_set
66
+ except Exception:
67
+ # If we can't determine the cloud (no session, SQL error, etc.),
68
+ # default to False for safety
69
+ return False
70
+
71
+ return check
35
72
 
36
- def is_enabled(self, default: bool = False) -> bool:
37
- """Check if the feature flag is enabled.
73
+
74
+ class _FeatureFlag:
75
+ """A feature flag backed by an environment variable with a configurable default.
76
+
77
+ The default value can be a constant boolean or a callable that dynamically
78
+ determines the default based on runtime context (e.g., cloud provider).
79
+ """
80
+
81
+ def __init__(self, env_var: str, default: DefaultValue = False) -> None:
82
+ """Initialize a feature flag.
38
83
 
39
84
  Args:
40
- default: The default value to return if the environment variable is not set.
85
+ env_var: The environment variable name that controls this flag.
86
+ default: The default value when the env var is not set. Can be:
87
+ - A boolean constant (True/False)
88
+ - A callable that returns a boolean (evaluated at check time)
89
+ """
90
+ self._env_var = env_var
91
+ self._default = default
92
+
93
+ @property
94
+ def value(self) -> str:
95
+ """Return the environment variable name (for compatibility with Enum-style access)."""
96
+ return self._env_var
97
+
98
+ def _get_default(self) -> bool:
99
+ """Get the default value, calling it if it's a callable."""
100
+ if callable(self._default):
101
+ return self._default()
102
+ return self._default
103
+
104
+ def is_enabled(self) -> bool:
105
+ """Check if the feature flag is enabled.
106
+
107
+ First checks the environment variable. If not set or unrecognized,
108
+ falls back to the configured default value.
41
109
 
42
110
  Returns:
43
- True if the environment variable is set to a truthy value,
44
- False if set to a falsy value, or the default value if not set.
111
+ True if the feature is enabled, False otherwise.
45
112
  """
46
- return parse_bool_env_value(os.getenv(self.value), default)
113
+ env_value = os.getenv(self._env_var)
114
+ if env_value is not None:
115
+ # Environment variable is set, parse it
116
+ result = parse_bool_env_value(env_value, default=self._get_default())
117
+ return result
118
+ else:
119
+ # Environment variable not set, use the default
120
+ return self._get_default()
47
121
 
48
122
  def __str__(self) -> str:
49
- return self.value
123
+ return self._env_var
124
+
125
+
126
+ class FeatureFlags:
127
+ """Collection of feature flags for ML Jobs."""
128
+
129
+ ENABLE_RUNTIME_VERSIONS = _FeatureFlag("MLRS_ENABLE_RUNTIME_VERSIONS", default=True)
130
+ ENABLE_STAGE_MOUNT_V2 = _FeatureFlag(
131
+ "MLRS_ENABLE_STAGE_MOUNT_V2",
132
+ default=_enabled_in_clouds(SnowflakeCloudType.GCP),
133
+ )
@@ -17,20 +17,12 @@ import cloudpickle as cp
17
17
  from packaging import version
18
18
 
19
19
  from snowflake import snowpark
20
- from snowflake.ml.jobs._utils import (
21
- constants,
22
- function_payload_utils,
23
- query_helper,
24
- stage_utils,
25
- types,
26
- )
20
+ from snowflake.ml.jobs._utils import constants, query_helper, stage_utils, types
27
21
  from snowflake.snowpark import exceptions as sp_exceptions
28
22
  from snowflake.snowpark._internal import code_generation
29
23
  from snowflake.snowpark._internal.utils import zip_file_or_directory_to_stream
30
24
 
31
25
  logger = logging.getLogger(__name__)
32
-
33
- cp.register_pickle_by_value(function_payload_utils)
34
26
  ImportType = Union[str, Path, ModuleType]
35
27
 
36
28
  _SUPPORTED_ARG_TYPES = {str, int, float}
@@ -561,7 +553,6 @@ class JobPayload:
561
553
  env_vars = {
562
554
  constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,
563
555
  constants.PAYLOAD_DIR_ENV_VAR: constants.APP_STAGE_SUBPATH,
564
- constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
565
556
  }
566
557
 
567
558
  return types.UploadedPayload(
@@ -691,14 +682,9 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
691
682
  return param_code
692
683
 
693
684
 
694
- def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
685
+ def generate_python_code(function: Callable[..., Any], source_code_display: bool = False) -> str:
695
686
  """Generate an entrypoint script from a Python function."""
696
687
 
697
- if isinstance(payload, function_payload_utils.FunctionPayload):
698
- function = payload.function
699
- else:
700
- function = payload
701
-
702
688
  signature = inspect.signature(function)
703
689
  if any(
704
690
  p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
@@ -711,7 +697,7 @@ def generate_python_code(payload: Callable[..., Any], source_code_display: bool
711
697
  source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
712
698
 
713
699
  arg_dict_name = "kwargs"
714
- if isinstance(payload, function_payload_utils.FunctionPayload):
700
+ if getattr(function, constants.IS_MLJOB_REMOTE_ATTR, None):
715
701
  param_code = f"{arg_dict_name} = {{}}"
716
702
  else:
717
703
  param_code = _generate_param_handler_code(signature, arg_dict_name)
@@ -721,7 +707,7 @@ import pickle
721
707
 
722
708
  try:
723
709
  {textwrap.indent(source_code_comment, ' ')}
724
- {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
710
+ {_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(function).hex()}'))
725
711
  except (TypeError, pickle.PickleError):
726
712
  if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
727
713
  raise RuntimeError(
@@ -747,26 +733,6 @@ if __name__ == '__main__':
747
733
  """
748
734
 
749
735
 
750
- def create_function_payload(
751
- func: Callable[..., Any], *args: Any, **kwargs: Any
752
- ) -> function_payload_utils.FunctionPayload:
753
- signature = inspect.signature(func)
754
- bound = signature.bind(*args, **kwargs)
755
- bound.apply_defaults()
756
- session_argument = ""
757
- session = None
758
- for name, val in list(bound.arguments.items()):
759
- if isinstance(val, snowpark.Session):
760
- if session:
761
- raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
762
- session = val
763
- session_argument = name
764
- del bound.arguments[name]
765
- payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
766
-
767
- return payload
768
-
769
-
770
736
  def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optional[Union[str, list[str]]] = None) -> str:
771
737
 
772
738
  if entrypoint and isinstance(entrypoint, (list, tuple)):
@@ -775,7 +741,7 @@ def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optiona
775
741
  return f"{PurePath(entrypoint).stem}"
776
742
  elif source and not callable(source):
777
743
  return f"{PurePath(source).stem}"
778
- elif isinstance(source, function_payload_utils.FunctionPayload):
779
- return f"{source.function.__name__}"
744
+ elif callable(source):
745
+ return f"{source.__name__}"
780
746
  else:
781
747
  return f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
@@ -1,117 +1,18 @@
1
- import datetime
2
- import logging
3
- from typing import Any, Literal, Optional, Union
4
-
5
- from packaging.version import Version
6
- from pydantic import BaseModel, Field, RootModel, field_validator
1
+ from typing import Optional, cast
7
2
 
8
3
  from snowflake import snowpark
9
- from snowflake.ml.jobs._utils import constants, query_helper
10
-
11
-
12
- class SpcsContainerRuntime(BaseModel):
13
- python_version: Version = Field(alias="pythonVersion")
14
- hardware_type: str = Field(alias="hardwareType")
15
- runtime_container_image: str = Field(alias="runtimeContainerImage")
16
-
17
- @field_validator("python_version", mode="before")
18
- @classmethod
19
- def validate_python_version(cls, v: Union[str, Version]) -> Version:
20
- if isinstance(v, Version):
21
- return v
22
- try:
23
- return Version(v)
24
- except Exception:
25
- raise ValueError(f"Invalid Python version format: {v}")
26
-
27
- class Config:
28
- frozen = True
29
- extra = "allow"
30
- arbitrary_types_allowed = True
31
-
32
-
33
- class RuntimeEnvironmentEntry(BaseModel):
34
- spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
35
- created_on: datetime.datetime = Field(alias="createdOn")
36
- id: Optional[str] = Field(alias="id")
37
-
38
- class Config:
39
- extra = "allow"
40
- frozen = True
41
-
42
-
43
- class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
44
- @field_validator("root", mode="before")
45
- @classmethod
46
- def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]:
47
- """
48
- Pre-validation hook: keep only those items at the root level
49
- whose values are dicts. Non-dict values will be dropped.
4
+ from snowflake.ml.jobs._utils import query_helper
50
5
 
51
- Args:
52
- data: The input data to filter, expected to be a dictionary.
53
6
 
54
- Returns:
55
- A dictionary containing only the key-value pairs where values are dictionaries.
56
-
57
- Raises:
58
- ValueError: If input data is not a dictionary.
59
- """
60
- # If the entire root is not a dict, raise error immediately
61
- if not isinstance(data, dict):
62
- raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}")
63
-
64
- # Filter out any key whose value is not a dict
65
- return {key: value for key, value in data.items() if isinstance(value, dict)}
66
-
67
- def get_spcs_container_runtimes(
68
- self,
69
- *,
70
- hardware_type: Optional[str] = None,
71
- python_version: Optional[Version] = None,
72
- ) -> list[SpcsContainerRuntime]:
73
- # TODO(SNOW-2682000): parse version from NRE in a safer way, like relying on the label,id or image tag.
74
- entries: list[RuntimeEnvironmentEntry] = [
75
- entry
76
- for entry in self.root.values()
77
- if entry.spcs_container_runtime is not None
78
- and (hardware_type is None or entry.spcs_container_runtime.hardware_type.lower() == hardware_type.lower())
79
- and (
80
- python_version is None
81
- or (
82
- entry.spcs_container_runtime.python_version.major == python_version.major
83
- and entry.spcs_container_runtime.python_version.minor == python_version.minor
84
- )
85
- )
86
- ]
87
- entries.sort(key=lambda e: e.created_on, reverse=True)
88
-
89
- return [entry.spcs_container_runtime for entry in entries if entry.spcs_container_runtime is not None]
90
-
91
-
92
- def _extract_image_tag(image_url: str) -> Optional[str]:
93
- image_tag = image_url.rsplit(":", 1)[-1]
94
- return image_tag
95
-
96
-
97
- def find_runtime_image(
98
- session: snowpark.Session, target_hardware: Literal["CPU", "GPU"], target_python_version: Optional[str] = None
7
+ def get_runtime_image(
8
+ session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
99
9
  ) -> Optional[str]:
100
- python_version = (
101
- Version(target_python_version) if target_python_version else Version(constants.DEFAULT_PYTHON_VERSION)
102
- )
103
- rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
10
+ runtime_environment = runtime_environment if runtime_environment else ""
11
+ rows = query_helper.run_query(session, f"CALL SYSTEM$GET_ML_JOB_RUNTIME('{compute_pool}', '{runtime_environment}')")
104
12
  if not rows:
105
- return None
106
- try:
107
- runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
108
- spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes(
109
- hardware_type=target_hardware,
110
- python_version=python_version,
111
- )
112
- except Exception as e:
113
- logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
114
- return None
115
-
116
- selected_runtime = spcs_container_runtimes[0] if spcs_container_runtimes else None
117
- return selected_runtime.runtime_container_image if selected_runtime else None
13
+ raise ValueError("Failed to get any available runtime image")
14
+ image = rows[0][0]
15
+ url, tag = image.rsplit(":", 1)
16
+ if url is None or tag is None:
17
+ raise ValueError(f"image {image} is not a valid runtime image")
18
+ return cast(str, image) if image else None
@@ -1,6 +1,7 @@
1
1
  import argparse
2
2
  import copy
3
3
  import importlib.util
4
+ import io
4
5
  import json
5
6
  import logging
6
7
  import math
@@ -12,15 +13,22 @@ import sys
12
13
  import time
13
14
  import traceback
14
15
  import zipfile
15
- from pathlib import Path
16
- from typing import Any, Optional
16
+ from pathlib import Path, PurePosixPath
17
+ from typing import Any, Callable, Optional
17
18
 
18
19
  # Ensure payload directory is in sys.path for module imports before importing other modules
19
20
  # This is needed to support relative imports in user scripts and to allow overriding
20
21
  # modules using modules in the payload directory
21
22
  # TODO: Inject the environment variable names at job submission time
22
23
  STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
23
- JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
24
+ STAGE_RESULT_PATH = os.environ.get("MLRS_STAGE_RESULT_PATH")
25
+ # Updated MLRS_RESULT_PATH to use unique stage mounts for each ML Job.
26
+ # To prevent output collisions between jobs sharing the same definition,
27
+ # the server-side mount now dynamically includes the job_name.
28
+ # Format: @payload_stage/{job_definition_name}/{job_name}/mljob_result
29
+ JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "mljob_result")
30
+ if STAGE_RESULT_PATH:
31
+ JOB_RESULT_PATH = os.path.join(STAGE_RESULT_PATH, JOB_RESULT_PATH)
24
32
  PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
25
33
 
26
34
  if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
@@ -347,24 +355,156 @@ def wait_for_instances(
347
355
  )
348
356
 
349
357
 
350
- def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
358
+ def _load_dto_fallback(function_args: str, path_transform: Callable[[str], str]) -> Any:
359
+ from snowflake.ml.jobs._interop import data_utils
360
+ from snowflake.ml.jobs._interop.utils import DEFAULT_CODEC, DEFAULT_PROTOCOL
361
+ from snowflake.snowpark import exceptions as sp_exceptions
362
+
363
+ try:
364
+ with data_utils.open_stream(function_args, "r") as stream:
365
+ # Load the DTO as a dict for easy fallback to legacy loading if necessary
366
+ data = DEFAULT_CODEC.decode(stream, as_dict=True)
367
+ # the exception could be OSError or BlockingIOError(the file name is too long)
368
+ except OSError as e:
369
+ # path_or_data might be inline data
370
+ try:
371
+ data = DEFAULT_CODEC.decode(io.StringIO(function_args), as_dict=True)
372
+ except Exception:
373
+ raise e
374
+
375
+ if data["protocol"] is not None:
376
+ try:
377
+ from snowflake.ml.jobs._interop.dto_schema import ProtocolInfo
378
+
379
+ protocol_info = ProtocolInfo.model_validate(data["protocol"])
380
+ logger.debug(f"Loading result value with protocol {protocol_info}")
381
+ result_value = DEFAULT_PROTOCOL.load(protocol_info, session=None, path_transform=path_transform)
382
+ except sp_exceptions.SnowparkSQLException:
383
+ raise
384
+ else:
385
+ result_value = None
386
+
387
+ return data["value"] or result_value
388
+
389
+
390
+ def _unpack_obj_fallback(obj: Any, session: Optional[snowflake.snowpark.Session]) -> Any:
391
+ SESSION_KEY_PREFIX = "session@"
392
+
393
+ if not isinstance(obj, dict):
394
+ return obj
395
+ elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
396
+ return session
397
+ else:
398
+ type = obj.get("type@", None)
399
+ # If type is None, we are unpacking a dict
400
+ if type is None:
401
+ result_dict = {}
402
+ for k, v in obj.items():
403
+ if k.startswith(SESSION_KEY_PREFIX):
404
+ result_key = k[len(SESSION_KEY_PREFIX) :]
405
+ result_dict[result_key] = session
406
+ else:
407
+ result_dict[k] = _unpack_obj_fallback(v, session)
408
+ return result_dict
409
+ # If type is not None, we are unpacking a tuple or list
410
+ else:
411
+ indexes = []
412
+ for k, _ in obj.items():
413
+ if "#" in k:
414
+ indexes.append(int(k.split("#")[-1]))
415
+
416
+ if not indexes:
417
+ return tuple() if type is tuple else []
418
+ result_list: list[Any] = [None] * (max(indexes) + 1)
419
+
420
+ for k, v in obj.items():
421
+ if k == "type@":
422
+ continue
423
+ idx = int(k.split("#")[-1])
424
+ if k.startswith(SESSION_KEY_PREFIX):
425
+ result_list[idx] = session
426
+ else:
427
+ result_list[idx] = _unpack_obj_fallback(v, session)
428
+ return tuple(result_list) if type is tuple else result_list
429
+
430
+
431
+ def _load_function_args(
432
+ session: snowflake.snowpark.Session,
433
+ function_args: Optional[str] = None,
434
+ ) -> tuple[tuple[Any, ...], dict[str, Any]]:
435
+ """Load and deserialize function arguments.
436
+
437
+ Args:
438
+ function_args: Inline serialized function arguments or path to serialized file.
439
+ session: Optional Snowpark session for stage access if needed.
440
+
441
+ Returns:
442
+ A tuple of (positional_args, keyword_args)
443
+
444
+ """
445
+ if not function_args:
446
+ return (), {}
447
+
448
+ def path_transform(stage_path: str) -> str:
449
+ if not PAYLOAD_PATH:
450
+ return stage_path
451
+
452
+ payload_path = PurePosixPath(PAYLOAD_PATH)
453
+ payload_dir_name = payload_path.name # e.g., "app"
454
+
455
+ # Parse stage path and find the payload directory
456
+ stage_parts = PurePosixPath(stage_path.lstrip("@")).parts
457
+
458
+ try:
459
+ # Find index of payload directory (e.g., "app") in stage path
460
+ idx = stage_parts.index(payload_dir_name)
461
+ # Get relative path after the payload directory
462
+ relative_parts = stage_parts[idx + 1 :]
463
+ return str(payload_path.joinpath(*relative_parts))
464
+ except (ValueError, IndexError):
465
+ # Fallback to just the filename
466
+ return str(payload_path / PurePosixPath(stage_path).name)
467
+
468
+ try:
469
+ from snowflake.ml.jobs._interop import utils as interop_utils
470
+
471
+ args, kwargs = interop_utils.load(
472
+ function_args,
473
+ session=session,
474
+ path_transform=path_transform,
475
+ )
476
+ return args, kwargs
477
+ except (AttributeError, ImportError):
478
+ # Backwards compatibility: load may not exist in older SnowML versions
479
+ packed = _load_dto_fallback(function_args, path_transform)
480
+ args, kwargs = _unpack_obj_fallback(packed, session)
481
+ return args, kwargs
482
+
483
+
484
+ def run_script(
485
+ script_path: str,
486
+ payload_args: Optional[tuple[Any, ...]] = None,
487
+ payload_kwargs: Optional[dict[str, Any]] = None,
488
+ main_func: Optional[str] = None,
489
+ ) -> Any:
351
490
  """
352
491
  Execute a Python script and return its result.
353
492
 
354
493
  Args:
355
- script_path: Path to the Python script
356
- script_args: Arguments to pass to the script
357
- main_func: The name of the function to call in the script (if any)
494
+ script_path: Path to the Python script.
495
+ payload_args: Positional arguments to pass to the script or entrypoint.
496
+ payload_kwargs: Keyword arguments to pass to the script or entrypoint.
497
+ main_func: The name of the function to call in the script (if any).
358
498
 
359
499
  Returns:
360
500
  Result from script execution, either from the main function or the script's __return__ value
361
501
 
362
502
  Raises:
363
503
  RuntimeError: If the specified main_func is not found or not callable
504
+ ValueError: If payload_kwargs is provided for runpy execution.
364
505
  """
365
506
  # Save original sys.argv and modify it for the script (applies to runpy execution only)
366
507
  original_argv = sys.argv
367
- sys.argv = [script_path, *script_args]
368
508
 
369
509
  try:
370
510
  if main_func:
@@ -381,10 +521,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
381
521
  raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
382
522
 
383
523
  # Call main function
384
- result = func(*script_args)
524
+ result = func(*(payload_args or ()), **(payload_kwargs or {}))
385
525
  return result
386
526
  else:
387
- # Use runpy for other scripts
527
+ if payload_kwargs:
528
+ raise ValueError("payload_kwargs is not supported for runpy execution; use payload_args instead")
529
+ # Save original sys.argv and modify it for the script.
530
+ sys.argv = [script_path, *(payload_args or ())]
388
531
  globals_dict = runpy.run_path(script_path, run_name="__main__")
389
532
  result = globals_dict.get("__return__", None)
390
533
  return result
@@ -393,24 +536,28 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
393
536
  sys.argv = original_argv
394
537
 
395
538
 
396
- def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
539
+ def main(
540
+ entrypoint: str,
541
+ session: snowflake.snowpark.Session,
542
+ payload_args: Optional[tuple[Any, ...]] = None,
543
+ payload_kwargs: Optional[dict[str, Any]] = None,
544
+ script_main_func: Optional[str] = None,
545
+ ) -> Any:
397
546
  """Executes a Python script and serializes the result to JOB_RESULT_PATH.
398
547
 
399
548
  Args:
400
549
  entrypoint (str): The job payload entrypoint to execute.
401
- script_args (Any): Arguments to pass to the script.
550
+ payload_args (tuple[Any, ...], optional): Positional args to pass to the script or entrypoint.
551
+ payload_kwargs (dict[str, Any], optional): Keyword args to pass to the script or entrypoint.
402
552
  script_main_func (str, optional): The name of the function to call in the script (if any).
553
+ session (snowflake.snowpark.Session, optional): Snowpark session for stage access if needed.
403
554
 
404
555
  Returns:
405
556
  Any: The result of the script execution.
406
557
 
407
558
  Raises:
408
- Exception: Re-raises any exception caught during script execution.
559
+ ValueError: If payload_kwargs is provided for runpy execution.
409
560
  """
410
- try:
411
- from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
412
- except ImportError:
413
- from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
414
561
 
415
562
  # Initialize Ray if available
416
563
  try:
@@ -420,12 +567,6 @@ def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = N
420
567
  except ModuleNotFoundError:
421
568
  logger.debug("Ray is not installed, skipping Ray initialization")
422
569
 
423
- # Create a Snowpark session before starting
424
- # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
425
- config = SnowflakeLoginOptions()
426
- config["client_session_keep_alive"] = "True"
427
- session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
428
-
429
570
  execution_result_is_error = False
430
571
  execution_result_value = None
431
572
  try:
@@ -446,10 +587,21 @@ def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = N
446
587
 
447
588
  if is_python:
448
589
  # Run as Python script
449
- execution_result_value = run_script(resolved_entrypoint, *script_args, main_func=script_main_func)
590
+ execution_result_value = run_script(
591
+ resolved_entrypoint,
592
+ payload_args=payload_args,
593
+ payload_kwargs=payload_kwargs,
594
+ main_func=script_main_func,
595
+ )
450
596
  else:
451
597
  # Run as subprocess
452
- run_command(resolved_entrypoint, *script_args)
598
+ if payload_kwargs:
599
+ raise ValueError("payload_kwargs is not supported for subprocesses")
600
+
601
+ run_command(
602
+ resolved_entrypoint,
603
+ *(payload_args or ()),
604
+ )
453
605
 
454
606
  # Log end marker for user script execution
455
607
  print(LOG_END_MSG) # noqa: T201
@@ -487,11 +639,36 @@ if __name__ == "__main__":
487
639
  parser.add_argument(
488
640
  "--script_main_func", required=False, help="The name of the main function to call in the script"
489
641
  )
642
+ parser.add_argument(
643
+ "--function_args",
644
+ required=False,
645
+ help="Serialized function arguments or path to serialized function arguments file",
646
+ )
490
647
  args, unknown_args = parser.parse_known_args()
491
648
 
649
+ try:
650
+ from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
651
+ except ImportError:
652
+ from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
653
+
654
+ # Create a Snowpark session before starting
655
+ # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
656
+ # _load_function_args will use the session to load the function arguments
657
+ config = SnowflakeLoginOptions()
658
+ config["client_session_keep_alive"] = "True"
659
+ session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
660
+
661
+ if args.function_args:
662
+ if args.script_args or unknown_args:
663
+ raise ValueError("Only one of function_args and script_args can be provided")
664
+ payload_args, payload_kwargs = _load_function_args(session, args.function_args)
665
+ else:
666
+ payload_args, payload_kwargs = (args.script_args + unknown_args), {}
667
+
492
668
  main(
493
669
  args.entrypoint,
494
- *args.script_args,
495
- *unknown_args,
670
+ session=session,
671
+ payload_args=payload_args,
672
+ payload_kwargs=payload_kwargs,
496
673
  script_main_func=args.script_main_func,
497
674
  )