snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/platform_capabilities.py +4 -0
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/experiment/experiment_tracking.py +63 -19
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +50 -11
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +54 -36
- snowflake/ml/model/__init__.py +16 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
- snowflake/ml/model/_client/model/model_version_impl.py +44 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +50 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +48 -21
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +1 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import logging
|
|
3
|
+
import pickle
|
|
4
|
+
import posixpath
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Callable, Optional, Protocol, Union, cast, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from snowflake import snowpark
|
|
9
|
+
from snowflake.ml.jobs._interop import data_utils
|
|
10
|
+
from snowflake.ml.jobs._interop.dto_schema import (
|
|
11
|
+
BinaryManifest,
|
|
12
|
+
ParquetManifest,
|
|
13
|
+
ProtocolInfo,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
Condition = Union[type, tuple[type, ...], Callable[[Any], bool], None]
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SerializationError(TypeError):
|
|
22
|
+
"""Exception raised when a serialization protocol fails."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DeserializationError(ValueError):
|
|
26
|
+
"""Exception raised when a serialization protocol fails."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InvalidPayloadError(DeserializationError):
|
|
30
|
+
"""Exception raised when the payload is invalid."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ProtocolMismatchError(DeserializationError):
|
|
34
|
+
"""Exception raised when the protocol of the serialization protocol is incompatible."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class VersionMismatchError(ProtocolMismatchError):
|
|
38
|
+
"""Exception raised when the version of the serialization protocol is incompatible."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ProtocolNotFoundError(SerializationError):
|
|
42
|
+
"""Exception raised when no suitable serialization protocol is available."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@runtime_checkable
|
|
46
|
+
class SerializationProtocol(Protocol):
|
|
47
|
+
"""
|
|
48
|
+
More advanced protocol which supports more flexibility in how results are saved or loaded.
|
|
49
|
+
Results can be saved as one or more files, or directly inline in the PayloadManifest.
|
|
50
|
+
If saving as files, the PayloadManifest can save arbitrary "manifest" information.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def supported_types(self) -> Condition:
|
|
55
|
+
"""The types that the protocol supports."""
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
59
|
+
"""The information about the protocol."""
|
|
60
|
+
|
|
61
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
62
|
+
"""Save the object to the destination directory."""
|
|
63
|
+
|
|
64
|
+
def load(
|
|
65
|
+
self,
|
|
66
|
+
payload_info: ProtocolInfo,
|
|
67
|
+
session: Optional[snowpark.Session] = None,
|
|
68
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
69
|
+
) -> Any:
|
|
70
|
+
"""Load the object from the source directory."""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class CloudPickleProtocol(SerializationProtocol):
|
|
74
|
+
"""
|
|
75
|
+
CloudPickle serialization protocol.
|
|
76
|
+
Uses BinaryManifest for manifest schema.
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
DEFAULT_PATH = "mljob_extra.pkl"
|
|
80
|
+
|
|
81
|
+
def __init__(self) -> None:
|
|
82
|
+
import cloudpickle as cp
|
|
83
|
+
|
|
84
|
+
self._backend = cp
|
|
85
|
+
|
|
86
|
+
def _get_compatibility_error(self, payload_info: ProtocolInfo) -> Optional[Exception]:
|
|
87
|
+
"""Check compatibility and attempt load, raising helpful errors on failure."""
|
|
88
|
+
version_error = python_error = None
|
|
89
|
+
|
|
90
|
+
# Check cloudpickle version compatibility
|
|
91
|
+
if payload_info.version:
|
|
92
|
+
try:
|
|
93
|
+
from packaging import version
|
|
94
|
+
|
|
95
|
+
payload_major, current_major = (
|
|
96
|
+
version.parse(payload_info.version).major,
|
|
97
|
+
version.parse(self._backend.__version__).major,
|
|
98
|
+
)
|
|
99
|
+
if payload_major != current_major:
|
|
100
|
+
version_error = "cloudpickle version mismatch: payload={}, current={}".format(
|
|
101
|
+
payload_info.version, self._backend.__version__
|
|
102
|
+
)
|
|
103
|
+
except Exception:
|
|
104
|
+
if payload_info.version != self.protocol_info.version:
|
|
105
|
+
version_error = "cloudpickle version mismatch: payload={}, current={}".format(
|
|
106
|
+
payload_info.version, self.protocol_info.version
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Check Python version compatibility
|
|
110
|
+
if payload_info.metadata and "python_version" in payload_info.metadata:
|
|
111
|
+
payload_py, current_py = (
|
|
112
|
+
payload_info.metadata["python_version"],
|
|
113
|
+
f"{sys.version_info.major}.{sys.version_info.minor}",
|
|
114
|
+
)
|
|
115
|
+
if payload_py != current_py:
|
|
116
|
+
python_error = f"Python version mismatch: payload={payload_py}, current={current_py}"
|
|
117
|
+
|
|
118
|
+
if version_error or python_error:
|
|
119
|
+
errors = [err for err in [version_error, python_error] if err]
|
|
120
|
+
return VersionMismatchError(f"Load failed due to incompatibility: {'; '.join(errors)}")
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def supported_types(self) -> Condition:
|
|
125
|
+
return None # All types are supported
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
129
|
+
return ProtocolInfo(
|
|
130
|
+
name="cloudpickle",
|
|
131
|
+
version=self._backend.__version__,
|
|
132
|
+
metadata={
|
|
133
|
+
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}",
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
138
|
+
"""Save the object to the destination directory."""
|
|
139
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH)
|
|
140
|
+
with data_utils.open_stream(result_path, "wb", session=session) as f:
|
|
141
|
+
self._backend.dump(obj, f)
|
|
142
|
+
manifest: BinaryManifest = {"path": result_path}
|
|
143
|
+
return self.protocol_info.with_manifest(manifest)
|
|
144
|
+
|
|
145
|
+
def load(
|
|
146
|
+
self,
|
|
147
|
+
payload_info: ProtocolInfo,
|
|
148
|
+
session: Optional[snowpark.Session] = None,
|
|
149
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
150
|
+
) -> Any:
|
|
151
|
+
"""Load the object from the source directory."""
|
|
152
|
+
if payload_info.name != self.protocol_info.name:
|
|
153
|
+
raise ProtocolMismatchError(
|
|
154
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
payload_manifest = cast(BinaryManifest, payload_info.manifest)
|
|
158
|
+
try:
|
|
159
|
+
if payload_bytes := payload_manifest.get("bytes"):
|
|
160
|
+
return self._backend.loads(payload_bytes)
|
|
161
|
+
if payload_b64 := payload_manifest.get("base64"):
|
|
162
|
+
return self._backend.loads(base64.b64decode(payload_b64))
|
|
163
|
+
result_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
|
|
164
|
+
with data_utils.open_stream(result_path, "rb", session=session) as f:
|
|
165
|
+
return self._backend.load(f)
|
|
166
|
+
except (
|
|
167
|
+
pickle.UnpicklingError,
|
|
168
|
+
TypeError,
|
|
169
|
+
AttributeError,
|
|
170
|
+
MemoryError,
|
|
171
|
+
) as pickle_error:
|
|
172
|
+
if error := self._get_compatibility_error(payload_info):
|
|
173
|
+
raise error from pickle_error
|
|
174
|
+
raise
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class ArrowTableProtocol(SerializationProtocol):
|
|
178
|
+
"""
|
|
179
|
+
Arrow Table serialization protocol.
|
|
180
|
+
Uses ParquetManifest for manifest schema.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
|
|
184
|
+
|
|
185
|
+
def __init__(self) -> None:
|
|
186
|
+
import pyarrow as pa
|
|
187
|
+
import pyarrow.parquet as pq
|
|
188
|
+
|
|
189
|
+
self._pa = pa
|
|
190
|
+
self._pq = pq
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def supported_types(self) -> Condition:
|
|
194
|
+
return cast(type, self._pa.Table)
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
198
|
+
return ProtocolInfo(
|
|
199
|
+
name="pyarrow",
|
|
200
|
+
version=self._pa.__version__,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
204
|
+
"""Save the object to the destination directory."""
|
|
205
|
+
if not isinstance(obj, self._pa.Table):
|
|
206
|
+
raise SerializationError(f"Expected {self._pa.Table.__name__} object, got {type(obj).__name__}")
|
|
207
|
+
|
|
208
|
+
# TODO: Support partitioned writes for large datasets
|
|
209
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
|
|
210
|
+
with data_utils.open_stream(result_path, "wb", session=session) as stream:
|
|
211
|
+
self._pq.write_table(obj, stream)
|
|
212
|
+
|
|
213
|
+
manifest: ParquetManifest = {"paths": [result_path]}
|
|
214
|
+
return self.protocol_info.with_manifest(manifest)
|
|
215
|
+
|
|
216
|
+
def load(
|
|
217
|
+
self,
|
|
218
|
+
payload_info: ProtocolInfo,
|
|
219
|
+
session: Optional[snowpark.Session] = None,
|
|
220
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
221
|
+
) -> Any:
|
|
222
|
+
"""Load the object from the source directory."""
|
|
223
|
+
if payload_info.name != self.protocol_info.name:
|
|
224
|
+
raise ProtocolMismatchError(
|
|
225
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
payload_manifest = cast(ParquetManifest, payload_info.manifest)
|
|
229
|
+
tables = []
|
|
230
|
+
for path in payload_manifest["paths"]:
|
|
231
|
+
transformed_path = path_transform(path) if path_transform else path
|
|
232
|
+
with data_utils.open_stream(transformed_path, "rb", session=session) as f:
|
|
233
|
+
table = self._pq.read_table(f)
|
|
234
|
+
tables.append(table)
|
|
235
|
+
return self._pa.concat_tables(tables) if len(tables) > 1 else tables[0]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class PandasDataFrameProtocol(SerializationProtocol):
|
|
239
|
+
"""
|
|
240
|
+
Pandas DataFrame serialization protocol.
|
|
241
|
+
Uses ParquetManifest for manifest schema.
|
|
242
|
+
"""
|
|
243
|
+
|
|
244
|
+
DEFAULT_PATH_PATTERN = "mljob_extra_{0}.parquet"
|
|
245
|
+
|
|
246
|
+
def __init__(self) -> None:
|
|
247
|
+
import pandas as pd
|
|
248
|
+
|
|
249
|
+
self._pd = pd
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def supported_types(self) -> Condition:
|
|
253
|
+
return cast(type, self._pd.DataFrame)
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
257
|
+
return ProtocolInfo(
|
|
258
|
+
name="pandas",
|
|
259
|
+
version=self._pd.__version__,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
263
|
+
"""Save the object to the destination directory."""
|
|
264
|
+
if not isinstance(obj, self._pd.DataFrame):
|
|
265
|
+
raise SerializationError(f"Expected {self._pd.DataFrame.__name__} object, got {type(obj).__name__}")
|
|
266
|
+
|
|
267
|
+
# TODO: Support partitioned writes for large datasets
|
|
268
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN.format(0))
|
|
269
|
+
with data_utils.open_stream(result_path, "wb", session=session) as stream:
|
|
270
|
+
obj.to_parquet(stream)
|
|
271
|
+
|
|
272
|
+
manifest: ParquetManifest = {"paths": [result_path]}
|
|
273
|
+
return self.protocol_info.with_manifest(manifest)
|
|
274
|
+
|
|
275
|
+
def load(
|
|
276
|
+
self,
|
|
277
|
+
payload_info: ProtocolInfo,
|
|
278
|
+
session: Optional[snowpark.Session] = None,
|
|
279
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
280
|
+
) -> Any:
|
|
281
|
+
"""Load the object from the source directory."""
|
|
282
|
+
if payload_info.name != self.protocol_info.name:
|
|
283
|
+
raise ProtocolMismatchError(
|
|
284
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
payload_manifest = cast(ParquetManifest, payload_info.manifest)
|
|
288
|
+
dfs = []
|
|
289
|
+
for path in payload_manifest["paths"]:
|
|
290
|
+
transformed_path = path_transform(path) if path_transform else path
|
|
291
|
+
with data_utils.open_stream(transformed_path, "rb", session=session) as f:
|
|
292
|
+
df = self._pd.read_parquet(f)
|
|
293
|
+
dfs.append(df)
|
|
294
|
+
return self._pd.concat(dfs) if len(dfs) > 1 else dfs[0]
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
class NumpyArrayProtocol(SerializationProtocol):
|
|
298
|
+
"""
|
|
299
|
+
Numpy Array serialization protocol.
|
|
300
|
+
Uses BinaryManifest for manifest schema.
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
DEFAULT_PATH_PATTERN = "mljob_extra.npy"
|
|
304
|
+
|
|
305
|
+
def __init__(self) -> None:
|
|
306
|
+
import numpy as np
|
|
307
|
+
|
|
308
|
+
self._np = np
|
|
309
|
+
|
|
310
|
+
@property
|
|
311
|
+
def supported_types(self) -> Condition:
|
|
312
|
+
return cast(type, self._np.ndarray)
|
|
313
|
+
|
|
314
|
+
@property
|
|
315
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
316
|
+
return ProtocolInfo(
|
|
317
|
+
name="numpy",
|
|
318
|
+
version=self._np.__version__,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
322
|
+
"""Save the object to the destination directory."""
|
|
323
|
+
if not isinstance(obj, self._np.ndarray):
|
|
324
|
+
raise SerializationError(f"Expected {self._np.ndarray.__name__} object, got {type(obj).__name__}")
|
|
325
|
+
result_path = posixpath.join(dest_dir, self.DEFAULT_PATH_PATTERN)
|
|
326
|
+
with data_utils.open_stream(result_path, "wb", session=session) as stream:
|
|
327
|
+
self._np.save(stream, obj)
|
|
328
|
+
|
|
329
|
+
manifest: BinaryManifest = {"path": result_path}
|
|
330
|
+
return self.protocol_info.with_manifest(manifest)
|
|
331
|
+
|
|
332
|
+
def load(
|
|
333
|
+
self,
|
|
334
|
+
payload_info: ProtocolInfo,
|
|
335
|
+
session: Optional[snowpark.Session] = None,
|
|
336
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
337
|
+
) -> Any:
|
|
338
|
+
"""Load the object from the source directory."""
|
|
339
|
+
if payload_info.name != self.protocol_info.name:
|
|
340
|
+
raise ProtocolMismatchError(
|
|
341
|
+
f"Invalid payload protocol: expected '{self.protocol_info.name}', got '{payload_info.name}'"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
payload_manifest = cast(BinaryManifest, payload_info.manifest)
|
|
345
|
+
transformed_path = path_transform(payload_manifest["path"]) if path_transform else payload_manifest["path"]
|
|
346
|
+
with data_utils.open_stream(transformed_path, "rb", session=session) as f:
|
|
347
|
+
return self._np.load(f)
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class AutoProtocol(SerializationProtocol):
|
|
351
|
+
def __init__(self) -> None:
|
|
352
|
+
self._protocols: list[SerializationProtocol] = []
|
|
353
|
+
self._protocol_info = ProtocolInfo(
|
|
354
|
+
name="auto",
|
|
355
|
+
version=None,
|
|
356
|
+
metadata=None,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
@property
|
|
360
|
+
def supported_types(self) -> Condition:
|
|
361
|
+
return None # All types are supported
|
|
362
|
+
|
|
363
|
+
@property
|
|
364
|
+
def protocol_info(self) -> ProtocolInfo:
|
|
365
|
+
return self._protocol_info
|
|
366
|
+
|
|
367
|
+
def try_register_protocol(
|
|
368
|
+
self,
|
|
369
|
+
klass: type[SerializationProtocol],
|
|
370
|
+
*args: Any,
|
|
371
|
+
index: int = 0,
|
|
372
|
+
**kwargs: Any,
|
|
373
|
+
) -> None:
|
|
374
|
+
"""
|
|
375
|
+
Try to construct and register a protocol. If the protocol cannot be constructed,
|
|
376
|
+
log a warning and skip registration. By default (index=0), the most recently
|
|
377
|
+
registered protocol takes precedence.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
klass: The class of the protocol to register.
|
|
381
|
+
args: The positional arguments to pass to the protocol constructor.
|
|
382
|
+
index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
|
|
383
|
+
kwargs: The keyword arguments to pass to the protocol constructor.
|
|
384
|
+
"""
|
|
385
|
+
try:
|
|
386
|
+
protocol = klass(*args, **kwargs)
|
|
387
|
+
self.register_protocol(protocol, index=index)
|
|
388
|
+
except Exception as e:
|
|
389
|
+
logger.warning(f"Failed to register protocol {klass}: {e}")
|
|
390
|
+
|
|
391
|
+
def register_protocol(
|
|
392
|
+
self,
|
|
393
|
+
protocol: SerializationProtocol,
|
|
394
|
+
index: int = 0,
|
|
395
|
+
) -> None:
|
|
396
|
+
"""
|
|
397
|
+
Register a protocol with a condition. By default (index=0), the most recently
|
|
398
|
+
registered protocol takes precedence.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
protocol: The protocol to register.
|
|
402
|
+
index: The index to register the protocol at. If -1, the protocol is registered at the end of the list.
|
|
403
|
+
|
|
404
|
+
Raises:
|
|
405
|
+
ValueError: If the condition is invalid.
|
|
406
|
+
ValueError: If the index is invalid.
|
|
407
|
+
"""
|
|
408
|
+
# Validate condition
|
|
409
|
+
# TODO: Build lookup table of supported types to protocols (in priority order)
|
|
410
|
+
# for faster lookup at save/load time (instead of iterating over all protocols)
|
|
411
|
+
if not isinstance(protocol, SerializationProtocol):
|
|
412
|
+
raise ValueError(f"Invalid protocol type: {type(protocol)}. Expected SerializationProtocol.")
|
|
413
|
+
if index == -1:
|
|
414
|
+
self._protocols.append(protocol)
|
|
415
|
+
elif index < 0:
|
|
416
|
+
raise ValueError(f"Invalid index: {index}. Expected -1 or >= 0.")
|
|
417
|
+
else:
|
|
418
|
+
self._protocols.insert(index, protocol)
|
|
419
|
+
|
|
420
|
+
def save(self, obj: Any, dest_dir: str, session: Optional[snowpark.Session] = None) -> ProtocolInfo:
|
|
421
|
+
"""Save the object to the destination directory."""
|
|
422
|
+
last_protocol_error = None
|
|
423
|
+
for protocol in self._protocols:
|
|
424
|
+
try:
|
|
425
|
+
if self._is_supported_type(obj, protocol):
|
|
426
|
+
logger.debug(f"Dumping object of type {type(obj)} with protocol {protocol}")
|
|
427
|
+
return protocol.save(obj, dest_dir, session)
|
|
428
|
+
except Exception as e:
|
|
429
|
+
logger.warning(f"Error dumping object {obj} with protocol {protocol}: {repr(e)}")
|
|
430
|
+
last_protocol_error = (protocol.protocol_info, e)
|
|
431
|
+
last_error_str = (
|
|
432
|
+
f", most recent error ({last_protocol_error[0]}): {repr(last_protocol_error[1])}"
|
|
433
|
+
if last_protocol_error
|
|
434
|
+
else ""
|
|
435
|
+
)
|
|
436
|
+
raise ProtocolNotFoundError(
|
|
437
|
+
f"No suitable protocol found for type {type(obj).__name__}"
|
|
438
|
+
f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)}){last_error_str}"
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
def load(
|
|
442
|
+
self,
|
|
443
|
+
payload_info: ProtocolInfo,
|
|
444
|
+
session: Optional[snowpark.Session] = None,
|
|
445
|
+
path_transform: Optional[Callable[[str], str]] = None,
|
|
446
|
+
) -> Any:
|
|
447
|
+
"""Load the object from the source directory."""
|
|
448
|
+
last_error = None
|
|
449
|
+
for protocol in self._protocols:
|
|
450
|
+
if protocol.protocol_info.name == payload_info.name:
|
|
451
|
+
try:
|
|
452
|
+
return protocol.load(payload_info, session, path_transform)
|
|
453
|
+
except Exception as e:
|
|
454
|
+
logger.warning(f"Error loading object with protocol {protocol}: {repr(e)}")
|
|
455
|
+
last_error = e
|
|
456
|
+
if last_error:
|
|
457
|
+
raise last_error
|
|
458
|
+
raise ProtocolNotFoundError(
|
|
459
|
+
f"No protocol matching {payload_info} available"
|
|
460
|
+
f" (available: {', '.join(str(p.protocol_info) for p in self._protocols)})"
|
|
461
|
+
", possibly due to snowflake-ml-python package version mismatch"
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
def _is_supported_type(self, obj: Any, protocol: SerializationProtocol) -> bool:
|
|
465
|
+
if protocol.supported_types is None:
|
|
466
|
+
return True # None means all types are supported
|
|
467
|
+
elif isinstance(protocol.supported_types, (type, tuple)):
|
|
468
|
+
return isinstance(obj, protocol.supported_types)
|
|
469
|
+
elif callable(protocol.supported_types):
|
|
470
|
+
return protocol.supported_types(obj) is True
|
|
471
|
+
raise ValueError(f"Invalid supported types: {protocol.supported_types} for protocol {protocol}")
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dataclass(frozen=True)
|
|
6
|
+
class ExecutionResult:
|
|
7
|
+
"""
|
|
8
|
+
A result of a job execution.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
success: Whether the execution was successful.
|
|
12
|
+
value: The value of the execution.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
success: bool
|
|
16
|
+
value: Any
|
|
17
|
+
|
|
18
|
+
def get_value(self, wrap_exceptions: bool = True) -> Any:
|
|
19
|
+
if not self.success:
|
|
20
|
+
assert isinstance(self.value, BaseException), "Unexpected non-exception value for failed result"
|
|
21
|
+
self._raise_exception(self.value, wrap_exceptions)
|
|
22
|
+
return self.value
|
|
23
|
+
|
|
24
|
+
def _raise_exception(self, exception: BaseException, wrap_exceptions: bool) -> None:
|
|
25
|
+
if wrap_exceptions:
|
|
26
|
+
raise RuntimeError(f"Job execution failed with error: {exception!r}") from exception
|
|
27
|
+
else:
|
|
28
|
+
raise exception
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class LoadedExecutionResult(ExecutionResult):
|
|
33
|
+
"""
|
|
34
|
+
A result of a job execution that has been loaded from a file.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
load_error: Optional[Exception] = None
|
|
38
|
+
result_metadata: Optional[dict[str, Any]] = None
|
|
39
|
+
|
|
40
|
+
def get_value(self, wrap_exceptions: bool = True) -> Any:
|
|
41
|
+
if not self.success:
|
|
42
|
+
# Raise the original exception if available, otherwise raise the load error
|
|
43
|
+
ex = self.value
|
|
44
|
+
if not isinstance(ex, BaseException):
|
|
45
|
+
ex = RuntimeError(f"Unknown error {ex or ''}")
|
|
46
|
+
ex.__cause__ = self.load_error
|
|
47
|
+
self._raise_exception(ex, wrap_exceptions)
|
|
48
|
+
else:
|
|
49
|
+
if self.load_error:
|
|
50
|
+
raise ValueError("Job execution succeeded but result retrieval failed") from self.load_error
|
|
51
|
+
return self.value
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import traceback
|
|
4
|
+
from pathlib import PurePath
|
|
5
|
+
from typing import Any, Callable, Optional
|
|
6
|
+
|
|
7
|
+
import pydantic
|
|
8
|
+
|
|
9
|
+
from snowflake import snowpark
|
|
10
|
+
from snowflake.ml.jobs._interop import data_utils, exception_utils, legacy, protocols
|
|
11
|
+
from snowflake.ml.jobs._interop.dto_schema import (
|
|
12
|
+
ExceptionMetadata,
|
|
13
|
+
ResultDTO,
|
|
14
|
+
ResultMetadata,
|
|
15
|
+
)
|
|
16
|
+
from snowflake.ml.jobs._interop.results import ExecutionResult, LoadedExecutionResult
|
|
17
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
|
18
|
+
|
|
19
|
+
DEFAULT_CODEC = data_utils.JsonDtoCodec
|
|
20
|
+
DEFAULT_PROTOCOL = protocols.AutoProtocol()
|
|
21
|
+
DEFAULT_PROTOCOL.try_register_protocol(protocols.CloudPickleProtocol)
|
|
22
|
+
DEFAULT_PROTOCOL.try_register_protocol(protocols.ArrowTableProtocol)
|
|
23
|
+
DEFAULT_PROTOCOL.try_register_protocol(protocols.PandasDataFrameProtocol)
|
|
24
|
+
DEFAULT_PROTOCOL.try_register_protocol(protocols.NumpyArrayProtocol)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def save_result(result: ExecutionResult, path: str, session: Optional[snowpark.Session] = None) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Save the result to a file.
|
|
33
|
+
"""
|
|
34
|
+
result_dto = ResultDTO(
|
|
35
|
+
success=result.success,
|
|
36
|
+
value=result.value,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
# Try to encode result directly
|
|
41
|
+
payload = DEFAULT_CODEC.encode(result_dto)
|
|
42
|
+
except TypeError:
|
|
43
|
+
result_dto.value = None # Remove raw value to avoid serialization error
|
|
44
|
+
result_dto.metadata = _get_metadata(result.value) # Add metadata for client fallback on protocol mismatch
|
|
45
|
+
try:
|
|
46
|
+
path_dir = PurePath(path).parent.as_posix()
|
|
47
|
+
protocol_info = DEFAULT_PROTOCOL.save(result.value, path_dir, session=session)
|
|
48
|
+
result_dto.protocol = protocol_info
|
|
49
|
+
|
|
50
|
+
except Exception as e:
|
|
51
|
+
logger.warning(f"Error dumping result value: {repr(e)}")
|
|
52
|
+
result_dto.serialize_error = repr(e)
|
|
53
|
+
|
|
54
|
+
# Encode the modified result DTO
|
|
55
|
+
payload = DEFAULT_CODEC.encode(result_dto)
|
|
56
|
+
|
|
57
|
+
with data_utils.open_stream(path, "wb", session=session) as stream:
|
|
58
|
+
stream.write(payload)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def load_result(
|
|
62
|
+
path: str, session: Optional[snowpark.Session] = None, path_transform: Optional[Callable[[str], str]] = None
|
|
63
|
+
) -> ExecutionResult:
|
|
64
|
+
"""Load the result from a file on a Snowflake stage."""
|
|
65
|
+
try:
|
|
66
|
+
with data_utils.open_stream(path, "r", session=session) as stream:
|
|
67
|
+
# Load the DTO as a dict for easy fallback to legacy loading if necessary
|
|
68
|
+
dto_dict = DEFAULT_CODEC.decode(stream, as_dict=True)
|
|
69
|
+
except UnicodeDecodeError:
|
|
70
|
+
# Path may be a legacy result file (cloudpickle)
|
|
71
|
+
# TODO: Re-use the stream
|
|
72
|
+
assert session is not None
|
|
73
|
+
return legacy.load_legacy_result(session, path)
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
dto = ResultDTO.model_validate(dto_dict)
|
|
77
|
+
except pydantic.ValidationError as e:
|
|
78
|
+
if "success" in dto_dict:
|
|
79
|
+
assert session is not None
|
|
80
|
+
if path.endswith(".json"):
|
|
81
|
+
path = os.path.splitext(path)[0] + ".pkl"
|
|
82
|
+
return legacy.load_legacy_result(session, path, result_json=dto_dict)
|
|
83
|
+
raise ValueError("Invalid result schema") from e
|
|
84
|
+
|
|
85
|
+
# Try loading data from file using the protocol info
|
|
86
|
+
result_value = None
|
|
87
|
+
data_load_error = None
|
|
88
|
+
if dto.protocol is not None:
|
|
89
|
+
try:
|
|
90
|
+
logger.debug(f"Loading result value with protocol {dto.protocol}")
|
|
91
|
+
result_value = DEFAULT_PROTOCOL.load(dto.protocol, session=session, path_transform=path_transform)
|
|
92
|
+
except sp_exceptions.SnowparkSQLException:
|
|
93
|
+
raise # Data retrieval errors should be bubbled up
|
|
94
|
+
except Exception as e:
|
|
95
|
+
logger.debug(f"Error loading result value with protocol {dto.protocol}: {repr(e)}")
|
|
96
|
+
data_load_error = e
|
|
97
|
+
|
|
98
|
+
# Wrap serialize_error in a TypeError
|
|
99
|
+
if dto.serialize_error:
|
|
100
|
+
serialize_error = TypeError("Original result serialization failed with error: " + dto.serialize_error)
|
|
101
|
+
if data_load_error:
|
|
102
|
+
data_load_error.__context__ = serialize_error
|
|
103
|
+
else:
|
|
104
|
+
data_load_error = serialize_error
|
|
105
|
+
|
|
106
|
+
# Prepare to assemble the final result
|
|
107
|
+
result_value = result_value if result_value is not None else dto.value
|
|
108
|
+
if not dto.success and result_value is None:
|
|
109
|
+
# Try to reconstruct exception from metadata if available
|
|
110
|
+
if isinstance(dto.metadata, ExceptionMetadata):
|
|
111
|
+
logger.debug(f"Reconstructing exception from metadata {dto.metadata}")
|
|
112
|
+
result_value = exception_utils.build_exception(
|
|
113
|
+
type_str=dto.metadata.type,
|
|
114
|
+
message=dto.metadata.message,
|
|
115
|
+
traceback=dto.metadata.traceback,
|
|
116
|
+
original_repr=dto.metadata.repr,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Generate a generic error if we still don't have a value,
|
|
120
|
+
# attaching the data load error if any
|
|
121
|
+
if result_value is None:
|
|
122
|
+
result_value = exception_utils.RemoteError("Unknown remote error")
|
|
123
|
+
result_value.__cause__ = data_load_error
|
|
124
|
+
|
|
125
|
+
return LoadedExecutionResult(
|
|
126
|
+
success=dto.success,
|
|
127
|
+
value=result_value,
|
|
128
|
+
load_error=data_load_error,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _get_metadata(value: Any) -> ResultMetadata:
|
|
133
|
+
type_name = f"{type(value).__module__}.{type(value).__name__}"
|
|
134
|
+
if isinstance(value, BaseException):
|
|
135
|
+
return ExceptionMetadata(
|
|
136
|
+
type=type_name,
|
|
137
|
+
repr=repr(value),
|
|
138
|
+
message=str(value),
|
|
139
|
+
traceback="".join(traceback.format_tb(value.__traceback__)),
|
|
140
|
+
)
|
|
141
|
+
return ResultMetadata(
|
|
142
|
+
type=type_name,
|
|
143
|
+
repr=repr(value),
|
|
144
|
+
)
|
|
@@ -12,6 +12,9 @@ PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
|
|
12
12
|
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
|
13
13
|
MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
|
|
14
14
|
TARGET_INSTANCES_ENV_VAR = "SNOWFLAKE_JOBS_COUNT"
|
|
15
|
+
INSTANCES_MIN_WAIT_ENV_VAR = "MLRS_INSTANCES_MIN_WAIT"
|
|
16
|
+
INSTANCES_TIMEOUT_ENV_VAR = "MLRS_INSTANCES_TIMEOUT"
|
|
17
|
+
INSTANCES_CHECK_INTERVAL_ENV_VAR = "MLRS_INSTANCES_CHECK_INTERVAL"
|
|
15
18
|
RUNTIME_IMAGE_TAG_ENV_VAR = "MLRS_CONTAINER_IMAGE_TAG"
|
|
16
19
|
|
|
17
20
|
# Stage mount paths
|
|
@@ -19,7 +22,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/job_stage"
|
|
|
19
22
|
APP_STAGE_SUBPATH = "app"
|
|
20
23
|
SYSTEM_STAGE_SUBPATH = "system"
|
|
21
24
|
OUTPUT_STAGE_SUBPATH = "output"
|
|
22
|
-
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result
|
|
25
|
+
RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result"
|
|
23
26
|
|
|
24
27
|
# Default container image information
|
|
25
28
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|