snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/platform_capabilities.py +4 -0
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/experiment/experiment_tracking.py +63 -19
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +50 -11
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +54 -36
- snowflake/ml/model/__init__.py +16 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
- snowflake/ml/model/_client/model/model_version_impl.py +44 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +50 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +48 -21
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +1 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,48 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from enum import Enum
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def parse_bool_env_value(value: Optional[str], default: bool = False) -> bool:
|
|
7
|
+
"""Parse a boolean value from an environment variable string.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
value: The environment variable value to parse (may be None).
|
|
11
|
+
default: The default value to return if the value is None or unrecognized.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
True if the value is a truthy string (true, 1, yes, on - case insensitive),
|
|
15
|
+
False if the value is a falsy string (false, 0, no, off - case insensitive),
|
|
16
|
+
or the default value if the value is None or unrecognized.
|
|
17
|
+
"""
|
|
18
|
+
if value is None:
|
|
19
|
+
return default
|
|
20
|
+
|
|
21
|
+
normalized_value = value.strip().lower()
|
|
22
|
+
if normalized_value in ("true", "1", "yes", "on"):
|
|
23
|
+
return True
|
|
24
|
+
elif normalized_value in ("false", "0", "no", "off"):
|
|
25
|
+
return False
|
|
26
|
+
else:
|
|
27
|
+
# For unrecognized values, return the default
|
|
28
|
+
return default
|
|
3
29
|
|
|
4
30
|
|
|
5
31
|
class FeatureFlags(Enum):
|
|
6
32
|
USE_SUBMIT_JOB_V2 = "MLRS_USE_SUBMIT_JOB_V2"
|
|
7
|
-
|
|
33
|
+
ENABLE_RUNTIME_VERSIONS = "MLRS_ENABLE_RUNTIME_VERSIONS"
|
|
34
|
+
|
|
35
|
+
def is_enabled(self, default: bool = False) -> bool:
|
|
36
|
+
"""Check if the feature flag is enabled.
|
|
8
37
|
|
|
9
|
-
|
|
10
|
-
|
|
38
|
+
Args:
|
|
39
|
+
default: The default value to return if the environment variable is not set.
|
|
11
40
|
|
|
12
|
-
|
|
13
|
-
|
|
41
|
+
Returns:
|
|
42
|
+
True if the environment variable is set to a truthy value,
|
|
43
|
+
False if set to a falsy value, or the default value if not set.
|
|
44
|
+
"""
|
|
45
|
+
return parse_bool_env_value(os.getenv(self.value), default)
|
|
14
46
|
|
|
15
47
|
def __str__(self) -> str:
|
|
16
48
|
return self.value
|
|
@@ -268,7 +268,7 @@ def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_sp
|
|
|
268
268
|
# can't handle directories. Reduce the number of PUT operations by using
|
|
269
269
|
# wildcard patterns to batch upload files with the same extension.
|
|
270
270
|
upload_path_patterns = set()
|
|
271
|
-
for p in source_path.
|
|
271
|
+
for p in source_path.rglob("*"):
|
|
272
272
|
if p.is_dir():
|
|
273
273
|
continue
|
|
274
274
|
if p.name.startswith("."):
|
|
@@ -9,19 +9,23 @@ import runpy
|
|
|
9
9
|
import sys
|
|
10
10
|
import time
|
|
11
11
|
import traceback
|
|
12
|
-
import warnings
|
|
13
|
-
from pathlib import Path
|
|
14
12
|
from typing import Any, Optional
|
|
15
13
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
14
|
+
# Ensure payload directory is in sys.path for module imports before importing other modules
|
|
15
|
+
# This is needed to support relative imports in user scripts and to allow overriding
|
|
16
|
+
# modules using modules in the payload directory
|
|
17
|
+
# TODO: Inject the environment variable names at job submission time
|
|
18
|
+
STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
|
|
19
|
+
JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
|
|
20
|
+
PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
|
|
21
|
+
if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
|
|
22
|
+
PAYLOAD_PATH = os.path.join(STAGE_MOUNT_PATH, PAYLOAD_PATH)
|
|
23
|
+
if PAYLOAD_PATH and PAYLOAD_PATH not in sys.path:
|
|
24
|
+
sys.path.insert(0, PAYLOAD_PATH)
|
|
25
|
+
|
|
26
|
+
# Imports below must come after sys.path modification to support module overrides
|
|
27
|
+
import snowflake.ml.jobs._utils.constants # noqa: E402
|
|
28
|
+
import snowflake.snowpark # noqa: E402
|
|
25
29
|
|
|
26
30
|
# Configure logging
|
|
27
31
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
@@ -33,48 +37,74 @@ logger = logging.getLogger(__name__)
|
|
|
33
37
|
# not have the latest version of the code
|
|
34
38
|
# Log start and end messages
|
|
35
39
|
LOG_START_MSG = getattr(
|
|
36
|
-
constants,
|
|
40
|
+
snowflake.ml.jobs._utils.constants,
|
|
37
41
|
"LOG_START_MSG",
|
|
38
42
|
"--------------------------------\nML job started\n--------------------------------",
|
|
39
43
|
)
|
|
40
44
|
LOG_END_MSG = getattr(
|
|
41
|
-
constants,
|
|
45
|
+
snowflake.ml.jobs._utils.constants,
|
|
42
46
|
"LOG_END_MSG",
|
|
43
47
|
"--------------------------------\nML job finished\n--------------------------------",
|
|
44
48
|
)
|
|
49
|
+
MIN_INSTANCES_ENV_VAR = getattr(
|
|
50
|
+
snowflake.ml.jobs._utils.constants,
|
|
51
|
+
"MIN_INSTANCES_ENV_VAR",
|
|
52
|
+
"MLRS_MIN_INSTANCES",
|
|
53
|
+
)
|
|
54
|
+
TARGET_INSTANCES_ENV_VAR = getattr(
|
|
55
|
+
snowflake.ml.jobs._utils.constants,
|
|
56
|
+
"TARGET_INSTANCES_ENV_VAR",
|
|
57
|
+
"SNOWFLAKE_JOBS_COUNT",
|
|
58
|
+
)
|
|
59
|
+
INSTANCES_MIN_WAIT_ENV_VAR = getattr(
|
|
60
|
+
snowflake.ml.jobs._utils.constants,
|
|
61
|
+
"INSTANCES_MIN_WAIT_ENV_VAR",
|
|
62
|
+
"MLRS_INSTANCES_MIN_WAIT",
|
|
63
|
+
)
|
|
64
|
+
INSTANCES_TIMEOUT_ENV_VAR = getattr(
|
|
65
|
+
snowflake.ml.jobs._utils.constants,
|
|
66
|
+
"INSTANCES_TIMEOUT_ENV_VAR",
|
|
67
|
+
"MLRS_INSTANCES_TIMEOUT",
|
|
68
|
+
)
|
|
69
|
+
INSTANCES_CHECK_INTERVAL_ENV_VAR = getattr(
|
|
70
|
+
snowflake.ml.jobs._utils.constants,
|
|
71
|
+
"INSTANCES_CHECK_INTERVAL_ENV_VAR",
|
|
72
|
+
"MLRS_INSTANCES_CHECK_INTERVAL",
|
|
73
|
+
)
|
|
45
74
|
|
|
46
|
-
# min_instances environment variable name
|
|
47
|
-
MIN_INSTANCES_ENV_VAR = getattr(constants, "MIN_INSTANCES_ENV_VAR", "MLRS_MIN_INSTANCES")
|
|
48
|
-
TARGET_INSTANCES_ENV_VAR = getattr(constants, "TARGET_INSTANCES_ENV_VAR", "SNOWFLAKE_JOBS_COUNT")
|
|
49
|
-
|
|
50
|
-
# Fallbacks in case of SnowML version mismatch
|
|
51
|
-
STAGE_MOUNT_PATH_ENV_VAR = getattr(constants, "STAGE_MOUNT_PATH_ENV_VAR", "MLRS_STAGE_MOUNT_PATH")
|
|
52
|
-
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
|
53
|
-
PAYLOAD_DIR_ENV_VAR = getattr(constants, "PAYLOAD_DIR_ENV_VAR", "MLRS_PAYLOAD_DIR")
|
|
54
75
|
|
|
55
76
|
# Constants for the wait_for_instances function
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
77
|
+
MIN_INSTANCES = int(os.environ.get(MIN_INSTANCES_ENV_VAR) or "1")
|
|
78
|
+
TARGET_INSTANCES = int(os.environ.get(TARGET_INSTANCES_ENV_VAR) or MIN_INSTANCES)
|
|
79
|
+
MIN_WAIT_TIME = float(os.getenv(INSTANCES_MIN_WAIT_ENV_VAR) or -1) # seconds
|
|
80
|
+
TIMEOUT = float(os.getenv(INSTANCES_TIMEOUT_ENV_VAR) or 720) # seconds
|
|
81
|
+
CHECK_INTERVAL = float(os.getenv(INSTANCES_CHECK_INTERVAL_ENV_VAR) or 10) # seconds
|
|
59
82
|
|
|
60
|
-
STAGE_MOUNT_PATH = os.environ.get(STAGE_MOUNT_PATH_ENV_VAR, "/mnt/job_stage")
|
|
61
|
-
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "output/mljob_result.pkl")
|
|
62
83
|
|
|
84
|
+
def save_mljob_result_v2(value: Any, is_error: bool, path: str) -> None:
|
|
85
|
+
from snowflake.ml.jobs._interop import (
|
|
86
|
+
results as interop_result,
|
|
87
|
+
utils as interop_utils,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
result_obj = interop_result.ExecutionResult(success=not is_error, value=value)
|
|
91
|
+
interop_utils.save_result(result_obj, path)
|
|
63
92
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
except ImportError:
|
|
93
|
+
|
|
94
|
+
def save_mljob_result_v1(value: Any, is_error: bool, path: str) -> None:
|
|
67
95
|
from dataclasses import dataclass
|
|
68
96
|
|
|
97
|
+
import cloudpickle
|
|
98
|
+
|
|
99
|
+
# Directly in-line the ExecutionResult class since the legacy type
|
|
100
|
+
# instead of attempting to import the to-be-deprecated
|
|
101
|
+
# snowflake.ml.jobs._utils.interop module
|
|
102
|
+
# Eventually, this entire function will be removed in favor of v2
|
|
69
103
|
@dataclass(frozen=True)
|
|
70
|
-
class ExecutionResult:
|
|
104
|
+
class ExecutionResult:
|
|
71
105
|
result: Optional[Any] = None
|
|
72
106
|
exception: Optional[BaseException] = None
|
|
73
107
|
|
|
74
|
-
@property
|
|
75
|
-
def success(self) -> bool:
|
|
76
|
-
return self.exception is None
|
|
77
|
-
|
|
78
108
|
def to_dict(self) -> dict[str, Any]:
|
|
79
109
|
"""Return the serializable dictionary."""
|
|
80
110
|
if isinstance(self.exception, BaseException):
|
|
@@ -91,14 +121,45 @@ except ImportError:
|
|
|
91
121
|
"result": self.result,
|
|
92
122
|
}
|
|
93
123
|
|
|
124
|
+
# Create a custom JSON encoder that converts non-serializable types to strings
|
|
125
|
+
class SimpleJSONEncoder(json.JSONEncoder):
|
|
126
|
+
def default(self, obj: Any) -> Any:
|
|
127
|
+
try:
|
|
128
|
+
return super().default(obj)
|
|
129
|
+
except TypeError:
|
|
130
|
+
return f"Unserializable object: {repr(obj)}"
|
|
131
|
+
|
|
132
|
+
result_obj = ExecutionResult(result=None if is_error else value, exception=value if is_error else None)
|
|
133
|
+
result_dict = result_obj.to_dict()
|
|
134
|
+
try:
|
|
135
|
+
# Serialize result using cloudpickle
|
|
136
|
+
result_pickle_path = path
|
|
137
|
+
with open(result_pickle_path, "wb") as f:
|
|
138
|
+
cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
|
|
139
|
+
except Exception as pkl_exc:
|
|
140
|
+
logger.warning(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}")
|
|
94
141
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
142
|
+
try:
|
|
143
|
+
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
|
144
|
+
result_json_path = os.path.splitext(path)[0] + ".json"
|
|
145
|
+
with open(result_json_path, "w") as f:
|
|
146
|
+
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
|
147
|
+
except Exception as json_exc:
|
|
148
|
+
logger.warning(f"Failed to serialize JSON result to {result_json_path}: {json_exc}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def save_mljob_result(result_obj: Any, is_error: bool, path: str) -> None:
|
|
152
|
+
"""Saves the result or error message to a file in the stage mount path.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
result_obj: The result object to save, either the return value or the exception.
|
|
156
|
+
is_error: Whether the result_obj is a raised exception.
|
|
157
|
+
path: The file path to save the result to.
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
save_mljob_result_v2(result_obj, is_error, path)
|
|
161
|
+
except ImportError:
|
|
162
|
+
save_mljob_result_v1(result_obj, is_error, path)
|
|
102
163
|
|
|
103
164
|
|
|
104
165
|
def wait_for_instances(
|
|
@@ -225,20 +286,10 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
225
286
|
original_argv = sys.argv
|
|
226
287
|
sys.argv = [script_path, *script_args]
|
|
227
288
|
|
|
228
|
-
# Ensure payload directory is in sys.path for module imports
|
|
229
|
-
# This is needed because mljob_launcher.py is now in /mnt/job_stage/system
|
|
230
|
-
# but user scripts are in the payload directory and may import from each other
|
|
231
|
-
payload_dir = os.environ.get(PAYLOAD_DIR_ENV_VAR)
|
|
232
|
-
if payload_dir and not os.path.isabs(payload_dir):
|
|
233
|
-
payload_dir = os.path.join(STAGE_MOUNT_PATH, payload_dir)
|
|
234
|
-
if payload_dir and payload_dir not in sys.path:
|
|
235
|
-
sys.path.insert(0, payload_dir)
|
|
236
|
-
|
|
237
289
|
try:
|
|
238
|
-
|
|
239
290
|
if main_func:
|
|
240
291
|
# Use importlib for scripts with a main function defined
|
|
241
|
-
module_name =
|
|
292
|
+
module_name = os.path.splitext(os.path.basename(script_path))[0]
|
|
242
293
|
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
|
243
294
|
assert spec is not None
|
|
244
295
|
assert spec.loader is not None
|
|
@@ -262,7 +313,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
|
|
|
262
313
|
sys.argv = original_argv
|
|
263
314
|
|
|
264
315
|
|
|
265
|
-
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) ->
|
|
316
|
+
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
|
|
266
317
|
"""Executes a Python script and serializes the result to JOB_RESULT_PATH.
|
|
267
318
|
|
|
268
319
|
Args:
|
|
@@ -271,55 +322,53 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
271
322
|
script_main_func (str, optional): The name of the function to call in the script (if any).
|
|
272
323
|
|
|
273
324
|
Returns:
|
|
274
|
-
|
|
325
|
+
Any: The result of the script execution.
|
|
275
326
|
|
|
276
327
|
Raises:
|
|
277
328
|
Exception: Re-raises any exception caught during script execution.
|
|
278
329
|
"""
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
output_dir = os.path.dirname(result_abs_path)
|
|
284
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
330
|
+
try:
|
|
331
|
+
from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions
|
|
332
|
+
except ImportError:
|
|
333
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
|
285
334
|
|
|
335
|
+
# Initialize Ray if available
|
|
286
336
|
try:
|
|
287
337
|
import ray
|
|
288
338
|
|
|
289
339
|
ray.init(address="auto")
|
|
290
340
|
except ModuleNotFoundError:
|
|
291
|
-
|
|
341
|
+
logger.debug("Ray is not installed, skipping Ray initialization")
|
|
292
342
|
|
|
293
343
|
# Create a Snowpark session before starting
|
|
294
344
|
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
|
295
345
|
config = SnowflakeLoginOptions()
|
|
296
346
|
config["client_session_keep_alive"] = "True"
|
|
297
|
-
session = Session.builder.configs(config).create() # noqa: F841
|
|
347
|
+
session = snowflake.snowpark.Session.builder.configs(config).create() # noqa: F841
|
|
298
348
|
|
|
349
|
+
execution_result_is_error = False
|
|
350
|
+
execution_result_value = None
|
|
299
351
|
try:
|
|
300
|
-
# Wait for minimum required instances
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
# Log start marker for user script execution
|
|
352
|
+
# Wait for minimum required instances before starting user script execution
|
|
353
|
+
wait_for_instances(
|
|
354
|
+
MIN_INSTANCES,
|
|
355
|
+
TARGET_INSTANCES,
|
|
356
|
+
min_wait_time=MIN_WAIT_TIME,
|
|
357
|
+
timeout=TIMEOUT,
|
|
358
|
+
check_interval=CHECK_INTERVAL,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Log start marker before starting user script execution
|
|
313
362
|
print(LOG_START_MSG) # noqa: T201
|
|
314
363
|
|
|
315
|
-
# Run the script
|
|
316
|
-
|
|
364
|
+
# Run the user script
|
|
365
|
+
execution_result_value = run_script(script_path, *script_args, main_func=script_main_func)
|
|
317
366
|
|
|
318
367
|
# Log end marker for user script execution
|
|
319
368
|
print(LOG_END_MSG) # noqa: T201
|
|
320
369
|
|
|
321
|
-
|
|
322
|
-
|
|
370
|
+
return execution_result_value
|
|
371
|
+
|
|
323
372
|
except Exception as e:
|
|
324
373
|
tb = e.__traceback__
|
|
325
374
|
skip_files = {__file__, runpy.__file__}
|
|
@@ -328,35 +377,23 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
|
|
|
328
377
|
tb = tb.tb_next
|
|
329
378
|
cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
|
|
330
379
|
cleaned_ex = cleaned_ex.with_traceback(tb)
|
|
331
|
-
|
|
380
|
+
execution_result_value = cleaned_ex
|
|
381
|
+
execution_result_is_error = True
|
|
332
382
|
raise
|
|
333
383
|
finally:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
try:
|
|
344
|
-
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
|
345
|
-
# TODO: Manually convert non-serializable types to strings
|
|
346
|
-
result_json_path = os.path.splitext(result_abs_path)[0] + ".json"
|
|
347
|
-
with open(result_json_path, "w") as f:
|
|
348
|
-
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
|
349
|
-
except Exception as json_exc:
|
|
350
|
-
warnings.warn(
|
|
351
|
-
f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
# Close the session after serializing the result
|
|
384
|
+
# Ensure the output directory exists before trying to write result files.
|
|
385
|
+
result_abs_path = (
|
|
386
|
+
JOB_RESULT_PATH if os.path.isabs(JOB_RESULT_PATH) else os.path.join(STAGE_MOUNT_PATH, JOB_RESULT_PATH)
|
|
387
|
+
)
|
|
388
|
+
output_dir = os.path.dirname(result_abs_path)
|
|
389
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
390
|
+
|
|
391
|
+
# Save the result before closing the session
|
|
392
|
+
save_mljob_result(execution_result_value, execution_result_is_error, result_abs_path)
|
|
355
393
|
session.close()
|
|
356
394
|
|
|
357
395
|
|
|
358
396
|
if __name__ == "__main__":
|
|
359
|
-
# Parse command line arguments
|
|
360
397
|
parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
|
|
361
398
|
parser.add_argument("script_path", help="Path to the Python script to execute")
|
|
362
399
|
parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
+
import re
|
|
3
4
|
import sys
|
|
4
5
|
from math import ceil
|
|
5
6
|
from pathlib import PurePath
|
|
@@ -10,6 +11,8 @@ from snowflake.ml._internal.utils import snowflake_env
|
|
|
10
11
|
from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
|
|
11
12
|
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
12
13
|
|
|
14
|
+
_OCI_TAG_REGEX = re.compile("^[a-zA-Z0-9._-]{1,128}$")
|
|
15
|
+
|
|
13
16
|
|
|
14
17
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
15
18
|
"""Extract resource information for the specified compute pool"""
|
|
@@ -56,22 +59,55 @@ def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU"
|
|
|
56
59
|
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
57
60
|
|
|
58
61
|
|
|
59
|
-
def
|
|
62
|
+
def _check_image_tag_valid(tag: Optional[str]) -> bool:
|
|
63
|
+
if tag is None:
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
return _OCI_TAG_REGEX.fullmatch(tag) is not None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _get_image_spec(
|
|
70
|
+
session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
|
|
71
|
+
) -> types.ImageSpec:
|
|
72
|
+
"""
|
|
73
|
+
Resolve image specification (container image and resources) for the job.
|
|
74
|
+
|
|
75
|
+
Behavior:
|
|
76
|
+
- If `runtime_environment` is empty or the feature flag is disabled, use the
|
|
77
|
+
default image tag and image name.
|
|
78
|
+
- If `runtime_environment` is a valid image tag, use that tag with the default
|
|
79
|
+
repository/name.
|
|
80
|
+
- If `runtime_environment` is a full image URL, use it directly.
|
|
81
|
+
- If the feature flag is enabled and `runtime_environment` is not provided,
|
|
82
|
+
select an ML Runtime image matching the local Python major.minor
|
|
83
|
+
- When multiple inputs are provided, `runtime_environment` takes priority.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
session: Snowflake session.
|
|
87
|
+
compute_pool: Compute pool used to infer CPU/GPU resources.
|
|
88
|
+
runtime_environment: Optional image tag or full image URL to override.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Image spec including container image and resource requests/limits.
|
|
92
|
+
"""
|
|
60
93
|
# Retrieve compute pool node resources
|
|
61
94
|
resources = _get_node_resources(session, compute_pool=compute_pool)
|
|
95
|
+
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
96
|
+
image_tag = _get_runtime_image_tag()
|
|
97
|
+
image_repo = constants.DEFAULT_IMAGE_REPO
|
|
98
|
+
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
|
62
99
|
|
|
63
100
|
# Use MLRuntime image
|
|
64
|
-
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
65
101
|
container_image = None
|
|
66
|
-
if
|
|
102
|
+
if runtime_environment:
|
|
103
|
+
if _check_image_tag_valid(runtime_environment):
|
|
104
|
+
image_tag = runtime_environment
|
|
105
|
+
else:
|
|
106
|
+
container_image = runtime_environment
|
|
107
|
+
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
67
108
|
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
68
109
|
|
|
69
|
-
|
|
70
|
-
image_repo = constants.DEFAULT_IMAGE_REPO
|
|
71
|
-
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
|
72
|
-
image_tag = _get_runtime_image_tag()
|
|
73
|
-
container_image = f"{image_repo}/{image_name}:{image_tag}"
|
|
74
|
-
|
|
110
|
+
container_image = container_image or f"{image_repo}/{image_name}:{image_tag}"
|
|
75
111
|
# TODO: Should each instance consume the entire pod?
|
|
76
112
|
return types.ImageSpec(
|
|
77
113
|
resource_requests=resources,
|
|
@@ -127,6 +163,7 @@ def generate_service_spec(
|
|
|
127
163
|
target_instances: int = 1,
|
|
128
164
|
min_instances: int = 1,
|
|
129
165
|
enable_metrics: bool = False,
|
|
166
|
+
runtime_environment: Optional[str] = None,
|
|
130
167
|
) -> dict[str, Any]:
|
|
131
168
|
"""
|
|
132
169
|
Generate a service specification for a job.
|
|
@@ -139,11 +176,12 @@ def generate_service_spec(
|
|
|
139
176
|
target_instances: Number of instances for multi-node job
|
|
140
177
|
enable_metrics: Enable platform metrics for the job
|
|
141
178
|
min_instances: Minimum number of instances required to start the job
|
|
179
|
+
runtime_environment: The runtime image to use. Only support image tag or full image URL.
|
|
142
180
|
|
|
143
181
|
Returns:
|
|
144
182
|
Job service specification
|
|
145
183
|
"""
|
|
146
|
-
image_spec = _get_image_spec(session, compute_pool)
|
|
184
|
+
image_spec = _get_image_spec(session, compute_pool, runtime_environment)
|
|
147
185
|
|
|
148
186
|
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
|
149
187
|
resource_requests: dict[str, Union[str, int]] = {
|
|
@@ -228,6 +266,7 @@ def generate_service_spec(
|
|
|
228
266
|
{"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
|
|
229
267
|
{"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
|
|
230
268
|
{"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
|
|
269
|
+
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
|
|
231
270
|
{"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
|
|
232
271
|
{"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
|
|
233
272
|
{"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
|
|
@@ -317,7 +356,7 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
|
|
317
356
|
Returns:
|
|
318
357
|
The patched object.
|
|
319
358
|
"""
|
|
320
|
-
if
|
|
359
|
+
if type(base) is not type(patch):
|
|
321
360
|
if base is not None:
|
|
322
361
|
logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
|
|
323
362
|
return patch
|
|
@@ -11,6 +11,7 @@ JOB_STATUS = Literal[
|
|
|
11
11
|
"CANCELLING",
|
|
12
12
|
"CANCELLED",
|
|
13
13
|
"INTERNAL_ERROR",
|
|
14
|
+
"DELETED",
|
|
14
15
|
]
|
|
15
16
|
|
|
16
17
|
|
|
@@ -106,3 +107,12 @@ class ImageSpec:
|
|
|
106
107
|
resource_requests: ComputeResources
|
|
107
108
|
resource_limits: ComputeResources
|
|
108
109
|
container_image: str
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass(frozen=True)
|
|
113
|
+
class ServiceInfo:
|
|
114
|
+
database_name: str
|
|
115
|
+
schema_name: str
|
|
116
|
+
status: str
|
|
117
|
+
compute_pool: str
|
|
118
|
+
target_instances: int
|