snowflake-ml-python 1.16.0__py3-none-any.whl → 1.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/telemetry.py +3 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +17 -12
- snowflake/ml/experiment/callback/keras.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +3 -0
- snowflake/ml/experiment/callback/xgboost.py +3 -0
- snowflake/ml/experiment/experiment_tracking.py +19 -7
- snowflake/ml/feature_store/feature_store.py +236 -61
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +16 -2
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +8 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +2 -1
- snowflake/ml/jobs/_utils/stage_utils.py +4 -0
- snowflake/ml/jobs/_utils/types.py +15 -0
- snowflake/ml/jobs/job.py +186 -40
- snowflake/ml/jobs/manager.py +48 -39
- snowflake/ml/model/__init__.py +19 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +63 -0
- snowflake/ml/model/_client/model/inference_engine_utils.py +1 -5
- snowflake/ml/model/_client/model/model_version_impl.py +168 -18
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +3 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +11 -3
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +22 -5
- snowflake/ml/model/_packager/model_handlers/_utils.py +70 -0
- snowflake/ml/model/_packager/model_handlers/prophet.py +566 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +7 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/type_hints.py +16 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/metrics/metrics_utils.py +9 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/METADATA +50 -4
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/RECORD +54 -45
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.16.0.dist-info → snowflake_ml_python-1.18.0.dist-info}/top_level.txt +0 -0
|
@@ -12,6 +12,9 @@ PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
|
|
12
12
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
|
13
13
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
|
14
14
|
TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
|
|
15
|
+
INSTANCES_MIN_WAIT_ENV_VAR = "MLRS_INSTANCES_MIN_WAIT"
|
|
16
|
+
INSTANCES_TIMEOUT_ENV_VAR = "MLRS_INSTANCES_TIMEOUT"
|
|
17
|
+
INSTANCES_CHECK_INTERVAL_ENV_VAR = "MLRS_INSTANCES_CHECK_INTERVAL"
|
|
15
18
|
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
|
16
19
|
|
|
17
20
|
# Stage mount paths
|
|
@@ -19,7 +22,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
|
19
22
|
APP_STAGE_SUBPATH = "app"
|
|
20
23
|
SYSTEM_STAGE_SUBPATH = "system"
|
|
21
24
|
OUTPUT_STAGE_SUBPATH = "output"
|
|
22
|
-
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result
|
|
25
|
+
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result"
|
|
23
26
|
|
|
24
27
|
# Default container image information
|
|
25
28
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
|
@@ -53,8 +56,9 @@ ENABLE_HEALTH_CHECKS_ENV_VAR = "ENABLE_HEALTH_CHECKS"
|
|
|
53
56
|
ENABLE_HEALTH_CHECKS = "false"
|
|
54
57
|
|
|
55
58
|
# Job status polling constants
|
|
56
|
-
JOB_POLL_INITIAL_DELAY_SECONDS =
|
|
59
|
+
JOB_POLL_INITIAL_DELAY_SECONDS = 5
|
|
57
60
|
JOB_POLL_MAX_DELAY_SECONDS = 30
|
|
61
|
+
JOB_SPCS_TIMEOUT_SECONDS = 30
|
|
58
62
|
|
|
59
63
|
# Log start and end messages
|
|
60
64
|
LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
|
|
@@ -70,6 +74,7 @@ COMMON_INSTANCE_FAMILIES = {
|
|
|
70
74
|
"CPU_X64_XS": ComputeResources(cpu=1, memory=6),
|
|
71
75
|
"CPU_X64_S": ComputeResources(cpu=3, memory=13),
|
|
72
76
|
"CPU_X64_M": ComputeResources(cpu=6, memory=28),
|
|
77
|
+
"CPU_X64_SL": ComputeResources(cpu=14, memory=54),
|
|
73
78
|
"CPU_X64_L": ComputeResources(cpu=28, memory=116),
|
|
74
79
|
"HIGHMEM_X64_S": ComputeResources(cpu=6, memory=58),
|
|
75
80
|
}
|
|
@@ -82,6 +87,7 @@ AWS_INSTANCE_FAMILIES = {
|
|
|
82
87
|
}
|
|
83
88
|
AZURE_INSTANCE_FAMILIES = {
|
|
84
89
|
"HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
|
|
90
|
+
"HIGHMEM_X64_SL": ComputeResources(cpu=92, memory=654),
|
|
85
91
|
"HIGHMEM_X64_L": ComputeResources(cpu=92, memory=654),
|
|
86
92
|
"GPU_NV_XS": ComputeResources(cpu=3, memory=26, gpu=1, gpu_type="T4"),
|
|
87
93
|
"GPU_NV_SM": ComputeResources(cpu=32, memory=424, gpu=1, gpu_type="A10"),
|
|
@@ -89,7 +95,15 @@ AZURE_INSTANCE_FAMILIES = {
|
|
|
89
95
|
"GPU_NV_3M": ComputeResources(cpu=44, memory=424, gpu=2, gpu_type="A100"),
|
|
90
96
|
"GPU_NV_SL": ComputeResources(cpu=92, memory=858, gpu=4, gpu_type="A100"),
|
|
91
97
|
}
|
|
98
|
+
GCP_INSTANCE_FAMILIES = {
|
|
99
|
+
"HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
|
|
100
|
+
"HIGHMEM_X64_SL": ComputeResources(cpu=92, memory=654),
|
|
101
|
+
"GPU_GCP_NV_L4_1_24G": ComputeResources(cpu=6, memory=28, gpu=1, gpu_type="L4"),
|
|
102
|
+
"GPU_GCP_NV_L4_4_24G": ComputeResources(cpu=44, memory=178, gpu=4, gpu_type="L4"),
|
|
103
|
+
"GPU_GCP_NV_A100_8_40G": ComputeResources(cpu=92, memory=654, gpu=8, gpu_type="A100"),
|
|
104
|
+
}
|
|
92
105
|
CLOUD_INSTANCE_FAMILIES = {
|
|
93
106
|
SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
|
|
94
107
|
SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
|
|
108
|
+
SnowflakeCloudType.GCP: GCP_INSTANCE_FAMILIES,
|
|
95
109
|
}
|
|
@@ -1,16 +1,48 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from enum import Enum
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
|
|
7
|
+
"""Parse a boolean value from an environment variable string.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
value: The environment variable value to parse (may be None).
|
|
11
|
+
default: The default value to return if the value is None or unrecognized.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
True if the value is a truthy string (true, 1, yes, on - case insensitive),
|
|
15
|
+
False if the value is a falsy string (false, 0, no, off - case insensitive),
|
|
16
|
+
or the default value if the value is None or unrecognized.
|
|
17
|
+
"""
|
|
18
|
+
if value is None:
|
|
19
|
+
return default
|
|
20
|
+
|
|
21
|
+
normalized_value = value.strip().lower()
|
|
22
|
+
if normalized_value in ("true", "1", "yes", "on"):
|
|
23
|
+
return True
|
|
24
|
+
elif normalized_value in ("false", "0", "no", "off"):
|
|
25
|
+
return False
|
|
26
|
+
else:
|
|
27
|
+
# For unrecognized values, return the default
|
|
28
|
+
return default
|
|
3
29
|
|
|
4
30
|
|
|
5
31
|
class FeatureFlags(Enum):
|
|
6
32
|
USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
|
|
7
|
-
|
|
33
|
+
ENABLE_RUNTIME_VERSIONS = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
34
|
+
|
|
35
|
+
def is_enabled(self, default: bool = False) -> bool:
|
|
36
|
+
"""Check if the feature flag is enabled.
|
|
8
37
|
|
|
9
|
-
|
|
10
|
-
|
|
38
|
+
Args:
|
|
39
|
+
default: The default value to return if the environment variable is not set.
|
|
11
40
|
|
|
12
|
-
|
|
13
|
-
|
|
41
|
+
Returns:
|
|
42
|
+
True if the environment variable is set to a truthy value,
|
|
43
|
+
False if set to a falsy value, or the default value if not set.
|
|
44
|
+
"""
|
|
45
|
+
return parse_bool_env_value(os.getenv(self.value), default)
|
|
14
46
|
|
|
15
47
|
def __str__(self) -> str:
|
|
16
48
|
return self.value
|
|
@@ -268,7 +268,7 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
268
268
|
# can't handle directories. Reduce the number of PUT operations by using
|
|
269
269
|
# wildcard patterns to batch upload files with the same extension.
|
|
270
270
|
upload_path_patterns = set()
|
|
271
|
-
for p in source_path.
|
|
271
|
+
for p in source_path.rglob("*"):
|
|
272
272
|
if p.is_dir():
|
|
273
273
|
continue
|
|
274
274
|
if p.name.startswith("."):
|
|
@@ -488,10 +488,13 @@ class JobPayload:
|
|
|
488
488
|
" comment = 'Created by snowflake.ml.jobs Python API'",
|
|
489
489
|
params=[stage_name],
|
|
490
490
|
)
|
|
491
|
-
|
|
491
|
+
payload_name = None
|
|
492
492
|
# Upload payload to stage - organize into app/ subdirectory
|
|
493
493
|
app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
|
|
494
494
|
if not isinstance(source, types.PayloadPath):
|
|
495
|
+
if isinstance(source, function_payload_utils.FunctionPayload):
|
|
496
|
+
payload_name = source.function.__name__
|
|
497
|
+
|
|
495
498
|
source_code = generate_python_code(source, source_code_display=True)
|
|
496
499
|
_ = session.file.put_stream(
|
|
497
500
|
io.BytesIO(source_code.encode()),
|
|
@@ -502,12 +505,14 @@ class JobPayload:
|
|
|
502
505
|
source = Path(entrypoint.file_path.parent)
|
|
503
506
|
|
|
504
507
|
elif isinstance(source, stage_utils.StagePath):
|
|
508
|
+
payload_name = entrypoint.file_path.stem
|
|
505
509
|
# copy payload to stage
|
|
506
510
|
if source == entrypoint.file_path:
|
|
507
511
|
source = source.parent
|
|
508
512
|
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
509
513
|
|
|
510
514
|
elif isinstance(source, Path):
|
|
515
|
+
payload_name = entrypoint.file_path.stem
|
|
511
516
|
upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
|
|
512
517
|
if source.is_file():
|
|
513
518
|
source = source.parent
|
|
@@ -562,6 +567,7 @@ class JobPayload:
|
|
|
562
567
|
*python_entrypoint,
|
|
563
568
|
],
|
|
564
569
|
env_vars=env_vars,
|
|
570
|
+
payload_name=payload_name,
|
|
565
571
|
)
|
|
566
572
|
|
|
567
573
|
|
|
@@ -9,19 +9,23 @@ import runpy
|
|
|
9
9
|
import sys
|
|
10
10
|
import time
|
|
11
11
|
import traceback
|
|
12
|
-
import warnings
|
|
13
|
-
from pathlib import Path
|
|
14
12
|
from typing import Any, Optional
|
|
15
13
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
14
|
+
# Ensure payload directory is in sys.path for module imports before importing other modules
|
|
15
|
+
# This is needed to support relative imports in user scripts and to allow overriding
|
|
16
|
+
# modules using modules in the payload directory
|
|
17
|
+
# TODO: Inject the environment variable names at job submission time
|
|
18
|
+
STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
|
|
19
|
+
JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
|
|
20
|
+
PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
|
|
21
|
+
if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
|
|
22
|
+
PAYLOAD_PATH = os.path.join(STAGE_MOUNT_PATH, PAYLOAD_PATH)
|
|
23
|
+
if PAYLOAD_PATH and PAYLOAD_PATH not in sys.path:
|
|
24
|
+
sys.path.insert(0, PAYLOAD_PATH)
|
|
25
|
+
|
|
26
|
+
# Imports below must come after sys.path modification to support module overrides
|
|
27
|
+
import snowflake.ml.jobs._utils.constants # noqa: E402
|
|
28
|
+
import snowflake.snowpark # noqa: E402
|
|
25
29
|
|
|
26
30
|
# Configure logging
|
|
27
31
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
@@ -33,48 +37,74 @@ logger = logging.getLogger(__name__)
|
|
|
33
37
|
# not have the latest version of the code
|
|
34
38
|
# Log start and end messages
|
|
35
39
|
LOG_START_MSG = getattr(
|
|
36
|
-
constants,
|
|
40
|
+
snowflake.ml.jobs._utils.constants,
|
|
37
41
|
"LOG_START_MSG",
|
|
38
42
|
"--------------------------------\nML job started\n--------------------------------",
|
|
39
43
|
)
|
|
40
44
|
LOG_END_MSG = getattr(
|
|
41
|
-
constants,
|
|
45
|
+
snowflake.ml.jobs._utils.constants,
|
|
42
46
|
"LOG_END_MSG",
|
|
43
47
|
"--------------------------------\nML job finished\n--------------------------------",
|
|
44
48
|
)
|
|
49
|
+
MIN_INSTANCES_ENV_VAR = getattr(
|
|
50
|
+
snowflake.ml.jobs._utils.constants,
|
|
51
|
+
"MIN_INSTANCES_ENV_VAR",
|
|
52
|
+
"MLRS_MIN_INSTANCES",
|
|
53
|
+
)
|
|
54
|
+
TARGET_INSTANCES_ENV_VAR = getattr(
|
|
55
|
+
snowflake.ml.jobs._utils.constants,
|
|
56
|
+
"TARGET_INSTANCES_ENV_VAR",
|
|
57
|
+
"SNOWFLAKE_JOBS_COUNT",
|
|
58
|
+
)
|
|
59
|
+
INSTANCES_MIN_WAIT_ENV_VAR = getattr(
|
|
60
|
+
snowflake.ml.jobs._utils.constants,
|
|
61
|
+
"INSTANCES_MIN_WAIT_ENV_VAR",
|
|
62
|
+
"MLRS_INSTANCES_MIN_WAIT",
|
|
63
|
+
)
|
|
64
|
+
INSTANCES_TIMEOUT_ENV_VAR = getattr(
|
|
65
|
+
snowflake.ml.jobs._utils.constants,
|
|
66
|
+
"INSTANCES_TIMEOUT_ENV_VAR",
|
|
67
|
+
"MLRS_INSTANCES_TIMEOUT",
|
|
68
|
+
)
|
|
69
|
+
INSTANCES_CHECK_INTERVAL_ENV_VAR = getattr(
|
|
70
|
+
snowflake.ml.jobs._utils.constants,
|
|
71
|
+
"INSTANCES_CHECK_INTERVAL_ENV_VAR",
|
|
72
|
+
"MLRS_INSTANCES_CHECK_INTERVAL",
|
|
73
|
+
)
|
|
45
74
|
|
|
46
|
-
# min_instances environment variable name
|
|
47
|
-
MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
48
|
-
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
|
-
|
|
50
|
-
# Fallbacks in case of SnowML version mismatch
|
|
51
|
-
STAGE_MOUNT_PATH_ENV_VAR = getattr(constants, "STAGE_MOUNT_PATH_ENV_VAR", "MLRS_STAGE_MOUNT_PATH")
|
|
52
|
-
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
53
|
-
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
54
75
|
|
|
55
76
|
# Constants for the wait_for_instances function
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
77
|
+
MIN_INSTANCES = int(os.environ.get(MIN_INSTANCES_ENV_VAR) or "1")
|
|
78
|
+
TARGET_INSTANCES = int(os.environ.get(TARGET_INSTANCES_ENV_VAR) or MIN_INSTANCES)
|
|
79
|
+
MIN_WAIT_TIME = float(os.getenv(INSTANCES_MIN_WAIT_ENV_VAR) or -1) # seconds
|
|
80
|
+
TIMEOUT = float(os.getenv(INSTANCES_TIMEOUT_ENV_VAR) or 720) # seconds
|
|
81
|
+
CHECK_INTERVAL = float(os.getenv(INSTANCES_CHECK_INTERVAL_ENV_VAR) or 10) # seconds
|
|
59
82
|
|
|
60
|
-
STAGE_MOUNT_PATH = os.environ.get(STAGE_MOUNT_PATH_ENV_VAR, "/mnt/job_stage")
|
|
61
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "output/mljob_result.pkl")
|
|
62
83
|
|
|
84
|
+
def save_mljob_result_v2(value: Any, is_error: bool, path: str) -> None:
|
|
85
|
+
from snowflake.ml.jobs._interop import (
|
|
86
|
+
results as interop_result,
|
|
87
|
+
utils as interop_utils,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
result_obj = interop_result.ExecutionResult(success=not is_error, value=value)
|
|
91
|
+
interop_utils.save_result(result_obj, path)
|
|
63
92
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
except ImportError:
|
|
93
|
+
|
|
94
|
+
def save_mljob_result_v1(value: Any, is_error: bool, path: str) -> None:
|
|
67
95
|
from dataclasses import dataclass
|
|
68
96
|
|
|
97
|
+
import cloudpickle
|
|
98
|
+
|
|
99
|
+
# Directly in-line the ExecutionResult class since the legacy type
|
|
100
|
+
# instead of attempting to import the to-be-deprecated
|
|
101
|
+
# snowflake.ml.jobs._utils.interop module
|
|
102
|
+
# Eventually, this entire function will be removed in favor of v2
|
|
69
103
|
@dataclass(frozen=True)
|
|
70
|
-
class ExecutionResult:
|
|
104
|
+
class ExecutionResult:
|
|
71
105
|
result: Optional[Any] = None
|
|
72
106
|
exception: Optional[BaseException] = None
|
|
73
107
|
|
|
74
|
-
@property
|
|
75
|
-
def success(self) -> bool:
|
|
76
|
-
return self.exception is None
|
|
77
|
-
|
|
78
108
|
def to_dict(self) -> dict[str, Any]:
|
|
79
109
|
"""Return the serializable dictionary."""
|
|
80
110
|
if isinstance(self.exception, BaseException):
|
|
@@ -91,14 +121,45 @@ except ImportError:
|
|
|
91
121
|
"result": self.result,
|
|
92
122
|
}
|
|
93
123
|
|
|
124
|
+
# Create a custom JSON encoder that converts non-serializable types to strings
|
|
125
|
+
class SimpleJSONEncoder(json.JSONEncoder):
|
|
126
|
+
def default(self, obj: Any) -> Any:
|
|
127
|
+
try:
|
|
128
|
+
return super().default(obj)
|
|
129
|
+
except TypeError:
|
|
130
|
+
return f"Unserializable object: {repr(obj)}"
|
|
131
|
+
|
|
132
|
+
result_obj = ExecutionResult(result=None if is_error else value, exception=value if is_error else None)
|
|
133
|
+
result_dict = result_obj.to_dict()
|
|
134
|
+
try:
|
|
135
|
+
# Serialize result using cloudpickle
|
|
136
|
+
result_pickle_path = path
|
|
137
|
+
with open(result_pickle_path, "wb") as f:
|
|
138
|
+
cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
|
|
139
|
+
except Exception as pkl_exc:
|
|
140
|
+
logger.warning(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}")
|
|
94
141
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
142
|
+
try:
|
|
143
|
+
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
|
144
|
+
result_json_path = os.path.splitext(path)[0] + ".json"
|
|
145
|
+
with open(result_json_path, "w") as f:
|
|
146
|
+
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
|
147
|
+
except Exception as json_exc:
|
|
148
|
+
logger.warning(f"Failed to serialize JSON result to {result_json_path}: {json_exc}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def save_mljob_result(result_obj: Any, is_error: bool, path: str) -> None:
|
|
152
|
+
"""Saves the result or error message to a file in the stage mount path.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
result_obj: The result object to save, either the return value or the exception.
|
|
156
|
+
is_error: Whether the result_obj is a raised exception.
|
|
157
|
+
path: The file path to save the result to.
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
save_mljob_result_v2(result_obj, is_error, path)
|
|
161
|
+
except ImportError:
|
|
162
|
+
save_mljob_result_v1(result_obj, is_error, path)
|
|
102
163
|
|
|
103
164
|
|
|
104
165
|
def wait_for_instances(
|
|
@@ -225,20 +286,10 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
225
286
|
original_argv = sys.argv
|
|
226
287
|
sys.argv = [script_path, *script_args]
|
|
227
288
|
|
|
228
|
-
# Ensure payload directory is in sys.path for module imports
|
|
229
|
-
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
230
|
-
# but user scripts are in the payload directory and may import from each other
|
|
231
|
-
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
232
|
-
if payload_dir and not os.path.isabs(payload_dir):
|
|
233
|
-
payload_dir = os.path.join(STAGE_MOUNT_PATH, payload_dir)
|
|
234
|
-
if payload_dir and payload_dir not in sys.path:
|
|
235
|
-
sys.path.insert(0, payload_dir)
|
|
236
|
-
|
|
237
289
|
try:
|
|
238
|
-
|
|
239
290
|
if main_func:
|
|
240
291
|
# Use importlib for scripts with a main function defined
|
|
241
|
-
module_name =
|
|
292
|
+
module_name = os.path.splitext(os.path.basename(script_path))[0]
|
|
242
293
|
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
|
243
294
|
assert spec is not None
|
|
244
295
|
assert spec.loader is not None
|
|
@@ -262,7 +313,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
262
313
|
sys.argv = original_argv
|
|
263
314
|
|
|
264
315
|
|
|
265
|
-
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) ->
|
|
316
|
+
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
|
|
266
317
|
"""Executes a Python script and serializes the result to JOB_RESULT_PATH.
|
|
267
318
|
|
|
268
319
|
Args:
|
|
@@ -271,55 +322,53 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
271
322
|
script_main_func (str, optional): The name of the function to call in the script (if any).
|
|
272
323
|
|
|
273
324
|
Returns:
|
|
274
|
-
|
|
325
|
+
Any: The result of the script execution.
|
|
275
326
|
|
|
276
327
|
Raises:
|
|
277
328
|
Exception: Re-raises any exception caught during script execution.
|
|
278
329
|
"""
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
output_dir = os.path.dirname(result_abs_path)
|
|
284
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
330
|
+
try:
|
|
331
|
+
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
|
332
|
+
except ImportError:
|
|
333
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
285
334
|
|
|
335
|
+
# Initialize Ray if available
|
|
286
336
|
try:
|
|
287
337
|
import ray
|
|
288
338
|
|
|
289
339
|
ray.init(address="auto")
|
|
290
340
|
except ModuleNotFoundError:
|
|
291
|
-
|
|
341
|
+
logger.debug("Ray is not installed, skipping Ray initialization")
|
|
292
342
|
|
|
293
343
|
# Create a Snowpark session before starting
|
|
294
344
|
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
295
345
|
config = SnowflakeLoginOptions()
|
|
296
346
|
config["client_session_keep_alive"] = "True"
|
|
297
|
-
session = Session.builder.configs(config).create() # noqa: F841
|
|
347
|
+
session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
|
|
298
348
|
|
|
349
|
+
execution_result_is_error = False
|
|
350
|
+
execution_result_value = None
|
|
299
351
|
try:
|
|
300
|
-
# Wait for minimum required instances
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
# Log start marker for user script execution
|
|
352
|
+
# Wait for minimum required instances before starting user script execution
|
|
353
|
+
wait_for_instances(
|
|
354
|
+
MIN_INSTANCES,
|
|
355
|
+
TARGET_INSTANCES,
|
|
356
|
+
min_wait_time=MIN_WAIT_TIME,
|
|
357
|
+
timeout=TIMEOUT,
|
|
358
|
+
check_interval=CHECK_INTERVAL,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Log start marker before starting user script execution
|
|
313
362
|
print(LOG_START_MSG) # noqa: T201
|
|
314
363
|
|
|
315
|
-
# Run the script
|
|
316
|
-
|
|
364
|
+
# Run the user script
|
|
365
|
+
execution_result_value = run_script(script_path, *script_args, main_func=script_main_func)
|
|
317
366
|
|
|
318
367
|
# Log end marker for user script execution
|
|
319
368
|
print(LOG_END_MSG) # noqa: T201
|
|
320
369
|
|
|
321
|
-
|
|
322
|
-
|
|
370
|
+
return execution_result_value
|
|
371
|
+
|
|
323
372
|
except Exception as e:
|
|
324
373
|
tb = e.__traceback__
|
|
325
374
|
skip_files = {__file__, runpy.__file__}
|
|
@@ -328,35 +377,23 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
328
377
|
tb = tb.tb_next
|
|
329
378
|
cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
|
|
330
379
|
cleaned_ex = cleaned_ex.with_traceback(tb)
|
|
331
|
-
|
|
380
|
+
execution_result_value = cleaned_ex
|
|
381
|
+
execution_result_is_error = True
|
|
332
382
|
raise
|
|
333
383
|
finally:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
try:
|
|
344
|
-
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
|
345
|
-
# TODO: Manually convert non-serializable types to strings
|
|
346
|
-
result_json_path = os.path.splitext(result_abs_path)[0] + ".json"
|
|
347
|
-
with open(result_json_path, "w") as f:
|
|
348
|
-
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
|
349
|
-
except Exception as json_exc:
|
|
350
|
-
warnings.warn(
|
|
351
|
-
f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
# Close the session after serializing the result
|
|
384
|
+
# Ensure the output directory exists before trying to write result files.
|
|
385
|
+
result_abs_path = (
|
|
386
|
+
JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH)
|
|
387
|
+
)
|
|
388
|
+
output_dir = os.path.dirname(result_abs_path)
|
|
389
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
390
|
+
|
|
391
|
+
# Save the result before closing the session
|
|
392
|
+
save_mljob_result(execution_result_value, execution_result_is_error, result_abs_path)
|
|
355
393
|
session.close()
|
|
356
394
|
|
|
357
395
|
|
|
358
396
|
if __name__ == "__main__":
|
|
359
|
-
# Parse command line arguments
|
|
360
397
|
parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
|
|
361
398
|
parser.add_argument("script_path", help="Path to the Python script to execute")
|
|
362
399
|
parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
|
|
@@ -104,7 +104,7 @@ def _get_image_spec(
|
|
|
104
104
|
image_tag = runtime_environment
|
|
105
105
|
else:
|
|
106
106
|
container_image = runtime_environment
|
|
107
|
-
elif feature_flags.FeatureFlags.
|
|
107
|
+
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
108
108
|
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
109
109
|
|
|
110
110
|
container_image = container_image or f"{image_repo}/{image_name}:{image_tag}"
|
|
@@ -266,6 +266,7 @@ def generate_service_spec(
|
|
|
266
266
|
{"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
|
|
267
267
|
{"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
|
|
268
268
|
{"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
|
|
269
|
+
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
|
|
269
270
|
{"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
|
|
270
271
|
{"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
|
|
271
272
|
{"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
|
|
@@ -32,6 +32,10 @@ class StagePath:
|
|
|
32
32
|
self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
|
|
33
33
|
self._path = Path(relpath or "")
|
|
34
34
|
|
|
35
|
+
@property
|
|
36
|
+
def stem(self) -> str:
|
|
37
|
+
return self._path.stem
|
|
38
|
+
|
|
35
39
|
@property
|
|
36
40
|
def parts(self) -> tuple[str, ...]:
|
|
37
41
|
return self._path.parts
|
|
@@ -11,6 +11,7 @@ JOB_STATUS = Literal[
|
|
|
11
11
|
"CANCELLING",
|
|
12
12
|
"CANCELLED",
|
|
13
13
|
"INTERNAL_ERROR",
|
|
14
|
+
"DELETED",
|
|
14
15
|
]
|
|
15
16
|
|
|
16
17
|
|
|
@@ -22,6 +23,10 @@ class PayloadPath(Protocol):
|
|
|
22
23
|
def name(self) -> str:
|
|
23
24
|
...
|
|
24
25
|
|
|
26
|
+
@property
|
|
27
|
+
def stem(self) -> str:
|
|
28
|
+
...
|
|
29
|
+
|
|
25
30
|
@property
|
|
26
31
|
def suffix(self) -> str:
|
|
27
32
|
...
|
|
@@ -91,6 +96,7 @@ class UploadedPayload:
|
|
|
91
96
|
stage_path: PurePath
|
|
92
97
|
entrypoint: list[Union[str, PurePath]]
|
|
93
98
|
env_vars: dict[str, str] = field(default_factory=dict)
|
|
99
|
+
payload_name: Optional[str] = None
|
|
94
100
|
|
|
95
101
|
|
|
96
102
|
@dataclass(frozen=True)
|
|
@@ -106,3 +112,12 @@ class ImageSpec:
|
|
|
106
112
|
resource_requests: ComputeResources
|
|
107
113
|
resource_limits: ComputeResources
|
|
108
114
|
container_image: str
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass(frozen=True)
|
|
118
|
+
class ServiceInfo:
|
|
119
|
+
database_name: str
|
|
120
|
+
schema_name: str
|
|
121
|
+
status: str
|
|
122
|
+
compute_pool: str
|
|
123
|
+
target_instances: int
|