snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -0,0 +1,163 @@
1
+ from types import ModuleType
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import pandas as pd
5
+ from absl.logging import logging
6
+
7
+ from snowflake.ml._internal.utils import sql_identifier
8
+ from snowflake.ml.model import model_signature, type_hints as model_types
9
+ from snowflake.ml.model._client.model import model_impl, model_version_impl
10
+ from snowflake.ml.model._client.ops import metadata_ops, model_ops
11
+ from snowflake.ml.model._model_composer import model_composer
12
+ from snowflake.snowpark import session
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ModelManager:
18
+ def __init__(
19
+ self,
20
+ session: session.Session,
21
+ *,
22
+ database_name: sql_identifier.SqlIdentifier,
23
+ schema_name: sql_identifier.SqlIdentifier,
24
+ ) -> None:
25
+ self._database_name = database_name
26
+ self._schema_name = schema_name
27
+ self._model_ops = model_ops.ModelOperator(
28
+ session, database_name=self._database_name, schema_name=self._schema_name
29
+ )
30
+
31
+ def log_model(
32
+ self,
33
+ model: model_types.SupportedModelType,
34
+ *,
35
+ model_name: str,
36
+ version_name: str,
37
+ comment: Optional[str] = None,
38
+ metrics: Optional[Dict[str, Any]] = None,
39
+ conda_dependencies: Optional[List[str]] = None,
40
+ pip_requirements: Optional[List[str]] = None,
41
+ python_version: Optional[str] = None,
42
+ signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
43
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
44
+ code_paths: Optional[List[str]] = None,
45
+ ext_modules: Optional[List[ModuleType]] = None,
46
+ options: Optional[model_types.ModelSaveOption] = None,
47
+ statement_params: Optional[Dict[str, Any]] = None,
48
+ ) -> model_version_impl.ModelVersion:
49
+ model_name_id = sql_identifier.SqlIdentifier(model_name)
50
+
51
+ version_name_id = sql_identifier.SqlIdentifier(version_name)
52
+
53
+ if self._model_ops.validate_existence(
54
+ model_name=model_name_id, statement_params=statement_params
55
+ ) and self._model_ops.validate_existence(
56
+ model_name=model_name_id, version_name=version_name_id, statement_params=statement_params
57
+ ):
58
+ raise ValueError(f"Model {model_name} version {version_name} already existed.")
59
+
60
+ stage_path = self._model_ops.prepare_model_stage_path(
61
+ statement_params=statement_params,
62
+ )
63
+
64
+ logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
65
+
66
+ mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path)
67
+ mc.save(
68
+ name=model_name_id.resolved(),
69
+ model=model,
70
+ signatures=signatures,
71
+ sample_input=sample_input_data,
72
+ conda_dependencies=conda_dependencies,
73
+ pip_requirements=pip_requirements,
74
+ python_version=python_version,
75
+ code_paths=code_paths,
76
+ ext_modules=ext_modules,
77
+ options=options,
78
+ )
79
+
80
+ logger.info("Start creating MODEL object for you in the Snowflake.")
81
+
82
+ self._model_ops.create_from_stage(
83
+ composed_model=mc,
84
+ model_name=model_name_id,
85
+ version_name=version_name_id,
86
+ statement_params=statement_params,
87
+ )
88
+
89
+ mv = model_version_impl.ModelVersion._ref(
90
+ self._model_ops,
91
+ model_name=model_name_id,
92
+ version_name=version_name_id,
93
+ )
94
+
95
+ if comment:
96
+ mv.comment = comment
97
+
98
+ if metrics:
99
+ self._model_ops._metadata_ops.save(
100
+ metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
101
+ model_name=model_name_id,
102
+ version_name=version_name_id,
103
+ statement_params=statement_params,
104
+ )
105
+
106
+ return mv
107
+
108
+ def get_model(
109
+ self,
110
+ model_name: str,
111
+ *,
112
+ statement_params: Optional[Dict[str, Any]] = None,
113
+ ) -> model_impl.Model:
114
+ model_name_id = sql_identifier.SqlIdentifier(model_name)
115
+ if self._model_ops.validate_existence(
116
+ model_name=model_name_id,
117
+ statement_params=statement_params,
118
+ ):
119
+ return model_impl.Model._ref(
120
+ self._model_ops,
121
+ model_name=model_name_id,
122
+ )
123
+ else:
124
+ raise ValueError(f"Unable to find model {model_name}")
125
+
126
+ def models(
127
+ self,
128
+ *,
129
+ statement_params: Optional[Dict[str, Any]] = None,
130
+ ) -> List[model_impl.Model]:
131
+ model_names = self._model_ops.list_models_or_versions(
132
+ statement_params=statement_params,
133
+ )
134
+ return [
135
+ model_impl.Model._ref(
136
+ self._model_ops,
137
+ model_name=model_name,
138
+ )
139
+ for model_name in model_names
140
+ ]
141
+
142
+ def show_models(
143
+ self,
144
+ *,
145
+ statement_params: Optional[Dict[str, Any]] = None,
146
+ ) -> pd.DataFrame:
147
+ rows = self._model_ops.show_models_or_versions(
148
+ statement_params=statement_params,
149
+ )
150
+ return pd.DataFrame([row.as_dict() for row in rows])
151
+
152
+ def delete_model(
153
+ self,
154
+ model_name: str,
155
+ *,
156
+ statement_params: Optional[Dict[str, Any]] = None,
157
+ ) -> None:
158
+ model_name_id = sql_identifier.SqlIdentifier(model_name)
159
+
160
+ self._model_ops.delete_model_or_version(
161
+ model_name=model_name_id,
162
+ statement_params=statement_params,
163
+ )
@@ -3,6 +3,7 @@ import json
3
3
  import sys
