snowflake-ml-python 1.25.0__py3-none-any.whl → 1.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +1 -26
- snowflake/ml/jobs/_interop/data_utils.py +8 -8
- snowflake/ml/jobs/_interop/dto_schema.py +7 -52
- snowflake/ml/jobs/_interop/protocols.py +7 -124
- snowflake/ml/jobs/_interop/utils.py +33 -92
- snowflake/ml/jobs/_utils/constants.py +0 -4
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/payload_utils.py +40 -6
- snowflake/ml/jobs/_utils/runtime_env_utils.py +111 -12
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +27 -204
- snowflake/ml/jobs/_utils/spec_utils.py +22 -0
- snowflake/ml/jobs/decorators.py +22 -17
- snowflake/ml/jobs/job.py +10 -25
- snowflake/ml/jobs/job_definition.py +4 -90
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/METADATA +7 -1
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/RECORD +20 -19
- snowflake/ml/jobs/_utils/arg_protocol.py +0 -7
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.25.0.dist-info → snowflake_ml_python-1.25.1.dist-info}/top_level.txt +0 -0
|
@@ -17,12 +17,20 @@ import cloudpickle as cp
|
|
|
17
17
|
from packaging import version
|
|
18
18
|
|
|
19
19
|
from snowflake import snowpark
|
|
20
|
-
from snowflake.ml.jobs._utils import
|
|
20
|
+
from snowflake.ml.jobs._utils import (
|
|
21
|
+
constants,
|
|
22
|
+
function_payload_utils,
|
|
23
|
+
query_helper,
|
|
24
|
+
stage_utils,
|
|
25
|
+
types,
|
|
26
|
+
)
|
|
21
27
|
from snowflake.snowpark import exceptions as sp_exceptions
|
|
22
28
|
from snowflake.snowpark._internal import code_generation
|
|
23
29
|
from snowflake.snowpark._internal.utils import zip_file_or_directory_to_stream
|
|
24
30
|
|
|
25
31
|
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
cp.register_pickle_by_value(function_payload_utils)
|
|
26
34
|
ImportType = Union[str, Path, ModuleType]
|
|
27
35
|
|
|
28
36
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
|
@@ -553,6 +561,7 @@ class JobPayload:
|
|
|
553
561
|
env_vars = {
|
|
554
562
|
constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,
|
|
555
563
|
constants.PAYLOAD_DIR_ENV_VAR: constants.APP_STAGE_SUBPATH,
|
|
564
|
+
constants.RESULT_PATH_ENV_VAR: constants.RESULT_PATH_DEFAULT_VALUE,
|
|
556
565
|
}
|
|
557
566
|
|
|
558
567
|
return types.UploadedPayload(
|
|
@@ -682,9 +691,14 @@ def _generate_param_handler_code(signature: inspect.Signature, output_name: str
|
|
|
682
691
|
return param_code
|
|
683
692
|
|
|
684
693
|
|
|
685
|
-
def generate_python_code(
|
|
694
|
+
def generate_python_code(payload: Callable[..., Any], source_code_display: bool = False) -> str:
|
|
686
695
|
"""Generate an entrypoint script from a Python function."""
|
|
687
696
|
|
|
697
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
|
698
|
+
function = payload.function
|
|
699
|
+
else:
|
|
700
|
+
function = payload
|
|
701
|
+
|
|
688
702
|
signature = inspect.signature(function)
|
|
689
703
|
if any(
|
|
690
704
|
p.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
|
@@ -697,7 +711,7 @@ def generate_python_code(function: Callable[..., Any], source_code_display: bool
|
|
|
697
711
|
source_code_comment = _generate_source_code_comment(function) if source_code_display else ""
|
|
698
712
|
|
|
699
713
|
arg_dict_name = "kwargs"
|
|
700
|
-
if
|
|
714
|
+
if isinstance(payload, function_payload_utils.FunctionPayload):
|
|
701
715
|
param_code = f"{arg_dict_name} = {{}}"
|
|
702
716
|
else:
|
|
703
717
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
|
@@ -707,7 +721,7 @@ import pickle
|
|
|
707
721
|
|
|
708
722
|
try:
|
|
709
723
|
{textwrap.indent(source_code_comment, ' ')}
|
|
710
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(
|
|
724
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(payload).hex()}'))
|
|
711
725
|
except (TypeError, pickle.PickleError):
|
|
712
726
|
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
|
713
727
|
raise RuntimeError(
|
|
@@ -733,6 +747,26 @@ if __name__ == '__main__':
|
|
|
733
747
|
"""
|
|
734
748
|
|
|
735
749
|
|
|
750
|
+
def create_function_payload(
|
|
751
|
+
func: Callable[..., Any], *args: Any, **kwargs: Any
|
|
752
|
+
) -> function_payload_utils.FunctionPayload:
|
|
753
|
+
signature = inspect.signature(func)
|
|
754
|
+
bound = signature.bind(*args, **kwargs)
|
|
755
|
+
bound.apply_defaults()
|
|
756
|
+
session_argument = ""
|
|
757
|
+
session = None
|
|
758
|
+
for name, val in list(bound.arguments.items()):
|
|
759
|
+
if isinstance(val, snowpark.Session):
|
|
760
|
+
if session:
|
|
761
|
+
raise TypeError(f"Expected only one Session-type argument, but got both {session_argument} and {name}.")
|
|
762
|
+
session = val
|
|
763
|
+
session_argument = name
|
|
764
|
+
del bound.arguments[name]
|
|
765
|
+
payload = function_payload_utils.FunctionPayload(func, session, session_argument, *bound.args, **bound.kwargs)
|
|
766
|
+
|
|
767
|
+
return payload
|
|
768
|
+
|
|
769
|
+
|
|
736
770
|
def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optional[Union[str, list[str]]] = None) -> str:
|
|
737
771
|
|
|
738
772
|
if entrypoint and isinstance(entrypoint, (list, tuple)):
|
|
@@ -741,7 +775,7 @@ def get_payload_name(source: Union[str, Callable[..., Any]], entrypoint: Optiona
|
|
|
741
775
|
return f"{PurePath(entrypoint).stem}"
|
|
742
776
|
elif source and not callable(source):
|
|
743
777
|
return f"{PurePath(source).stem}"
|
|
744
|
-
elif
|
|
745
|
-
return f"{source.__name__}"
|
|
778
|
+
elif isinstance(source, function_payload_utils.FunctionPayload):
|
|
779
|
+
return f"{source.function.__name__}"
|
|
746
780
|
else:
|
|
747
781
|
return f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
|
|
@@ -1,18 +1,117 @@
|
|
|
1
|
-
|
|
1
|
+
import datetime
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Any, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
from packaging.version import Version
|
|
6
|
+
from pydantic import BaseModel, Field, RootModel, field_validator
|
|
2
7
|
|
|
3
8
|
from snowflake import snowpark
|
|
4
|
-
from snowflake.ml.jobs._utils import query_helper
|
|
9
|
+
from snowflake.ml.jobs._utils import constants, query_helper
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SpcsContainerRuntime(BaseModel):
|
|
13
|
+
python_version: Version = Field(alias="pythonVersion")
|
|
14
|
+
hardware_type: str = Field(alias="hardwareType")
|
|
15
|
+
runtime_container_image: str = Field(alias="runtimeContainerImage")
|
|
16
|
+
|
|
17
|
+
@field_validator("python_version", mode="before")
|
|
18
|
+
@classmethod
|
|
19
|
+
def validate_python_version(cls, v: Union[str, Version]) -> Version:
|
|
20
|
+
if isinstance(v, Version):
|
|
21
|
+
return v
|
|
22
|
+
try:
|
|
23
|
+
return Version(v)
|
|
24
|
+
except Exception:
|
|
25
|
+
raise ValueError(f"Invalid Python version format: {v}")
|
|
26
|
+
|
|
27
|
+
class Config:
|
|
28
|
+
frozen = True
|
|
29
|
+
extra = "allow"
|
|
30
|
+
arbitrary_types_allowed = True
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RuntimeEnvironmentEntry(BaseModel):
|
|
34
|
+
spcs_container_runtime: Optional[SpcsContainerRuntime] = Field(alias="spcsContainerRuntime", default=None)
|
|
35
|
+
created_on: datetime.datetime = Field(alias="createdOn")
|
|
36
|
+
id: Optional[str] = Field(alias="id")
|
|
37
|
+
|
|
38
|
+
class Config:
|
|
39
|
+
extra = "allow"
|
|
40
|
+
frozen = True
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RuntimeEnvironmentsDict(RootModel[dict[str, RuntimeEnvironmentEntry]]):
|
|
44
|
+
@field_validator("root", mode="before")
|
|
45
|
+
@classmethod
|
|
46
|
+
def _filter_to_dict_entries(cls, data: Any) -> dict[str, dict[str, Any]]:
|
|
47
|
+
"""
|
|
48
|
+
Pre-validation hook: keep only those items at the root level
|
|
49
|
+
whose values are dicts. Non-dict values will be dropped.
|
|
5
50
|
|
|
51
|
+
Args:
|
|
52
|
+
data: The input data to filter, expected to be a dictionary.
|
|
6
53
|
|
|
7
|
-
|
|
8
|
-
|
|
54
|
+
Returns:
|
|
55
|
+
A dictionary containing only the key-value pairs where values are dictionaries.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If input data is not a dictionary.
|
|
59
|
+
"""
|
|
60
|
+
# If the entire root is not a dict, raise error immediately
|
|
61
|
+
if not isinstance(data, dict):
|
|
62
|
+
raise ValueError(f"Expected dictionary data, but got {type(data).__name__}: {data}")
|
|
63
|
+
|
|
64
|
+
# Filter out any key whose value is not a dict
|
|
65
|
+
return {key: value for key, value in data.items() if isinstance(value, dict)}
|
|
66
|
+
|
|
67
|
+
def get_spcs_container_runtimes(
|
|
68
|
+
self,
|
|
69
|
+
*,
|
|
70
|
+
hardware_type: Optional[str] = None,
|
|
71
|
+
python_version: Optional[Version] = None,
|
|
72
|
+
) -> list[SpcsContainerRuntime]:
|
|
73
|
+
# TODO(SNOW-2682000): parse version from NRE in a safer way, like relying on the label,id or image tag.
|
|
74
|
+
entries: list[RuntimeEnvironmentEntry] = [
|
|
75
|
+
entry
|
|
76
|
+
for entry in self.root.values()
|
|
77
|
+
if entry.spcs_container_runtime is not None
|
|
78
|
+
and (hardware_type is None or entry.spcs_container_runtime.hardware_type.lower() == hardware_type.lower())
|
|
79
|
+
and (
|
|
80
|
+
python_version is None
|
|
81
|
+
or (
|
|
82
|
+
entry.spcs_container_runtime.python_version.major == python_version.major
|
|
83
|
+
and entry.spcs_container_runtime.python_version.minor == python_version.minor
|
|
84
|
+
)
|
|
85
|
+
)
|
|
86
|
+
]
|
|
87
|
+
entries.sort(key=lambda e: e.created_on, reverse=True)
|
|
88
|
+
|
|
89
|
+
return [entry.spcs_container_runtime for entry in entries if entry.spcs_container_runtime is not None]
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _extract_image_tag(image_url: str) -> Optional[str]:
|
|
93
|
+
image_tag = image_url.rsplit(":", 1)[-1]
|
|
94
|
+
return image_tag
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def find_runtime_image(
|
|
98
|
+
session: snowpark.Session, target_hardware: Literal["CPU", "GPU"], target_python_version: Optional[str] = None
|
|
9
99
|
) -> Optional[str]:
|
|
10
|
-
|
|
11
|
-
|
|
100
|
+
python_version = (
|
|
101
|
+
Version(target_python_version) if target_python_version else Version(constants.DEFAULT_PYTHON_VERSION)
|
|
102
|
+
)
|
|
103
|
+
rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
|
|
12
104
|
if not rows:
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
105
|
+
return None
|
|
106
|
+
try:
|
|
107
|
+
runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
|
|
108
|
+
spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes(
|
|
109
|
+
hardware_type=target_hardware,
|
|
110
|
+
python_version=python_version,
|
|
111
|
+
)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
|
|
114
|
+
return None
|
|
115
|
+
|
|
116
|
+
selected_runtime = spcs_container_runtimes[0] if spcs_container_runtimes else None
|
|
117
|
+
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import copy
|
|
3
3
|
import importlib.util
|
|
4
|
-
import io
|
|
5
4
|
import json
|
|
6
5
|
import logging
|
|
7
6
|
import math
|
|
@@ -13,22 +12,15 @@ import sys
|
|
|
13
12
|
import time
|
|
14
13
|
import traceback
|
|
15
14
|
import zipfile
|
|
16
|
-
from pathlib import Path
|
|
17
|
-
from typing import Any,
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any, Optional
|
|
18
17
|
|
|
19
18
|
# Ensure payload directory is in sys.path for module imports before importing other modules
|
|
20
19
|
# This is needed to support relative imports in user scripts and to allow overriding
|
|
21
20
|
# modules using modules in the payload directory
|
|
22
21
|
# TODO: Inject the environment variable names at job submission time
|
|
23
22
|
STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
|
|
24
|
-
|
|
25
|
-
# Updated MLRS_RESULT_PATH to use unique stage mounts for each ML Job.
|
|
26
|
-
# To prevent output collisions between jobs sharing the same definition,
|
|
27
|
-
# the server-side mount now dynamically includes the job_name.
|
|
28
|
-
# Format: @payload_stage/{job_definition_name}/{job_name}/mljob_result
|
|
29
|
-
JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "mljob_result")
|
|
30
|
-
if STAGE_RESULT_PATH:
|
|
31
|
-
JOB_RESULT_PATH = os.path.join(STAGE_RESULT_PATH, JOB_RESULT_PATH)
|
|
23
|
+
JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
|
|
32
24
|
PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
|
|
33
25
|
|
|
34
26
|
if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
|
|
@@ -355,156 +347,24 @@ def wait_for_instances(
|
|
|
355
347
|
)
|
|
356
348
|
|
|
357
349
|
|
|
358
|
-
def
|
|
359
|
-
from snowflake.ml.jobs._interop import data_utils
|
|
360
|
-
from snowflake.ml.jobs._interop.utils import DEFAULT_CODEC, DEFAULT_PROTOCOL
|
|
361
|
-
from snowflake.snowpark import exceptions as sp_exceptions
|
|
362
|
-
|
|
363
|
-
try:
|
|
364
|
-
with data_utils.open_stream(function_args, "r") as stream:
|
|
365
|
-
# Load the DTO as a dict for easy fallback to legacy loading if necessary
|
|
366
|
-
data = DEFAULT_CODEC.decode(stream, as_dict=True)
|
|
367
|
-
# the exception could be OSError or BlockingIOError(the file name is too long)
|
|
368
|
-
except OSError as e:
|
|
369
|
-
# path_or_data might be inline data
|
|
370
|
-
try:
|
|
371
|
-
data = DEFAULT_CODEC.decode(io.StringIO(function_args), as_dict=True)
|
|
372
|
-
except Exception:
|
|
373
|
-
raise e
|
|
374
|
-
|
|
375
|
-
if data["protocol"] is not None:
|
|
376
|
-
try:
|
|
377
|
-
from snowflake.ml.jobs._interop.dto_schema import ProtocolInfo
|
|
378
|
-
|
|
379
|
-
protocol_info = ProtocolInfo.model_validate(data["protocol"])
|
|
380
|
-
logger.debug(f"Loading result value with protocol {protocol_info}")
|
|
381
|
-
result_value = DEFAULT_PROTOCOL.load(protocol_info, session=None, path_transform=path_transform)
|
|
382
|
-
except sp_exceptions.SnowparkSQLException:
|
|
383
|
-
raise
|
|
384
|
-
else:
|
|
385
|
-
result_value = None
|
|
386
|
-
|
|
387
|
-
return data["value"] or result_value
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
def _unpack_obj_fallback(obj: Any, session: Optional[snowflake.snowpark.Session]) -> Any:
|
|
391
|
-
SESSION_KEY_PREFIX = "session@"
|
|
392
|
-
|
|
393
|
-
if not isinstance(obj, dict):
|
|
394
|
-
return obj
|
|
395
|
-
elif len(obj) == 1 and SESSION_KEY_PREFIX in obj:
|
|
396
|
-
return session
|
|
397
|
-
else:
|
|
398
|
-
type = obj.get("type@", None)
|
|
399
|
-
# If type is None, we are unpacking a dict
|
|
400
|
-
if type is None:
|
|
401
|
-
result_dict = {}
|
|
402
|
-
for k, v in obj.items():
|
|
403
|
-
if k.startswith(SESSION_KEY_PREFIX):
|
|
404
|
-
result_key = k[len(SESSION_KEY_PREFIX) :]
|
|
405
|
-
result_dict[result_key] = session
|
|
406
|
-
else:
|
|
407
|
-
result_dict[k] = _unpack_obj_fallback(v, session)
|
|
408
|
-
return result_dict
|
|
409
|
-
# If type is not None, we are unpacking a tuple or list
|
|
410
|
-
else:
|
|
411
|
-
indexes = []
|
|
412
|
-
for k, _ in obj.items():
|
|
413
|
-
if "#" in k:
|
|
414
|
-
indexes.append(int(k.split("#")[-1]))
|
|
415
|
-
|
|
416
|
-
if not indexes:
|
|
417
|
-
return tuple() if type is tuple else []
|
|
418
|
-
result_list: list[Any] = [None] * (max(indexes) + 1)
|
|
419
|
-
|
|
420
|
-
for k, v in obj.items():
|
|
421
|
-
if k == "type@":
|
|
422
|
-
continue
|
|
423
|
-
idx = int(k.split("#")[-1])
|
|
424
|
-
if k.startswith(SESSION_KEY_PREFIX):
|
|
425
|
-
result_list[idx] = session
|
|
426
|
-
else:
|
|
427
|
-
result_list[idx] = _unpack_obj_fallback(v, session)
|
|
428
|
-
return tuple(result_list) if type is tuple else result_list
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
def _load_function_args(
|
|
432
|
-
session: snowflake.snowpark.Session,
|
|
433
|
-
function_args: Optional[str] = None,
|
|
434
|
-
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
|
435
|
-
"""Load and deserialize function arguments.
|
|
436
|
-
|
|
437
|
-
Args:
|
|
438
|
-
function_args: Inline serialized function arguments or path to serialized file.
|
|
439
|
-
session: Optional Snowpark session for stage access if needed.
|
|
440
|
-
|
|
441
|
-
Returns:
|
|
442
|
-
A tuple of (positional_args, keyword_args)
|
|
443
|
-
|
|
444
|
-
"""
|
|
445
|
-
if not function_args:
|
|
446
|
-
return (), {}
|
|
447
|
-
|
|
448
|
-
def path_transform(stage_path: str) -> str:
|
|
449
|
-
if not PAYLOAD_PATH:
|
|
450
|
-
return stage_path
|
|
451
|
-
|
|
452
|
-
payload_path = PurePosixPath(PAYLOAD_PATH)
|
|
453
|
-
payload_dir_name = payload_path.name # e.g., "app"
|
|
454
|
-
|
|
455
|
-
# Parse stage path and find the payload directory
|
|
456
|
-
stage_parts = PurePosixPath(stage_path.lstrip("@")).parts
|
|
457
|
-
|
|
458
|
-
try:
|
|
459
|
-
# Find index of payload directory (e.g., "app") in stage path
|
|
460
|
-
idx = stage_parts.index(payload_dir_name)
|
|
461
|
-
# Get relative path after the payload directory
|
|
462
|
-
relative_parts = stage_parts[idx + 1 :]
|
|
463
|
-
return str(payload_path.joinpath(*relative_parts))
|
|
464
|
-
except (ValueError, IndexError):
|
|
465
|
-
# Fallback to just the filename
|
|
466
|
-
return str(payload_path / PurePosixPath(stage_path).name)
|
|
467
|
-
|
|
468
|
-
try:
|
|
469
|
-
from snowflake.ml.jobs._interop import utils as interop_utils
|
|
470
|
-
|
|
471
|
-
args, kwargs = interop_utils.load(
|
|
472
|
-
function_args,
|
|
473
|
-
session=session,
|
|
474
|
-
path_transform=path_transform,
|
|
475
|
-
)
|
|
476
|
-
return args, kwargs
|
|
477
|
-
except (AttributeError, ImportError):
|
|
478
|
-
# Backwards compatibility: load may not exist in older SnowML versions
|
|
479
|
-
packed = _load_dto_fallback(function_args, path_transform)
|
|
480
|
-
args, kwargs = _unpack_obj_fallback(packed, session)
|
|
481
|
-
return args, kwargs
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
def run_script(
|
|
485
|
-
script_path: str,
|
|
486
|
-
payload_args: Optional[tuple[Any, ...]] = None,
|
|
487
|
-
payload_kwargs: Optional[dict[str, Any]] = None,
|
|
488
|
-
main_func: Optional[str] = None,
|
|
489
|
-
) -> Any:
|
|
350
|
+
def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
|
|
490
351
|
"""
|
|
491
352
|
Execute a Python script and return its result.
|
|
492
353
|
|
|
493
354
|
Args:
|
|
494
|
-
script_path: Path to the Python script
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
main_func: The name of the function to call in the script (if any).
|
|
355
|
+
script_path: Path to the Python script
|
|
356
|
+
script_args: Arguments to pass to the script
|
|
357
|
+
main_func: The name of the function to call in the script (if any)
|
|
498
358
|
|
|
499
359
|
Returns:
|
|
500
360
|
Result from script execution, either from the main function or the script's __return__ value
|
|
501
361
|
|
|
502
362
|
Raises:
|
|
503
363
|
RuntimeError: If the specified main_func is not found or not callable
|
|
504
|
-
ValueError: If payload_kwargs is provided for runpy execution.
|
|
505
364
|
"""
|
|
506
365
|
# Save original sys.argv and modify it for the script (applies to runpy execution only)
|
|
507
366
|
original_argv = sys.argv
|
|
367
|
+
sys.argv = [script_path, *script_args]
|
|
508
368
|
|
|
509
369
|
try:
|
|
510
370
|
if main_func:
|
|
@@ -521,13 +381,10 @@ def run_script(
|
|
|
521
381
|
raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
|
|
522
382
|
|
|
523
383
|
# Call main function
|
|
524
|
-
result = func(*
|
|
384
|
+
result = func(*script_args)
|
|
525
385
|
return result
|
|
526
386
|
else:
|
|
527
|
-
|
|
528
|
-
raise ValueError("payload_kwargs is not supported for runpy execution; use payload_args instead")
|
|
529
|
-
# Save original sys.argv and modify it for the script.
|
|
530
|
-
sys.argv = [script_path, *(payload_args or ())]
|
|
387
|
+
# Use runpy for other scripts
|
|
531
388
|
globals_dict = runpy.run_path(script_path, run_name="__main__")
|
|
532
389
|
result = globals_dict.get("__return__", None)
|
|
533
390
|
return result
|
|
@@ -536,28 +393,24 @@ def run_script(
|
|
|
536
393
|
sys.argv = original_argv
|
|
537
394
|
|
|
538
395
|
|
|
539
|
-
def main(
|
|
540
|
-
entrypoint: str,
|
|
541
|
-
session: snowflake.snowpark.Session,
|
|
542
|
-
payload_args: Optional[tuple[Any, ...]] = None,
|
|
543
|
-
payload_kwargs: Optional[dict[str, Any]] = None,
|
|
544
|
-
script_main_func: Optional[str] = None,
|
|
545
|
-
) -> Any:
|
|
396
|
+
def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
|
|
546
397
|
"""Executes a Python script and serializes the result to JOB_RESULT_PATH.
|
|
547
398
|
|
|
548
399
|
Args:
|
|
549
400
|
entrypoint (str): The job payload entrypoint to execute.
|
|
550
|
-
|
|
551
|
-
payload_kwargs (dict[str, Any], optional): Keyword args to pass to the script or entrypoint.
|
|
401
|
+
script_args (Any): Arguments to pass to the script.
|
|
552
402
|
script_main_func (str, optional): The name of the function to call in the script (if any).
|
|
553
|
-
session (snowflake.snowpark.Session, optional): Snowpark session for stage access if needed.
|
|
554
403
|
|
|
555
404
|
Returns:
|
|
556
405
|
Any: The result of the script execution.
|
|
557
406
|
|
|
558
407
|
Raises:
|
|
559
|
-
|
|
408
|
+
Exception: Re-raises any exception caught during script execution.
|
|
560
409
|
"""
|
|
410
|
+
try:
|
|
411
|
+
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
|
412
|
+
except ImportError:
|
|
413
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
561
414
|
|
|
562
415
|
# Initialize Ray if available
|
|
563
416
|
try:
|
|
@@ -567,6 +420,12 @@ def main(
|
|
|
567
420
|
except ModuleNotFoundError:
|
|
568
421
|
logger.debug("Ray is not installed, skipping Ray initialization")
|
|
569
422
|
|
|
423
|
+
# Create a Snowpark session before starting
|
|
424
|
+
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
425
|
+
config = SnowflakeLoginOptions()
|
|
426
|
+
config["client_session_keep_alive"] = "True"
|
|
427
|
+
session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
|
|
428
|
+
|
|
570
429
|
execution_result_is_error = False
|
|
571
430
|
execution_result_value = None
|
|
572
431
|
try:
|
|
@@ -587,21 +446,10 @@ def main(
|
|
|
587
446
|
|
|
588
447
|
if is_python:
|
|
589
448
|
# Run as Python script
|
|
590
|
-
execution_result_value = run_script(
|
|
591
|
-
resolved_entrypoint,
|
|
592
|
-
payload_args=payload_args,
|
|
593
|
-
payload_kwargs=payload_kwargs,
|
|
594
|
-
main_func=script_main_func,
|
|
595
|
-
)
|
|
449
|
+
execution_result_value = run_script(resolved_entrypoint, *script_args, main_func=script_main_func)
|
|
596
450
|
else:
|
|
597
451
|
# Run as subprocess
|
|
598
|
-
|
|
599
|
-
raise ValueError("payload_kwargs is not supported for subprocesses")
|
|
600
|
-
|
|
601
|
-
run_command(
|
|
602
|
-
resolved_entrypoint,
|
|
603
|
-
*(payload_args or ()),
|
|
604
|
-
)
|
|
452
|
+
run_command(resolved_entrypoint, *script_args)
|
|
605
453
|
|
|
606
454
|
# Log end marker for user script execution
|
|
607
455
|
print(LOG_END_MSG) # noqa: T201
|
|
@@ -639,36 +487,11 @@ if __name__ == "__main__":
|
|
|
639
487
|
parser.add_argument(
|
|
640
488
|
"--script_main_func", required=False, help="The name of the main function to call in the script"
|
|
641
489
|
)
|
|
642
|
-
parser.add_argument(
|
|
643
|
-
"--function_args",
|
|
644
|
-
required=False,
|
|
645
|
-
help="Serialized function arguments or path to serialized function arguments file",
|
|
646
|
-
)
|
|
647
490
|
args, unknown_args = parser.parse_known_args()
|
|
648
491
|
|
|
649
|
-
try:
|
|
650
|
-
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
|
651
|
-
except ImportError:
|
|
652
|
-
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
653
|
-
|
|
654
|
-
# Create a Snowpark session before starting
|
|
655
|
-
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
656
|
-
# _load_function_args will use the session to load the function arguments
|
|
657
|
-
config = SnowflakeLoginOptions()
|
|
658
|
-
config["client_session_keep_alive"] = "True"
|
|
659
|
-
session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
|
|
660
|
-
|
|
661
|
-
if args.function_args:
|
|
662
|
-
if args.script_args or unknown_args:
|
|
663
|
-
raise ValueError("Only one of function_args and script_args can be provided")
|
|
664
|
-
payload_args, payload_kwargs = _load_function_args(session, args.function_args)
|
|
665
|
-
else:
|
|
666
|
-
payload_args, payload_kwargs = (args.script_args + unknown_args), {}
|
|
667
|
-
|
|
668
492
|
main(
|
|
669
493
|
args.entrypoint,
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
payload_kwargs=payload_kwargs,
|
|
494
|
+
*args.script_args,
|
|
495
|
+
*unknown_args,
|
|
673
496
|
script_main_func=args.script_main_func,
|
|
674
497
|
)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from snowflake import snowpark
|
|
2
|
+
from snowflake.ml._internal.utils import snowflake_env
|
|
3
|
+
from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
7
|
+
"""Extract resource information for the specified compute pool"""
|
|
8
|
+
# Get the instance family
|
|
9
|
+
rows = query_helper.run_query(
|
|
10
|
+
session,
|
|
11
|
+
"show compute pools like ?",
|
|
12
|
+
params=[compute_pool],
|
|
13
|
+
)
|
|
14
|
+
if not rows:
|
|
15
|
+
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
|
16
|
+
instance_family: str = rows[0]["instance_family"]
|
|
17
|
+
cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
|
|
18
|
+
|
|
19
|
+
return (
|
|
20
|
+
constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
|
|
21
|
+
or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
|
|
22
|
+
)
|
snowflake/ml/jobs/decorators.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import copy
|
|
2
|
+
import functools
|
|
2
3
|
from typing import Any, Callable, Optional, TypeVar
|
|
3
4
|
|
|
4
5
|
from typing_extensions import ParamSpec
|
|
5
6
|
|
|
6
7
|
from snowflake import snowpark
|
|
7
8
|
from snowflake.ml._internal import telemetry
|
|
8
|
-
from snowflake.ml.jobs import
|
|
9
|
-
from snowflake.ml.jobs._utils import
|
|
9
|
+
from snowflake.ml.jobs import job as jb, manager as jm
|
|
10
|
+
from snowflake.ml.jobs._utils import payload_utils
|
|
10
11
|
|
|
11
12
|
_PROJECT = "MLJob"
|
|
12
13
|
|
|
@@ -24,7 +25,7 @@ def remote(
|
|
|
24
25
|
external_access_integrations: Optional[list[str]] = None,
|
|
25
26
|
session: Optional[snowpark.Session] = None,
|
|
26
27
|
**kwargs: Any,
|
|
27
|
-
) -> Callable[[Callable[_Args, _ReturnValue]],
|
|
28
|
+
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
|
28
29
|
"""
|
|
29
30
|
Submit a job to the compute pool.
|
|
30
31
|
|
|
@@ -50,25 +51,29 @@ def remote(
|
|
|
50
51
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
|
51
52
|
"""
|
|
52
53
|
|
|
53
|
-
def decorator(func: Callable[_Args, _ReturnValue]) ->
|
|
54
|
+
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, jb.MLJob[_ReturnValue]]:
|
|
54
55
|
# Copy the function to avoid modifying the original
|
|
55
56
|
# We need to modify the line number of the function to exclude the
|
|
56
57
|
# decorator from the copied source code
|
|
57
58
|
wrapped_func = copy.copy(func)
|
|
58
59
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
|
59
60
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
61
|
+
@functools.wraps(func)
|
|
62
|
+
def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
|
63
|
+
payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
|
|
64
|
+
job = jm._submit_job(
|
|
65
|
+
source=payload,
|
|
66
|
+
stage_name=stage_name,
|
|
67
|
+
compute_pool=compute_pool,
|
|
68
|
+
target_instances=target_instances,
|
|
69
|
+
pip_requirements=pip_requirements,
|
|
70
|
+
external_access_integrations=external_access_integrations,
|
|
71
|
+
session=payload.session or session,
|
|
72
|
+
**kwargs,
|
|
73
|
+
)
|
|
74
|
+
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
|
75
|
+
return job
|
|
76
|
+
|
|
77
|
+
return wrapper
|
|
73
78
|
|
|
74
79
|
return decorator
|