snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ from types import ModuleType
2
2
  from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
3
3
 
4
4
  import pandas as pd
5
+ from typing_extensions import deprecated
5
6
 
6
7
  from snowflake.ml._internal.exceptions import (
7
8
  error_codes,
@@ -23,6 +24,7 @@ from snowflake.ml.model._signatures import snowpark_handler
23
24
  from snowflake.snowpark import DataFrame as SnowparkDataFrame, Session, functions as F
24
25
 
25
26
 
27
+ @deprecated("Only used by PrPr model registry.")
26
28
  @overload
27
29
  def save_model(
28
30
  *,
@@ -61,6 +63,7 @@ def save_model(
61
63
  ...
62
64
 
63
65
 
66
+ @deprecated("Only used by PrPr model registry.")
64
67
  @overload
65
68
  def save_model(
66
69
  *,
@@ -101,6 +104,7 @@ def save_model(
101
104
  ...
102
105
 
103
106
 
107
+ @deprecated("Only used by PrPr model registry.")
104
108
  @overload
105
109
  def save_model(
106
110
  *,
@@ -142,6 +146,7 @@ def save_model(
142
146
  ...
143
147
 
144
148
 
149
+ @deprecated("Only used by PrPr model registry.")
145
150
  def save_model(
146
151
  *,
147
152
  name: str,
@@ -208,6 +213,7 @@ def save_model(
208
213
  return m
209
214
 
210
215
 
216
+ @deprecated("Only used by PrPr model registry.")
211
217
  @overload
212
218
  def load_model(*, session: Session, stage_path: str) -> model_composer.ModelComposer:
213
219
  """Load the model into memory from a zip file in the stage.
@@ -219,6 +225,7 @@ def load_model(*, session: Session, stage_path: str) -> model_composer.ModelComp
219
225
  ...
220
226
 
221
227
 
228
+ @deprecated("Only used by PrPr model registry.")
222
229
  @overload
223
230
  def load_model(*, session: Session, stage_path: str, meta_only: Literal[False]) -> model_composer.ModelComposer:
224
231
  """Load the model into memory from a zip file in the stage.
@@ -231,6 +238,7 @@ def load_model(*, session: Session, stage_path: str, meta_only: Literal[False])
231
238
  ...
232
239
 
233
240
 
241
+ @deprecated("Only used by PrPr model registry.")
234
242
  @overload
235
243
  def load_model(*, session: Session, stage_path: str, meta_only: Literal[True]) -> model_composer.ModelComposer:
236
244
  """Load the model into memory from a zip file in the stage with metadata only.
@@ -243,6 +251,7 @@ def load_model(*, session: Session, stage_path: str, meta_only: Literal[True]) -
243
251
  ...
244
252
 
245
253
 
254
+ @deprecated("Only used by PrPr model registry.")
246
255
  def load_model(
247
256
  *,
248
257
  session: Session,
@@ -261,10 +270,11 @@ def load_model(
261
270
  Loaded model.
262
271
  """
263
272
  m = model_composer.ModelComposer(session=session, stage_path=stage_path)
264
- m.load(meta_only=meta_only)
273
+ m.legacy_load(meta_only=meta_only)
265
274
  return m
266
275
 
267
276
 
277
+ @deprecated("Only used by PrPr model registry.")
268
278
  @overload
269
279
  def deploy(
270
280
  session: Session,
@@ -290,6 +300,7 @@ def deploy(
290
300
  ...
291
301
 
292
302
 
303
+ @deprecated("Only used by PrPr model registry.")
293
304
  @overload
294
305
  def deploy(
295
306
  session: Session,
@@ -319,6 +330,7 @@ def deploy(
319
330
  ...
320
331
 
321
332
 
333
+ @deprecated("Only used by PrPr model registry.")
322
334
  def deploy(
323
335
  session: Session,
324
336
  *,
@@ -423,6 +435,7 @@ def deploy(
423
435
  return info
424
436
 
425
437
 
438
+ @deprecated("Only used by PrPr model registry.")
426
439
  @overload
427
440
  def predict(
428
441
  session: Session,
@@ -443,6 +456,7 @@ def predict(
443
456
  ...
444
457
 
445
458
 
459
+ @deprecated("Only used by PrPr model registry.")
446
460
  @overload
447
461
  def predict(
448
462
  session: Session,
@@ -462,6 +476,7 @@ def predict(
462
476
  ...
463
477
 
464
478
 
479
+ @deprecated("Only used by PrPr model registry.")
465
480
  def predict(
466
481
  session: Session,
467
482
  *,
@@ -1,9 +1,9 @@
1
- from typing import Dict, List, Optional, Tuple, Union
1
+ from typing import Dict, List, Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
 
5
5
  from snowflake.ml._internal import telemetry
6
- from snowflake.ml._internal.utils import identifier, sql_identifier
6
+ from snowflake.ml._internal.utils import sql_identifier
7
7
  from snowflake.ml.model._client.model import model_version_impl
8
8
  from snowflake.ml.model._client.ops import model_ops
9
9
 
@@ -45,7 +45,7 @@ class Model:
45
45
  @property
46
46
  def fully_qualified_name(self) -> str:
47
47
  """Return the fully qualified name of the model that can be used to refer to it in SQL."""
48
- return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
48
+ return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
49
49
 
50
50
  @property
51
51
  @telemetry.send_api_usage_telemetry(
@@ -76,6 +76,8 @@ class Model:
76
76
  subproject=_TELEMETRY_SUBPROJECT,
77
77
  )
78
78
  return self._model_ops.get_comment(
79
+ database_name=None,
80
+ schema_name=None,
79
81
  model_name=self._model_name,
80
82
  statement_params=statement_params,
81
83
  )
@@ -92,6 +94,8 @@ class Model:
92
94
  )
93
95
  return self._model_ops.set_comment(
94
96
  comment=comment,
97
+ database_name=None,
98
+ schema_name=None,
95
99
  model_name=self._model_name,
96
100
  statement_params=statement_params,
97
101
  )
@@ -109,7 +113,7 @@ class Model:
109
113
  class_name=self.__class__.__name__,
110
114
  )
111
115
  default_version_name = self._model_ops.get_default_version(
112
- model_name=self._model_name, statement_params=statement_params
116
+ database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
113
117
  )
114
118
  return self.version(default_version_name)
115
119
 
@@ -129,7 +133,11 @@ class Model:
129
133
  else:
130
134
  version_name = version._version_name
131
135
  self._model_ops.set_default_version(
132
- model_name=self._model_name, version_name=version_name, statement_params=statement_params
136
+ database_name=None,
137
+ schema_name=None,
138
+ model_name=self._model_name,
139
+ version_name=version_name,
140
+ statement_params=statement_params,
133
141
  )
134
142
 
135
143
  @telemetry.send_api_usage_telemetry(
@@ -155,6 +163,8 @@ class Model:
155
163
  )
156
164
  version_id = sql_identifier.SqlIdentifier(version_name)
157
165
  if self._model_ops.validate_existence(
166
+ database_name=None,
167
+ schema_name=None,
158
168
  model_name=self._model_name,
159
169
  version_name=version_id,
160
170
  statement_params=statement_params,
@@ -184,6 +194,8 @@ class Model:
184
194
  subproject=_TELEMETRY_SUBPROJECT,
185
195
  )
186
196
  version_names = self._model_ops.list_models_or_versions(
197
+ database_name=None,
198
+ schema_name=None,
187
199
  model_name=self._model_name,
188
200
  statement_params=statement_params,
189
201
  )
@@ -211,6 +223,8 @@ class Model:
211
223
  subproject=_TELEMETRY_SUBPROJECT,
212
224
  )
213
225
  rows = self._model_ops.show_models_or_versions(
226
+ database_name=None,
227
+ schema_name=None,
214
228
  model_name=self._model_name,
215
229
  statement_params=statement_params,
216
230
  )
@@ -231,6 +245,8 @@ class Model:
231
245
  subproject=_TELEMETRY_SUBPROJECT,
232
246
  )
233
247
  self._model_ops.delete_model_or_version(
248
+ database_name=None,
249
+ schema_name=None,
234
250
  model_name=self._model_name,
235
251
  version_name=sql_identifier.SqlIdentifier(version_name),
236
252
  statement_params=statement_params,
@@ -250,29 +266,9 @@ class Model:
250
266
  project=_TELEMETRY_PROJECT,
251
267
  subproject=_TELEMETRY_SUBPROJECT,
252
268
  )
253
- return self._model_ops.show_tags(model_name=self._model_name, statement_params=statement_params)
254
-
255
- def _parse_tag_name(
256
- self,
257
- tag_name: str,
258
- ) -> Tuple[sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier, sql_identifier.SqlIdentifier]:
259
- _tag_db, _tag_schema, _tag_name, _ = identifier.parse_schema_level_object_identifier(tag_name)
260
- if _tag_db is None:
261
- tag_db_id = self._model_ops._model_client._database_name
262
- else:
263
- tag_db_id = sql_identifier.SqlIdentifier(_tag_db)
264
-
265
- if _tag_schema is None:
266
- tag_schema_id = self._model_ops._model_client._schema_name
267
- else:
268
- tag_schema_id = sql_identifier.SqlIdentifier(_tag_schema)
269
-
270
- if _tag_name is None:
271
- raise ValueError(f"Unable parse the tag name `{tag_name}` you input.")
272
-
273
- tag_name_id = sql_identifier.SqlIdentifier(_tag_name)
274
-
275
- return tag_db_id, tag_schema_id, tag_name_id
269
+ return self._model_ops.show_tags(
270
+ database_name=None, schema_name=None, model_name=self._model_name, statement_params=statement_params
271
+ )
276
272
 
277
273
  @telemetry.send_api_usage_telemetry(
278
274
  project=_TELEMETRY_PROJECT,
@@ -292,8 +288,10 @@ class Model:
292
288
  project=_TELEMETRY_PROJECT,
293
289
  subproject=_TELEMETRY_SUBPROJECT,
294
290
  )
295
- tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
291
+ tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
296
292
  return self._model_ops.get_tag_value(
293
+ database_name=None,
294
+ schema_name=None,
297
295
  model_name=self._model_name,
298
296
  tag_database_name=tag_db_id,
299
297
  tag_schema_name=tag_schema_id,
@@ -317,8 +315,10 @@ class Model:
317
315
  project=_TELEMETRY_PROJECT,
318
316
  subproject=_TELEMETRY_SUBPROJECT,
319
317
  )
320
- tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
318
+ tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
321
319
  self._model_ops.set_tag(
320
+ database_name=None,
321
+ schema_name=None,
322
322
  model_name=self._model_name,
323
323
  tag_database_name=tag_db_id,
324
324
  tag_schema_name=tag_schema_id,
@@ -342,11 +342,45 @@ class Model:
342
342
  project=_TELEMETRY_PROJECT,
343
343
  subproject=_TELEMETRY_SUBPROJECT,
344
344
  )
345
- tag_db_id, tag_schema_id, tag_name_id = self._parse_tag_name(tag_name)
345
+ tag_db_id, tag_schema_id, tag_name_id = sql_identifier.parse_fully_qualified_name(tag_name)
346
346
  self._model_ops.unset_tag(
347
+ database_name=None,
348
+ schema_name=None,
347
349
  model_name=self._model_name,
348
350
  tag_database_name=tag_db_id,
349
351
  tag_schema_name=tag_schema_id,
350
352
  tag_name=tag_name_id,
351
353
  statement_params=statement_params,
352
354
  )
355
+
356
+ @telemetry.send_api_usage_telemetry(
357
+ project=_TELEMETRY_PROJECT,
358
+ subproject=_TELEMETRY_SUBPROJECT,
359
+ )
360
+ def rename(self, model_name: str) -> None:
361
+ """Rename a model. Can be used to move a model when a fully qualified name is provided.
362
+
363
+ Args:
364
+ model_name: The new model name.
365
+ """
366
+ statement_params = telemetry.get_statement_params(
367
+ project=_TELEMETRY_PROJECT,
368
+ subproject=_TELEMETRY_SUBPROJECT,
369
+ )
370
+ new_db, new_schema, new_model = sql_identifier.parse_fully_qualified_name(model_name)
371
+
372
+ self._model_ops.rename(
373
+ database_name=None,
374
+ schema_name=None,
375
+ model_name=self._model_name,
376
+ new_model_db=new_db,
377
+ new_model_schema=new_schema,
378
+ new_model_name=new_model,
379
+ statement_params=statement_params,
380
+ )
381
+ self._model_ops = model_ops.ModelOperator(
382
+ self._model_ops._session,
383
+ database_name=new_db or self._model_ops._model_client._database_name,
384
+ schema_name=new_schema or self._model_ops._model_client._schema_name,
385
+ )
386
+ self._model_name = new_model
@@ -1,17 +1,29 @@
1
+ import enum
2
+ import pathlib
3
+ import tempfile
4
+ import warnings
1
5
  from typing import Any, Callable, Dict, List, Optional, Union
2
6
 
3
7
  import pandas as pd
4
8
 
5
9
  from snowflake.ml._internal import telemetry
6
10
  from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.model import type_hints as model_types
7
12
  from snowflake.ml.model._client.ops import metadata_ops, model_ops
13
+ from snowflake.ml.model._model_composer import model_composer
8
14
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
15
+ from snowflake.ml.model._packager.model_handlers import snowmlmodel
9
16
  from snowflake.snowpark import dataframe
10
17
 
11
18
  _TELEMETRY_PROJECT = "MLOps"
12
19
  _TELEMETRY_SUBPROJECT = "ModelManagement"
13
20
 
14
21
 
22
+ class ExportMode(enum.Enum):
23
+ MODEL = "model"
24
+ FULL = "full"
25
+
26
+
15
27
  class ModelVersion:
16
28
  """Model Version Object representing a specific version of the model that could be run."""
17
29
 
@@ -60,7 +72,7 @@ class ModelVersion:
60
72
  @property
61
73
  def fully_qualified_model_name(self) -> str:
62
74
  """Return the fully qualified name of the model to which the model version belongs."""
63
- return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
75
+ return self._model_ops._model_version_client.fully_qualified_object_name(None, None, self._model_name)
64
76
 
65
77
  @property
66
78
  @telemetry.send_api_usage_telemetry(
@@ -91,6 +103,8 @@ class ModelVersion:
91
103
  subproject=_TELEMETRY_SUBPROJECT,
92
104
  )
93
105
  return self._model_ops.get_comment(
106
+ database_name=None,
107
+ schema_name=None,
94
108
  model_name=self._model_name,
95
109
  version_name=self._version_name,
96
110
  statement_params=statement_params,
@@ -108,6 +122,8 @@ class ModelVersion:
108
122
  )
109
123
  return self._model_ops.set_comment(
110
124
  comment=comment,
125
+ database_name=None,
126
+ schema_name=None,
111
127
  model_name=self._model_name,
112
128
  version_name=self._version_name,
113
129
  statement_params=statement_params,
@@ -128,7 +144,11 @@ class ModelVersion:
128
144
  subproject=_TELEMETRY_SUBPROJECT,
129
145
  )
130
146
  return self._model_ops._metadata_ops.load(
131
- model_name=self._model_name, version_name=self._version_name, statement_params=statement_params
147
+ database_name=None,
148
+ schema_name=None,
149
+ model_name=self._model_name,
150
+ version_name=self._version_name,
151
+ statement_params=statement_params,
132
152
  )["metrics"]
133
153
 
134
154
  @telemetry.send_api_usage_telemetry(
@@ -171,6 +191,8 @@ class ModelVersion:
171
191
  metrics[metric_name] = value
172
192
  self._model_ops._metadata_ops.save(
173
193
  metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
194
+ database_name=None,
195
+ schema_name=None,
174
196
  model_name=self._model_name,
175
197
  version_name=self._version_name,
176
198
  statement_params=statement_params,
@@ -199,6 +221,8 @@ class ModelVersion:
199
221
  del metrics[metric_name]
200
222
  self._model_ops._metadata_ops.save(
201
223
  metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
224
+ database_name=None,
225
+ schema_name=None,
202
226
  model_name=self._model_name,
203
227
  version_name=self._version_name,
204
228
  statement_params=statement_params,
@@ -210,6 +234,8 @@ class ModelVersion:
210
234
  subproject=_TELEMETRY_SUBPROJECT,
211
235
  )
212
236
  return self._model_ops.get_functions(
237
+ database_name=None,
238
+ schema_name=None,
213
239
  model_name=self._model_name,
214
240
  version_name=self._version_name,
215
241
  statement_params=statement_params,
@@ -240,6 +266,7 @@ class ModelVersion:
240
266
  X: Union[pd.DataFrame, dataframe.DataFrame],
241
267
  *,
242
268
  function_name: Optional[str] = None,
269
+ partition_column: Optional[str] = None,
243
270
  strict_input_validation: bool = False,
244
271
  ) -> Union[pd.DataFrame, dataframe.DataFrame]:
245
272
  """Invoke a method in a model version object.
@@ -248,12 +275,14 @@ class ModelVersion:
248
275
  X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
249
276
  function_name: The function name to run. It is the name used to call a function in SQL.
250
277
  Defaults to None. It can only be None if there is only 1 method.
278
+ partition_column: The partition column name to partition by.
251
279
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
252
280
  type validation to make sure your input data won't overflow when providing to the model.
253
281
 
254
282
  Raises:
255
283
  ValueError: When no method with the corresponding name is available.
256
284
  ValueError: When there are more than 1 target methods available in the model but no function name specified.
285
+ ValueError: When the partition column is not a valid Snowflake identifier.
257
286
 
258
287
  Returns:
259
288
  The prediction data. It would be the same type dataframe as your input.
@@ -263,6 +292,10 @@ class ModelVersion:
263
292
  subproject=_TELEMETRY_SUBPROJECT,
264
293
  )
265
294
 
295
+ if partition_column is not None:
296
+ # Partition column must be a valid identifier
297
+ partition_column = sql_identifier.SqlIdentifier(partition_column)
298
+
266
299
  functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
267
300
  if function_name:
268
301
  req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
@@ -287,10 +320,134 @@ class ModelVersion:
287
320
  target_function_info = functions[0]
288
321
  return self._model_ops.invoke_method(
289
322
  method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
323
+ method_function_type=target_function_info["target_method_function_type"],
290
324
  signature=target_function_info["signature"],
291
325
  X=X,
326
+ database_name=None,
327
+ schema_name=None,
292
328
  model_name=self._model_name,
293
329
  version_name=self._version_name,
294
330
  strict_input_validation=strict_input_validation,
331
+ partition_column=partition_column,
332
+ statement_params=statement_params,
333
+ )
334
+
335
+ @telemetry.send_api_usage_telemetry(
336
+ project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
337
+ )
338
+ def export(self, target_path: str, *, export_mode: ExportMode = ExportMode.MODEL) -> None:
339
+ """Export model files to a local directory.
340
+
341
+ Args:
342
+ target_path: Path to a local directory to export files to. A directory will be created if does not exist.
343
+ export_mode: The mode to export the model. Defaults to ExportMode.MODEL.
344
+ ExportMode.MODEL: All model files including environment to load the model and model weights.
345
+ ExportMode.FULL: Additional files to run the model in Warehouse, besides all files in MODEL mode,
346
+
347
+ Raises:
348
+ ValueError: Raised when the target path is a file or an non-empty folder.
349
+ """
350
+ target_local_path = pathlib.Path(target_path)
351
+ if target_local_path.is_file() or any(target_local_path.iterdir()):
352
+ raise ValueError(f"Target path {target_local_path} is a file or an non-empty folder.")
353
+
354
+ target_local_path.mkdir(parents=False, exist_ok=True)
355
+ statement_params = telemetry.get_statement_params(
356
+ project=_TELEMETRY_PROJECT,
357
+ subproject=_TELEMETRY_SUBPROJECT,
358
+ )
359
+ self._model_ops.download_files(
360
+ database_name=None,
361
+ schema_name=None,
362
+ model_name=self._model_name,
363
+ version_name=self._version_name,
364
+ target_path=target_local_path,
365
+ mode=export_mode.value,
295
366
  statement_params=statement_params,
296
367
  )
368
+
369
+ @telemetry.send_api_usage_telemetry(
370
+ project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["force", "options"]
371
+ )
372
+ def load(
373
+ self,
374
+ *,
375
+ force: bool = False,
376
+ options: Optional[model_types.ModelLoadOption] = None,
377
+ ) -> model_types.SupportedModelType:
378
+ """Load the underlying original Python object back from a model.
379
+ This operation requires to have the exact the same environment as the one when logging the model, otherwise,
380
+ the model might be not functional or some other problems might occur.
381
+
382
+ Args:
383
+ force: Bypass the best-effort environment validation. Defaults to False.
384
+ options: Options to specify when loading the model, check `snowflake.ml.model.type_hints` for available
385
+ options. Defaults to None.
386
+
387
+ Raises:
388
+ ValueError: Raised when the best-effort environment validation fails.
389
+
390
+ Returns:
391
+ The original Python object loaded from the model object.
392
+ """
393
+ statement_params = telemetry.get_statement_params(
394
+ project=_TELEMETRY_PROJECT,
395
+ subproject=_TELEMETRY_SUBPROJECT,
396
+ )
397
+ if not force:
398
+ with tempfile.TemporaryDirectory() as tmp_workspace_for_validation:
399
+ ws_path_for_validation = pathlib.Path(tmp_workspace_for_validation)
400
+ self._model_ops.download_files(
401
+ database_name=None,
402
+ schema_name=None,
403
+ model_name=self._model_name,
404
+ version_name=self._version_name,
405
+ target_path=ws_path_for_validation,
406
+ mode="minimal",
407
+ statement_params=statement_params,
408
+ )
409
+ pk_for_validation = model_composer.ModelComposer.load(
410
+ ws_path_for_validation, meta_only=True, options=options
411
+ )
412
+ assert pk_for_validation.meta, (
413
+ "Unable to load model metadata for validation. "
414
+ f"model_name={self._model_name}, version_name={self._version_name}"
415
+ )
416
+
417
+ validation_errors = pk_for_validation.meta.env.validate_with_local_env(
418
+ check_snowpark_ml_version=(
419
+ pk_for_validation.meta.model_type == snowmlmodel.SnowMLModelHandler.HANDLER_TYPE
420
+ )
421
+ )
422
+ if validation_errors:
423
+ raise ValueError(
424
+ f"Unable to load this model due to following validation errors: {validation_errors}. "
425
+ "Make sure your local environment is the same as that when you logged the model, "
426
+ "or if you believe it should work, specify `force=True` to bypass this check."
427
+ )
428
+
429
+ warnings.warn(
430
+ "Loading model requires to have the exact the same environment as the one when "
431
+ "logging the model, otherwise, the model might be not functional or "
432
+ "some other problems might occur.",
433
+ category=RuntimeWarning,
434
+ stacklevel=2,
435
+ )
436
+
437
+ # We need the folder to be existed.
438
+ workspace = pathlib.Path(tempfile.mkdtemp())
439
+ self._model_ops.download_files(
440
+ database_name=None,
441
+ schema_name=None,
442
+ model_name=self._model_name,
443
+ version_name=self._version_name,
444
+ target_path=workspace,
445
+ mode="model",
446
+ statement_params=statement_params,
447
+ )
448
+ pk = model_composer.ModelComposer.load(workspace, meta_only=False, options=options)
449
+ assert pk.model, (
450
+ "Unable to load model. "
451
+ f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
452
+ )
453
+ return pk.model
@@ -61,12 +61,18 @@ class MetadataOperator:
61
61
  def _get_current_metadata_dict(
62
62
  self,
63
63
  *,
64
+ database_name: Optional[sql_identifier.SqlIdentifier],
65
+ schema_name: Optional[sql_identifier.SqlIdentifier],
64
66
  model_name: sql_identifier.SqlIdentifier,
65
67
  version_name: sql_identifier.SqlIdentifier,
66
68
  statement_params: Optional[Dict[str, Any]] = None,
67
69
  ) -> Dict[str, Any]:
68
70
  version_info_list = self._model_client.show_versions(
69
- model_name=model_name, version_name=version_name, statement_params=statement_params
71
+ database_name=database_name,
72
+ schema_name=schema_name,
73
+ model_name=model_name,
74
+ version_name=version_name,
75
+ statement_params=statement_params,
70
76
  )
71
77
  metadata_str = version_info_list[0][self._model_client.MODEL_VERSION_METADATA_COL_NAME]
72
78
  if not metadata_str:
@@ -79,12 +85,18 @@ class MetadataOperator:
79
85
  def load(
80
86
  self,
81
87
  *,
88
+ database_name: Optional[sql_identifier.SqlIdentifier],
89
+ schema_name: Optional[sql_identifier.SqlIdentifier],
82
90
  model_name: sql_identifier.SqlIdentifier,
83
91
  version_name: sql_identifier.SqlIdentifier,
84
92
  statement_params: Optional[Dict[str, Any]] = None,
85
93
  ) -> ModelVersionMetadataSchema:
86
94
  metadata_dict = self._get_current_metadata_dict(
87
- model_name=model_name, version_name=version_name, statement_params=statement_params
95
+ database_name=database_name,
96
+ schema_name=schema_name,
97
+ model_name=model_name,
98
+ version_name=version_name,
99
+ statement_params=statement_params,
88
100
  )
89
101
  return MetadataOperator._parse(metadata_dict)
90
102
 
@@ -92,14 +104,25 @@ class MetadataOperator:
92
104
  self,
93
105
  metadata: ModelVersionMetadataSchema,
94
106
  *,
107
+ database_name: Optional[sql_identifier.SqlIdentifier],
108
+ schema_name: Optional[sql_identifier.SqlIdentifier],
95
109
  model_name: sql_identifier.SqlIdentifier,
96
110
  version_name: sql_identifier.SqlIdentifier,
97
111
  statement_params: Optional[Dict[str, Any]] = None,
98
112
  ) -> None:
99
113
  metadata_dict = self._get_current_metadata_dict(
100
- model_name=model_name, version_name=version_name, statement_params=statement_params
114
+ database_name=database_name,
115
+ schema_name=schema_name,
116
+ model_name=model_name,
117
+ version_name=version_name,
118
+ statement_params=statement_params,
101
119
  )
102
120
  metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION})
103
121
  self._model_version_client.set_metadata(
104
- metadata_dict, model_name=model_name, version_name=version_name, statement_params=statement_params
122
+ metadata_dict,
123
+ database_name=database_name,
124
+ schema_name=schema_name,
125
+ model_name=model_name,
126
+ version_name=version_name,
127
+ statement_params=statement_params,
105
128
  )