snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
- snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
- snowflake/ml/_internal/env_utils.py +31 -52
- snowflake/ml/_internal/file_utils.py +17 -0
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/_internal/utils/query_result_checker.py +8 -5
- snowflake/ml/_internal/utils/snowflake_env.py +95 -0
- snowflake/ml/fileset/parquet_parser.py +31 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/model_impl.py +172 -13
- snowflake/ml/model/_client/model/model_version_impl.py +96 -52
- snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
- snowflake/ml/model/_client/ops/model_ops.py +155 -9
- snowflake/ml/model/_client/sql/model.py +55 -10
- snowflake/ml/model/_client/sql/model_version.py +72 -61
- snowflake/ml/model/_client/sql/stage.py +10 -4
- snowflake/ml/model/_client/sql/tag.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/_signatures/core.py +20 -17
- snowflake/ml/model/custom_model.py +30 -27
- snowflake/ml/model/model_signature.py +16 -17
- snowflake/ml/model/type_hints.py +3 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
- snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
- snowflake/ml/modeling/_internal/model_specifications.py +3 -10
- snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
- snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
- snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
- snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
- snowflake/ml/modeling/cluster/birch.py +51 -16
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
- snowflake/ml/modeling/cluster/dbscan.py +51 -16
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
- snowflake/ml/modeling/cluster/k_means.py +51 -16
- snowflake/ml/modeling/cluster/mean_shift.py +51 -16
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
- snowflake/ml/modeling/cluster/optics.py +51 -16
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
- snowflake/ml/modeling/compose/column_transformer.py +51 -16
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
- snowflake/ml/modeling/covariance/oas.py +51 -16
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/pca.py +51 -16
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
- snowflake/ml/modeling/impute/knn_imputer.py +51 -16
- snowflake/ml/modeling/impute/missing_indicator.py +51 -16
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/lars.py +51 -16
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/perceptron.py +51 -16
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/ridge.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
- snowflake/ml/modeling/manifold/isomap.py +51 -16
- snowflake/ml/modeling/manifold/mds.py +51 -16
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
- snowflake/ml/modeling/manifold/tsne.py +51 -16
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
- snowflake/ml/modeling/svm/linear_svc.py +51 -16
- snowflake/ml/modeling/svm/linear_svr.py +51 -16
- snowflake/ml/modeling/svm/nu_svc.py +51 -16
- snowflake/ml/modeling/svm/nu_svr.py +51 -16
- snowflake/ml/modeling/svm/svc.py +51 -16
- snowflake/ml/modeling/svm/svr.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
- snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
- snowflake/ml/registry/__init__.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +163 -0
- snowflake/ml/registry/model_registry.py +12 -0
- snowflake/ml/registry/registry.py +100 -90
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
- snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
- {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
- snowflake/ml/model/_client/model/model_method_info.py +0 -19
- snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
- /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
- /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -3,11 +3,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
|
6
|
+
from snowflake import connector
|
6
7
|
from snowflake.ml._internal import telemetry
|
7
8
|
from snowflake.ml._internal.utils import sql_identifier
|
8
9
|
from snowflake.ml.model import model_signature
|
9
|
-
from snowflake.ml.model._client.model import model_method_info
|
10
10
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
11
|
+
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
11
12
|
from snowflake.snowpark import dataframe
|
12
13
|
|
13
14
|
_TELEMETRY_PROJECT = "MLOps"
|
@@ -49,14 +50,17 @@ class ModelVersion:
|
|
49
50
|
|
50
51
|
@property
|
51
52
|
def model_name(self) -> str:
|
53
|
+
"""Return the name of the model to which the model version belongs, usable as a reference in SQL."""
|
52
54
|
return self._model_name.identifier()
|
53
55
|
|
54
56
|
@property
|
55
57
|
def version_name(self) -> str:
|
58
|
+
"""Return the name of the version to which the model version belongs, usable as a reference in SQL."""
|
56
59
|
return self._version_name.identifier()
|
57
60
|
|
58
61
|
@property
|
59
62
|
def fully_qualified_model_name(self) -> str:
|
63
|
+
"""Return the fully qualified name of the model to which the model version belongs."""
|
60
64
|
return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
|
61
65
|
|
62
66
|
@property
|
@@ -65,6 +69,24 @@ class ModelVersion:
|
|
65
69
|
subproject=_TELEMETRY_SUBPROJECT,
|
66
70
|
)
|
67
71
|
def description(self) -> str:
|
72
|
+
"""The description for the model version. This is an alias of `comment`."""
|
73
|
+
return self.comment
|
74
|
+
|
75
|
+
@description.setter
|
76
|
+
@telemetry.send_api_usage_telemetry(
|
77
|
+
project=_TELEMETRY_PROJECT,
|
78
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
79
|
+
)
|
80
|
+
def description(self, description: str) -> None:
|
81
|
+
self.comment = description
|
82
|
+
|
83
|
+
@property
|
84
|
+
@telemetry.send_api_usage_telemetry(
|
85
|
+
project=_TELEMETRY_PROJECT,
|
86
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
87
|
+
)
|
88
|
+
def comment(self) -> str:
|
89
|
+
"""The comment to the model version."""
|
68
90
|
statement_params = telemetry.get_statement_params(
|
69
91
|
project=_TELEMETRY_PROJECT,
|
70
92
|
subproject=_TELEMETRY_SUBPROJECT,
|
@@ -75,18 +97,18 @@ class ModelVersion:
|
|
75
97
|
statement_params=statement_params,
|
76
98
|
)
|
77
99
|
|
78
|
-
@
|
100
|
+
@comment.setter
|
79
101
|
@telemetry.send_api_usage_telemetry(
|
80
102
|
project=_TELEMETRY_PROJECT,
|
81
103
|
subproject=_TELEMETRY_SUBPROJECT,
|
82
104
|
)
|
83
|
-
def
|
105
|
+
def comment(self, comment: str) -> None:
|
84
106
|
statement_params = telemetry.get_statement_params(
|
85
107
|
project=_TELEMETRY_PROJECT,
|
86
108
|
subproject=_TELEMETRY_SUBPROJECT,
|
87
109
|
)
|
88
110
|
return self._model_ops.set_comment(
|
89
|
-
comment=
|
111
|
+
comment=comment,
|
90
112
|
model_name=self._model_name,
|
91
113
|
version_name=self._version_name,
|
92
114
|
statement_params=statement_params,
|
@@ -96,11 +118,11 @@ class ModelVersion:
|
|
96
118
|
project=_TELEMETRY_PROJECT,
|
97
119
|
subproject=_TELEMETRY_SUBPROJECT,
|
98
120
|
)
|
99
|
-
def
|
121
|
+
def show_metrics(self) -> Dict[str, Any]:
|
100
122
|
"""Show all metrics logged with the model version.
|
101
123
|
|
102
124
|
Returns:
|
103
|
-
A dictionary showing the metrics
|
125
|
+
A dictionary showing the metrics.
|
104
126
|
"""
|
105
127
|
statement_params = telemetry.get_statement_params(
|
106
128
|
project=_TELEMETRY_PROJECT,
|
@@ -118,15 +140,15 @@ class ModelVersion:
|
|
118
140
|
"""Get the value of a specific metric.
|
119
141
|
|
120
142
|
Args:
|
121
|
-
metric_name: The name of the metric
|
143
|
+
metric_name: The name of the metric.
|
122
144
|
|
123
145
|
Raises:
|
124
|
-
KeyError:
|
146
|
+
KeyError: When the requested metric name does not exist.
|
125
147
|
|
126
148
|
Returns:
|
127
149
|
The value of the metric.
|
128
150
|
"""
|
129
|
-
metrics = self.
|
151
|
+
metrics = self.show_metrics()
|
130
152
|
if metric_name not in metrics:
|
131
153
|
raise KeyError(f"Cannot find metric with name {metric_name}.")
|
132
154
|
return metrics[metric_name]
|
@@ -136,17 +158,17 @@ class ModelVersion:
|
|
136
158
|
subproject=_TELEMETRY_SUBPROJECT,
|
137
159
|
)
|
138
160
|
def set_metric(self, metric_name: str, value: Any) -> None:
|
139
|
-
"""Set the value of a specific metric
|
161
|
+
"""Set the value of a specific metric.
|
140
162
|
|
141
163
|
Args:
|
142
|
-
metric_name: The name of the metric
|
164
|
+
metric_name: The name of the metric.
|
143
165
|
value: The value of the metric.
|
144
166
|
"""
|
145
167
|
statement_params = telemetry.get_statement_params(
|
146
168
|
project=_TELEMETRY_PROJECT,
|
147
169
|
subproject=_TELEMETRY_SUBPROJECT,
|
148
170
|
)
|
149
|
-
metrics = self.
|
171
|
+
metrics = self.show_metrics()
|
150
172
|
metrics[metric_name] = value
|
151
173
|
self._model_ops._metadata_ops.save(
|
152
174
|
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
@@ -166,13 +188,13 @@ class ModelVersion:
|
|
166
188
|
metric_name: The name of the metric to be deleted.
|
167
189
|
|
168
190
|
Raises:
|
169
|
-
KeyError:
|
191
|
+
KeyError: When the requested metric name does not exist.
|
170
192
|
"""
|
171
193
|
statement_params = telemetry.get_statement_params(
|
172
194
|
project=_TELEMETRY_PROJECT,
|
173
195
|
subproject=_TELEMETRY_SUBPROJECT,
|
174
196
|
)
|
175
|
-
metrics = self.
|
197
|
+
metrics = self.show_metrics()
|
176
198
|
if metric_name not in metrics:
|
177
199
|
raise KeyError(f"Cannot find metric with name {metric_name}.")
|
178
200
|
del metrics[metric_name]
|
@@ -183,24 +205,12 @@ class ModelVersion:
|
|
183
205
|
statement_params=statement_params,
|
184
206
|
)
|
185
207
|
|
186
|
-
|
187
|
-
|
188
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
189
|
-
)
|
190
|
-
def list_methods(self) -> List[model_method_info.ModelMethodInfo]:
|
191
|
-
"""List all method information in a model version that is callable.
|
192
|
-
|
193
|
-
Returns:
|
194
|
-
A list of ModelMethodInfo object containing the following information:
|
195
|
-
- name: The name of the method to be called (both in SQL and in Python SDK).
|
196
|
-
- target_method: The original method name in the logged Python object.
|
197
|
-
- Signature: Python signature of the original method.
|
198
|
-
"""
|
208
|
+
# Only used when the model does not contains user_data with client SDK information.
|
209
|
+
def _legacy_show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
|
199
210
|
statement_params = telemetry.get_statement_params(
|
200
211
|
project=_TELEMETRY_PROJECT,
|
201
212
|
subproject=_TELEMETRY_SUBPROJECT,
|
202
213
|
)
|
203
|
-
# TODO(SNOW-986673, SNOW-986675): Avoid parsing manifest and meta file and put Python signature into user_data.
|
204
214
|
manifest = self._model_ops.get_model_version_manifest(
|
205
215
|
model_name=self._model_name,
|
206
216
|
version_name=self._version_name,
|
@@ -211,7 +221,7 @@ class ModelVersion:
|
|
211
221
|
version_name=self._version_name,
|
212
222
|
statement_params=statement_params,
|
213
223
|
)
|
214
|
-
|
224
|
+
return_functions_info: List[model_manifest_schema.ModelFunctionInfo] = []
|
215
225
|
for method in manifest["methods"]:
|
216
226
|
# Method's name is resolved so we need to use case_sensitive as True to get the user-facing identifier.
|
217
227
|
method_name = sql_identifier.SqlIdentifier(method["name"], case_sensitive=True).identifier()
|
@@ -221,14 +231,48 @@ class ModelVersion:
|
|
221
231
|
), f"Get unexpected handler name {method['handler']}"
|
222
232
|
target_method = method["handler"].split(".")[1]
|
223
233
|
signature_dict = model_meta["signatures"][target_method]
|
224
|
-
|
234
|
+
fi = model_manifest_schema.ModelFunctionInfo(
|
225
235
|
name=method_name,
|
226
236
|
target_method=target_method,
|
227
237
|
signature=model_signature.ModelSignature.from_dict(signature_dict),
|
228
238
|
)
|
229
|
-
|
239
|
+
return_functions_info.append(fi)
|
240
|
+
return return_functions_info
|
241
|
+
|
242
|
+
@telemetry.send_api_usage_telemetry(
|
243
|
+
project=_TELEMETRY_PROJECT,
|
244
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
245
|
+
)
|
246
|
+
def show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
|
247
|
+
"""Show all functions information in a model version that is callable.
|
230
248
|
|
231
|
-
|
249
|
+
Returns:
|
250
|
+
A list of ModelFunctionInfo objects containing the following information:
|
251
|
+
|
252
|
+
- name: The name of the function to be called (both in SQL and in Python SDK).
|
253
|
+
- target_method: The original method name in the logged Python object.
|
254
|
+
- signature: Python signature of the original method.
|
255
|
+
"""
|
256
|
+
statement_params = telemetry.get_statement_params(
|
257
|
+
project=_TELEMETRY_PROJECT,
|
258
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
259
|
+
)
|
260
|
+
try:
|
261
|
+
client_data = self._model_ops.get_client_data_in_user_data(
|
262
|
+
model_name=self._model_name,
|
263
|
+
version_name=self._version_name,
|
264
|
+
statement_params=statement_params,
|
265
|
+
)
|
266
|
+
return [
|
267
|
+
model_manifest_schema.ModelFunctionInfo(
|
268
|
+
name=fi["name"],
|
269
|
+
target_method=fi["target_method"],
|
270
|
+
signature=model_signature.ModelSignature.from_dict(fi["signature"]),
|
271
|
+
)
|
272
|
+
for fi in client_data["functions"]
|
273
|
+
]
|
274
|
+
except (NotImplementedError, ValueError, connector.DataError):
|
275
|
+
return self._legacy_show_functions()
|
232
276
|
|
233
277
|
@telemetry.send_api_usage_telemetry(
|
234
278
|
project=_TELEMETRY_PROJECT,
|
@@ -238,52 +282,52 @@ class ModelVersion:
|
|
238
282
|
self,
|
239
283
|
X: Union[pd.DataFrame, dataframe.DataFrame],
|
240
284
|
*,
|
241
|
-
|
285
|
+
function_name: Optional[str] = None,
|
242
286
|
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
243
|
-
"""Invoke a method in a model version object
|
287
|
+
"""Invoke a method in a model version object.
|
244
288
|
|
245
289
|
Args:
|
246
|
-
X: The input data
|
247
|
-
|
248
|
-
It can only be None if there is only 1 method.
|
290
|
+
X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
|
291
|
+
function_name: The function name to run. It is the name used to call a function in SQL.
|
292
|
+
Defaults to None. It can only be None if there is only 1 method.
|
249
293
|
|
250
294
|
Raises:
|
251
|
-
ValueError:
|
252
|
-
ValueError:
|
295
|
+
ValueError: When no method with the corresponding name is available.
|
296
|
+
ValueError: When there are more than 1 target methods available in the model but no function name specified.
|
253
297
|
|
254
298
|
Returns:
|
255
|
-
The prediction data.
|
299
|
+
The prediction data. It would be the same type dataframe as your input.
|
256
300
|
"""
|
257
301
|
statement_params = telemetry.get_statement_params(
|
258
302
|
project=_TELEMETRY_PROJECT,
|
259
303
|
subproject=_TELEMETRY_SUBPROJECT,
|
260
304
|
)
|
261
305
|
|
262
|
-
|
263
|
-
if
|
264
|
-
req_method_name = sql_identifier.SqlIdentifier(
|
265
|
-
find_method: Callable[[
|
306
|
+
functions: List[model_manifest_schema.ModelFunctionInfo] = self.show_functions()
|
307
|
+
if function_name:
|
308
|
+
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
309
|
+
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
266
310
|
lambda method: method["name"] == req_method_name
|
267
311
|
)
|
268
|
-
|
269
|
-
filter(find_method,
|
312
|
+
target_function_info = next(
|
313
|
+
filter(find_method, functions),
|
270
314
|
None,
|
271
315
|
)
|
272
|
-
if
|
316
|
+
if target_function_info is None:
|
273
317
|
raise ValueError(
|
274
|
-
f"There is no method with name {
|
318
|
+
f"There is no method with name {function_name} available in the model"
|
275
319
|
f" {self.fully_qualified_model_name} version {self.version_name}"
|
276
320
|
)
|
277
|
-
elif len(
|
321
|
+
elif len(functions) != 1:
|
278
322
|
raise ValueError(
|
279
323
|
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
280
324
|
f" version {self.version_name}. Please specify a `method_name` when calling the `run` method."
|
281
325
|
)
|
282
326
|
else:
|
283
|
-
|
327
|
+
target_function_info = functions[0]
|
284
328
|
return self._model_ops.invoke_method(
|
285
|
-
method_name=sql_identifier.SqlIdentifier(
|
286
|
-
signature=
|
329
|
+
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
330
|
+
signature=target_function_info["signature"],
|
287
331
|
X=X,
|
288
332
|
model_name=self._model_name,
|
289
333
|
version_name=self._version_name,
|
@@ -68,9 +68,7 @@ class MetadataOperator:
|
|
68
68
|
version_info_list = self._model_client.show_versions(
|
69
69
|
model_name=model_name, version_name=version_name, statement_params=statement_params
|
70
70
|
)
|
71
|
-
|
72
|
-
version_info = version_info_list[0]
|
73
|
-
metadata_str = version_info.metadata
|
71
|
+
metadata_str = version_info_list[0][self._model_client.MODEL_VERSION_METADATA_COL_NAME]
|
74
72
|
if not metadata_str:
|
75
73
|
return {}
|
76
74
|
res = json.loads(metadata_str)
|
@@ -1,16 +1,18 @@
|
|
1
|
+
import json
|
1
2
|
import pathlib
|
2
3
|
import tempfile
|
3
4
|
from typing import Any, Dict, List, Optional, Union, cast
|
4
5
|
|
5
6
|
import yaml
|
6
7
|
|
7
|
-
from snowflake.ml._internal.utils import sql_identifier
|
8
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
8
9
|
from snowflake.ml.model import model_signature, type_hints
|
9
10
|
from snowflake.ml.model._client.ops import metadata_ops
|
10
11
|
from snowflake.ml.model._client.sql import (
|
11
12
|
model as model_sql,
|
12
13
|
model_version as model_version_sql,
|
13
14
|
stage as stage_sql,
|
15
|
+
tag as tag_sql,
|
14
16
|
)
|
15
17
|
from snowflake.ml.model._model_composer import model_composer
|
16
18
|
from snowflake.ml.model._model_composer.model_manifest import (
|
@@ -19,7 +21,7 @@ from snowflake.ml.model._model_composer.model_manifest import (
|
|
19
21
|
)
|
20
22
|
from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
|
21
23
|
from snowflake.ml.model._signatures import snowpark_handler
|
22
|
-
from snowflake.snowpark import dataframe, session
|
24
|
+
from snowflake.snowpark import dataframe, row, session
|
23
25
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
24
26
|
|
25
27
|
|
@@ -50,6 +52,11 @@ class ModelOperator:
|
|
50
52
|
database_name=database_name,
|
51
53
|
schema_name=schema_name,
|
52
54
|
)
|
55
|
+
self._tag_client = tag_sql.ModuleTagSQLClient(
|
56
|
+
session,
|
57
|
+
database_name=database_name,
|
58
|
+
schema_name=schema_name,
|
59
|
+
)
|
53
60
|
self._metadata_ops = metadata_ops.MetadataOperator(
|
54
61
|
session,
|
55
62
|
database_name=database_name,
|
@@ -109,22 +116,39 @@ class ModelOperator:
|
|
109
116
|
statement_params=statement_params,
|
110
117
|
)
|
111
118
|
|
112
|
-
def
|
119
|
+
def show_models_or_versions(
|
113
120
|
self,
|
114
121
|
*,
|
115
122
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
116
123
|
statement_params: Optional[Dict[str, Any]] = None,
|
117
|
-
) -> List[
|
124
|
+
) -> List[row.Row]:
|
118
125
|
if model_name:
|
119
|
-
|
126
|
+
return self._model_client.show_versions(
|
120
127
|
model_name=model_name,
|
128
|
+
validate_result=False,
|
121
129
|
statement_params=statement_params,
|
122
130
|
)
|
123
131
|
else:
|
124
|
-
|
132
|
+
return self._model_client.show_models(
|
133
|
+
validate_result=False,
|
125
134
|
statement_params=statement_params,
|
126
135
|
)
|
127
|
-
|
136
|
+
|
137
|
+
def list_models_or_versions(
|
138
|
+
self,
|
139
|
+
*,
|
140
|
+
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
141
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
142
|
+
) -> List[sql_identifier.SqlIdentifier]:
|
143
|
+
res = self.show_models_or_versions(
|
144
|
+
model_name=model_name,
|
145
|
+
statement_params=statement_params,
|
146
|
+
)
|
147
|
+
if model_name:
|
148
|
+
col_name = self._model_client.MODEL_VERSION_NAME_COL_NAME
|
149
|
+
else:
|
150
|
+
col_name = self._model_client.MODEL_NAME_COL_NAME
|
151
|
+
return [sql_identifier.SqlIdentifier(row[col_name], case_sensitive=True) for row in res]
|
128
152
|
|
129
153
|
def validate_existence(
|
130
154
|
self,
|
@@ -137,11 +161,13 @@ class ModelOperator:
|
|
137
161
|
res = self._model_client.show_versions(
|
138
162
|
model_name=model_name,
|
139
163
|
version_name=version_name,
|
164
|
+
validate_result=False,
|
140
165
|
statement_params=statement_params,
|
141
166
|
)
|
142
167
|
else:
|
143
168
|
res = self._model_client.show_models(
|
144
169
|
model_name=model_name,
|
170
|
+
validate_result=False,
|
145
171
|
statement_params=statement_params,
|
146
172
|
)
|
147
173
|
return len(res) == 1
|
@@ -159,13 +185,14 @@ class ModelOperator:
|
|
159
185
|
version_name=version_name,
|
160
186
|
statement_params=statement_params,
|
161
187
|
)
|
188
|
+
col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
|
162
189
|
else:
|
163
190
|
res = self._model_client.show_models(
|
164
191
|
model_name=model_name,
|
165
192
|
statement_params=statement_params,
|
166
193
|
)
|
167
|
-
|
168
|
-
return cast(str, res[0]
|
194
|
+
col_name = self._model_client.MODEL_COMMENT_COL_NAME
|
195
|
+
return cast(str, res[0][col_name])
|
169
196
|
|
170
197
|
def set_comment(
|
171
198
|
self,
|
@@ -189,6 +216,109 @@ class ModelOperator:
|
|
189
216
|
statement_params=statement_params,
|
190
217
|
)
|
191
218
|
|
219
|
+
def set_default_version(
|
220
|
+
self,
|
221
|
+
*,
|
222
|
+
model_name: sql_identifier.SqlIdentifier,
|
223
|
+
version_name: sql_identifier.SqlIdentifier,
|
224
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
225
|
+
) -> None:
|
226
|
+
if not self.validate_existence(
|
227
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
228
|
+
):
|
229
|
+
raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
|
230
|
+
self._model_version_client.set_default_version(
|
231
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
232
|
+
)
|
233
|
+
|
234
|
+
def get_default_version(
|
235
|
+
self,
|
236
|
+
*,
|
237
|
+
model_name: sql_identifier.SqlIdentifier,
|
238
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
239
|
+
) -> sql_identifier.SqlIdentifier:
|
240
|
+
res = self._model_client.show_models(model_name=model_name, statement_params=statement_params)[0]
|
241
|
+
return sql_identifier.SqlIdentifier(
|
242
|
+
res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
|
243
|
+
)
|
244
|
+
|
245
|
+
def get_tag_value(
|
246
|
+
self,
|
247
|
+
*,
|
248
|
+
model_name: sql_identifier.SqlIdentifier,
|
249
|
+
tag_database_name: sql_identifier.SqlIdentifier,
|
250
|
+
tag_schema_name: sql_identifier.SqlIdentifier,
|
251
|
+
tag_name: sql_identifier.SqlIdentifier,
|
252
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
253
|
+
) -> Optional[str]:
|
254
|
+
r = self._tag_client.get_tag_value(
|
255
|
+
module_name=model_name,
|
256
|
+
tag_database_name=tag_database_name,
|
257
|
+
tag_schema_name=tag_schema_name,
|
258
|
+
tag_name=tag_name,
|
259
|
+
statement_params=statement_params,
|
260
|
+
)
|
261
|
+
value = r.TAG_VALUE
|
262
|
+
if value is None:
|
263
|
+
return value
|
264
|
+
return str(value)
|
265
|
+
|
266
|
+
def show_tags(
|
267
|
+
self,
|
268
|
+
*,
|
269
|
+
model_name: sql_identifier.SqlIdentifier,
|
270
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
271
|
+
) -> Dict[str, str]:
|
272
|
+
tags_info = self._tag_client.get_tag_list(
|
273
|
+
module_name=model_name,
|
274
|
+
statement_params=statement_params,
|
275
|
+
)
|
276
|
+
res: Dict[str, str] = {
|
277
|
+
identifier.get_schema_level_object_identifier(
|
278
|
+
sql_identifier.SqlIdentifier(r.TAG_DATABASE, case_sensitive=True),
|
279
|
+
sql_identifier.SqlIdentifier(r.TAG_SCHEMA, case_sensitive=True),
|
280
|
+
sql_identifier.SqlIdentifier(r.TAG_NAME, case_sensitive=True),
|
281
|
+
): str(r.TAG_VALUE)
|
282
|
+
for r in tags_info
|
283
|
+
}
|
284
|
+
return res
|
285
|
+
|
286
|
+
def set_tag(
|
287
|
+
self,
|
288
|
+
*,
|
289
|
+
model_name: sql_identifier.SqlIdentifier,
|
290
|
+
tag_database_name: sql_identifier.SqlIdentifier,
|
291
|
+
tag_schema_name: sql_identifier.SqlIdentifier,
|
292
|
+
tag_name: sql_identifier.SqlIdentifier,
|
293
|
+
tag_value: str,
|
294
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
295
|
+
) -> None:
|
296
|
+
self._tag_client.set_tag_on_model(
|
297
|
+
model_name=model_name,
|
298
|
+
tag_database_name=tag_database_name,
|
299
|
+
tag_schema_name=tag_schema_name,
|
300
|
+
tag_name=tag_name,
|
301
|
+
tag_value=tag_value,
|
302
|
+
statement_params=statement_params,
|
303
|
+
)
|
304
|
+
|
305
|
+
def unset_tag(
|
306
|
+
self,
|
307
|
+
*,
|
308
|
+
model_name: sql_identifier.SqlIdentifier,
|
309
|
+
tag_database_name: sql_identifier.SqlIdentifier,
|
310
|
+
tag_schema_name: sql_identifier.SqlIdentifier,
|
311
|
+
tag_name: sql_identifier.SqlIdentifier,
|
312
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
313
|
+
) -> None:
|
314
|
+
self._tag_client.unset_tag_on_model(
|
315
|
+
model_name=model_name,
|
316
|
+
tag_database_name=tag_database_name,
|
317
|
+
tag_schema_name=tag_schema_name,
|
318
|
+
tag_name=tag_name,
|
319
|
+
statement_params=statement_params,
|
320
|
+
)
|
321
|
+
|
192
322
|
def get_model_version_manifest(
|
193
323
|
self,
|
194
324
|
*,
|
@@ -228,6 +358,22 @@ class ModelOperator:
|
|
228
358
|
raw_model_meta = yaml.safe_load(f)
|
229
359
|
return model_meta.ModelMetadata._validate_model_metadata(raw_model_meta)
|
230
360
|
|
361
|
+
def get_client_data_in_user_data(
|
362
|
+
self,
|
363
|
+
*,
|
364
|
+
model_name: sql_identifier.SqlIdentifier,
|
365
|
+
version_name: sql_identifier.SqlIdentifier,
|
366
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
367
|
+
) -> model_manifest_schema.SnowparkMLDataDict:
|
368
|
+
raw_user_data_json_string = self._model_client.show_versions(
|
369
|
+
model_name=model_name,
|
370
|
+
version_name=version_name,
|
371
|
+
statement_params=statement_params,
|
372
|
+
)[0][self._model_client.MODEL_VERSION_USER_DATA_COL_NAME]
|
373
|
+
raw_user_data = json.loads(raw_user_data_json_string)
|
374
|
+
assert isinstance(raw_user_data, dict), "user data should be a dictionary"
|
375
|
+
return model_manifest.ModelManifest.parse_client_data_from_user_data(raw_user_data)
|
376
|
+
|
231
377
|
def invoke_method(
|
232
378
|
self,
|
233
379
|
*,
|
@@ -1,10 +1,23 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
3
|
+
from snowflake.ml._internal.utils import (
|
4
|
+
identifier,
|
5
|
+
query_result_checker,
|
6
|
+
sql_identifier,
|
7
|
+
)
|
4
8
|
from snowflake.snowpark import row, session
|
5
9
|
|
6
10
|
|
7
11
|
class ModelSQLClient:
|
12
|
+
MODEL_NAME_COL_NAME = "name"
|
13
|
+
MODEL_COMMENT_COL_NAME = "comment"
|
14
|
+
MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
|
15
|
+
|
16
|
+
MODEL_VERSION_NAME_COL_NAME = "name"
|
17
|
+
MODEL_VERSION_COMMENT_COL_NAME = "comment"
|
18
|
+
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
19
|
+
MODEL_VERSION_USER_DATA_COL_NAME = "user_data"
|
20
|
+
|
8
21
|
def __init__(
|
9
22
|
self,
|
10
23
|
session: session.Session,
|
@@ -30,29 +43,56 @@ class ModelSQLClient:
|
|
30
43
|
self,
|
31
44
|
*,
|
32
45
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
46
|
+
validate_result: bool = True,
|
33
47
|
statement_params: Optional[Dict[str, Any]] = None,
|
34
48
|
) -> List[row.Row]:
|
35
49
|
fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
36
50
|
like_sql = ""
|
37
51
|
if model_name:
|
38
52
|
like_sql = f" LIKE '{model_name.resolved()}'"
|
39
|
-
res = self._session.sql(f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}")
|
40
53
|
|
41
|
-
|
54
|
+
res = (
|
55
|
+
query_result_checker.SqlResultValidator(
|
56
|
+
self._session,
|
57
|
+
f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}",
|
58
|
+
statement_params=statement_params,
|
59
|
+
)
|
60
|
+
.has_column(ModelSQLClient.MODEL_NAME_COL_NAME, allow_empty=True)
|
61
|
+
.has_column(ModelSQLClient.MODEL_COMMENT_COL_NAME, allow_empty=True)
|
62
|
+
.has_column(ModelSQLClient.MODEL_DEFAULT_VERSION_NAME_COL_NAME, allow_empty=True)
|
63
|
+
)
|
64
|
+
if validate_result and model_name:
|
65
|
+
res = res.has_dimensions(expected_rows=1)
|
66
|
+
|
67
|
+
return res.validate()
|
42
68
|
|
43
69
|
def show_versions(
|
44
70
|
self,
|
45
71
|
*,
|
46
72
|
model_name: sql_identifier.SqlIdentifier,
|
47
73
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
74
|
+
validate_result: bool = True,
|
48
75
|
statement_params: Optional[Dict[str, Any]] = None,
|
49
76
|
) -> List[row.Row]:
|
50
77
|
like_sql = ""
|
51
78
|
if version_name:
|
52
79
|
like_sql = f" LIKE '{version_name.resolved()}'"
|
53
|
-
res = self._session.sql(f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}")
|
54
80
|
|
55
|
-
|
81
|
+
res = (
|
82
|
+
query_result_checker.SqlResultValidator(
|
83
|
+
self._session,
|
84
|
+
f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}",
|
85
|
+
statement_params=statement_params,
|
86
|
+
)
|
87
|
+
.has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
88
|
+
.has_column(ModelSQLClient.MODEL_VERSION_COMMENT_COL_NAME, allow_empty=True)
|
89
|
+
.has_column(ModelSQLClient.MODEL_VERSION_METADATA_COL_NAME, allow_empty=True)
|
90
|
+
.has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
|
91
|
+
)
|
92
|
+
if validate_result and version_name:
|
93
|
+
res = res.has_dimensions(expected_rows=1)
|
94
|
+
|
95
|
+
return res.validate()
|
56
96
|
|
57
97
|
def set_comment(
|
58
98
|
self,
|
@@ -61,8 +101,11 @@ class ModelSQLClient:
|
|
61
101
|
model_name: sql_identifier.SqlIdentifier,
|
62
102
|
statement_params: Optional[Dict[str, Any]] = None,
|
63
103
|
) -> None:
|
64
|
-
|
65
|
-
|
104
|
+
query_result_checker.SqlResultValidator(
|
105
|
+
self._session,
|
106
|
+
f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$",
|
107
|
+
statement_params=statement_params,
|
108
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
66
109
|
|
67
110
|
def drop_model(
|
68
111
|
self,
|
@@ -70,6 +113,8 @@ class ModelSQLClient:
|
|
70
113
|
model_name: sql_identifier.SqlIdentifier,
|
71
114
|
statement_params: Optional[Dict[str, Any]] = None,
|
72
115
|
) -> None:
|
73
|
-
|
74
|
-
|
75
|
-
|
116
|
+
query_result_checker.SqlResultValidator(
|
117
|
+
self._session,
|
118
|
+
f"DROP MODEL {self.fully_qualified_model_name(model_name)}",
|
119
|
+
statement_params=statement_params,
|
120
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|