snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -3,11 +3,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
3
3
 
4
4
  import pandas as pd
5
5
 
6
+ from snowflake import connector
6
7
  from snowflake.ml._internal import telemetry
7
8
  from snowflake.ml._internal.utils import sql_identifier
8
9
  from snowflake.ml.model import model_signature
9
- from snowflake.ml.model._client.model import model_method_info
10
10
  from snowflake.ml.model._client.ops import metadata_ops, model_ops
11
+ from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
11
12
  from snowflake.snowpark import dataframe
12
13
 
13
14
  _TELEMETRY_PROJECT = "MLOps"
@@ -49,14 +50,17 @@ class ModelVersion:
49
50
 
50
51
  @property
51
52
  def model_name(self) -> str:
53
+ """Return the name of the model to which the model version belongs, usable as a reference in SQL."""
52
54
  return self._model_name.identifier()
53
55
 
54
56
  @property
55
57
  def version_name(self) -> str:
58
+ """Return the name of the version to which the model version belongs, usable as a reference in SQL."""
56
59
  return self._version_name.identifier()
57
60
 
58
61
  @property
59
62
  def fully_qualified_model_name(self) -> str:
63
+ """Return the fully qualified name of the model to which the model version belongs."""
60
64
  return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
61
65
 
62
66
  @property
@@ -65,6 +69,24 @@ class ModelVersion:
65
69
  subproject=_TELEMETRY_SUBPROJECT,
66
70
  )
67
71
  def description(self) -> str:
72
+ """The description for the model version. This is an alias of `comment`."""
73
+ return self.comment
74
+
75
+ @description.setter
76
+ @telemetry.send_api_usage_telemetry(
77
+ project=_TELEMETRY_PROJECT,
78
+ subproject=_TELEMETRY_SUBPROJECT,
79
+ )
80
+ def description(self, description: str) -> None:
81
+ self.comment = description
82
+
83
+ @property
84
+ @telemetry.send_api_usage_telemetry(
85
+ project=_TELEMETRY_PROJECT,
86
+ subproject=_TELEMETRY_SUBPROJECT,
87
+ )
88
+ def comment(self) -> str:
89
+ """The comment to the model version."""
68
90
  statement_params = telemetry.get_statement_params(
69
91
  project=_TELEMETRY_PROJECT,
70
92
  subproject=_TELEMETRY_SUBPROJECT,
@@ -75,18 +97,18 @@ class ModelVersion:
75
97
  statement_params=statement_params,
76
98
  )
77
99
 
78
- @description.setter
100
+ @comment.setter
79
101
  @telemetry.send_api_usage_telemetry(
80
102
  project=_TELEMETRY_PROJECT,
81
103
  subproject=_TELEMETRY_SUBPROJECT,
82
104
  )
83
- def description(self, description: str) -> None:
105
+ def comment(self, comment: str) -> None:
84
106
  statement_params = telemetry.get_statement_params(
85
107
  project=_TELEMETRY_PROJECT,
86
108
  subproject=_TELEMETRY_SUBPROJECT,
87
109
  )
88
110
  return self._model_ops.set_comment(
89
- comment=description,
111
+ comment=comment,
90
112
  model_name=self._model_name,
91
113
  version_name=self._version_name,
92
114
  statement_params=statement_params,
@@ -96,11 +118,11 @@ class ModelVersion:
96
118
  project=_TELEMETRY_PROJECT,
97
119
  subproject=_TELEMETRY_SUBPROJECT,
98
120
  )
99
- def list_metrics(self) -> Dict[str, Any]:
121
+ def show_metrics(self) -> Dict[str, Any]:
100
122
  """Show all metrics logged with the model version.
101
123
 
102
124
  Returns:
103
- A dictionary showing the metrics
125
+ A dictionary showing the metrics.
104
126
  """
