snowflake-ml-python 1.14.0__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +13 -7
- snowflake/ml/_internal/utils/connection_params.py +5 -3
- snowflake/ml/_internal/utils/jwt_generator.py +3 -2
- snowflake/ml/_internal/utils/mixins.py +24 -9
- snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
- snowflake/ml/experiment/_entities/__init__.py +2 -1
- snowflake/ml/experiment/_entities/run.py +0 -15
- snowflake/ml/experiment/_entities/run_metadata.py +3 -51
- snowflake/ml/experiment/experiment_tracking.py +71 -27
- snowflake/ml/jobs/_utils/spec_utils.py +49 -11
- snowflake/ml/jobs/manager.py +20 -0
- snowflake/ml/model/__init__.py +12 -2
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -4
- snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
- snowflake/ml/model/_client/model/model_version_impl.py +30 -62
- snowflake/ml/model/_client/ops/service_ops.py +68 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/sql/service.py +29 -2
- snowflake/ml/model/_client/sql/stage.py +8 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +25 -2
- snowflake/ml/model/_packager/model_env/model_env.py +26 -16
- snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
- snowflake/ml/model/_packager/model_packager.py +4 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_signatures/utils.py +0 -21
- snowflake/ml/model/models/huggingface_pipeline.py +56 -21
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/model/volatility.py +34 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +2 -1
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +29 -2
- snowflake/ml/registry/registry.py +15 -0
- snowflake/ml/utils/authentication.py +16 -0
- snowflake/ml/utils/connection_params.py +5 -3
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/METADATA +81 -36
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/RECORD +193 -191
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.16.0.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
from contextlib import contextmanager
|
|
3
4
|
from typing import Any, Optional
|
|
4
5
|
|
|
5
|
-
from absl import logging
|
|
6
6
|
from packaging import version
|
|
7
7
|
|
|
8
8
|
from snowflake.ml import version as snowml_version
|
|
@@ -13,8 +13,11 @@ from snowflake.snowpark import (
|
|
|
13
13
|
session as snowpark_session,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
16
18
|
LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
|
|
17
19
|
INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
|
|
20
|
+
SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST = "SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST"
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class PlatformCapabilities:
|
|
@@ -60,17 +63,20 @@ class PlatformCapabilities:
|
|
|
60
63
|
@classmethod # type: ignore[arg-type]
|
|
61
64
|
@contextmanager
|
|
62
65
|
def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
|
|
63
|
-
|
|
66
|
+
logger.debug(f"Setting mock features: {features}")
|
|
64
67
|
cls.set_mock_features(features)
|
|
65
68
|
try:
|
|
66
69
|
yield
|
|
67
70
|
finally:
|
|
68
|
-
|
|
71
|
+
logger.debug(f"Clearing mock features: {features}")
|
|
69
72
|
cls.clear_mock_features()
|
|
70
73
|
|
|
71
74
|
def is_inlined_deployment_spec_enabled(self) -> bool:
|
|
72
75
|
return self._is_version_feature_enabled(INLINE_DEPLOYMENT_SPEC_PARAMETER)
|
|
73
76
|
|
|
77
|
+
def is_set_module_functions_volatility_from_manifest(self) -> bool:
|
|
78
|
+
return self._get_bool_feature(SET_MODULE_FUNCTIONS_VOLATILITY_FROM_MANIFEST, False)
|
|
79
|
+
|
|
74
80
|
def is_live_commit_enabled(self) -> bool:
|
|
75
81
|
return self._get_bool_feature(LIVE_COMMIT_PARAMETER, False)
|
|
76
82
|
|
|
@@ -98,7 +104,7 @@ class PlatformCapabilities:
|
|
|
98
104
|
error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
|
|
99
105
|
)
|
|
100
106
|
except snowpark_exceptions.SnowparkSQLException as e:
|
|
101
|
-
|
|
107
|
+
logger.debug(f"Failed to retrieve platform capabilities: {e}")
|
|
102
108
|
# This can happen is server side is older than 9.2. That is fine.
|
|
103
109
|
return {}
|
|
104
110
|
|
|
@@ -144,7 +150,7 @@ class PlatformCapabilities:
|
|
|
144
150
|
|
|
145
151
|
value = self.features.get(feature_name)
|
|
146
152
|
if value is None:
|
|
147
|
-
|
|
153
|
+
logger.debug(f"Feature {feature_name} not found, returning large version number")
|
|
148
154
|
return large_version
|
|
149
155
|
|
|
150
156
|
try:
|
|
@@ -152,7 +158,7 @@ class PlatformCapabilities:
|
|
|
152
158
|
version_str = str(value)
|
|
153
159
|
return version.Version(version_str)
|
|
154
160
|
except (version.InvalidVersion, ValueError, TypeError) as e:
|
|
155
|
-
|
|
161
|
+
logger.debug(
|
|
156
162
|
f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
|
|
157
163
|
f"Returning large version number"
|
|
158
164
|
)
|
|
@@ -171,7 +177,7 @@ class PlatformCapabilities:
|
|
|
171
177
|
feature_version = self._get_version_feature(feature_name)
|
|
172
178
|
|
|
173
179
|
result = current_version >= feature_version
|
|
174
|
-
|
|
180
|
+
logger.debug(
|
|
175
181
|
f"Version comparison for feature {feature_name}: "
|
|
176
182
|
f"current={current_version}, feature={feature_version}, enabled={result}"
|
|
177
183
|
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import configparser
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import Optional, Union
|
|
4
5
|
|
|
5
|
-
from absl import logging
|
|
6
6
|
from cryptography.hazmat import backends
|
|
7
7
|
from cryptography.hazmat.primitives import serialization
|
|
8
8
|
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
9
11
|
_DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
|
|
10
12
|
|
|
11
13
|
|
|
@@ -106,7 +108,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
|
106
108
|
"""Loads the dictionary from snowsql config file."""
|
|
107
109
|
snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
|
|
108
110
|
if not os.path.exists(snowsql_config_file):
|
|
109
|
-
|
|
111
|
+
logger.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
|
|
110
112
|
raise Exception("Snowflake SnowSQL config not found.")
|
|
111
113
|
|
|
112
114
|
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
|
@@ -122,7 +124,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
|
122
124
|
# See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
|
123
125
|
connection_name = "connections"
|
|
124
126
|
|
|
125
|
-
|
|
127
|
+
logger.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
|
|
126
128
|
config.read(snowsql_config_file)
|
|
127
129
|
conn_params = dict(config[connection_name])
|
|
128
130
|
# Remap names to appropriate args in Python Connector API
|
|
@@ -110,15 +110,16 @@ class JWTGenerator:
|
|
|
110
110
|
}
|
|
111
111
|
|
|
112
112
|
# Regenerate the actual token
|
|
113
|
-
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
|
|
113
|
+
token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) # type: ignore[arg-type]
|
|
114
114
|
# If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
|
|
115
115
|
# If the token is a byte string, convert it to a string.
|
|
116
116
|
if isinstance(token, bytes):
|
|
117
117
|
token = token.decode("utf-8")
|
|
118
118
|
self.token = token
|
|
119
|
+
public_key = self.private_key.public_key()
|
|
119
120
|
logger.info(
|
|
120
121
|
"Generated a JWT with the following payload: %s",
|
|
121
|
-
jwt.decode(self.token, key=
|
|
122
|
+
jwt.decode(self.token, key=public_key, algorithms=[JWTGenerator.ALGORITHM]), # type: ignore[arg-type]
|
|
122
123
|
)
|
|
123
124
|
|
|
124
125
|
return token
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import Any, Optional
|
|
2
3
|
|
|
3
4
|
from snowflake.ml._internal.utils import identifier
|
|
@@ -16,6 +17,14 @@ def _identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
|
|
16
17
|
return saved_resolved == current_resolved
|
|
17
18
|
|
|
18
19
|
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class _SessionState:
|
|
22
|
+
account: Optional[str]
|
|
23
|
+
role: Optional[str]
|
|
24
|
+
database: Optional[str]
|
|
25
|
+
schema: Optional[str]
|
|
26
|
+
|
|
27
|
+
|
|
19
28
|
class SerializableSessionMixin:
|
|
20
29
|
"""Mixin that provides pickling capabilities for objects with Snowpark sessions."""
|
|
21
30
|
|
|
@@ -40,17 +49,23 @@ class SerializableSessionMixin:
|
|
|
40
49
|
|
|
41
50
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
42
51
|
"""Restore session from context during unpickling."""
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
52
|
+
session_state = _SessionState(
|
|
53
|
+
account=state.pop(_SESSION_ACCOUNT_KEY, None),
|
|
54
|
+
role=state.pop(_SESSION_ROLE_KEY, None),
|
|
55
|
+
database=state.pop(_SESSION_DATABASE_KEY, None),
|
|
56
|
+
schema=state.pop(_SESSION_SCHEMA_KEY, None),
|
|
57
|
+
)
|
|
47
58
|
|
|
48
59
|
if hasattr(super(), "__setstate__"):
|
|
49
60
|
super().__setstate__(state) # type: ignore[misc]
|
|
50
61
|
else:
|
|
51
62
|
self.__dict__.update(state)
|
|
52
63
|
|
|
53
|
-
|
|
64
|
+
self._set_session(session_state)
|
|
65
|
+
|
|
66
|
+
def _set_session(self, session_state: _SessionState) -> None:
|
|
67
|
+
|
|
68
|
+
if session_state.account is not None:
|
|
54
69
|
active_sessions = snowpark_session._get_active_sessions()
|
|
55
70
|
if len(active_sessions) == 0:
|
|
56
71
|
raise RuntimeError("No active Snowpark session available. Please create a session.")
|
|
@@ -63,10 +78,10 @@ class SerializableSessionMixin:
|
|
|
63
78
|
active_sessions,
|
|
64
79
|
key=lambda s: sum(
|
|
65
80
|
(
|
|
66
|
-
_identifiers_match(
|
|
67
|
-
_identifiers_match(
|
|
68
|
-
_identifiers_match(
|
|
69
|
-
_identifiers_match(
|
|
81
|
+
_identifiers_match(session_state.account, s.get_current_account()),
|
|
82
|
+
_identifiers_match(session_state.role, s.get_current_role()),
|
|
83
|
+
_identifiers_match(session_state.database, s.get_current_database()),
|
|
84
|
+
_identifiers_match(session_state.schema, s.get_current_schema()),
|
|
70
85
|
)
|
|
71
86
|
),
|
|
72
87
|
),
|
|
@@ -76,17 +76,30 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
|
|
|
76
76
|
self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
|
|
77
77
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
78
78
|
|
|
79
|
-
def
|
|
79
|
+
def modify_run_add_metrics(
|
|
80
80
|
self,
|
|
81
81
|
*,
|
|
82
82
|
experiment_name: sql_identifier.SqlIdentifier,
|
|
83
83
|
run_name: sql_identifier.SqlIdentifier,
|
|
84
|
-
|
|
84
|
+
metrics: str,
|
|
85
85
|
) -> None:
|
|
86
86
|
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
87
87
|
query_result_checker.SqlResultValidator(
|
|
88
88
|
self._session,
|
|
89
|
-
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name}
|
|
89
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
|
|
90
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
|
+
|
|
92
|
+
def modify_run_add_params(
|
|
93
|
+
self,
|
|
94
|
+
*,
|
|
95
|
+
experiment_name: sql_identifier.SqlIdentifier,
|
|
96
|
+
run_name: sql_identifier.SqlIdentifier,
|
|
97
|
+
params: str,
|
|
98
|
+
) -> None:
|
|
99
|
+
experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
|
|
100
|
+
query_result_checker.SqlResultValidator(
|
|
101
|
+
self._session,
|
|
102
|
+
f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
|
|
90
103
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
|
91
104
|
|
|
92
105
|
def put_artifact(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from snowflake.ml.experiment._entities.experiment import Experiment
|
|
2
2
|
from snowflake.ml.experiment._entities.run import Run
|
|
3
|
+
from snowflake.ml.experiment._entities.run_metadata import Metric, Param
|
|
3
4
|
|
|
4
|
-
__all__ = ["Experiment", "Run"]
|
|
5
|
+
__all__ = ["Experiment", "Run", "Metric", "Param"]
|
|
@@ -1,11 +1,8 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import types
|
|
3
2
|
from typing import TYPE_CHECKING, Optional
|
|
4
3
|
|
|
5
4
|
from snowflake.ml._internal.utils import sql_identifier
|
|
6
5
|
from snowflake.ml.experiment import _experiment_info as experiment_info
|
|
7
|
-
from snowflake.ml.experiment._client import experiment_tracking_sql_client
|
|
8
|
-
from snowflake.ml.experiment._entities import run_metadata
|
|
9
6
|
|
|
10
7
|
if TYPE_CHECKING:
|
|
11
8
|
from snowflake.ml.experiment import experiment_tracking
|
|
@@ -41,18 +38,6 @@ class Run:
|
|
|
41
38
|
if self._experiment_tracking._run is self:
|
|
42
39
|
self._experiment_tracking.end_run()
|
|
43
40
|
|
|
44
|
-
def _get_metadata(
|
|
45
|
-
self,
|
|
46
|
-
) -> run_metadata.RunMetadata:
|
|
47
|
-
runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
|
|
48
|
-
experiment_name=self.experiment_name, like=str(self.name)
|
|
49
|
-
)
|
|
50
|
-
if not runs:
|
|
51
|
-
raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
|
|
52
|
-
return run_metadata.RunMetadata.from_dict(
|
|
53
|
-
json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
|
|
54
|
-
)
|
|
55
|
-
|
|
56
41
|
def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
|
|
57
42
|
return experiment_info.ExperimentInfo(
|
|
58
43
|
fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
|
|
@@ -1,12 +1,4 @@
|
|
|
1
1
|
import dataclasses
|
|
2
|
-
import enum
|
|
3
|
-
import typing
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class RunStatus(str, enum.Enum):
|
|
7
|
-
UNKNOWN = "UNKNOWN"
|
|
8
|
-
RUNNING = "RUNNING"
|
|
9
|
-
FINISHED = "FINISHED"
|
|
10
2
|
|
|
11
3
|
|
|
12
4
|
@dataclasses.dataclass
|
|
@@ -15,54 +7,14 @@ class Metric:
|
|
|
15
7
|
value: float
|
|
16
8
|
step: int
|
|
17
9
|
|
|
10
|
+
def to_dict(self) -> dict: # type: ignore[type-arg]
|
|
11
|
+
return dataclasses.asdict(self)
|
|
12
|
+
|
|
18
13
|
|
|
19
14
|
@dataclasses.dataclass
|
|
20
15
|
class Param:
|
|
21
16
|
name: str
|
|
22
17
|
value: str
|
|
23
18
|
|
|
24
|
-
|
|
25
|
-
@dataclasses.dataclass
|
|
26
|
-
class RunMetadata:
|
|
27
|
-
status: RunStatus
|
|
28
|
-
metrics: list[Metric]
|
|
29
|
-
parameters: list[Param]
|
|
30
|
-
|
|
31
|
-
@classmethod
|
|
32
|
-
def from_dict(
|
|
33
|
-
cls,
|
|
34
|
-
metadata: dict, # type: ignore[type-arg]
|
|
35
|
-
) -> "RunMetadata":
|
|
36
|
-
return RunMetadata(
|
|
37
|
-
status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
|
|
38
|
-
metrics=[Metric(**m) for m in metadata.get("metrics", [])],
|
|
39
|
-
parameters=[Param(**p) for p in metadata.get("parameters", [])],
|
|
40
|
-
)
|
|
41
|
-
|
|
42
19
|
def to_dict(self) -> dict: # type: ignore[type-arg]
|
|
43
20
|
return dataclasses.asdict(self)
|
|
44
|
-
|
|
45
|
-
def set_metric(
|
|
46
|
-
self,
|
|
47
|
-
key: str,
|
|
48
|
-
value: float,
|
|
49
|
-
step: int,
|
|
50
|
-
) -> None:
|
|
51
|
-
for metric in self.metrics:
|
|
52
|
-
if metric.name == key and metric.step == step:
|
|
53
|
-
metric.value = value
|
|
54
|
-
break
|
|
55
|
-
else:
|
|
56
|
-
self.metrics.append(Metric(name=key, value=value, step=step))
|
|
57
|
-
|
|
58
|
-
def set_param(
|
|
59
|
-
self,
|
|
60
|
-
key: str,
|
|
61
|
-
value: typing.Any,
|
|
62
|
-
) -> None:
|
|
63
|
-
for parameter in self.parameters:
|
|
64
|
-
if parameter.name == key:
|
|
65
|
-
parameter.value = str(value)
|
|
66
|
-
break
|
|
67
|
-
else:
|
|
68
|
-
self.parameters.append(Param(name=key, value=str(value)))
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
-
from typing import Any, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar, Union
|
|
5
5
|
from urllib.parse import quote
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
from snowflake import snowpark
|
|
8
8
|
from snowflake.ml import model as ml_model, registry
|
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
10
|
from snowflake.ml._internal.utils import mixins, sql_identifier
|
|
@@ -18,20 +18,40 @@ from snowflake.ml.experiment._client import (
|
|
|
18
18
|
)
|
|
19
19
|
from snowflake.ml.model import type_hints
|
|
20
20
|
from snowflake.ml.utils import sql_client as sql_client_utils
|
|
21
|
-
from snowflake.snowpark import session
|
|
22
21
|
|
|
23
22
|
DEFAULT_EXPERIMENT_NAME = sql_identifier.SqlIdentifier("DEFAULT")
|
|
24
23
|
|
|
24
|
+
P = ParamSpec("P")
|
|
25
|
+
T = TypeVar("T")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _restore_session(
|
|
29
|
+
func: Callable[Concatenate["ExperimentTracking", P], T],
|
|
30
|
+
) -> Callable[Concatenate["ExperimentTracking", P], T]:
|
|
31
|
+
@functools.wraps(func)
|
|
32
|
+
def wrapper(self: "ExperimentTracking", /, *args: P.args, **kwargs: P.kwargs) -> T:
|
|
33
|
+
if self._session is None:
|
|
34
|
+
if self._session_state is None:
|
|
35
|
+
raise RuntimeError(
|
|
36
|
+
f"Session is not set before calling {func.__name__}, and there is no session state to restore from"
|
|
37
|
+
)
|
|
38
|
+
self._set_session(self._session_state)
|
|
39
|
+
if self._session is None:
|
|
40
|
+
raise RuntimeError(f"Failed to restore session before calling {func.__name__}")
|
|
41
|
+
return func(self, *args, **kwargs)
|
|
42
|
+
|
|
43
|
+
return wrapper
|
|
44
|
+
|
|
25
45
|
|
|
26
46
|
class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
27
47
|
"""
|
|
28
48
|
Class to manage experiments in Snowflake.
|
|
29
49
|
"""
|
|
30
50
|
|
|
31
|
-
@
|
|
51
|
+
@snowpark._internal.utils.private_preview(version="1.9.1")
|
|
32
52
|
def __init__(
|
|
33
53
|
self,
|
|
34
|
-
session:
|
|
54
|
+
session: snowpark.Session,
|
|
35
55
|
*,
|
|
36
56
|
database_name: Optional[str] = None,
|
|
37
57
|
schema_name: Optional[str] = None,
|
|
@@ -73,7 +93,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
73
93
|
database_name=self._database_name,
|
|
74
94
|
schema_name=self._schema_name,
|
|
75
95
|
)
|
|
76
|
-
self._session = session
|
|
96
|
+
self._session: Optional[snowpark.Session] = session
|
|
97
|
+
# Used to store information about the session if the session could not be restored during unpickling
|
|
98
|
+
# _session_state is None if and only if _session is not None
|
|
99
|
+
self._session_state: Optional[mixins._SessionState] = None
|
|
77
100
|
|
|
78
101
|
# The experiment in context
|
|
79
102
|
self._experiment: Optional[entities.Experiment] = None
|
|
@@ -87,20 +110,29 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
87
110
|
state["_registry"] = None
|
|
88
111
|
return state
|
|
89
112
|
|
|
90
|
-
def
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
session
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
113
|
+
def _set_session(self, session_state: mixins._SessionState) -> None:
|
|
114
|
+
try:
|
|
115
|
+
super()._set_session(session_state)
|
|
116
|
+
assert self._session is not None
|
|
117
|
+
except (snowpark.exceptions.SnowparkSessionException, AssertionError):
|
|
118
|
+
# If session was not set, store the session state
|
|
119
|
+
self._session = None
|
|
120
|
+
self._session_state = session_state
|
|
121
|
+
else:
|
|
122
|
+
# If session was set, clear the session state, and reinitialize the SQL client and registry
|
|
123
|
+
self._session_state = None
|
|
124
|
+
self._sql_client = sql_client.ExperimentTrackingSQLClient(
|
|
125
|
+
session=self._session,
|
|
126
|
+
database_name=self._database_name,
|
|
127
|
+
schema_name=self._schema_name,
|
|
128
|
+
)
|
|
129
|
+
self._registry = registry.Registry(
|
|
130
|
+
session=self._session,
|
|
131
|
+
database_name=self._database_name,
|
|
132
|
+
schema_name=self._schema_name,
|
|
133
|
+
)
|
|
103
134
|
|
|
135
|
+
@_restore_session
|
|
104
136
|
def set_experiment(
|
|
105
137
|
self,
|
|
106
138
|
experiment_name: str,
|
|
@@ -125,6 +157,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
125
157
|
self._run = None
|
|
126
158
|
return self._experiment
|
|
127
159
|
|
|
160
|
+
@_restore_session
|
|
128
161
|
def delete_experiment(
|
|
129
162
|
self,
|
|
130
163
|
experiment_name: str,
|
|
@@ -141,8 +174,10 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
141
174
|
self._run = None
|
|
142
175
|
|
|
143
176
|
@functools.wraps(registry.Registry.log_model)
|
|
177
|
+
@_restore_session
|
|
144
178
|
def log_model(
|
|
145
179
|
self,
|
|
180
|
+
/, # self needs to be a positional argument to stop mypy from complaining
|
|
146
181
|
model: Union[type_hints.SupportedModelType, ml_model.ModelVersion],
|
|
147
182
|
*,
|
|
148
183
|
model_name: str,
|
|
@@ -152,6 +187,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
152
187
|
with experiment_info.ExperimentInfoPatcher(experiment_info=run._get_experiment_info()):
|
|
153
188
|
return self._registry.log_model(model, model_name=model_name, **kwargs)
|
|
154
189
|
|
|
190
|
+
@_restore_session
|
|
155
191
|
def start_run(
|
|
156
192
|
self,
|
|
157
193
|
run_name: Optional[str] = None,
|
|
@@ -181,6 +217,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
181
217
|
self._run = entities.Run(experiment_tracking=self, experiment_name=experiment.name, run_name=run_name)
|
|
182
218
|
return self._run
|
|
183
219
|
|
|
220
|
+
@_restore_session
|
|
184
221
|
def end_run(self, run_name: Optional[str] = None) -> None:
|
|
185
222
|
"""
|
|
186
223
|
End the current run if no run name is provided. Otherwise, the specified run is ended.
|
|
@@ -210,6 +247,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
210
247
|
self._run = None
|
|
211
248
|
self._print_urls(experiment_name=experiment_name, run_name=run_name)
|
|
212
249
|
|
|
250
|
+
@_restore_session
|
|
213
251
|
def delete_run(
|
|
214
252
|
self,
|
|
215
253
|
run_name: str,
|
|
@@ -248,6 +286,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
248
286
|
"""
|
|
249
287
|
self.log_metrics(metrics={key: value}, step=step)
|
|
250
288
|
|
|
289
|
+
@_restore_session
|
|
251
290
|
def log_metrics(
|
|
252
291
|
self,
|
|
253
292
|
metrics: dict[str, float],
|
|
@@ -261,13 +300,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
261
300
|
step: The step of the metrics. Defaults to 0.
|
|
262
301
|
"""
|
|
263
302
|
run = self._get_or_start_run()
|
|
264
|
-
|
|
303
|
+
metrics_list = []
|
|
265
304
|
for key, value in metrics.items():
|
|
266
|
-
|
|
267
|
-
self._sql_client.
|
|
305
|
+
metrics_list.append(entities.Metric(key, value, step))
|
|
306
|
+
self._sql_client.modify_run_add_metrics(
|
|
268
307
|
experiment_name=run.experiment_name,
|
|
269
308
|
run_name=run.name,
|
|
270
|
-
|
|
309
|
+
metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
|
|
271
310
|
)
|
|
272
311
|
|
|
273
312
|
def log_param(
|
|
@@ -284,6 +323,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
284
323
|
"""
|
|
285
324
|
self.log_params({key: value})
|
|
286
325
|
|
|
326
|
+
@_restore_session
|
|
287
327
|
def log_params(
|
|
288
328
|
self,
|
|
289
329
|
params: dict[str, Any],
|
|
@@ -296,15 +336,16 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
296
336
|
to string.
|
|
297
337
|
"""
|
|
298
338
|
run = self._get_or_start_run()
|
|
299
|
-
|
|
339
|
+
params_list = []
|
|
300
340
|
for key, value in params.items():
|
|
301
|
-
|
|
302
|
-
self._sql_client.
|
|
341
|
+
params_list.append(entities.Param(key, str(value)))
|
|
342
|
+
self._sql_client.modify_run_add_params(
|
|
303
343
|
experiment_name=run.experiment_name,
|
|
304
344
|
run_name=run.name,
|
|
305
|
-
|
|
345
|
+
params=json.dumps([param.to_dict() for param in params_list]),
|
|
306
346
|
)
|
|
307
347
|
|
|
348
|
+
@_restore_session
|
|
308
349
|
def log_artifact(
|
|
309
350
|
self,
|
|
310
351
|
local_path: str,
|
|
@@ -328,6 +369,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
328
369
|
file_path=file_path,
|
|
329
370
|
)
|
|
330
371
|
|
|
372
|
+
@_restore_session
|
|
331
373
|
def list_artifacts(
|
|
332
374
|
self,
|
|
333
375
|
run_name: str,
|
|
@@ -356,6 +398,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
356
398
|
artifact_path=artifact_path or "",
|
|
357
399
|
)
|
|
358
400
|
|
|
401
|
+
@_restore_session
|
|
359
402
|
def download_artifacts(
|
|
360
403
|
self,
|
|
361
404
|
run_name: str,
|
|
@@ -397,6 +440,7 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
|
|
|
397
440
|
return self._run
|
|
398
441
|
return self.start_run()
|
|
399
442
|
|
|
443
|
+
@_restore_session
|
|
400
444
|
def _generate_run_name(self, experiment: entities.Experiment) -> sql_identifier.SqlIdentifier:
|
|
401
445
|
generator = hrid_generator.HRID16()
|
|
402
446
|
existing_runs = self._sql_client.show_runs_in_experiment(experiment_name=experiment.name)
|