snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
- snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
- snowflake/ml/modeling/cluster/birch.py +94 -124
- snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
- snowflake/ml/modeling/cluster/dbscan.py +94 -124
- snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
- snowflake/ml/modeling/cluster/k_means.py +93 -124
- snowflake/ml/modeling/cluster/mean_shift.py +94 -124
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
- snowflake/ml/modeling/cluster/optics.py +94 -124
- snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
- snowflake/ml/modeling/compose/column_transformer.py +94 -124
- snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
- snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
- snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
- snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
- snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
- snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
- snowflake/ml/modeling/covariance/oas.py +80 -110
- snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
- snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
- snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
- snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
- snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/pca.py +94 -124
- snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
- snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
- snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
- snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
- snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
- snowflake/ml/modeling/framework/base.py +2 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
- snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
- snowflake/ml/modeling/impute/knn_imputer.py +94 -124
- snowflake/ml/modeling/impute/missing_indicator.py +94 -124
- snowflake/ml/modeling/impute/simple_imputer.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
- snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
- snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/lars.py +96 -124
- snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
- snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
- snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
- snowflake/ml/modeling/linear_model/perceptron.py +95 -124
- snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ridge.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
- snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
- snowflake/ml/modeling/manifold/isomap.py +94 -124
- snowflake/ml/modeling/manifold/mds.py +94 -124
- snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
- snowflake/ml/modeling/manifold/tsne.py +94 -124
- snowflake/ml/modeling/metrics/classification.py +187 -52
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
- snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
- snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
- snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
- snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
- snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
- snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
- snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
- snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
- snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
- snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
- snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
- snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
- snowflake/ml/modeling/svm/linear_svc.py +96 -124
- snowflake/ml/modeling/svm/linear_svr.py +96 -124
- snowflake/ml/modeling/svm/nu_svc.py +96 -124
- snowflake/ml/modeling/svm/nu_svr.py +96 -124
- snowflake/ml/modeling/svm/svc.py +96 -124
- snowflake/ml/modeling/svm/svr.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
- snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
snowflake/ml/model/_api.py
CHANGED
@@ -7,7 +7,6 @@ from snowflake.ml._internal.exceptions import (
|
|
7
7
|
error_codes,
|
8
8
|
exceptions as snowml_exceptions,
|
9
9
|
)
|
10
|
-
from snowflake.ml._internal.utils import identifier
|
11
10
|
from snowflake.ml.model import (
|
12
11
|
deploy_platforms,
|
13
12
|
model_signature,
|
@@ -188,6 +187,10 @@ def save_model(
|
|
188
187
|
Returns:
|
189
188
|
Model
|
190
189
|
"""
|
190
|
+
if options is None:
|
191
|
+
options = {}
|
192
|
+
options["_legacy_save"] = True
|
193
|
+
|
191
194
|
m = model_composer.ModelComposer(session=session, stage_path=stage_path)
|
192
195
|
m.save(
|
193
196
|
name=name,
|
@@ -481,6 +484,7 @@ def predict(
|
|
481
484
|
# Get options
|
482
485
|
INTERMEDIATE_OBJ_NAME = "tmp_result"
|
483
486
|
sig = deployment["signature"]
|
487
|
+
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
484
488
|
|
485
489
|
# Validate and prepare input
|
486
490
|
if not isinstance(X, SnowparkDataFrame):
|
@@ -491,7 +495,7 @@ def predict(
|
|
491
495
|
else:
|
492
496
|
keep_order = False
|
493
497
|
output_with_input_features = True
|
494
|
-
model_signature._validate_snowpark_data(X, sig.inputs)
|
498
|
+
identifier_rule = model_signature._validate_snowpark_data(X, sig.inputs)
|
495
499
|
s_df = X
|
496
500
|
|
497
501
|
if statement_params:
|
@@ -500,10 +504,14 @@ def predict(
|
|
500
504
|
else:
|
501
505
|
s_df._statement_params = statement_params # type: ignore[assignment]
|
502
506
|
|
507
|
+
original_cols = s_df.columns
|
508
|
+
|
503
509
|
# Infer and get intermediate result
|
504
510
|
input_cols = []
|
505
|
-
for
|
506
|
-
literal_col_name =
|
511
|
+
for input_feature in sig.inputs:
|
512
|
+
literal_col_name = input_feature.name
|
513
|
+
col_name = identifier_rule.get_identifier_from_feature(input_feature.name)
|
514
|
+
|
507
515
|
input_cols.extend(
|
508
516
|
[
|
509
517
|
F.lit(literal_col_name),
|
@@ -511,29 +519,28 @@ def predict(
|
|
511
519
|
]
|
512
520
|
)
|
513
521
|
|
514
|
-
# TODO[shchen]: SNOW-870032, For SnowService, external function name cannot be double quoted, else it results in
|
515
|
-
# external function no found.
|
516
522
|
udf_name = deployment["name"]
|
517
|
-
output_obj = F.call_udf(udf_name, F.
|
518
|
-
|
519
|
-
if output_with_input_features:
|
520
|
-
df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
|
521
|
-
else:
|
522
|
-
df_res = s_df.select(output_obj.alias(INTERMEDIATE_OBJ_NAME))
|
523
|
+
output_obj = F.call_udf(udf_name, F.object_construct_keep_null(*input_cols))
|
524
|
+
df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
|
523
525
|
|
524
526
|
if keep_order:
|
525
527
|
df_res = df_res.order_by(
|
526
|
-
F.col(
|
528
|
+
F.col(infer_template._KEEP_ORDER_COL_NAME),
|
527
529
|
ascending=True,
|
528
530
|
)
|
529
531
|
|
532
|
+
if not output_with_input_features:
|
533
|
+
df_res = df_res.drop(*original_cols)
|
534
|
+
|
530
535
|
# Prepare the output
|
531
536
|
output_cols = []
|
537
|
+
output_col_names = []
|
532
538
|
for output_feature in sig.outputs:
|
533
539
|
output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature.name].astype(output_feature.as_snowpark_type()))
|
540
|
+
output_col_names.append(identifier_rule.get_identifier_from_feature(output_feature.name))
|
534
541
|
|
535
542
|
df_res = df_res.with_columns(
|
536
|
-
|
543
|
+
output_col_names,
|
537
544
|
output_cols,
|
538
545
|
).drop(INTERMEDIATE_OBJ_NAME)
|
539
546
|
|
@@ -0,0 +1,176 @@
|
|
1
|
+
from typing import List, Union
|
2
|
+
|
3
|
+
from snowflake.ml._internal import telemetry
|
4
|
+
from snowflake.ml._internal.utils import sql_identifier
|
5
|
+
from snowflake.ml.model._client.model import model_version_impl
|
6
|
+
from snowflake.ml.model._client.ops import model_ops
|
7
|
+
|
8
|
+
_TELEMETRY_PROJECT = "MLOps"
|
9
|
+
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
10
|
+
|
11
|
+
|
12
|
+
class Model:
|
13
|
+
"""Model Object containing multiple versions. Mapping to SQL's MODEL object."""
|
14
|
+
|
15
|
+
_model_ops: model_ops.ModelOperator
|
16
|
+
_model_name: sql_identifier.SqlIdentifier
|
17
|
+
|
18
|
+
def __init__(self) -> None:
|
19
|
+
raise RuntimeError("Model's initializer is not meant to be used. Use `get_model` from registry instead.")
|
20
|
+
|
21
|
+
@classmethod
|
22
|
+
def _ref(
|
23
|
+
cls,
|
24
|
+
model_ops: model_ops.ModelOperator,
|
25
|
+
*,
|
26
|
+
model_name: sql_identifier.SqlIdentifier,
|
27
|
+
) -> "Model":
|
28
|
+
self: "Model" = object.__new__(cls)
|
29
|
+
self._model_ops = model_ops
|
30
|
+
self._model_name = model_name
|
31
|
+
return self
|
32
|
+
|
33
|
+
def __eq__(self, __value: object) -> bool:
|
34
|
+
if not isinstance(__value, Model):
|
35
|
+
return False
|
36
|
+
return self._model_ops == __value._model_ops and self._model_name == __value._model_name
|
37
|
+
|
38
|
+
@property
|
39
|
+
def name(self) -> str:
|
40
|
+
return self._model_name.identifier()
|
41
|
+
|
42
|
+
@property
|
43
|
+
def fully_qualified_name(self) -> str:
|
44
|
+
return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
|
45
|
+
|
46
|
+
@property
|
47
|
+
@telemetry.send_api_usage_telemetry(
|
48
|
+
project=_TELEMETRY_PROJECT,
|
49
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
50
|
+
)
|
51
|
+
def description(self) -> str:
|
52
|
+
statement_params = telemetry.get_statement_params(
|
53
|
+
project=_TELEMETRY_PROJECT,
|
54
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
55
|
+
)
|
56
|
+
return self._model_ops.get_comment(
|
57
|
+
model_name=self._model_name,
|
58
|
+
statement_params=statement_params,
|
59
|
+
)
|
60
|
+
|
61
|
+
@description.setter
|
62
|
+
@telemetry.send_api_usage_telemetry(
|
63
|
+
project=_TELEMETRY_PROJECT,
|
64
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
65
|
+
)
|
66
|
+
def description(self, description: str) -> None:
|
67
|
+
statement_params = telemetry.get_statement_params(
|
68
|
+
project=_TELEMETRY_PROJECT,
|
69
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
70
|
+
)
|
71
|
+
return self._model_ops.set_comment(
|
72
|
+
comment=description,
|
73
|
+
model_name=self._model_name,
|
74
|
+
statement_params=statement_params,
|
75
|
+
)
|
76
|
+
|
77
|
+
@property
|
78
|
+
@telemetry.send_api_usage_telemetry(
|
79
|
+
project=_TELEMETRY_PROJECT,
|
80
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
81
|
+
)
|
82
|
+
def default(self) -> model_version_impl.ModelVersion:
|
83
|
+
statement_params = telemetry.get_statement_params(
|
84
|
+
project=_TELEMETRY_PROJECT,
|
85
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
86
|
+
class_name=self.__class__.__name__,
|
87
|
+
)
|
88
|
+
default_version_name = self._model_ops._model_version_client.get_default_version(
|
89
|
+
model_name=self._model_name, statement_params=statement_params
|
90
|
+
)
|
91
|
+
return self.version(default_version_name)
|
92
|
+
|
93
|
+
@default.setter
|
94
|
+
@telemetry.send_api_usage_telemetry(
|
95
|
+
project=_TELEMETRY_PROJECT,
|
96
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
97
|
+
)
|
98
|
+
def default(self, version: Union[str, model_version_impl.ModelVersion]) -> None:
|
99
|
+
statement_params = telemetry.get_statement_params(
|
100
|
+
project=_TELEMETRY_PROJECT,
|
101
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
102
|
+
class_name=self.__class__.__name__,
|
103
|
+
)
|
104
|
+
if isinstance(version, str):
|
105
|
+
version_name = sql_identifier.SqlIdentifier(version)
|
106
|
+
else:
|
107
|
+
version_name = version._version_name
|
108
|
+
self._model_ops._model_version_client.set_default_version(
|
109
|
+
model_name=self._model_name, version_name=version_name, statement_params=statement_params
|
110
|
+
)
|
111
|
+
|
112
|
+
@telemetry.send_api_usage_telemetry(
|
113
|
+
project=_TELEMETRY_PROJECT,
|
114
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
115
|
+
)
|
116
|
+
def version(self, version_name: str) -> model_version_impl.ModelVersion:
|
117
|
+
"""Get a model version object given a version name in the model.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
version_name: The name of version
|
121
|
+
|
122
|
+
Raises:
|
123
|
+
ValueError: Raised when the version requested does not exist.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
The model version object.
|
127
|
+
"""
|
128
|
+
statement_params = telemetry.get_statement_params(
|
129
|
+
project=_TELEMETRY_PROJECT,
|
130
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
131
|
+
)
|
132
|
+
version_id = sql_identifier.SqlIdentifier(version_name)
|
133
|
+
if self._model_ops.validate_existence(
|
134
|
+
model_name=self._model_name,
|
135
|
+
version_name=version_id,
|
136
|
+
statement_params=statement_params,
|
137
|
+
):
|
138
|
+
return model_version_impl.ModelVersion._ref(
|
139
|
+
self._model_ops,
|
140
|
+
model_name=self._model_name,
|
141
|
+
version_name=version_id,
|
142
|
+
)
|
143
|
+
else:
|
144
|
+
raise ValueError(
|
145
|
+
f"Unable to find version with name {version_id.identifier()} in model {self.fully_qualified_name}"
|
146
|
+
)
|
147
|
+
|
148
|
+
@telemetry.send_api_usage_telemetry(
|
149
|
+
project=_TELEMETRY_PROJECT,
|
150
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
151
|
+
)
|
152
|
+
def list_versions(self) -> List[model_version_impl.ModelVersion]:
|
153
|
+
"""List all versions in the model.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
A List of ModelVersion object representing all versions in the model.
|
157
|
+
"""
|
158
|
+
statement_params = telemetry.get_statement_params(
|
159
|
+
project=_TELEMETRY_PROJECT,
|
160
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
161
|
+
)
|
162
|
+
version_names = self._model_ops.list_models_or_versions(
|
163
|
+
model_name=self._model_name,
|
164
|
+
statement_params=statement_params,
|
165
|
+
)
|
166
|
+
return [
|
167
|
+
model_version_impl.ModelVersion._ref(
|
168
|
+
self._model_ops,
|
169
|
+
model_name=self._model_name,
|
170
|
+
version_name=version_name,
|
171
|
+
)
|
172
|
+
for version_name in version_names
|
173
|
+
]
|
174
|
+
|
175
|
+
def delete_version(self, version_name: str) -> None:
|
176
|
+
raise NotImplementedError("Deleting version has not been supported yet.")
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import TypedDict
|
2
|
+
|
3
|
+
from typing_extensions import Required
|
4
|
+
|
5
|
+
from snowflake.ml.model import model_signature
|
6
|
+
|
7
|
+
|
8
|
+
class ModelMethodInfo(TypedDict):
|
9
|
+
"""Method information.
|
10
|
+
|
11
|
+
Attributes:
|
12
|
+
name: Name of the method to be called via SQL.
|
13
|
+
target_method: actual target method name to be called.
|
14
|
+
signature: The signature of the model method.
|
15
|
+
"""
|
16
|
+
|
17
|
+
name: Required[str]
|
18
|
+
target_method: Required[str]
|
19
|
+
signature: Required[model_signature.ModelSignature]
|
@@ -0,0 +1,291 @@
|
|
1
|
+
import re
|
2
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
|
6
|
+
from snowflake.ml._internal import telemetry
|
7
|
+
from snowflake.ml._internal.utils import sql_identifier
|
8
|
+
from snowflake.ml.model import model_signature
|
9
|
+
from snowflake.ml.model._client.model import model_method_info
|
10
|
+
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
11
|
+
from snowflake.snowpark import dataframe
|
12
|
+
|
13
|
+
_TELEMETRY_PROJECT = "MLOps"
|
14
|
+
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
15
|
+
|
16
|
+
|
17
|
+
class ModelVersion:
|
18
|
+
"""Model Version Object representing a specific version of the model that could be run."""
|
19
|
+
|
20
|
+
_model_ops: model_ops.ModelOperator
|
21
|
+
_model_name: sql_identifier.SqlIdentifier
|
22
|
+
_version_name: sql_identifier.SqlIdentifier
|
23
|
+
|
24
|
+
def __init__(self) -> None:
|
25
|
+
raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def _ref(
|
29
|
+
cls,
|
30
|
+
model_ops: model_ops.ModelOperator,
|
31
|
+
*,
|
32
|
+
model_name: sql_identifier.SqlIdentifier,
|
33
|
+
version_name: sql_identifier.SqlIdentifier,
|
34
|
+
) -> "ModelVersion":
|
35
|
+
self: "ModelVersion" = object.__new__(cls)
|
36
|
+
self._model_ops = model_ops
|
37
|
+
self._model_name = model_name
|
38
|
+
self._version_name = version_name
|
39
|
+
return self
|
40
|
+
|
41
|
+
def __eq__(self, __value: object) -> bool:
|
42
|
+
if not isinstance(__value, ModelVersion):
|
43
|
+
return False
|
44
|
+
return (
|
45
|
+
self._model_ops == __value._model_ops
|
46
|
+
and self._model_name == __value._model_name
|
47
|
+
and self._version_name == __value._version_name
|
48
|
+
)
|
49
|
+
|
50
|
+
@property
|
51
|
+
def model_name(self) -> str:
|
52
|
+
return self._model_name.identifier()
|
53
|
+
|
54
|
+
@property
|
55
|
+
def version_name(self) -> str:
|
56
|
+
return self._version_name.identifier()
|
57
|
+
|
58
|
+
@property
|
59
|
+
def fully_qualified_model_name(self) -> str:
|
60
|
+
return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
|
61
|
+
|
62
|
+
@property
|
63
|
+
@telemetry.send_api_usage_telemetry(
|
64
|
+
project=_TELEMETRY_PROJECT,
|
65
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
66
|
+
)
|
67
|
+
def description(self) -> str:
|
68
|
+
statement_params = telemetry.get_statement_params(
|
69
|
+
project=_TELEMETRY_PROJECT,
|
70
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
71
|
+
)
|
72
|
+
return self._model_ops.get_comment(
|
73
|
+
model_name=self._model_name,
|
74
|
+
version_name=self._version_name,
|
75
|
+
statement_params=statement_params,
|
76
|
+
)
|
77
|
+
|
78
|
+
@description.setter
|
79
|
+
@telemetry.send_api_usage_telemetry(
|
80
|
+
project=_TELEMETRY_PROJECT,
|
81
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
82
|
+
)
|
83
|
+
def description(self, description: str) -> None:
|
84
|
+
statement_params = telemetry.get_statement_params(
|
85
|
+
project=_TELEMETRY_PROJECT,
|
86
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
87
|
+
)
|
88
|
+
return self._model_ops.set_comment(
|
89
|
+
comment=description,
|
90
|
+
model_name=self._model_name,
|
91
|
+
version_name=self._version_name,
|
92
|
+
statement_params=statement_params,
|
93
|
+
)
|
94
|
+
|
95
|
+
@telemetry.send_api_usage_telemetry(
|
96
|
+
project=_TELEMETRY_PROJECT,
|
97
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
98
|
+
)
|
99
|
+
def list_metrics(self) -> Dict[str, Any]:
|
100
|
+
"""Show all metrics logged with the model version.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
A dictionary showing the metrics
|
104
|
+
"""
|
105
|
+
statement_params = telemetry.get_statement_params(
|
106
|
+
project=_TELEMETRY_PROJECT,
|
107
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
108
|
+
)
|
109
|
+
return self._model_ops._metadata_ops.load(
|
110
|
+
model_name=self._model_name, version_name=self._version_name, statement_params=statement_params
|
111
|
+
)["metrics"]
|
112
|
+
|
113
|
+
@telemetry.send_api_usage_telemetry(
|
114
|
+
project=_TELEMETRY_PROJECT,
|
115
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
116
|
+
)
|
117
|
+
def get_metric(self, metric_name: str) -> Any:
|
118
|
+
"""Get the value of a specific metric.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
metric_name: The name of the metric
|
122
|
+
|
123
|
+
Raises:
|
124
|
+
KeyError: Raised when the requested metric name does not exist.
|
125
|
+
|
126
|
+
Returns:
|
127
|
+
The value of the metric.
|
128
|
+
"""
|
129
|
+
metrics = self.list_metrics()
|
130
|
+
if metric_name not in metrics:
|
131
|
+
raise KeyError(f"Cannot find metric with name {metric_name}.")
|
132
|
+
return metrics[metric_name]
|
133
|
+
|
134
|
+
@telemetry.send_api_usage_telemetry(
|
135
|
+
project=_TELEMETRY_PROJECT,
|
136
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
137
|
+
)
|
138
|
+
def set_metric(self, metric_name: str, value: Any) -> None:
|
139
|
+
"""Set the value of a specific metric name
|
140
|
+
|
141
|
+
Args:
|
142
|
+
metric_name: The name of the metric
|
143
|
+
value: The value of the metric.
|
144
|
+
"""
|
145
|
+
statement_params = telemetry.get_statement_params(
|
146
|
+
project=_TELEMETRY_PROJECT,
|
147
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
148
|
+
)
|
149
|
+
metrics = self.list_metrics()
|
150
|
+
metrics[metric_name] = value
|
151
|
+
self._model_ops._metadata_ops.save(
|
152
|
+
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
153
|
+
model_name=self._model_name,
|
154
|
+
version_name=self._version_name,
|
155
|
+
statement_params=statement_params,
|
156
|
+
)
|
157
|
+
|
158
|
+
@telemetry.send_api_usage_telemetry(
|
159
|
+
project=_TELEMETRY_PROJECT,
|
160
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
161
|
+
)
|
162
|
+
def delete_metric(self, metric_name: str) -> None:
|
163
|
+
"""Delete a metric from metric storage.
|
164
|
+
|
165
|
+
Args:
|
166
|
+
metric_name: The name of the metric to be deleted.
|
167
|
+
|
168
|
+
Raises:
|
169
|
+
KeyError: Raised when the requested metric name does not exist.
|
170
|
+
"""
|
171
|
+
statement_params = telemetry.get_statement_params(
|
172
|
+
project=_TELEMETRY_PROJECT,
|
173
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
174
|
+
)
|
175
|
+
metrics = self.list_metrics()
|
176
|
+
if metric_name not in metrics:
|
177
|
+
raise KeyError(f"Cannot find metric with name {metric_name}.")
|
178
|
+
del metrics[metric_name]
|
179
|
+
self._model_ops._metadata_ops.save(
|
180
|
+
metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
|
181
|
+
model_name=self._model_name,
|
182
|
+
version_name=self._version_name,
|
183
|
+
statement_params=statement_params,
|
184
|
+
)
|
185
|
+
|
186
|
+
@telemetry.send_api_usage_telemetry(
|
187
|
+
project=_TELEMETRY_PROJECT,
|
188
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
189
|
+
)
|
190
|
+
def list_methods(self) -> List[model_method_info.ModelMethodInfo]:
|
191
|
+
"""List all method information in a model version that is callable.
|
192
|
+
|
193
|
+
Returns:
|
194
|
+
A list of ModelMethodInfo object containing the following information:
|
195
|
+
- name: The name of the method to be called (both in SQL and in Python SDK).
|
196
|
+
- target_method: The original method name in the logged Python object.
|
197
|
+
- Signature: Python signature of the original method.
|
198
|
+
"""
|
199
|
+
statement_params = telemetry.get_statement_params(
|
200
|
+
project=_TELEMETRY_PROJECT,
|
201
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
202
|
+
)
|
203
|
+
# TODO(SNOW-986673, SNOW-986675): Avoid parsing manifest and meta file and put Python signature into user_data.
|
204
|
+
manifest = self._model_ops.get_model_version_manifest(
|
205
|
+
model_name=self._model_name,
|
206
|
+
version_name=self._version_name,
|
207
|
+
statement_params=statement_params,
|
208
|
+
)
|
209
|
+
model_meta = self._model_ops.get_model_version_native_packing_meta(
|
210
|
+
model_name=self._model_name,
|
211
|
+
version_name=self._version_name,
|
212
|
+
statement_params=statement_params,
|
213
|
+
)
|
214
|
+
return_methods_info: List[model_method_info.ModelMethodInfo] = []
|
215
|
+
for method in manifest["methods"]:
|
216
|
+
# Method's name is resolved so we need to use case_sensitive as True to get the user-facing identifier.
|
217
|
+
method_name = sql_identifier.SqlIdentifier(method["name"], case_sensitive=True).identifier()
|
218
|
+
# Method's handler is `functions.<target_method>.infer`
|
219
|
+
assert re.match(
|
220
|
+
r"^functions\.([^\d\W]\w*)\.infer$", method["handler"]
|
221
|
+
), f"Get unexpected handler name {method['handler']}"
|
222
|
+
target_method = method["handler"].split(".")[1]
|
223
|
+
signature_dict = model_meta["signatures"][target_method]
|
224
|
+
method_info = model_method_info.ModelMethodInfo(
|
225
|
+
name=method_name,
|
226
|
+
target_method=target_method,
|
227
|
+
signature=model_signature.ModelSignature.from_dict(signature_dict),
|
228
|
+
)
|
229
|
+
return_methods_info.append(method_info)
|
230
|
+
|
231
|
+
return return_methods_info
|
232
|
+
|
233
|
+
@telemetry.send_api_usage_telemetry(
|
234
|
+
project=_TELEMETRY_PROJECT,
|
235
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
236
|
+
)
|
237
|
+
def run(
|
238
|
+
self,
|
239
|
+
X: Union[pd.DataFrame, dataframe.DataFrame],
|
240
|
+
*,
|
241
|
+
method_name: Optional[str] = None,
|
242
|
+
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
243
|
+
"""Invoke a method in a model version object
|
244
|
+
|
245
|
+
Args:
|
246
|
+
X: The input data. Could be pandas DataFrame or Snowpark DataFrame
|
247
|
+
method_name: The method name to run. It is the name you will use to call a method in SQL. Defaults to None.
|
248
|
+
It can only be None if there is only 1 method.
|
249
|
+
|
250
|
+
Raises:
|
251
|
+
ValueError: No method with the corresponding name is available.
|
252
|
+
ValueError: There are more than 1 target methods available in the model but no method name specified.
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
The prediction data.
|
256
|
+
"""
|
257
|
+
statement_params = telemetry.get_statement_params(
|
258
|
+
project=_TELEMETRY_PROJECT,
|
259
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
260
|
+
)
|
261
|
+
|
262
|
+
methods: List[model_method_info.ModelMethodInfo] = self.list_methods()
|
263
|
+
if method_name:
|
264
|
+
req_method_name = sql_identifier.SqlIdentifier(method_name).identifier()
|
265
|
+
find_method: Callable[[model_method_info.ModelMethodInfo], bool] = (
|
266
|
+
lambda method: method["name"] == req_method_name
|
267
|
+
)
|
268
|
+
target_method_info = next(
|
269
|
+
filter(find_method, methods),
|
270
|
+
None,
|
271
|
+
)
|
272
|
+
if target_method_info is None:
|
273
|
+
raise ValueError(
|
274
|
+
f"There is no method with name {method_name} available in the model"
|
275
|
+
f" {self.fully_qualified_model_name} version {self.version_name}"
|
276
|
+
)
|
277
|
+
elif len(methods) != 1:
|
278
|
+
raise ValueError(
|
279
|
+
f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
|
280
|
+
f" version {self.version_name}. Please specify a `method_name` when calling the `run` method."
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
target_method_info = methods[0]
|
284
|
+
return self._model_ops.invoke_method(
|
285
|
+
method_name=sql_identifier.SqlIdentifier(target_method_info["name"]),
|
286
|
+
signature=target_method_info["signature"],
|
287
|
+
X=X,
|
288
|
+
model_name=self._model_name,
|
289
|
+
version_name=self._version_name,
|
290
|
+
statement_params=statement_params,
|
291
|
+
)
|
@@ -0,0 +1,107 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Any, Dict, Optional, TypedDict
|
3
|
+
|
4
|
+
from typing_extensions import NotRequired
|
5
|
+
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
7
|
+
from snowflake.ml.model._client.sql import (
|
8
|
+
model as model_sql,
|
9
|
+
model_version as model_version_sql,
|
10
|
+
)
|
11
|
+
from snowflake.snowpark import session
|
12
|
+
|
13
|
+
MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
|
14
|
+
|
15
|
+
|
16
|
+
class ModelVersionMetadataSchema(TypedDict):
|
17
|
+
metrics: NotRequired[Dict[str, Any]]
|
18
|
+
|
19
|
+
|
20
|
+
class MetadataOperator:
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
session: session.Session,
|
24
|
+
*,
|
25
|
+
database_name: sql_identifier.SqlIdentifier,
|
26
|
+
schema_name: sql_identifier.SqlIdentifier,
|
27
|
+
) -> None:
|
28
|
+
self._model_client = model_sql.ModelSQLClient(
|
29
|
+
session,
|
30
|
+
database_name=database_name,
|
31
|
+
schema_name=schema_name,
|
32
|
+
)
|
33
|
+
self._model_version_client = model_version_sql.ModelVersionSQLClient(
|
34
|
+
session,
|
35
|
+
database_name=database_name,
|
36
|
+
schema_name=schema_name,
|
37
|
+
)
|
38
|
+
|
39
|
+
def __eq__(self, __value: object) -> bool:
|
40
|
+
if not isinstance(__value, MetadataOperator):
|
41
|
+
return False
|
42
|
+
return (
|
43
|
+
self._model_client == __value._model_client and self._model_version_client == __value._model_version_client
|
44
|
+
)
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def _parse(metadata_dict: Dict[str, Any]) -> ModelVersionMetadataSchema:
|
48
|
+
loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
|
49
|
+
if loaded_metadata_schema_version is None:
|
50
|
+
return ModelVersionMetadataSchema(metrics={})
|
51
|
+
elif (
|
52
|
+
not isinstance(loaded_metadata_schema_version, str)
|
53
|
+
or loaded_metadata_schema_version != MODEL_VERSION_METADATA_SCHEMA_VERSION
|
54
|
+
):
|
55
|
+
raise ValueError(f"Unsupported model metadata schema version {loaded_metadata_schema_version} confronted.")
|
56
|
+
loaded_metrics = metadata_dict.get("metrics", {})
|
57
|
+
if not isinstance(loaded_metrics, dict):
|
58
|
+
raise ValueError(f"Metrics in the metadata is expected to be a dictionary, getting {loaded_metrics}")
|
59
|
+
return ModelVersionMetadataSchema(metrics=loaded_metrics)
|
60
|
+
|
61
|
+
def _get_current_metadata_dict(
|
62
|
+
self,
|
63
|
+
*,
|
64
|
+
model_name: sql_identifier.SqlIdentifier,
|
65
|
+
version_name: sql_identifier.SqlIdentifier,
|
66
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
67
|
+
) -> Dict[str, Any]:
|
68
|
+
version_info_list = self._model_client.show_versions(
|
69
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
70
|
+
)
|
71
|
+
assert len(version_info_list) == 1
|
72
|
+
version_info = version_info_list[0]
|
73
|
+
metadata_str = version_info.metadata
|
74
|
+
if not metadata_str:
|
75
|
+
return {}
|
76
|
+
res = json.loads(metadata_str)
|
77
|
+
if not isinstance(res, dict):
|
78
|
+
raise ValueError(f"Metadata is expected to be a dictionary, getting {res}")
|
79
|
+
return res
|
80
|
+
|
81
|
+
def load(
|
82
|
+
self,
|
83
|
+
*,
|
84
|
+
model_name: sql_identifier.SqlIdentifier,
|
85
|
+
version_name: sql_identifier.SqlIdentifier,
|
86
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
87
|
+
) -> ModelVersionMetadataSchema:
|
88
|
+
metadata_dict = self._get_current_metadata_dict(
|
89
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
90
|
+
)
|
91
|
+
return MetadataOperator._parse(metadata_dict)
|
92
|
+
|
93
|
+
def save(
|
94
|
+
self,
|
95
|
+
metadata: ModelVersionMetadataSchema,
|
96
|
+
*,
|
97
|
+
model_name: sql_identifier.SqlIdentifier,
|
98
|
+
version_name: sql_identifier.SqlIdentifier,
|
99
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
100
|
+
) -> None:
|
101
|
+
metadata_dict = self._get_current_metadata_dict(
|
102
|
+
model_name=model_name, version_name=version_name, statement_params=statement_params
|
103
|
+
)
|
104
|
+
metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION})
|
105
|
+
self._model_version_client.set_metadata(
|
106
|
+
metadata_dict, model_name=model_name, version_name=version_name, statement_params=statement_params
|
107
|
+
)
|