snowflake-ml-python 1.15.0__py3-none-any.whl → 1.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/human_readable_id/adjectives.txt +5 -5
- snowflake/ml/_internal/human_readable_id/animals.txt +3 -3
- snowflake/ml/_internal/platform_capabilities.py +4 -0
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/experiment/experiment_tracking.py +63 -19
- snowflake/ml/jobs/__init__.py +4 -0
- snowflake/ml/jobs/_interop/__init__.py +0 -0
- snowflake/ml/jobs/_interop/data_utils.py +124 -0
- snowflake/ml/jobs/_interop/dto_schema.py +95 -0
- snowflake/ml/jobs/{_utils/interop_utils.py → _interop/exception_utils.py} +49 -178
- snowflake/ml/jobs/_interop/legacy.py +225 -0
- snowflake/ml/jobs/_interop/protocols.py +471 -0
- snowflake/ml/jobs/_interop/results.py +51 -0
- snowflake/ml/jobs/_interop/utils.py +144 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/feature_flags.py +37 -5
- snowflake/ml/jobs/_utils/payload_utils.py +1 -1
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +139 -102
- snowflake/ml/jobs/_utils/spec_utils.py +50 -11
- snowflake/ml/jobs/_utils/types.py +10 -0
- snowflake/ml/jobs/job.py +168 -36
- snowflake/ml/jobs/manager.py +54 -36
- snowflake/ml/model/__init__.py +16 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +18 -2
- snowflake/ml/model/_client/model/model_version_impl.py +44 -7
- snowflake/ml/model/_client/ops/model_ops.py +4 -0
- snowflake/ml/model/_client/ops/service_ops.py +50 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/model_version.py +3 -1
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +32 -4
- snowflake/ml/model/_model_composer/model_method/utils.py +28 -0
- snowflake/ml/model/_packager/model_env/model_env.py +48 -21
- snowflake/ml/model/_packager/model_meta/model_meta.py +8 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +5 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +1 -0
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +27 -0
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/METADATA +65 -5
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/RECORD +201 -192
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.15.0.dist-info → snowflake_ml_python-1.17.0.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
aerial
|
|
1
2
|
afraid
|
|
2
3
|
ancient
|
|
3
4
|
angry
|
|
@@ -26,7 +27,6 @@ dull
|
|
|
26
27
|
empty
|
|
27
28
|
evil
|
|
28
29
|
fast
|
|
29
|
-
fat
|
|
30
30
|
fluffy
|
|
31
31
|
foolish
|
|
32
32
|
fresh
|
|
@@ -57,10 +57,10 @@ lovely
|
|
|
57
57
|
lucky
|
|
58
58
|
massive
|
|
59
59
|
mean
|
|
60
|
+
metallic
|
|
60
61
|
mighty
|
|
61
62
|
modern
|
|
62
63
|
moody
|
|
63
|
-
nasty
|
|
64
64
|
neat
|
|
65
65
|
nervous
|
|
66
66
|
new
|
|
@@ -85,7 +85,6 @@ rotten
|
|
|
85
85
|
rude
|
|
86
86
|
selfish
|
|
87
87
|
serious
|
|
88
|
-
shaggy
|
|
89
88
|
sharp
|
|
90
89
|
short
|
|
91
90
|
shy
|
|
@@ -96,14 +95,15 @@ slippery
|
|
|
96
95
|
smart
|
|
97
96
|
smooth
|
|
98
97
|
soft
|
|
98
|
+
solid
|
|
99
99
|
sour
|
|
100
100
|
spicy
|
|
101
101
|
splendid
|
|
102
102
|
spotty
|
|
103
|
+
squishy
|
|
103
104
|
stale
|
|
104
105
|
strange
|
|
105
106
|
strong
|
|
106
|
-
stupid
|
|
107
107
|
sweet
|
|
108
108
|
swift
|
|
109
109
|
tall
|
|
@@ -116,7 +116,6 @@ tidy
|
|
|
116
116
|
tiny
|
|
117
117
|
tough
|
|
118
118
|
tricky
|
|
119
|
-
ugly
|
|
120
119
|
warm
|
|
121
120
|
weak
|
|
122
121
|
wet
|
|
@@ -124,5 +123,6 @@ wicked
|
|
|
124
123
|
wise
|
|
125
124
|
witty
|
|
126
125
|
wonderful
|
|
126
|
+
wooden
|
|
127
127
|
yellow
|
|
128
128
|
young
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
anaconda
|
|
2
2
|
ant
|
|
3
|
-
ape
|
|
4
|
-
baboon
|
|
5
3
|
badger
|
|
6
4
|
bat
|
|
7
5
|
bear
|
|
6
|
+
beetle
|
|
8
7
|
bird
|
|
9
8
|
bobcat
|
|
10
9
|
bulldog
|
|
@@ -73,7 +72,6 @@ lobster
|
|
|
73
72
|
mayfly
|
|
74
73
|
mamba
|
|
75
74
|
mole
|
|
76
|
-
monkey
|
|
77
75
|
moose
|
|
78
76
|
moth
|
|
79
77
|
mouse
|
|
@@ -114,6 +112,7 @@ swan
|
|
|
114
112
|
termite
|
|
115
113
|
tiger
|
|
116
114
|
treefrog
|
|
115
|
+
tuna
|
|
117
116
|
turkey
|
|
118
117
|
turtle
|
|
119
118
|
vampirebat
|
|
@@ -126,3 +125,4 @@ worm
|
|
|
126
125
|
yak
|
|
127
126
|
yeti
|
|
128
127
|
zebra
|
|
128
|
+
zebrafish
|
|
@@ -17,6 +17,7 @@ logger = logging.getLogger(__name__)
|
|
|
17
17
|
|
|
18
18
|
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
19
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
20
|
+
SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
|
|
20
21
|
|
|
21
22
|
|
|
22
23
|
class PlatformCapabilities:
|
|
@@ -73,6 +74,9 @@ class PlatformCapabilities:
|
|
|
73
74
|
def is_inlined_deployment_spec_enabled(self) -> bool:
|
|
74
75
|
return self._is_version_feature_enabled(INLINE_DEPLOYMENT_SPEC_PARAMETER)
|
|
75
76
|
|
|
77
|
+
def is_set_module_functions_volatility_from_manifest(self) -> bool:
|
|
78
|
+
return self._get_bool_feature(SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST, False)
|
|
79
|
+
|
|
76
80
|
def is_live_commit_enabled(self) -> bool:
|
|
77
81
|
return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
|
|
78
82
|
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import Any, Optional
|
|
2
3
|
|
|
3
4
|
from snowflake.ml._internal.utils import identifier
|
|
@@ -16,6 +17,14 @@ def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
|
|
16
17
|
return saved_resolved == current_resolved
|
|
17
18
|
|
|
18
19
|
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class _SessionState:
|
|
22
|
+
account: Optional[str]
|
|
23
|
+
role: Optional[str]
|
|
24
|
+
database: Optional[str]
|
|
25
|
+
schema: Optional[str]
|
|
26
|
+
|
|
27
|
+
|
|
19
28
|
class SerializableSessionMixin:
|
|
20
29
|
"""Mixin that provides pickling capabilities for objects with Snowpark sessions."""
|
|
21
30
|
|
|
@@ -40,17 +49,23 @@ class SerializableSessionMixin:
|
|
|
40
49
|
|
|
41
50
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
42
51
|
"""Restore session from context during unpickling."""
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
52
|
+
session_state = _SessionState(
|
|
53
|
+
account=state.pop(_SESSION_ACCOUNT_KEY, None),
|
|
54
|
+
role=state.pop(_SESSION_ROLE_KEY, None),
|
|
55
|
+
database=state.pop(_SESSION_DATABASE_KEY, None),
|
|
56
|
+
schema=state.pop(_SESSION_SCHEMA_KEY, None),
|
|
57
|
+
)
|
|
47
58
|
|
|
48
59
|
if hasattr(super(), "__setstate__"):
|
|
49
60
|
super().__setstate__(state) # type: ignore[misc]
|
|
50
61
|
else:
|
|
51
62
|
self.__dict__.update(state)
|
|
52
63
|
|
|
53
|
-
|
|
64
|
+
self._set_session(session_state)
|
|
65
|
+
|
|
66
|
+
def _set_session(self, session_state: _SessionState) -> None:
|
|
67
|
+
|
|
68
|
+
if session_state.account is not None:
|
|
54
69
|
active_sessions = snowpark_session._get_active_sessions()
|
|
55
70
|
if len(active_sessions) == 0:
|
|
56
71
|
raise RuntimeError("No active Snowpark session available. Please create a session.")
|
|
@@ -63,10 +78,10 @@ class SerializableSessionMixin:
|
|
|
63
78
|
active_sessions,
|
|
64
79
|
key=lambda s: sum(
|
|
65
80
|
(
|
|
66
|
-
_identifiers_match(
|
|
67
|
-
_identifiers_match(
|
|
68
|
-
_identifiers_match(
|
|
69
|
-
_identifiers_match(
|
|
81
|
+
_identifiers_match(session_state.account, s.get_current_account()),
|
|
82
|
+
_identifiers_match(session_state.role, s.get_current_role()),
|
|
83
|
+
_identifiers_match(session_state.database, s.get_current_database()),
|
|
84
|
+
_identifiers_match(session_state.schema, s.get_current_schema()),
|
|
70
85
|
)
|
|
71
86
|
),
|
|
72
87
|
),
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Any, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar, Union
|
|
5
5
|
from urllib.parse import quote
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
from snowflake import snowpark
|
|
8
8
|
from snowflake.ml import model as ml_model, registry
|
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
10
|
from snowflake.ml._internal.utils import mixins, sql_identifier
|
|
@@ -18,20 +18,40 @@ from snowflake.ml.experiment._client import (
|
|
|
18
18
|
)
|
|
19
19
|
from snowflake.ml.model import type_hints
|
|
20
20
|
from snowflake.ml.utils import sql_client as sql_client_utils
|
|
21
|
-
from snowflake.snowpark import session
|
|
22
21
|
|
|
23
22
|
DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
|
|
24
23
|
|
|
24
|
+
P = ParamSpec("P")
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _restore_session(
|
|
29
|
+
func: Callable[Concatenate["ExperimentTracking", P], T],
|
|
30
|
+
) -> Callable[Concatenate["ExperimentTracking", P], T]:
|
|
31
|
+
@functools.wraps(func)
|
|
32
|
+
def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
33
|
+
if self._session is None:
|
|
34
|
+
if self._session_state is None:
|
|
35
|
+
raise RuntimeError(
|
|
36
|
+
f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
|
|
37
|
+
)
|
|
38
|
+
self._set_session(self._session_state)
|
|
39
|
+
if self._session is None:
|
|
40
|
+
raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
|
|
41
|
+
return func(self, *args, **kwargs)
|
|
42
|
+
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
25
45
|
|
|
26
46
|
class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
27
47
|
"""
|
|
28
48
|
Class to manage experiments in Snowflake.
|
|
29
49
|
"""
|
|
30
50
|
|
|
31
|
-
@
|
|
51
|
+
@snowpark._internal.utils.private_preview(version="1.9.1")
|
|
32
52
|
def __init__(
|
|
33
53
|
self,
|
|
34
|
-
session:
|
|
54
|
+
session: snowpark.Session,
|
|
35
55
|
*,
|
|
36
56
|
database_name: Optional[str] = None,
|
|
37
57
|
schema_name: Optional[str] = None,
|
|
@@ -73,7 +93,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
73
93
|
database_name=self._database_name,
|
|
74
94
|
schema_name=self._schema_name,
|
|
75
95
|
)
|
|
76
|
-
self._session = session
|
|
96
|
+
self._session: Optional[snowpark.Session] = session
|
|
97
|
+
# Used to store information about the session if the session could not be restored during unpickling
|
|
98
|
+
# _session_state is None if and only if _session is not None
|
|
99
|
+
self._session_state: Optional[mixins._SessionState] = None
|
|
77
100
|
|
|
78
101
|
# The experiment in context
|
|
79
102
|
self._experiment: Optional[entities.Experiment] = None
|
|
@@ -87,20 +110,29 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
87
110
|
state["_registry"] = None
|
|
88
111
|
return state
|
|
89
112
|
|
|
90
|
-
def
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
session
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
113
|
+
def _set_session(self, session_state: mixins._SessionState) -> None:
|
|
114
|
+
try:
|
|
115
|
+
super()._set_session(session_state)
|
|
116
|
+
assert self._session is not None
|
|
117
|
+
except (snowpark.exceptions.SnowparkSessionException, AssertionError):
|
|
118
|
+
# If session was not set, store the session state
|
|
119
|
+
self._session = None
|
|
120
|
+
self._session_state = session_state
|
|
121
|
+
else:
|
|
122
|
+
# If session was set, clear the session state, and reinitialize the SQL client and registry
|
|
123
|
+
self._session_state = None
|
|
124
|
+
self._sql_client = sql_client.ExperimentTrackingSQLClient(
|
|
125
|
+
session=self._session,
|
|
126
|
+
database_name=self._database_name,
|
|
127
|
+
schema_name=self._schema_name,
|
|
128
|
+
)
|
|
129
|
+
self._registry = registry.Registry(
|
|
130
|
+
session=self._session,
|
|
131
|
+
database_name=self._database_name,
|
|
132
|
+
schema_name=self._schema_name,
|
|
133
|
+
)
|
|
103
134
|
|
|
135
|
+
@_restore_session
|
|
104
136
|
def set_experiment(
|
|
105
137
|
self,
|
|
106
138
|
experiment_name: str,
|
|
@@ -125,6 +157,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
125
157
|
self._run = None
|
|
126
158
|
return self._experiment
|
|
127
159
|
|
|
160
|
+
@_restore_session
|
|
128
161
|
def delete_experiment(
|
|
129
162
|
self,
|
|
130
163
|
experiment_name: str,
|
|
@@ -141,8 +174,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
141
174
|
self._run = None
|
|
142
175
|
|
|
143
176
|
@functools.wraps(registry.Registry.log_model)
|
|
177
|
+
@_restore_session
|
|
144
178
|
def log_model(
|
|
145
179
|
self,
|
|
180
|
+
/, # self needs to be a positional argument to stop mypy from complaining
|
|
146
181
|
model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
|
|
147
182
|
*,
|
|
148
183
|
model_name: str,
|
|
@@ -152,6 +187,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
152
187
|
with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
|
|
153
188
|
return self._registry.log_model(model, model_name=model_name, **kwargs)
|
|
154
189
|
|
|
190
|
+
@_restore_session
|
|
155
191
|
def start_run(
|
|
156
192
|
self,
|
|
157
193
|
run_name: Optional[str] = None,
|
|
@@ -181,6 +217,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
181
217
|
self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
|
|
182
218
|
return self._run
|
|
183
219
|
|
|
220
|
+
@_restore_session
|
|
184
221
|
def end_run(self, run_name: Optional[str] = None) -> None:
|
|
185
222
|
"""
|
|
186
223
|
End the current run if no run name is provided. Otherwise, the specified run is ended.
|
|
@@ -210,6 +247,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
210
247
|
self._run = None
|
|
211
248
|
self._print_urls(experiment_name=experiment_name, run_name=run_name)
|
|
212
249
|
|
|
250
|
+
@_restore_session
|
|
213
251
|
def delete_run(
|
|
214
252
|
self,
|
|
215
253
|
run_name: str,
|
|
@@ -248,6 +286,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
248
286
|
"""
|
|
249
287
|
self.log_metrics(metrics={key: value}, step=step)
|
|
250
288
|
|
|
289
|
+
@_restore_session
|
|
251
290
|
def log_metrics(
|
|
252
291
|
self,
|
|
253
292
|
metrics: dict[str, float],
|
|
@@ -284,6 +323,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
284
323
|
"""
|
|
285
324
|
self.log_params({key: value})
|
|
286
325
|
|
|
326
|
+
@_restore_session
|
|
287
327
|
def log_params(
|
|
288
328
|
self,
|
|
289
329
|
params: dict[str, Any],
|
|
@@ -305,6 +345,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
305
345
|
params=json.dumps([param.to_dict() for param in params_list]),
|
|
306
346
|
)
|
|
307
347
|
|
|
348
|
+
@_restore_session
|
|
308
349
|
def log_artifact(
|
|
309
350
|
self,
|
|
310
351
|
local_path: str,
|
|
@@ -328,6 +369,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
328
369
|
file_path=file_path,
|
|
329
370
|
)
|
|
330
371
|
|
|
372
|
+
@_restore_session
|
|
331
373
|
def list_artifacts(
|
|
332
374
|
self,
|
|
333
375
|
run_name: str,
|
|
@@ -356,6 +398,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
356
398
|
artifact_path=artifact_path or "",
|
|
357
399
|
)
|
|
358
400
|
|
|
401
|
+
@_restore_session
|
|
359
402
|
def download_artifacts(
|
|
360
403
|
self,
|
|
361
404
|
run_name: str,
|
|
@@ -397,6 +440,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
397
440
|
return self._run
|
|
398
441
|
return self.start_run()
|
|
399
442
|
|
|
443
|
+
@_restore_session
|
|
400
444
|
def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
|
|
401
445
|
generator = hrid_generator.HRID16()
|
|
402
446
|
existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
|
snowflake/ml/jobs/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from snowflake.ml.jobs._interop.exception_utils import install_exception_display_hooks
|
|
1
2
|
from snowflake.ml.jobs._utils.types import JOB_STATUS
|
|
2
3
|
from snowflake.ml.jobs.decorators import remote
|
|
3
4
|
from snowflake.ml.jobs.job import MLJob
|
|
@@ -10,6 +11,9 @@ from snowflake.ml.jobs.manager import (
|
|
|
10
11
|
submit_from_stage,
|
|
11
12
|
)
|
|
12
13
|
|
|
14
|
+
# Initialize exception display hooks for remote job error handling
|
|
15
|
+
install_exception_display_hooks()
|
|
16
|
+
|
|
13
17
|
__all__ = [
|
|
14
18
|
"remote",
|
|
15
19
|
"submit_file",
|
|
File without changes
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Literal, Optional, Protocol, Union, cast, overload
|
|
4
|
+
|
|
5
|
+
from snowflake import snowpark
|
|
6
|
+
from snowflake.ml.jobs._interop import dto_schema
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class StageFileWriter(io.IOBase):
|
|
10
|
+
"""
|
|
11
|
+
A context manager IOBase implementation that proxies writes to an internal BytesIO
|
|
12
|
+
and uploads to Snowflake stage on close.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, session: snowpark.Session, path: str) -> None:
|
|
16
|
+
self._session = session
|
|
17
|
+
self._path = path
|
|
18
|
+
self._buffer = io.BytesIO()
|
|
19
|
+
self._closed = False
|
|
20
|
+
self._exception_occurred = False
|
|
21
|
+
|
|
22
|
+
def write(self, data: Union[bytes, bytearray]) -> int:
|
|
23
|
+
"""Write data to the internal buffer."""
|
|
24
|
+
if self._closed:
|
|
25
|
+
raise ValueError("I/O operation on closed file")
|
|
26
|
+
return self._buffer.write(data)
|
|
27
|
+
|
|
28
|
+
def close(self, write_contents: bool = True) -> None:
|
|
29
|
+
"""Close the file and upload the buffer contents to the stage."""
|
|
30
|
+
if not self._closed:
|
|
31
|
+
# Only upload if buffer has content and no exception occurred
|
|
32
|
+
if write_contents and self._buffer.tell() > 0:
|
|
33
|
+
self._buffer.seek(0)
|
|
34
|
+
self._session.file.put_stream(self._buffer, self._path)
|
|
35
|
+
self._buffer.close()
|
|
36
|
+
self._closed = True
|
|
37
|
+
|
|
38
|
+
def __enter__(self) -> "StageFileWriter":
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
42
|
+
exception_occurred = exc_type is not None
|
|
43
|
+
self.close(write_contents=not exception_occurred)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def closed(self) -> bool:
|
|
47
|
+
return self._closed
|
|
48
|
+
|
|
49
|
+
def writable(self) -> bool:
|
|
50
|
+
return not self._closed
|
|
51
|
+
|
|
52
|
+
def readable(self) -> bool:
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
def seekable(self) -> bool:
|
|
56
|
+
return not self._closed
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _is_stage_path(path: str) -> bool:
|
|
60
|
+
return path.startswith("@") or path.startswith("snow://")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def open_stream(path: str, mode: str = "rb", session: Optional[snowpark.Session] = None) -> io.IOBase:
|
|
64
|
+
if _is_stage_path(path):
|
|
65
|
+
if session is None:
|
|
66
|
+
raise ValueError("Session is required when opening a stage path")
|
|
67
|
+
if "r" in mode:
|
|
68
|
+
stream: io.IOBase = session.file.get_stream(path) # type: ignore[assignment]
|
|
69
|
+
return stream
|
|
70
|
+
elif "w" in mode:
|
|
71
|
+
return StageFileWriter(session, path)
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(f"Unsupported mode '{mode}' for stage path")
|
|
74
|
+
else:
|
|
75
|
+
result: io.IOBase = open(path, mode) # type: ignore[assignment]
|
|
76
|
+
return result
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class DtoCodec(Protocol):
|
|
80
|
+
@overload
|
|
81
|
+
@staticmethod
|
|
82
|
+
def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
|
|
83
|
+
...
|
|
84
|
+
|
|
85
|
+
@overload
|
|
86
|
+
@staticmethod
|
|
87
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
|
|
88
|
+
...
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def encode(dto: dto_schema.ResultDTO) -> bytes:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class JsonDtoCodec(DtoCodec):
|
|
100
|
+
@overload
|
|
101
|
+
@staticmethod
|
|
102
|
+
def decode(stream: io.IOBase, as_dict: Literal[True]) -> dict[str, Any]:
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
@overload
|
|
106
|
+
@staticmethod
|
|
107
|
+
def decode(stream: io.IOBase, as_dict: Literal[False] = False) -> dto_schema.ResultDTO:
|
|
108
|
+
...
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def decode(stream: io.IOBase, as_dict: bool = False) -> Union[dto_schema.ResultDTO, dict[str, Any]]:
|
|
112
|
+
data = cast(dict[str, Any], json.load(stream))
|
|
113
|
+
if as_dict:
|
|
114
|
+
return data
|
|
115
|
+
return dto_schema.ResultDTO.model_validate(data)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def encode(dto: dto_schema.ResultDTO) -> bytes:
|
|
119
|
+
# Temporarily extract the value to avoid accidentally applying model_dump() on it
|
|
120
|
+
result_value = dto.value
|
|
121
|
+
dto.value = None # Clear value to avoid serializing it in the model_dump
|
|
122
|
+
result_dict = dto.model_dump()
|
|
123
|
+
result_dict["value"] = result_value # Put back the value
|
|
124
|
+
return json.dumps(result_dict).encode("utf-8")
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from typing import Any, Optional, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, model_validator
|
|
4
|
+
from typing_extensions import NotRequired, TypedDict
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class BinaryManifest(TypedDict):
|
|
8
|
+
"""
|
|
9
|
+
Binary data manifest schema.
|
|
10
|
+
Contains one of: path, bytes, or base64 for the serialized data.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
path: NotRequired[str] # Path to file
|
|
14
|
+
bytes: NotRequired[bytes] # In-line byte string (not supported with JSON codec)
|
|
15
|
+
base64: NotRequired[str] # Base64 encoded string
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ParquetManifest(TypedDict):
|
|
19
|
+
"""Protocol manifest schema for parquet files."""
|
|
20
|
+
|
|
21
|
+
paths: list[str] # File paths
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Union type for all manifest types, including catch-all dict[str, Any] for backward compatibility
|
|
25
|
+
PayloadManifest = Union[BinaryManifest, ParquetManifest, dict[str, Any]]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ProtocolInfo(BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
The protocol used to serialize the result and the manifest of the result.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
name: str
|
|
34
|
+
version: Optional[str] = None
|
|
35
|
+
metadata: Optional[dict[str, str]] = None
|
|
36
|
+
manifest: Optional[PayloadManifest] = None
|
|
37
|
+
|
|
38
|
+
def __str__(self) -> str:
|
|
39
|
+
result = self.name
|
|
40
|
+
if self.version:
|
|
41
|
+
result += f"-{self.version}"
|
|
42
|
+
return result
|
|
43
|
+
|
|
44
|
+
def with_manifest(self, manifest: PayloadManifest) -> "ProtocolInfo":
|
|
45
|
+
"""
|
|
46
|
+
Return a new ProtocolInfo object with the manifest.
|
|
47
|
+
"""
|
|
48
|
+
return ProtocolInfo(
|
|
49
|
+
name=self.name,
|
|
50
|
+
version=self.version,
|
|
51
|
+
metadata=self.metadata,
|
|
52
|
+
manifest=manifest,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ResultMetadata(BaseModel):
|
|
57
|
+
"""
|
|
58
|
+
The metadata of a result.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
type: str
|
|
62
|
+
repr: str
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ExceptionMetadata(ResultMetadata):
|
|
66
|
+
message: str
|
|
67
|
+
traceback: str
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ResultDTO(BaseModel):
|
|
71
|
+
"""
|
|
72
|
+
A JSON representation of an execution result.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
success: Whether the execution was successful.
|
|
76
|
+
value: The value of the execution or the exception if the execution failed.
|
|
77
|
+
protocol: The protocol used to serialize the result.
|
|
78
|
+
metadata: The metadata of the result.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
success: bool
|
|
82
|
+
value: Optional[Any] = None
|
|
83
|
+
protocol: Optional[ProtocolInfo] = None
|
|
84
|
+
metadata: Optional[Union[ResultMetadata, ExceptionMetadata]] = None
|
|
85
|
+
serialize_error: Optional[str] = None
|
|
86
|
+
|
|
87
|
+
@model_validator(mode="before")
|
|
88
|
+
@classmethod
|
|
89
|
+
def validate_fields(cls, data: Any) -> Any:
|
|
90
|
+
"""Ensure at least one of value, protocol, or metadata keys is specified."""
|
|
91
|
+
if isinstance(data, dict):
|
|
92
|
+
required_fields = {"value", "protocol", "metadata"}
|
|
93
|
+
if not any(field in data for field in required_fields):
|
|
94
|
+
raise ValueError("At least one of 'value', 'protocol', or 'metadata' must be specified")
|
|
95
|
+
return data
|