snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_sentiment.py +7 -4
- snowflake/ml/_internal/env_utils.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
- snowflake/ml/dataset/__init__.py +2 -1
- snowflake/ml/dataset/dataset.py +4 -3
- snowflake/ml/dataset/dataset_reader.py +5 -8
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +283 -0
- snowflake/ml/feature_store/feature_store.py +160 -100
- snowflake/ml/feature_store/feature_view.py +30 -19
- snowflake/ml/fileset/embedded_stage_fs.py +15 -12
- snowflake/ml/fileset/snowfs.py +2 -30
- snowflake/ml/fileset/stage_fs.py +25 -7
- snowflake/ml/model/_client/model/model_impl.py +46 -39
- snowflake/ml/model/_client/model/model_version_impl.py +24 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +174 -16
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +32 -39
- snowflake/ml/model/_client/sql/model_version.py +111 -42
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_model_composer/model_composer.py +8 -4
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
- snowflake/ml/modeling/cluster/birch.py +8 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
- snowflake/ml/modeling/cluster/dbscan.py +8 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
- snowflake/ml/modeling/cluster/k_means.py +8 -1
- snowflake/ml/modeling/cluster/mean_shift.py +8 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
- snowflake/ml/modeling/cluster/optics.py +8 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
- snowflake/ml/modeling/compose/column_transformer.py +8 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
- snowflake/ml/modeling/covariance/oas.py +8 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/pca.py +8 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
- snowflake/ml/modeling/framework/base.py +4 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
- snowflake/ml/modeling/impute/knn_imputer.py +8 -1
- snowflake/ml/modeling/impute/missing_indicator.py +8 -1
- snowflake/ml/modeling/impute/simple_imputer.py +21 -2
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/lars.py +8 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/perceptron.py +8 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ridge.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
- snowflake/ml/modeling/manifold/isomap.py +8 -1
- snowflake/ml/modeling/manifold/mds.py +8 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
- snowflake/ml/modeling/manifold/tsne.py +8 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
- snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +27 -7
- snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
- snowflake/ml/modeling/svm/linear_svc.py +8 -1
- snowflake/ml/modeling/svm/linear_svr.py +8 -1
- snowflake/ml/modeling/svm/nu_svc.py +8 -1
- snowflake/ml/modeling/svm/nu_svr.py +8 -1
- snowflake/ml/modeling/svm/svc.py +8 -1
- snowflake/ml/modeling/svm/svr.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
- snowflake/ml/registry/_manager/model_manager.py +95 -8
- snowflake/ml/registry/registry.py +10 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
- snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
@@ -74,37 +74,57 @@ class ModelOperator:
|
|
74
74
|
and self._model_version_client == __value._model_version_client
|
75
75
|
)
|
76
76
|
|
77
|
-
def prepare_model_stage_path(
|
77
|
+
def prepare_model_stage_path(
|
78
|
+
self,
|
79
|
+
*,
|
80
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
81
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
82
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
83
|
+
) -> str:
|
78
84
|
stage_name = sql_identifier.SqlIdentifier(
|
79
85
|
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
80
86
|
)
|
81
|
-
self._stage_client.create_tmp_stage(
|
82
|
-
|
87
|
+
self._stage_client.create_tmp_stage(
|
88
|
+
database_name=database_name,
|
89
|
+
schema_name=schema_name,
|
90
|
+
stage_name=stage_name,
|
91
|
+
statement_params=statement_params,
|
92
|
+
)
|
93
|
+
return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
|
83
94
|
|
84
95
|
def create_from_stage(
|
85
96
|
self,
|
86
97
|
composed_model: model_composer.ModelComposer,
|
87
98
|
*,
|
99
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
100
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
88
101
|
model_name: sql_identifier.SqlIdentifier,
|
89
102
|
version_name: sql_identifier.SqlIdentifier,
|
90
103
|
statement_params: Optional[Dict[str, Any]] = None,
|
91
104
|
) -> None:
|
92
105
|
stage_path = str(composed_model.stage_path)
|
93
106
|
if self.validate_existence(
|
107
|
+
database_name=database_name,
|
108
|
+
schema_name=schema_name,
|
94
109
|
model_name=model_name,
|
95
110
|
statement_params=statement_params,
|
96
111
|
):
|
97
112
|
if self.validate_existence(
|
113
|
+
database_name=database_name,
|
114
|
+
schema_name=schema_name,
|
98
115
|
model_name=model_name,
|
99
116
|
version_name=version_name,
|
100
117
|
statement_params=statement_params,
|
101
118
|
):
|
102
119
|
raise ValueError(
|
103
|
-
|
104
|
-
f"
|
120
|
+
"Model "
|
121
|
+
f"{self._model_version_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
122
|
+
f" version {version_name} already existed."
|
105
123
|
)
|
106
124
|
else:
|
107
125
|
self._model_version_client.add_version_from_stage(
|
126
|
+
database_name=database_name,
|
127
|
+
schema_name=schema_name,
|
108
128
|
stage_path=stage_path,
|
109
129
|
model_name=model_name,
|
110
130
|
version_name=version_name,
|
@@ -112,26 +132,77 @@ class ModelOperator:
|
|
112
132
|
)
|
113
133
|
else:
|
114
134
|
self._model_version_client.create_from_stage(
|
135
|
+
database_name=database_name,
|
136
|
+
schema_name=schema_name,
|
115
137
|
stage_path=stage_path,
|
116
138
|
model_name=model_name,
|
117
139
|
version_name=version_name,
|
118
140
|
statement_params=statement_params,
|
119
141
|
)
|
120
142
|
|
143
|
+
def create_from_model_version(
|
144
|
+
self,
|
145
|
+
*,
|
146
|
+
source_database_name: Optional[sql_identifier.SqlIdentifier],
|
147
|
+
source_schema_name: Optional[sql_identifier.SqlIdentifier],
|
148
|
+
source_model_name: sql_identifier.SqlIdentifier,
|
149
|
+
source_version_name: sql_identifier.SqlIdentifier,
|
150
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
151
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
152
|
+
model_name: sql_identifier.SqlIdentifier,
|
153
|
+
version_name: sql_identifier.SqlIdentifier,
|
154
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
155
|
+
) -> None:
|
156
|
+
if self.validate_existence(
|
157
|
+
database_name=database_name,
|
158
|
+
schema_name=schema_name,
|
159
|
+
model_name=model_name,
|
160
|
+
statement_params=statement_params,
|
161
|
+
):
|
162
|
+
return self._model_version_client.add_version_from_model_version(
|
163
|
+
source_database_name=source_database_name,
|
164
|
+
source_schema_name=source_schema_name,
|
165
|
+
source_model_name=source_model_name,
|
166
|
+
source_version_name=source_version_name,
|
167
|
+
database_name=database_name,
|
168
|
+
schema_name=schema_name,
|
169
|
+
model_name=model_name,
|
170
|
+
version_name=version_name,
|
171
|
+
statement_params=statement_params,
|
172
|
+
)
|
173
|
+
else:
|
174
|
+
return self._model_version_client.create_from_model_version(
|
175
|
+
source_database_name=source_database_name,
|
176
|
+
source_schema_name=source_schema_name,
|
177
|
+
source_model_name=source_model_name,
|
178
|
+
source_version_name=source_version_name,
|
179
|
+
database_name=database_name,
|
180
|
+
schema_name=schema_name,
|
181
|
+
model_name=model_name,
|
182
|
+
version_name=version_name,
|
183
|
+
statement_params=statement_params,
|
184
|
+
)
|
185
|
+
|
121
186
|
def show_models_or_versions(
|
122
187
|
self,
|
123
188
|
*,
|
189
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
190
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
124
191
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
125
192
|
statement_params: Optional[Dict[str, Any]] = None,
|
126
193
|
) -> List[row.Row]:
|
127
194
|
if model_name:
|
128
195
|
return self._model_client.show_versions(
|
196
|
+
database_name=database_name,
|
197
|
+
schema_name=schema_name,
|
129
198
|
model_name=model_name,
|
130
199
|
validate_result=False,
|
131
200
|
statement_params=statement_params,
|
132
201
|
)
|
133
202
|
else:
|
134
203
|
return self._model_client.show_models(
|
204
|
+
database_name=database_name,
|
205
|
+
schema_name=schema_name,
|
135
206
|
validate_result=False,
|
136
207
|
statement_params=statement_params,
|
137
208
|
)
|
@@ -139,10 +210,14 @@ class ModelOperator:
|
|
139
210
|
def list_models_or_versions(
|
140
211
|
self,
|
141
212
|
*,
|
213
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
214
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
142
215
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
143
216
|
statement_params: Optional[Dict[str, Any]] = None,
|
144
217
|
) -> List[sql_identifier.SqlIdentifier]:
|
145
218
|
res = self.show_models_or_versions(
|
219
|
+
database_name=database_name,
|
220
|
+
schema_name=schema_name,
|
146
221
|
model_name=model_name,
|
147
222
|
statement_params=statement_params,
|
148
223
|
)
|
@@ -155,12 +230,16 @@ class ModelOperator:
|
|
155
230
|
def validate_existence(
|
156
231
|
self,
|
157
232
|
*,
|
233
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
234
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
158
235
|
model_name: sql_identifier.SqlIdentifier,
|
159
236
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
160
237
|
statement_params: Optional[Dict[str, Any]] = None,
|
161
238
|
) -> bool:
|
162
239
|
if version_name:
|
163
240
|
res = self._model_client.show_versions(
|
241
|
+
database_name=database_name,
|
242
|
+
schema_name=schema_name,
|
164
243
|
model_name=model_name,
|
165
244
|
version_name=version_name,
|
166
245
|
validate_result=False,
|
@@ -168,6 +247,8 @@ class ModelOperator:
|
|
168
247
|
)
|
169
248
|
else:
|
170
249
|
res = self._model_client.show_models(
|
250
|
+
database_name=database_name,
|
251
|
+
schema_name=schema_name,
|
171
252
|
model_name=model_name,
|
172
253
|
validate_result=False,
|
173
254
|
statement_params=statement_params,
|
@@ -177,12 +258,16 @@ class ModelOperator:
|
|
177
258
|
def get_comment(
|
178
259
|
self,
|
179
260
|
*,
|
261
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
262
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
180
263
|
model_name: sql_identifier.SqlIdentifier,
|
181
264
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
182
265
|
statement_params: Optional[Dict[str, Any]] = None,
|
183
266
|
) -> str:
|
184
267
|
if version_name:
|
185
268
|
res = self._model_client.show_versions(
|
269
|
+
database_name=database_name,
|
270
|
+
schema_name=schema_name,
|
186
271
|
model_name=model_name,
|
187
272
|
version_name=version_name,
|
188
273
|
statement_params=statement_params,
|
@@ -190,6 +275,8 @@ class ModelOperator:
|
|
190
275
|
col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
|
191
276
|
else:
|
192
277
|
res = self._model_client.show_models(
|
278
|
+
database_name=database_name,
|
279
|
+
schema_name=schema_name,
|
193
280
|
model_name=model_name,
|
194
281
|
statement_params=statement_params,
|
195
282
|
)
|
@@ -200,6 +287,8 @@ class ModelOperator:
|
|
200
287
|
self,
|
201
288
|
*,
|
202
289
|
comment: str,
|
290
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
291
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
203
292
|
model_name: sql_identifier.SqlIdentifier,
|
204
293
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
205
294
|
statement_params: Optional[Dict[str, Any]] = None,
|
@@ -207,6 +296,8 @@ class ModelOperator:
|
|
207
296
|
if version_name:
|
208
297
|
self._model_version_client.set_comment(
|
209
298
|
comment=comment,
|
299
|
+
database_name=database_name,
|
300
|
+
schema_name=schema_name,
|
210
301
|
model_name=model_name,
|
211
302
|
version_name=version_name,
|
212
303
|
statement_params=statement_params,
|
@@ -214,6 +305,8 @@ class ModelOperator:
|
|
214
305
|
else:
|
215
306
|
self._model_client.set_comment(
|
216
307
|
comment=comment,
|
308
|
+
database_name=database_name,
|
309
|
+
schema_name=schema_name,
|
217
310
|
model_name=model_name,
|
218
311
|
statement_params=statement_params,
|
219
312
|
)
|
@@ -221,25 +314,42 @@ class ModelOperator:
|
|
221
314
|
def set_default_version(
|
222
315
|
self,
|
223
316
|
*,
|
317
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
318
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
224
319
|
model_name: sql_identifier.SqlIdentifier,
|
225
320
|
version_name: sql_identifier.SqlIdentifier,
|
226
321
|
statement_params: Optional[Dict[str, Any]] = None,
|
227
322
|
) -> None:
|
228
323
|
if not self.validate_existence(
|
229
|
-
|
324
|
+
database_name=database_name,
|
325
|
+
schema_name=schema_name,
|
326
|
+
model_name=model_name,
|
327
|
+
version_name=version_name,
|
328
|
+
statement_params=statement_params,
|
230
329
|
):
|
231
330
|
raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
|
232
331
|
self._model_version_client.set_default_version(
|
233
|
-
|
332
|
+
database_name=database_name,
|
333
|
+
schema_name=schema_name,
|
334
|
+
model_name=model_name,
|
335
|
+
version_name=version_name,
|
336
|
+
statement_params=statement_params,
|
234
337
|
)
|
235
338
|
|
236
339
|
def get_default_version(
|
237
340
|
self,
|
238
341
|
*,
|
342
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
343
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
239
344
|
model_name: sql_identifier.SqlIdentifier,
|
240
345
|
statement_params: Optional[Dict[str, Any]] = None,
|
241
346
|
) -> sql_identifier.SqlIdentifier:
|
242
|
-
res = self._model_client.show_models(
|
347
|
+
res = self._model_client.show_models(
|
348
|
+
database_name=database_name,
|
349
|
+
schema_name=schema_name,
|
350
|
+
model_name=model_name,
|
351
|
+
statement_params=statement_params,
|
352
|
+
)[0]
|
243
353
|
return sql_identifier.SqlIdentifier(
|
244
354
|
res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
|
245
355
|
)
|
@@ -247,14 +357,18 @@ class ModelOperator:
|
|
247
357
|
def get_tag_value(
|
248
358
|
self,
|
249
359
|
*,
|
360
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
361
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
250
362
|
model_name: sql_identifier.SqlIdentifier,
|
251
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
252
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
363
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
364
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
253
365
|
tag_name: sql_identifier.SqlIdentifier,
|
254
366
|
statement_params: Optional[Dict[str, Any]] = None,
|
255
367
|
) -> Optional[str]:
|
256
368
|
r = self._tag_client.get_tag_value(
|
257
|
-
|
369
|
+
database_name=database_name,
|
370
|
+
schema_name=schema_name,
|
371
|
+
model_name=model_name,
|
258
372
|
tag_database_name=tag_database_name,
|
259
373
|
tag_schema_name=tag_schema_name,
|
260
374
|
tag_name=tag_name,
|
@@ -268,11 +382,15 @@ class ModelOperator:
|
|
268
382
|
def show_tags(
|
269
383
|
self,
|
270
384
|
*,
|
385
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
386
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
271
387
|
model_name: sql_identifier.SqlIdentifier,
|
272
388
|
statement_params: Optional[Dict[str, Any]] = None,
|
273
389
|
) -> Dict[str, str]:
|
274
390
|
tags_info = self._tag_client.get_tag_list(
|
275
|
-
|
391
|
+
database_name=database_name,
|
392
|
+
schema_name=schema_name,
|
393
|
+
model_name=model_name,
|
276
394
|
statement_params=statement_params,
|
277
395
|
)
|
278
396
|
res: Dict[str, str] = {
|
@@ -288,14 +406,18 @@ class ModelOperator:
|
|
288
406
|
def set_tag(
|
289
407
|
self,
|
290
408
|
*,
|
409
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
410
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
291
411
|
model_name: sql_identifier.SqlIdentifier,
|
292
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
293
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
412
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
413
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
294
414
|
tag_name: sql_identifier.SqlIdentifier,
|
295
415
|
tag_value: str,
|
296
416
|
statement_params: Optional[Dict[str, Any]] = None,
|
297
417
|
) -> None:
|
298
418
|
self._tag_client.set_tag_on_model(
|
419
|
+
database_name=database_name,
|
420
|
+
schema_name=schema_name,
|
299
421
|
model_name=model_name,
|
300
422
|
tag_database_name=tag_database_name,
|
301
423
|
tag_schema_name=tag_schema_name,
|
@@ -307,13 +429,17 @@ class ModelOperator:
|
|
307
429
|
def unset_tag(
|
308
430
|
self,
|
309
431
|
*,
|
432
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
433
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
310
434
|
model_name: sql_identifier.SqlIdentifier,
|
311
|
-
tag_database_name: sql_identifier.SqlIdentifier,
|
312
|
-
tag_schema_name: sql_identifier.SqlIdentifier,
|
435
|
+
tag_database_name: Optional[sql_identifier.SqlIdentifier],
|
436
|
+
tag_schema_name: Optional[sql_identifier.SqlIdentifier],
|
313
437
|
tag_name: sql_identifier.SqlIdentifier,
|
314
438
|
statement_params: Optional[Dict[str, Any]] = None,
|
315
439
|
) -> None:
|
316
440
|
self._tag_client.unset_tag_on_model(
|
441
|
+
database_name=database_name,
|
442
|
+
schema_name=schema_name,
|
317
443
|
model_name=model_name,
|
318
444
|
tag_database_name=tag_database_name,
|
319
445
|
tag_schema_name=tag_schema_name,
|
@@ -324,12 +450,16 @@ class ModelOperator:
|
|
324
450
|
def get_model_version_manifest(
|
325
451
|
self,
|
326
452
|
*,
|
453
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
454
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
327
455
|
model_name: sql_identifier.SqlIdentifier,
|
328
456
|
version_name: sql_identifier.SqlIdentifier,
|
329
457
|
statement_params: Optional[Dict[str, Any]] = None,
|
330
458
|
) -> model_manifest_schema.ModelManifestDict:
|
331
459
|
with tempfile.TemporaryDirectory() as tmpdir:
|
332
460
|
self._model_version_client.get_file(
|
461
|
+
database_name=database_name,
|
462
|
+
schema_name=schema_name,
|
333
463
|
model_name=model_name,
|
334
464
|
version_name=version_name,
|
335
465
|
file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
|
@@ -362,11 +492,15 @@ class ModelOperator:
|
|
362
492
|
def get_functions(
|
363
493
|
self,
|
364
494
|
*,
|
495
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
496
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
365
497
|
model_name: sql_identifier.SqlIdentifier,
|
366
498
|
version_name: sql_identifier.SqlIdentifier,
|
367
499
|
statement_params: Optional[Dict[str, Any]] = None,
|
368
500
|
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
369
501
|
raw_model_spec_res = self._model_client.show_versions(
|
502
|
+
database_name=database_name,
|
503
|
+
schema_name=schema_name,
|
370
504
|
model_name=model_name,
|
371
505
|
version_name=version_name,
|
372
506
|
check_model_details=True,
|
@@ -375,6 +509,8 @@ class ModelOperator:
|
|
375
509
|
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
376
510
|
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
377
511
|
show_functions_res = self._model_version_client.show_functions(
|
512
|
+
database_name=database_name,
|
513
|
+
schema_name=schema_name,
|
378
514
|
model_name=model_name,
|
379
515
|
version_name=version_name,
|
380
516
|
statement_params=statement_params,
|
@@ -419,6 +555,8 @@ class ModelOperator:
|
|
419
555
|
method_function_type: str,
|
420
556
|
signature: model_signature.ModelSignature,
|
421
557
|
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
558
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
559
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
422
560
|
model_name: sql_identifier.SqlIdentifier,
|
423
561
|
version_name: sql_identifier.SqlIdentifier,
|
424
562
|
strict_input_validation: bool = False,
|
@@ -466,6 +604,8 @@ class ModelOperator:
|
|
466
604
|
input_df=s_df,
|
467
605
|
input_args=input_args,
|
468
606
|
returns=returns,
|
607
|
+
database_name=database_name,
|
608
|
+
schema_name=schema_name,
|
469
609
|
model_name=model_name,
|
470
610
|
version_name=version_name,
|
471
611
|
statement_params=statement_params,
|
@@ -477,6 +617,8 @@ class ModelOperator:
|
|
477
617
|
input_args=input_args,
|
478
618
|
partition_column=partition_column,
|
479
619
|
returns=returns,
|
620
|
+
database_name=database_name,
|
621
|
+
schema_name=schema_name,
|
480
622
|
model_name=model_name,
|
481
623
|
version_name=version_name,
|
482
624
|
statement_params=statement_params,
|
@@ -504,18 +646,24 @@ class ModelOperator:
|
|
504
646
|
def delete_model_or_version(
|
505
647
|
self,
|
506
648
|
*,
|
649
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
650
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
507
651
|
model_name: sql_identifier.SqlIdentifier,
|
508
652
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
509
653
|
statement_params: Optional[Dict[str, Any]] = None,
|
510
654
|
) -> None:
|
511
655
|
if version_name:
|
512
656
|
self._model_version_client.drop_version(
|
657
|
+
database_name=database_name,
|
658
|
+
schema_name=schema_name,
|
513
659
|
model_name=model_name,
|
514
660
|
version_name=version_name,
|
515
661
|
statement_params=statement_params,
|
516
662
|
)
|
517
663
|
else:
|
518
664
|
self._model_client.drop_model(
|
665
|
+
database_name=database_name,
|
666
|
+
schema_name=schema_name,
|
519
667
|
model_name=model_name,
|
520
668
|
statement_params=statement_params,
|
521
669
|
)
|
@@ -523,6 +671,8 @@ class ModelOperator:
|
|
523
671
|
def rename(
|
524
672
|
self,
|
525
673
|
*,
|
674
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
675
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
526
676
|
model_name: sql_identifier.SqlIdentifier,
|
527
677
|
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
528
678
|
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
@@ -530,6 +680,8 @@ class ModelOperator:
|
|
530
680
|
statement_params: Optional[Dict[str, Any]] = None,
|
531
681
|
) -> None:
|
532
682
|
self._model_client.rename(
|
683
|
+
database_name=database_name,
|
684
|
+
schema_name=schema_name,
|
533
685
|
model_name=model_name,
|
534
686
|
new_model_db=new_model_db,
|
535
687
|
new_model_schema=new_model_schema,
|
@@ -554,6 +706,8 @@ class ModelOperator:
|
|
554
706
|
def download_files(
|
555
707
|
self,
|
556
708
|
*,
|
709
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
710
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
557
711
|
model_name: sql_identifier.SqlIdentifier,
|
558
712
|
version_name: sql_identifier.SqlIdentifier,
|
559
713
|
target_path: pathlib.Path,
|
@@ -562,6 +716,8 @@ class ModelOperator:
|
|
562
716
|
) -> None:
|
563
717
|
for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
|
564
718
|
list_file_res = self._model_version_client.list_file(
|
719
|
+
database_name=database_name,
|
720
|
+
schema_name=schema_name,
|
565
721
|
model_name=model_name,
|
566
722
|
version_name=version_name,
|
567
723
|
file_path=remote_rel_path,
|
@@ -576,6 +732,8 @@ class ModelOperator:
|
|
576
732
|
local_file_dir = target_path / stage_file_path.parent
|
577
733
|
local_file_dir.mkdir(parents=True, exist_ok=True)
|
578
734
|
self._model_version_client.get_file(
|
735
|
+
database_name=database_name,
|
736
|
+
schema_name=schema_name,
|
579
737
|
model_name=model_name,
|
580
738
|
version_name=version_name,
|
581
739
|
file_path=stage_file_path,
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
4
|
+
from snowflake.snowpark import session
|
5
|
+
|
6
|
+
|
7
|
+
class _BaseSQLClient:
|
8
|
+
def __init__(
|
9
|
+
self,
|
10
|
+
session: session.Session,
|
11
|
+
*,
|
12
|
+
database_name: sql_identifier.SqlIdentifier,
|
13
|
+
schema_name: sql_identifier.SqlIdentifier,
|
14
|
+
) -> None:
|
15
|
+
self._session = session
|
16
|
+
self._database_name = database_name
|
17
|
+
self._schema_name = schema_name
|
18
|
+
|
19
|
+
def __eq__(self, __value: object) -> bool:
|
20
|
+
if not isinstance(__value, _BaseSQLClient):
|
21
|
+
return False
|
22
|
+
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
23
|
+
|
24
|
+
def fully_qualified_object_name(
|
25
|
+
self,
|
26
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
27
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
28
|
+
object_name: sql_identifier.SqlIdentifier,
|
29
|
+
) -> str:
|
30
|
+
actual_database_name = database_name or self._database_name
|
31
|
+
actual_schema_name = schema_name or self._schema_name
|
32
|
+
return identifier.get_schema_level_object_identifier(
|
33
|
+
actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier()
|
34
|
+
)
|
@@ -1,14 +1,11 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional
|
2
2
|
|
3
|
-
from snowflake.ml._internal.utils import
|
4
|
-
|
5
|
-
|
6
|
-
sql_identifier,
|
7
|
-
)
|
8
|
-
from snowflake.snowpark import row, session
|
3
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
4
|
+
from snowflake.ml.model._client.sql import _base
|
5
|
+
from snowflake.snowpark import row
|
9
6
|
|
10
7
|
|
11
|
-
class ModelSQLClient:
|
8
|
+
class ModelSQLClient(_base._BaseSQLClient):
|
12
9
|
MODEL_NAME_COL_NAME = "name"
|
13
10
|
MODEL_COMMENT_COL_NAME = "comment"
|
14
11
|
MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
|
@@ -18,35 +15,18 @@ class ModelSQLClient:
|
|
18
15
|
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
19
16
|
MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
|
20
17
|
|
21
|
-
def __init__(
|
22
|
-
self,
|
23
|
-
session: session.Session,
|
24
|
-
*,
|
25
|
-
database_name: sql_identifier.SqlIdentifier,
|
26
|
-
schema_name: sql_identifier.SqlIdentifier,
|
27
|
-
) -> None:
|
28
|
-
self._session = session
|
29
|
-
self._database_name = database_name
|
30
|
-
self._schema_name = schema_name
|
31
|
-
|
32
|
-
def __eq__(self, __value: object) -> bool:
|
33
|
-
if not isinstance(__value, ModelSQLClient):
|
34
|
-
return False
|
35
|
-
return self._database_name == __value._database_name and self._schema_name == __value._schema_name
|
36
|
-
|
37
|
-
def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
|
38
|
-
return identifier.get_schema_level_object_identifier(
|
39
|
-
self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
|
40
|
-
)
|
41
|
-
|
42
18
|
def show_models(
|
43
19
|
self,
|
44
20
|
*,
|
21
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
22
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
45
23
|
model_name: Optional[sql_identifier.SqlIdentifier] = None,
|
46
24
|
validate_result: bool = True,
|
47
25
|
statement_params: Optional[Dict[str, Any]] = None,
|
48
26
|
) -> List[row.Row]:
|
49
|
-
|
27
|
+
actual_database_name = database_name or self._database_name
|
28
|
+
actual_schema_name = schema_name or self._schema_name
|
29
|
+
fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
|
50
30
|
like_sql = ""
|
51
31
|
if model_name:
|
52
32
|
like_sql = f" LIKE '{model_name.resolved()}'"
|
@@ -69,6 +49,8 @@ class ModelSQLClient:
|
|
69
49
|
def show_versions(
|
70
50
|
self,
|
71
51
|
*,
|
52
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
53
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
72
54
|
model_name: sql_identifier.SqlIdentifier,
|
73
55
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
74
56
|
validate_result: bool = True,
|
@@ -82,7 +64,10 @@ class ModelSQLClient:
|
|
82
64
|
res = (
|
83
65
|
query_result_checker.SqlResultValidator(
|
84
66
|
self._session,
|
85
|
-
|
67
|
+
(
|
68
|
+
f"SHOW VERSIONS{like_sql} IN "
|
69
|
+
f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
70
|
+
),
|
86
71
|
statement_params=statement_params,
|
87
72
|
)
|
88
73
|
.has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
@@ -99,31 +84,40 @@ class ModelSQLClient:
|
|
99
84
|
def set_comment(
|
100
85
|
self,
|
101
86
|
*,
|
102
|
-
|
87
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
88
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
103
89
|
model_name: sql_identifier.SqlIdentifier,
|
90
|
+
comment: str,
|
104
91
|
statement_params: Optional[Dict[str, Any]] = None,
|
105
92
|
) -> None:
|
106
93
|
query_result_checker.SqlResultValidator(
|
107
94
|
self._session,
|
108
|
-
|
95
|
+
(
|
96
|
+
f"COMMENT ON MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
97
|
+
f" IS $${comment}$$"
|
98
|
+
),
|
109
99
|
statement_params=statement_params,
|
110
100
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
111
101
|
|
112
102
|
def drop_model(
|
113
103
|
self,
|
114
104
|
*,
|
105
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
106
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
115
107
|
model_name: sql_identifier.SqlIdentifier,
|
116
108
|
statement_params: Optional[Dict[str, Any]] = None,
|
117
109
|
) -> None:
|
118
110
|
query_result_checker.SqlResultValidator(
|
119
111
|
self._session,
|
120
|
-
f"DROP MODEL {self.
|
112
|
+
f"DROP MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}",
|
121
113
|
statement_params=statement_params,
|
122
114
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
123
115
|
|
124
116
|
def rename(
|
125
117
|
self,
|
126
118
|
*,
|
119
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
120
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
127
121
|
model_name: sql_identifier.SqlIdentifier,
|
128
122
|
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
129
123
|
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
@@ -131,13 +125,12 @@ class ModelSQLClient:
|
|
131
125
|
statement_params: Optional[Dict[str, Any]] = None,
|
132
126
|
) -> None:
|
133
127
|
# Use registry's database and schema if a non fully qualified new model name is provided.
|
134
|
-
new_fully_qualified_name =
|
135
|
-
new_model_db.identifier() if new_model_db else self._database_name.identifier(),
|
136
|
-
new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(),
|
137
|
-
new_model_name.identifier(),
|
138
|
-
)
|
128
|
+
new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
|
139
129
|
query_result_checker.SqlResultValidator(
|
140
130
|
self._session,
|
141
|
-
|
131
|
+
(
|
132
|
+
f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
|
133
|
+
f" RENAME TO {new_fully_qualified_name}"
|
134
|
+
),
|
142
135
|
statement_params=statement_params,
|
143
136
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|