snowflake-ml-python 1.24.0__py3-none-any.whl → 1.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +26 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +5 -1
- snowflake/ml/data/data_connector.py +2 -2
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/experiment/_experiment_info.py +3 -3
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +52 -7
- snowflake/ml/jobs/_interop/protocols.py +124 -7
- snowflake/ml/jobs/_interop/utils.py +92 -33
- snowflake/ml/jobs/_utils/arg_protocol.py +7 -0
- snowflake/ml/jobs/_utils/constants.py +4 -0
- snowflake/ml/jobs/_utils/feature_flags.py +97 -13
- snowflake/ml/jobs/_utils/payload_utils.py +6 -40
- snowflake/ml/jobs/_utils/runtime_env_utils.py +12 -111
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +204 -27
- snowflake/ml/jobs/decorators.py +17 -22
- snowflake/ml/jobs/job.py +25 -10
- snowflake/ml/jobs/job_definition.py +100 -8
- snowflake/ml/model/_client/model/model_version_impl.py +25 -14
- snowflake/ml/model/_client/ops/service_ops.py +6 -6
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/models/huggingface_pipeline.py +3 -0
- snowflake/ml/model/openai_signatures.py +154 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/METADATA +41 -2
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/RECORD +31 -32
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/WHEEL +1 -1
- snowflake/ml/jobs/_utils/function_payload_utils.py +0 -43
- snowflake/ml/jobs/_utils/spec_utils.py +0 -22
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.24.0.dist-info → snowflake_ml_python-1.25.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,11 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from
|
|
3
|
-
|
|
2
|
+
from typing import Callable, Optional, Union
|
|
3
|
+
|
|
4
|
+
from snowflake.ml._internal.utils.snowflake_env import SnowflakeCloudType
|
|
5
|
+
from snowflake.snowpark import context as sp_context
|
|
6
|
+
|
|
7
|
+
# Default value type: can be a bool or a callable that returns a bool
|
|
8
|
+
DefaultValue = Union[bool, Callable[[], bool]]
|
|
4
9
|
|
|
5
10
|
|
|
6
11
|
def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
|
|
@@ -28,22 +33,101 @@ def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
|
|
|
28
33
|
return default
|
|
29
34
|
|
|
30
35
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
36
|
+
def _enabled_in_clouds(*clouds: SnowflakeCloudType) -> Callable[[], bool]:
|
|
37
|
+
"""Create a callable that checks if the current environment is in any of the specified clouds.
|
|
38
|
+
|
|
39
|
+
This factory function returns a callable that can be used as a dynamic default
|
|
40
|
+
for feature flags. The returned callable will check if the current Snowflake
|
|
41
|
+
session is connected to a region in any of the specified cloud providers.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
*clouds: One or more SnowflakeCloudType values to check against.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
A callable that returns True if running in any of the specified clouds,
|
|
48
|
+
False otherwise (including when no session is available).
|
|
49
|
+
|
|
50
|
+
Example:
|
|
51
|
+
>>> # Enable feature only in GCP
|
|
52
|
+
>>> default=_enabled_in_clouds(SnowflakeCloudType.GCP)
|
|
53
|
+
>>>
|
|
54
|
+
>>> # Enable feature in both GCP and Azure
|
|
55
|
+
>>> default=_enabled_in_clouds(SnowflakeCloudType.GCP, SnowflakeCloudType.AZURE)
|
|
56
|
+
"""
|
|
57
|
+
cloud_set = frozenset(clouds)
|
|
58
|
+
|
|
59
|
+
def check() -> bool:
|
|
60
|
+
try:
|
|
61
|
+
from snowflake.ml._internal.utils.snowflake_env import get_current_cloud
|
|
62
|
+
|
|
63
|
+
session = sp_context.get_active_session()
|
|
64
|
+
current_cloud = get_current_cloud(session, default=SnowflakeCloudType.AWS)
|
|
65
|
+
return current_cloud in cloud_set
|
|
66
|
+
except Exception:
|
|
67
|
+
# If we can't determine the cloud (no session, SQL error, etc.),
|
|
68
|
+
# default to False for safety
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
return check
|
|
35
72
|
|
|
36
|
-
|
|
37
|
-
|
|
73
|
+
|
|
74
|
+
class _FeatureFlag:
|
|
75
|
+
"""A feature flag backed by an environment variable with a configurable default.
|
|
76
|
+
|
|
77
|
+
The default value can be a constant boolean or a callable that dynamically
|
|
78
|
+
determines the default based on runtime context (e.g., cloud provider).
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, env_var: str, default: DefaultValue = False) -> None:
|
|
82
|
+
"""Initialize a feature flag.
|
|
38
83
|
|
|
39
84
|
Args:
|
|
40
|
-
|
|
85
|
+
env_var: The environment variable name that controls this flag.
|
|
86
|
+
default: The default value when the env var is not set. Can be:
|
|
87
|
+
- A boolean constant (True/False)
|
|
88
|
+
- A callable that returns a boolean (evaluated at check time)
|
|
89
|
+
"""
|
|
90
|
+
self._env_var = env_var
|
|
91
|
+
self._default = default
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def value(self) -> str:
|
|
95
|
+
"""Return the environment variable name (for compatibility with Enum-style access)."""
|
|
96
|
+
return self._env_var
|
|
97
|
+
|
|
98
|
+
def _get_default(self) -> bool:
|
|
99
|
+
"""Get the default value, calling it if it's a callable."""
|
|
100
|
+
if callable(self._default):
|
|
101
|
+
return self._default()
|
|
102
|
+
return self._default
|
|
103
|
+
|
|
104
|
+
def is_enabled(self) -> bool:
|
|
105
|
+
"""Check if the feature flag is enabled.
|
|
106
|
+
|
|
107
|
+
First checks the environment variable. If not set or unrecognized,
|
|
108
|
+
falls back to the configured default value.
|
|
41
109
|
|
|
42
110
|
Returns:
|
|
43
|
-
True if the
|
|
44
|
-
False if set to a falsy value, or the default value if not set.
|
|
111
|
+
True if the feature is enabled, False otherwise.
|
|
45
112
|
"""
|
|
46
|
-
|
|
113
|
+
env_value = os.getenv(self._env_var)
|
|
114
|
+
if env_value is not None:
|
|
115
|
+
# Environment variable is set, parse it
|
|
116
|
+
result = parse_bool_env_value(env_value, default=self._get_default())
|
|
117
|
+
return result
|
|
118
|
+
else:
|
|
119
|
+
# Environment variable not set, use the default
|
|
120
|
+
return self._get_default()
|
|
47
121
|
|
|
48
122
|
def __str__(self) -> str:
|
|
49
|
-
return self.
|
|
123
|
+
return self._env_var
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class FeatureFlags:
|
|
127
|
+
"""Collection of feature flags for ML Jobs."""
|
|
128
|
+
|
|
129
|
+
ENABLE_RUNTIME_VERSIONS = _FeatureFlag("MLRS_ENABLE_RUNTIME_VERSIONS", default=True)
|
|
130
|
+
ENABLE_STAGE_MOUNT_V2 = _FeatureFlag(
|
|
131
|
+
"MLRS_ENABLE_STAGE_MOUNT_V2",
|
|
132
|
+
default=_enabled_in_clouds(SnowflakeCloudType.GCP),
|
|
133
|
+
)
|
|
@@ -17,20 +17,12 @@ import cloudpickle as cp
|
|
|
17
17
|
from packaging import version
|
|
18
18
|
|
|
19
19
|
from snowflake import snowpark
|
|
20
|
-
from snowflake.ml.jobs._utils import
|
|
21
|
-
constants,
|
|
22
|
-
function_payload_utils,
|
|
23
|
-
query_helper,
|
|
24
|
-
stage_utils,
|
|
25
|
-
types,
|
|
26
|
-
)
|
|
20
|
+
from snowflake.ml.jobs._utils import constants, query_helper, stage_utils, types
|
|
27
21
|
from snowflake.snowpark import exceptions as sp_exceptions
|
|
28
22
|
from snowflake.snowpark._internal import code_generation
|
|
29
23
|
from snowflake.snowpark._internal.utils import zip_file_or_directory_to_stream
|
|
30
24
|
|
|
31
25
|
logger = logging.getLogger(__name__)
|
|
32
|
-
|
|
33
|
-
cp.register_pickle_by_value(function_payload_utils)
|
|
34
26
|
ImportType = Union[str, Path, ModuleType]
|
|
35
27
|
|
|
36
28
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
|
@@ -561,7 +553,6 @@ class JobPayload:
|
|
|
561
553
|
env_vars = {
|
|
562
554
|
constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,
|
|
563
555
|
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_STAGE_SUBPATH,
|
|
564
|
-
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
565
556
|
}
|
|
566
557
|
|
|
567
558
|
return types.UploadedPayload(
|
|
@@ -691,14 +682,9 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
|
|
|
691
682
|
return param_code
|
|
692
683
|
|
|
693
684
|
|
|
694
|
-
def generate_python_code(
|
|
685
|
+
def generate_python_code(function: Callable[..., Any], source_code_display: bool = False) -> str:
|
|
695
686
|
"""Generate an entrypoint script from a Python function."""
|
|
696
687
|
|
|
697
|
-
if isinstance(payload, function_payload_utils.FunctionPayload):
|
|
698
|
-
function = payload.function
|
|
699
|
-
else:
|
|
700
|
-
function = payload
|
|
701
|
-
|
|
702
688
|
signature = inspect.signature(function)
|
|
703
689
|
if any(
|
|
704
690
|
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
|
@@ -711,7 +697,7 @@ def generate_python_code(payload: Callable[..., Any], source_code_display: bool
|
|
|
711
697
|
source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
|
|
712
698
|
|
|
713
699
|
arg_dict_name = "kwargs"
|
|
714
|
-
if
|
|
700
|
+
if getattr(function, constants.IS_MLJOB_REMOTE_ATTR, None):
|
|
715
701
|
param_code = f"{arg_dict_name} = {{}}"
|
|
716
702
|
else:
|
|
717
703
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
|
@@ -721,7 +707,7 @@ import pickle
|
|
|
721
707
|
|
|
722
708
|
try:
|
|
723
709
|
{textwrap.indent(source_code_comment, ' ')}
|
|
724
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(
|
|
710
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(function).hex()}'))
|
|
725
711
|
except (TypeError, pickle.PickleError):
|
|
726
712
|
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
|
727
713
|
raise RuntimeError(
|
|
@@ -747,26 +733,6 @@ if __name__ == '__main__':
|
|
|
747
733
|
"""
|
|
748
734
|
|
|
749
735
|
|
|
750
|
-
def create_function_payload(
|
|
751
|
-
func: Callable[..., Any], *args: Any, **kwargs: Any
|
|
752
|
-
) -> function_payload_utils.FunctionPayload:
|
|
753
|
-
signature = inspect.signature(func)
|
|
754
|
-
bound = signature.bind(*args, **kwargs)
|
|
755
|
-
bound.apply_defaults()
|
|
756
|
-
session_argument = ""
|
|
757
|
-
session = None
|
|
758
|
-
for name, val in list(bound.arguments.items()):
|
|
759
|
-
if isinstance(val, snowpark.Session):
|
|
760
|
-
if session:
|
|
761
|
-
raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
|
|
762
|
-
session = val
|
|
763
|
-
session_argument = name
|
|
764
|
-
del bound.arguments[name]
|
|
765
|
-
payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
|
|
766
|
-
|
|
767
|
-
return payload
|
|
768
|
-
|
|
769
|
-
|
|
770
736
|
def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optional[Union[str, list[str]]] = None) -> str:
|
|
771
737
|
|
|
772
738
|
if entrypoint and isinstance(entrypoint, (list, tuple)):
|
|
@@ -775,7 +741,7 @@ def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optiona
|
|
|
775
741
|
return f"{PurePath(entrypoint).stem}"
|
|
776
742
|
elif source and not callable(source):
|
|
777
743
|
return f"{PurePath(source).stem}"
|
|
778
|
-
elif
|
|
779
|
-
return f"{source.
|
|
744
|
+
elif callable(source):
|
|
745
|
+
return f"{source.__name__}"
|
|
780
746
|
else:
|
|
781
747
|
return f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
|
@@ -1,117 +1,18 @@
|
|
|
1
|
-
import
|
|
2
|
-
import logging
|
|
3
|
-
from typing import Any, Literal, Optional, Union
|
|
4
|
-
|
|
5
|
-
from packaging.version import Version
|
|
6
|
-
from pydantic import BaseModel, Field, RootModel, field_validator
|
|
1
|
+
from typing import Optional, cast
|
|
7
2
|
|
|
8
3
|
from snowflake import snowpark
|
|
9
|
-
from snowflake.ml.jobs._utils import
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class SpcsContainerRuntime(BaseModel):
|
|
13
|
-
python_version: Version = Field(alias="pythonVersion")
|
|
14
|
-
hardware_type: str = Field(alias="hardwareType")
|
|
15
|
-
runtime_container_image: str = Field(alias="runtimeContainerImage")
|
|
16
|
-
|
|
17
|
-
@field_validator("python_version", mode="before")
|
|
18
|
-
@classmethod
|
|
19
|
-
def validate_python_version(cls, v: Union[str, Version]) -> Version:
|
|
20
|
-
if isinstance(v, Version):
|
|
21
|
-
return v
|
|
22
|
-
try:
|
|
23
|
-
return Version(v)
|
|
24
|
-
except Exception:
|
|
25
|
-
raise ValueError(f"Invalid Python version format: {v}")
|
|
26
|
-
|
|
27
|
-
class Config:
|
|
28
|
-
frozen = True
|
|
29
|
-
extra = "allow"
|
|
30
|
-
arbitrary_types_allowed = True
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class RuntimeEnvironmentEntry(BaseModel):
|
|
34
|
-
spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
|
|
35
|
-
created_on: datetime.datetime = Field(alias="createdOn")
|
|
36
|
-
id: Optional[str] = Field(alias="id")
|
|
37
|
-
|
|
38
|
-
class Config:
|
|
39
|
-
extra = "allow"
|
|
40
|
-
frozen = True
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
|
|
44
|
-
@field_validator("root", mode="before")
|
|
45
|
-
@classmethod
|
|
46
|
-
def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]:
|
|
47
|
-
"""
|
|
48
|
-
Pre-validation hook: keep only those items at the root level
|
|
49
|
-
whose values are dicts. Non-dict values will be dropped.
|
|
4
|
+
from snowflake.ml.jobs._utils import query_helper
|
|
50
5
|
|
|
51
|
-
Args:
|
|
52
|
-
data: The input data to filter, expected to be a dictionary.
|
|
53
6
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
Raises:
|
|
58
|
-
ValueError: If input data is not a dictionary.
|
|
59
|
-
"""
|
|
60
|
-
# If the entire root is not a dict, raise error immediately
|
|
61
|
-
if not isinstance(data, dict):
|
|
62
|
-
raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}")
|
|
63
|
-
|
|
64
|
-
# Filter out any key whose value is not a dict
|
|
65
|
-
return {key: value for key, value in data.items() if isinstance(value, dict)}
|
|
66
|
-
|
|
67
|
-
def get_spcs_container_runtimes(
|
|
68
|
-
self,
|
|
69
|
-
*,
|
|
70
|
-
hardware_type: Optional[str] = None,
|
|
71
|
-
python_version: Optional[Version] = None,
|
|
72
|
-
) -> list[SpcsContainerRuntime]:
|
|
73
|
-
# TODO(SNOW-2682000): parse version from NRE in a safer way, like relying on the label,id or image tag.
|
|
74
|
-
entries: list[RuntimeEnvironmentEntry] = [
|
|
75
|
-
entry
|
|
76
|
-
for entry in self.root.values()
|
|
77
|
-
if entry.spcs_container_runtime is not None
|
|
78
|
-
and (hardware_type is None or entry.spcs_container_runtime.hardware_type.lower() == hardware_type.lower())
|
|
79
|
-
and (
|
|
80
|
-
python_version is None
|
|
81
|
-
or (
|
|
82
|
-
entry.spcs_container_runtime.python_version.major == python_version.major
|
|
83
|
-
and entry.spcs_container_runtime.python_version.minor == python_version.minor
|
|
84
|
-
)
|
|
85
|
-
)
|
|
86
|
-
]
|
|
87
|
-
entries.sort(key=lambda e: e.created_on, reverse=True)
|
|
88
|
-
|
|
89
|
-
return [entry.spcs_container_runtime for entry in entries if entry.spcs_container_runtime is not None]
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def _extract_image_tag(image_url: str) -> Optional[str]:
|
|
93
|
-
image_tag = image_url.rsplit(":", 1)[-1]
|
|
94
|
-
return image_tag
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
def find_runtime_image(
|
|
98
|
-
session: snowpark.Session, target_hardware: Literal["CPU", "GPU"], target_python_version: Optional[str] = None
|
|
7
|
+
def get_runtime_image(
|
|
8
|
+
session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
|
|
99
9
|
) -> Optional[str]:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
)
|
|
103
|
-
rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
|
|
10
|
+
runtime_environment = runtime_environment if runtime_environment else ""
|
|
11
|
+
rows = query_helper.run_query(session, f"CALL SYSTEM$GET_ML_JOB_RUNTIME('{compute_pool}', '{runtime_environment}')")
|
|
104
12
|
if not rows:
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
)
|
|
112
|
-
except Exception as e:
|
|
113
|
-
logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
|
|
114
|
-
return None
|
|
115
|
-
|
|
116
|
-
selected_runtime = spcs_container_runtimes[0] if spcs_container_runtimes else None
|
|
117
|
-
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
13
|
+
raise ValueError("Failed to get any available runtime image")
|
|
14
|
+
image = rows[0][0]
|
|
15
|
+
url, tag = image.rsplit(":", 1)
|
|
16
|
+
if url is None or tag is None:
|
|
17
|
+
raise ValueError(f"image {image} is not a valid runtime image")
|
|
18
|
+
return cast(str, image) if image else None
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import copy
|
|
3
3
|
import importlib.util
|
|
4
|
+
import io
|
|
4
5
|
import json
|
|
5
6
|
import logging
|
|
6
7
|
import math
|
|
@@ -12,15 +13,22 @@ import sys
|
|
|
12
13
|
import time
|
|
13
14
|
import traceback
|
|
14
15
|
import zipfile
|
|
15
|
-
from pathlib import Path
|
|
16
|
-
from typing import Any, Optional
|
|
16
|
+
from pathlib import Path, PurePosixPath
|
|
17
|
+
from typing import Any, Callable, Optional
|
|
17
18
|
|
|
18
19
|
# Ensure payload directory is in sys.path for module imports before importing other modules
|
|
19
20
|
# This is needed to support relative imports in user scripts and to allow overriding
|
|
20
21
|
# modules using modules in the payload directory
|
|
21
22
|
# TODO: Inject the environment variable names at job submission time
|
|
22
23
|
STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
|
|
23
|
-
|
|
24
|
+
STAGE_RESULT_PATH = os.environ.get("MLRS_STAGE_RESULT_PATH")
|
|
25
|
+
# Updated MLRS_RESULT_PATH to use unique stage mounts for each ML Job.
|
|
26
|
+
# To prevent output collisions between jobs sharing the same definition,
|
|
27
|
+
# the server-side mount now dynamically includes the job_name.
|
|
28
|
+
# Format: @payload_stage/{job_definition_name}/{job_name}/mljob_result
|
|
29
|
+
JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "mljob_result")
|
|
30
|
+
if STAGE_RESULT_PATH:
|
|
31
|
+
JOB_RESULT_PATH = os.path.join(STAGE_RESULT_PATH, JOB_RESULT_PATH)
|
|
24
32
|
PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
|
|
25
33
|
|
|
26
34
|
if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
|
|
@@ -347,24 +355,156 @@ def wait_for_instances(
|
|
|
347
355
|
)
|
|
348
356
|
|
|
349
357
|
|
|
350
|
-
def
|
|
358
|
+
def _load_dto_fallback(function_args: str, path_transform: Callable[[str], str]) -> Any:
|
|
359
|
+
from snowflake.ml.jobs._interop import data_utils
|
|
360
|
+
from snowflake.ml.jobs._interop.utils import DEFAULT_CODEC, DEFAULT_PROTOCOL
|
|
361
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
|
362
|
+
|
|
363
|
+
try:
|
|
364
|
+
with data_utils.open_stream(function_args, "r") as stream:
|
|
365
|
+
# Load the DTO as a dict for easy fallback to legacy loading if necessary
|
|
366
|
+
data = DEFAULT_CODEC.decode(stream, as_dict=True)
|
|
367
|
+
# the exception could be OSError or BlockingIOError(the file name is too long)
|
|
368
|
+
except OSError as e:
|
|
369
|
+
# path_or_data might be inline data
|
|
370
|
+
try:
|
|
371
|
+
data = DEFAULT_CODEC.decode(io.StringIO(function_args), as_dict=True)
|
|
372
|
+
except Exception:
|
|
373
|
+
raise e
|
|
374
|
+
|
|
375
|
+
if data["protocol"] is not None:
|
|
376
|
+
try:
|
|
377
|
+
from snowflake.ml.jobs._interop.dto_schema import ProtocolInfo
|
|
378
|
+
|
|
379
|
+
protocol_info = ProtocolInfo.model_validate(data["protocol"])
|
|
380
|
+
logger.debug(f"Loading result value with protocol {protocol_info}")
|
|
381
|
+
result_value = DEFAULT_PROTOCOL.load(protocol_info, session=None, path_transform=path_transform)
|
|
382
|
+
except sp_exceptions.SnowparkSQLException:
|
|
383
|
+
raise
|
|
384
|
+
else:
|
|
385
|
+
result_value = None
|
|
386
|
+
|
|
387
|
+
return data["value"] or result_value
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _unpack_obj_fallback(obj: Any, session: Optional[snowflake.snowpark.Session]) -> Any:
|
|
391
|
+
SESSION_KEY_PREFIX = "session@"
|
|
392
|
+
|
|
393
|
+
if not isinstance(obj, dict):
|
|
394
|
+
return obj
|
|
395
|
+
elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
|
|
396
|
+
return session
|
|
397
|
+
else:
|
|
398
|
+
type = obj.get("type@", None)
|
|
399
|
+
# If type is None, we are unpacking a dict
|
|
400
|
+
if type is None:
|
|
401
|
+
result_dict = {}
|
|
402
|
+
for k, v in obj.items():
|
|
403
|
+
if k.startswith(SESSION_KEY_PREFIX):
|
|
404
|
+
result_key = k[len(SESSION_KEY_PREFIX) :]
|
|
405
|
+
result_dict[result_key] = session
|
|
406
|
+
else:
|
|
407
|
+
result_dict[k] = _unpack_obj_fallback(v, session)
|
|
408
|
+
return result_dict
|
|
409
|
+
# If type is not None, we are unpacking a tuple or list
|
|
410
|
+
else:
|
|
411
|
+
indexes = []
|
|
412
|
+
for k, _ in obj.items():
|
|
413
|
+
if "#" in k:
|
|
414
|
+
indexes.append(int(k.split("#")[-1]))
|
|
415
|
+
|
|
416
|
+
if not indexes:
|
|
417
|
+
return tuple() if type is tuple else []
|
|
418
|
+
result_list: list[Any] = [None] * (max(indexes) + 1)
|
|
419
|
+
|
|
420
|
+
for k, v in obj.items():
|
|
421
|
+
if k == "type@":
|
|
422
|
+
continue
|
|
423
|
+
idx = int(k.split("#")[-1])
|
|
424
|
+
if k.startswith(SESSION_KEY_PREFIX):
|
|
425
|
+
result_list[idx] = session
|
|
426
|
+
else:
|
|
427
|
+
result_list[idx] = _unpack_obj_fallback(v, session)
|
|
428
|
+
return tuple(result_list) if type is tuple else result_list
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def _load_function_args(
|
|
432
|
+
session: snowflake.snowpark.Session,
|
|
433
|
+
function_args: Optional[str] = None,
|
|
434
|
+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
|
435
|
+
"""Load and deserialize function arguments.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
function_args: Inline serialized function arguments or path to serialized file.
|
|
439
|
+
session: Optional Snowpark session for stage access if needed.
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
A tuple of (positional_args, keyword_args)
|
|
443
|
+
|
|
444
|
+
"""
|
|
445
|
+
if not function_args:
|
|
446
|
+
return (), {}
|
|
447
|
+
|
|
448
|
+
def path_transform(stage_path: str) -> str:
|
|
449
|
+
if not PAYLOAD_PATH:
|
|
450
|
+
return stage_path
|
|
451
|
+
|
|
452
|
+
payload_path = PurePosixPath(PAYLOAD_PATH)
|
|
453
|
+
payload_dir_name = payload_path.name # e.g., "app"
|
|
454
|
+
|
|
455
|
+
# Parse stage path and find the payload directory
|
|
456
|
+
stage_parts = PurePosixPath(stage_path.lstrip("@")).parts
|
|
457
|
+
|
|
458
|
+
try:
|
|
459
|
+
# Find index of payload directory (e.g., "app") in stage path
|
|
460
|
+
idx = stage_parts.index(payload_dir_name)
|
|
461
|
+
# Get relative path after the payload directory
|
|
462
|
+
relative_parts = stage_parts[idx + 1 :]
|
|
463
|
+
return str(payload_path.joinpath(*relative_parts))
|
|
464
|
+
except (ValueError, IndexError):
|
|
465
|
+
# Fallback to just the filename
|
|
466
|
+
return str(payload_path / PurePosixPath(stage_path).name)
|
|
467
|
+
|
|
468
|
+
try:
|
|
469
|
+
from snowflake.ml.jobs._interop import utils as interop_utils
|
|
470
|
+
|
|
471
|
+
args, kwargs = interop_utils.load(
|
|
472
|
+
function_args,
|
|
473
|
+
session=session,
|
|
474
|
+
path_transform=path_transform,
|
|
475
|
+
)
|
|
476
|
+
return args, kwargs
|
|
477
|
+
except (AttributeError, ImportError):
|
|
478
|
+
# Backwards compatibility: load may not exist in older SnowML versions
|
|
479
|
+
packed = _load_dto_fallback(function_args, path_transform)
|
|
480
|
+
args, kwargs = _unpack_obj_fallback(packed, session)
|
|
481
|
+
return args, kwargs
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def run_script(
|
|
485
|
+
script_path: str,
|
|
486
|
+
payload_args: Optional[tuple[Any, ...]] = None,
|
|
487
|
+
payload_kwargs: Optional[dict[str, Any]] = None,
|
|
488
|
+
main_func: Optional[str] = None,
|
|
489
|
+
) -> Any:
|
|
351
490
|
"""
|
|
352
491
|
Execute a Python script and return its result.
|
|
353
492
|
|
|
354
493
|
Args:
|
|
355
|
-
script_path: Path to the Python script
|
|
356
|
-
|
|
357
|
-
|
|
494
|
+
script_path: Path to the Python script.
|
|
495
|
+
payload_args: Positional arguments to pass to the script or entrypoint.
|
|
496
|
+
payload_kwargs: Keyword arguments to pass to the script or entrypoint.
|
|
497
|
+
main_func: The name of the function to call in the script (if any).
|
|
358
498
|
|
|
359
499
|
Returns:
|
|
360
500
|
Result from script execution, either from the main function or the script's __return__ value
|
|
361
501
|
|
|
362
502
|
Raises:
|
|
363
503
|
RuntimeError: If the specified main_func is not found or not callable
|
|
504
|
+
ValueError: If payload_kwargs is provided for runpy execution.
|
|
364
505
|
"""
|
|
365
506
|
# Save original sys.argv and modify it for the script (applies to runpy execution only)
|
|
366
507
|
original_argv = sys.argv
|
|
367
|
-
sys.argv = [script_path, *script_args]
|
|
368
508
|
|
|
369
509
|
try:
|
|
370
510
|
if main_func:
|
|
@@ -381,10 +521,13 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
381
521
|
raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
|
|
382
522
|
|
|
383
523
|
# Call main function
|
|
384
|
-
result = func(*
|
|
524
|
+
result = func(*(payload_args or ()), **(payload_kwargs or {}))
|
|
385
525
|
return result
|
|
386
526
|
else:
|
|
387
|
-
|
|
527
|
+
if payload_kwargs:
|
|
528
|
+
raise ValueError("payload_kwargs is not supported for runpy execution; use payload_args instead")
|
|
529
|
+
# Save original sys.argv and modify it for the script.
|
|
530
|
+
sys.argv = [script_path, *(payload_args or ())]
|
|
388
531
|
globals_dict = runpy.run_path(script_path, run_name="__main__")
|
|
389
532
|
result = globals_dict.get("__return__", None)
|
|
390
533
|
return result
|
|
@@ -393,24 +536,28 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
393
536
|
sys.argv = original_argv
|
|
394
537
|
|
|
395
538
|
|
|
396
|
-
def main(
|
|
539
|
+
def main(
|
|
540
|
+
entrypoint: str,
|
|
541
|
+
session: snowflake.snowpark.Session,
|
|
542
|
+
payload_args: Optional[tuple[Any, ...]] = None,
|
|
543
|
+
payload_kwargs: Optional[dict[str, Any]] = None,
|
|
544
|
+
script_main_func: Optional[str] = None,
|
|
545
|
+
) -> Any:
|
|
397
546
|
"""Executes a Python script and serializes the result to JOB_RESULT_PATH.
|
|
398
547
|
|
|
399
548
|
Args:
|
|
400
549
|
entrypoint (str): The job payload entrypoint to execute.
|
|
401
|
-
|
|
550
|
+
payload_args (tuple[Any, ...], optional): Positional args to pass to the script or entrypoint.
|
|
551
|
+
payload_kwargs (dict[str, Any], optional): Keyword args to pass to the script or entrypoint.
|
|
402
552
|
script_main_func (str, optional): The name of the function to call in the script (if any).
|
|
553
|
+
session (snowflake.snowpark.Session, optional): Snowpark session for stage access if needed.
|
|
403
554
|
|
|
404
555
|
Returns:
|
|
405
556
|
Any: The result of the script execution.
|
|
406
557
|
|
|
407
558
|
Raises:
|
|
408
|
-
|
|
559
|
+
ValueError: If payload_kwargs is provided for runpy execution.
|
|
409
560
|
"""
|
|
410
|
-
try:
|
|
411
|
-
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
|
412
|
-
except ImportError:
|
|
413
|
-
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
414
561
|
|
|
415
562
|
# Initialize Ray if available
|
|
416
563
|
try:
|
|
@@ -420,12 +567,6 @@ def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = N
|
|
|
420
567
|
except ModuleNotFoundError:
|
|
421
568
|
logger.debug("Ray is not installed, skipping Ray initialization")
|
|
422
569
|
|
|
423
|
-
# Create a Snowpark session before starting
|
|
424
|
-
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
425
|
-
config = SnowflakeLoginOptions()
|
|
426
|
-
config["client_session_keep_alive"] = "True"
|
|
427
|
-
session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
|
|
428
|
-
|
|
429
570
|
execution_result_is_error = False
|
|
430
571
|
execution_result_value = None
|
|
431
572
|
try:
|
|
@@ -446,10 +587,21 @@ def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = N
|
|
|
446
587
|
|
|
447
588
|
if is_python:
|
|
448
589
|
# Run as Python script
|
|
449
|
-
execution_result_value = run_script(
|
|
590
|
+
execution_result_value = run_script(
|
|
591
|
+
resolved_entrypoint,
|
|
592
|
+
payload_args=payload_args,
|
|
593
|
+
payload_kwargs=payload_kwargs,
|
|
594
|
+
main_func=script_main_func,
|
|
595
|
+
)
|
|
450
596
|
else:
|
|
451
597
|
# Run as subprocess
|
|
452
|
-
|
|
598
|
+
if payload_kwargs:
|
|
599
|
+
raise ValueError("payload_kwargs is not supported for subprocesses")
|
|
600
|
+
|
|
601
|
+
run_command(
|
|
602
|
+
resolved_entrypoint,
|
|
603
|
+
*(payload_args or ()),
|
|
604
|
+
)
|
|
453
605
|
|
|
454
606
|
# Log end marker for user script execution
|
|
455
607
|
print(LOG_END_MSG) # noqa: T201
|
|
@@ -487,11 +639,36 @@ if __name__ == "__main__":
|
|
|
487
639
|
parser.add_argument(
|
|
488
640
|
"--script_main_func", required=False, help="The name of the main function to call in the script"
|
|
489
641
|
)
|
|
642
|
+
parser.add_argument(
|
|
643
|
+
"--function_args",
|
|
644
|
+
required=False,
|
|
645
|
+
help="Serialized function arguments or path to serialized function arguments file",
|
|
646
|
+
)
|
|
490
647
|
args, unknown_args = parser.parse_known_args()
|
|
491
648
|
|
|
649
|
+
try:
|
|
650
|
+
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
|
651
|
+
except ImportError:
|
|
652
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
653
|
+
|
|
654
|
+
# Create a Snowpark session before starting
|
|
655
|
+
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
656
|
+
# _load_function_args will use the session to load the function arguments
|
|
657
|
+
config = SnowflakeLoginOptions()
|
|
658
|
+
config["client_session_keep_alive"] = "True"
|
|
659
|
+
session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
|
|
660
|
+
|
|
661
|
+
if args.function_args:
|
|
662
|
+
if args.script_args or unknown_args:
|
|
663
|
+
raise ValueError("Only one of function_args and script_args can be provided")
|
|
664
|
+
payload_args, payload_kwargs = _load_function_args(session, args.function_args)
|
|
665
|
+
else:
|
|
666
|
+
payload_args, payload_kwargs = (args.script_args + unknown_args), {}
|
|
667
|
+
|
|
492
668
|
main(
|
|
493
669
|
args.entrypoint,
|
|
494
|
-
|
|
495
|
-
|
|
670
|
+
session=session,
|
|
671
|
+
payload_args=payload_args,
|
|
672
|
+
payload_kwargs=payload_kwargs,
|
|
496
673
|
script_main_func=args.script_main_func,
|
|
497
674
|
)
|