snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  10. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  11. snowflake/ml/dataset/dataset.py +1 -1
  12. snowflake/ml/model/_api.py +21 -14
  13. snowflake/ml/model/_client/model/model_impl.py +176 -0
  14. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  15. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  16. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  18. snowflake/ml/model/_client/sql/model.py +75 -0
  19. snowflake/ml/model/_client/sql/model_version.py +213 -0
  20. snowflake/ml/model/_client/sql/stage.py +40 -0
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  22. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  23. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  24. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  25. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  26. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  27. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  28. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  31. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  32. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  33. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  34. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  36. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  37. snowflake/ml/model/model_signature.py +108 -53
  38. snowflake/ml/model/type_hints.py +1 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  40. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  41. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  42. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  43. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  44. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  45. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  46. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  47. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
  48. snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
  49. snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
  50. snowflake/ml/modeling/cluster/birch.py +94 -124
  51. snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
  52. snowflake/ml/modeling/cluster/dbscan.py +94 -124
  53. snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
  54. snowflake/ml/modeling/cluster/k_means.py +93 -124
  55. snowflake/ml/modeling/cluster/mean_shift.py +94 -124
  56. snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
  57. snowflake/ml/modeling/cluster/optics.py +94 -124
  58. snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
  59. snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
  60. snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
  61. snowflake/ml/modeling/compose/column_transformer.py +94 -124
  62. snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
  63. snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
  64. snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
  65. snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
  66. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
  67. snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
  68. snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
  69. snowflake/ml/modeling/covariance/oas.py +80 -110
  70. snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
  71. snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
  72. snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
  73. snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
  74. snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
  75. snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
  76. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
  77. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
  78. snowflake/ml/modeling/decomposition/pca.py +94 -124
  79. snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
  80. snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
  81. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
  82. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
  83. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
  84. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
  85. snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
  86. snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
  87. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
  88. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
  89. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
  90. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
  93. snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
  94. snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
  95. snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
  96. snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
  97. snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
  98. snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
  100. snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
  101. snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
  102. snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
  103. snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
  104. snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
  105. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
  106. snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
  107. snowflake/ml/modeling/framework/base.py +2 -2
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
  110. snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
  111. snowflake/ml/modeling/impute/knn_imputer.py +94 -124
  112. snowflake/ml/modeling/impute/missing_indicator.py +94 -124
  113. snowflake/ml/modeling/impute/simple_imputer.py +1 -1
  114. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
  115. snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
  116. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
  117. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
  118. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
  119. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
  120. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
  121. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
  122. snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
  123. snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
  124. snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
  125. snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
  126. snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
  127. snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
  128. snowflake/ml/modeling/linear_model/lars.py +96 -124
  129. snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
  130. snowflake/ml/modeling/linear_model/lasso.py +96 -124
  131. snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
  132. snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
  133. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
  134. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
  135. snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
  136. snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
  137. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
  140. snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
  141. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
  142. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
  143. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
  144. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
  145. snowflake/ml/modeling/linear_model/perceptron.py +95 -124
  146. snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
  147. snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
  148. snowflake/ml/modeling/linear_model/ridge.py +96 -124
  149. snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
  150. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
  151. snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
  152. snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
  153. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
  154. snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
  155. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
  156. snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
  157. snowflake/ml/modeling/manifold/isomap.py +94 -124
  158. snowflake/ml/modeling/manifold/mds.py +94 -124
  159. snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
  160. snowflake/ml/modeling/manifold/tsne.py +94 -124
  161. snowflake/ml/modeling/metrics/classification.py +187 -52
  162. snowflake/ml/modeling/metrics/correlation.py +4 -2
  163. snowflake/ml/modeling/metrics/covariance.py +7 -4
  164. snowflake/ml/modeling/metrics/ranking.py +32 -16
  165. snowflake/ml/modeling/metrics/regression.py +60 -32
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
  180. snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
  190. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  191. snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
  192. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
  193. snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
  194. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
  195. snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
  196. snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
  197. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
  198. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
  199. snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
  200. snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
  201. snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
  202. snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
  203. snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
  204. snowflake/ml/modeling/svm/linear_svc.py +96 -124
  205. snowflake/ml/modeling/svm/linear_svr.py +96 -124
  206. snowflake/ml/modeling/svm/nu_svc.py +96 -124
  207. snowflake/ml/modeling/svm/nu_svr.py +96 -124
  208. snowflake/ml/modeling/svm/svc.py +96 -124
  209. snowflake/ml/modeling/svm/svr.py +96 -124
  210. snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
  211. snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
  212. snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
  213. snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
  214. snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
  215. snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
  216. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
  217. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
  218. snowflake/ml/registry/model_registry.py +2 -0
  219. snowflake/ml/registry/registry.py +215 -0
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
  222. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  223. snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
  224. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -7,7 +7,6 @@ from snowflake.ml._internal.exceptions import (
7
7
  error_codes,
8
8
  exceptions as snowml_exceptions,
9
9
  )
