snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/sql_identifier.py +25 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +19 -2
- snowflake/ml/feature_store/feature_view.py +82 -28
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +78 -9
- snowflake/ml/model/_client/ops/model_ops.py +89 -7
- snowflake/ml/model/_client/ops/service_ops.py +200 -91
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +47 -13
- snowflake/ml/model/_model_composer/model_composer.py +11 -41
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
- snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
- snowflake/ml/model/_packager/model_packager.py +14 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/type_hints.py +11 -152
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
- snowflake/ml/modeling/cluster/birch.py +1 -0
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
- snowflake/ml/modeling/cluster/dbscan.py +1 -0
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
- snowflake/ml/modeling/cluster/k_means.py +1 -0
- snowflake/ml/modeling/cluster/mean_shift.py +1 -0
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
- snowflake/ml/modeling/cluster/optics.py +1 -0
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
- snowflake/ml/modeling/compose/column_transformer.py +1 -0
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
- snowflake/ml/modeling/covariance/oas.py +1 -0
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/pca.py +1 -0
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
- snowflake/ml/modeling/impute/knn_imputer.py +1 -0
- snowflake/ml/modeling/impute/missing_indicator.py +1 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/lars.py +1 -0
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/perceptron.py +1 -0
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ridge.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
- snowflake/ml/modeling/manifold/isomap.py +1 -0
- snowflake/ml/modeling/manifold/mds.py +1 -0
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
- snowflake/ml/modeling/manifold/tsne.py +1 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
- snowflake/ml/modeling/pipeline/pipeline.py +0 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
- snowflake/ml/modeling/svm/linear_svc.py +1 -0
- snowflake/ml/modeling/svm/linear_svr.py +1 -0
- snowflake/ml/modeling/svm/nu_svc.py +1 -0
- snowflake/ml/modeling/svm/nu_svr.py +1 -0
- snowflake/ml/modeling/svm/svc.py +1 -0
- snowflake/ml/modeling/svm/svr.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -4
- snowflake/ml/registry/registry.py +165 -6
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +24 -9
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/RECORD +225 -249
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -106
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,2048 +0,0 @@
|
|
1
|
-
import inspect
|
2
|
-
import json
|
3
|
-
import sys
|
4
|
-
import textwrap
|
5
|
-
import types
|
6
|
-
import warnings
|
7
|
-
from typing import (
|
8
|
-
TYPE_CHECKING,
|
9
|
-
Any,
|
10
|
-
Callable,
|
11
|
-
Dict,
|
12
|
-
List,
|
13
|
-
Optional,
|
14
|
-
Tuple,
|
15
|
-
Union,
|
16
|
-
cast,
|
17
|
-
)
|
18
|
-
from uuid import uuid1
|
19
|
-
|
20
|
-
from absl import logging
|
21
|
-
|
22
|
-
from snowflake import connector, snowpark
|
23
|
-
from snowflake.ml._internal import telemetry
|
24
|
-
from snowflake.ml._internal.utils import (
|
25
|
-
formatting,
|
26
|
-
identifier,
|
27
|
-
query_result_checker,
|
28
|
-
spcs_attribution_utils,
|
29
|
-
table_manager,
|
30
|
-
uri,
|
31
|
-
)
|
32
|
-
from snowflake.ml.model import (
|
33
|
-
_api as model_api,
|
34
|
-
deploy_platforms,
|
35
|
-
model_signature,
|
36
|
-
type_hints as model_types,
|
37
|
-
)
|
38
|
-
from snowflake.ml.registry import _initial_schema, _schema_version_manager
|
39
|
-
from snowflake.snowpark._internal import utils as snowpark_utils
|
40
|
-
|
41
|
-
if TYPE_CHECKING:
|
42
|
-
import pandas as pd
|
43
|
-
|
44
|
-
_DEFAULT_REGISTRY_NAME: str = "_SYSTEM_MODEL_REGISTRY"
|
45
|
-
_DEFAULT_SCHEMA_NAME: str = "_SYSTEM_MODEL_REGISTRY_SCHEMA"
|
46
|
-
_MODELS_TABLE_NAME: str = "_SYSTEM_REGISTRY_MODELS"
|
47
|
-
_METADATA_TABLE_NAME: str = "_SYSTEM_REGISTRY_METADATA"
|
48
|
-
_DEPLOYMENT_TABLE_NAME: str = "_SYSTEM_REGISTRY_DEPLOYMENTS"
|
49
|
-
|
50
|
-
# Metadata operation types.
|
51
|
-
_SET_METADATA_OPERATION: str = "SET"
|
52
|
-
_ADD_METADATA_OPERATION: str = "ADD"
|
53
|
-
_DROP_METADATA_OPERATION: str = "DROP"
|
54
|
-
|
55
|
-
# Metadata types.
|
56
|
-
_METADATA_ATTRIBUTE_DESCRIPTION: str = "DESCRIPTION"
|
57
|
-
_METADATA_ATTRIBUTE_METRICS: str = "METRICS"
|
58
|
-
_METADATA_ATTRIBUTE_REGISTRATION: str = "REGISTRATION"
|
59
|
-
_METADATA_ATTRIBUTE_TAGS: str = "TAGS"
|
60
|
-
_METADATA_ATTRIBUTE_DEPLOYMENT: str = "DEPLOYMENTS"
|
61
|
-
_METADATA_ATTRIBUTE_DELETION: str = "DELETION"
|
62
|
-
|
63
|
-
# Leaving out REGISTRATION/DEPLOYMENT events as they will be handled differently from all mutable attributes.
|
64
|
-
_LIST_METADATA_ATTRIBUTE: List[str] = [
|
65
|
-
_METADATA_ATTRIBUTE_DESCRIPTION,
|
66
|
-
_METADATA_ATTRIBUTE_METRICS,
|
67
|
-
_METADATA_ATTRIBUTE_TAGS,
|
68
|
-
]
|
69
|
-
_TELEMETRY_PROJECT = "MLOps"
|
70
|
-
_TELEMETRY_SUBPROJECT = "ModelRegistry"
|
71
|
-
|
72
|
-
_STAGE_PREFIX = "@"
|
73
|
-
|
74
|
-
|
75
|
-
def _create_registry_database(
|
76
|
-
session: snowpark.Session,
|
77
|
-
database_name: str,
|
78
|
-
statement_params: Dict[str, Any],
|
79
|
-
) -> None:
|
80
|
-
"""Private helper to create the model registry database.
|
81
|
-
|
82
|
-
The creation will be skipped if the target database already exists.
|
83
|
-
|
84
|
-
Args:
|
85
|
-
session: Session object to communicate with Snowflake.
|
86
|
-
database_name: Desired name of the model registry database.
|
87
|
-
statement_params: Function usage statement parameters used in sql query executions.
|
88
|
-
"""
|
89
|
-
registry_databases = session.sql(f"SHOW DATABASES LIKE '{identifier.get_unescaped_names(database_name)}'").collect(
|
90
|
-
statement_params=statement_params
|
91
|
-
)
|
92
|
-
if len(registry_databases) > 0:
|
93
|
-
logging.warning(f"The database {database_name} already exists. Skipping creation.")
|
94
|
-
return
|
95
|
-
|
96
|
-
session.sql(f"CREATE DATABASE {database_name}").collect(statement_params=statement_params)
|
97
|
-
|
98
|
-
|
99
|
-
def _create_registry_schema(
|
100
|
-
session: snowpark.Session,
|
101
|
-
database_name: str,
|
102
|
-
schema_name: str,
|
103
|
-
statement_params: Dict[str, Any],
|
104
|
-
) -> None:
|
105
|
-
"""Private helper to create the model registry schema.
|
106
|
-
|
107
|
-
The creation will be skipped if the target schema already exists.
|
108
|
-
|
109
|
-
Args:
|
110
|
-
session: Session object to communicate with Snowflake.
|
111
|
-
database_name: Desired name of the model registry database.
|
112
|
-
schema_name: Desired name of the schema used by this model registry inside the database.
|
113
|
-
statement_params: Function usage statement parameters used in sql query executions.
|
114
|
-
"""
|
115
|
-
# The default PUBLIC schema is created by default so it might already exist even in a new database.
|
116
|
-
registry_schemas = session.sql(
|
117
|
-
f"SHOW SCHEMAS LIKE '{identifier.get_unescaped_names(schema_name)}' IN DATABASE {database_name}"
|
118
|
-
).collect(statement_params=statement_params)
|
119
|
-
|
120
|
-
if len(registry_schemas) > 0:
|
121
|
-
logging.warning(
|
122
|
-
f"The schema {table_manager.get_fully_qualified_schema_name(database_name, schema_name)} already exists. "
|
123
|
-
+ "Skipping creation."
|
124
|
-
)
|
125
|
-
return
|
126
|
-
|
127
|
-
session.sql(f"CREATE SCHEMA {table_manager.get_fully_qualified_schema_name(database_name, schema_name)}").collect(
|
128
|
-
statement_params=statement_params
|
129
|
-
)
|
130
|
-
|
131
|
-
|
132
|
-
def _create_registry_views(
|
133
|
-
session: snowpark.Session,
|
134
|
-
database_name: str,
|
135
|
-
schema_name: str,
|
136
|
-
registry_table_name: str,
|
137
|
-
metadata_table_name: str,
|
138
|
-
deployment_table_name: str,
|
139
|
-
statement_params: Dict[str, Any],
|
140
|
-
) -> None:
|
141
|
-
"""Create views on underlying ModelRegistry tables.
|
142
|
-
|
143
|
-
Args:
|
144
|
-
session: Session object to communicate with Snowflake.
|
145
|
-
database_name: Desired name of the model registry database.
|
146
|
-
schema_name: Desired name of the schema used by this model registry inside the database.
|
147
|
-
registry_table_name: Name for the main model registry table.
|
148
|
-
metadata_table_name: Name for the metadata table used by the model registry.
|
149
|
-
deployment_table_name: Name for the deployment event table.
|
150
|
-
statement_params: Function usage statement parameters used in sql query executions.
|
151
|
-
"""
|
152
|
-
fully_qualified_schema_name = table_manager.get_fully_qualified_schema_name(database_name, schema_name)
|
153
|
-
|
154
|
-
# From the documentation: Each DDL statement executes as a separate transaction. Races should not be an issue.
|
155
|
-
# https://docs.snowflake.com/en/sql-reference/transactions.html#ddl
|
156
|
-
|
157
|
-
# Create a view on active permanent deployments.
|
158
|
-
_create_active_permanent_deployment_view(
|
159
|
-
session,
|
160
|
-
fully_qualified_schema_name,
|
161
|
-
registry_table_name,
|
162
|
-
deployment_table_name,
|
163
|
-
statement_params,
|
164
|
-
)
|
165
|
-
|
166
|
-
# Create views on most recent metadata items.
|
167
|
-
metadata_view_name_prefix = identifier.concat_names([metadata_table_name, "_LAST_"])
|
168
|
-
metadata_view_template = formatting.unwrap(
|
169
|
-
"""CREATE OR REPLACE TEMPORARY VIEW {database}.{schema}.{attribute_view} COPY GRANTS AS
|
170
|
-
SELECT DISTINCT MODEL_ID, {select_expression} AS {final_attribute_name} FROM {metadata_table}
|
171
|
-
WHERE ATTRIBUTE_NAME = '{attribute_name}'"""
|
172
|
-
)
|
173
|
-
|
174
|
-
# Create a separate view for the most recent item in each metadata column.
|
175
|
-
metadata_view_names = []
|
176
|
-
metadata_select_fields = []
|
177
|
-
for attribute_name in _LIST_METADATA_ATTRIBUTE:
|
178
|
-
view_name = identifier.concat_names([metadata_view_name_prefix, attribute_name])
|
179
|
-
select_expression = (
|
180
|
-
f"(LAST_VALUE(VALUE) OVER (PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['{attribute_name}']"
|
181
|
-
)
|
182
|
-
sql = metadata_view_template.format(
|
183
|
-
database=database_name,
|
184
|
-
schema=schema_name,
|
185
|
-
select_expression=select_expression,
|
186
|
-
attribute_view=view_name,
|
187
|
-
attribute_name=attribute_name,
|
188
|
-
final_attribute_name=attribute_name,
|
189
|
-
metadata_table=metadata_table_name,
|
190
|
-
)
|
191
|
-
session.sql(sql).collect(statement_params=statement_params)
|
192
|
-
metadata_view_names.append(view_name)
|
193
|
-
metadata_select_fields.append(f"{view_name}.{attribute_name} AS {attribute_name}")
|
194
|
-
|
195
|
-
# Create a special view for the registration timestamp.
|
196
|
-
attribute_name = _METADATA_ATTRIBUTE_REGISTRATION
|
197
|
-
final_attribute_name = identifier.concat_names([attribute_name, "_TIMESTAMP"])
|
198
|
-
view_name = identifier.concat_names([metadata_view_name_prefix, attribute_name])
|
199
|
-
create_registration_view_sql = metadata_view_template.format(
|
200
|
-
database=database_name,
|
201
|
-
schema=schema_name,
|
202
|
-
select_expression="EVENT_TIMESTAMP",
|
203
|
-
attribute_view=view_name,
|
204
|
-
attribute_name=attribute_name,
|
205
|
-
final_attribute_name=final_attribute_name,
|
206
|
-
metadata_table=metadata_table_name,
|
207
|
-
)
|
208
|
-
session.sql(create_registration_view_sql).collect(statement_params=statement_params)
|
209
|
-
metadata_view_names.append(view_name)
|
210
|
-
metadata_select_fields.append(f"{view_name}.{final_attribute_name} AS {final_attribute_name}")
|
211
|
-
|
212
|
-
metadata_views_join = " ".join(
|
213
|
-
[
|
214
|
-
"LEFT JOIN {view} ON ({view}.MODEL_ID = {registry_table}.ID)".format(
|
215
|
-
view=view, registry_table=registry_table_name
|
216
|
-
)
|
217
|
-
for view in metadata_view_names
|
218
|
-
]
|
219
|
-
)
|
220
|
-
|
221
|
-
# Create view to combine all attributes.
|
222
|
-
registry_view_name = identifier.concat_names([registry_table_name, "_VIEW"])
|
223
|
-
metadata_select_fields_formatted = ",".join(metadata_select_fields)
|
224
|
-
session.sql(
|
225
|
-
f"""CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{registry_view_name} COPY GRANTS AS
|
226
|
-
SELECT {registry_table_name}.*, {metadata_select_fields_formatted}
|
227
|
-
FROM {registry_table_name} {metadata_views_join}"""
|
228
|
-
).collect(statement_params=statement_params)
|
229
|
-
|
230
|
-
|
231
|
-
def _create_active_permanent_deployment_view(
|
232
|
-
session: snowpark.Session,
|
233
|
-
fully_qualified_schema_name: str,
|
234
|
-
registry_table_name: str,
|
235
|
-
deployment_table_name: str,
|
236
|
-
statement_params: Dict[str, Any],
|
237
|
-
) -> None:
|
238
|
-
"""Create a view which lists all available permanent deployments.
|
239
|
-
|
240
|
-
Args:
|
241
|
-
session: Session object to communicate with Snowflake.
|
242
|
-
fully_qualified_schema_name: Schema name to the target table.
|
243
|
-
registry_table_name: Name for the main model registry table.
|
244
|
-
deployment_table_name: Name of the deployment table.
|
245
|
-
statement_params: Function usage statement parameters used in sql query executions.
|
246
|
-
"""
|
247
|
-
|
248
|
-
# Create a view on active permanent deployments
|
249
|
-
# Active deployments are those whose last operation is not DROP.
|
250
|
-
active_deployments_view_name = identifier.concat_names([deployment_table_name, "_VIEW"])
|
251
|
-
active_deployments_view_expr = f"""
|
252
|
-
CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{active_deployments_view_name}
|
253
|
-
COPY GRANTS AS
|
254
|
-
SELECT
|
255
|
-
DEPLOYMENT_NAME,
|
256
|
-
MODEL_ID,
|
257
|
-
{registry_table_name}.NAME as MODEL_NAME,
|
258
|
-
{registry_table_name}.VERSION as MODEL_VERSION,
|
259
|
-
{deployment_table_name}.CREATION_TIME as CREATION_TIME,
|
260
|
-
TARGET_METHOD,
|
261
|
-
TARGET_PLATFORM,
|
262
|
-
SIGNATURE,
|
263
|
-
OPTIONS,
|
264
|
-
STAGE_PATH,
|
265
|
-
ROLE
|
266
|
-
FROM {deployment_table_name}
|
267
|
-
LEFT JOIN {registry_table_name}
|
268
|
-
ON {deployment_table_name}.MODEL_ID = {registry_table_name}.ID
|
269
|
-
"""
|
270
|
-
session.sql(active_deployments_view_expr).collect(statement_params=statement_params)
|
271
|
-
|
272
|
-
|
273
|
-
class ModelRegistry:
|
274
|
-
"""Model Management API."""
|
275
|
-
|
276
|
-
def __init__(
|
277
|
-
self,
|
278
|
-
*,
|
279
|
-
session: snowpark.Session,
|
280
|
-
database_name: str = _DEFAULT_REGISTRY_NAME,
|
281
|
-
schema_name: str = _DEFAULT_SCHEMA_NAME,
|
282
|
-
create_if_not_exists: bool = False,
|
283
|
-
) -> None:
|
284
|
-
"""
|
285
|
-
Opens an already-created registry.
|
286
|
-
|
287
|
-
Args:
|
288
|
-
session: Session object to communicate with Snowflake.
|
289
|
-
database_name: Desired name of the model registry database.
|
290
|
-
schema_name: Desired name of the schema used by this model registry inside the database.
|
291
|
-
create_if_not_exists: create model registry if it's not exists already.
|
292
|
-
"""
|
293
|
-
|
294
|
-
warnings.warn(
|
295
|
-
"""
|
296
|
-
The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0.
|
297
|
-
It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`,
|
298
|
-
except when specifically required. The old model registry will be removed once all its primary functionalities are
|
299
|
-
fully integrated into the new registry.
|
300
|
-
""",
|
301
|
-
DeprecationWarning,
|
302
|
-
stacklevel=2,
|
303
|
-
)
|
304
|
-
if create_if_not_exists:
|
305
|
-
create_model_registry(session=session, database_name=database_name, schema_name=schema_name)
|
306
|
-
|
307
|
-
self._name = identifier.get_inferred_name(database_name)
|
308
|
-
self._schema = identifier.get_inferred_name(schema_name)
|
309
|
-
self._registry_table = identifier.get_inferred_name(_MODELS_TABLE_NAME)
|
310
|
-
self._registry_table_view = identifier.concat_names([self._registry_table, "_VIEW"])
|
311
|
-
self._metadata_table = identifier.get_inferred_name(_METADATA_TABLE_NAME)
|
312
|
-
self._deployment_table = identifier.get_inferred_name(_DEPLOYMENT_TABLE_NAME)
|
313
|
-
self._permanent_deployment_view = identifier.concat_names([self._deployment_table, "_VIEW"])
|
314
|
-
self._permanent_deployment_stage = identifier.concat_names([self._deployment_table, "_STAGE"])
|
315
|
-
self._session = session
|
316
|
-
self._svm = _schema_version_manager.SchemaVersionManager(self._session, self._name, self._schema)
|
317
|
-
|
318
|
-
# A in-memory deployment info cache to store information of temporary deployments
|
319
|
-
# TODO(zhe): Use a temporary table to replace the in-memory cache.
|
320
|
-
self._temporary_deployments: Dict[str, model_types.Deployment] = {}
|
321
|
-
|
322
|
-
_initial_schema.check_access(self._session, self._name, self._schema)
|
323
|
-
|
324
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
325
|
-
self._svm.validate_schema_version(statement_params)
|
326
|
-
|
327
|
-
_create_registry_views(
|
328
|
-
session,
|
329
|
-
self._name,
|
330
|
-
self._schema,
|
331
|
-
self._registry_table,
|
332
|
-
self._metadata_table,
|
333
|
-
self._deployment_table,
|
334
|
-
statement_params,
|
335
|
-
)
|
336
|
-
|
337
|
-
# Private methods
|
338
|
-
|
339
|
-
def _get_statement_params(self, frame: Optional[types.FrameType]) -> Dict[str, Any]:
|
340
|
-
return telemetry.get_function_usage_statement_params(
|
341
|
-
project=_TELEMETRY_PROJECT,
|
342
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
343
|
-
function_name=telemetry.get_statement_params_full_func_name(frame, "ModelRegistry"),
|
344
|
-
)
|
345
|
-
|
346
|
-
def _get_new_unique_identifier(self) -> str:
|
347
|
-
"""Create new unique identifier.
|
348
|
-
|
349
|
-
Returns:
|
350
|
-
String identifier."""
|
351
|
-
return uuid1().hex
|
352
|
-
|
353
|
-
def _fully_qualified_registry_table_name(self) -> str:
|
354
|
-
"""Get the fully qualified name to the current registry table."""
|
355
|
-
return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._registry_table)
|
356
|
-
|
357
|
-
def _fully_qualified_registry_view_name(self) -> str:
|
358
|
-
"""Get the fully qualified name to the current registry view."""
|
359
|
-
return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._registry_table_view)
|
360
|
-
|
361
|
-
def _fully_qualified_metadata_table_name(self) -> str:
|
362
|
-
"""Get the fully qualified name to the current metadata table."""
|
363
|
-
return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._metadata_table)
|
364
|
-
|
365
|
-
def _fully_qualified_deployment_table_name(self) -> str:
|
366
|
-
"""Get the fully qualified name to the current deployment table."""
|
367
|
-
return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._deployment_table)
|
368
|
-
|
369
|
-
def _fully_qualified_permanent_deployment_view_name(self) -> str:
|
370
|
-
"""Get the fully qualified name to the permanent deployment view."""
|
371
|
-
return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._permanent_deployment_view)
|
372
|
-
|
373
|
-
def _fully_qualified_schema_name(self) -> str:
|
374
|
-
"""Get the fully qualified name to the current registry schema."""
|
375
|
-
return table_manager.get_fully_qualified_schema_name(self._name, self._schema)
|
376
|
-
|
377
|
-
def _fully_qualified_deployment_name(self, deployment_name: str) -> str:
|
378
|
-
"""Get the fully qualified name to the given deployment."""
|
379
|
-
return table_manager.get_fully_qualified_table_name(self._name, self._schema, deployment_name)
|
380
|
-
|
381
|
-
def _insert_registry_entry(
|
382
|
-
self, *, id: str, name: str, version: str, properties: Dict[str, Any]
|
383
|
-
) -> List[snowpark.Row]:
|
384
|
-
"""Insert a new row into the model registry table.
|
385
|
-
|
386
|
-
Args:
|
387
|
-
id: Model id to register.
|
388
|
-
name: Model Name string.
|
389
|
-
version: Model Version string.
|
390
|
-
properties: Dictionary of properties corresponding to table columns.
|
391
|
-
|
392
|
-
Returns:
|
393
|
-
snowpark.Dataframe with the result of the operation.
|
394
|
-
|
395
|
-
Raises:
|
396
|
-
DataError: Mismatch between different id fields.
|
397
|
-
"""
|
398
|
-
if not id:
|
399
|
-
raise connector.DataError("Model ID is required but none given.")
|
400
|
-
mandatory_args = {"ID": id, "NAME": name, "VERSION": version}
|
401
|
-
for k, v in mandatory_args.items():
|
402
|
-
if k not in properties:
|
403
|
-
properties[k] = v
|
404
|
-
else:
|
405
|
-
if v and v != properties[k]:
|
406
|
-
raise connector.DataError(
|
407
|
-
formatting.unwrap(
|
408
|
-
f"""Parameter '{k.lower()}' is given and parameter 'properties' has the field '{k}' set but
|
409
|
-
the values do not match: {k.lower()}=="{v}" properties['{k}']=="{properties[k]}"."""
|
410
|
-
)
|
411
|
-
)
|
412
|
-
# Could do a multi-table insert here with some pros and cons:
|
413
|
-
# [PRO] Atomic insert across multiple tables.
|
414
|
-
# [CON] Code logic becomes messy depending on which fields are set.
|
415
|
-
# [CON] Harder to reuse existing methods like set_model_name.
|
416
|
-
# Context: https://docs.snowflake.com/en/sql-reference/sql/insert-multi-table.html
|
417
|
-
return table_manager.insert_table_entry(
|
418
|
-
self._session,
|
419
|
-
table=self._fully_qualified_registry_table_name(),
|
420
|
-
columns=properties,
|
421
|
-
)
|
422
|
-
|
423
|
-
def _insert_metadata_entry(self, *, id: str, attribute: str, value: Any, operation: str) -> List[snowpark.Row]:
|
424
|
-
"""Insert a new row into the model metadata table.
|
425
|
-
|
426
|
-
Args:
|
427
|
-
id: Model id to register.
|
428
|
-
attribute: name of the metadata attribute
|
429
|
-
value: new value of the metadata attribute
|
430
|
-
operation: the operation type of the metadata entry.
|
431
|
-
|
432
|
-
Returns:
|
433
|
-
snowpark.DataFrame with the result of the operation.
|
434
|
-
|
435
|
-
Raises:
|
436
|
-
DataError: Missing ID field.
|
437
|
-
"""
|
438
|
-
if not id:
|
439
|
-
raise connector.DataError("Model ID is required but none given.")
|
440
|
-
|
441
|
-
columns: Dict[str, Any] = {}
|
442
|
-
columns["EVENT_TIMESTAMP"] = formatting.SqlStr("CURRENT_TIMESTAMP()")
|
443
|
-
columns["EVENT_ID"] = self._get_new_unique_identifier()
|
444
|
-
columns["MODEL_ID"] = id
|
445
|
-
columns["ROLE"] = self._session.get_current_role()
|
446
|
-
columns["OPERATION"] = operation
|
447
|
-
columns["ATTRIBUTE_NAME"] = attribute
|
448
|
-
columns["VALUE"] = value
|
449
|
-
|
450
|
-
return table_manager.insert_table_entry(
|
451
|
-
self._session,
|
452
|
-
table=self._fully_qualified_metadata_table_name(),
|
453
|
-
columns=columns,
|
454
|
-
)
|
455
|
-
|
456
|
-
def _insert_deployment_entry(
|
457
|
-
self,
|
458
|
-
*,
|
459
|
-
id: str,
|
460
|
-
name: str,
|
461
|
-
platform: str,
|
462
|
-
stage_path: str,
|
463
|
-
signature: Dict[str, Any],
|
464
|
-
target_method: str,
|
465
|
-
options: Optional[
|
466
|
-
Union[
|
467
|
-
model_types.WarehouseDeployOptions,
|
468
|
-
model_types.SnowparkContainerServiceDeployOptions,
|
469
|
-
]
|
470
|
-
] = None,
|
471
|
-
) -> List[snowpark.Row]:
|
472
|
-
"""Insert a new row into the model deployment table.
|
473
|
-
|
474
|
-
Each row in the deployment table is a deployment event.
|
475
|
-
|
476
|
-
Args:
|
477
|
-
id: Model id of the deployed model.
|
478
|
-
name: Name of the deployment.
|
479
|
-
platform: The deployment target destination.
|
480
|
-
stage_path: The stage location where the deployment UDF is stored.
|
481
|
-
signature: The model signature.
|
482
|
-
target_method: The method name which is used for the deployment.
|
483
|
-
options: The deployment options.
|
484
|
-
|
485
|
-
Returns:
|
486
|
-
A list of snowpark rows which is the insertion result.
|
487
|
-
|
488
|
-
Raises:
|
489
|
-
DataError: Missing ID field.
|
490
|
-
"""
|
491
|
-
if not id:
|
492
|
-
raise connector.DataError("Model ID is required but none given.")
|
493
|
-
|
494
|
-
columns: Dict[str, Any] = {}
|
495
|
-
columns["CREATION_TIME"] = formatting.SqlStr("CURRENT_TIMESTAMP()")
|
496
|
-
columns["MODEL_ID"] = id
|
497
|
-
columns["DEPLOYMENT_NAME"] = name
|
498
|
-
columns["TARGET_PLATFORM"] = platform
|
499
|
-
columns["STAGE_PATH"] = stage_path
|
500
|
-
columns["ROLE"] = self._session.get_current_role()
|
501
|
-
columns["SIGNATURE"] = signature
|
502
|
-
columns["TARGET_METHOD"] = target_method
|
503
|
-
columns["OPTIONS"] = options
|
504
|
-
|
505
|
-
return table_manager.insert_table_entry(
|
506
|
-
self._session,
|
507
|
-
table=self._fully_qualified_deployment_table_name(),
|
508
|
-
columns=columns,
|
509
|
-
)
|
510
|
-
|
511
|
-
def _prepare_deployment_stage(self) -> str:
|
512
|
-
"""Create a stage in the model registry for storing all permanent deployments.
|
513
|
-
|
514
|
-
Returns:
|
515
|
-
Path to the stage that was created.
|
516
|
-
"""
|
517
|
-
schema = self._fully_qualified_schema_name()
|
518
|
-
fully_qualified_deployment_stage_name = f"{schema}.{self._permanent_deployment_stage}"
|
519
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
520
|
-
self._session.sql(
|
521
|
-
f"CREATE STAGE IF NOT EXISTS {fully_qualified_deployment_stage_name} "
|
522
|
-
f"ENCRYPTION = (TYPE= 'SNOWFLAKE_SSE')"
|
523
|
-
).collect(statement_params=statement_params)
|
524
|
-
return f"@{fully_qualified_deployment_stage_name}"
|
525
|
-
|
526
|
-
def _prepare_model_stage(self, model_id: str) -> str:
|
527
|
-
"""Create a stage in the model registry for storing the model with the given id.
|
528
|
-
|
529
|
-
Creating a permanent stage here since we do not have a way to switch a stage from temporary to permanent.
|
530
|
-
This can result in orphaned stages in case the process fails. It might be better to try to create a
|
531
|
-
temporary stage, attempt to perform all operations and convert the temp stage into permanent once the
|
532
|
-
operation is complete.
|
533
|
-
|
534
|
-
Args:
|
535
|
-
model_id: Internal model ID string.
|
536
|
-
|
537
|
-
Returns:
|
538
|
-
Name of the stage that was created.
|
539
|
-
|
540
|
-
Raises:
|
541
|
-
DatabaseError: Indicates that something went wrong when creating the stage.
|
542
|
-
"""
|
543
|
-
schema = self._fully_qualified_schema_name()
|
544
|
-
|
545
|
-
# Uppercase the model_stage_name to avoid having to quote the the stage name.
|
546
|
-
stage_name = model_id.upper()
|
547
|
-
|
548
|
-
model_stage_name = f"SNOWML_MODEL_{stage_name}"
|
549
|
-
fully_qualified_model_stage_name = f"{schema}.{model_stage_name}"
|
550
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
551
|
-
|
552
|
-
create_stage_result = self._session.sql(
|
553
|
-
f"CREATE OR REPLACE STAGE {fully_qualified_model_stage_name} ENCRYPTION = (TYPE= 'SNOWFLAKE_SSE')"
|
554
|
-
).collect(statement_params=statement_params)
|
555
|
-
if not create_stage_result:
|
556
|
-
raise connector.DatabaseError("Unable to create stage for model. Operation returned not result.")
|
557
|
-
if len(create_stage_result) != 1:
|
558
|
-
raise connector.DatabaseError(
|
559
|
-
"Unable to create stage for model. Creating the model stage returned unexpected result: {}.".format(
|
560
|
-
str(create_stage_result)
|
561
|
-
)
|
562
|
-
)
|
563
|
-
|
564
|
-
return fully_qualified_model_stage_name
|
565
|
-
|
566
|
-
def _get_fully_qualified_stage_name_from_uri(self, model_uri: str) -> Optional[str]:
|
567
|
-
"""Get fully qualified stage path pointed by the URI.
|
568
|
-
|
569
|
-
Args:
|
570
|
-
model_uri: URI for which stage file is needed.
|
571
|
-
|
572
|
-
Returns:
|
573
|
-
The fully qualified Snowflake stage location encoded by the given URI. Returns None if the URI is not
|
574
|
-
pointing to a Snowflake stage.
|
575
|
-
"""
|
576
|
-
raw_stage_path = uri.get_snowflake_stage_path_from_uri(model_uri)
|
577
|
-
if not raw_stage_path:
|
578
|
-
return None
|
579
|
-
(db, schema, stage, _) = identifier.parse_snowflake_stage_path(raw_stage_path)
|
580
|
-
return identifier.get_schema_level_object_identifier(db, schema, stage)
|
581
|
-
|
582
|
-
def _list_selected_models(
|
583
|
-
self,
|
584
|
-
*,
|
585
|
-
id: Optional[str] = None,
|
586
|
-
model_name: Optional[str] = None,
|
587
|
-
model_version: Optional[str] = None,
|
588
|
-
) -> snowpark.DataFrame:
|
589
|
-
"""Retrieve the Snowpark dataframe of models matching the specified ID or (name and version).
|
590
|
-
|
591
|
-
Args:
|
592
|
-
id: Model ID string. Required if either name or version is None.
|
593
|
-
model_name: Model Name string. Required if id is None.
|
594
|
-
model_version: Model Version string. Required if id is None.
|
595
|
-
|
596
|
-
Returns:
|
597
|
-
A Snowpark dataframe representing the models that match the given constraints.
|
598
|
-
"""
|
599
|
-
models = self.list_models()
|
600
|
-
|
601
|
-
if id:
|
602
|
-
filtered_models = models.filter(snowpark.Column("ID") == id)
|
603
|
-
else:
|
604
|
-
self._model_identifier_is_nonempty_or_raise(model_name, model_version)
|
605
|
-
|
606
|
-
# The following two asserts is to satisfy mypy.
|
607
|
-
assert model_name
|
608
|
-
assert model_version
|
609
|
-
|
610
|
-
filtered_models = models.filter(snowpark.Column("NAME") == model_name).filter(
|
611
|
-
snowpark.Column("VERSION") == model_version
|
612
|
-
)
|
613
|
-
|
614
|
-
return cast(snowpark.DataFrame, filtered_models)
|
615
|
-
|
616
|
-
def _validate_exact_one_result(
|
617
|
-
self, selected_model: snowpark.DataFrame, model_identifier: str
|
618
|
-
) -> List[snowpark.Row]:
|
619
|
-
"""Validate the filtered model has exactly one result.
|
620
|
-
|
621
|
-
Args:
|
622
|
-
selected_model: A snowpark dataframe representing the models that are filtered out.
|
623
|
-
model_identifier: A string which is used to filter the model.
|
624
|
-
|
625
|
-
Returns:
|
626
|
-
A snowpark row which contains the metadata of the filtered model
|
627
|
-
|
628
|
-
Raises:
|
629
|
-
KeyError: The target model doesn't exist.
|
630
|
-
DataError: The target model is not unique.
|
631
|
-
"""
|
632
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
633
|
-
model_info = None
|
634
|
-
try:
|
635
|
-
model_info = (
|
636
|
-
query_result_checker.ResultValidator(result=selected_model.collect(statement_params=statement_params))
|
637
|
-
.has_dimensions(expected_rows=1)
|
638
|
-
.validate()
|
639
|
-
)
|
640
|
-
except connector.DataError:
|
641
|
-
if model_info is None or len(model_info) == 0:
|
642
|
-
raise KeyError(f"The model {model_identifier} does not exist in the current registry.")
|
643
|
-
else:
|
644
|
-
raise connector.DataError(
|
645
|
-
formatting.unwrap(
|
646
|
-
f"""There are {len(model_info)} models {model_identifier}. This might indicate a problem with
|
647
|
-
the integrity of the model registry data."""
|
648
|
-
)
|
649
|
-
)
|
650
|
-
return model_info
|
651
|
-
|
652
|
-
def _get_metadata_attribute(
|
653
|
-
self,
|
654
|
-
attribute: str,
|
655
|
-
id: Optional[str] = None,
|
656
|
-
model_name: Optional[str] = None,
|
657
|
-
model_version: Optional[str] = None,
|
658
|
-
) -> Any:
|
659
|
-
"""Get the value of the given metadata attribute for target model with given (model name + model version) or id.
|
660
|
-
|
661
|
-
Args:
|
662
|
-
attribute: Name of the attribute to get.
|
663
|
-
id: Model ID string. Required if either name or version is None.
|
664
|
-
model_name: Model Name string. Required if id is None.
|
665
|
-
model_version: Model Version string. Required if version is None.
|
666
|
-
|
667
|
-
Returns:
|
668
|
-
The value of the attribute that was requested. Can be None if the attribute is not set.
|
669
|
-
"""
|
670
|
-
selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version)
|
671
|
-
identifier = f"id {id}" if id else f"{model_name}/{model_version}"
|
672
|
-
model_info = self._validate_exact_one_result(selected_models, identifier)
|
673
|
-
return model_info[0][attribute]
|
674
|
-
|
675
|
-
def _set_metadata_attribute(
|
676
|
-
self,
|
677
|
-
attribute: str,
|
678
|
-
value: Any,
|
679
|
-
id: Optional[str] = None,
|
680
|
-
model_name: Optional[str] = None,
|
681
|
-
model_version: Optional[str] = None,
|
682
|
-
operation: str = _SET_METADATA_OPERATION,
|
683
|
-
enable_model_presence_check: bool = True,
|
684
|
-
) -> None:
|
685
|
-
"""Set the value of the given metadata attribute for target model with given (model name + model version) or id.
|
686
|
-
|
687
|
-
Args:
|
688
|
-
attribute: Name of the attribute to set.
|
689
|
-
value: Value to set.
|
690
|
-
id: Model ID string. Required if either name or version is None.
|
691
|
-
model_name: Model Name string. Required if id is None.
|
692
|
-
model_version: Model Version string. Required if version is None.
|
693
|
-
operation: the operation type of the metadata entry.
|
694
|
-
enable_model_presence_check: If True, we will check if the model with the given ID is currently registered
|
695
|
-
before setting the metadata attribute. False by default meaning that by default we will check.
|
696
|
-
|
697
|
-
Raises:
|
698
|
-
DataError: Failed to set the metadata attribute.
|
699
|
-
KeyError: The target model doesn't exist
|
700
|
-
"""
|
701
|
-
selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version)
|
702
|
-
identifier = f"id {id}" if id else f"{model_name}/{model_version}"
|
703
|
-
try:
|
704
|
-
model_info = self._validate_exact_one_result(selected_models, identifier)
|
705
|
-
except KeyError as e:
|
706
|
-
# If the target model doesn't exist, raise the error only if enable_model_presence_check is True.
|
707
|
-
if enable_model_presence_check:
|
708
|
-
raise e
|
709
|
-
|
710
|
-
if not id:
|
711
|
-
id = model_info[0]["ID"]
|
712
|
-
assert id is not None
|
713
|
-
|
714
|
-
try:
|
715
|
-
self._insert_metadata_entry(
|
716
|
-
id=id,
|
717
|
-
attribute=attribute,
|
718
|
-
value={attribute: value},
|
719
|
-
operation=operation,
|
720
|
-
)
|
721
|
-
except connector.DataError:
|
722
|
-
raise connector.DataError(f"Setting {attribute} for mode id {id} failed.")
|
723
|
-
|
724
|
-
def _model_identifier_is_nonempty_or_raise(self, model_name: Optional[str], model_version: Optional[str]) -> None:
|
725
|
-
"""Validate model_name and model_version are non-empty strings.
|
726
|
-
|
727
|
-
Args:
|
728
|
-
model_name: Model Name string.
|
729
|
-
model_version: Model Version string.
|
730
|
-
|
731
|
-
Raises:
|
732
|
-
ValueError: Raised when either model_name and model_version is empty.
|
733
|
-
"""
|
734
|
-
if not model_name or not model_version:
|
735
|
-
raise ValueError("model_name and model_version have to be non-empty strings.")
|
736
|
-
|
737
|
-
def _get_model_id(self, model_name: str, model_version: str) -> str:
|
738
|
-
"""Get ID of the model with the given (model name + model version).
|
739
|
-
|
740
|
-
Args:
|
741
|
-
model_name: Model Name string.
|
742
|
-
model_version: Model Version string.
|
743
|
-
|
744
|
-
Returns:
|
745
|
-
Id of the model.
|
746
|
-
|
747
|
-
Raises:
|
748
|
-
DataError: The requested model could not be found.
|
749
|
-
"""
|
750
|
-
result = self._get_metadata_attribute("ID", model_name=model_name, model_version=model_version)
|
751
|
-
if not result:
|
752
|
-
raise connector.DataError(f"Model {model_name}/{model_version} doesn't exist.")
|
753
|
-
return str(result)
|
754
|
-
|
755
|
-
def _get_model_path(
|
756
|
-
self,
|
757
|
-
id: Optional[str] = None,
|
758
|
-
model_name: Optional[str] = None,
|
759
|
-
model_version: Optional[str] = None,
|
760
|
-
) -> str:
|
761
|
-
"""Get the stage path for the model with the given (model name + model version) or `id` from the registry.
|
762
|
-
|
763
|
-
Args:
|
764
|
-
id: Id of the model to deploy. Required if either model name or model version is None.
|
765
|
-
model_name: Model Name string. Required if id is None.
|
766
|
-
model_version: Model Version string. Required if id is None.
|
767
|
-
|
768
|
-
Returns:
|
769
|
-
str: Stage path for the model.
|
770
|
-
|
771
|
-
Raises:
|
772
|
-
DataError: When the model cannot be found or not be restored.
|
773
|
-
"""
|
774
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
775
|
-
selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version)
|
776
|
-
identifier = f"id {id}" if id else f"{model_name}/{model_version}"
|
777
|
-
model_info = self._validate_exact_one_result(selected_models, identifier)
|
778
|
-
if not id:
|
779
|
-
id = model_info[0]["ID"]
|
780
|
-
model_uri = model_info[0]["URI"]
|
781
|
-
|
782
|
-
if not uri.is_snowflake_stage_uri(model_uri):
|
783
|
-
raise connector.DataError(
|
784
|
-
f"Artifacts with URI scheme {uri.get_uri_scheme(model_uri)} are currently not supported."
|
785
|
-
)
|
786
|
-
|
787
|
-
model_stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri=model_uri)
|
788
|
-
|
789
|
-
# Currently we assume only the model is on the stage.
|
790
|
-
model_file_list = self._session.sql(f"LIST @{model_stage_path}").collect(statement_params=statement_params)
|
791
|
-
if len(model_file_list) == 0:
|
792
|
-
raise connector.DataError(f"No files in model artifact for id {id} located at {model_uri}.")
|
793
|
-
return f"{_STAGE_PREFIX}{model_stage_path}"
|
794
|
-
|
795
|
-
def _log_model_path(
|
796
|
-
self,
|
797
|
-
model_name: str,
|
798
|
-
model_version: str,
|
799
|
-
) -> Tuple[str, str]:
|
800
|
-
"""Generate a path in the Model Registry to store a model.
|
801
|
-
|
802
|
-
Args:
|
803
|
-
model_name: The given name for the model.
|
804
|
-
model_version: Version string to be set for the model.
|
805
|
-
|
806
|
-
Returns:
|
807
|
-
String of the auto-generate unique model identifier and path to store it.
|
808
|
-
"""
|
809
|
-
model_id = self._get_new_unique_identifier()
|
810
|
-
|
811
|
-
# Copy model from local disk to remote stage.
|
812
|
-
# TODO(zhe): Check if we could use the same stage for multiple models.
|
813
|
-
fully_qualified_model_stage_name = self._prepare_model_stage(model_id=model_id)
|
814
|
-
|
815
|
-
return model_id, fully_qualified_model_stage_name
|
816
|
-
|
817
|
-
def _register_model_with_id(
|
818
|
-
self,
|
819
|
-
model_name: str,
|
820
|
-
model_version: str,
|
821
|
-
model_id: str,
|
822
|
-
*,
|
823
|
-
type: str,
|
824
|
-
uri: str,
|
825
|
-
input_spec: Optional[Dict[str, str]] = None,
|
826
|
-
output_spec: Optional[Dict[str, str]] = None,
|
827
|
-
description: Optional[str] = None,
|
828
|
-
tags: Optional[Dict[str, str]] = None,
|
829
|
-
) -> None:
|
830
|
-
"""Helper function to register model metadata.
|
831
|
-
|
832
|
-
Args:
|
833
|
-
model_name: Name to be set for the model. The model name can NOT be changed after registration. The
|
834
|
-
combination of name and version is expected to be unique inside the registry.
|
835
|
-
model_version: Version string to be set for the model. The model version string can NOT be changed after
|
836
|
-
model registration. The combination of name and version is expected to be unique inside the registry.
|
837
|
-
model_id: The internal id for the model.
|
838
|
-
type: Type of the model. Only a subset of types are supported natively.
|
839
|
-
uri: Resource identifier pointing to the model artifact. There are no restrictions on the URI format,
|
840
|
-
however only a limited set of URI schemes is supported natively.
|
841
|
-
input_spec: The expected input schema of the model. Dictionary where the keys are
|
842
|
-
expected column names and the values are the value types.
|
843
|
-
output_spec: The expected output schema of the model. Dictionary where the keys
|
844
|
-
are expected column names and the values are the value types.
|
845
|
-
description: A description for the model. The description can be changed later.
|
846
|
-
tags: Key-value pairs of tags to be set for this model. Tags can be modified
|
847
|
-
after model registration.
|
848
|
-
|
849
|
-
Raises:
|
850
|
-
DataError: The given model already exists.
|
851
|
-
DatabaseError: Unable to register the model properties into table.
|
852
|
-
"""
|
853
|
-
new_model: Dict[Any, Any] = {}
|
854
|
-
new_model["ID"] = model_id
|
855
|
-
new_model["NAME"] = model_name
|
856
|
-
new_model["VERSION"] = model_version
|
857
|
-
new_model["TYPE"] = type
|
858
|
-
new_model["URI"] = uri
|
859
|
-
new_model["INPUT_SPEC"] = input_spec
|
860
|
-
new_model["OUTPUT_SPEC"] = output_spec
|
861
|
-
new_model["CREATION_TIME"] = formatting.SqlStr("CURRENT_TIMESTAMP()")
|
862
|
-
new_model["CREATION_ROLE"] = self._session.get_current_role()
|
863
|
-
new_model["CREATION_ENVIRONMENT_SPEC"] = {"python": ".".join(map(str, sys.version_info[:3]))}
|
864
|
-
|
865
|
-
existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count()
|
866
|
-
if existing_model_nums:
|
867
|
-
raise connector.DataError(
|
868
|
-
f"Model {model_name}/{model_version} already exists. Unable to register the model."
|
869
|
-
)
|
870
|
-
|
871
|
-
if self._insert_registry_entry(id=model_id, name=model_name, version=model_version, properties=new_model):
|
872
|
-
self._set_metadata_attribute(
|
873
|
-
model_name=model_name,
|
874
|
-
model_version=model_version,
|
875
|
-
attribute=_METADATA_ATTRIBUTE_REGISTRATION,
|
876
|
-
value=new_model,
|
877
|
-
)
|
878
|
-
if description:
|
879
|
-
self.set_model_description(
|
880
|
-
model_name=model_name,
|
881
|
-
model_version=model_version,
|
882
|
-
description=description,
|
883
|
-
)
|
884
|
-
if tags:
|
885
|
-
self._set_metadata_attribute(
|
886
|
-
_METADATA_ATTRIBUTE_TAGS,
|
887
|
-
value=tags,
|
888
|
-
model_name=model_name,
|
889
|
-
model_version=model_version,
|
890
|
-
)
|
891
|
-
else:
|
892
|
-
raise connector.DatabaseError("Failed to insert the model properties to the registry table.")
|
893
|
-
|
894
|
-
def _get_deployment(self, *, model_name: str, model_version: str, deployment_name: str) -> snowpark.Row:
|
895
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
896
|
-
deployment_lst = (
|
897
|
-
self._session.sql(f"SELECT * FROM {self._fully_qualified_permanent_deployment_view_name()}")
|
898
|
-
.filter(snowpark.Column("DEPLOYMENT_NAME") == deployment_name)
|
899
|
-
.filter(snowpark.Column("MODEL_NAME") == model_name)
|
900
|
-
.filter(snowpark.Column("MODEL_VERSION") == model_version)
|
901
|
-
).collect(statement_params=statement_params)
|
902
|
-
if len(deployment_lst) == 0:
|
903
|
-
raise KeyError(
|
904
|
-
f"Unable to find deployment named {deployment_name} in the model {model_name}/{model_version}."
|
905
|
-
)
|
906
|
-
assert len(deployment_lst) == 1, "_get_deployment should return exactly 1 deployment"
|
907
|
-
return cast(snowpark.Row, deployment_lst[0])
|
908
|
-
|
909
|
-
# Registry operations
|
910
|
-
|
911
|
-
@telemetry.send_api_usage_telemetry(
|
912
|
-
project=_TELEMETRY_PROJECT,
|
913
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
914
|
-
)
|
915
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
916
|
-
def list_models(self) -> snowpark.DataFrame:
|
917
|
-
"""Lists models contained in the registry.
|
918
|
-
|
919
|
-
Returns:
|
920
|
-
snowpark.DataFrame with the list of models. Access is read-only through the snowpark.DataFrame.
|
921
|
-
The resulting snowpark.dataframe will have an "id" column that uniquely identifies each model and can be
|
922
|
-
used to reference the model when performing operations.
|
923
|
-
"""
|
924
|
-
# Explicitly not calling collect.
|
925
|
-
return self._session.sql(
|
926
|
-
"SELECT * FROM {database}.{schema}.{view}".format(
|
927
|
-
database=self._name, schema=self._schema, view=self._registry_table_view
|
928
|
-
)
|
929
|
-
)
|
930
|
-
|
931
|
-
@telemetry.send_api_usage_telemetry(
|
932
|
-
project=_TELEMETRY_PROJECT,
|
933
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
934
|
-
)
|
935
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
936
|
-
def set_tag(
|
937
|
-
self,
|
938
|
-
model_name: str,
|
939
|
-
model_version: str,
|
940
|
-
tag_name: str,
|
941
|
-
tag_value: Optional[str] = None,
|
942
|
-
) -> None:
|
943
|
-
"""Set model tag to the model with value.
|
944
|
-
|
945
|
-
If the model tag already exists, the tag value will be overwritten.
|
946
|
-
|
947
|
-
Args:
|
948
|
-
model_name: Model Name string.
|
949
|
-
model_version: Model Version string.
|
950
|
-
tag_name: Desired tag name string.
|
951
|
-
tag_value: (optional) New tag value string. If no value is given the value of the tag will be set to None.
|
952
|
-
"""
|
953
|
-
# This method uses a read-modify-write pattern for setting tags.
|
954
|
-
# TODO(amauser): Investigate the use of transactions to avoid race conditions.
|
955
|
-
model_tags = self.get_tags(model_name=model_name, model_version=model_version)
|
956
|
-
model_tags[tag_name] = tag_value
|
957
|
-
self._set_metadata_attribute(
|
958
|
-
_METADATA_ATTRIBUTE_TAGS,
|
959
|
-
model_tags,
|
960
|
-
model_name=model_name,
|
961
|
-
model_version=model_version,
|
962
|
-
)
|
963
|
-
|
964
|
-
@telemetry.send_api_usage_telemetry(
|
965
|
-
project=_TELEMETRY_PROJECT,
|
966
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
967
|
-
)
|
968
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
969
|
-
def remove_tag(self, model_name: str, model_version: str, tag_name: str) -> None:
|
970
|
-
"""Remove target model tag.
|
971
|
-
|
972
|
-
Args:
|
973
|
-
model_name: Model Name string.
|
974
|
-
model_version: Model Version string.
|
975
|
-
tag_name: Desired tag name string.
|
976
|
-
|
977
|
-
Raises:
|
978
|
-
DataError: If the model does not have the requested tag.
|
979
|
-
"""
|
980
|
-
# This method uses a read-modify-write pattern for updating tags.
|
981
|
-
|
982
|
-
model_tags = self.get_tags(model_name=model_name, model_version=model_version)
|
983
|
-
try:
|
984
|
-
del model_tags[tag_name]
|
985
|
-
except KeyError:
|
986
|
-
raise connector.DataError(
|
987
|
-
f"Model {model_name}/{model_version} has no tag named {tag_name}. Full list of tags: {model_tags}"
|
988
|
-
)
|
989
|
-
|
990
|
-
self._set_metadata_attribute(
|
991
|
-
_METADATA_ATTRIBUTE_TAGS,
|
992
|
-
model_tags,
|
993
|
-
model_name=model_name,
|
994
|
-
model_version=model_version,
|
995
|
-
)
|
996
|
-
|
997
|
-
@telemetry.send_api_usage_telemetry(
|
998
|
-
project=_TELEMETRY_PROJECT,
|
999
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1000
|
-
)
|
1001
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1002
|
-
def has_tag(
|
1003
|
-
self,
|
1004
|
-
model_name: str,
|
1005
|
-
model_version: str,
|
1006
|
-
tag_name: str,
|
1007
|
-
tag_value: Optional[str] = None,
|
1008
|
-
) -> bool:
|
1009
|
-
"""Check if a model has a tag with the given name and value.
|
1010
|
-
|
1011
|
-
If no value is given, any value for the tag will return true.
|
1012
|
-
|
1013
|
-
Args:
|
1014
|
-
model_name: Model Name string.
|
1015
|
-
model_version: Model Version string.
|
1016
|
-
tag_name: Desired tag name string.
|
1017
|
-
tag_value: (optional) Tag value to check. If not value is given, only the presence of the tag will be
|
1018
|
-
checked.
|
1019
|
-
|
1020
|
-
Returns:
|
1021
|
-
True if the tag or tag and value combination is present for the model with the given id, False otherwise.
|
1022
|
-
"""
|
1023
|
-
tags = self.get_tags(model_name=model_name, model_version=model_version)
|
1024
|
-
has_tag = tag_name in tags
|
1025
|
-
if tag_value is None:
|
1026
|
-
return has_tag
|
1027
|
-
return has_tag and tags[tag_name] == str(tag_value)
|
1028
|
-
|
1029
|
-
@telemetry.send_api_usage_telemetry(
|
1030
|
-
project=_TELEMETRY_PROJECT,
|
1031
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1032
|
-
)
|
1033
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1034
|
-
def get_tag_value(self, model_name: str, model_version: str, tag_name: str) -> Any:
|
1035
|
-
"""Return the value of the tag for the model.
|
1036
|
-
|
1037
|
-
The returned value can be None. If the tag does not exist, KeyError will be raised.
|
1038
|
-
|
1039
|
-
Args:
|
1040
|
-
model_name: Model Name string.
|
1041
|
-
model_version: Model Version string.
|
1042
|
-
tag_name: Desired tag name string.
|
1043
|
-
|
1044
|
-
Returns:
|
1045
|
-
Value string of the tag or None, if no value is set for the tag.
|
1046
|
-
"""
|
1047
|
-
return self.get_tags(model_name=model_name, model_version=model_version)[tag_name]
|
1048
|
-
|
1049
|
-
@telemetry.send_api_usage_telemetry(
|
1050
|
-
project=_TELEMETRY_PROJECT,
|
1051
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1052
|
-
)
|
1053
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1054
|
-
def get_tags(self, model_name: Optional[str] = None, model_version: Optional[str] = None) -> Dict[str, Any]:
|
1055
|
-
"""Get all tags and values stored for the target model.
|
1056
|
-
|
1057
|
-
Args:
|
1058
|
-
model_name: Model Name string.
|
1059
|
-
model_version: Model Version string.
|
1060
|
-
|
1061
|
-
Returns:
|
1062
|
-
String-to-string dictionary containing all tags and values. The resulting dictionary can be empty.
|
1063
|
-
"""
|
1064
|
-
# Snowpark snowpark.dataframe returns dictionary objects as strings. We need to convert it back to a dictionary
|
1065
|
-
# here.
|
1066
|
-
result = self._get_metadata_attribute(
|
1067
|
-
_METADATA_ATTRIBUTE_TAGS, model_name=model_name, model_version=model_version
|
1068
|
-
)
|
1069
|
-
|
1070
|
-
if result:
|
1071
|
-
ret: Dict[str, Optional[str]] = json.loads(result)
|
1072
|
-
return ret
|
1073
|
-
else:
|
1074
|
-
return dict()
|
1075
|
-
|
1076
|
-
@telemetry.send_api_usage_telemetry(
|
1077
|
-
project=_TELEMETRY_PROJECT,
|
1078
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1079
|
-
)
|
1080
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1081
|
-
def get_model_description(self, model_name: str, model_version: str) -> Optional[str]:
|
1082
|
-
"""Get the description of the model.
|
1083
|
-
|
1084
|
-
Args:
|
1085
|
-
model_name: Model Name string.
|
1086
|
-
model_version: Model Version string.
|
1087
|
-
|
1088
|
-
Returns:
|
1089
|
-
Description of the model or None.
|
1090
|
-
"""
|
1091
|
-
result = self._get_metadata_attribute(
|
1092
|
-
_METADATA_ATTRIBUTE_DESCRIPTION,
|
1093
|
-
model_name=model_name,
|
1094
|
-
model_version=model_version,
|
1095
|
-
)
|
1096
|
-
return None if result is None else json.loads(result)
|
1097
|
-
|
1098
|
-
@telemetry.send_api_usage_telemetry(
|
1099
|
-
project=_TELEMETRY_PROJECT,
|
1100
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1101
|
-
)
|
1102
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1103
|
-
def set_model_description(
|
1104
|
-
self,
|
1105
|
-
model_name: str,
|
1106
|
-
model_version: str,
|
1107
|
-
description: str,
|
1108
|
-
) -> None:
|
1109
|
-
"""Set the description of the model.
|
1110
|
-
|
1111
|
-
Args:
|
1112
|
-
model_name: Model Name string.
|
1113
|
-
model_version: Model Version string.
|
1114
|
-
description: Desired new model description.
|
1115
|
-
"""
|
1116
|
-
self._set_metadata_attribute(
|
1117
|
-
_METADATA_ATTRIBUTE_DESCRIPTION,
|
1118
|
-
description,
|
1119
|
-
model_name=model_name,
|
1120
|
-
model_version=model_version,
|
1121
|
-
)
|
1122
|
-
|
1123
|
-
@telemetry.send_api_usage_telemetry(
|
1124
|
-
project=_TELEMETRY_PROJECT,
|
1125
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1126
|
-
)
|
1127
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1128
|
-
def get_history(self) -> snowpark.DataFrame:
|
1129
|
-
"""Return a dataframe with the history of operations performed on the model registry.
|
1130
|
-
|
1131
|
-
The returned dataframe is order by time and can be filtered further.
|
1132
|
-
|
1133
|
-
Returns:
|
1134
|
-
snowpark.DataFrame with the history of the model.
|
1135
|
-
"""
|
1136
|
-
res = (
|
1137
|
-
self._session.table(self._fully_qualified_metadata_table_name())
|
1138
|
-
.order_by("EVENT_TIMESTAMP")
|
1139
|
-
.select_expr(
|
1140
|
-
"EVENT_TIMESTAMP",
|
1141
|
-
"EVENT_ID",
|
1142
|
-
"MODEL_ID",
|
1143
|
-
"ROLE",
|
1144
|
-
"OPERATION",
|
1145
|
-
"ATTRIBUTE_NAME",
|
1146
|
-
"VALUE[ATTRIBUTE_NAME]",
|
1147
|
-
)
|
1148
|
-
)
|
1149
|
-
return cast(snowpark.DataFrame, res)
|
1150
|
-
|
1151
|
-
@telemetry.send_api_usage_telemetry(
|
1152
|
-
project=_TELEMETRY_PROJECT,
|
1153
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1154
|
-
)
|
1155
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1156
|
-
def get_model_history(
|
1157
|
-
self,
|
1158
|
-
model_name: str,
|
1159
|
-
model_version: str,
|
1160
|
-
) -> snowpark.DataFrame:
|
1161
|
-
"""Return a dataframe with the history of operations performed on the desired model.
|
1162
|
-
|
1163
|
-
The returned dataframe is order by time and can be filtered further.
|
1164
|
-
|
1165
|
-
Args:
|
1166
|
-
model_name: Model Name string.
|
1167
|
-
model_version: Model Version string.
|
1168
|
-
|
1169
|
-
Returns:
|
1170
|
-
snowpark.DataFrame with the history of the model.
|
1171
|
-
"""
|
1172
|
-
id = self._get_model_id(model_name=model_name, model_version=model_version)
|
1173
|
-
return cast(
|
1174
|
-
snowpark.DataFrame,
|
1175
|
-
self.get_history().filter(snowpark.Column("MODEL_ID") == id),
|
1176
|
-
)
|
1177
|
-
|
1178
|
-
@telemetry.send_api_usage_telemetry(
|
1179
|
-
project=_TELEMETRY_PROJECT,
|
1180
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1181
|
-
)
|
1182
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1183
|
-
def set_metric(
|
1184
|
-
self,
|
1185
|
-
model_name: str,
|
1186
|
-
model_version: str,
|
1187
|
-
metric_name: str,
|
1188
|
-
metric_value: object,
|
1189
|
-
) -> None:
|
1190
|
-
"""Set scalar model metric to value.
|
1191
|
-
|
1192
|
-
If a metric with that name already exists for the model, the metric value will be overwritten.
|
1193
|
-
|
1194
|
-
Args:
|
1195
|
-
model_name: Model Name string.
|
1196
|
-
model_version: Model Version string.
|
1197
|
-
metric_name: Desired metric name.
|
1198
|
-
metric_value: New metric value.
|
1199
|
-
"""
|
1200
|
-
# This method uses a read-modify-write pattern for setting tags.
|
1201
|
-
# TODO(amauser): Investigate the use of transactions to avoid race conditions.
|
1202
|
-
model_metrics = self.get_metrics(model_name=model_name, model_version=model_version)
|
1203
|
-
model_metrics[metric_name] = metric_value
|
1204
|
-
self._set_metadata_attribute(
|
1205
|
-
_METADATA_ATTRIBUTE_METRICS,
|
1206
|
-
model_metrics,
|
1207
|
-
model_name=model_name,
|
1208
|
-
model_version=model_version,
|
1209
|
-
)
|
1210
|
-
|
1211
|
-
@telemetry.send_api_usage_telemetry(
|
1212
|
-
project=_TELEMETRY_PROJECT,
|
1213
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1214
|
-
)
|
1215
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1216
|
-
def remove_metric(
|
1217
|
-
self,
|
1218
|
-
model_name: str,
|
1219
|
-
model_version: str,
|
1220
|
-
metric_name: str,
|
1221
|
-
) -> None:
|
1222
|
-
"""Remove a specific metric entry from the model.
|
1223
|
-
|
1224
|
-
Args:
|
1225
|
-
model_name: Model Name string.
|
1226
|
-
model_version: Model Version string.
|
1227
|
-
metric_name: Desired metric name.
|
1228
|
-
|
1229
|
-
Raises:
|
1230
|
-
DataError: If the model does not have the requested metric.
|
1231
|
-
"""
|
1232
|
-
# This method uses a read-modify-write pattern for updating tags.
|
1233
|
-
|
1234
|
-
model_metrics = self.get_metrics(model_name=model_name, model_version=model_version)
|
1235
|
-
try:
|
1236
|
-
del model_metrics[metric_name]
|
1237
|
-
except KeyError:
|
1238
|
-
raise connector.DataError(
|
1239
|
-
f"Model {model_name}/{model_version} has no metric named {metric_name}. "
|
1240
|
-
f"Full list of metrics: {model_metrics}"
|
1241
|
-
)
|
1242
|
-
|
1243
|
-
self._set_metadata_attribute(
|
1244
|
-
_METADATA_ATTRIBUTE_METRICS,
|
1245
|
-
model_metrics,
|
1246
|
-
model_name=model_name,
|
1247
|
-
model_version=model_version,
|
1248
|
-
)
|
1249
|
-
|
1250
|
-
@telemetry.send_api_usage_telemetry(
|
1251
|
-
project=_TELEMETRY_PROJECT,
|
1252
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1253
|
-
)
|
1254
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1255
|
-
def has_metric(self, model_name: str, model_version: str, metric_name: str) -> bool:
|
1256
|
-
"""Check if a model has a metric with the given name.
|
1257
|
-
|
1258
|
-
Args:
|
1259
|
-
model_name: Model Name string.
|
1260
|
-
model_version: Model Version string.
|
1261
|
-
metric_name: Desired metric name.
|
1262
|
-
|
1263
|
-
Returns:
|
1264
|
-
True if the metric is present for the model with the given id, False otherwise.
|
1265
|
-
"""
|
1266
|
-
metrics = self.get_metrics(model_name=model_name, model_version=model_version)
|
1267
|
-
return metric_name in metrics
|
1268
|
-
|
1269
|
-
@telemetry.send_api_usage_telemetry(
|
1270
|
-
project=_TELEMETRY_PROJECT,
|
1271
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1272
|
-
)
|
1273
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1274
|
-
def get_metric_value(self, model_name: str, model_version: str, metric_name: str) -> object:
|
1275
|
-
"""Return the value of the given metric for the model.
|
1276
|
-
|
1277
|
-
The returned value can be None. If the metric does not exist, KeyError will be raised.
|
1278
|
-
|
1279
|
-
Args:
|
1280
|
-
model_name: Model Name string.
|
1281
|
-
model_version: Model Version string.
|
1282
|
-
metric_name: Desired metric name.
|
1283
|
-
|
1284
|
-
Returns:
|
1285
|
-
Value of the metric. Can be None if the metric was set to None.
|
1286
|
-
"""
|
1287
|
-
return self.get_metrics(model_name=model_name, model_version=model_version)[metric_name]
|
1288
|
-
|
1289
|
-
@telemetry.send_api_usage_telemetry(
|
1290
|
-
project=_TELEMETRY_PROJECT,
|
1291
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1292
|
-
)
|
1293
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1294
|
-
def get_metrics(self, model_name: str, model_version: str) -> Dict[str, object]:
|
1295
|
-
"""Get all metrics and values stored for the given model.
|
1296
|
-
|
1297
|
-
Args:
|
1298
|
-
model_name: Model Name string.
|
1299
|
-
model_version: Model Version string.
|
1300
|
-
|
1301
|
-
Returns:
|
1302
|
-
String-to-float dictionary containing all metrics and values. The resulting dictionary can be empty.
|
1303
|
-
"""
|
1304
|
-
# Snowpark snowpark.dataframe returns dictionary objects as strings. We need to convert it back to a dictionary
|
1305
|
-
# here.
|
1306
|
-
result = self._get_metadata_attribute(
|
1307
|
-
_METADATA_ATTRIBUTE_METRICS,
|
1308
|
-
model_name=model_name,
|
1309
|
-
model_version=model_version,
|
1310
|
-
)
|
1311
|
-
|
1312
|
-
if result:
|
1313
|
-
ret: Dict[str, object] = json.loads(result)
|
1314
|
-
return ret
|
1315
|
-
else:
|
1316
|
-
return dict()
|
1317
|
-
|
1318
|
-
# Combined Registry and Repository operations.
|
1319
|
-
@telemetry.send_api_usage_telemetry(
|
1320
|
-
project=_TELEMETRY_PROJECT,
|
1321
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1322
|
-
)
|
1323
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1324
|
-
def log_model(
|
1325
|
-
self,
|
1326
|
-
model_name: str,
|
1327
|
-
model_version: str,
|
1328
|
-
*,
|
1329
|
-
model: Any,
|
1330
|
-
description: Optional[str] = None,
|
1331
|
-
tags: Optional[Dict[str, str]] = None,
|
1332
|
-
conda_dependencies: Optional[List[str]] = None,
|
1333
|
-
pip_requirements: Optional[List[str]] = None,
|
1334
|
-
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
1335
|
-
sample_input_data: Optional[Any] = None,
|
1336
|
-
code_paths: Optional[List[str]] = None,
|
1337
|
-
options: Optional[model_types.BaseModelSaveOption] = None,
|
1338
|
-
) -> Optional["ModelReference"]:
|
1339
|
-
"""Uploads and register a model to the Model Registry.
|
1340
|
-
|
1341
|
-
Args:
|
1342
|
-
model_name: The given name for the model. The combination (name + version) must be unique for each model.
|
1343
|
-
model_version: Version string to be set for the model. The combination (name + version) must be unique for
|
1344
|
-
each model.
|
1345
|
-
model: Local model object in a supported format.
|
1346
|
-
description: A description for the model. The description can be changed later.
|
1347
|
-
tags: string-to-string dictionary of tag names and values to be set for the model.
|
1348
|
-
conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to
|
1349
|
-
specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is
|
1350
|
-
not specified, defaults channel will be used. When deploying to Snowflake Warehouse, defaults channel
|
1351
|
-
would be replaced with the Snowflake Anaconda channel.
|
1352
|
-
pip_requirements: List of PIP package specs. Model will not be able to deploy to the warehouse if there is
|
1353
|
-
pip requirements.
|
1354
|
-
signatures: Signatures of the model, which is a mapping from target method name to signatures of input and
|
1355
|
-
output, which could be inferred by calling `infer_signature` method with sample input data.
|
1356
|
-
sample_input_data: Sample of the input data for the model.
|
1357
|
-
code_paths: Directory of code to import when loading and deploying the model.
|
1358
|
-
options: Additional options when saving the model.
|
1359
|
-
|
1360
|
-
Raises:
|
1361
|
-
DataError: Raised when:
|
1362
|
-
1) the given model already exists;
|
1363
|
-
ValueError: Raised when: # noqa: DAR402
|
1364
|
-
1) Signatures and sample_input_data are both not provided and model is not a
|
1365
|
-
snowflake estimator.
|
1366
|
-
Exception: Raised when there is any error raised when saving the model.
|
1367
|
-
|
1368
|
-
Returns:
|
1369
|
-
Model Reference . None if failed.
|
1370
|
-
"""
|
1371
|
-
# Ideally, the whole operation should be a single transaction. Currently, transactions do not support stage
|
1372
|
-
# operations.
|
1373
|
-
|
1374
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
1375
|
-
self._svm.validate_schema_version(statement_params)
|
1376
|
-
|
1377
|
-
self._model_identifier_is_nonempty_or_raise(model_name, model_version)
|
1378
|
-
|
1379
|
-
existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count()
|
1380
|
-
if existing_model_nums:
|
1381
|
-
raise connector.DataError(f"Model {model_name}/{model_version} already exists. Unable to log the model.")
|
1382
|
-
model_id, fully_qualified_model_stage_name = self._log_model_path(
|
1383
|
-
model_name=model_name,
|
1384
|
-
model_version=model_version,
|
1385
|
-
)
|
1386
|
-
stage_path = f"{_STAGE_PREFIX}{fully_qualified_model_stage_name}"
|
1387
|
-
model = cast(model_types.SupportedModelType, model)
|
1388
|
-
try:
|
1389
|
-
model_composer = model_api.save_model( # type: ignore[call-overload, misc]
|
1390
|
-
name=model_name,
|
1391
|
-
session=self._session,
|
1392
|
-
stage_path=stage_path,
|
1393
|
-
model=model,
|
1394
|
-
signatures=signatures,
|
1395
|
-
metadata=tags,
|
1396
|
-
conda_dependencies=conda_dependencies,
|
1397
|
-
pip_requirements=pip_requirements,
|
1398
|
-
sample_input_data=sample_input_data,
|
1399
|
-
code_paths=code_paths,
|
1400
|
-
options=options,
|
1401
|
-
)
|
1402
|
-
except Exception:
|
1403
|
-
# When model saving fails, clean up the model stage.
|
1404
|
-
query_result_checker.SqlResultValidator(
|
1405
|
-
self._session, f"DROP STAGE {fully_qualified_model_stage_name}"
|
1406
|
-
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
1407
|
-
raise
|
1408
|
-
|
1409
|
-
self._register_model_with_id(
|
1410
|
-
model_name=model_name,
|
1411
|
-
model_version=model_version,
|
1412
|
-
model_id=model_id,
|
1413
|
-
type=model_composer.packager.meta.model_type,
|
1414
|
-
uri=uri.get_uri_from_snowflake_stage_path(stage_path),
|
1415
|
-
description=description,
|
1416
|
-
tags=tags,
|
1417
|
-
)
|
1418
|
-
|
1419
|
-
return ModelReference(registry=self, model_name=model_name, model_version=model_version)
|
1420
|
-
|
1421
|
-
@telemetry.send_api_usage_telemetry(
|
1422
|
-
project=_TELEMETRY_PROJECT,
|
1423
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1424
|
-
)
|
1425
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1426
|
-
def load_model(self, model_name: str, model_version: str) -> Any:
|
1427
|
-
"""Loads the model with the given (model_name + model_version) from the registry into memory.
|
1428
|
-
|
1429
|
-
Args:
|
1430
|
-
model_name: Model Name string.
|
1431
|
-
model_version: Model Version string.
|
1432
|
-
|
1433
|
-
Returns:
|
1434
|
-
Restored model object.
|
1435
|
-
"""
|
1436
|
-
warnings.warn(
|
1437
|
-
(
|
1438
|
-
"Please use with caution: "
|
1439
|
-
"Using `load_model` method requires you to have the EXACT same Python environments "
|
1440
|
-
"as the one when you logged the model. Any differences will potentially lead to errors.\n"
|
1441
|
-
"Also, if your model contains custom code imported using `code_paths` argument when logging, "
|
1442
|
-
"they will be added to your `sys.path`. It might lead to unexpected module importing issues. "
|
1443
|
-
"If you run into such kind of problems, you need to restart your Python or Notebook kernel."
|
1444
|
-
),
|
1445
|
-
category=UserWarning,
|
1446
|
-
stacklevel=2,
|
1447
|
-
)
|
1448
|
-
remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version)
|
1449
|
-
restored_model = None
|
1450
|
-
|
1451
|
-
restored_model = model_api.load_model(session=self._session, stage_path=remote_model_path)
|
1452
|
-
|
1453
|
-
return restored_model.packager.model
|
1454
|
-
|
1455
|
-
# Repository Operations
|
1456
|
-
|
1457
|
-
@telemetry.send_api_usage_telemetry(
|
1458
|
-
project=_TELEMETRY_PROJECT,
|
1459
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1460
|
-
)
|
1461
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1462
|
-
def deploy(
|
1463
|
-
self,
|
1464
|
-
model_name: str,
|
1465
|
-
model_version: str,
|
1466
|
-
*,
|
1467
|
-
deployment_name: str,
|
1468
|
-
target_method: Optional[str] = None,
|
1469
|
-
permanent: bool = False,
|
1470
|
-
platform: deploy_platforms.TargetPlatform = deploy_platforms.TargetPlatform.WAREHOUSE,
|
1471
|
-
options: Optional[
|
1472
|
-
Union[
|
1473
|
-
model_types.WarehouseDeployOptions,
|
1474
|
-
model_types.SnowparkContainerServiceDeployOptions,
|
1475
|
-
]
|
1476
|
-
] = None,
|
1477
|
-
) -> model_types.Deployment:
|
1478
|
-
"""Deploy the model with the given deployment name.
|
1479
|
-
|
1480
|
-
Args:
|
1481
|
-
model_name: Model Name string.
|
1482
|
-
model_version: Model Version string.
|
1483
|
-
deployment_name: name of the generated UDF.
|
1484
|
-
target_method: The method name to use in deployment. Can be omitted if only 1 method in the model.
|
1485
|
-
permanent: Whether the deployment is permanent or not. Permanent deployment will generate a permanent UDF.
|
1486
|
-
(Only applicable for Warehouse deployment)
|
1487
|
-
platform: Target platform to deploy the model to. Currently supported platforms are defined as enum in
|
1488
|
-
`snowflake.ml.model.deploy_platforms.TargetPlatform`
|
1489
|
-
options: Optional options for model deployment. Defaults to None.
|
1490
|
-
|
1491
|
-
Returns:
|
1492
|
-
Deployment info.
|
1493
|
-
|
1494
|
-
Raises:
|
1495
|
-
RuntimeError: Raised when parameters are not properly enabled when deploying to Warehouse with temporary UDF
|
1496
|
-
RuntimeError: Raised when deploying to SPCS with db/schema that starts with underscore.
|
1497
|
-
"""
|
1498
|
-
statement_params = self._get_statement_params(inspect.currentframe())
|
1499
|
-
self._svm.validate_schema_version(statement_params)
|
1500
|
-
|
1501
|
-
if options is None:
|
1502
|
-
options = {}
|
1503
|
-
|
1504
|
-
deployment_stage_path = ""
|
1505
|
-
|
1506
|
-
if platform == deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES:
|
1507
|
-
if self._name.startswith("_") or self._schema.startswith("_"):
|
1508
|
-
error_message = """\
|
1509
|
-
Model deployment to Snowpark Container Service does not support a database/schema name that starts with
|
1510
|
-
an underscore. Please ensure you pass in a valid db/schema name when initializing the registry with:
|
1511
|
-
|
1512
|
-
model_registry.create_model_registry(
|
1513
|
-
session=session,
|
1514
|
-
database_name=db,
|
1515
|
-
schema_name=schema
|
1516
|
-
)
|
1517
|
-
|
1518
|
-
registry = model_registry.ModelRegistry(
|
1519
|
-
session=session,
|
1520
|
-
database_name=db,
|
1521
|
-
schema_name=schema
|
1522
|
-
)
|
1523
|
-
"""
|
1524
|
-
raise RuntimeError(textwrap.dedent(error_message))
|
1525
|
-
permanent = True
|
1526
|
-
options = cast(model_types.SnowparkContainerServiceDeployOptions, options)
|
1527
|
-
deployment_stage_path = f"{self._prepare_deployment_stage()}/{deployment_name}/"
|
1528
|
-
elif platform == deploy_platforms.TargetPlatform.WAREHOUSE:
|
1529
|
-
options = cast(model_types.WarehouseDeployOptions, options)
|
1530
|
-
if permanent:
|
1531
|
-
# Every deployment-generated UDF should reside in its own unique directory. As long as each deployment
|
1532
|
-
# is allocated a distinct directory, multiple deployments can coexist within the same stage.
|
1533
|
-
# Given that each permanent deployment possesses a unique deployment_name, sharing the same stage does
|
1534
|
-
# not present any issues
|
1535
|
-
deployment_stage_path = (
|
1536
|
-
options.get("permanent_udf_stage_location")
|
1537
|
-
or f"{self._prepare_deployment_stage()}/{deployment_name}/"
|
1538
|
-
)
|
1539
|
-
options["permanent_udf_stage_location"] = deployment_stage_path
|
1540
|
-
|
1541
|
-
remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version)
|
1542
|
-
model_id = self._get_model_id(model_name, model_version)
|
1543
|
-
|
1544
|
-
# https://snowflakecomputing.atlassian.net/browse/SNOW-858376
|
1545
|
-
# During temporary deployment on the Warehouse, Snowpark creates an unencrypted temporary stage for UDF-related
|
1546
|
-
# artifacts. However, UDF generation fails when importing from a mix of encrypted and unencrypted stages.
|
1547
|
-
# The following workaround copies model between stages (PrPr as of July 7th, 2023) to transfer the SSE
|
1548
|
-
# encrypted model zip from model stage to the temporary unencrypted stage.
|
1549
|
-
if not permanent and platform == deploy_platforms.TargetPlatform.WAREHOUSE:
|
1550
|
-
schema = self._fully_qualified_schema_name()
|
1551
|
-
unencrypted_stage = (
|
1552
|
-
f"@{schema}.{snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)}"
|
1553
|
-
)
|
1554
|
-
self._session.sql(f"CREATE TEMPORARY STAGE {unencrypted_stage[1:]}").collect()
|
1555
|
-
try:
|
1556
|
-
self._session.sql(f"COPY FILES INTO {unencrypted_stage} from {remote_model_path}").collect()
|
1557
|
-
except Exception:
|
1558
|
-
raise RuntimeError(
|
1559
|
-
"Temporary deployment to the warehouse is currently not supported. Please use "
|
1560
|
-
"permanent deployment by setting the 'permanent' parameter to True"
|
1561
|
-
)
|
1562
|
-
remote_model_path = unencrypted_stage
|
1563
|
-
|
1564
|
-
# Step 1: Deploy to get the UDF
|
1565
|
-
deployment_info = model_api.deploy(
|
1566
|
-
session=self._session,
|
1567
|
-
name=self._fully_qualified_deployment_name(deployment_name),
|
1568
|
-
platform=platform,
|
1569
|
-
target_method=target_method,
|
1570
|
-
stage_path=remote_model_path,
|
1571
|
-
deployment_stage_path=deployment_stage_path,
|
1572
|
-
model_id=model_id,
|
1573
|
-
options=options,
|
1574
|
-
)
|
1575
|
-
|
1576
|
-
# Step 2: Record the deployment
|
1577
|
-
|
1578
|
-
# Assert to convince mypy.
|
1579
|
-
assert deployment_info
|
1580
|
-
if permanent:
|
1581
|
-
self._insert_deployment_entry(
|
1582
|
-
id=model_id,
|
1583
|
-
name=deployment_name,
|
1584
|
-
platform=deployment_info["platform"].value,
|
1585
|
-
stage_path=deployment_stage_path,
|
1586
|
-
signature=deployment_info["signature"].to_dict(),
|
1587
|
-
target_method=deployment_info["target_method"],
|
1588
|
-
options=options,
|
1589
|
-
)
|
1590
|
-
|
1591
|
-
self._set_metadata_attribute(
|
1592
|
-
_METADATA_ATTRIBUTE_DEPLOYMENT,
|
1593
|
-
{"name": deployment_name, "permanent": permanent},
|
1594
|
-
id=model_id,
|
1595
|
-
operation=_ADD_METADATA_OPERATION,
|
1596
|
-
)
|
1597
|
-
|
1598
|
-
# Store temporary deployment information in the in-memory cache. This allows for future referencing and
|
1599
|
-
# tracking of its availability status.
|
1600
|
-
if not permanent:
|
1601
|
-
self._temporary_deployments[deployment_name] = deployment_info
|
1602
|
-
|
1603
|
-
return deployment_info
|
1604
|
-
|
1605
|
-
@telemetry.send_api_usage_telemetry(
|
1606
|
-
project=_TELEMETRY_PROJECT,
|
1607
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1608
|
-
)
|
1609
|
-
@snowpark._internal.utils.private_preview(version="1.0.1")
|
1610
|
-
def list_deployments(self, model_name: str, model_version: str) -> snowpark.DataFrame:
|
1611
|
-
"""List all permanent deployments that originated from the given model.
|
1612
|
-
|
1613
|
-
Temporary deployment info are currently not supported for listing.
|
1614
|
-
|
1615
|
-
Args:
|
1616
|
-
model_name: Model Name string.
|
1617
|
-
model_version: Model Version string.
|
1618
|
-
|
1619
|
-
Returns:
|
1620
|
-
A snowpark dataframe that contains all deployments that associated with the given model.
|
1621
|
-
"""
|
1622
|
-
deployments_df = (
|
1623
|
-
self._session.sql(f"SELECT * FROM {self._fully_qualified_permanent_deployment_view_name()}")
|
1624
|
-
.filter(snowpark.Column("MODEL_NAME") == model_name)
|
1625
|
-
.filter(snowpark.Column("MODEL_VERSION") == model_version)
|
1626
|
-
)
|
1627
|
-
res = deployments_df.select(
|
1628
|
-
deployments_df["MODEL_NAME"],
|
1629
|
-
deployments_df["MODEL_VERSION"],
|
1630
|
-
deployments_df["DEPLOYMENT_NAME"],
|
1631
|
-
deployments_df["CREATION_TIME"],
|
1632
|
-
deployments_df["TARGET_METHOD"],
|
1633
|
-
deployments_df["TARGET_PLATFORM"],
|
1634
|
-
deployments_df["SIGNATURE"],
|
1635
|
-
deployments_df["OPTIONS"],
|
1636
|
-
deployments_df["STAGE_PATH"],
|
1637
|
-
deployments_df["ROLE"],
|
1638
|
-
)
|
1639
|
-
return cast(snowpark.DataFrame, res)
|
1640
|
-
|
1641
|
-
@telemetry.send_api_usage_telemetry(
|
1642
|
-
project=_TELEMETRY_PROJECT,
|
1643
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1644
|
-
)
|
1645
|
-
@snowpark._internal.utils.private_preview(version="1.0.1")
|
1646
|
-
def get_deployment(self, model_name: str, model_version: str, *, deployment_name: str) -> snowpark.DataFrame:
|
1647
|
-
"""Get the permanent deployment with target name of the given model.
|
1648
|
-
|
1649
|
-
Temporary deployment info are currently not supported.
|
1650
|
-
|
1651
|
-
Args:
|
1652
|
-
model_name: Model Name string.
|
1653
|
-
model_version: Model Version string.
|
1654
|
-
deployment_name: Deployment name string.
|
1655
|
-
|
1656
|
-
Returns:
|
1657
|
-
A snowpark dataframe that contains the information of the target deployment.
|
1658
|
-
|
1659
|
-
Raises:
|
1660
|
-
KeyError: Raised if the target deployment is not found.
|
1661
|
-
"""
|
1662
|
-
deployment = self.list_deployments(model_name, model_version).filter(
|
1663
|
-
snowpark.Column("DEPLOYMENT_NAME") == deployment_name
|
1664
|
-
)
|
1665
|
-
if deployment.count() == 0:
|
1666
|
-
raise KeyError(
|
1667
|
-
f"Unable to find deployment named {deployment_name} in the model {model_name}/{model_version}."
|
1668
|
-
)
|
1669
|
-
return cast(snowpark.DataFrame, deployment)
|
1670
|
-
|
1671
|
-
@telemetry.send_api_usage_telemetry(
|
1672
|
-
project=_TELEMETRY_PROJECT,
|
1673
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1674
|
-
)
|
1675
|
-
@snowpark._internal.utils.private_preview(version="1.0.1")
|
1676
|
-
def delete_deployment(self, model_name: str, model_version: str, *, deployment_name: str) -> None:
|
1677
|
-
"""Delete the target permanent deployment of the given model.
|
1678
|
-
|
1679
|
-
Deleting temporary deployment are currently not supported.
|
1680
|
-
Temporary deployment will get cleaned automatically when the current session closed.
|
1681
|
-
|
1682
|
-
Args:
|
1683
|
-
model_name: Model Name string.
|
1684
|
-
model_version: Model Version string.
|
1685
|
-
deployment_name: Name of the deployment that is getting deleted.
|
1686
|
-
|
1687
|
-
"""
|
1688
|
-
deployment = self._get_deployment(
|
1689
|
-
model_name=model_name,
|
1690
|
-
model_version=model_version,
|
1691
|
-
deployment_name=deployment_name,
|
1692
|
-
)
|
1693
|
-
|
1694
|
-
# TODO(SNOW-759526): The following sequence should be a transaction.
|
1695
|
-
# Step 1: Drop the UDF
|
1696
|
-
self._session.sql(
|
1697
|
-
f"DROP FUNCTION IF EXISTS {self._fully_qualified_deployment_name(deployment_name)}(OBJECT)"
|
1698
|
-
).collect()
|
1699
|
-
|
1700
|
-
# Step 2: Remove the staged artifact
|
1701
|
-
self._session.sql(f"REMOVE {deployment['STAGE_PATH']}").collect()
|
1702
|
-
|
1703
|
-
# Step 3: Delete the deployment from the deployment table
|
1704
|
-
query_result_checker.SqlResultValidator(
|
1705
|
-
self._session,
|
1706
|
-
f"""DELETE FROM {self._fully_qualified_deployment_table_name()}
|
1707
|
-
WHERE MODEL_ID='{deployment['MODEL_ID']}' AND DEPLOYMENT_NAME='{deployment_name}'
|
1708
|
-
""",
|
1709
|
-
).deletion_success(expected_num_rows=1).validate()
|
1710
|
-
|
1711
|
-
# Step 4: Record the delete event
|
1712
|
-
self._set_metadata_attribute(
|
1713
|
-
_METADATA_ATTRIBUTE_DEPLOYMENT,
|
1714
|
-
{"name": deployment_name},
|
1715
|
-
id=deployment["MODEL_ID"],
|
1716
|
-
operation=_DROP_METADATA_OPERATION,
|
1717
|
-
)
|
1718
|
-
|
1719
|
-
# Optional Step 5: Delete Snowpark container service.
|
1720
|
-
if deployment["TARGET_PLATFORM"] == deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES.value:
|
1721
|
-
service_name = identifier.get_schema_level_object_identifier(
|
1722
|
-
self._name, self._schema, f"service_{deployment['MODEL_ID']}"
|
1723
|
-
)
|
1724
|
-
spcs_attribution_utils.record_service_end(self._session, service_name)
|
1725
|
-
query_result_checker.SqlResultValidator(
|
1726
|
-
self._session,
|
1727
|
-
f"DROP SERVICE IF EXISTS {service_name}",
|
1728
|
-
).validate()
|
1729
|
-
|
1730
|
-
@telemetry.send_api_usage_telemetry(
|
1731
|
-
project=_TELEMETRY_PROJECT,
|
1732
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1733
|
-
)
|
1734
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1735
|
-
def delete_model(
|
1736
|
-
self,
|
1737
|
-
model_name: str,
|
1738
|
-
model_version: str,
|
1739
|
-
delete_artifact: bool = True,
|
1740
|
-
) -> None:
|
1741
|
-
"""Delete model with the given ID from the registry.
|
1742
|
-
|
1743
|
-
The history of the model will still be preserved.
|
1744
|
-
|
1745
|
-
Args:
|
1746
|
-
model_name: Model Name string.
|
1747
|
-
model_version: Model Version string.
|
1748
|
-
delete_artifact: If True, the underlying model artifact will also be deleted, not just the entry in
|
1749
|
-
the registry table.
|
1750
|
-
"""
|
1751
|
-
|
1752
|
-
# Check that a model with the given ID exists and there is only one of them.
|
1753
|
-
# TODO(amauser): The following sequence should be a transaction. Transactions currently cannot contain DDL
|
1754
|
-
# statements.
|
1755
|
-
model_info = None
|
1756
|
-
selected_models = self._list_selected_models(model_name=model_name, model_version=model_version)
|
1757
|
-
identifier = f"{model_name}/{model_version}"
|
1758
|
-
model_info = self._validate_exact_one_result(selected_models, identifier)
|
1759
|
-
id = model_info[0]["ID"]
|
1760
|
-
model_uri = model_info[0]["URI"]
|
1761
|
-
|
1762
|
-
# Step 1/3: Delete the registry entry.
|
1763
|
-
query_result_checker.SqlResultValidator(
|
1764
|
-
self._session,
|
1765
|
-
f"DELETE FROM {self._fully_qualified_registry_table_name()} WHERE ID='{id}'",
|
1766
|
-
).deletion_success(expected_num_rows=1).validate()
|
1767
|
-
|
1768
|
-
# Step 2/3: Delete the artifact (if desired).
|
1769
|
-
if delete_artifact:
|
1770
|
-
if uri.is_snowflake_stage_uri(model_uri):
|
1771
|
-
stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri)
|
1772
|
-
query_result_checker.SqlResultValidator(self._session, f"DROP STAGE {stage_path}").has_dimensions(
|
1773
|
-
expected_rows=1, expected_cols=1
|
1774
|
-
).validate()
|
1775
|
-
|
1776
|
-
# Step 3/3: Record the deletion event.
|
1777
|
-
self._set_metadata_attribute(
|
1778
|
-
id=id,
|
1779
|
-
attribute=_METADATA_ATTRIBUTE_DELETION,
|
1780
|
-
value={"delete_artifact": True, "URI": model_uri},
|
1781
|
-
enable_model_presence_check=False,
|
1782
|
-
)
|
1783
|
-
|
1784
|
-
|
1785
|
-
class ModelReference:
|
1786
|
-
"""Wrapper class for ModelReference objects that proxy model metadata operations."""
|
1787
|
-
|
1788
|
-
def _remove_arg_from_docstring(self, arg: str, docstring: Optional[str]) -> Optional[str]:
|
1789
|
-
"""Remove the given parameter from a function docstring (Google convention)."""
|
1790
|
-
if docstring is None:
|
1791
|
-
return None
|
1792
|
-
docstring_lines = docstring.split("\n")
|
1793
|
-
|
1794
|
-
args_section_start = None
|
1795
|
-
args_section_end = None
|
1796
|
-
args_section_indent = None
|
1797
|
-
arg_start = None
|
1798
|
-
arg_end = None
|
1799
|
-
arg_indent = None
|
1800
|
-
for i in range(len(docstring_lines)):
|
1801
|
-
line = docstring_lines[i]
|
1802
|
-
lstrip_line = line.lstrip()
|
1803
|
-
indent = len(line) - len(lstrip_line)
|
1804
|
-
|
1805
|
-
if line.strip() == "Args:":
|
1806
|
-
# Starting the Args section of the docstring (assuming Google-style).
|
1807
|
-
args_section_start = i
|
1808
|
-
# logging.info("TEST: args_section_start=" + str(args_section_start))
|
1809
|
-
args_section_indent = indent
|
1810
|
-
continue
|
1811
|
-
|
1812
|
-
# logging.info("TEST: " + lstrip_line)
|
1813
|
-
if args_section_start and lstrip_line.startswith(f"{arg}:"):
|
1814
|
-
# This is the arg we are looking for.
|
1815
|
-
arg_start = i
|
1816
|
-
# logging.info("TEST: arg_start=" + str(arg_start))
|
1817
|
-
arg_indent = indent
|
1818
|
-
continue
|
1819
|
-
|
1820
|
-
if arg_start and not arg_end and indent == arg_indent:
|
1821
|
-
# We got the next arg, previous line was the last of the cut out arg docstring
|
1822
|
-
# and we do have other args. Saving arg_end for python slice/range notation.
|
1823
|
-
arg_end = i
|
1824
|
-
continue
|
1825
|
-
|
1826
|
-
if arg_start and (len(lstrip_line) == 0 or indent == args_section_indent):
|
1827
|
-
# Arg section ends.
|
1828
|
-
args_section_end = i
|
1829
|
-
arg_end = arg_end if arg_end else i
|
1830
|
-
# We have learned everything we need to know, no need to continue.
|
1831
|
-
break
|
1832
|
-
|
1833
|
-
if arg_start and not arg_end:
|
1834
|
-
arg_end = len(docstring_lines)
|
1835
|
-
|
1836
|
-
if args_section_start and not args_section_end:
|
1837
|
-
args_section_end = len(docstring_lines)
|
1838
|
-
|
1839
|
-
# Determine which lines from the "Args:" section of the docstring to skip or if we
|
1840
|
-
# should skip the entire section.
|
1841
|
-
keep_lines = set(range(len(docstring_lines)))
|
1842
|
-
if args_section_start:
|
1843
|
-
if arg_start == args_section_start + 1 and arg_end == args_section_end:
|
1844
|
-
# Removed arg was the only arg, remove the entire section.
|
1845
|
-
assert args_section_end
|
1846
|
-
keep_lines.difference_update(range(args_section_start, args_section_end))
|
1847
|
-
else:
|
1848
|
-
# Just remove the arg.
|
1849
|
-
assert arg_start
|
1850
|
-
assert arg_end
|
1851
|
-
keep_lines.difference_update(range(arg_start, arg_end))
|
1852
|
-
|
1853
|
-
return "\n".join([docstring_lines[i] for i in sorted(keep_lines)])
|
1854
|
-
|
1855
|
-
def __init__(
|
1856
|
-
self,
|
1857
|
-
*,
|
1858
|
-
registry: ModelRegistry,
|
1859
|
-
model_name: str,
|
1860
|
-
model_version: str,
|
1861
|
-
) -> None:
|
1862
|
-
self._registry = registry
|
1863
|
-
self._id = registry._get_model_id(model_name=model_name, model_version=model_version)
|
1864
|
-
self._model_name = model_name
|
1865
|
-
self._model_version = model_version
|
1866
|
-
|
1867
|
-
# Wrap all functions of the ModelRegistry that have an "id" parameter and bind that parameter
|
1868
|
-
# the the "_id" member of this class.
|
1869
|
-
if hasattr(self.__class__, "init_complete"):
|
1870
|
-
# Already did the generation of wrapped method.
|
1871
|
-
return
|
1872
|
-
|
1873
|
-
for name, obj in self._registry.__class__.__dict__.items():
|
1874
|
-
if (
|
1875
|
-
not inspect.isfunction(obj)
|
1876
|
-
or "model_name" not in inspect.signature(obj).parameters
|
1877
|
-
or "model_version" not in inspect.signature(obj).parameters
|
1878
|
-
):
|
1879
|
-
continue
|
1880
|
-
|
1881
|
-
# Ensure that we are not silently overwriting existing functions.
|
1882
|
-
assert not hasattr(self.__class__, name)
|
1883
|
-
|
1884
|
-
def build_method(m: Callable[..., Any]) -> Callable[..., Any]:
|
1885
|
-
return lambda self, *args, **kwargs: m(
|
1886
|
-
self._registry,
|
1887
|
-
self._model_name,
|
1888
|
-
self._model_version,
|
1889
|
-
*args,
|
1890
|
-
**kwargs,
|
1891
|
-
)
|
1892
|
-
|
1893
|
-
method = build_method(m=obj)
|
1894
|
-
setattr(self.__class__, name, method)
|
1895
|
-
|
1896
|
-
docstring = self._remove_arg_from_docstring("model_name", obj.__doc__)
|
1897
|
-
if docstring and "model_version" in docstring:
|
1898
|
-
docstring = self._remove_arg_from_docstring("model_version", docstring)
|
1899
|
-
setattr(self.__class__.__dict__[name], "__doc__", docstring) # noqa: B010
|
1900
|
-
|
1901
|
-
setattr(self.__class__, "init_complete", True) # noqa: B010
|
1902
|
-
|
1903
|
-
@telemetry.send_api_usage_telemetry(
|
1904
|
-
project=_TELEMETRY_PROJECT,
|
1905
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1906
|
-
)
|
1907
|
-
def get_name(self) -> str:
|
1908
|
-
return self._model_name
|
1909
|
-
|
1910
|
-
@telemetry.send_api_usage_telemetry(
|
1911
|
-
project=_TELEMETRY_PROJECT,
|
1912
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1913
|
-
)
|
1914
|
-
def get_version(self) -> str:
|
1915
|
-
return self._model_version
|
1916
|
-
|
1917
|
-
@telemetry.send_api_usage_telemetry(
|
1918
|
-
project=_TELEMETRY_PROJECT,
|
1919
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1920
|
-
)
|
1921
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1922
|
-
def predict(self, deployment_name: str, data: Any) -> "pd.DataFrame":
|
1923
|
-
"""Predict using the deployed model in Snowflake.
|
1924
|
-
|
1925
|
-
Args:
|
1926
|
-
deployment_name: name of the generated UDF.
|
1927
|
-
data: Data to run predict.
|
1928
|
-
|
1929
|
-
Raises:
|
1930
|
-
ValueError: The deployment with given name haven't been deployed.
|
1931
|
-
|
1932
|
-
Returns:
|
1933
|
-
A dataframe containing the result of prediction.
|
1934
|
-
"""
|
1935
|
-
# We will search temporary deployments from the local in-memory cache.
|
1936
|
-
# If there is no hit, we try to search the remote deployment table.
|
1937
|
-
di = self._registry._temporary_deployments.get(deployment_name)
|
1938
|
-
|
1939
|
-
statement_params = telemetry.get_function_usage_statement_params(
|
1940
|
-
project=_TELEMETRY_PROJECT,
|
1941
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1942
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
1943
|
-
inspect.currentframe(), self.__class__.__name__
|
1944
|
-
),
|
1945
|
-
)
|
1946
|
-
|
1947
|
-
self._registry._svm.validate_schema_version(statement_params)
|
1948
|
-
|
1949
|
-
if di:
|
1950
|
-
return model_api.predict(
|
1951
|
-
session=self._registry._session,
|
1952
|
-
deployment=di,
|
1953
|
-
X=data,
|
1954
|
-
statement_params=statement_params,
|
1955
|
-
)
|
1956
|
-
|
1957
|
-
# Mypy enforce to refer to the registry for calling the function
|
1958
|
-
deployment_collect = self._registry.get_deployment(
|
1959
|
-
self._model_name, self._model_version, deployment_name=deployment_name
|
1960
|
-
).collect(statement_params=statement_params)
|
1961
|
-
if not deployment_collect:
|
1962
|
-
raise ValueError(f"The deployment with name {deployment_name} haven't been deployed")
|
1963
|
-
deployment = deployment_collect[0]
|
1964
|
-
platform = deploy_platforms.TargetPlatform(deployment["TARGET_PLATFORM"])
|
1965
|
-
target_method = deployment["TARGET_METHOD"]
|
1966
|
-
signature = model_signature.ModelSignature.from_dict(json.loads(deployment["SIGNATURE"]))
|
1967
|
-
options_dict = cast(Dict[str, Any], json.loads(deployment["OPTIONS"]))
|
1968
|
-
platform_options = {
|
1969
|
-
deploy_platforms.TargetPlatform.WAREHOUSE: model_types.WarehouseDeployOptions,
|
1970
|
-
deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES: (
|
1971
|
-
model_types.SnowparkContainerServiceDeployOptions
|
1972
|
-
),
|
1973
|
-
}
|
1974
|
-
|
1975
|
-
if platform not in platform_options:
|
1976
|
-
raise ValueError(f"Unsupported target Platform: {platform}")
|
1977
|
-
options = platform_options[platform](options_dict)
|
1978
|
-
di = model_types.Deployment(
|
1979
|
-
name=self._registry._fully_qualified_deployment_name(deployment_name),
|
1980
|
-
platform=platform,
|
1981
|
-
target_method=target_method,
|
1982
|
-
signature=signature,
|
1983
|
-
options=options,
|
1984
|
-
)
|
1985
|
-
return model_api.predict(
|
1986
|
-
session=self._registry._session,
|
1987
|
-
deployment=di,
|
1988
|
-
X=data,
|
1989
|
-
statement_params=statement_params,
|
1990
|
-
)
|
1991
|
-
|
1992
|
-
|
1993
|
-
@telemetry.send_api_usage_telemetry(
|
1994
|
-
project=_TELEMETRY_PROJECT,
|
1995
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
1996
|
-
)
|
1997
|
-
@snowpark._internal.utils.private_preview(version="0.2.0")
|
1998
|
-
def create_model_registry(
|
1999
|
-
*,
|
2000
|
-
session: snowpark.Session,
|
2001
|
-
database_name: str = _DEFAULT_REGISTRY_NAME,
|
2002
|
-
schema_name: str = _DEFAULT_SCHEMA_NAME,
|
2003
|
-
) -> bool:
|
2004
|
-
"""Setup a new model registry. This should be run once per model registry by an administrator role.
|
2005
|
-
|
2006
|
-
Args:
|
2007
|
-
session: Session object to communicate with Snowflake.
|
2008
|
-
database_name: Desired name of the model registry database.
|
2009
|
-
schema_name: Desired name of the schema used by this model registry inside the database.
|
2010
|
-
|
2011
|
-
Returns:
|
2012
|
-
True if the creation of the model registry internal data structures was successful,
|
2013
|
-
False otherwise.
|
2014
|
-
"""
|
2015
|
-
# Get the db & schema of the current session
|
2016
|
-
old_db = session.get_current_database()
|
2017
|
-
old_schema = session.get_current_schema()
|
2018
|
-
|
2019
|
-
# These might be exposed as parameters in the future.
|
2020
|
-
database_name = identifier.get_inferred_name(database_name)
|
2021
|
-
schema_name = identifier.get_inferred_name(schema_name)
|
2022
|
-
|
2023
|
-
statement_params = telemetry.get_function_usage_statement_params(
|
2024
|
-
project=_TELEMETRY_PROJECT,
|
2025
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
2026
|
-
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), ""),
|
2027
|
-
)
|
2028
|
-
try:
|
2029
|
-
_create_registry_database(session, database_name, statement_params)
|
2030
|
-
_create_registry_schema(session, database_name, schema_name, statement_params)
|
2031
|
-
|
2032
|
-
svm = _schema_version_manager.SchemaVersionManager(session, database_name, schema_name)
|
2033
|
-
deployed_schema_version = svm.get_deployed_version(statement_params)
|
2034
|
-
if deployed_schema_version == _initial_schema._INITIAL_VERSION:
|
2035
|
-
# We do not know if registry is being created for the first time.
|
2036
|
-
# So let's start with creating initial schema, which is idempotent anyways.
|
2037
|
-
_initial_schema.create_initial_registry_tables(session, database_name, schema_name, statement_params)
|
2038
|
-
|
2039
|
-
svm.try_upgrade(statement_params)
|
2040
|
-
|
2041
|
-
finally:
|
2042
|
-
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
2043
|
-
# Restore the db & schema to the original ones
|
2044
|
-
if old_db is not None and old_db != session.get_current_database():
|
2045
|
-
session.use_database(old_db)
|
2046
|
-
if old_schema is not None and old_schema != session.get_current_schema():
|
2047
|
-
session.use_schema(old_schema)
|
2048
|
-
return True
|