4
4
  import textwrap
5
5
  import types
6
+ import warnings
6
7
  from typing import (
7
8
  TYPE_CHECKING,
8
9
  Any,
@@ -305,6 +306,17 @@ class ModelRegistry:
305
306
  schema_name: Desired name of the schema used by this model registry inside the database.
306
307
  create_if_not_exists: create model registry if it's not exists already.
307
308
  """
309
+
310
+ warnings.warn(
311
+ """
312
+ The `snowflake.ml.registry.model_registry.ModelRegistry` has been deprecated starting from version 1.2.0.
313
+ It will stay in the Private Preview phase. For future implementations, kindly utilize `snowflake.ml.registry.Registry`,
314
+ except when specifically required. The old model registry will be removed once all its primary functionalities are
315
+ fully integrated into the new registry.
316
+ """,
317
+ DeprecationWarning,
318
+ stacklevel=2,
319
+ )
308
320
  if create_if_not_exists:
309
321
  create_model_registry(session=session, database_name=database_name, schema_name=schema_name)
310
322
 
@@ -1,12 +1,17 @@
1
1
  from types import ModuleType
2
- from typing import Dict, List, Optional
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ import pandas as pd
3
5
 
4
6
  from snowflake.ml._internal import telemetry
5
7
  from snowflake.ml._internal.utils import sql_identifier
6
- from snowflake.ml.model import model_signature, type_hints as model_types
7
- from snowflake.ml.model._client.model import model_impl, model_version_impl
8
- from snowflake.ml.model._client.ops import model_ops
9
- from snowflake.ml.model._model_composer import model_composer
8
+ from snowflake.ml.model import (
9
+ Model,
10
+ ModelVersion,
11
+ model_signature,
12
+ type_hints as model_types,
13
+ )
14
+ from snowflake.ml.registry._manager import model_manager
10
15
  from snowflake.snowpark import session
11
16
 
12
17
  _TELEMETRY_PROJECT = "MLOps"
@@ -21,6 +26,18 @@ class Registry:
21
26
  database_name: Optional[str] = None,
22
27
  schema_name: Optional[str] = None,
23
28
  ) -> None:
29
+ """Opens a registry within a pre-created Snowflake schema.
30
+
31
+ Args:
32
+ session: The Snowpark Session to connect with Snowflake.
33
+ database_name: The name of the database. If None, the current database of the session
34
+ will be used. Defaults to None.
35
+ schema_name: The name of the schema. If None, the current schema of the session
36
+ will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
37
+
38
+ Raises:
39
+ ValueError: When there is no specified or active database in the session.
40
+ """
24
41
  if database_name:
25
42
  self._database_name = sql_identifier.SqlIdentifier(database_name)
26
43
  else:
@@ -42,12 +59,13 @@ class Registry:
42
59
  else sql_identifier.SqlIdentifier("PUBLIC")
43
60
  )
44
61
 
45
- self._model_ops = model_ops.ModelOperator(
62
+ self._model_manager = model_manager.ModelManager(
46
63
  session, database_name=self._database_name, schema_name=self._schema_name
47
64
  )
48
65
 
49
66
  @property
50
67
  def location(self) -> str:
68
+ """Get the location (database.schema) of the registry."""
51
69
  return ".".join([self._database_name.identifier(), self._schema_name.identifier()])
52
70
 
53
71
  @telemetry.send_api_usage_telemetry(
@@ -60,6 +78,8 @@ class Registry:
60
78
  *,
61
79
  model_name: str,
62
80
  version_name: str,
81
+ comment: Optional[str] = None,
82
+ metrics: Optional[Dict[str, Any]] = None,
63
83
  conda_dependencies: Optional[List[str]] = None,
64
84
  pip_requirements: Optional[List[str]] = None,
65
85
  python_version: Optional[str] = None,
@@ -68,148 +88,138 @@ class Registry:
68
88
  code_paths: Optional[List[str]] = None,
69
89
  ext_modules: Optional[List[ModuleType]] = None,
70
90
  options: Optional[model_types.ModelSaveOption] = None,
71
- ) -> model_version_impl.ModelVersion:
72
- """Log a model.
91
+ ) -> ModelVersion:
92
+ """
93
+ Log a model with various parameters and metadata.
73
94
 
74
95
  Args:
75
- model: Model Python object
76
- model_name: A string as name.
77
- version_name: A string as version. model_name and version_name combination must be unique.
78
- signatures: Model data signatures for inputs and output for every target methods. If it is None,
96
+ model: Model object of supported types such as Scikit-learn, XGBoost, Snowpark ML,
97
+ PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline,
98
+ Peft-finetuned LLM, or Custom Model.
99
+ model_name: Name to identify the model.
100
+ version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
101
+ comment: Comment associated with the model version. Defaults to None.
102
+ metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
103
+ signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
79
104
  sample_input_data would be used to infer the signatures for those models that cannot automatically
80
- infer the signature. If not None, sample_input should not be specified. Defaults to None.
81
- sample_input_data: Sample input data to infer the model signatures from. If it is None, signatures must be
82
- specified if the model cannot automatically infer the signature. If not None, signatures should not be
83
- specified. Defaults to None.
84
- conda_dependencies: List of Conda package specs. Use "[channel::]package [operator version]" syntax to
85
- specify a dependency. It is a recommended way to specify your dependencies using conda. When channel is
86
- not specified, Snowflake Anaconda Channel will be used.
87
- pip_requirements: List of Pip package specs.
88
- python_version: A string of python version where model is run. Used for user override. If specified as None,
89
- current version would be captured. Defaults to None.
90
- code_paths: Directory of code to import.
91
- ext_modules: External modules that user might want to get pickled with model object. Defaults to None.
92
- options: Model specific kwargs.
105
+ infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
106
+ sample_input_data: Sample input data to infer model signatures from. Defaults to None.
107
+ conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
108
+ to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
109
+ is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
110
+ pip_requirements: List of Pip package specifications. Defaults to None.
111
+ python_version: Python version in which the model is run. Defaults to None.
112
+ code_paths: List of directories containing code to import. Defaults to None.
113
+ ext_modules: List of external modules to pickle with the model object.
114
+ Only supported when logging the following types of model:
115
+ Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
116
+ options (Dict[str, Any], optional): Additional model saving options.
117
+
118
+ Model Saving Options include:
119
+
120
+ - embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
121
+ Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
122
+ Channel. Otherwise, defaults to False
123
+ - relax_version: Whether or not relax the version constraints of the dependencies.
124
+ It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to False.
125
+ - method_options: Per-method saving options including:
126
+ - case_sensitive: Indicates whether the method and its signature should be case sensitive.
127
+ This means when you refer the method in the SQL, you need to double quote it.
128
+ This will be helpful if you need case to tell apart your methods or features, or you have
129
+ non-alphabetic characters in your method or feature name. Defaults to False.
130
+ - max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.
131
+ Defaults to None, determined automatically by Snowflake.
93
132
 
94
133
  Returns:
95
- A ModelVersion object corresponding to the model just get logged.
134
+ ModelVersion: ModelVersion object corresponding to the model just logged.
96
135
  """