105
127
  statement_params = telemetry.get_statement_params(
106
128
  project=_TELEMETRY_PROJECT,
@@ -118,15 +140,15 @@ class ModelVersion:
118
140
  """Get the value of a specific metric.
119
141
 
120
142
  Args:
121
- metric_name: The name of the metric
143
+ metric_name: The name of the metric.
122
144
 
123
145
  Raises:
124
- KeyError: Raised when the requested metric name does not exist.
146
+ KeyError: When the requested metric name does not exist.
125
147
 
126
148
  Returns:
127
149
  The value of the metric.
128
150
  """
129
- metrics = self.list_metrics()
151
+ metrics = self.show_metrics()
130
152
  if metric_name not in metrics:
131
153
  raise KeyError(f"Cannot find metric with name {metric_name}.")
132
154
  return metrics[metric_name]
@@ -136,17 +158,17 @@ class ModelVersion:
136
158
  subproject=_TELEMETRY_SUBPROJECT,
137
159
  )
138
160
  def set_metric(self, metric_name: str, value: Any) -> None:
139
- """Set the value of a specific metric name
161
+ """Set the value of a specific metric.
140
162
 
141
163
  Args:
142
- metric_name: The name of the metric
164
+ metric_name: The name of the metric.
143
165
  value: The value of the metric.
144
166
  """
145
167
  statement_params = telemetry.get_statement_params(
146
168
  project=_TELEMETRY_PROJECT,
147
169
  subproject=_TELEMETRY_SUBPROJECT,
148
170
  )
149
- metrics = self.list_metrics()
171
+ metrics = self.show_metrics()
150
172
  metrics[metric_name] = value
151
173
  self._model_ops._metadata_ops.save(
152
174
  metadata_ops.ModelVersionMetadataSchema(metrics=metrics),
@@ -166,13 +188,13 @@ class ModelVersion:
166
188
  metric_name: The name of the metric to be deleted.
167
189
 
168
190
  Raises:
169
- KeyError: Raised when the requested metric name does not exist.
191
+ KeyError: When the requested metric name does not exist.
170
192
  """
171
193
  statement_params = telemetry.get_statement_params(
172
194
  project=_TELEMETRY_PROJECT,
173
195
  subproject=_TELEMETRY_SUBPROJECT,
174
196
  )
175
- metrics = self.list_metrics()
197
+ metrics = self.show_metrics()
176
198
  if metric_name not in metrics:
177
199
  raise KeyError(f"Cannot find metric with name {metric_name}.")
178
200
  del metrics[metric_name]
@@ -183,24 +205,12 @@ class ModelVersion:
183
205
  statement_params=statement_params,
184
206
  )
185
207
 
186
- @telemetry.send_api_usage_telemetry(
187
- project=_TELEMETRY_PROJECT,
188
- subproject=_TELEMETRY_SUBPROJECT,
189
- )
190
- def list_methods(self) -> List[model_method_info.ModelMethodInfo]:
191
- """List all method information in a model version that is callable.
192
-
193
- Returns:
194
- A list of ModelMethodInfo object containing the following information:
195
- - name: The name of the method to be called (both in SQL and in Python SDK).
196
- - target_method: The original method name in the logged Python object.
197
- - Signature: Python signature of the original method.
198
- """
208
+ # Only used when the model does not contains user_data with client SDK information.
209
+ def _legacy_show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
199
210
  statement_params = telemetry.get_statement_params(
200
211
  project=_TELEMETRY_PROJECT,
201
212
  subproject=_TELEMETRY_SUBPROJECT,
202
213
  )
203
- # TODO(SNOW-986673, SNOW-986675): Avoid parsing manifest and meta file and put Python signature into user_data.
204
214
  manifest = self._model_ops.get_model_version_manifest(
205
215
  model_name=self._model_name,
206
216
  version_name=self._version_name,
@@ -211,7 +221,7 @@ class ModelVersion:
211
221
  version_name=self._version_name,
212
222
  statement_params=statement_params,
213
223
  )
214
- return_methods_info: List[model_method_info.ModelMethodInfo] = []
224
+ return_functions_info: List[model_manifest_schema.ModelFunctionInfo] = []
215
225
  for method in manifest["methods"]:
216
226
  # Method's name is resolved so we need to use case_sensitive as True to get the user-facing identifier.
217
227
  method_name = sql_identifier.SqlIdentifier(method["name"], case_sensitive=True).identifier()
@@ -221,14 +231,48 @@ class ModelVersion:
221
231
  ), f"Get unexpected handler name {method['handler']}"
