snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/ml/_internal/platform_capabilities.py +13 -11
- snowflake/ml/_internal/utils/identifier.py +2 -2
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +39 -30
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +1 -1
- snowflake/ml/jobs/_utils/spec_utils.py +1 -1
- snowflake/ml/jobs/decorators.py +6 -0
- snowflake/ml/jobs/job.py +63 -16
- snowflake/ml/jobs/manager.py +50 -16
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/service_ops.py +26 -14
- snowflake/ml/model/_client/service/model_deployment_spec.py +340 -170
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
- snowflake/ml/model/_client/sql/service.py +4 -13
- snowflake/ml/model/_model_composer/model_composer.py +41 -18
- snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +4 -4
- snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
- snowflake/ml/model/custom_model.py +17 -4
- snowflake/ml/model/model_signature.py +3 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/explain_visualize.py +286 -0
- snowflake/ml/registry/_manager/model_manager.py +23 -2
- snowflake/ml/registry/registry.py +10 -9
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +40 -8
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/RECORD +190 -189
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1
1
|
from snowflake.cortex._classify_text import ClassifyText, classify_text
|
2
|
-
from snowflake.cortex._complete import
|
2
|
+
from snowflake.cortex._complete import (
|
3
|
+
Complete,
|
4
|
+
CompleteOptions,
|
5
|
+
ConversationMessage,
|
6
|
+
complete,
|
7
|
+
)
|
3
8
|
from snowflake.cortex._embed_text_768 import EmbedText768, embed_text_768
|
4
9
|
from snowflake.cortex._embed_text_1024 import EmbedText1024, embed_text_1024
|
5
10
|
from snowflake.cortex._extract_answer import ExtractAnswer, extract_answer
|
@@ -14,6 +19,7 @@ __all__ = [
|
|
14
19
|
"Complete",
|
15
20
|
"complete",
|
16
21
|
"CompleteOptions",
|
22
|
+
"ConversationMessage",
|
17
23
|
"EmbedText768",
|
18
24
|
"embed_text_768",
|
19
25
|
"EmbedText1024",
|
@@ -11,6 +11,9 @@ from snowflake.snowpark import (
|
|
11
11
|
session as snowpark_session,
|
12
12
|
)
|
13
13
|
|
14
|
+
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
15
|
+
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC"
|
16
|
+
|
14
17
|
|
15
18
|
class PlatformCapabilities:
|
16
19
|
"""Class that retrieves platform feature values for the currently running server.
|
@@ -18,12 +21,12 @@ class PlatformCapabilities:
|
|
18
21
|
Example usage:
|
19
22
|
```
|
20
23
|
pc = PlatformCapabilities.get_instance(session)
|
21
|
-
if pc.
|
22
|
-
#
|
23
|
-
print("
|
24
|
+
if pc.is_inlined_deployment_spec_enabled():
|
25
|
+
# Inline deployment spec is enabled.
|
26
|
+
print("Inline deployment spec is enabled.")
|
24
27
|
else:
|
25
|
-
#
|
26
|
-
print("
|
28
|
+
# Inline deployment spec is disabled.
|
29
|
+
print("Inline deployment spec is disabled or not supported.")
|
27
30
|
```
|
28
31
|
"""
|
29
32
|
|
@@ -50,9 +53,11 @@ class PlatformCapabilities:
|
|
50
53
|
|
51
54
|
# For contextmanager, we need to have return type Iterator[Never]. However, Never type is introduced only in
|
52
55
|
# Python 3.11. So, we are ignoring the type for this method.
|
56
|
+
_dummy_features: dict[str, Any] = {"dummy": "dummy"}
|
57
|
+
|
53
58
|
@classmethod # type: ignore[arg-type]
|
54
59
|
@contextmanager
|
55
|
-
def mock_features(cls, features: dict[str, Any]) -> None: # type: ignore[misc]
|
60
|
+
def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
|
56
61
|
logging.debug(f"Setting mock features: {features}")
|
57
62
|
cls.set_mock_features(features)
|
58
63
|
try:
|
@@ -61,14 +66,11 @@ class PlatformCapabilities:
|
|
61
66
|
logging.debug(f"Clearing mock features: {features}")
|
62
67
|
cls.clear_mock_features()
|
63
68
|
|
64
|
-
def is_nested_function_enabled(self) -> bool:
|
65
|
-
return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
|
66
|
-
|
67
69
|
def is_inlined_deployment_spec_enabled(self) -> bool:
|
68
|
-
return self._get_bool_feature(
|
70
|
+
return self._get_bool_feature(INLINE_DEPLOYMENT_SPEC_PARAMETER, False)
|
69
71
|
|
70
72
|
def is_live_commit_enabled(self) -> bool:
|
71
|
-
return self._get_bool_feature(
|
73
|
+
return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
|
72
74
|
|
73
75
|
@staticmethod
|
74
76
|
def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
|
@@ -12,7 +12,7 @@ SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER)
|
|
12
12
|
_SF_SCHEMA_LEVEL_OBJECT = (
|
13
13
|
rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})"
|
14
14
|
)
|
15
|
-
_SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P<path
|
15
|
+
_SF_STAGE_PATH = rf"@?{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>/.*)?"
|
16
16
|
_SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
|
17
17
|
_SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH)
|
18
18
|
|
@@ -197,7 +197,7 @@ def parse_snowflake_stage_path(
|
|
197
197
|
res.group("db"),
|
198
198
|
res.group("schema"),
|
199
199
|
res.group("object"),
|
200
|
-
res.group("path"),
|
200
|
+
res.group("path") or "",
|
201
201
|
)
|
202
202
|
|
203
203
|
|
@@ -13,7 +13,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
13
13
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
14
14
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
15
15
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
16
|
-
DEFAULT_IMAGE_TAG = "1.
|
16
|
+
DEFAULT_IMAGE_TAG = "1.2.3"
|
17
17
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
18
18
|
|
19
19
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -9,6 +9,7 @@ from pathlib import Path, PurePath
|
|
9
9
|
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
|
10
10
|
|
11
11
|
import cloudpickle as cp
|
12
|
+
from packaging import version
|
12
13
|
|
13
14
|
from snowflake import snowpark
|
14
15
|
from snowflake.ml.jobs._utils import constants, types
|
@@ -97,11 +98,18 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
97
98
|
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
98
99
|
if [ $? -eq 0 ]; then
|
99
100
|
# Parse the output using read
|
100
|
-
read head_index head_ip <<< "$head_info"
|
101
|
+
read head_index head_ip head_status<<< "$head_info"
|
101
102
|
|
102
103
|
# Use the parsed variables
|
103
104
|
echo "Head Instance Index: $head_index"
|
104
105
|
echo "Head Instance IP: $head_ip"
|
106
|
+
echo "Head Instance Status: $head_status"
|
107
|
+
|
108
|
+
# If the head status is not "READY" or "PENDING", exit early
|
109
|
+
if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
|
110
|
+
echo "Head instance status is not READY or PENDING. Exiting."
|
111
|
+
exit 0
|
112
|
+
fi
|
105
113
|
|
106
114
|
else
|
107
115
|
echo "Error: Failed to get head instance information."
|
@@ -278,17 +286,19 @@ class JobPayload:
|
|
278
286
|
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
279
287
|
source = resolve_source(self.source)
|
280
288
|
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
289
|
+
pip_requirements = self.pip_requirements or []
|
281
290
|
|
282
291
|
# Create stage if necessary
|
283
292
|
stage_name = stage_path.parts[0].lstrip("@")
|
284
293
|
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
285
294
|
try:
|
286
|
-
session.sql(
|
295
|
+
session.sql("describe stage identifier(?)", params=[stage_name]).collect()
|
287
296
|
except sp_exceptions.SnowparkSQLException:
|
288
297
|
session.sql(
|
289
|
-
|
298
|
+
"create stage if not exists identifier(?)"
|
290
299
|
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
291
|
-
" comment = 'Created by snowflake.ml.jobs Python API'"
|
300
|
+
" comment = 'Created by snowflake.ml.jobs Python API'",
|
301
|
+
params=[stage_name],
|
292
302
|
).collect()
|
293
303
|
|
294
304
|
# Upload payload to stage
|
@@ -301,6 +311,8 @@ class JobPayload:
|
|
301
311
|
overwrite=True,
|
302
312
|
)
|
303
313
|
source = Path(entrypoint.file_path.parent)
|
314
|
+
if not any(r.startswith("cloudpickle") for r in pip_requirements):
|
315
|
+
pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
|
304
316
|
elif source.is_dir():
|
305
317
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
306
318
|
# can't handle directories. Reduce the number of PUT operations by using
|
@@ -325,10 +337,10 @@ class JobPayload:
|
|
325
337
|
|
326
338
|
# Upload requirements
|
327
339
|
# TODO: Check if payload includes both a requirements.txt file and pip_requirements
|
328
|
-
if
|
340
|
+
if pip_requirements:
|
329
341
|
# Upload requirements.txt to stage
|
330
342
|
session.file.put_stream(
|
331
|
-
io.BytesIO("\n".join(
|
343
|
+
io.BytesIO("\n".join(pip_requirements).encode()),
|
332
344
|
stage_location=stage_path.joinpath("requirements.txt").as_posix(),
|
333
345
|
auto_compress=False,
|
334
346
|
overwrite=True,
|
@@ -495,13 +507,6 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
495
507
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
496
508
|
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
497
509
|
|
498
|
-
func_code = f"""
|
499
|
-
{source_code_comment}
|
500
|
-
|
501
|
-
import pickle
|
502
|
-
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
503
|
-
"""
|
504
|
-
|
505
510
|
arg_dict_name = "kwargs"
|
506
511
|
if getattr(func, constants.IS_MLJOB_REMOTE_ATTR, None):
|
507
512
|
param_code = f"{arg_dict_name} = {{}}"
|
@@ -509,25 +514,29 @@ import pickle
|
|
509
514
|
param_code = _generate_param_handler_code(signature, arg_dict_name)
|
510
515
|
|
511
516
|
return f"""
|
512
|
-
### Version guard to check compatibility across Python versions ###
|
513
|
-
import os
|
514
517
|
import sys
|
515
|
-
import
|
516
|
-
|
517
|
-
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
518
|
-
warnings.warn(
|
519
|
-
"Python version mismatch: job was created using"
|
520
|
-
" python{sys.version_info.major}.{sys.version_info.minor}"
|
521
|
-
f" but runtime environment uses python{{sys.version_info.major}}.{{sys.version_info.minor}}."
|
522
|
-
" Compatibility across Python versions is not guaranteed and may result in unexpected behavior."
|
523
|
-
" This will be fixed in a future release; for now, please use Python version"
|
524
|
-
f" {{sys.version_info.major}}.{{sys.version_info.minor}}.",
|
525
|
-
RuntimeWarning,
|
526
|
-
stacklevel=0,
|
527
|
-
)
|
528
|
-
### End version guard ###
|
518
|
+
import pickle
|
529
519
|
|
530
|
-
|
520
|
+
try:
|
521
|
+
{textwrap.indent(source_code_comment, ' ')}
|
522
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
523
|
+
except (TypeError, pickle.PickleError):
|
524
|
+
if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor != {sys.version_info.minor}:
|
525
|
+
raise RuntimeError(
|
526
|
+
"Failed to deserialize function due to Python version mismatch."
|
527
|
+
f" Runtime environment is Python {{sys.version_info.major}}.{{sys.version_info.minor}}"
|
528
|
+
" but function was serialized using Python {sys.version_info.major}.{sys.version_info.minor}."
|
529
|
+
) from None
|
530
|
+
raise
|
531
|
+
except AttributeError as e:
|
532
|
+
if 'cloudpickle' in str(e):
|
533
|
+
import cloudpickle as cp
|
534
|
+
raise RuntimeError(
|
535
|
+
"Failed to deserialize function due to cloudpickle version mismatch."
|
536
|
+
f" Runtime environment uses cloudpickle=={{cp.__version__}}"
|
537
|
+
" but job was serialized using cloudpickle=={cp.__version__}."
|
538
|
+
) from e
|
539
|
+
raise
|
531
540
|
|
532
541
|
if __name__ == '__main__':
|
533
542
|
{textwrap.indent(param_code, ' ')}
|
@@ -29,7 +29,7 @@ def get_self_ip() -> Optional[str]:
|
|
29
29
|
return None
|
30
30
|
|
31
31
|
|
32
|
-
def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
32
|
+
def get_first_instance(service_name: str) -> Optional[tuple[str, str, str]]:
|
33
33
|
"""Get the first instance of a batch job based on start time and instance ID.
|
34
34
|
|
35
35
|
Args:
|
@@ -42,7 +42,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
|
42
42
|
|
43
43
|
session = session_utils.get_session()
|
44
44
|
df = session.sql(f"show service instances in service {service_name}")
|
45
|
-
result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
|
45
|
+
result = df.select('"instance_id"', '"ip_address"', '"start_time"', '"status"').collect()
|
46
46
|
|
47
47
|
if not result:
|
48
48
|
return None
|
@@ -57,7 +57,7 @@ def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
|
57
57
|
ip_address = head_instance["ip_address"]
|
58
58
|
try:
|
59
59
|
socket.inet_aton(ip_address) # Validate IPv4 address
|
60
|
-
return (head_instance["instance_id"], ip_address)
|
60
|
+
return (head_instance["instance_id"], ip_address, head_instance["status"])
|
61
61
|
except OSError:
|
62
62
|
logger.error(f"Error: Invalid IP address format: {ip_address}")
|
63
63
|
return None
|
@@ -110,7 +110,7 @@ def main():
|
|
110
110
|
head_info = get_first_instance(args.service_name)
|
111
111
|
if head_info:
|
112
112
|
# Print to stdout to allow capture but don't use logger
|
113
|
-
sys.stdout.write(
|
113
|
+
sys.stdout.write(" ".join(head_info) + "\n")
|
114
114
|
sys.exit(0)
|
115
115
|
time.sleep(args.retry_interval)
|
116
116
|
# If we get here, we've timed out
|
@@ -59,7 +59,7 @@ class SimpleJSONEncoder(json.JSONEncoder):
|
|
59
59
|
try:
|
60
60
|
return super().default(obj)
|
61
61
|
except TypeError:
|
62
|
-
return
|
62
|
+
return f"Unserializable object: {repr(obj)}"
|
63
63
|
|
64
64
|
|
65
65
|
def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
|
@@ -11,7 +11,7 @@ from snowflake.ml.jobs._utils import constants, types
|
|
11
11
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
12
12
|
"""Extract resource information for the specified compute pool"""
|
13
13
|
# Get the instance family
|
14
|
-
rows = session.sql(
|
14
|
+
rows = session.sql("show compute pools like ?", params=[compute_pool]).collect()
|
15
15
|
if not rows:
|
16
16
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
17
17
|
instance_family: str = rows[0]["instance_family"]
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -26,6 +26,8 @@ def remote(
|
|
26
26
|
env_vars: Optional[dict[str, str]] = None,
|
27
27
|
num_instances: Optional[int] = None,
|
28
28
|
enable_metrics: bool = False,
|
29
|
+
database: Optional[str] = None,
|
30
|
+
schema: Optional[str] = None,
|
29
31
|
session: Optional[snowpark.Session] = None,
|
30
32
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
31
33
|
"""
|
@@ -40,6 +42,8 @@ def remote(
|
|
40
42
|
env_vars: Environment variables to set in container
|
41
43
|
num_instances: The number of nodes in the job. If none specified, create a single node job.
|
42
44
|
enable_metrics: Whether to enable metrics publishing for the job.
|
45
|
+
database: The database to use for the job.
|
46
|
+
schema: The schema to use for the job.
|
43
47
|
session: The Snowpark session to use. If none specified, uses active session.
|
44
48
|
|
45
49
|
Returns:
|
@@ -67,6 +71,8 @@ def remote(
|
|
67
71
|
env_vars=env_vars,
|
68
72
|
num_instances=num_instances,
|
69
73
|
enable_metrics=enable_metrics,
|
74
|
+
database=database,
|
75
|
+
schema=schema,
|
70
76
|
session=session,
|
71
77
|
)
|
72
78
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
snowflake/ml/jobs/job.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1
1
|
import time
|
2
|
+
from functools import cached_property
|
2
3
|
from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
|
3
4
|
|
4
5
|
import yaml
|
5
6
|
|
6
7
|
from snowflake import snowpark
|
7
8
|
from snowflake.ml._internal import telemetry
|
9
|
+
from snowflake.ml._internal.utils import identifier
|
8
10
|
from snowflake.ml.jobs._utils import constants, interop_utils, types
|
9
|
-
from snowflake.snowpark import context as sp_context
|
11
|
+
from snowflake.snowpark import Row, context as sp_context
|
12
|
+
from snowflake.snowpark.exceptions import SnowparkSQLException
|
10
13
|
|
11
14
|
_PROJECT = "MLJob"
|
12
15
|
TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
|
@@ -28,6 +31,14 @@ class MLJob(Generic[T]):
|
|
28
31
|
self._status: types.JOB_STATUS = "PENDING"
|
29
32
|
self._result: Optional[interop_utils.ExecutionResult] = None
|
30
33
|
|
34
|
+
@cached_property
|
35
|
+
def name(self) -> str:
|
36
|
+
return identifier.parse_schema_level_object_identifier(self.id)[-1]
|
37
|
+
|
38
|
+
@cached_property
|
39
|
+
def num_instances(self) -> int:
|
40
|
+
return _get_num_instances(self._session, self.id)
|
41
|
+
|
31
42
|
@property
|
32
43
|
def id(self) -> str:
|
33
44
|
"""Get the unique job ID"""
|
@@ -67,7 +78,7 @@ class MLJob(Generic[T]):
|
|
67
78
|
"""Get the job's result file location."""
|
68
79
|
result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
|
69
80
|
if result_path is None:
|
70
|
-
raise RuntimeError(f"Job {self.
|
81
|
+
raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
|
71
82
|
return f"{self._stage_path}/{result_path}"
|
72
83
|
|
73
84
|
@overload
|
@@ -128,7 +139,7 @@ class MLJob(Generic[T]):
|
|
128
139
|
start_time = time.monotonic()
|
129
140
|
while self.status not in TERMINAL_JOB_STATUSES:
|
130
141
|
if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
|
131
|
-
raise TimeoutError(f"Job {self.
|
142
|
+
raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
|
132
143
|
time.sleep(delay)
|
133
144
|
delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
|
134
145
|
return self.status
|
@@ -154,11 +165,11 @@ class MLJob(Generic[T]):
|
|
154
165
|
try:
|
155
166
|
self._result = interop_utils.fetch_result(self._session, self._result_path)
|
156
167
|
except Exception as e:
|
157
|
-
raise RuntimeError(f"Failed to retrieve result for job (id={self.
|
168
|
+
raise RuntimeError(f"Failed to retrieve result for job (id={self.name})") from e
|
158
169
|
|
159
170
|
if self._result.success:
|
160
171
|
return cast(T, self._result.result)
|
161
|
-
raise RuntimeError(f"Job execution failed (id={self.
|
172
|
+
raise RuntimeError(f"Job execution failed (id={self.name})") from self._result.exception
|
162
173
|
|
163
174
|
|
164
175
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
|
@@ -172,14 +183,14 @@ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[in
|
|
172
183
|
return cast(types.JOB_STATUS, row["status"])
|
173
184
|
raise ValueError(f"Instance {instance_id} not found in job {job_id}")
|
174
185
|
else:
|
175
|
-
|
186
|
+
row = _get_service_info(session, job_id)
|
176
187
|
return cast(types.JOB_STATUS, row["status"])
|
177
188
|
|
178
189
|
|
179
190
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
180
191
|
def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
|
181
192
|
"""Retrieve job execution service spec."""
|
182
|
-
|
193
|
+
row = _get_service_info(session, job_id)
|
183
194
|
return cast(dict[str, Any], yaml.safe_load(row["spec"]))
|
184
195
|
|
185
196
|
|
@@ -196,10 +207,21 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
196
207
|
|
197
208
|
Returns:
|
198
209
|
The job's execution logs.
|
210
|
+
|
211
|
+
Raises:
|
212
|
+
SnowparkSQLException: if the container is pending
|
213
|
+
RuntimeError: if failed to get head instance_id
|
214
|
+
|
199
215
|
"""
|
200
216
|
# If instance_id is not specified, try to get the head instance ID
|
201
217
|
if instance_id is None:
|
202
|
-
|
218
|
+
try:
|
219
|
+
instance_id = _get_head_instance_id(session, job_id)
|
220
|
+
except RuntimeError:
|
221
|
+
raise RuntimeError(
|
222
|
+
"Failed to retrieve job logs. "
|
223
|
+
"Logs may be inaccessible due to job expiration and can be retrieved from Event Table instead."
|
224
|
+
)
|
203
225
|
|
204
226
|
# Assemble params: [job_id, instance_id, container_name, (optional) limit]
|
205
227
|
params: list[Any] = [
|
@@ -210,10 +232,15 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
|
|
210
232
|
if limit > 0:
|
211
233
|
params.append(limit)
|
212
234
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
235
|
+
try:
|
236
|
+
(row,) = session.sql(
|
237
|
+
f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
|
238
|
+
params=params,
|
239
|
+
).collect()
|
240
|
+
except SnowparkSQLException as e:
|
241
|
+
if "Container Status: PENDING" in e.message:
|
242
|
+
return "Warning: Waiting for container to start. Logs will be shown when available."
|
243
|
+
raise
|
217
244
|
return str(row[0])
|
218
245
|
|
219
246
|
|
@@ -223,18 +250,27 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
223
250
|
Retrieve the head instance ID of a job.
|
224
251
|
|
225
252
|
Args:
|
226
|
-
session: The Snowpark session to use.
|
227
|
-
job_id: The job ID.
|
253
|
+
session (Session): The Snowpark session to use.
|
254
|
+
job_id (str): The job ID.
|
228
255
|
|
229
256
|
Returns:
|
230
|
-
The head instance ID of the job
|
257
|
+
Optional[int]: The head instance ID of the job, or None if the head instance has not started yet.
|
258
|
+
|
259
|
+
Raises:
|
260
|
+
RuntimeError: If the instances died or if some instances disappeared.
|
231
261
|
"""
|
232
262
|
rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
233
263
|
if not rows:
|
234
264
|
return None
|
265
|
+
if _get_num_instances(session, job_id) > len(rows):
|
266
|
+
raise RuntimeError("Couldn’t retrieve head instance due to missing instances.")
|
235
267
|
|
236
268
|
# Sort by start_time first, then by instance_id
|
237
|
-
|
269
|
+
try:
|
270
|
+
sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
271
|
+
except TypeError:
|
272
|
+
raise RuntimeError("Job instance information unavailable.")
|
273
|
+
|
238
274
|
head_instance = sorted_instances[0]
|
239
275
|
if not head_instance["start_time"]:
|
240
276
|
# If head instance hasn't started yet, return None
|
@@ -243,3 +279,14 @@ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[in
|
|
243
279
|
return int(head_instance["instance_id"])
|
244
280
|
except (ValueError, TypeError):
|
245
281
|
return 0
|
282
|
+
|
283
|
+
|
284
|
+
def _get_service_info(session: snowpark.Session, job_id: str) -> Row:
|
285
|
+
(row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
|
286
|
+
return row
|
287
|
+
|
288
|
+
|
289
|
+
@telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
|
290
|
+
def _get_num_instances(session: snowpark.Session, job_id: str) -> int:
|
291
|
+
row = _get_service_info(session, job_id)
|
292
|
+
return int(row["target_instances"]) if row["target_instances"] else 0
|