10
- from snowflake.ml._internal.utils import identifier
11
10
  from snowflake.ml.model import (
12
11
  deploy_platforms,
13
12
  model_signature,
@@ -188,6 +187,10 @@ def save_model(
188
187
  Returns:
189
188
  Model
190
189
  """
190
+ if options is None:
191
+ options = {}
192
+ options["_legacy_save"] = True
193
+
191
194
  m = model_composer.ModelComposer(session=session, stage_path=stage_path)
192
195
  m.save(
193
196
  name=name,
@@ -481,6 +484,7 @@ def predict(
481
484
  # Get options
482
485
  INTERMEDIATE_OBJ_NAME = "tmp_result"
483
486
  sig = deployment["signature"]
487
+ identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
484
488
 
485
489
  # Validate and prepare input
486
490
  if not isinstance(X, SnowparkDataFrame):
@@ -491,7 +495,7 @@ def predict(
491
495
  else:
492
496
  keep_order = False
493
497
  output_with_input_features = True
494
- model_signature._validate_snowpark_data(X, sig.inputs)
498
+ identifier_rule = model_signature._validate_snowpark_data(X, sig.inputs)
495
499
  s_df = X
496
500
 
497
501
  if statement_params:
@@ -500,10 +504,14 @@ def predict(
500
504
  else:
501
505
  s_df._statement_params = statement_params # type: ignore[assignment]
502
506
 
507
+ original_cols = s_df.columns
508
+
503
509
  # Infer and get intermediate result
504
510
  input_cols = []
505
- for col_name in s_df.columns:
506
- literal_col_name = identifier.get_unescaped_names(col_name)
511
+ for input_feature in sig.inputs:
512
+ literal_col_name = input_feature.name
513
+ col_name = identifier_rule.get_identifier_from_feature(input_feature.name)
514
+
507
515
  input_cols.extend(
508
516
  [
509
517
  F.lit(literal_col_name),
@@ -511,29 +519,28 @@ def predict(
511
519
  ]
512
520
  )
513
521
 
514
- # TODO[shchen]: SNOW-870032, For SnowService, external function name cannot be double quoted, else it results in
515
- # external function no found.
516
522
  udf_name = deployment["name"]
517
- output_obj = F.call_udf(udf_name, F.object_construct(*input_cols))
518
-
519
- if output_with_input_features:
520
- df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
521
- else:
522
- df_res = s_df.select(output_obj.alias(INTERMEDIATE_OBJ_NAME))
523
+ output_obj = F.call_udf(udf_name, F.object_construct_keep_null(*input_cols))
524
+ df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
523
525
 
524
526
  if keep_order:
525
527
  df_res = df_res.order_by(
526
- F.col(INTERMEDIATE_OBJ_NAME)[infer_template._KEEP_ORDER_COL_NAME],
528
+ F.col(infer_template._KEEP_ORDER_COL_NAME),
527
529
  ascending=True,
528
530
  )
529
531
 
532
+ if not output_with_input_features:
533
+ df_res = df_res.drop(*original_cols)
534
+
530
535
  # Prepare the output
531
536
  output_cols = []
537
+ output_col_names = []
532
538
  for output_feature in sig.outputs:
533
539
  output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature.name].astype(output_feature.as_snowpark_type()))
540
+ output_col_names.append(identifier_rule.get_identifier_from_feature(output_feature.name))
534
541
 
535
542
  df_res = df_res.with_columns(
536
- [identifier.get_inferred_name(output_feature.name) for output_feature in sig.outputs],
543
+ output_col_names,
537
544
  output_cols,
538
545
  ).drop(INTERMEDIATE_OBJ_NAME)
539
546
 
@@ -0,0 +1,176 @@
1
+ from typing import List, Union
2
+
3
+ from snowflake.ml._internal import telemetry
4
+ from snowflake.ml._internal.utils import sql_identifier
5
+ from snowflake.ml.model._client.model import model_version_impl
6
+ from snowflake.ml.model._client.ops import model_ops
7
+
8
+ _TELEMETRY_PROJECT = "MLOps"
9
+ _TELEMETRY_SUBPROJECT = "ModelManagement"
10
+
11
+
12
+ class Model:
13
+ """Model Object containing multiple versions. Mapping to SQL's MODEL object."""
14
+
15
+ _model_ops: model_ops.ModelOperator
16
+ _model_name: sql_identifier.SqlIdentifier
17
+
18
+ def __init__(self) -> None:
19
+ raise RuntimeError("Model's initializer is not meant to be used. Use `get_model` from registry instead.")
20
+
21
+ @classmethod
22
+ def _ref(
23
+ cls,
24
+ model_ops: model_ops.ModelOperator,
25
+ *,
26
+ model_name: sql_identifier.SqlIdentifier,
27
+ ) -> "Model":
28
+ self: "Model" = object.__new__(cls)
29
+ self._model_ops = model_ops
30
+ self._model_name = model_name
31
+ return self
32
+
33
+ def __eq__(self, __value: object) -> bool:
34
+ if not isinstance(__value, Model):
35
+ return False
36
+ return self._model_ops == __value._model_ops and self._model_name == __value._model_name
37
+
38
+ @property
39
+ def name(self) -> str:
40
+ return self._model_name.identifier()
41
+
42
+ @property
43
+ def fully_qualified_name(self) -> str:
44
+ return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
45
+
46
+ @property
47
+ @telemetry.send_api_usage_telemetry(
48
+ project=_TELEMETRY_PROJECT,
49
+ subproject=_TELEMETRY_SUBPROJECT,
50
+ )
51
+ def description(self) -> str:
52
+ statement_params = telemetry.get_statement_params(
53
+ project=_TELEMETRY_PROJECT,
54
+ subproject=_TELEMETRY_SUBPROJECT,
55
+ )
56
+ return self._model_ops.get_comment(
57
+ model_name=self._model_name,
58
+ statement_params=statement_params,
59
+ )
60
+
61
+ @description.setter
62
+ @telemetry.send_api_usage_telemetry(
63
+ project=_TELEMETRY_PROJECT,
64
+ subproject=_TELEMETRY_SUBPROJECT,
65
+ )
66
+ def description(self, description: str) -> None:
67
+ statement_params = telemetry.get_statement_params(
68
+ project=_TELEMETRY_PROJECT,
69
+ subproject=_TELEMETRY_SUBPROJECT,
70
+ )
71
+ return self._model_ops.set_comment(
72
+ comment=description,
73
+ model_name=self._model_name,
74
+ statement_params=statement_params,
75
+ )
76
+
77
+ @property
78
+ @telemetry.send_api_usage_telemetry(
79
+ project=_TELEMETRY_PROJECT,
80
+ subproject=_TELEMETRY_SUBPROJECT,
81
+ )
82
+ def default(self) -> model_version_impl.ModelVersion:
83
+ statement_params = telemetry.get_statement_params(
84
+ project=_TELEMETRY_PROJECT,
85
+ subproject=_TELEMETRY_SUBPROJECT,
86
+ class_name=self.__class__.__name__,
87
+ )
88
+ default_version_name = self._model_ops._model_version_client.get_default_version(
89
+ model_name=self._model_name, statement_params=statement_params
90
+ )
91
+ return self.version(default_version_name)
92
+
93
+ @default.setter
94
+ @telemetry.send_api_usage_telemetry(
95
+ project=_TELEMETRY_PROJECT,
96
+ subproject=_TELEMETRY_SUBPROJECT,
97
+ )
98
+ def default(self, version: Union[str, model_version_impl.ModelVersion]) -> None:
99
+ statement_params = telemetry.get_statement_params(
100
+ project=_TELEMETRY_PROJECT,
101
+ subproject=_TELEMETRY_SUBPROJECT,
102
+ class_name=self.__class__.__name__,
103
+ )
104
+ if isinstance(version, str):
105
+ version_name = sql_identifier.SqlIdentifier(version)
106
+ else:
107
+ version_name = version._version_name
108
+ self._model_ops._model_version_client.set_default_version(
109
+ model_name=self._model_name, version_name=version_name, statement_params=statement_params
110
+ )
111
+
112
+ @telemetry.send_api_usage_telemetry(
113
+ project=_TELEMETRY_PROJECT,
114
+ subproject=_TELEMETRY_SUBPROJECT,
115
+ )
116
+ def version(self, version_name: str) -> model_version_impl.ModelVersion:
117
+ """Get a model version object given a version name in the model.
118
+
119
+ Args:
120
+ version_name: The name of version
121
+
122
+ Raises:
123
+ ValueError: Raised when the version requested does not exist.
124
+
125
+ Returns:
126
+ The model version object.
127
+ """
128
+ statement_params = telemetry.get_statement_params(
129
+ project=_TELEMETRY_PROJECT,
130
+ subproject=_TELEMETRY_SUBPROJECT,
131
+ )
132
+ version_id = sql_identifier.SqlIdentifier(version_name)
133
+ if self._model_ops.validate_existence(
134
+ model_name=self._model_name,
135
+ version_name=version_id,
136
+ statement_params=statement_params,
137
+ ):
138
+ return model_version_impl.ModelVersion._ref(
139
+ self._model_ops,
140
+ model_name=self._model_name,
141
+ version_name=version_id,
142
+ )
143
+ else:
144
+ raise ValueError(
145
+ f"Unable to find version with name {version_id.identifier()} in model {self.fully_qualified_name}"
146
+ )
147
+
148
+ @telemetry.send_api_usage_telemetry(
149
+ project=_TELEMETRY_PROJECT,
150
+ subproject=_TELEMETRY_SUBPROJECT,
151
+ )
152
+ def list_versions(self) -> List[model_version_impl.ModelVersion]:
153
+ """List all versions in the model.
154
+
155
+ Returns:
156
+ A List of ModelVersion object representing all versions in the model.
157
+ """
158
+ statement_params = telemetry.get_statement_params(
159
+ project=_TELEMETRY_PROJECT,
160
+ subproject=_TELEMETRY_SUBPROJECT,
161
+ )
162
+ version_names = self._model_ops.list_models_or_versions(
163
+ model_name=self._model_name,
164
+ statement_params=statement_params,
165
+ )
166
+ return [
167
+ model_version_impl.ModelVersion._ref(
168
+ self._model_ops,
169
+ model_name=self._model_name,
170
+ version_name=version_name,
171
+ )
172
+ for version_name in version_names
173
+ ]
174
+
175
+ def delete_version(self, version_name: str) -> None:
176
+ raise NotImplementedError("Deleting version has not been supported yet.")
@@ -0,0 +1,19 @@
1
+ from typing import TypedDict
2
+
3
+ from typing_extensions import Required
4
+
5
+ from snowflake.ml.model import model_signature
6
+
7
+
8
+ class ModelMethodInfo(TypedDict):
9
+ """Method information.
10
+
11
+ Attributes:
12
+ name: Name of the method to be called via SQL.
13
+ target_method: actual target method name to be called.
14
+ signature: The signature of the model method.
15
+ """
16
+
17
+ name: Required[str]
18
+ target_method: Required[str]
19
+ signature: Required[model_signature.ModelSignature]
@@ -0,0 +1,291 @@
1
+ import re
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import pandas as pd
5
+
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml._internal.utils import sql_identifier
8
+ from snowflake.ml.model import model_signature
9
+ from snowflake.ml.model._client.model import model_method_info
10
+ from snowflake.ml.model._client.ops import metadata_ops, model_ops
11
+ from snowflake.snowpark import dataframe
12
+
13
+ _TELEMETRY_PROJECT = "MLOps"
14
+ _TELEMETRY_SUBPROJECT = "ModelManagement"
15
+
16
+
17
+ class ModelVersion:
18
+ """Model Version Object representing a specific version of the model that could be run."""
19
+
20
+ _model_ops: model_ops.ModelOperator
21
+ _model_name: sql_identifier.SqlIdentifier
22
+ _version_name: sql_identifier.SqlIdentifier
23
+
24
+ def __init__(self) -> None:
25
+ raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
26
+
27
+ @classmethod
28
+ def _ref(
29
+ cls,
30
+ model_ops: model_ops.ModelOperator,
31
+ *,
32
+ model_name: sql_identifier.SqlIdentifier,
33
+ version_name: sql_identifier.SqlIdentifier,
34
+ ) -> "ModelVersion":
35
+ self: "ModelVersion" = object.__new__(cls)
36
+ self._model_ops = model_ops
37
+ self._model_name = model_name
38
+ self._version_name = version_name
39
+ return self
40
+
41
+ def __eq__(self, __value: object) -> bool:
42
+ if not isinstance(__value, ModelVersion):
43
+ return False
44
+ return (
45
+ self._model_ops == __value._model_ops
46
+ and self._model_name == __value._model_name
47
+ and self._version_name == __value._version_name
48
+ )
49
+
50
+ @property
51
+ def model_name(self) -> str:
52
+ return self._model_name.identifier()
53
+
54
+ @property
55
+ def version_name(self) -> str:
56
+ return self._version_name.identifier()
57
+
58
+ @property
59
+ def fully_qualified_model_name(self) -> str:
60
+ return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
61
+
62
+ @property
63
+ @telemetry.send_api_usage_telemetry(
64
+ project=_TELEMETRY_PROJECT,
65
+ subproject=_TELEMETRY_SUBPROJECT,
66
+ )
67
+ def description(self) -> str:
68
+ statement_params = telemetry.get_statement_params(
69
+ project=_TELEMETRY_PROJECT,
70
+ subproject=_TELEMETRY_SUBPROJECT,
71
+ )
72
+ return self._model_ops.get_comment(
73
+ model_name=self._model_name,
74
+ version_name=self._version_name,
75
+ statement_params=statement_params,
76
+ )
77
+
78
+ @description.setter
79
+ @telemetry.send_api_usage_telemetry(
80
+ project=_TELEMETRY_PROJECT,
81
+ subproject=_TELEMETRY_SUBPROJECT,
82
+ )
83
+ def description(self, description: str) -> None:
84
+ statement_params = telemetry.get_statement_params(
85
+ project=_TELEMETRY_PROJECT,
86
+ subproject=_TELEMETRY_SUBPROJECT,
87
+ )
88
+ return self._model_ops.set_comment(
89
+ comment=description,
90
+ model_name=self._model_name,
91
+ version_name=self._version_name,
92
+ statement_params=statement_params,
93
+ )
94
+
95
+ @telemetry.send_api_usage_telemetry(
96
+ project=_TELEMETRY_PROJECT,
97
+ subproject=_TELEMETRY_SUBPROJECT,
98
+ )
99
+ def list_metrics(self) -> Dict[str, Any]:
100
+ """Show all metrics logged with the model version.
101
+
102
+ Returns:
103
+ A dictionary showing the metrics
104
+ """
105
+ statement_params = telemetry.get_statement_params(
106
+ project=_TELEMETRY_PROJECT,
107
+ subproject=_TELEMETRY_SUBPROJECT,
108
+ )
109
+ return self._model_ops._metadata_ops.load(
110
+ model_name=self._model_name, version_name=self._version_name, statement_params=statement_params
111
+ )["metrics"]
112
+
113
+ @telemetry.send_api_usage_telemetry(
114
+ project=_TELEMETRY_PROJECT,
115
+ subproject=_TELEMETRY_SUBPROJECT,
116
+ )
117
+ def get_metric(self, metric_name: str) -> Any:
118
+ """Get the value of a specific metric.
119
+
120
+ Args:
121
+ metric_name: The name of the metric
122
+
123
+ Raises:
124
+ KeyError: Raised when the requested metric name does not exist.
125
+
126
+ Returns:
127
+ The value of the metric.
128
+ """
129
+ metrics = self.list_metrics()
130
+ if metric_name not in metrics:
131
+ raise KeyError(f"Cannot find metric with name {metric_name}.")
132
+ return metrics[metric_name]
133
+
134
+ @telemetry.send_api_usage_telemetry(
135
+ project=_TELEMETRY_PROJECT,
136
+ subproject=_TELEMETRY_SUBPROJECT,
137
+ )
138
+ def set_metric(self, metric_name: str, value: Any) -> None:
139
+ """Set the value of a specific metric name
140
+
141
+ Args:
142
+ metric_name: The name of the metric
143
+ value: The value of the metric.
144
+ """
145
+ statement_params = telemetry.get_statement_params(
146
+ project=_TELEMETRY_PROJECT,
147
+ subproject=_TELEMETRY_SUBPROJECT,
148
+ )
149
+ metrics = self.list_metrics()
150
+ metrics[metric_name] = value
151
+ self._model_ops._metadata_ops.save(
152
+ metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
153
+ model_name=self._model_name,
154
+ version_name=self._version_name,
155
+ statement_params=statement_params,
156
+ )
157
+
158
+ @telemetry.send_api_usage_telemetry(
159
+ project=_TELEMETRY_PROJECT,
160
+ subproject=_TELEMETRY_SUBPROJECT,
161
+ )
162
+ def delete_metric(self, metric_name: str) -> None:
163
+ """Delete a metric from metric storage.
164
+
165
+ Args:
166
+ metric_name: The name of the metric to be deleted.
167
+
168
+ Raises:
169
+ KeyError: Raised when the requested metric name does not exist.
170
+ """
171
+ statement_params = telemetry.get_statement_params(
172
+ project=_TELEMETRY_PROJECT,
173
+ subproject=_TELEMETRY_SUBPROJECT,
174
+ )
175
+ metrics = self.list_metrics()
176
+ if metric_name not in metrics:
177
+ raise KeyError(f"Cannot find metric with name {metric_name}.")
178
+ del metrics[metric_name]
179
+ self._model_ops._metadata_ops.save(
180
+ metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
181
+ model_name=self._model_name,
182
+ version_name=self._version_name,
183
+ statement_params=statement_params,
184
+ )
185
+
186
+ @telemetry.send_api_usage_telemetry(
187
+ project=_TELEMETRY_PROJECT,
188
+ subproject=_TELEMETRY_SUBPROJECT,
189
+ )
190
+ def list_methods(self) -> List[model_method_info.ModelMethodInfo]:
191
+ """List all method information in a model version that is callable.
192
+
193
+ Returns:
194
+ A list of ModelMethodInfo object containing the following information:
195
+ - name: The name of the method to be called (both in SQL and in Python SDK).
196
+ - target_method: The original method name in the logged Python object.
197
+ - Signature: Python signature of the original method.
198
+ """
199
+ statement_params = telemetry.get_statement_params(
200
+ project=_TELEMETRY_PROJECT,
201
+ subproject=_TELEMETRY_SUBPROJECT,
202
+ )
203
+ # TODO(SNOW-986673, SNOW-986675): Avoid parsing manifest and meta file and put Python signature into user_data.
204
+ manifest = self._model_ops.get_model_version_manifest(
205
+ model_name=self._model_name,
206
+ version_name=self._version_name,
207
+ statement_params=statement_params,
208
+ )
209
+ model_meta = self._model_ops.get_model_version_native_packing_meta(
210
+ model_name=self._model_name,
211
+ version_name=self._version_name,
212
+ statement_params=statement_params,
213
+ )
214
+ return_methods_info: List[model_method_info.ModelMethodInfo] = []
215
+ for method in manifest["methods"]:
216
+ # Method's name is resolved so we need to use case_sensitive as True to get the user-facing identifier.
217
+ method_name = sql_identifier.SqlIdentifier(method["name"], case_sensitive=True).identifier()
218
+ # Method's handler is `functions.<target_method>.infer`
219
+ assert re.match(
220
+ r"^functions\.([^\d\W]\w*)\.infer$", method["handler"]
221
+ ), f"Get unexpected handler name {method['handler']}"
222
+ target_method = method["handler"].split(".")[1]
223
+ signature_dict = model_meta["signatures"][target_method]
224
+ method_info = model_method_info.ModelMethodInfo(
225
+ name=method_name,
226
+ target_method=target_method,
227
+ signature=model_signature.ModelSignature.from_dict(signature_dict),
228
+ )
229
+ return_methods_info.append(method_info)
230
+
231
+ return return_methods_info
232
+
233
+ @telemetry.send_api_usage_telemetry(
234
+ project=_TELEMETRY_PROJECT,
235
+ subproject=_TELEMETRY_SUBPROJECT,
236
+ )
237
+ def run(
238
+ self,
239
+ X: Union[pd.DataFrame, dataframe.DataFrame],
240
+ *,
241
+ method_name: Optional[str] = None,
242
+ ) -> Union[pd.DataFrame, dataframe.DataFrame]:
243
+ """Invoke a method in a model version object
244
+
245
+ Args:
246
+ X: The input data. Could be pandas DataFrame or Snowpark DataFrame
247
+ method_name: The method name to run. It is the name you will use to call a method in SQL. Defaults to None.
248
+ It can only be None if there is only 1 method.
249
+
250
+ Raises:
251
+ ValueError: No method with the corresponding name is available.
252
+ ValueError: There are more than 1 target methods available in the model but no method name specified.
253
+
254
+ Returns:
255
+ The prediction data.
256
+ """
257
+ statement_params = telemetry.get_statement_params(
258
+ project=_TELEMETRY_PROJECT,
259
+ subproject=_TELEMETRY_SUBPROJECT,
260
+ )
261
+
262
+ methods: List[model_method_info.ModelMethodInfo] = self.list_methods()
263
+ if method_name:
264
+ req_method_name = sql_identifier.SqlIdentifier(method_name).identifier()
265
+ find_method: Callable[[model_method_info.ModelMethodInfo], bool] = (
266
+ lambda method: method["name"] == req_method_name
267
+ )
268
+ target_method_info = next(
269
+ filter(find_method, methods),
270
+ None,
271
+ )
272
+ if target_method_info is None:
273
+ raise ValueError(
274
+ f"There is no method with name {method_name} available in the model"
275
+ f" {self.fully_qualified_model_name} version {self.version_name}"
276
+ )
277
+ elif len(methods) != 1:
278
+ raise ValueError(
279
+ f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
280
+ f" version {self.version_name}. Please specify a `method_name` when calling the `run` method."
281
+ )
282
+ else:
283
+ target_method_info = methods[0]
284
+ return self._model_ops.invoke_method(
285
+ method_name=sql_identifier.SqlIdentifier(target_method_info["name"]),
286
+ signature=target_method_info["signature"],
287
+ X=X,
288
+ model_name=self._model_name,
289
+ version_name=self._version_name,
290
+ statement_params=statement_params,
291
+ )
@@ -0,0 +1,107 @@
1
+ import json
2
+ from typing import Any, Dict, Optional, TypedDict
3
+
4
+ from typing_extensions import NotRequired
5
+
6
+ from snowflake.ml._internal.utils import sql_identifier
7
+ from snowflake.ml.model._client.sql import (
8
+ model as model_sql,
9
+ model_version as model_version_sql,
10
+ )
11
+ from snowflake.snowpark import session
12
+
13
+ MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
14
+
15
+
16
+ class ModelVersionMetadataSchema(TypedDict):
17
+ metrics: NotRequired[Dict[str, Any]]
18
+
19
+
20
+ class MetadataOperator:
21
+ def __init__(
22
+ self,
23
+ session: session.Session,
24
+ *,
25
+ database_name: sql_identifier.SqlIdentifier,
26
+ schema_name: sql_identifier.SqlIdentifier,
27
+ ) -> None:
28
+ self._model_client = model_sql.ModelSQLClient(
29
+ session,
30
+ database_name=database_name,
31
+ schema_name=schema_name,
32
+ )
33
+ self._model_version_client = model_version_sql.ModelVersionSQLClient(
34
+ session,
35
+ database_name=database_name,
36
+ schema_name=schema_name,
37
+ )
38
+
39
+ def __eq__(self, __value: object) -> bool:
40
+ if not isinstance(__value, MetadataOperator):
41
+ return False
42
+ return (
43
+ self._model_client == __value._model_client and self._model_version_client == __value._model_version_client
44
+ )
45
+
46
+ @staticmethod
47
+ def _parse(metadata_dict: Dict[str, Any]) -> ModelVersionMetadataSchema:
48
+ loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
49
+ if loaded_metadata_schema_version is None:
50
+ return ModelVersionMetadataSchema(metrics={})
51
+ elif (
52
+ not isinstance(loaded_metadata_schema_version, str)
53
+ or loaded_metadata_schema_version != MODEL_VERSION_METADATA_SCHEMA_VERSION
54
+ ):
55
+ raise ValueError(f"Unsupported model metadata schema version {loaded_metadata_schema_version} confronted.")
56
+ loaded_metrics = metadata_dict.get("metrics", {})
57
+ if not isinstance(loaded_metrics, dict):
58
+ raise ValueError(f"Metrics in the metadata is expected to be a dictionary, getting {loaded_metrics}")
59
+ return ModelVersionMetadataSchema(metrics=loaded_metrics)
60
+
61
+ def _get_current_metadata_dict(
62
+ self,
63
+ *,
64
+ model_name: sql_identifier.SqlIdentifier,
65
+ version_name: sql_identifier.SqlIdentifier,
66
+ statement_params: Optional[Dict[str, Any]] = None,
67
+ ) -> Dict[str, Any]:
68
+ version_info_list = self._model_client.show_versions(
69
+ model_name=model_name, version_name=version_name, statement_params=statement_params
70
+ )
71
+ assert len(version_info_list) == 1
72
+ version_info = version_info_list[0]
73
+ metadata_str = version_info.metadata
74
+ if not metadata_str:
75
+ return {}
76
+ res = json.loads(metadata_str)
77
+ if not isinstance(res, dict):
78
+ raise ValueError(f"Metadata is expected to be a dictionary, getting {res}")
79
+ return res
80
+
81
+ def load(
82
+ self,
83
+ *,
84
+ model_name: sql_identifier.SqlIdentifier,
85
+ version_name: sql_identifier.SqlIdentifier,
86
+ statement_params: Optional[Dict[str, Any]] = None,
87
+ ) -> ModelVersionMetadataSchema:
88
+ metadata_dict = self._get_current_metadata_dict(
89
+ model_name=model_name, version_name=version_name, statement_params=statement_params
90
+ )
91
+ return MetadataOperator._parse(metadata_dict)
92
+
93
+ def save(
94
+ self,
95
+ metadata: ModelVersionMetadataSchema,
96
+ *,
97
+ model_name: sql_identifier.SqlIdentifier,
98
+ version_name: sql_identifier.SqlIdentifier,
99
+ statement_params: Optional[Dict[str, Any]] = None,
100
+ ) -> None:
101
+ metadata_dict = self._get_current_metadata_dict(
102
+ model_name=model_name, version_name=version_name, statement_params=statement_params
103
+ )
104
+ metadata_dict.update({**metadata, "snowpark_ml_schema_version": MODEL_VERSION_METADATA_SCHEMA_VERSION})
105
+ self._model_version_client.set_metadata(
106
+ metadata_dict, model_name=model_name, version_name=version_name, statement_params=statement_params
107
+ )