222
232
  target_method = method["handler"].split(".")[1]
223
233
  signature_dict = model_meta["signatures"][target_method]
224
- method_info = model_method_info.ModelMethodInfo(
234
+ fi = model_manifest_schema.ModelFunctionInfo(
225
235
  name=method_name,
226
236
  target_method=target_method,
227
237
  signature=model_signature.ModelSignature.from_dict(signature_dict),
228
238
  )
229
- return_methods_info.append(method_info)
239
+ return_functions_info.append(fi)
240
+ return return_functions_info
241
+
242
+ @telemetry.send_api_usage_telemetry(
243
+ project=_TELEMETRY_PROJECT,
244
+ subproject=_TELEMETRY_SUBPROJECT,
245
+ )
246
+ def show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
247
+ """Show all functions information in a model version that is callable.
230
248
 
231
- return return_methods_info
249
+ Returns:
250
+ A list of ModelFunctionInfo objects containing the following information:
251
+
252
+ - name: The name of the function to be called (both in SQL and in Python SDK).
253
+ - target_method: The original method name in the logged Python object.
254
+ - signature: Python signature of the original method.
255
+ """
256
+ statement_params = telemetry.get_statement_params(
257
+ project=_TELEMETRY_PROJECT,
258
+ subproject=_TELEMETRY_SUBPROJECT,
259
+ )
260
+ try:
261
+ client_data = self._model_ops.get_client_data_in_user_data(
262
+ model_name=self._model_name,
263
+ version_name=self._version_name,
264
+ statement_params=statement_params,
265
+ )
266
+ return [
267
+ model_manifest_schema.ModelFunctionInfo(
268
+ name=fi["name"],
269
+ target_method=fi["target_method"],
270
+ signature=model_signature.ModelSignature.from_dict(fi["signature"]),
271
+ )
272
+ for fi in client_data["functions"]
273
+ ]
274
+ except (NotImplementedError, ValueError, connector.DataError):
275
+ return self._legacy_show_functions()
232
276
 
233
277
  @telemetry.send_api_usage_telemetry(
234
278
  project=_TELEMETRY_PROJECT,
@@ -238,52 +282,52 @@ class ModelVersion:
238
282
  self,
239
283
  X: Union[pd.DataFrame, dataframe.DataFrame],
240
284
  *,
241
- method_name: Optional[str] = None,
285
+ function_name: Optional[str] = None,
242
286
  ) -> Union[pd.DataFrame, dataframe.DataFrame]:
243
- """Invoke a method in a model version object
287
+ """Invoke a method in a model version object.
244
288
 
245
289
  Args:
246
- X: The input data. Could be pandas DataFrame or Snowpark DataFrame
247
- method_name: The method name to run. It is the name you will use to call a method in SQL. Defaults to None.
248
- It can only be None if there is only 1 method.
290
+ X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
291
+ function_name: The function name to run. It is the name used to call a function in SQL.
292
+ Defaults to None. It can only be None if there is only 1 method.
249
293
 
250
294
  Raises:
251
- ValueError: No method with the corresponding name is available.
252
- ValueError: There are more than 1 target methods available in the model but no method name specified.
295
+ ValueError: When no method with the corresponding name is available.
296
+ ValueError: When there are more than 1 target methods available in the model but no function name specified.
253
297
 
254
298
  Returns:
255
- The prediction data.
299
+ The prediction data. It would be the same type dataframe as your input.
256
300
  """
257
301
  statement_params = telemetry.get_statement_params(
258
302
  project=_TELEMETRY_PROJECT,
259
303
  subproject=_TELEMETRY_SUBPROJECT,
260
304
  )
261
305
 
262
- methods: List[model_method_info.ModelMethodInfo] = self.list_methods()
263
- if method_name:
264
- req_method_name = sql_identifier.SqlIdentifier(method_name).identifier()
265
- find_method: Callable[[model_method_info.ModelMethodInfo], bool] = (
306
+ functions: List[model_manifest_schema.ModelFunctionInfo] = self.show_functions()
307
+ if function_name:
308
+ req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
309
+ find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
266
310
  lambda method: method["name"] == req_method_name
267
311
  )
268
- target_method_info = next(
269
- filter(find_method, methods),
312
+ target_function_info = next(
313
+ filter(find_method, functions),
270
314
  None,
271
315
  )
272
- if target_method_info is None:
316
+ if target_function_info is None:
273
317
  raise ValueError(
274
- f"There is no method with name {method_name} available in the model"
318
+ f"There is no method with name {function_name} available in the model"
275
319
  f" {self.fully_qualified_model_name} version {self.version_name}"
276
320
  )
277
- elif len(methods) != 1:
321
+ elif len(functions) != 1:
278
322
  raise ValueError(
279
323
  f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
280
324
  f" version {self.version_name}. Please specify a `method_name` when calling the `run` method."
281
325
  )
282
326
  else:
283
- target_method_info = methods[0]
327
+ target_function_info = functions[0]
284
328
  return self._model_ops.invoke_method(
285
- method_name=sql_identifier.SqlIdentifier(target_method_info["name"]),
286
- signature=target_method_info["signature"],
329
+ method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
330
+ signature=target_function_info["signature"],
287
331
  X=X,
288
332
  model_name=self._model_name,
289
333
  version_name=self._version_name,
@@ -68,9 +68,7 @@ class MetadataOperator:
68
68
  version_info_list = self._model_client.show_versions(
69
69
  model_name=model_name, version_name=version_name, statement_params=statement_params
70
70
  )
71
- assert len(version_info_list) == 1
72
- version_info = version_info_list[0]
73
- metadata_str = version_info.metadata
71
+ metadata_str = version_info_list[0][self._model_client.MODEL_VERSION_METADATA_COL_NAME]
74
72
  if not metadata_str:
75
73
  return {}
76
74
  res = json.loads(metadata_str)
@@ -1,16 +1,18 @@
1
+ import json
1
2
  import pathlib
2
3
  import tempfile
3
4
  from typing import Any, Dict, List, Optional, Union, cast
4
5
 
5
6
  import yaml
6
7
 
7
- from snowflake.ml._internal.utils import sql_identifier
8
+ from snowflake.ml._internal.utils import identifier, sql_identifier
8
9
  from snowflake.ml.model import model_signature, type_hints
9
10
  from snowflake.ml.model._client.ops import metadata_ops
10
11
  from snowflake.ml.model._client.sql import (
11
12
  model as model_sql,
12
13
  model_version as model_version_sql,
13
14
  stage as stage_sql,
15
+ tag as tag_sql,
14
16
  )
15
17
  from snowflake.ml.model._model_composer import model_composer
16
18
  from snowflake.ml.model._model_composer.model_manifest import (
@@ -19,7 +21,7 @@ from snowflake.ml.model._model_composer.model_manifest import (
19
21
  )
20
22
  from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
21
23
  from snowflake.ml.model._signatures import snowpark_handler
22
- from snowflake.snowpark import dataframe, session
24
+ from snowflake.snowpark import dataframe, row, session
23
25
  from snowflake.snowpark._internal import utils as snowpark_utils
24
26
 
25
27
 
@@ -50,6 +52,11 @@ class ModelOperator:
50
52
  database_name=database_name,
51
53
  schema_name=schema_name,
52
54
  )
55
+ self._tag_client = tag_sql.ModuleTagSQLClient(
56
+ session,
57
+ database_name=database_name,
58
+ schema_name=schema_name,
59
+ )
53
60
  self._metadata_ops = metadata_ops.MetadataOperator(
54
61
  session,
55
62
  database_name=database_name,
@@ -109,22 +116,39 @@ class ModelOperator:
109
116
  statement_params=statement_params,
110
117
  )
111
118
 
112
- def list_models_or_versions(
119
+ def show_models_or_versions(
113
120
  self,
114
121
  *,
115
122
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
116
123
  statement_params: Optional[Dict[str, Any]] = None,
117
- ) -> List[sql_identifier.SqlIdentifier]:
124
+ ) -> List[row.Row]:
118
125
  if model_name:
119
- res = self._model_client.show_versions(
126
+ return self._model_client.show_versions(
120
127
  model_name=model_name,
128
+ validate_result=False,
121
129
  statement_params=statement_params,
122
130
  )
123
131
  else:
124
- res = self._model_client.show_models(
132
+ return self._model_client.show_models(
133
+ validate_result=False,
125
134
  statement_params=statement_params,
126
135
  )
127
- return [sql_identifier.SqlIdentifier(row.name, case_sensitive=True) for row in res]
136
+
137
+ def list_models_or_versions(
138
+ self,
139
+ *,
140
+ model_name: Optional[sql_identifier.SqlIdentifier] = None,
141
+ statement_params: Optional[Dict[str, Any]] = None,
142
+ ) -> List[sql_identifier.SqlIdentifier]:
143
+ res = self.show_models_or_versions(
144
+ model_name=model_name,
145
+ statement_params=statement_params,
146
+ )
147
+ if model_name:
148
+ col_name = self._model_client.MODEL_VERSION_NAME_COL_NAME
149
+ else:
150
+ col_name = self._model_client.MODEL_NAME_COL_NAME
151
+ return [sql_identifier.SqlIdentifier(row[col_name], case_sensitive=True) for row in res]
128
152
 
129
153
  def validate_existence(
130
154
  self,
@@ -137,11 +161,13 @@ class ModelOperator:
137
161
  res = self._model_client.show_versions(
138
162
  model_name=model_name,
139
163
  version_name=version_name,
164
+ validate_result=False,
140
165
  statement_params=statement_params,
141
166
  )
142
167
  else:
143
168
  res = self._model_client.show_models(
144
169
  model_name=model_name,
170
+ validate_result=False,
145
171
  statement_params=statement_params,
146
172
  )
147
173
  return len(res) == 1
@@ -159,13 +185,14 @@ class ModelOperator:
159
185
  version_name=version_name,
160
186
  statement_params=statement_params,
161
187
  )
188
+ col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
162
189
  else:
163
190
  res = self._model_client.show_models(
164
191
  model_name=model_name,
165
192
  statement_params=statement_params,
166
193
  )
167
- assert len(res) == 1
168
- return cast(str, res[0].comment)
194
+ col_name = self._model_client.MODEL_COMMENT_COL_NAME
195
+ return cast(str, res[0][col_name])
169
196
 
170
197
  def set_comment(
171
198
  self,
@@ -189,6 +216,109 @@ class ModelOperator:
189
216
  statement_params=statement_params,
190
217
  )
191
218
 
219
+ def set_default_version(
220
+ self,
221
+ *,
222
+ model_name: sql_identifier.SqlIdentifier,
223
+ version_name: sql_identifier.SqlIdentifier,
224
+ statement_params: Optional[Dict[str, Any]] = None,
225
+ ) -> None:
226
+ if not self.validate_existence(
227
+ model_name=model_name, version_name=version_name, statement_params=statement_params
228
+ ):
229
+ raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
230
+ self._model_version_client.set_default_version(
231
+ model_name=model_name, version_name=version_name, statement_params=statement_params
232
+ )
233
+
234
+ def get_default_version(
235
+ self,
236
+ *,
237
+ model_name: sql_identifier.SqlIdentifier,
238
+ statement_params: Optional[Dict[str, Any]] = None,
239
+ ) -> sql_identifier.SqlIdentifier:
240
+ res = self._model_client.show_models(model_name=model_name, statement_params=statement_params)[0]
241
+ return sql_identifier.SqlIdentifier(
242
+ res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
243
+ )
244
+
245
+ def get_tag_value(
246
+ self,
247
+ *,
248
+ model_name: sql_identifier.SqlIdentifier,
249
+ tag_database_name: sql_identifier.SqlIdentifier,
250
+ tag_schema_name: sql_identifier.SqlIdentifier,
251
+ tag_name: sql_identifier.SqlIdentifier,
252
+ statement_params: Optional[Dict[str, Any]] = None,
253
+ ) -> Optional[str]:
254
+ r = self._tag_client.get_tag_value(
255
+ module_name=model_name,
256
+ tag_database_name=tag_database_name,
257
+ tag_schema_name=tag_schema_name,
258
+ tag_name=tag_name,
259
+ statement_params=statement_params,
260
+ )
261
+ value = r.TAG_VALUE
262
+ if value is None:
263
+ return value
264
+ return str(value)
265
+
266
+ def show_tags(
267
+ self,
268
+ *,
269
+ model_name: sql_identifier.SqlIdentifier,
270
+ statement_params: Optional[Dict[str, Any]] = None,
271
+ ) -> Dict[str, str]:
272
+ tags_info = self._tag_client.get_tag_list(
273
+ module_name=model_name,
274
+ statement_params=statement_params,
275
+ )
276
+ res: Dict[str, str] = {
277
+ identifier.get_schema_level_object_identifier(
278
+ sql_identifier.SqlIdentifier(r.TAG_DATABASE, case_sensitive=True),
279
+ sql_identifier.SqlIdentifier(r.TAG_SCHEMA, case_sensitive=True),
280
+ sql_identifier.SqlIdentifier(r.TAG_NAME, case_sensitive=True),
281
+ ): str(r.TAG_VALUE)
282
+ for r in tags_info
283
+ }
284
+ return res
285
+
286
+ def set_tag(
287
+ self,
288
+ *,
289
+ model_name: sql_identifier.SqlIdentifier,
290
+ tag_database_name: sql_identifier.SqlIdentifier,
291
+ tag_schema_name: sql_identifier.SqlIdentifier,
292
+ tag_name: sql_identifier.SqlIdentifier,
293
+ tag_value: str,
294
+ statement_params: Optional[Dict[str, Any]] = None,
295
+ ) -> None:
296
+ self._tag_client.set_tag_on_model(
297
+ model_name=model_name,
298
+ tag_database_name=tag_database_name,
299
+ tag_schema_name=tag_schema_name,
300
+ tag_name=tag_name,
301
+ tag_value=tag_value,
302
+ statement_params=statement_params,
303
+ )
304
+
305
+ def unset_tag(
306
+ self,
307
+ *,
308
+ model_name: sql_identifier.SqlIdentifier,
309
+ tag_database_name: sql_identifier.SqlIdentifier,
310
+ tag_schema_name: sql_identifier.SqlIdentifier,
311
+ tag_name: sql_identifier.SqlIdentifier,
312
+ statement_params: Optional[Dict[str, Any]] = None,
313
+ ) -> None:
314
+ self._tag_client.unset_tag_on_model(
315
+ model_name=model_name,
316
+ tag_database_name=tag_database_name,
317
+ tag_schema_name=tag_schema_name,
318
+ tag_name=tag_name,
319
+ statement_params=statement_params,
320
+ )
321
+
192
322
  def get_model_version_manifest(
193
323
  self,
194
324
  *,
@@ -228,6 +358,22 @@ class ModelOperator:
228
358
  raw_model_meta = yaml.safe_load(f)
229
359
  return model_meta.ModelMetadata._validate_model_metadata(raw_model_meta)
230
360
 
361
+ def get_client_data_in_user_data(
362
+ self,
363
+ *,
364
+ model_name: sql_identifier.SqlIdentifier,
365
+ version_name: sql_identifier.SqlIdentifier,
366
+ statement_params: Optional[Dict[str, Any]] = None,
367
+ ) -> model_manifest_schema.SnowparkMLDataDict:
368
+ raw_user_data_json_string = self._model_client.show_versions(
369
+ model_name=model_name,
370
+ version_name=version_name,
371
+ statement_params=statement_params,
372
+ )[0][self._model_client.MODEL_VERSION_USER_DATA_COL_NAME]
373
+ raw_user_data = json.loads(raw_user_data_json_string)
374
+ assert isinstance(raw_user_data, dict), "user data should be a dictionary"
375
+ return model_manifest.ModelManifest.parse_client_data_from_user_data(raw_user_data)
376
+
231
377
  def invoke_method(
232
378
  self,
233
379
  *,
@@ -1,10 +1,23 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from snowflake.ml._internal.utils import identifier, sql_identifier
3
+ from snowflake.ml._internal.utils import (
4
+ identifier,
5
+ query_result_checker,
6
+ sql_identifier,
7
+ )
4
8
  from snowflake.snowpark import row, session
5
9
 
6
10
 
7
11
  class ModelSQLClient:
12
+ MODEL_NAME_COL_NAME = "name"
13
+ MODEL_COMMENT_COL_NAME = "comment"
14
+ MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
15
+
16
+ MODEL_VERSION_NAME_COL_NAME = "name"
17
+ MODEL_VERSION_COMMENT_COL_NAME = "comment"
18
+ MODEL_VERSION_METADATA_COL_NAME = "metadata"
19
+ MODEL_VERSION_USER_DATA_COL_NAME = "user_data"
20
+
8
21
  def __init__(
9
22
  self,
10
23
  session: session.Session,
@@ -30,29 +43,56 @@ class ModelSQLClient:
30
43
  self,
31
44
  *,
32
45
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
46
+ validate_result: bool = True,
33
47
  statement_params: Optional[Dict[str, Any]] = None,
34
48
  ) -> List[row.Row]:
35
49
  fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
36
50
  like_sql = ""
37
51
  if model_name:
38
52
  like_sql = f" LIKE '{model_name.resolved()}'"
39
- res = self._session.sql(f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}")
40
53
 
41
- return res.collect(statement_params=statement_params)
54
+ res = (
55
+ query_result_checker.SqlResultValidator(
56
+ self._session,
57
+ f"SHOW MODELS{like_sql} IN SCHEMA {fully_qualified_schema_name}",
58
+ statement_params=statement_params,
59
+ )
60
+ .has_column(ModelSQLClient.MODEL_NAME_COL_NAME, allow_empty=True)
61
+ .has_column(ModelSQLClient.MODEL_COMMENT_COL_NAME, allow_empty=True)
62
+ .has_column(ModelSQLClient.MODEL_DEFAULT_VERSION_NAME_COL_NAME, allow_empty=True)
63
+ )
64
+ if validate_result and model_name:
65
+ res = res.has_dimensions(expected_rows=1)
66
+
67
+ return res.validate()
42
68
 
43
69
  def show_versions(
44
70
  self,
45
71
  *,
46
72
  model_name: sql_identifier.SqlIdentifier,
47
73
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
74
+ validate_result: bool = True,
48
75
  statement_params: Optional[Dict[str, Any]] = None,
49
76
  ) -> List[row.Row]:
50
77
  like_sql = ""
51
78
  if version_name:
52
79
  like_sql = f" LIKE '{version_name.resolved()}'"
53
- res = self._session.sql(f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}")
54
80
 
55
- return res.collect(statement_params=statement_params)
81
+ res = (
82
+ query_result_checker.SqlResultValidator(
83
+ self._session,
84
+ f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}",
85
+ statement_params=statement_params,
86
+ )
87
+ .has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
88
+ .has_column(ModelSQLClient.MODEL_VERSION_COMMENT_COL_NAME, allow_empty=True)
89
+ .has_column(ModelSQLClient.MODEL_VERSION_METADATA_COL_NAME, allow_empty=True)
90
+ .has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
91
+ )
92
+ if validate_result and version_name:
93
+ res = res.has_dimensions(expected_rows=1)
94
+
95
+ return res.validate()
56
96
 
57
97
  def set_comment(
58
98
  self,
@@ -61,8 +101,11 @@ class ModelSQLClient:
61
101
  model_name: sql_identifier.SqlIdentifier,
62
102
  statement_params: Optional[Dict[str, Any]] = None,
63
103
  ) -> None:
64
- comment_sql = f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$"
65
- self._session.sql(comment_sql).collect(statement_params=statement_params)
104
+ query_result_checker.SqlResultValidator(
105
+ self._session,
106
+ f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$",
107
+ statement_params=statement_params,
108
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
66
109
 
67
110
  def drop_model(
68
111
  self,
@@ -70,6 +113,8 @@ class ModelSQLClient:
70
113
  model_name: sql_identifier.SqlIdentifier,
71
114
  statement_params: Optional[Dict[str, Any]] = None,
72
115
  ) -> None:
73
- self._session.sql(f"DROP MODEL {self.fully_qualified_model_name(model_name)}").collect(
74
- statement_params=statement_params
75
- )
116
+ query_result_checker.SqlResultValidator(
117
+ self._session,
118
+ f"DROP MODEL {self.fully_qualified_model_name(model_name)}",
119
+ statement_params=statement_params,
120
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()