97
136
 
98
137
  statement_params = telemetry.get_statement_params(
99
138
  project=_TELEMETRY_PROJECT,
100
139
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
101
140
  )
102
- model_name_id = sql_identifier.SqlIdentifier(model_name)
103
-
104
- version_name_id = sql_identifier.SqlIdentifier(version_name)
105
-
106
- stage_path = self._model_ops.prepare_model_stage_path(
107
- statement_params=statement_params,
108
- )
109
-
110
- mc = model_composer.ModelComposer(self._model_ops._session, stage_path=stage_path)
111
- mc.save(
112
- name=model_name_id.resolved(),
141
+ return self._model_manager.log_model(
113
142
  model=model,
114
- signatures=signatures,
115
- sample_input=sample_input_data,
143
+ model_name=model_name,
144
+ version_name=version_name,
145
+ comment=comment,
146
+ metrics=metrics,
116
147
  conda_dependencies=conda_dependencies,
117
148
  pip_requirements=pip_requirements,
118
149
  python_version=python_version,
150
+ signatures=signatures,
151
+ sample_input_data=sample_input_data,
119
152
  code_paths=code_paths,
120
153
  ext_modules=ext_modules,
121
154
  options=options,
122
- )
123
- self._model_ops.create_from_stage(
124
- composed_model=mc,
125
- model_name=model_name_id,
126
- version_name=version_name_id,
127
155
  statement_params=statement_params,
128
156
  )
129
157
 
130
- return model_version_impl.ModelVersion._ref(
131
- self._model_ops,
132
- model_name=model_name_id,
133
- version_name=version_name_id,
134
- )
135
-
136
158
  @telemetry.send_api_usage_telemetry(
137
159
  project=_TELEMETRY_PROJECT,
138
160
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
139
161
  )
140
- def get_model(self, model_name: str) -> model_impl.Model:
141
- """Get the model object.
162
+ def get_model(self, model_name: str) -> Model:
163
+ """Get the model object by its name.
142
164
 
143
165
  Args:
144
- model_name: The model name.
145
-
146
- Raises:
147
- ValueError: Raised when the model requested does not exist.
166
+ model_name: The name of the model.
148
167
 
149
168
  Returns:
150
- The model object.
169
+ The corresponding model object.
151
170
  """
152
- model_name_id = sql_identifier.SqlIdentifier(model_name)
153
-
154
171
  statement_params = telemetry.get_statement_params(
155
172
  project=_TELEMETRY_PROJECT,
156
173
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
157
174
  )
158
- if self._model_ops.validate_existence(
159
- model_name=model_name_id,
160
- statement_params=statement_params,
161
- ):
162
- return model_impl.Model._ref(
163
- self._model_ops,
164
- model_name=model_name_id,
165
- )
166
- else:
167
- raise ValueError(f"Unable to find model {model_name}")
175
+ return self._model_manager.get_model(model_name=model_name, statement_params=statement_params)
168
176
 
169
177
  @telemetry.send_api_usage_telemetry(
170
178
  project=_TELEMETRY_PROJECT,
171
179
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
172
180
  )
173
- def list_models(self) -> List[model_impl.Model]:
174
- """List all models in the schema where the registry is opened.
181
+ def models(self) -> List[Model]:
182
+ """Get all models in the schema where the registry is opened.
175
183
 
176
184
  Returns:
177
- A List of Model= object representing all models in the schema where the registry is opened.
185
+ A list of Model objects representing all models in the opened registry.
178
186
  """
179
187
  statement_params = telemetry.get_statement_params(
180
188
  project=_TELEMETRY_PROJECT,
181
189
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
182
190
  )
183
- model_names = self._model_ops.list_models_or_versions(
184
- statement_params=statement_params,
191
+ return self._model_manager.models(statement_params=statement_params)
192
+
193
+ @telemetry.send_api_usage_telemetry(
194
+ project=_TELEMETRY_PROJECT,
195
+ subproject=_MODEL_TELEMETRY_SUBPROJECT,
196
+ )
197
+ def show_models(self) -> pd.DataFrame:
198
+ """Show information of all models in the schema where the registry is opened.
199
+
200
+ Returns:
201
+ A Pandas DataFrame containing information of all models in the schema.
202
+ """
203
+ statement_params = telemetry.get_statement_params(
204
+ project=_TELEMETRY_PROJECT,
205
+ subproject=_MODEL_TELEMETRY_SUBPROJECT,
185
206
  )
186
- return [
187
- model_impl.Model._ref(
188
- self._model_ops,
189
- model_name=model_name,
190
- )
191
- for model_name in model_names
192
- ]
207
+ return self._model_manager.show_models(statement_params=statement_params)
193
208
 
194
209
  @telemetry.send_api_usage_telemetry(
195
210
  project=_TELEMETRY_PROJECT,
196
211
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
197
212
  )
198
213
  def delete_model(self, model_name: str) -> None:
199
- """Delete the model.
214
+ """
215
+ Delete the model by its name.
200
216
 
201
217
  Args:
202
- model_name: The model name, can be fully qualified one.
203
- If not, use database name and schema name of the registry.
218
+ model_name: The name of the model to be deleted.
204
219
  """
205
- model_name_id = sql_identifier.SqlIdentifier(model_name)
206
-
207
220
  statement_params = telemetry.get_statement_params(
208
221
  project=_TELEMETRY_PROJECT,
209
222
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
210
223
  )
211
224
 
212
- self._model_ops.delete_model_or_version(
213
- model_name=model_name_id,
214
- statement_params=statement_params,
215
- )
225
+ self._model_manager.delete_model(model_name=model_name, statement_params=statement_params)
snowflake/ml/version.py CHANGED
@@ -1 +1 @@
1
- VERSION="1.1.2"
1
+ VERSION="1.2.1"