snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,2048 +0,0 @@
1
- import inspect
2
- import json
3
- import sys
4
- import textwrap
5
- import types
6
- import warnings
7
- from typing import (
8
- TYPE_CHECKING,
9
- Any,
10
- Callable,
11
- Dict,
12
- List,
13
- Optional,
14
- Tuple,
15
- Union,
16
- cast,
17
- )
18
- from uuid import uuid1
19
-
20
- from absl import logging
21
-
22
- from snowflake import connector, snowpark
23
- from snowflake.ml._internal import telemetry
24
- from snowflake.ml._internal.utils import (
25
- formatting,
26
- identifier,
27
- query_result_checker,
28
- spcs_attribution_utils,
29
- table_manager,
30
- uri,
31
- )
32
- from snowflake.ml.model import (
33
- _api as model_api,
34
- deploy_platforms,
35
- model_signature,
36
- type_hints as model_types,
37
- )
38
- from snowflake.ml.registry import _initial_schema, _schema_version_manager
39
- from snowflake.snowpark._internal import utils as snowpark_utils
40
-
41
- if TYPE_CHECKING:
42
- import pandas as pd
43
-
44
- _DEFAULT_REGISTRY_NAME: str = "_SYSTEM_MODEL_REGISTRY"
45
- _DEFAULT_SCHEMA_NAME: str = "_SYSTEM_MODEL_REGISTRY_SCHEMA"
46
- _MODELS_TABLE_NAME: str = "_SYSTEM_REGISTRY_MODELS"
47
- _METADATA_TABLE_NAME: str = "_SYSTEM_REGISTRY_METADATA"
48
- _DEPLOYMENT_TABLE_NAME: str = "_SYSTEM_REGISTRY_DEPLOYMENTS"
49
-
50
- # Metadata operation types.
51
- _SET_METADATA_OPERATION: str = "SET"
52
- _ADD_METADATA_OPERATION: str = "ADD"
53
- _DROP_METADATA_OPERATION: str = "DROP"
54
-
55
- # Metadata types.
56
- _METADATA_ATTRIBUTE_DESCRIPTION: str = "DESCRIPTION"
57
- _METADATA_ATTRIBUTE_METRICS: str = "METRICS"
58
- _METADATA_ATTRIBUTE_REGISTRATION: str = "REGISTRATION"
59
- _METADATA_ATTRIBUTE_TAGS: str = "TAGS"
60
- _METADATA_ATTRIBUTE_DEPLOYMENT: str = "DEPLOYMENTS"
61
- _METADATA_ATTRIBUTE_DELETION: str = "DELETION"
62
-
63
- # Leaving out REGISTRATION/DEPLOYMENT events as they will be handled differently from all mutable attributes.
64
- _LIST_METADATA_ATTRIBUTE: List[str] = [
65
- _METADATA_ATTRIBUTE_DESCRIPTION,
66
- _METADATA_ATTRIBUTE_METRICS,
67
- _METADATA_ATTRIBUTE_TAGS,
68
- ]
69
- _TELEMETRY_PROJECT = "MLOps"
70
- _TELEMETRY_SUBPROJECT = "ModelRegistry"
71
-
72
- _STAGE_PREFIX = "@"
73
-
74
-
75
- def _create_registry_database(
76
- session: snowpark.Session,
77
- database_name: str,
78
- statement_params: Dict[str, Any],
79
- ) -> None:
80
- """Private helper to create the model registry database.
81
-
82
- The creation will be skipped if the target database already exists.
83
-
84
- Args:
85
- session: Session object to communicate with Snowflake.
86
- database_name: Desired name of the model registry database.
87
- statement_params: Function usage statement parameters used in sql query executions.
88
- """
89
- registry_databases = session.sql(f"SHOW DATABASES LIKE '{identifier.get_unescaped_names(database_name)}'").collect(
90
- statement_params=statement_params
91
- )
92
- if len(registry_databases) > 0:
93
- logging.warning(f"The database {database_name} already exists. Skipping creation.")
94
- return
95
-
96
- session.sql(f"CREATE DATABASE {database_name}").collect(statement_params=statement_params)
97
-
98
-
99
- def _create_registry_schema(
100
- session: snowpark.Session,
101
- database_name: str,
102
- schema_name: str,
103
- statement_params: Dict[str, Any],
104
- ) -> None:
105
- """Private helper to create the model registry schema.
106
-
107
- The creation will be skipped if the target schema already exists.
108
-
109
- Args:
110
- session: Session object to communicate with Snowflake.
111
- database_name: Desired name of the model registry database.
112
- schema_name: Desired name of the schema used by this model registry inside the database.
113
- statement_params: Function usage statement parameters used in sql query executions.
114
- """
115
- # The default PUBLIC schema is created by default so it might already exist even in a new database.
116
- registry_schemas = session.sql(
117
- f"SHOW SCHEMAS LIKE '{identifier.get_unescaped_names(schema_name)}' IN DATABASE {database_name}"
118
- ).collect(statement_params=statement_params)
119
-
120
- if len(registry_schemas) > 0:
121
- logging.warning(
122
- f"The schema {table_manager.get_fully_qualified_schema_name(database_name, schema_name)} already exists. "
123
- + "Skipping creation."
124
- )
125
- return
126
-
127
- session.sql(f"CREATE SCHEMA {table_manager.get_fully_qualified_schema_name(database_name, schema_name)}").collect(
128
- statement_params=statement_params
129
- )
130
-
131
-
132
- def _create_registry_views(
133
- session: snowpark.Session,
134
- database_name: str,
135
- schema_name: str,
136
- registry_table_name: str,
137
- metadata_table_name: str,
138
- deployment_table_name: str,
139
- statement_params: Dict[str, Any],
140
- ) -> None:
141
- """Create views on underlying ModelRegistry tables.
142
-
143
- Args:
144
- session: Session object to communicate with Snowflake.
145
- database_name: Desired name of the model registry database.
146
- schema_name: Desired name of the schema used by this model registry inside the database.
147
- registry_table_name: Name for the main model registry table.
148
- metadata_table_name: Name for the metadata table used by the model registry.
149
- deployment_table_name: Name for the deployment event table.
150
- statement_params: Function usage statement parameters used in sql query executions.
151
- """
152
- fully_qualified_schema_name = table_manager.get_fully_qualified_schema_name(database_name, schema_name)
153
-
154
- # From the documentation: Each DDL statement executes as a separate transaction. Races should not be an issue.
155
- # https://docs.snowflake.com/en/sql-reference/transactions.html#ddl
156
-
157
- # Create a view on active permanent deployments.
158
- _create_active_permanent_deployment_view(
159
- session,
160
- fully_qualified_schema_name,
161
- registry_table_name,
162
- deployment_table_name,
163
- statement_params,
164
- )
165
-
166
- # Create views on most recent metadata items.
167
- metadata_view_name_prefix = identifier.concat_names([metadata_table_name, "_LAST_"])
168
- metadata_view_template = formatting.unwrap(
169
- """CREATE OR REPLACE TEMPORARY VIEW {database}.{schema}.{attribute_view} COPY GRANTS AS
170
- SELECT DISTINCT MODEL_ID, {select_expression} AS {final_attribute_name} FROM {metadata_table}
171
- WHERE ATTRIBUTE_NAME = '{attribute_name}'"""
172
- )
173
-
174
- # Create a separate view for the most recent item in each metadata column.
175
- metadata_view_names = []
176
- metadata_select_fields = []
177
- for attribute_name in _LIST_METADATA_ATTRIBUTE:
178
- view_name = identifier.concat_names([metadata_view_name_prefix, attribute_name])
179
- select_expression = (
180
- f"(LAST_VALUE(VALUE) OVER (PARTITION BY MODEL_ID ORDER BY EVENT_TIMESTAMP))['{attribute_name}']"
181
- )
182
- sql = metadata_view_template.format(
183
- database=database_name,
184
- schema=schema_name,
185
- select_expression=select_expression,
186
- attribute_view=view_name,
187
- attribute_name=attribute_name,
188
- final_attribute_name=attribute_name,
189
- metadata_table=metadata_table_name,
190
- )
191
- session.sql(sql).collect(statement_params=statement_params)
192
- metadata_view_names.append(view_name)
193
- metadata_select_fields.append(f"{view_name}.{attribute_name} AS {attribute_name}")
194
-
195
- # Create a special view for the registration timestamp.
196
- attribute_name = _METADATA_ATTRIBUTE_REGISTRATION
197
- final_attribute_name = identifier.concat_names([attribute_name, "_TIMESTAMP"])
198
- view_name = identifier.concat_names([metadata_view_name_prefix, attribute_name])
199
- create_registration_view_sql = metadata_view_template.format(
200
- database=database_name,
201
- schema=schema_name,
202
- select_expression="EVENT_TIMESTAMP",
203
- attribute_view=view_name,
204
- attribute_name=attribute_name,
205
- final_attribute_name=final_attribute_name,
206
- metadata_table=metadata_table_name,
207
- )
208
- session.sql(create_registration_view_sql).collect(statement_params=statement_params)
209
- metadata_view_names.append(view_name)
210
- metadata_select_fields.append(f"{view_name}.{final_attribute_name} AS {final_attribute_name}")
211
-
212
- metadata_views_join = " ".join(
213
- [
214
- "LEFT JOIN {view} ON ({view}.MODEL_ID = {registry_table}.ID)".format(
215
- view=view, registry_table=registry_table_name
216
- )
217
- for view in metadata_view_names
218
- ]
219
- )
220
-
221
- # Create view to combine all attributes.
222
- registry_view_name = identifier.concat_names([registry_table_name, "_VIEW"])
223
- metadata_select_fields_formatted = ",".join(metadata_select_fields)
224
- session.sql(
225
- f"""CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{registry_view_name} COPY GRANTS AS
226
- SELECT {registry_table_name}.*, {metadata_select_fields_formatted}
227
- FROM {registry_table_name} {metadata_views_join}"""
228
- ).collect(statement_params=statement_params)
229
-
230
-
231
- def _create_active_permanent_deployment_view(
232
- session: snowpark.Session,
233
- fully_qualified_schema_name: str,
234
- registry_table_name: str,
235
- deployment_table_name: str,
236
- statement_params: Dict[str, Any],
237
- ) -> None:
238
- """Create a view which lists all available permanent deployments.
239
-
240
- Args:
241
- session: Session object to communicate with Snowflake.
242
- fully_qualified_schema_name: Schema name to the target table.
243
- registry_table_name: Name for the main model registry table.
244
- deployment_table_name: Name of the deployment table.
245
- statement_params: Function usage statement parameters used in sql query executions.
246
- """
247
-
248
- # Create a view on active permanent deployments
249
- # Active deployments are those whose last operation is not DROP.
250
- active_deployments_view_name = identifier.concat_names([deployment_table_name, "_VIEW"])
251
- active_deployments_view_expr = f"""
252
- CREATE OR REPLACE TEMPORARY VIEW {fully_qualified_schema_name}.{active_deployments_view_name}
253
- COPY GRANTS AS
254
- SELECT
255
- DEPLOYMENT_NAME,
256
- MODEL_ID,
257
- {registry_table_name}.NAME as MODEL_NAME,
258
- {registry_table_name}.VERSION as MODEL_VERSION,
259
- {deployment_table_name}.CREATION_TIME as CREATION_TIME,
260
- TARGET_METHOD,
261
- TARGET_PLATFORM,
262
- SIGNATURE,
263
- OPTIONS,
264
- STAGE_PATH,
265
- ROLE
266
- FROM {deployment_table_name}
267
- LEFT JOIN {registry_table_name}
268
- ON {deployment_table_name}.MODEL_ID = {registry_table_name}.ID
269
- """
270
- session.sql(active_deployments_view_expr).collect(statement_params=statement_params)
271
-
272
-
273
- class ModelRegistry:
274
- """Model Management API."""
275
-
276
- def __init__(
277
- self,
278
- *,
279
- session: snowpark.Session,
280
- database_name: str = _DEFAULT_REGISTRY_NAME,
281
- schema_name: str = _DEFAULT_SCHEMA_NAME,
282
- create_if_not_exists: bool = False,
283
- ) -> None:
284
- """
285
- Opens an already-created registry.
286
-
287
- Args:
288
- session: Session object to communicate with Snowflake.
289
- database_name: Desired name of the model registry database.
290
- schema_name: Desired name of the schema used by this model registry inside the database.
291
- create_if_not_exists: create model registry if it's not exists already.
292
- """
293
-
294
- warnings.warn(
295
- """
296
- The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0.
297
- It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`,
298
- except when specifically required. The old model registry will be removed once all its primary functionalities are
299
- fully integrated into the new registry.
300
- """,
301
- DeprecationWarning,
302
- stacklevel=2,
303
- )
304
- if create_if_not_exists:
305
- create_model_registry(session=session, database_name=database_name, schema_name=schema_name)
306
-
307
- self._name = identifier.get_inferred_name(database_name)
308
- self._schema = identifier.get_inferred_name(schema_name)
309
- self._registry_table = identifier.get_inferred_name(_MODELS_TABLE_NAME)
310
- self._registry_table_view = identifier.concat_names([self._registry_table, "_VIEW"])
311
- self._metadata_table = identifier.get_inferred_name(_METADATA_TABLE_NAME)
312
- self._deployment_table = identifier.get_inferred_name(_DEPLOYMENT_TABLE_NAME)
313
- self._permanent_deployment_view = identifier.concat_names([self._deployment_table, "_VIEW"])
314
- self._permanent_deployment_stage = identifier.concat_names([self._deployment_table, "_STAGE"])
315
- self._session = session
316
- self._svm = _schema_version_manager.SchemaVersionManager(self._session, self._name, self._schema)
317
-
318
- # A in-memory deployment info cache to store information of temporary deployments
319
- # TODO(zhe): Use a temporary table to replace the in-memory cache.
320
- self._temporary_deployments: Dict[str, model_types.Deployment] = {}
321
-
322
- _initial_schema.check_access(self._session, self._name, self._schema)
323
-
324
- statement_params = self._get_statement_params(inspect.currentframe())
325
- self._svm.validate_schema_version(statement_params)
326
-
327
- _create_registry_views(
328
- session,
329
- self._name,
330
- self._schema,
331
- self._registry_table,
332
- self._metadata_table,
333
- self._deployment_table,
334
- statement_params,
335
- )
336
-
337
- # Private methods
338
-
339
- def _get_statement_params(self, frame: Optional[types.FrameType]) -> Dict[str, Any]:
340
- return telemetry.get_function_usage_statement_params(
341
- project=_TELEMETRY_PROJECT,
342
- subproject=_TELEMETRY_SUBPROJECT,
343
- function_name=telemetry.get_statement_params_full_func_name(frame, "ModelRegistry"),
344
- )
345
-
346
- def _get_new_unique_identifier(self) -> str:
347
- """Create new unique identifier.
348
-
349
- Returns:
350
- String identifier."""
351
- return uuid1().hex
352
-
353
- def _fully_qualified_registry_table_name(self) -> str:
354
- """Get the fully qualified name to the current registry table."""
355
- return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._registry_table)
356
-
357
- def _fully_qualified_registry_view_name(self) -> str:
358
- """Get the fully qualified name to the current registry view."""
359
- return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._registry_table_view)
360
-
361
- def _fully_qualified_metadata_table_name(self) -> str:
362
- """Get the fully qualified name to the current metadata table."""
363
- return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._metadata_table)
364
-
365
- def _fully_qualified_deployment_table_name(self) -> str:
366
- """Get the fully qualified name to the current deployment table."""
367
- return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._deployment_table)
368
-
369
- def _fully_qualified_permanent_deployment_view_name(self) -> str:
370
- """Get the fully qualified name to the permanent deployment view."""
371
- return table_manager.get_fully_qualified_table_name(self._name, self._schema, self._permanent_deployment_view)
372
-
373
- def _fully_qualified_schema_name(self) -> str:
374
- """Get the fully qualified name to the current registry schema."""
375
- return table_manager.get_fully_qualified_schema_name(self._name, self._schema)
376
-
377
- def _fully_qualified_deployment_name(self, deployment_name: str) -> str:
378
- """Get the fully qualified name to the given deployment."""
379
- return table_manager.get_fully_qualified_table_name(self._name, self._schema, deployment_name)
380
-
381
- def _insert_registry_entry(
382
- self, *, id: str, name: str, version: str, properties: Dict[str, Any]
383
- ) -> List[snowpark.Row]:
384
- """Insert a new row into the model registry table.
385
-
386
- Args:
387
- id: Model id to register.
388
- name: Model Name string.
389
- version: Model Version string.
390
- properties: Dictionary of properties corresponding to table columns.
391
-
392
- Returns:
393
- snowpark.Dataframe with the result of the operation.
394
-
395
- Raises:
396
- DataError: Mismatch between different id fields.
397
- """
398
- if not id:
399
- raise connector.DataError("Model ID is required but none given.")
400
- mandatory_args = {"ID": id, "NAME": name, "VERSION": version}
401
- for k, v in mandatory_args.items():
402
- if k not in properties:
403
- properties[k] = v
404
- else:
405
- if v and v != properties[k]:
406
- raise connector.DataError(
407
- formatting.unwrap(
408
- f"""Parameter '{k.lower()}' is given and parameter 'properties' has the field '{k}' set but
409
- the values do not match: {k.lower()}=="{v}" properties['{k}']=="{properties[k]}"."""
410
- )
411
- )
412
- # Could do a multi-table insert here with some pros and cons:
413
- # [PRO] Atomic insert across multiple tables.
414
- # [CON] Code logic becomes messy depending on which fields are set.
415
- # [CON] Harder to reuse existing methods like set_model_name.
416
- # Context: https://docs.snowflake.com/en/sql-reference/sql/insert-multi-table.html
417
- return table_manager.insert_table_entry(
418
- self._session,
419
- table=self._fully_qualified_registry_table_name(),
420
- columns=properties,
421
- )
422
-
423
- def _insert_metadata_entry(self, *, id: str, attribute: str, value: Any, operation: str) -> List[snowpark.Row]:
424
- """Insert a new row into the model metadata table.
425
-
426
- Args:
427
- id: Model id to register.
428
- attribute: name of the metadata attribute
429
- value: new value of the metadata attribute
430
- operation: the operation type of the metadata entry.
431
-
432
- Returns:
433
- snowpark.DataFrame with the result of the operation.
434
-
435
- Raises:
436
- DataError: Missing ID field.
437
- """
438
- if not id:
439
- raise connector.DataError("Model ID is required but none given.")
440
-
441
- columns: Dict[str, Any] = {}
442
- columns["EVENT_TIMESTAMP"] = formatting.SqlStr("CURRENT_TIMESTAMP()")
443
- columns["EVENT_ID"] = self._get_new_unique_identifier()
444
- columns["MODEL_ID"] = id
445
- columns["ROLE"] = self._session.get_current_role()
446
- columns["OPERATION"] = operation
447
- columns["ATTRIBUTE_NAME"] = attribute
448
- columns["VALUE"] = value
449
-
450
- return table_manager.insert_table_entry(
451
- self._session,
452
- table=self._fully_qualified_metadata_table_name(),
453
- columns=columns,
454
- )
455
-
456
- def _insert_deployment_entry(
457
- self,
458
- *,
459
- id: str,
460
- name: str,
461
- platform: str,
462
- stage_path: str,
463
- signature: Dict[str, Any],
464
- target_method: str,
465
- options: Optional[
466
- Union[
467
- model_types.WarehouseDeployOptions,
468
- model_types.SnowparkContainerServiceDeployOptions,
469
- ]
470
- ] = None,
471
- ) -> List[snowpark.Row]:
472
- """Insert a new row into the model deployment table.
473
-
474
- Each row in the deployment table is a deployment event.
475
-
476
- Args:
477
- id: Model id of the deployed model.
478
- name: Name of the deployment.
479
- platform: The deployment target destination.
480
- stage_path: The stage location where the deployment UDF is stored.
481
- signature: The model signature.
482
- target_method: The method name which is used for the deployment.
483
- options: The deployment options.
484
-
485
- Returns:
486
- A list of snowpark rows which is the insertion result.
487
-
488
- Raises:
489
- DataError: Missing ID field.
490
- """
491
- if not id:
492
- raise connector.DataError("Model ID is required but none given.")
493
-
494
- columns: Dict[str, Any] = {}
495
- columns["CREATION_TIME"] = formatting.SqlStr("CURRENT_TIMESTAMP()")
496
- columns["MODEL_ID"] = id
497
- columns["DEPLOYMENT_NAME"] = name
498
- columns["TARGET_PLATFORM"] = platform
499
- columns["STAGE_PATH"] = stage_path
500
- columns["ROLE"] = self._session.get_current_role()
501
- columns["SIGNATURE"] = signature
502
- columns["TARGET_METHOD"] = target_method
503
- columns["OPTIONS"] = options
504
-
505
- return table_manager.insert_table_entry(
506
- self._session,
507
- table=self._fully_qualified_deployment_table_name(),
508
- columns=columns,
509
- )
510
-
511
- def _prepare_deployment_stage(self) -> str:
512
- """Create a stage in the model registry for storing all permanent deployments.
513
-
514
- Returns:
515
- Path to the stage that was created.
516
- """
517
- schema = self._fully_qualified_schema_name()
518
- fully_qualified_deployment_stage_name = f"{schema}.{self._permanent_deployment_stage}"
519
- statement_params = self._get_statement_params(inspect.currentframe())
520
- self._session.sql(
521
- f"CREATE STAGE IF NOT EXISTS {fully_qualified_deployment_stage_name} "
522
- f"ENCRYPTION = (TYPE= 'SNOWFLAKE_SSE')"
523
- ).collect(statement_params=statement_params)
524
- return f"@{fully_qualified_deployment_stage_name}"
525
-
526
- def _prepare_model_stage(self, model_id: str) -> str:
527
- """Create a stage in the model registry for storing the model with the given id.
528
-
529
- Creating a permanent stage here since we do not have a way to switch a stage from temporary to permanent.
530
- This can result in orphaned stages in case the process fails. It might be better to try to create a
531
- temporary stage, attempt to perform all operations and convert the temp stage into permanent once the
532
- operation is complete.
533
-
534
- Args:
535
- model_id: Internal model ID string.
536
-
537
- Returns:
538
- Name of the stage that was created.
539
-
540
- Raises:
541
- DatabaseError: Indicates that something went wrong when creating the stage.
542
- """
543
- schema = self._fully_qualified_schema_name()
544
-
545
- # Uppercase the model_stage_name to avoid having to quote the the stage name.
546
- stage_name = model_id.upper()
547
-
548
- model_stage_name = f"SNOWML_MODEL_{stage_name}"
549
- fully_qualified_model_stage_name = f"{schema}.{model_stage_name}"
550
- statement_params = self._get_statement_params(inspect.currentframe())
551
-
552
- create_stage_result = self._session.sql(
553
- f"CREATE OR REPLACE STAGE {fully_qualified_model_stage_name} ENCRYPTION = (TYPE= 'SNOWFLAKE_SSE')"
554
- ).collect(statement_params=statement_params)
555
- if not create_stage_result:
556
- raise connector.DatabaseError("Unable to create stage for model. Operation returned not result.")
557
- if len(create_stage_result) != 1:
558
- raise connector.DatabaseError(
559
- "Unable to create stage for model. Creating the model stage returned unexpected result: {}.".format(
560
- str(create_stage_result)
561
- )
562
- )
563
-
564
- return fully_qualified_model_stage_name
565
-
566
- def _get_fully_qualified_stage_name_from_uri(self, model_uri: str) -> Optional[str]:
567
- """Get fully qualified stage path pointed by the URI.
568
-
569
- Args:
570
- model_uri: URI for which stage file is needed.
571
-
572
- Returns:
573
- The fully qualified Snowflake stage location encoded by the given URI. Returns None if the URI is not
574
- pointing to a Snowflake stage.
575
- """
576
- raw_stage_path = uri.get_snowflake_stage_path_from_uri(model_uri)
577
- if not raw_stage_path:
578
- return None
579
- (db, schema, stage, _) = identifier.parse_schema_level_object_identifier(raw_stage_path)
580
- return identifier.get_schema_level_object_identifier(db, schema, stage)
581
-
582
- def _list_selected_models(
583
- self,
584
- *,
585
- id: Optional[str] = None,
586
- model_name: Optional[str] = None,
587
- model_version: Optional[str] = None,
588
- ) -> snowpark.DataFrame:
589
- """Retrieve the Snowpark dataframe of models matching the specified ID or (name and version).
590
-
591
- Args:
592
- id: Model ID string. Required if either name or version is None.
593
- model_name: Model Name string. Required if id is None.
594
- model_version: Model Version string. Required if id is None.
595
-
596
- Returns:
597
- A Snowpark dataframe representing the models that match the given constraints.
598
- """
599
- models = self.list_models()
600
-
601
- if id:
602
- filtered_models = models.filter(snowpark.Column("ID") == id)
603
- else:
604
- self._model_identifier_is_nonempty_or_raise(model_name, model_version)
605
-
606
- # The following two asserts is to satisfy mypy.
607
- assert model_name
608
- assert model_version
609
-
610
- filtered_models = models.filter(snowpark.Column("NAME") == model_name).filter(
611
- snowpark.Column("VERSION") == model_version
612
- )
613
-
614
- return cast(snowpark.DataFrame, filtered_models)
615
-
616
- def _validate_exact_one_result(
617
- self, selected_model: snowpark.DataFrame, model_identifier: str
618
- ) -> List[snowpark.Row]:
619
- """Validate the filtered model has exactly one result.
620
-
621
- Args:
622
- selected_model: A snowpark dataframe representing the models that are filtered out.
623
- model_identifier: A string which is used to filter the model.
624
-
625
- Returns:
626
- A snowpark row which contains the metadata of the filtered model
627
-
628
- Raises:
629
- KeyError: The target model doesn't exist.
630
- DataError: The target model is not unique.
631
- """
632
- statement_params = self._get_statement_params(inspect.currentframe())
633
- model_info = None
634
- try:
635
- model_info = (
636
- query_result_checker.ResultValidator(result=selected_model.collect(statement_params=statement_params))
637
- .has_dimensions(expected_rows=1)
638
- .validate()
639
- )
640
- except connector.DataError:
641
- if model_info is None or len(model_info) == 0:
642
- raise KeyError(f"The model {model_identifier} does not exist in the current registry.")
643
- else:
644
- raise connector.DataError(
645
- formatting.unwrap(
646
- f"""There are {len(model_info)} models {model_identifier}. This might indicate a problem with
647
- the integrity of the model registry data."""
648
- )
649
- )
650
- return model_info
651
-
652
- def _get_metadata_attribute(
653
- self,
654
- attribute: str,
655
- id: Optional[str] = None,
656
- model_name: Optional[str] = None,
657
- model_version: Optional[str] = None,
658
- ) -> Any:
659
- """Get the value of the given metadata attribute for target model with given (model name + model version) or id.
660
-
661
- Args:
662
- attribute: Name of the attribute to get.
663
- id: Model ID string. Required if either name or version is None.
664
- model_name: Model Name string. Required if id is None.
665
- model_version: Model Version string. Required if version is None.
666
-
667
- Returns:
668
- The value of the attribute that was requested. Can be None if the attribute is not set.
669
- """
670
- selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version)
671
- identifier = f"id {id}" if id else f"{model_name}/{model_version}"
672
- model_info = self._validate_exact_one_result(selected_models, identifier)
673
- return model_info[0][attribute]
674
-
675
- def _set_metadata_attribute(
676
- self,
677
- attribute: str,
678
- value: Any,
679
- id: Optional[str] = None,
680
- model_name: Optional[str] = None,
681
- model_version: Optional[str] = None,
682
- operation: str = _SET_METADATA_OPERATION,
683
- enable_model_presence_check: bool = True,
684
- ) -> None:
685
- """Set the value of the given metadata attribute for target model with given (model name + model version) or id.
686
-
687
- Args:
688
- attribute: Name of the attribute to set.
689
- value: Value to set.
690
- id: Model ID string. Required if either name or version is None.
691
- model_name: Model Name string. Required if id is None.
692
- model_version: Model Version string. Required if version is None.
693
- operation: the operation type of the metadata entry.
694
- enable_model_presence_check: If True, we will check if the model with the given ID is currently registered
695
- before setting the metadata attribute. False by default meaning that by default we will check.
696
-
697
- Raises:
698
- DataError: Failed to set the metadata attribute.
699
- KeyError: The target model doesn't exist
700
- """
701
- selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version)
702
- identifier = f"id {id}" if id else f"{model_name}/{model_version}"
703
- try:
704
- model_info = self._validate_exact_one_result(selected_models, identifier)
705
- except KeyError as e:
706
- # If the target model doesn't exist, raise the error only if enable_model_presence_check is True.
707
- if enable_model_presence_check:
708
- raise e
709
-
710
- if not id:
711
- id = model_info[0]["ID"]
712
- assert id is not None
713
-
714
- try:
715
- self._insert_metadata_entry(
716
- id=id,
717
- attribute=attribute,
718
- value={attribute: value},
719
- operation=operation,
720
- )
721
- except connector.DataError:
722
- raise connector.DataError(f"Setting {attribute} for mode id {id} failed.")
723
-
724
- def _model_identifier_is_nonempty_or_raise(self, model_name: Optional[str], model_version: Optional[str]) -> None:
725
- """Validate model_name and model_version are non-empty strings.
726
-
727
- Args:
728
- model_name: Model Name string.
729
- model_version: Model Version string.
730
-
731
- Raises:
732
- ValueError: Raised when either model_name and model_version is empty.
733
- """
734
- if not model_name or not model_version:
735
- raise ValueError("model_name and model_version have to be non-empty strings.")
736
-
737
- def _get_model_id(self, model_name: str, model_version: str) -> str:
738
- """Get ID of the model with the given (model name + model version).
739
-
740
- Args:
741
- model_name: Model Name string.
742
- model_version: Model Version string.
743
-
744
- Returns:
745
- Id of the model.
746
-
747
- Raises:
748
- DataError: The requested model could not be found.
749
- """
750
- result = self._get_metadata_attribute("ID", model_name=model_name, model_version=model_version)
751
- if not result:
752
- raise connector.DataError(f"Model {model_name}/{model_version} doesn't exist.")
753
- return str(result)
754
-
755
- def _get_model_path(
756
- self,
757
- id: Optional[str] = None,
758
- model_name: Optional[str] = None,
759
- model_version: Optional[str] = None,
760
- ) -> str:
761
- """Get the stage path for the model with the given (model name + model version) or `id` from the registry.
762
-
763
- Args:
764
- id: Id of the model to deploy. Required if either model name or model version is None.
765
- model_name: Model Name string. Required if id is None.
766
- model_version: Model Version string. Required if id is None.
767
-
768
- Returns:
769
- str: Stage path for the model.
770
-
771
- Raises:
772
- DataError: When the model cannot be found or not be restored.
773
- """
774
- statement_params = self._get_statement_params(inspect.currentframe())
775
- selected_models = self._list_selected_models(id=id, model_name=model_name, model_version=model_version)
776
- identifier = f"id {id}" if id else f"{model_name}/{model_version}"
777
- model_info = self._validate_exact_one_result(selected_models, identifier)
778
- if not id:
779
- id = model_info[0]["ID"]
780
- model_uri = model_info[0]["URI"]
781
-
782
- if not uri.is_snowflake_stage_uri(model_uri):
783
- raise connector.DataError(
784
- f"Artifacts with URI scheme {uri.get_uri_scheme(model_uri)} are currently not supported."
785
- )
786
-
787
- model_stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri=model_uri)
788
-
789
- # Currently we assume only the model is on the stage.
790
- model_file_list = self._session.sql(f"LIST @{model_stage_path}").collect(statement_params=statement_params)
791
- if len(model_file_list) == 0:
792
- raise connector.DataError(f"No files in model artifact for id {id} located at {model_uri}.")
793
- return f"{_STAGE_PREFIX}{model_stage_path}"
794
-
795
- def _log_model_path(
796
- self,
797
- model_name: str,
798
- model_version: str,
799
- ) -> Tuple[str, str]:
800
- """Generate a path in the Model Registry to store a model.
801
-
802
- Args:
803
- model_name: The given name for the model.
804
- model_version: Version string to be set for the model.
805
-
806
- Returns:
807
- String of the auto-generate unique model identifier and path to store it.
808
- """
809
- model_id = self._get_new_unique_identifier()
810
-
811
- # Copy model from local disk to remote stage.
812
- # TODO(zhe): Check if we could use the same stage for multiple models.
813
- fully_qualified_model_stage_name = self._prepare_model_stage(model_id=model_id)
814
-
815
- return model_id, fully_qualified_model_stage_name
816
-
817
- def _register_model_with_id(
818
- self,
819
- model_name: str,
820
- model_version: str,
821
- model_id: str,
822
- *,
823
- type: str,
824
- uri: str,
825
- input_spec: Optional[Dict[str, str]] = None,
826
- output_spec: Optional[Dict[str, str]] = None,
827
- description: Optional[str] = None,
828
- tags: Optional[Dict[str, str]] = None,
829
- ) -> None:
830
- """Helper function to register model metadata.
831
-
832
- Args:
833
- model_name: Name to be set for the model. The model name can NOT be changed after registration. The
834
- combination of name and version is expected to be unique inside the registry.
835
- model_version: Version string to be set for the model. The model version string can NOT be changed after
836
- model registration. The combination of name and version is expected to be unique inside the registry.
837
- model_id: The internal id for the model.
838
- type: Type of the model. Only a subset of types are supported natively.
839
- uri: Resource identifier pointing to the model artifact. There are no restrictions on the URI format,
840
- however only a limited set of URI schemes is supported natively.
841
- input_spec: The expected input schema of the model. Dictionary where the keys are
842
- expected column names and the values are the value types.
843
- output_spec: The expected output schema of the model. Dictionary where the keys
844
- are expected column names and the values are the value types.
845
- description: A description for the model. The description can be changed later.
846
- tags: Key-value pairs of tags to be set for this model. Tags can be modified
847
- after model registration.
848
-
849
- Raises:
850
- DataError: The given model already exists.
851
- DatabaseError: Unable to register the model properties into table.
852
- """
853
- new_model: Dict[Any, Any] = {}
854
- new_model["ID"] = model_id
855
- new_model["NAME"] = model_name
856
- new_model["VERSION"] = model_version
857
- new_model["TYPE"] = type
858
- new_model["URI"] = uri
859
- new_model["INPUT_SPEC"] = input_spec
860
- new_model["OUTPUT_SPEC"] = output_spec
861
- new_model["CREATION_TIME"] = formatting.SqlStr("CURRENT_TIMESTAMP()")
862
- new_model["CREATION_ROLE"] = self._session.get_current_role()
863
- new_model["CREATION_ENVIRONMENT_SPEC"] = {"python": ".".join(map(str, sys.version_info[:3]))}
864
-
865
- existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count()
866
- if existing_model_nums:
867
- raise connector.DataError(
868
- f"Model {model_name}/{model_version} already exists. Unable to register the model."
869
- )
870
-
871
- if self._insert_registry_entry(id=model_id, name=model_name, version=model_version, properties=new_model):
872
- self._set_metadata_attribute(
873
- model_name=model_name,
874
- model_version=model_version,
875
- attribute=_METADATA_ATTRIBUTE_REGISTRATION,
876
- value=new_model,
877
- )
878
- if description:
879
- self.set_model_description(
880
- model_name=model_name,
881
- model_version=model_version,
882
- description=description,
883
- )
884
- if tags:
885
- self._set_metadata_attribute(
886
- _METADATA_ATTRIBUTE_TAGS,
887
- value=tags,
888
- model_name=model_name,
889
- model_version=model_version,
890
- )
891
- else:
892
- raise connector.DatabaseError("Failed to insert the model properties to the registry table.")
893
-
894
- def _get_deployment(self, *, model_name: str, model_version: str, deployment_name: str) -> snowpark.Row:
895
- statement_params = self._get_statement_params(inspect.currentframe())
896
- deployment_lst = (
897
- self._session.sql(f"SELECT * FROM {self._fully_qualified_permanent_deployment_view_name()}")
898
- .filter(snowpark.Column("DEPLOYMENT_NAME") == deployment_name)
899
- .filter(snowpark.Column("MODEL_NAME") == model_name)
900
- .filter(snowpark.Column("MODEL_VERSION") == model_version)
901
- ).collect(statement_params=statement_params)
902
- if len(deployment_lst) == 0:
903
- raise KeyError(
904
- f"Unable to find deployment named {deployment_name} in the model {model_name}/{model_version}."
905
- )
906
- assert len(deployment_lst) == 1, "_get_deployment should return exactly 1 deployment"
907
- return cast(snowpark.Row, deployment_lst[0])
908
-
909
- # Registry operations
910
-
911
- @telemetry.send_api_usage_telemetry(
912
- project=_TELEMETRY_PROJECT,
913
- subproject=_TELEMETRY_SUBPROJECT,
914
- )
915
- @snowpark._internal.utils.private_preview(version="0.2.0")
916
- def list_models(self) -> snowpark.DataFrame:
917
- """Lists models contained in the registry.
918
-
919
- Returns:
920
- snowpark.DataFrame with the list of models. Access is read-only through the snowpark.DataFrame.
921
- The resulting snowpark.dataframe will have an "id" column that uniquely identifies each model and can be
922
- used to reference the model when performing operations.
923
- """
924
- # Explicitly not calling collect.
925
- return self._session.sql(
926
- "SELECT * FROM {database}.{schema}.{view}".format(
927
- database=self._name, schema=self._schema, view=self._registry_table_view
928
- )
929
- )
930
-
931
- @telemetry.send_api_usage_telemetry(
932
- project=_TELEMETRY_PROJECT,
933
- subproject=_TELEMETRY_SUBPROJECT,
934
- )
935
- @snowpark._internal.utils.private_preview(version="0.2.0")
936
- def set_tag(
937
- self,
938
- model_name: str,
939
- model_version: str,
940
- tag_name: str,
941
- tag_value: Optional[str] = None,
942
- ) -> None:
943
- """Set model tag to the model with value.
944
-
945
- If the model tag already exists, the tag value will be overwritten.
946
-
947
- Args:
948
- model_name: Model Name string.
949
- model_version: Model Version string.
950
- tag_name: Desired tag name string.
951
- tag_value: (optional) New tag value string. If no value is given the value of the tag will be set to None.
952
- """
953
- # This method uses a read-modify-write pattern for setting tags.
954
- # TODO(amauser): Investigate the use of transactions to avoid race conditions.
955
- model_tags = self.get_tags(model_name=model_name, model_version=model_version)
956
- model_tags[tag_name] = tag_value
957
- self._set_metadata_attribute(
958
- _METADATA_ATTRIBUTE_TAGS,
959
- model_tags,
960
- model_name=model_name,
961
- model_version=model_version,
962
- )
963
-
964
- @telemetry.send_api_usage_telemetry(
965
- project=_TELEMETRY_PROJECT,
966
- subproject=_TELEMETRY_SUBPROJECT,
967
- )
968
- @snowpark._internal.utils.private_preview(version="0.2.0")
969
- def remove_tag(self, model_name: str, model_version: str, tag_name: str) -> None:
970
- """Remove target model tag.
971
-
972
- Args:
973
- model_name: Model Name string.
974
- model_version: Model Version string.
975
- tag_name: Desired tag name string.
976
-
977
- Raises:
978
- DataError: If the model does not have the requested tag.
979
- """
980
- # This method uses a read-modify-write pattern for updating tags.
981
-
982
- model_tags = self.get_tags(model_name=model_name, model_version=model_version)
983
- try:
984
- del model_tags[tag_name]
985
- except KeyError:
986
- raise connector.DataError(
987
- f"Model {model_name}/{model_version} has no tag named {tag_name}. Full list of tags: {model_tags}"
988
- )
989
-
990
- self._set_metadata_attribute(
991
- _METADATA_ATTRIBUTE_TAGS,
992
- model_tags,
993
- model_name=model_name,
994
- model_version=model_version,
995
- )
996
-
997
- @telemetry.send_api_usage_telemetry(
998
- project=_TELEMETRY_PROJECT,
999
- subproject=_TELEMETRY_SUBPROJECT,
1000
- )
1001
- @snowpark._internal.utils.private_preview(version="0.2.0")
1002
- def has_tag(
1003
- self,
1004
- model_name: str,
1005
- model_version: str,
1006
- tag_name: str,
1007
- tag_value: Optional[str] = None,
1008
- ) -> bool:
1009
- """Check if a model has a tag with the given name and value.
1010
-
1011
- If no value is given, any value for the tag will return true.
1012
-
1013
- Args:
1014
- model_name: Model Name string.
1015
- model_version: Model Version string.
1016
- tag_name: Desired tag name string.
1017
- tag_value: (optional) Tag value to check. If not value is given, only the presence of the tag will be
1018
- checked.
1019
-
1020
- Returns:
1021
- True if the tag or tag and value combination is present for the model with the given id, False otherwise.
1022
- """
1023
- tags = self.get_tags(model_name=model_name, model_version=model_version)
1024
- has_tag = tag_name in tags
1025
- if tag_value is None:
1026
- return has_tag
1027
- return has_tag and tags[tag_name] == str(tag_value)
1028
-
1029
- @telemetry.send_api_usage_telemetry(
1030
- project=_TELEMETRY_PROJECT,
1031
- subproject=_TELEMETRY_SUBPROJECT,
1032
- )
1033
- @snowpark._internal.utils.private_preview(version="0.2.0")
1034
- def get_tag_value(self, model_name: str, model_version: str, tag_name: str) -> Any:
1035
- """Return the value of the tag for the model.
1036
-
1037
- The returned value can be None. If the tag does not exist, KeyError will be raised.
1038
-
1039
- Args:
1040
- model_name: Model Name string.
1041
- model_version: Model Version string.
1042
- tag_name: Desired tag name string.
1043
-
1044
- Returns:
1045
- Value string of the tag or None, if no value is set for the tag.
1046
- """
1047
- return self.get_tags(model_name=model_name, model_version=model_version)[tag_name]
1048
-
1049
- @telemetry.send_api_usage_telemetry(
1050
- project=_TELEMETRY_PROJECT,
1051
- subproject=_TELEMETRY_SUBPROJECT,
1052
- )
1053
- @snowpark._internal.utils.private_preview(version="0.2.0")
1054
- def get_tags(self, model_name: Optional[str] = None, model_version: Optional[str] = None) -> Dict[str, Any]:
1055
- """Get all tags and values stored for the target model.
1056
-
1057
- Args:
1058
- model_name: Model Name string.
1059
- model_version: Model Version string.
1060
-
1061
- Returns:
1062
- String-to-string dictionary containing all tags and values. The resulting dictionary can be empty.
1063
- """
1064
- # Snowpark snowpark.dataframe returns dictionary objects as strings. We need to convert it back to a dictionary
1065
- # here.
1066
- result = self._get_metadata_attribute(
1067
- _METADATA_ATTRIBUTE_TAGS, model_name=model_name, model_version=model_version
1068
- )
1069
-
1070
- if result:
1071
- ret: Dict[str, Optional[str]] = json.loads(result)
1072
- return ret
1073
- else:
1074
- return dict()
1075
-
1076
- @telemetry.send_api_usage_telemetry(
1077
- project=_TELEMETRY_PROJECT,
1078
- subproject=_TELEMETRY_SUBPROJECT,
1079
- )
1080
- @snowpark._internal.utils.private_preview(version="0.2.0")
1081
- def get_model_description(self, model_name: str, model_version: str) -> Optional[str]:
1082
- """Get the description of the model.
1083
-
1084
- Args:
1085
- model_name: Model Name string.
1086
- model_version: Model Version string.
1087
-
1088
- Returns:
1089
- Description of the model or None.
1090
- """
1091
- result = self._get_metadata_attribute(
1092
- _METADATA_ATTRIBUTE_DESCRIPTION,
1093
- model_name=model_name,
1094
- model_version=model_version,
1095
- )
1096
- return None if result is None else json.loads(result)
1097
-
1098
- @telemetry.send_api_usage_telemetry(
1099
- project=_TELEMETRY_PROJECT,
1100
- subproject=_TELEMETRY_SUBPROJECT,
1101
- )
1102
- @snowpark._internal.utils.private_preview(version="0.2.0")
1103
- def set_model_description(
1104
- self,
1105
- model_name: str,
1106
- model_version: str,
1107
- description: str,
1108
- ) -> None:
1109
- """Set the description of the model.
1110
-
1111
- Args:
1112
- model_name: Model Name string.
1113
- model_version: Model Version string.
1114
- description: Desired new model description.
1115
- """
1116
- self._set_metadata_attribute(
1117
- _METADATA_ATTRIBUTE_DESCRIPTION,
1118
- description,
1119
- model_name=model_name,
1120
- model_version=model_version,
1121
- )
1122
-
1123
- @telemetry.send_api_usage_telemetry(
1124
- project=_TELEMETRY_PROJECT,
1125
- subproject=_TELEMETRY_SUBPROJECT,
1126
- )
1127
- @snowpark._internal.utils.private_preview(version="0.2.0")
1128
- def get_history(self) -> snowpark.DataFrame:
1129
- """Return a dataframe with the history of operations performed on the model registry.
1130
-
1131
- The returned dataframe is order by time and can be filtered further.
1132
-
1133
- Returns:
1134
- snowpark.DataFrame with the history of the model.
1135
- """
1136
- res = (
1137
- self._session.table(self._fully_qualified_metadata_table_name())
1138
- .order_by("EVENT_TIMESTAMP")
1139
- .select_expr(
1140
- "EVENT_TIMESTAMP",
1141
- "EVENT_ID",
1142
- "MODEL_ID",
1143
- "ROLE",
1144
- "OPERATION",
1145
- "ATTRIBUTE_NAME",
1146
- "VALUE[ATTRIBUTE_NAME]",
1147
- )
1148
- )
1149
- return cast(snowpark.DataFrame, res)
1150
-
1151
- @telemetry.send_api_usage_telemetry(
1152
- project=_TELEMETRY_PROJECT,
1153
- subproject=_TELEMETRY_SUBPROJECT,
1154
- )
1155
- @snowpark._internal.utils.private_preview(version="0.2.0")
1156
- def get_model_history(
1157
- self,
1158
- model_name: str,
1159
- model_version: str,
1160
- ) -> snowpark.DataFrame:
1161
- """Return a dataframe with the history of operations performed on the desired model.
1162
-
1163
- The returned dataframe is order by time and can be filtered further.
1164
-
1165
- Args:
1166
- model_name: Model Name string.
1167
- model_version: Model Version string.
1168
-
1169
- Returns:
1170
- snowpark.DataFrame with the history of the model.
1171
- """
1172
- id = self._get_model_id(model_name=model_name, model_version=model_version)
1173
- return cast(
1174
- snowpark.DataFrame,
1175
- self.get_history().filter(snowpark.Column("MODEL_ID") == id),
1176
- )
1177
-
1178
- @telemetry.send_api_usage_telemetry(
1179
- project=_TELEMETRY_PROJECT,
1180
- subproject=_TELEMETRY_SUBPROJECT,
1181
- )
1182
- @snowpark._internal.utils.private_preview(version="0.2.0")
1183
- def set_metric(
1184
- self,
1185
- model_name: str,
1186
- model_version: str,
1187
- metric_name: str,
1188
- metric_value: object,
1189
- ) -> None:
1190
- """Set scalar model metric to value.
1191
-
1192
- If a metric with that name already exists for the model, the metric value will be overwritten.
1193
-
1194
- Args:
1195
- model_name: Model Name string.
1196
- model_version: Model Version string.
1197
- metric_name: Desired metric name.
1198
- metric_value: New metric value.
1199
- """
1200
- # This method uses a read-modify-write pattern for setting tags.
1201
- # TODO(amauser): Investigate the use of transactions to avoid race conditions.
1202
- model_metrics = self.get_metrics(model_name=model_name, model_version=model_version)
1203
- model_metrics[metric_name] = metric_value
1204
- self._set_metadata_attribute(
1205
- _METADATA_ATTRIBUTE_METRICS,
1206
- model_metrics,
1207
- model_name=model_name,
1208
- model_version=model_version,
1209
- )
1210
-
1211
- @telemetry.send_api_usage_telemetry(
1212
- project=_TELEMETRY_PROJECT,
1213
- subproject=_TELEMETRY_SUBPROJECT,
1214
- )
1215
- @snowpark._internal.utils.private_preview(version="0.2.0")
1216
- def remove_metric(
1217
- self,
1218
- model_name: str,
1219
- model_version: str,
1220
- metric_name: str,
1221
- ) -> None:
1222
- """Remove a specific metric entry from the model.
1223
-
1224
- Args:
1225
- model_name: Model Name string.
1226
- model_version: Model Version string.
1227
- metric_name: Desired metric name.
1228
-
1229
- Raises:
1230
- DataError: If the model does not have the requested metric.
1231
- """
1232
- # This method uses a read-modify-write pattern for updating tags.
1233
-
1234
- model_metrics = self.get_metrics(model_name=model_name, model_version=model_version)
1235
- try:
1236
- del model_metrics[metric_name]
1237
- except KeyError:
1238
- raise connector.DataError(
1239
- f"Model {model_name}/{model_version} has no metric named {metric_name}. "
1240
- f"Full list of metrics: {model_metrics}"
1241
- )
1242
-
1243
- self._set_metadata_attribute(
1244
- _METADATA_ATTRIBUTE_METRICS,
1245
- model_metrics,
1246
- model_name=model_name,
1247
- model_version=model_version,
1248
- )
1249
-
1250
- @telemetry.send_api_usage_telemetry(
1251
- project=_TELEMETRY_PROJECT,
1252
- subproject=_TELEMETRY_SUBPROJECT,
1253
- )
1254
- @snowpark._internal.utils.private_preview(version="0.2.0")
1255
- def has_metric(self, model_name: str, model_version: str, metric_name: str) -> bool:
1256
- """Check if a model has a metric with the given name.
1257
-
1258
- Args:
1259
- model_name: Model Name string.
1260
- model_version: Model Version string.
1261
- metric_name: Desired metric name.
1262
-
1263
- Returns:
1264
- True if the metric is present for the model with the given id, False otherwise.
1265
- """
1266
- metrics = self.get_metrics(model_name=model_name, model_version=model_version)
1267
- return metric_name in metrics
1268
-
1269
- @telemetry.send_api_usage_telemetry(
1270
- project=_TELEMETRY_PROJECT,
1271
- subproject=_TELEMETRY_SUBPROJECT,
1272
- )
1273
- @snowpark._internal.utils.private_preview(version="0.2.0")
1274
- def get_metric_value(self, model_name: str, model_version: str, metric_name: str) -> object:
1275
- """Return the value of the given metric for the model.
1276
-
1277
- The returned value can be None. If the metric does not exist, KeyError will be raised.
1278
-
1279
- Args:
1280
- model_name: Model Name string.
1281
- model_version: Model Version string.
1282
- metric_name: Desired metric name.
1283
-
1284
- Returns:
1285
- Value of the metric. Can be None if the metric was set to None.
1286
- """
1287
- return self.get_metrics(model_name=model_name, model_version=model_version)[metric_name]
1288
-
1289
- @telemetry.send_api_usage_telemetry(
1290
- project=_TELEMETRY_PROJECT,
1291
- subproject=_TELEMETRY_SUBPROJECT,
1292
- )
1293
- @snowpark._internal.utils.private_preview(version="0.2.0")
1294
- def get_metrics(self, model_name: str, model_version: str) -> Dict[str, object]:
1295
- """Get all metrics and values stored for the given model.
1296
-
1297
- Args:
1298
- model_name: Model Name string.
1299
- model_version: Model Version string.
1300
-
1301
- Returns:
1302
- String-to-float dictionary containing all metrics and values. The resulting dictionary can be empty.
1303
- """
1304
- # Snowpark snowpark.dataframe returns dictionary objects as strings. We need to convert it back to a dictionary
1305
- # here.
1306
- result = self._get_metadata_attribute(
1307
- _METADATA_ATTRIBUTE_METRICS,
1308
- model_name=model_name,
1309
- model_version=model_version,
1310
- )
1311
-
1312
- if result:
1313
- ret: Dict[str, object] = json.loads(result)
1314
- return ret
1315
- else:
1316
- return dict()
1317
-
1318
- # Combined Registry and Repository operations.
1319
- @telemetry.send_api_usage_telemetry(
1320
- project=_TELEMETRY_PROJECT,
1321
- subproject=_TELEMETRY_SUBPROJECT,
1322
- )
1323
- @snowpark._internal.utils.private_preview(version="0.2.0")
1324
- def log_model(
1325
- self,
1326
- model_name: str,
1327
- model_version: str,
1328
- *,
1329
- model: Any,
1330
- description: Optional[str] = None,
1331
- tags: Optional[Dict[str, str]] = None,
1332
- conda_dependencies: Optional[List[str]] = None,
1333
- pip_requirements: Optional[List[str]] = None,
1334
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
1335
- sample_input_data: Optional[Any] = None,
1336
- code_paths: Optional[List[str]] = None,
1337
- options: Optional[model_types.BaseModelSaveOption] = None,
1338
- ) -> Optional["ModelReference"]:
1339
- """Uploads and register a model to the Model Registry.
1340
-
1341
- Args:
1342
- model_name: The given name for the model. The combination (name + version) must be unique for each model.
1343
- model_version: Version string to be set for the model. The combination (name + version) must be unique for
1344
- each model.
1345
- model: Local model object in a supported format.
1346
- description: A description for the model. The description can be changed later.
1347
- tags: string-to-string dictionary of tag names and values to be set for the model.
1348
- conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to
1349
- specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is
1350
- not specified, defaults channel will be used. When deploying to Snowflake Warehouse, defaults channel
1351
- would be replaced with the Snowflake Anaconda channel.
1352
- pip_requirements: List of PIP package specs. Model will not be able to deploy to the warehouse if there is
1353
- pip requirements.
1354
- signatures: Signatures of the model, which is a mapping from target method name to signatures of input and
1355
- output, which could be inferred by calling `infer_signature` method with sample input data.
1356
- sample_input_data: Sample of the input data for the model.
1357
- code_paths: Directory of code to import when loading and deploying the model.
1358
- options: Additional options when saving the model.
1359
-
1360
- Raises:
1361
- DataError: Raised when:
1362
- 1) the given model already exists;
1363
- ValueError: Raised when: # noqa: DAR402
1364
- 1) Signatures and sample_input_data are both not provided and model is not a
1365
- snowflake estimator.
1366
- Exception: Raised when there is any error raised when saving the model.
1367
-
1368
- Returns:
1369
- Model Reference . None if failed.
1370
- """
1371
- # Ideally, the whole operation should be a single transaction. Currently, transactions do not support stage
1372
- # operations.
1373
-
1374
- statement_params = self._get_statement_params(inspect.currentframe())
1375
- self._svm.validate_schema_version(statement_params)
1376
-
1377
- self._model_identifier_is_nonempty_or_raise(model_name, model_version)
1378
-
1379
- existing_model_nums = self._list_selected_models(model_name=model_name, model_version=model_version).count()
1380
- if existing_model_nums:
1381
- raise connector.DataError(f"Model {model_name}/{model_version} already exists. Unable to log the model.")
1382
- model_id, fully_qualified_model_stage_name = self._log_model_path(
1383
- model_name=model_name,
1384
- model_version=model_version,
1385
- )
1386
- stage_path = f"{_STAGE_PREFIX}{fully_qualified_model_stage_name}"
1387
- model = cast(model_types.SupportedModelType, model)
1388
- try:
1389
- model_composer = model_api.save_model( # type: ignore[call-overload, misc]
1390
- name=model_name,
1391
- session=self._session,
1392
- stage_path=stage_path,
1393
- model=model,
1394
- signatures=signatures,
1395
- metadata=tags,
1396
- conda_dependencies=conda_dependencies,
1397
- pip_requirements=pip_requirements,
1398
- sample_input_data=sample_input_data,
1399
- code_paths=code_paths,
1400
- options=options,
1401
- )
1402
- except Exception:
1403
- # When model saving fails, clean up the model stage.
1404
- query_result_checker.SqlResultValidator(
1405
- self._session, f"DROP STAGE {fully_qualified_model_stage_name}"
1406
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
1407
- raise
1408
-
1409
- self._register_model_with_id(
1410
- model_name=model_name,
1411
- model_version=model_version,
1412
- model_id=model_id,
1413
- type=model_composer.packager.meta.model_type,
1414
- uri=uri.get_uri_from_snowflake_stage_path(stage_path),
1415
- description=description,
1416
- tags=tags,
1417
- )
1418
-
1419
- return ModelReference(registry=self, model_name=model_name, model_version=model_version)
1420
-
1421
- @telemetry.send_api_usage_telemetry(
1422
- project=_TELEMETRY_PROJECT,
1423
- subproject=_TELEMETRY_SUBPROJECT,
1424
- )
1425
- @snowpark._internal.utils.private_preview(version="0.2.0")
1426
- def load_model(self, model_name: str, model_version: str) -> Any:
1427
- """Loads the model with the given (model_name + model_version) from the registry into memory.
1428
-
1429
- Args:
1430
- model_name: Model Name string.
1431
- model_version: Model Version string.
1432
-
1433
- Returns:
1434
- Restored model object.
1435
- """
1436
- warnings.warn(
1437
- (
1438
- "Please use with caution: "
1439
- "Using `load_model` method requires you to have the EXACT same Python environments "
1440
- "as the one when you logged the model. Any differences will potentially lead to errors.\n"
1441
- "Also, if your model contains custom code imported using `code_paths` argument when logging, "
1442
- "they will be added to your `sys.path`. It might lead to unexpected module importing issues. "
1443
- "If you run into such kind of problems, you need to restart your Python or Notebook kernel."
1444
- ),
1445
- category=UserWarning,
1446
- stacklevel=2,
1447
- )
1448
- remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version)
1449
- restored_model = None
1450
-
1451
- restored_model = model_api.load_model(session=self._session, stage_path=remote_model_path)
1452
-
1453
- return restored_model.packager.model
1454
-
1455
- # Repository Operations
1456
-
1457
- @telemetry.send_api_usage_telemetry(
1458
- project=_TELEMETRY_PROJECT,
1459
- subproject=_TELEMETRY_SUBPROJECT,
1460
- )
1461
- @snowpark._internal.utils.private_preview(version="0.2.0")
1462
- def deploy(
1463
- self,
1464
- model_name: str,
1465
- model_version: str,
1466
- *,
1467
- deployment_name: str,
1468
- target_method: Optional[str] = None,
1469
- permanent: bool = False,
1470
- platform: deploy_platforms.TargetPlatform = deploy_platforms.TargetPlatform.WAREHOUSE,
1471
- options: Optional[
1472
- Union[
1473
- model_types.WarehouseDeployOptions,
1474
- model_types.SnowparkContainerServiceDeployOptions,
1475
- ]
1476
- ] = None,
1477
- ) -> model_types.Deployment:
1478
- """Deploy the model with the given deployment name.
1479
-
1480
- Args:
1481
- model_name: Model Name string.
1482
- model_version: Model Version string.
1483
- deployment_name: name of the generated UDF.
1484
- target_method: The method name to use in deployment. Can be omitted if only 1 method in the model.
1485
- permanent: Whether the deployment is permanent or not. Permanent deployment will generate a permanent UDF.
1486
- (Only applicable for Warehouse deployment)
1487
- platform: Target platform to deploy the model to. Currently supported platforms are defined as enum in
1488
- `snowflake.ml.model.deploy_platforms.TargetPlatform`
1489
- options: Optional options for model deployment. Defaults to None.
1490
-
1491
- Returns:
1492
- Deployment info.
1493
-
1494
- Raises:
1495
- RuntimeError: Raised when parameters are not properly enabled when deploying to Warehouse with temporary UDF
1496
- RuntimeError: Raised when deploying to SPCS with db/schema that starts with underscore.
1497
- """
1498
- statement_params = self._get_statement_params(inspect.currentframe())
1499
- self._svm.validate_schema_version(statement_params)
1500
-
1501
- if options is None:
1502
- options = {}
1503
-
1504
- deployment_stage_path = ""
1505
-
1506
- if platform == deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES:
1507
- if self._name.startswith("_") or self._schema.startswith("_"):
1508
- error_message = """\
1509
- Model deployment to Snowpark Container Service does not support a database/schema name that starts with
1510
- an underscore. Please ensure you pass in a valid db/schema name when initializing the registry with:
1511
-
1512
- model_registry.create_model_registry(
1513
- session=session,
1514
- database_name=db,
1515
- schema_name=schema
1516
- )
1517
-
1518
- registry = model_registry.ModelRegistry(
1519
- session=session,
1520
- database_name=db,
1521
- schema_name=schema
1522
- )
1523
- """
1524
- raise RuntimeError(textwrap.dedent(error_message))
1525
- permanent = True
1526
- options = cast(model_types.SnowparkContainerServiceDeployOptions, options)
1527
- deployment_stage_path = f"{self._prepare_deployment_stage()}/{deployment_name}/"
1528
- elif platform == deploy_platforms.TargetPlatform.WAREHOUSE:
1529
- options = cast(model_types.WarehouseDeployOptions, options)
1530
- if permanent:
1531
- # Every deployment-generated UDF should reside in its own unique directory. As long as each deployment
1532
- # is allocated a distinct directory, multiple deployments can coexist within the same stage.
1533
- # Given that each permanent deployment possesses a unique deployment_name, sharing the same stage does
1534
- # not present any issues
1535
- deployment_stage_path = (
1536
- options.get("permanent_udf_stage_location")
1537
- or f"{self._prepare_deployment_stage()}/{deployment_name}/"
1538
- )
1539
- options["permanent_udf_stage_location"] = deployment_stage_path
1540
-
1541
- remote_model_path = self._get_model_path(model_name=model_name, model_version=model_version)
1542
- model_id = self._get_model_id(model_name, model_version)
1543
-
1544
- # https://snowflakecomputing.atlassian.net/browse/SNOW-858376
1545
- # During temporary deployment on the Warehouse, Snowpark creates an unencrypted temporary stage for UDF-related
1546
- # artifacts. However, UDF generation fails when importing from a mix of encrypted and unencrypted stages.
1547
- # The following workaround copies model between stages (PrPr as of July 7th, 2023) to transfer the SSE
1548
- # encrypted model zip from model stage to the temporary unencrypted stage.
1549
- if not permanent and platform == deploy_platforms.TargetPlatform.WAREHOUSE:
1550
- schema = self._fully_qualified_schema_name()
1551
- unencrypted_stage = (
1552
- f"@{schema}.{snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)}"
1553
- )
1554
- self._session.sql(f"CREATE TEMPORARY STAGE {unencrypted_stage[1:]}").collect()
1555
- try:
1556
- self._session.sql(f"COPY FILES INTO {unencrypted_stage} from {remote_model_path}").collect()
1557
- except Exception:
1558
- raise RuntimeError(
1559
- "Temporary deployment to the warehouse is currently not supported. Please use "
1560
- "permanent deployment by setting the 'permanent' parameter to True"
1561
- )
1562
- remote_model_path = unencrypted_stage
1563
-
1564
- # Step 1: Deploy to get the UDF
1565
- deployment_info = model_api.deploy(
1566
- session=self._session,
1567
- name=self._fully_qualified_deployment_name(deployment_name),
1568
- platform=platform,
1569
- target_method=target_method,
1570
- stage_path=remote_model_path,
1571
- deployment_stage_path=deployment_stage_path,
1572
- model_id=model_id,
1573
- options=options,
1574
- )
1575
-
1576
- # Step 2: Record the deployment
1577
-
1578
- # Assert to convince mypy.
1579
- assert deployment_info
1580
- if permanent:
1581
- self._insert_deployment_entry(
1582
- id=model_id,
1583
- name=deployment_name,
1584
- platform=deployment_info["platform"].value,
1585
- stage_path=deployment_stage_path,
1586
- signature=deployment_info["signature"].to_dict(),
1587
- target_method=deployment_info["target_method"],
1588
- options=options,
1589
- )
1590
-
1591
- self._set_metadata_attribute(
1592
- _METADATA_ATTRIBUTE_DEPLOYMENT,
1593
- {"name": deployment_name, "permanent": permanent},
1594
- id=model_id,
1595
- operation=_ADD_METADATA_OPERATION,
1596
- )
1597
-
1598
- # Store temporary deployment information in the in-memory cache. This allows for future referencing and
1599
- # tracking of its availability status.
1600
- if not permanent:
1601
- self._temporary_deployments[deployment_name] = deployment_info
1602
-
1603
- return deployment_info
1604
-
1605
- @telemetry.send_api_usage_telemetry(
1606
- project=_TELEMETRY_PROJECT,
1607
- subproject=_TELEMETRY_SUBPROJECT,
1608
- )
1609
- @snowpark._internal.utils.private_preview(version="1.0.1")
1610
- def list_deployments(self, model_name: str, model_version: str) -> snowpark.DataFrame:
1611
- """List all permanent deployments that originated from the given model.
1612
-
1613
- Temporary deployment info are currently not supported for listing.
1614
-
1615
- Args:
1616
- model_name: Model Name string.
1617
- model_version: Model Version string.
1618
-
1619
- Returns:
1620
- A snowpark dataframe that contains all deployments that associated with the given model.
1621
- """
1622
- deployments_df = (
1623
- self._session.sql(f"SELECT * FROM {self._fully_qualified_permanent_deployment_view_name()}")
1624
- .filter(snowpark.Column("MODEL_NAME") == model_name)
1625
- .filter(snowpark.Column("MODEL_VERSION") == model_version)
1626
- )
1627
- res = deployments_df.select(
1628
- deployments_df["MODEL_NAME"],
1629
- deployments_df["MODEL_VERSION"],
1630
- deployments_df["DEPLOYMENT_NAME"],
1631
- deployments_df["CREATION_TIME"],
1632
- deployments_df["TARGET_METHOD"],
1633
- deployments_df["TARGET_PLATFORM"],
1634
- deployments_df["SIGNATURE"],
1635
- deployments_df["OPTIONS"],
1636
- deployments_df["STAGE_PATH"],
1637
- deployments_df["ROLE"],
1638
- )
1639
- return cast(snowpark.DataFrame, res)
1640
-
1641
- @telemetry.send_api_usage_telemetry(
1642
- project=_TELEMETRY_PROJECT,
1643
- subproject=_TELEMETRY_SUBPROJECT,
1644
- )
1645
- @snowpark._internal.utils.private_preview(version="1.0.1")
1646
- def get_deployment(self, model_name: str, model_version: str, *, deployment_name: str) -> snowpark.DataFrame:
1647
- """Get the permanent deployment with target name of the given model.
1648
-
1649
- Temporary deployment info are currently not supported.
1650
-
1651
- Args:
1652
- model_name: Model Name string.
1653
- model_version: Model Version string.
1654
- deployment_name: Deployment name string.
1655
-
1656
- Returns:
1657
- A snowpark dataframe that contains the information of the target deployment.
1658
-
1659
- Raises:
1660
- KeyError: Raised if the target deployment is not found.
1661
- """
1662
- deployment = self.list_deployments(model_name, model_version).filter(
1663
- snowpark.Column("DEPLOYMENT_NAME") == deployment_name
1664
- )
1665
- if deployment.count() == 0:
1666
- raise KeyError(
1667
- f"Unable to find deployment named {deployment_name} in the model {model_name}/{model_version}."
1668
- )
1669
- return cast(snowpark.DataFrame, deployment)
1670
-
1671
- @telemetry.send_api_usage_telemetry(
1672
- project=_TELEMETRY_PROJECT,
1673
- subproject=_TELEMETRY_SUBPROJECT,
1674
- )
1675
- @snowpark._internal.utils.private_preview(version="1.0.1")
1676
- def delete_deployment(self, model_name: str, model_version: str, *, deployment_name: str) -> None:
1677
- """Delete the target permanent deployment of the given model.
1678
-
1679
- Deleting temporary deployment are currently not supported.
1680
- Temporary deployment will get cleaned automatically when the current session closed.
1681
-
1682
- Args:
1683
- model_name: Model Name string.
1684
- model_version: Model Version string.
1685
- deployment_name: Name of the deployment that is getting deleted.
1686
-
1687
- """
1688
- deployment = self._get_deployment(
1689
- model_name=model_name,
1690
- model_version=model_version,
1691
- deployment_name=deployment_name,
1692
- )
1693
-
1694
- # TODO(SNOW-759526): The following sequence should be a transaction.
1695
- # Step 1: Drop the UDF
1696
- self._session.sql(
1697
- f"DROP FUNCTION IF EXISTS {self._fully_qualified_deployment_name(deployment_name)}(OBJECT)"
1698
- ).collect()
1699
-
1700
- # Step 2: Remove the staged artifact
1701
- self._session.sql(f"REMOVE {deployment['STAGE_PATH']}").collect()
1702
-
1703
- # Step 3: Delete the deployment from the deployment table
1704
- query_result_checker.SqlResultValidator(
1705
- self._session,
1706
- f"""DELETE FROM {self._fully_qualified_deployment_table_name()}
1707
- WHERE MODEL_ID='{deployment['MODEL_ID']}' AND DEPLOYMENT_NAME='{deployment_name}'
1708
- """,
1709
- ).deletion_success(expected_num_rows=1).validate()
1710
-
1711
- # Step 4: Record the delete event
1712
- self._set_metadata_attribute(
1713
- _METADATA_ATTRIBUTE_DEPLOYMENT,
1714
- {"name": deployment_name},
1715
- id=deployment["MODEL_ID"],
1716
- operation=_DROP_METADATA_OPERATION,
1717
- )
1718
-
1719
- # Optional Step 5: Delete Snowpark container service.
1720
- if deployment["TARGET_PLATFORM"] == deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES.value:
1721
- service_name = identifier.get_schema_level_object_identifier(
1722
- self._name, self._schema, f"service_{deployment['MODEL_ID']}"
1723
- )
1724
- spcs_attribution_utils.record_service_end(self._session, service_name)
1725
- query_result_checker.SqlResultValidator(
1726
- self._session,
1727
- f"DROP SERVICE IF EXISTS {service_name}",
1728
- ).validate()
1729
-
1730
- @telemetry.send_api_usage_telemetry(
1731
- project=_TELEMETRY_PROJECT,
1732
- subproject=_TELEMETRY_SUBPROJECT,
1733
- )
1734
- @snowpark._internal.utils.private_preview(version="0.2.0")
1735
- def delete_model(
1736
- self,
1737
- model_name: str,
1738
- model_version: str,
1739
- delete_artifact: bool = True,
1740
- ) -> None:
1741
- """Delete model with the given ID from the registry.
1742
-
1743
- The history of the model will still be preserved.
1744
-
1745
- Args:
1746
- model_name: Model Name string.
1747
- model_version: Model Version string.
1748
- delete_artifact: If True, the underlying model artifact will also be deleted, not just the entry in
1749
- the registry table.
1750
- """
1751
-
1752
- # Check that a model with the given ID exists and there is only one of them.
1753
- # TODO(amauser): The following sequence should be a transaction. Transactions currently cannot contain DDL
1754
- # statements.
1755
- model_info = None
1756
- selected_models = self._list_selected_models(model_name=model_name, model_version=model_version)
1757
- identifier = f"{model_name}/{model_version}"
1758
- model_info = self._validate_exact_one_result(selected_models, identifier)
1759
- id = model_info[0]["ID"]
1760
- model_uri = model_info[0]["URI"]
1761
-
1762
- # Step 1/3: Delete the registry entry.
1763
- query_result_checker.SqlResultValidator(
1764
- self._session,
1765
- f"DELETE FROM {self._fully_qualified_registry_table_name()} WHERE ID='{id}'",
1766
- ).deletion_success(expected_num_rows=1).validate()
1767
-
1768
- # Step 2/3: Delete the artifact (if desired).
1769
- if delete_artifact:
1770
- if uri.is_snowflake_stage_uri(model_uri):
1771
- stage_path = self._get_fully_qualified_stage_name_from_uri(model_uri)
1772
- query_result_checker.SqlResultValidator(self._session, f"DROP STAGE {stage_path}").has_dimensions(
1773
- expected_rows=1, expected_cols=1
1774
- ).validate()
1775
-
1776
- # Step 3/3: Record the deletion event.
1777
- self._set_metadata_attribute(
1778
- id=id,
1779
- attribute=_METADATA_ATTRIBUTE_DELETION,
1780
- value={"delete_artifact": True, "URI": model_uri},
1781
- enable_model_presence_check=False,
1782
- )
1783
-
1784
-
1785
- class ModelReference:
1786
- """Wrapper class for ModelReference objects that proxy model metadata operations."""
1787
-
1788
- def _remove_arg_from_docstring(self, arg: str, docstring: Optional[str]) -> Optional[str]:
1789
- """Remove the given parameter from a function docstring (Google convention)."""
1790
- if docstring is None:
1791
- return None
1792
- docstring_lines = docstring.split("\n")
1793
-
1794
- args_section_start = None
1795
- args_section_end = None
1796
- args_section_indent = None
1797
- arg_start = None
1798
- arg_end = None
1799
- arg_indent = None
1800
- for i in range(len(docstring_lines)):
1801
- line = docstring_lines[i]
1802
- lstrip_line = line.lstrip()
1803
- indent = len(line) - len(lstrip_line)
1804
-
1805
- if line.strip() == "Args:":
1806
- # Starting the Args section of the docstring (assuming Google-style).
1807
- args_section_start = i
1808
- # logging.info("TEST: args_section_start=" + str(args_section_start))
1809
- args_section_indent = indent
1810
- continue
1811
-
1812
- # logging.info("TEST: " + lstrip_line)
1813
- if args_section_start and lstrip_line.startswith(f"{arg}:"):
1814
- # This is the arg we are looking for.
1815
- arg_start = i
1816
- # logging.info("TEST: arg_start=" + str(arg_start))
1817
- arg_indent = indent
1818
- continue
1819
-
1820
- if arg_start and not arg_end and indent == arg_indent:
1821
- # We got the next arg, previous line was the last of the cut out arg docstring
1822
- # and we do have other args. Saving arg_end for python slice/range notation.
1823
- arg_end = i
1824
- continue
1825
-
1826
- if arg_start and (len(lstrip_line) == 0 or indent == args_section_indent):
1827
- # Arg section ends.
1828
- args_section_end = i
1829
- arg_end = arg_end if arg_end else i
1830
- # We have learned everything we need to know, no need to continue.
1831
- break
1832
-
1833
- if arg_start and not arg_end:
1834
- arg_end = len(docstring_lines)
1835
-
1836
- if args_section_start and not args_section_end:
1837
- args_section_end = len(docstring_lines)
1838
-
1839
- # Determine which lines from the "Args:" section of the docstring to skip or if we
1840
- # should skip the entire section.
1841
- keep_lines = set(range(len(docstring_lines)))
1842
- if args_section_start:
1843
- if arg_start == args_section_start + 1 and arg_end == args_section_end:
1844
- # Removed arg was the only arg, remove the entire section.
1845
- assert args_section_end
1846
- keep_lines.difference_update(range(args_section_start, args_section_end))
1847
- else:
1848
- # Just remove the arg.
1849
- assert arg_start
1850
- assert arg_end
1851
- keep_lines.difference_update(range(arg_start, arg_end))
1852
-
1853
- return "\n".join([docstring_lines[i] for i in sorted(keep_lines)])
1854
-
1855
- def __init__(
1856
- self,
1857
- *,
1858
- registry: ModelRegistry,
1859
- model_name: str,
1860
- model_version: str,
1861
- ) -> None:
1862
- self._registry = registry
1863
- self._id = registry._get_model_id(model_name=model_name, model_version=model_version)
1864
- self._model_name = model_name
1865
- self._model_version = model_version
1866
-
1867
- # Wrap all functions of the ModelRegistry that have an "id" parameter and bind that parameter
1868
- # the the "_id" member of this class.
1869
- if hasattr(self.__class__, "init_complete"):
1870
- # Already did the generation of wrapped method.
1871
- return
1872
-
1873
- for name, obj in self._registry.__class__.__dict__.items():
1874
- if (
1875
- not inspect.isfunction(obj)
1876
- or "model_name" not in inspect.signature(obj).parameters
1877
- or "model_version" not in inspect.signature(obj).parameters
1878
- ):
1879
- continue
1880
-
1881
- # Ensure that we are not silently overwriting existing functions.
1882
- assert not hasattr(self.__class__, name)
1883
-
1884
- def build_method(m: Callable[..., Any]) -> Callable[..., Any]:
1885
- return lambda self, *args, **kwargs: m(
1886
- self._registry,
1887
- self._model_name,
1888
- self._model_version,
1889
- *args,
1890
- **kwargs,
1891
- )
1892
-
1893
- method = build_method(m=obj)
1894
- setattr(self.__class__, name, method)
1895
-
1896
- docstring = self._remove_arg_from_docstring("model_name", obj.__doc__)
1897
- if docstring and "model_version" in docstring:
1898
- docstring = self._remove_arg_from_docstring("model_version", docstring)
1899
- setattr(self.__class__.__dict__[name], "__doc__", docstring) # noqa: B010
1900
-
1901
- setattr(self.__class__, "init_complete", True) # noqa: B010
1902
-
1903
- @telemetry.send_api_usage_telemetry(
1904
- project=_TELEMETRY_PROJECT,
1905
- subproject=_TELEMETRY_SUBPROJECT,
1906
- )
1907
- def get_name(self) -> str:
1908
- return self._model_name
1909
-
1910
- @telemetry.send_api_usage_telemetry(
1911
- project=_TELEMETRY_PROJECT,
1912
- subproject=_TELEMETRY_SUBPROJECT,
1913
- )
1914
- def get_version(self) -> str:
1915
- return self._model_version
1916
-
1917
- @telemetry.send_api_usage_telemetry(
1918
- project=_TELEMETRY_PROJECT,
1919
- subproject=_TELEMETRY_SUBPROJECT,
1920
- )
1921
- @snowpark._internal.utils.private_preview(version="0.2.0")
1922
- def predict(self, deployment_name: str, data: Any) -> "pd.DataFrame":
1923
- """Predict using the deployed model in Snowflake.
1924
-
1925
- Args:
1926
- deployment_name: name of the generated UDF.
1927
- data: Data to run predict.
1928
-
1929
- Raises:
1930
- ValueError: The deployment with given name haven't been deployed.
1931
-
1932
- Returns:
1933
- A dataframe containing the result of prediction.
1934
- """
1935
- # We will search temporary deployments from the local in-memory cache.
1936
- # If there is no hit, we try to search the remote deployment table.
1937
- di = self._registry._temporary_deployments.get(deployment_name)
1938
-
1939
- statement_params = telemetry.get_function_usage_statement_params(
1940
- project=_TELEMETRY_PROJECT,
1941
- subproject=_TELEMETRY_SUBPROJECT,
1942
- function_name=telemetry.get_statement_params_full_func_name(
1943
- inspect.currentframe(), self.__class__.__name__
1944
- ),
1945
- )
1946
-
1947
- self._registry._svm.validate_schema_version(statement_params)
1948
-
1949
- if di:
1950
- return model_api.predict(
1951
- session=self._registry._session,
1952
- deployment=di,
1953
- X=data,
1954
- statement_params=statement_params,
1955
- )
1956
-
1957
- # Mypy enforce to refer to the registry for calling the function
1958
- deployment_collect = self._registry.get_deployment(
1959
- self._model_name, self._model_version, deployment_name=deployment_name
1960
- ).collect(statement_params=statement_params)
1961
- if not deployment_collect:
1962
- raise ValueError(f"The deployment with name {deployment_name} haven't been deployed")
1963
- deployment = deployment_collect[0]
1964
- platform = deploy_platforms.TargetPlatform(deployment["TARGET_PLATFORM"])
1965
- target_method = deployment["TARGET_METHOD"]
1966
- signature = model_signature.ModelSignature.from_dict(json.loads(deployment["SIGNATURE"]))
1967
- options_dict = cast(Dict[str, Any], json.loads(deployment["OPTIONS"]))
1968
- platform_options = {
1969
- deploy_platforms.TargetPlatform.WAREHOUSE: model_types.WarehouseDeployOptions,
1970
- deploy_platforms.TargetPlatform.SNOWPARK_CONTAINER_SERVICES: (
1971
- model_types.SnowparkContainerServiceDeployOptions
1972
- ),
1973
- }
1974
-
1975
- if platform not in platform_options:
1976
- raise ValueError(f"Unsupported target Platform: {platform}")
1977
- options = platform_options[platform](options_dict)
1978
- di = model_types.Deployment(
1979
- name=self._registry._fully_qualified_deployment_name(deployment_name),
1980
- platform=platform,
1981
- target_method=target_method,
1982
- signature=signature,
1983
- options=options,
1984
- )
1985
- return model_api.predict(
1986
- session=self._registry._session,
1987
- deployment=di,
1988
- X=data,
1989
- statement_params=statement_params,
1990
- )
1991
-
1992
-
1993
- @telemetry.send_api_usage_telemetry(
1994
- project=_TELEMETRY_PROJECT,
1995
- subproject=_TELEMETRY_SUBPROJECT,
1996
- )
1997
- @snowpark._internal.utils.private_preview(version="0.2.0")
1998
- def create_model_registry(
1999
- *,
2000
- session: snowpark.Session,
2001
- database_name: str = _DEFAULT_REGISTRY_NAME,
2002
- schema_name: str = _DEFAULT_SCHEMA_NAME,
2003
- ) -> bool:
2004
- """Setup a new model registry. This should be run once per model registry by an administrator role.
2005
-
2006
- Args:
2007
- session: Session object to communicate with Snowflake.
2008
- database_name: Desired name of the model registry database.
2009
- schema_name: Desired name of the schema used by this model registry inside the database.
2010
-
2011
- Returns:
2012
- True if the creation of the model registry internal data structures was successful,
2013
- False otherwise.
2014
- """
2015
- # Get the db & schema of the current session
2016
- old_db = session.get_current_database()
2017
- old_schema = session.get_current_schema()
2018
-
2019
- # These might be exposed as parameters in the future.
2020
- database_name = identifier.get_inferred_name(database_name)
2021
- schema_name = identifier.get_inferred_name(schema_name)
2022
-
2023
- statement_params = telemetry.get_function_usage_statement_params(
2024
- project=_TELEMETRY_PROJECT,
2025
- subproject=_TELEMETRY_SUBPROJECT,
2026
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), ""),
2027
- )
2028
- try:
2029
- _create_registry_database(session, database_name, statement_params)
2030
- _create_registry_schema(session, database_name, schema_name, statement_params)
2031
-
2032
- svm = _schema_version_manager.SchemaVersionManager(session, database_name, schema_name)
2033
- deployed_schema_version = svm.get_deployed_version(statement_params)
2034
- if deployed_schema_version == _initial_schema._INITIAL_VERSION:
2035
- # We do not know if registry is being created for the first time.
2036
- # So let's start with creating initial schema, which is idempotent anyways.
2037
- _initial_schema.create_initial_registry_tables(session, database_name, schema_name, statement_params)
2038
-
2039
- svm.try_upgrade(statement_params)
2040
-
2041
- finally:
2042
- if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
2043
- # Restore the db & schema to the original ones
2044
- if old_db is not None and old_db != session.get_current_database():
2045
- session.use_database(old_db)
2046
- if old_schema is not None and old_schema != session.get_current_schema():
2047
- session.use_schema(old_schema)
2048
- return True