snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
snowflake/ml/model/_api.py
CHANGED
@@ -2,6 +2,7 @@ from types import ModuleType
|
|
2
2
|
from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
3
3
|
|
4
4
|
import pandas as pd
|
5
|
+
from typing_extensions import deprecated
|
5
6
|
|
6
7
|
from snowflake.ml._internal.exceptions import (
|
7
8
|
error_codes,
|
@@ -23,6 +24,7 @@ from snowflake.ml.model._signatures import snowpark_handler
|
|
23
24
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session, functions as F
|
24
25
|
|
25
26
|
|
27
|
+
@deprecated("Only used by PrPr model registry.")
|
26
28
|
@overload
|
27
29
|
def save_model(
|
28
30
|
*,
|
@@ -61,6 +63,7 @@ def save_model(
|
|
61
63
|
...
|
62
64
|
|
63
65
|
|
66
|
+
@deprecated("Only used by PrPr model registry.")
|
64
67
|
@overload
|
65
68
|
def save_model(
|
66
69
|
*,
|
@@ -101,6 +104,7 @@ def save_model(
|
|
101
104
|
...
|
102
105
|
|
103
106
|
|
107
|
+
@deprecated("Only used by PrPr model registry.")
|
104
108
|
@overload
|
105
109
|
def save_model(
|
106
110
|
*,
|
@@ -142,6 +146,7 @@ def save_model(
|
|
142
146
|
...
|
143
147
|
|
144
148
|
|
149
|
+
@deprecated("Only used by PrPr model registry.")
|
145
150
|
def save_model(
|
146
151
|
*,
|
147
152
|
name: str,
|
@@ -208,6 +213,7 @@ def save_model(
|
|
208
213
|
return m
|
209
214
|
|
210
215
|
|
216
|
+
@deprecated("Only used by PrPr model registry.")
|
211
217
|
@overload
|
212
218
|
def load_model(*, session: Session, stage_path: str) -> model_composer.ModelComposer:
|
213
219
|
"""Load the model into memory from a zip file in the stage.
|
@@ -219,6 +225,7 @@ def load_model(*, session: Session, stage_path: str) -> model_composer.ModelComp
|
|
219
225
|
...
|
220
226
|
|
221
227
|
|
228
|
+
@deprecated("Only used by PrPr model registry.")
|
222
229
|
@overload
|
223
230
|
def load_model(*, session: Session, stage_path: str, meta_only: Literal[False]) -> model_composer.ModelComposer:
|
224
231
|
"""Load the model into memory from a zip file in the stage.
|
@@ -231,6 +238,7 @@ def load_model(*, session: Session, stage_path: str, meta_only: Literal[False])
|
|
231
238
|
...
|
232
239
|
|
233
240
|
|
241
|
+
@deprecated("Only used by PrPr model registry.")
|
234
242
|
@overload
|
235
243
|
def load_model(*, session: Session, stage_path: str, meta_only: Literal[True]) -> model_composer.ModelComposer:
|
236
244
|
"""Load the model into memory from a zip file in the stage with metadata only.
|
@@ -243,6 +251,7 @@ def load_model(*, session: Session, stage_path: str, meta_only: Literal[True]) -
|
|
243
251
|
...
|
244
252
|
|
245
253
|
|
254
|
+
@deprecated("Only used by PrPr model registry.")
|
246
255
|
def load_model(
|
247
256
|
*,
|
248
257
|
session: Session,
|
@@ -261,10 +270,11 @@ def load_model(
|
|
261
270
|
Loaded model.
|
262
271
|
"""
|
263
272
|
m = model_composer.ModelComposer(session=session, stage_path=stage_path)
|
264
|
-
m.
|
273
|
+
m.legacy_load(meta_only=meta_only)
|
265
274
|
return m
|
266
275
|
|
267
276
|
|
277
|
+
@deprecated("Only used by PrPr model registry.")
|
268
278
|
@overload
|
269
279
|
def deploy(
|
270
280
|
session: Session,
|
@@ -290,6 +300,7 @@ def deploy(
|
|
290
300
|
...
|
291
301
|
|
292
302
|
|
303
|
+
@deprecated("Only used by PrPr model registry.")
|
293
304
|
@overload
|
294
305
|
def deploy(
|
295
306
|
session: Session,
|
@@ -319,6 +330,7 @@ def deploy(
|
|
319
330
|
...
|
320
331
|
|
321
332
|
|
333
|
+
@deprecated("Only used by PrPr model registry.")
|
322
334
|
def deploy(
|
323
335
|
session: Session,
|
324
336
|
*,
|
@@ -423,6 +435,7 @@ def deploy(
|
|
423
435
|
return info
|
424
436
|
|
425
437
|
|
438
|
+
@deprecated("Only used by PrPr model registry.")
|
426
439
|
@overload
|
427
440
|
def predict(
|
428
441
|
session: Session,
|
@@ -443,6 +456,7 @@ def predict(
|
|
443
456
|
...
|
444
457
|
|
445
458
|
|
459
|
+
@deprecated("Only used by PrPr model registry.")
|
446
460
|
@overload
|
447
461
|
def predict(
|
448
462
|
session: Session,
|
@@ -462,6 +476,7 @@ def predict(
|
|
462
476
|
...
|
463
477
|
|
464
478
|
|
479
|
+
@deprecated("Only used by PrPr model registry.")
|
465
480
|
def predict(
|
466
481
|
session: Session,
|
467
482
|
*,
|
@@ -1,9 +1,9 @@
|
|
1
|
-
from typing import Dict, List, Optional,
|
1
|
+
from typing import Dict, List, Optional, Union
|
2
2
|
|
3
3
|
import pandas as pd
|
4
4
|
|
5
5
|
from snowflake.ml._internal import telemetry
|
6
|
-
from snowflake.ml._internal.utils import
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
7
7
|
from snowflake.ml.model._client.model import model_version_impl
|
8
8
|
from snowflake.ml.model._client.ops import model_ops
|
9
9
|
|
@@ -45,7 +45,7 @@ class Model:
|
|
45
45
|
@property
|
46
46
|
def fully_qualified_name(self) -> str:
|
47
47
|
"""Return the fully qualified name of the model that can be used to refer to it in SQL."""
|
48
|
-
return self._model_ops._model_version_client.
|
48
|
+
return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
|
49
49
|
|
50
50
|
@property
|
51
51
|
@telemetry.send_api_usage_telemetry(
|
@@ -76,6 +76,8 @@ class Model:
|
|
76
76
|
subproject=_TELEMETRY_SUBPROJECT,
|
77
77
|
)
|
78
78
|
return self._model_ops.get_comment(
|
79
|
+
database_name=None,
|
80
|
+
schema_name=None,
|
79
81
|
model_name=self._model_name,
|
80
82
|
statement_params=statement_params,
|
81
83
|
)
|
@@ -92,6 +94,8 @@ class Model:
|
|
92
94
|
)
|
93
95
|
return self._model_ops.set_comment(
|
94
96
|
comment=comment,
|
97
|
+
database_name=None,
|
98
|
+
schema_name=None,
|
95
99
|
model_name=self._model_name,
|
96
100
|
statement_params=statement_params,
|
97
101
|
)
|
@@ -109,7 +113,7 @@ class Model:
|
|
109
113
|
class_name=self.__class__.__name__,
|
110
114
|
)
|
111
115
|
default_version_name = self._model_ops.get_default_version(
|
112
|
-
model_name=self._model_name, statement_params=statement_params
|
116
|
+
database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
|
113
117
|
)
|
114
118
|
return self.version(default_version_name)
|
115
119
|
|
@@ -129,7 +133,11 @@ class Model:
|
|
129
133
|
else:
|
130
134
|
version_name = version._version_name
|
131
135
|
self._model_ops.set_default_version(
|
132
|
-
|
136
|
+
database_name=None,
|
137
|
+
schema_name=None,
|
138
|
+
model_name=self._model_name,
|
139
|
+
version_name=version_name,
|
140
|
+
statement_params=statement_params,
|
133
141
|
)
|
134
142
|
|
135
143
|
@telemetry.send_api_usage_telemetry(
|
@@ -155,6 +163,8 @@ class Model:
|
|
155
163
|
)
|
156
164
|
version_id = sql_identifier.SqlIdentifier(version_name)
|
157
165
|
if self._model_ops.validate_existence(
|
166
|
+
database_name=None,
|
167
|
+
schema_name=None,
|
158
168
|
model_name=self._model_name,
|
159
169
|
version_name=version_id,
|
160
170
|
statement_params=statement_params,
|
@@ -184,6 +194,8 @@ class Model:
|
|
184
194
|
subproject=_TELEMETRY_SUBPROJECT,
|
185
195
|
)
|
186
196
|
version_names = self._model_ops.list_models_or_versions(
|
197
|
+
database_name=None,
|
198
|
+
schema_name=None,
|
187
199
|
model_name=self._model_name,
|
188
200
|
statement_params=statement_params,
|
189
201
|
)
|
@@ -211,6 +223,8 @@ class Model:
|
|
211
223
|
subproject=_TELEMETRY_SUBPROJECT,
|
212
224
|
)
|
213
225
|
rows = self._model_ops.show_models_or_versions(
|
226
|
+
database_name=None,
|
227
|
+
schema_name=None,
|
214
228
|
model_name=self._model_name,
|
215
229
|
statement_params=statement_params,
|
216
230
|
)
|
@@ -231,6 +245,8 @@ class Model:
|
|
231
245
|
subproject=_TELEMETRY_SUBPROJECT,
|
232
246
|
)
|
233
247
|
self._model_ops.delete_model_or_version(
|
248
|
+
database_name=None,
|
249
|
+
schema_name=None,
|
234
250
|
model_name=self._model_name,
|
235
251
|
version_name=sql_identifier.SqlIdentifier(version_name),
|
236
252
|
statement_params=statement_params,
|
@@ -250,29 +266,9 @@ class Model:
|
|
250
266
|
project=_TELEMETRY_PROJECT,
|
251
267
|
subproject=_TELEMETRY_SUBPROJECT,
|
252
268
|
)
|
253
|
-
return self._model_ops.show_tags(
|
254
|
-
|
255
|
-
|
256
|
-
self,
|
257
|
-
tag_name: str,
|
258
|
-
) -> Tuple[sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier]:
|
259
|
-
_tag_db, _tag_schema, _tag_name, _ = identifier.parse_schema_level_object_identifier(tag_name)
|
260
|
-
if _tag_db is None:
|
261
|
-
tag_db_id = self._model_ops._model_client._database_name
|
262
|
-
else:
|
263
|
-
tag_db_id = sql_identifier.SqlIdentifier(_tag_db)
|
264
|
-
|
265
|
-
if _tag_schema is None:
|
266
|
-
tag_schema_id = self._model_ops._model_client._schema_name
|
267
|
-
else:
|
268
|
-
tag_schema_id = sql_identifier.SqlIdentifier(_tag_schema)
|
269
|
-
|
270
|
-
if _tag_name is None:
|
271
|
-
raise ValueError(f"Unable parse the tag name `{tag_name}` you input.")
|
272
|
-
|
273
|
-
tag_name_id = sql_identifier.SqlIdentifier(_tag_name)
|
274
|
-
|
275
|
-
return tag_db_id, tag_schema_id, tag_name_id
|
269
|
+
return self._model_ops.show_tags(
|
270
|
+
database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
|
271
|
+
)
|
276
272
|
|
277
273
|
@telemetry.send_api_usage_telemetry(
|
278
274
|
project=_TELEMETRY_PROJECT,
|
@@ -292,8 +288,10 @@ class Model:
|
|
292
288
|
project=_TELEMETRY_PROJECT,
|
293
289
|
subproject=_TELEMETRY_SUBPROJECT,
|
294
290
|
)
|
295
|
-
tag_db_id, tag_schema_id, tag_name_id =
|
291
|
+
tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
|
296
292
|
return self._model_ops.get_tag_value(
|
293
|
+
database_name=None,
|
294
|
+
schema_name=None,
|
297
295
|
model_name=self._model_name,
|
298
296
|
tag_database_name=tag_db_id,
|
299
297
|
tag_schema_name=tag_schema_id,
|
@@ -317,8 +315,10 @@ class Model:
|
|
317
315
|
project=_TELEMETRY_PROJECT,
|
318
316
|
subproject=_TELEMETRY_SUBPROJECT,
|
319
317
|
)
|
320
|
-
tag_db_id, tag_schema_id, tag_name_id =
|
318
|
+
tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
|
321
319
|
self._model_ops.set_tag(
|
320
|
+
database_name=None,
|
321
|
+
schema_name=None,
|
322
322
|
model_name=self._model_name,
|
323
323
|
tag_database_name=tag_db_id,
|
324
324
|
tag_schema_name=tag_schema_id,
|
@@ -342,11 +342,45 @@ class Model:
|
|
342
342
|
project=_TELEMETRY_PROJECT,
|
343
343
|
subproject=_TELEMETRY_SUBPROJECT,
|
344
344
|
)
|
345
|
-
tag_db_id, tag_schema_id, tag_name_id =
|
345
|
+
tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
|
346
346
|
self._model_ops.unset_tag(
|
347
|
+
database_name=None,
|
348
|
+
schema_name=None,
|
347
349
|
model_name=self._model_name,
|
348
350
|
tag_database_name=tag_db_id,
|
349
351
|
tag_schema_name=tag_schema_id,
|
350
352
|
tag_name=tag_name_id,
|
351
353
|
statement_params=statement_params,
|
352
354
|
)
|
355
|
+
|
356
|
+
@telemetry.send_api_usage_telemetry(
|
357
|
+
project=_TELEMETRY_PROJECT,
|
358
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
359
|
+
)
|
360
|
+
def rename(self, model_name: str) -> None:
|
361
|
+
"""Rename a model. Can be used to move a model when a fully qualified name is provided.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
model_name: The new model name.
|
365
|
+
"""
|
366
|
+
statement_params = telemetry.get_statement_params(
|
367
|
+
project=_TELEMETRY_PROJECT,
|
368
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
369
|
+
)
|
370
|
+
new_db, new_schema, new_model = sql_identifier.parse_fully_qualified_name(model_name)
|
371
|
+
|
372
|
+
self._model_ops.rename(
|
373
|
+
database_name=None,
|
374
|
+
schema_name=None,
|
375
|
+
model_name=self._model_name,
|
376
|
+
new_model_db=new_db,
|
377
|
+
new_model_schema=new_schema,
|
378
|
+
new_model_name=new_model,
|
379
|
+
statement_params=statement_params,
|
380
|
+
)
|
381
|
+
self._model_ops = model_ops.ModelOperator(
|
382
|
+
self._model_ops._session,
|
383
|
+
database_name=new_db or self._model_ops._model_client._database_name,
|
384
|
+
schema_name=new_schema or self._model_ops._model_client._schema_name,
|
385
|
+
)
|
386
|
+
self._model_name = new_model
|
@@ -1,17 +1,29 @@
|
|
1
|
+
import enum
|
2
|
+
import pathlib
|
3
|
+
import tempfile
|
4
|
+
import warnings
|
1
5
|
from typing import Any, Callable, Dict, List, Optional, Union
|
2
6
|
|
3
7
|
import pandas as pd
|
4
8
|
|
5
9
|
from snowflake.ml._internal import telemetry
|
6
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
|
+
from snowflake.ml.model import type_hints as model_types
|
7
12
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
13
|
+
from snowflake.ml.model._model_composer import model_composer
|
8
14
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
15
|
+
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
9
16
|
from snowflake.snowpark import dataframe
|
10
17
|
|
11
18
|
_TELEMETRY_PROJECT = "MLOps"
|
12
19
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
13
20
|
|
14
21
|
|
22
|
+
class ExportMode(enum.Enum):
|
23
|
+
MODEL = "model"
|
24
|
+
FULL = "full"
|
25
|
+
|
26
|
+
|
15
27
|
class ModelVersion:
|
16
28
|
"""Model Version Object representing a specific version of the model that could be run."""
|
17
29
|
|
@@ -60,7 +72,7 @@ class ModelVersion:
|
|
60
72
|
@property
|
61
73
|
def fully_qualified_model_name(self) -> str:
|
62
74
|
"""Return the fully qualified name of the model to which the model version belongs."""
|
63
|
-
return self._model_ops._model_version_client.
|
75
|
+
return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
|
64
76
|
|
65
77
|
@property
|
66
78
|
@telemetry.send_api_usage_telemetry(
|
@@ -91,6 +103,8 @@ class ModelVersion:
|
|
91
103
|
subproject=_TELEMETRY_SUBPROJECT,
|
92
104
|
)
|
93
105
|
return self._model_ops.get_comment(
|
106
|
+
database_name=None,
|
107
|
+
schema_name=None,
|
94
108
|
model_name=self._model_name,
|
95
109
|
version_name=self._version_name,
|
96
110
|
statement_params=statement_params,
|
@@ -108,6 +122,8 @@ class ModelVersion:
|
|
108
122
|
)
|
109
123
|
return self._model_ops.set_comment(
|
110
124
|
comment=comment,
|
125
|
+
database_name=None,
|
126
|
+
schema_name=None,
|
111
127
|
model_name=self._model_name,
|
112
128
|
version_name=self._version_name,
|
113
129
|
statement_params=statement_params,
|
@@ -128,7 +144,11 @@ class ModelVersion:
|
|
128
144
|
subproject=_TELEMETRY_SUBPROJECT,
|
129
145
|
)
|
130
146
|
return self._model_ops._metadata_ops.load(
|
131
|
-
|
147
|
+
database_name=None,
|
148
|
+
schema_name=None,
|
149
|
+
model_name=self._model_name,
|
150
|
+
version_name=self._version_name,
|
151
|
+
statement_params=statement_params,
|
132
152
|
)["metrics"]
|
133
153
|
|
134
154
|
@telemetry.send_api_usage_telemetry(
|
@@ -171,6 +191,8 @@ class ModelVersion:
|
|
171
191
|
metrics[metric_name] = value
|
172
192
|
self._model_ops._metadata_ops.save(
|
173
193
|
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
194
|
+
database_name=None,
|
195
|
+
schema_name=None,
|
174
196
|
model_name=self._model_name,
|
175
197
|
version_name=self._version_name,
|
176
198
|
statement_params=statement_params,
|
@@ -199,6 +221,8 @@ class ModelVersion:
|
|
199
221
|
del metrics[metric_name]
|
200
222
|
self._model_ops._metadata_ops.save(
|
201
223
|
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
224
|
+
database_name=None,
|
225
|
+
schema_name=None,
|
202
226
|
model_name=self._model_name,
|
203
227
|
version_name=self._version_name,
|
204
228
|
statement_params=statement_params,
|
@@ -210,6 +234,8 @@ class ModelVersion:
|
|
210
234
|
subproject=_TELEMETRY_SUBPROJECT,
|
211
235
|
)
|
212
236
|
return self._model_ops.get_functions(
|
237
|
+
database_name=None,
|
238
|
+
schema_name=None,
|
213
239
|
model_name=self._model_name,
|
214
240
|
version_name=self._version_name,
|
215
241
|
statement_params=statement_params,
|
@@ -240,6 +266,7 @@ class ModelVersion:
|
|
240
266
|
X: Union[pd.DataFrame, dataframe.DataFrame],
|
241
267
|
*,
|
242
268
|
function_name: Optional[str] = None,
|
269
|
+
partition_column: Optional[str] = None,
|
243
270
|
strict_input_validation: bool = False,
|
244
271
|
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
245
272
|
"""Invoke a method in a model version object.
|
@@ -248,12 +275,14 @@ class ModelVersion:
|
|
248
275
|
X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
|
249
276
|
function_name: The function name to run. It is the name used to call a function in SQL.
|
250
277
|
Defaults to None. It can only be None if there is only 1 method.
|
278
|
+
partition_column: The partition column name to partition by.
|
251
279
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
252
280
|
type validation to make sure your input data won't overflow when providing to the model.
|
253
281
|
|
254
282
|
Raises:
|
255
283
|
ValueError: When no method with the corresponding name is available.
|
256
284
|
ValueError: When there are more than 1 target methods available in the model but no function name specified.
|
285
|
+
ValueError: When the partition column is not a valid Snowflake identifier.
|
257
286
|
|
258
287
|
Returns:
|
259
288
|
The prediction data. It would be the same type dataframe as your input.
|
@@ -263,6 +292,10 @@ class ModelVersion:
|
|
263
292
|
subproject=_TELEMETRY_SUBPROJECT,
|
264
293
|
)
|
265
294
|
|
295
|
+
if partition_column is not None:
|
296
|
+
# Partition column must be a valid identifier
|
297
|
+
partition_column = sql_identifier.SqlIdentifier(partition_column)
|
298
|
+
|
266
299
|
functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
|
267
300
|
if function_name:
|
268
301
|
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
@@ -287,10 +320,134 @@ class ModelVersion:
|
|
287
320
|
target_function_info = functions[0]
|
288
321
|
return self._model_ops.invoke_method(
|
289
322
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
323
|
+
method_function_type=target_function_info["target_method_function_type"],
|
290
324
|
signature=target_function_info["signature"],
|
291
325
|
X=X,
|
326
|
+
database_name=None,
|
327
|
+
schema_name=None,
|
292
328
|
model_name=self._model_name,
|
293
329
|
version_name=self._version_name,
|
294
330
|
strict_input_validation=strict_input_validation,
|
331
|
+
partition_column=partition_column,
|
332
|
+
statement_params=statement_params,
|
333
|
+
)
|
334
|
+
|
335
|
+
@telemetry.send_api_usage_telemetry(
|
336
|
+
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
|
337
|
+
)
|
338
|
+
def export(self, target_path: str, *, export_mode: ExportMode = ExportMode.MODEL) -> None:
|
339
|
+
"""Export model files to a local directory.
|
340
|
+
|
341
|
+
Args:
|
342
|
+
target_path: Path to a local directory to export files to. A directory will be created if does not exist.
|
343
|
+
export_mode: The mode to export the model. Defaults to ExportMode.MODEL.
|
344
|
+
ExportMode.MODEL: All model files including environment to load the model and model weights.
|
345
|
+
ExportMode.FULL: Additional files to run the model in Warehouse, besides all files in MODEL mode,
|
346
|
+
|
347
|
+
Raises:
|
348
|
+
ValueError: Raised when the target path is a file or an non-empty folder.
|
349
|
+
"""
|
350
|
+
target_local_path = pathlib.Path(target_path)
|
351
|
+
if target_local_path.is_file() or any(target_local_path.iterdir()):
|
352
|
+
raise ValueError(f"Target path {target_local_path} is a file or an non-empty folder.")
|
353
|
+
|
354
|
+
target_local_path.mkdir(parents=False, exist_ok=True)
|
355
|
+
statement_params = telemetry.get_statement_params(
|
356
|
+
project=_TELEMETRY_PROJECT,
|
357
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
358
|
+
)
|
359
|
+
self._model_ops.download_files(
|
360
|
+
database_name=None,
|
361
|
+
schema_name=None,
|
362
|
+
model_name=self._model_name,
|
363
|
+
version_name=self._version_name,
|
364
|
+
target_path=target_local_path,
|
365
|
+
mode=export_mode.value,
|
295
366
|
statement_params=statement_params,
|
296
367
|
)
|
368
|
+
|
369
|
+
@telemetry.send_api_usage_telemetry(
|
370
|
+
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["force", "options"]
|
371
|
+
)
|
372
|
+
def load(
|
373
|
+
self,
|
374
|
+
*,
|
375
|
+
force: bool = False,
|
376
|
+
options: Optional[model_types.ModelLoadOption] = None,
|
377
|
+
) -> model_types.SupportedModelType:
|
378
|
+
"""Load the underlying original Python object back from a model.
|
379
|
+
This operation requires to have the exact the same environment as the one when logging the model, otherwise,
|
380
|
+
the model might be not functional or some other problems might occur.
|
381
|
+
|
382
|
+
Args:
|
383
|
+
force: Bypass the best-effort environment validation. Defaults to False.
|
384
|
+
options: Options to specify when loading the model, check `snowflake.ml.model.type_hints` for available
|
385
|
+
options. Defaults to None.
|
386
|
+
|
387
|
+
Raises:
|
388
|
+
ValueError: Raised when the best-effort environment validation fails.
|
389
|
+
|
390
|
+
Returns:
|
391
|
+
The original Python object loaded from the model object.
|
392
|
+
"""
|
393
|
+
statement_params = telemetry.get_statement_params(
|
394
|
+
project=_TELEMETRY_PROJECT,
|
395
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
396
|
+
)
|
397
|
+
if not force:
|
398
|
+
with tempfile.TemporaryDirectory() as tmp_workspace_for_validation:
|
399
|
+
ws_path_for_validation = pathlib.Path(tmp_workspace_for_validation)
|
400
|
+
self._model_ops.download_files(
|
401
|
+
database_name=None,
|
402
|
+
schema_name=None,
|
403
|
+
model_name=self._model_name,
|
404
|
+
version_name=self._version_name,
|
405
|
+
target_path=ws_path_for_validation,
|
406
|
+
mode="minimal",
|
407
|
+
statement_params=statement_params,
|
408
|
+
)
|
409
|
+
pk_for_validation = model_composer.ModelComposer.load(
|
410
|
+
ws_path_for_validation, meta_only=True, options=options
|
411
|
+
)
|
412
|
+
assert pk_for_validation.meta, (
|
413
|
+
"Unable to load model metadata for validation. "
|
414
|
+
f"model_name={self._model_name}, version_name={self._version_name}"
|
415
|
+
)
|
416
|
+
|
417
|
+
validation_errors = pk_for_validation.meta.env.validate_with_local_env(
|
418
|
+
check_snowpark_ml_version=(
|
419
|
+
pk_for_validation.meta.model_type == snowmlmodel.SnowMLModelHandler.HANDLER_TYPE
|
420
|
+
)
|
421
|
+
)
|
422
|
+
if validation_errors:
|
423
|
+
raise ValueError(
|
424
|
+
f"Unable to load this model due to following validation errors: {validation_errors}. "
|
425
|
+
"Make sure your local environment is the same as that when you logged the model, "
|
426
|
+
"or if you believe it should work, specify `force=True` to bypass this check."
|
427
|
+
)
|
428
|
+
|
429
|
+
warnings.warn(
|
430
|
+
"Loading model requires to have the exact the same environment as the one when "
|
431
|
+
"logging the model, otherwise, the model might be not functional or "
|
432
|
+
"some other problems might occur.",
|
433
|
+
category=RuntimeWarning,
|
434
|
+
stacklevel=2,
|
435
|
+
)
|
436
|
+
|
437
|
+
# We need the folder to be existed.
|
438
|
+
workspace = pathlib.Path(tempfile.mkdtemp())
|
439
|
+
self._model_ops.download_files(
|
440
|
+
database_name=None,
|
441
|
+
schema_name=None,
|
442
|
+
model_name=self._model_name,
|
443
|
+
version_name=self._version_name,
|
444
|
+
target_path=workspace,
|
445
|
+
mode="model",
|
446
|
+
statement_params=statement_params,
|
447
|
+
)
|
448
|
+
pk = model_composer.ModelComposer.load(workspace, meta_only=False, options=options)
|
449
|
+
assert pk.model, (
|
450
|
+
"Unable to load model. "
|
451
|
+
f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
|
452
|
+
)
|
453
|
+
return pk.model
|
@@ -61,12 +61,18 @@ class MetadataOperator:
|
|
61
61
|
def _get_current_metadata_dict(
|
62
62
|
self,
|
63
63
|
*,
|
64
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
65
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
64
66
|
model_name: sql_identifier.SqlIdentifier,
|
65
67
|
version_name: sql_identifier.SqlIdentifier,
|
66
68
|
statement_params: Optional[Dict[str, Any]] = None,
|
67
69
|
) -> Dict[str, Any]:
|
68
70
|
version_info_list = self._model_client.show_versions(
|
69
|
-
|
71
|
+
database_name=database_name,
|
72
|
+
schema_name=schema_name,
|
73
|
+
model_name=model_name,
|
74
|
+
version_name=version_name,
|
75
|
+
statement_params=statement_params,
|
70
76
|
)
|
71
77
|
metadata_str = version_info_list[0][self._model_client.MODEL_VERSION_METADATA_COL_NAME]
|
72
78
|
if not metadata_str:
|
@@ -79,12 +85,18 @@ class MetadataOperator:
|
|
79
85
|
def load(
|
80
86
|
self,
|
81
87
|
*,
|
88
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
89
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
82
90
|
model_name: sql_identifier.SqlIdentifier,
|
83
91
|
version_name: sql_identifier.SqlIdentifier,
|
84
92
|
statement_params: Optional[Dict[str, Any]] = None,
|
85
93
|
) -> ModelVersionMetadataSchema:
|
86
94
|
metadata_dict = self._get_current_metadata_dict(
|
87
|
-
|
95
|
+
database_name=database_name,
|
96
|
+
schema_name=schema_name,
|
97
|
+
model_name=model_name,
|
98
|
+
version_name=version_name,
|
99
|
+
statement_params=statement_params,
|
88
100
|
)
|
89
101
|
return MetadataOperator._parse(metadata_dict)
|
90
102
|
|
@@ -92,14 +104,25 @@ class MetadataOperator:
|
|
92
104
|
self,
|
93
105
|
metadata: ModelVersionMetadataSchema,
|
94
106
|
*,
|
107
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
108
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
95
109
|
model_name: sql_identifier.SqlIdentifier,
|
96
110
|
version_name: sql_identifier.SqlIdentifier,
|
97
111
|
statement_params: Optional[Dict[str, Any]] = None,
|
98
112
|
) -> None:
|
99
113
|
metadata_dict = self._get_current_metadata_dict(
|
100
|
-
|
114
|
+
database_name=database_name,
|
115
|
+
schema_name=schema_name,
|
116
|
+
model_name=model_name,
|
117
|
+
version_name=version_name,
|
118
|
+
statement_params=statement_params,
|
101
119
|
)
|
102
120
|
metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION})
|
103
121
|
self._model_version_client.set_metadata(
|
104
|
-
metadata_dict,
|
122
|
+
metadata_dict,
|
123
|
+
database_name=database_name,
|
124
|
+
schema_name=schema_name,
|
125
|
+
model_name=model_name,
|
126
|
+
version_name=version_name,
|
127
|
+
statement_params=statement_params,
|
105
128
|
)
|