snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
+ import os
1
2
  import pathlib
2
3
  import tempfile
3
- from contextlib import contextmanager
4
- from typing import Any, Dict, Generator, List, Optional, Union, cast
4
+ from typing import Any, Dict, List, Literal, Optional, Union, cast
5
5
 
6
6
  import yaml
7
7
 
@@ -19,7 +19,9 @@ from snowflake.ml.model._model_composer.model_manifest import (
19
19
  model_manifest,
20
20
  model_manifest_schema,
21
21
  )
22
+ from snowflake.ml.model._packager.model_env import model_env
22
23
  from snowflake.ml.model._packager.model_meta import model_meta
24
+ from snowflake.ml.model._packager.model_runtime import model_runtime
23
25
  from snowflake.ml.model._signatures import snowpark_handler
24
26
  from snowflake.snowpark import dataframe, row, session
25
27
  from snowflake.snowpark._internal import utils as snowpark_utils
@@ -72,37 +74,57 @@ class ModelOperator:
72
74
  and self._model_version_client == __value._model_version_client
73
75
  )
74
76
 
75
- def prepare_model_stage_path(self, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
77
+ def prepare_model_stage_path(
78
+ self,
79
+ *,
80
+ database_name: Optional[sql_identifier.SqlIdentifier],
81
+ schema_name: Optional[sql_identifier.SqlIdentifier],
82
+ statement_params: Optional[Dict[str, Any]] = None,
83
+ ) -> str:
76
84
  stage_name = sql_identifier.SqlIdentifier(
77
85
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
78
86
  )
79
- self._stage_client.create_tmp_stage(stage_name=stage_name, statement_params=statement_params)
80
- return f"@{self._stage_client.fully_qualified_stage_name(stage_name)}/model"
87
+ self._stage_client.create_tmp_stage(
88
+ database_name=database_name,
89
+ schema_name=schema_name,
90
+ stage_name=stage_name,
91
+ statement_params=statement_params,
92
+ )
93
+ return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
81
94
 
82
95
  def create_from_stage(
83
96
  self,
84
97
  composed_model: model_composer.ModelComposer,
85
98
  *,
99
+ database_name: Optional[sql_identifier.SqlIdentifier],
100
+ schema_name: Optional[sql_identifier.SqlIdentifier],
86
101
  model_name: sql_identifier.SqlIdentifier,
87
102
  version_name: sql_identifier.SqlIdentifier,
88
103
  statement_params: Optional[Dict[str, Any]] = None,
89
104
  ) -> None:
90
105
  stage_path = str(composed_model.stage_path)
91
106
  if self.validate_existence(
107
+ database_name=database_name,
108
+ schema_name=schema_name,
92
109
  model_name=model_name,
93
110
  statement_params=statement_params,
94
111
  ):
95
112
  if self.validate_existence(
113
+ database_name=database_name,
114
+ schema_name=schema_name,
96
115
  model_name=model_name,
97
116
  version_name=version_name,
98
117
  statement_params=statement_params,
99
118
  ):
100
119
  raise ValueError(
101
- f"Model {self._model_version_client.fully_qualified_model_name(model_name)} "
102
- f"version {version_name} already existed."
120
+ "Model "
121
+ f"{self._model_version_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
122
+ f" version {version_name} already existed."
103
123
  )
104
124
  else:
105
125
  self._model_version_client.add_version_from_stage(
126
+ database_name=database_name,
127
+ schema_name=schema_name,
106
128
  stage_path=stage_path,
107
129
  model_name=model_name,
108
130
  version_name=version_name,
@@ -110,6 +132,8 @@ class ModelOperator:
110
132
  )
111
133
  else:
112
134
  self._model_version_client.create_from_stage(
135
+ database_name=database_name,
136
+ schema_name=schema_name,
113
137
  stage_path=stage_path,
114
138
  model_name=model_name,
115
139
  version_name=version_name,
@@ -119,17 +143,23 @@ class ModelOperator:
119
143
  def show_models_or_versions(
120
144
  self,
121
145
  *,
146
+ database_name: Optional[sql_identifier.SqlIdentifier],
147
+ schema_name: Optional[sql_identifier.SqlIdentifier],
122
148
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
123
149
  statement_params: Optional[Dict[str, Any]] = None,
124
150
  ) -> List[row.Row]:
125
151
  if model_name:
126
152
  return self._model_client.show_versions(
153
+ database_name=database_name,
154
+ schema_name=schema_name,
127
155
  model_name=model_name,
128
156
  validate_result=False,
129
157
  statement_params=statement_params,
130
158
  )
131
159
  else:
132
160
  return self._model_client.show_models(
161
+ database_name=database_name,
162
+ schema_name=schema_name,
133
163
  validate_result=False,
134
164
  statement_params=statement_params,
135
165
  )
@@ -137,10 +167,14 @@ class ModelOperator:
137
167
  def list_models_or_versions(
138
168
  self,
139
169
  *,
170
+ database_name: Optional[sql_identifier.SqlIdentifier],
171
+ schema_name: Optional[sql_identifier.SqlIdentifier],
140
172
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
141
173
  statement_params: Optional[Dict[str, Any]] = None,
142
174
  ) -> List[sql_identifier.SqlIdentifier]:
143
175
  res = self.show_models_or_versions(
176
+ database_name=database_name,
177
+ schema_name=schema_name,
144
178
  model_name=model_name,
145
179
  statement_params=statement_params,
146
180
  )
@@ -153,12 +187,16 @@ class ModelOperator:
153
187
  def validate_existence(
154
188
  self,
155
189
  *,
190
+ database_name: Optional[sql_identifier.SqlIdentifier],
191
+ schema_name: Optional[sql_identifier.SqlIdentifier],
156
192
  model_name: sql_identifier.SqlIdentifier,
157
193
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
158
194
  statement_params: Optional[Dict[str, Any]] = None,
159
195
  ) -> bool:
160
196
  if version_name:
161
197
  res = self._model_client.show_versions(
198
+ database_name=database_name,
199
+ schema_name=schema_name,
162
200
  model_name=model_name,
163
201
  version_name=version_name,
164
202
  validate_result=False,
@@ -166,6 +204,8 @@ class ModelOperator:
166
204
  )
167
205
  else:
168
206
  res = self._model_client.show_models(
207
+ database_name=database_name,
208
+ schema_name=schema_name,
169
209
  model_name=model_name,
170
210
  validate_result=False,
171
211
  statement_params=statement_params,
@@ -175,12 +215,16 @@ class ModelOperator:
175
215
  def get_comment(
176
216
  self,
177
217
  *,
218
+ database_name: Optional[sql_identifier.SqlIdentifier],
219
+ schema_name: Optional[sql_identifier.SqlIdentifier],
178
220
  model_name: sql_identifier.SqlIdentifier,
179
221
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
180
222
  statement_params: Optional[Dict[str, Any]] = None,
181
223
  ) -> str:
182
224
  if version_name:
183
225
  res = self._model_client.show_versions(
226
+ database_name=database_name,
227
+ schema_name=schema_name,
184
228
  model_name=model_name,
185
229
  version_name=version_name,
186
230
  statement_params=statement_params,
@@ -188,6 +232,8 @@ class ModelOperator:
188
232
  col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
189
233
  else:
190
234
  res = self._model_client.show_models(
235
+ database_name=database_name,
236
+ schema_name=schema_name,
191
237
  model_name=model_name,
192
238
  statement_params=statement_params,
193
239
  )
@@ -198,6 +244,8 @@ class ModelOperator:
198
244
  self,
199
245
  *,
200
246
  comment: str,
247
+ database_name: Optional[sql_identifier.SqlIdentifier],
248
+ schema_name: Optional[sql_identifier.SqlIdentifier],
201
249
  model_name: sql_identifier.SqlIdentifier,
202
250
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
203
251
  statement_params: Optional[Dict[str, Any]] = None,
@@ -205,6 +253,8 @@ class ModelOperator:
205
253
  if version_name:
206
254
  self._model_version_client.set_comment(
207
255
  comment=comment,
256
+ database_name=database_name,
257
+ schema_name=schema_name,
208
258
  model_name=model_name,
209
259
  version_name=version_name,
210
260
  statement_params=statement_params,
@@ -212,6 +262,8 @@ class ModelOperator:
212
262
  else:
213
263
  self._model_client.set_comment(
214
264
  comment=comment,
265
+ database_name=database_name,
266
+ schema_name=schema_name,
215
267
  model_name=model_name,
216
268
  statement_params=statement_params,
217
269
  )
@@ -219,25 +271,42 @@ class ModelOperator:
219
271
  def set_default_version(
220
272
  self,
221
273
  *,
274
+ database_name: Optional[sql_identifier.SqlIdentifier],
275
+ schema_name: Optional[sql_identifier.SqlIdentifier],
222
276
  model_name: sql_identifier.SqlIdentifier,
223
277
  version_name: sql_identifier.SqlIdentifier,
224
278
  statement_params: Optional[Dict[str, Any]] = None,
225
279
  ) -> None:
226
280
  if not self.validate_existence(
227
- model_name=model_name, version_name=version_name, statement_params=statement_params
281
+ database_name=database_name,
282
+ schema_name=schema_name,
283
+ model_name=model_name,
284
+ version_name=version_name,
285
+ statement_params=statement_params,
228
286
  ):
229
287
  raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
230
288
  self._model_version_client.set_default_version(
231
- model_name=model_name, version_name=version_name, statement_params=statement_params
289
+ database_name=database_name,
290
+ schema_name=schema_name,
291
+ model_name=model_name,
292
+ version_name=version_name,
293
+ statement_params=statement_params,
232
294
  )
233
295
 
234
296
  def get_default_version(
235
297
  self,
236
298
  *,
299
+ database_name: Optional[sql_identifier.SqlIdentifier],
300
+ schema_name: Optional[sql_identifier.SqlIdentifier],
237
301
  model_name: sql_identifier.SqlIdentifier,
238
302
  statement_params: Optional[Dict[str, Any]] = None,
239
303
  ) -> sql_identifier.SqlIdentifier:
240
- res = self._model_client.show_models(model_name=model_name, statement_params=statement_params)[0]
304
+ res = self._model_client.show_models(
305
+ database_name=database_name,
306
+ schema_name=schema_name,
307
+ model_name=model_name,
308
+ statement_params=statement_params,
309
+ )[0]
241
310
  return sql_identifier.SqlIdentifier(
242
311
  res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
243
312
  )
@@ -245,14 +314,18 @@ class ModelOperator:
245
314
  def get_tag_value(
246
315
  self,
247
316
  *,
317
+ database_name: Optional[sql_identifier.SqlIdentifier],
318
+ schema_name: Optional[sql_identifier.SqlIdentifier],
248
319
  model_name: sql_identifier.SqlIdentifier,
249
- tag_database_name: sql_identifier.SqlIdentifier,
250
- tag_schema_name: sql_identifier.SqlIdentifier,
320
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
321
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
251
322
  tag_name: sql_identifier.SqlIdentifier,
252
323
  statement_params: Optional[Dict[str, Any]] = None,
253
324
  ) -> Optional[str]:
254
325
  r = self._tag_client.get_tag_value(
255
- module_name=model_name,
326
+ database_name=database_name,
327
+ schema_name=schema_name,
328
+ model_name=model_name,
256
329
  tag_database_name=tag_database_name,
257
330
  tag_schema_name=tag_schema_name,
258
331
  tag_name=tag_name,
@@ -266,11 +339,15 @@ class ModelOperator:
266
339
  def show_tags(
267
340
  self,
268
341
  *,
342
+ database_name: Optional[sql_identifier.SqlIdentifier],
343
+ schema_name: Optional[sql_identifier.SqlIdentifier],
269
344
  model_name: sql_identifier.SqlIdentifier,
270
345
  statement_params: Optional[Dict[str, Any]] = None,
271
346
  ) -> Dict[str, str]:
272
347
  tags_info = self._tag_client.get_tag_list(
273
- module_name=model_name,
348
+ database_name=database_name,
349
+ schema_name=schema_name,
350
+ model_name=model_name,
274
351
  statement_params=statement_params,
275
352
  )
276
353
  res: Dict[str, str] = {
@@ -286,14 +363,18 @@ class ModelOperator:
286
363
  def set_tag(
287
364
  self,
288
365
  *,
366
+ database_name: Optional[sql_identifier.SqlIdentifier],
367
+ schema_name: Optional[sql_identifier.SqlIdentifier],
289
368
  model_name: sql_identifier.SqlIdentifier,
290
- tag_database_name: sql_identifier.SqlIdentifier,
291
- tag_schema_name: sql_identifier.SqlIdentifier,
369
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
370
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
292
371
  tag_name: sql_identifier.SqlIdentifier,
293
372
  tag_value: str,
294
373
  statement_params: Optional[Dict[str, Any]] = None,
295
374
  ) -> None:
296
375
  self._tag_client.set_tag_on_model(
376
+ database_name=database_name,
377
+ schema_name=schema_name,
297
378
  model_name=model_name,
298
379
  tag_database_name=tag_database_name,
299
380
  tag_schema_name=tag_schema_name,
@@ -305,13 +386,17 @@ class ModelOperator:
305
386
  def unset_tag(
306
387
  self,
307
388
  *,
389
+ database_name: Optional[sql_identifier.SqlIdentifier],
390
+ schema_name: Optional[sql_identifier.SqlIdentifier],
308
391
  model_name: sql_identifier.SqlIdentifier,
309
- tag_database_name: sql_identifier.SqlIdentifier,
310
- tag_schema_name: sql_identifier.SqlIdentifier,
392
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
393
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
311
394
  tag_name: sql_identifier.SqlIdentifier,
312
395
  statement_params: Optional[Dict[str, Any]] = None,
313
396
  ) -> None:
314
397
  self._tag_client.unset_tag_on_model(
398
+ database_name=database_name,
399
+ schema_name=schema_name,
315
400
  model_name=model_name,
316
401
  tag_database_name=tag_database_name,
317
402
  tag_schema_name=tag_schema_name,
@@ -322,12 +407,16 @@ class ModelOperator:
322
407
  def get_model_version_manifest(
323
408
  self,
324
409
  *,
410
+ database_name: Optional[sql_identifier.SqlIdentifier],
411
+ schema_name: Optional[sql_identifier.SqlIdentifier],
325
412
  model_name: sql_identifier.SqlIdentifier,
326
413
  version_name: sql_identifier.SqlIdentifier,
327
414
  statement_params: Optional[Dict[str, Any]] = None,
328
415
  ) -> model_manifest_schema.ModelManifestDict:
329
416
  with tempfile.TemporaryDirectory() as tmpdir:
330
417
  self._model_version_client.get_file(
418
+ database_name=database_name,
419
+ schema_name=schema_name,
331
420
  model_name=model_name,
332
421
  version_name=version_name,
333
422
  file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
@@ -337,16 +426,6 @@ class ModelOperator:
337
426
  mm = model_manifest.ModelManifest(pathlib.Path(tmpdir))
338
427
  return mm.load()
339
428
 
340
- @contextmanager
341
- def _enable_model_details(
342
- self,
343
- *,
344
- statement_params: Optional[Dict[str, Any]] = None,
345
- ) -> Generator[None, None, None]:
346
- self._model_client.config_model_details(enable=True, statement_params=statement_params)
347
- yield
348
- self._model_client.config_model_details(enable=False, statement_params=statement_params)
349
-
350
429
  @staticmethod
351
430
  def _match_model_spec_with_sql_functions(
352
431
  sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
@@ -370,68 +449,75 @@ class ModelOperator:
370
449
  def get_functions(
371
450
  self,
372
451
  *,
452
+ database_name: Optional[sql_identifier.SqlIdentifier],
453
+ schema_name: Optional[sql_identifier.SqlIdentifier],
373
454
  model_name: sql_identifier.SqlIdentifier,
374
455
  version_name: sql_identifier.SqlIdentifier,
375
456
  statement_params: Optional[Dict[str, Any]] = None,
376
457
  ) -> List[model_manifest_schema.ModelFunctionInfo]:
377
- with self._enable_model_details(statement_params=statement_params):
378
- raw_model_spec_res = self._model_client.show_versions(
379
- model_name=model_name,
380
- version_name=version_name,
381
- check_model_details=True,
382
- statement_params=statement_params,
383
- )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
384
- model_spec_dict = yaml.safe_load(raw_model_spec_res)
385
- model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
386
- show_functions_res = self._model_version_client.show_functions(
387
- model_name=model_name,
388
- version_name=version_name,
389
- statement_params=statement_params,
458
+ raw_model_spec_res = self._model_client.show_versions(
459
+ database_name=database_name,
460
+ schema_name=schema_name,
461
+ model_name=model_name,
462
+ version_name=version_name,
463
+ check_model_details=True,
464
+ statement_params={**(statement_params or {}), "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True},
465
+ )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
466
+ model_spec_dict = yaml.safe_load(raw_model_spec_res)
467
+ model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
468
+ show_functions_res = self._model_version_client.show_functions(
469
+ database_name=database_name,
470
+ schema_name=schema_name,
471
+ model_name=model_name,
472
+ version_name=version_name,
473
+ statement_params=statement_params,
474
+ )
475
+ function_names_and_types = []
476
+ for r in show_functions_res:
477
+ function_name = sql_identifier.SqlIdentifier(
478
+ r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
390
479
  )
391
- function_names_and_types = []
392
- for r in show_functions_res:
393
- function_name = sql_identifier.SqlIdentifier(
394
- r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
395
- )
396
480
 
397
- function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
398
- try:
399
- return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME]
400
- except KeyError:
401
- pass
402
- else:
403
- if "TABLE" in return_type:
404
- function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
405
-
406
- function_names_and_types.append((function_name, function_type))
407
-
408
- signatures = model_spec["signatures"]
409
- function_names = [name for name, _ in function_names_and_types]
410
- function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
411
- function_names, list(signatures.keys())
412
- )
481
+ function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
482
+ try:
483
+ return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME]
484
+ except KeyError:
485
+ pass
486
+ else:
487
+ if "TABLE" in return_type:
488
+ function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
413
489
 
414
- return [
415
- model_manifest_schema.ModelFunctionInfo(
416
- name=function_name.identifier(),
417
- target_method=function_name_mapping[function_name],
418
- target_method_function_type=function_type,
419
- signature=model_signature.ModelSignature.from_dict(
420
- signatures[function_name_mapping[function_name]]
421
- ),
422
- )
423
- for function_name, function_type in function_names_and_types
424
- ]
490
+ function_names_and_types.append((function_name, function_type))
491
+
492
+ signatures = model_spec["signatures"]
493
+ function_names = [name for name, _ in function_names_and_types]
494
+ function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
495
+ function_names, list(signatures.keys())
496
+ )
497
+
498
+ return [
499
+ model_manifest_schema.ModelFunctionInfo(
500
+ name=function_name.identifier(),
501
+ target_method=function_name_mapping[function_name],
502
+ target_method_function_type=function_type,
503
+ signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]),
504
+ )
505
+ for function_name, function_type in function_names_and_types
506
+ ]
425
507
 
426
508
  def invoke_method(
427
509
  self,
428
510
  *,
429
511
  method_name: sql_identifier.SqlIdentifier,
512
+ method_function_type: str,
430
513
  signature: model_signature.ModelSignature,
431
514
  X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
515
+ database_name: Optional[sql_identifier.SqlIdentifier],
516
+ schema_name: Optional[sql_identifier.SqlIdentifier],
432
517
  model_name: sql_identifier.SqlIdentifier,
433
518
  version_name: sql_identifier.SqlIdentifier,
434
519
  strict_input_validation: bool = False,
520
+ partition_column: Optional[sql_identifier.SqlIdentifier] = None,
435
521
  statement_params: Optional[Dict[str, str]] = None,
436
522
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
437
523
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
@@ -469,15 +555,31 @@ class ModelOperator:
469
555
  if output_name in original_cols:
470
556
  original_cols.remove(output_name)
471
557
 
472
- df_res = self._model_version_client.invoke_method(
473
- method_name=method_name,
474
- input_df=s_df,
475
- input_args=input_args,
476
- returns=returns,
477
- model_name=model_name,
478
- version_name=version_name,
479
- statement_params=statement_params,
480
- )
558
+ if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
559
+ df_res = self._model_version_client.invoke_function_method(
560
+ method_name=method_name,
561
+ input_df=s_df,
562
+ input_args=input_args,
563
+ returns=returns,
564
+ database_name=database_name,
565
+ schema_name=schema_name,
566
+ model_name=model_name,
567
+ version_name=version_name,
568
+ statement_params=statement_params,
569
+ )
570
+ elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
571
+ df_res = self._model_version_client.invoke_table_function_method(
572
+ method_name=method_name,
573
+ input_df=s_df,
574
+ input_args=input_args,
575
+ partition_column=partition_column,
576
+ returns=returns,
577
+ database_name=database_name,
578
+ schema_name=schema_name,
579
+ model_name=model_name,
580
+ version_name=version_name,
581
+ statement_params=statement_params,
582
+ )
481
583
 
482
584
  if keep_order:
483
585
  df_res = df_res.sort(
@@ -486,7 +588,11 @@ class ModelOperator:
486
588
  )
487
589
 
488
590
  if not output_with_input_features:
489
- df_res = df_res.drop(*original_cols)
591
+ cols_to_drop = original_cols
592
+ if partition_column is not None:
593
+ # don't drop partition column
594
+ cols_to_drop.remove(partition_column.identifier())
595
+ df_res = df_res.drop(*cols_to_drop)
490
596
 
491
597
  # Get final result
492
598
  if not isinstance(X, dataframe.DataFrame):
@@ -497,18 +603,97 @@ class ModelOperator:
497
603
  def delete_model_or_version(
498
604
  self,
499
605
  *,
606
+ database_name: Optional[sql_identifier.SqlIdentifier],
607
+ schema_name: Optional[sql_identifier.SqlIdentifier],
500
608
  model_name: sql_identifier.SqlIdentifier,
501
609
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
502
610
  statement_params: Optional[Dict[str, Any]] = None,
503
611
  ) -> None:
504
612
  if version_name:
505
613
  self._model_version_client.drop_version(
614
+ database_name=database_name,
615
+ schema_name=schema_name,
506
616
  model_name=model_name,
507
617
  version_name=version_name,
508
618
  statement_params=statement_params,
509
619
  )
510
620
  else:
511
621
  self._model_client.drop_model(
622
+ database_name=database_name,
623
+ schema_name=schema_name,
624
+ model_name=model_name,
625
+ statement_params=statement_params,
626
+ )
627
+
628
+ def rename(
629
+ self,
630
+ *,
631
+ database_name: Optional[sql_identifier.SqlIdentifier],
632
+ schema_name: Optional[sql_identifier.SqlIdentifier],
633
+ model_name: sql_identifier.SqlIdentifier,
634
+ new_model_db: Optional[sql_identifier.SqlIdentifier],
635
+ new_model_schema: Optional[sql_identifier.SqlIdentifier],
636
+ new_model_name: sql_identifier.SqlIdentifier,
637
+ statement_params: Optional[Dict[str, Any]] = None,
638
+ ) -> None:
639
+ self._model_client.rename(
640
+ database_name=database_name,
641
+ schema_name=schema_name,
642
+ model_name=model_name,
643
+ new_model_db=new_model_db,
644
+ new_model_schema=new_model_schema,
645
+ new_model_name=new_model_name,
646
+ statement_params=statement_params,
647
+ )
648
+
649
+ # Map indicating in different modes, the path to list and download.
650
+ # The boolean value indicates if it is a directory,
651
+ MODEL_FILE_DOWNLOAD_PATTERN = {
652
+ "minimal": {
653
+ pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
654
+ / model_meta.MODEL_METADATA_FILE: False,
655
+ pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) / model_env._DEFAULT_ENV_DIR: True,
656
+ pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
657
+ / model_runtime.ModelRuntime.RUNTIME_DIR_REL_PATH: True,
658
+ },
659
+ "model": {pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH): True},
660
+ "full": {pathlib.PurePosixPath(os.curdir): True},
661
+ }
662
+
663
+ def download_files(
664
+ self,
665
+ *,
666
+ database_name: Optional[sql_identifier.SqlIdentifier],
667
+ schema_name: Optional[sql_identifier.SqlIdentifier],
668
+ model_name: sql_identifier.SqlIdentifier,
669
+ version_name: sql_identifier.SqlIdentifier,
670
+ target_path: pathlib.Path,
671
+ mode: Literal["full", "model", "minimal"] = "model",
672
+ statement_params: Optional[Dict[str, Any]] = None,
673
+ ) -> None:
674
+ for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
675
+ list_file_res = self._model_version_client.list_file(
676
+ database_name=database_name,
677
+ schema_name=schema_name,
512
678
  model_name=model_name,
679
+ version_name=version_name,
680
+ file_path=remote_rel_path,
681
+ is_dir=is_dir,
513
682
  statement_params=statement_params,
514
683
  )
684
+ file_list = [
685
+ pathlib.PurePosixPath(*pathlib.PurePosixPath(row.name).parts[2:]) # versions/<version_name>/...
686
+ for row in list_file_res
687
+ ]
688
+ for stage_file_path in file_list:
689
+ local_file_dir = target_path / stage_file_path.parent
690
+ local_file_dir.mkdir(parents=True, exist_ok=True)
691
+ self._model_version_client.get_file(
692
+ database_name=database_name,
693
+ schema_name=schema_name,
694
+ model_name=model_name,
695
+ version_name=version_name,
696
+ file_path=stage_file_path,
697
+ target_path=local_file_dir,
698
+ statement_params=statement_params,
699
+ )
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ from snowflake.ml._internal.utils import identifier, sql_identifier
4
+ from snowflake.snowpark import session
5
+
6
+
7
+ class _BaseSQLClient:
8
+ def __init__(
9
+ self,
10
+ session: session.Session,
11
+ *,
12
+ database_name: sql_identifier.SqlIdentifier,
13
+ schema_name: sql_identifier.SqlIdentifier,
14
+ ) -> None:
15
+ self._session = session
16
+ self._database_name = database_name
17
+ self._schema_name = schema_name
18
+
19
+ def __eq__(self, __value: object) -> bool:
20
+ if not isinstance(__value, _BaseSQLClient):
21
+ return False
22
+ return self._database_name == __value._database_name and self._schema_name == __value._schema_name
23
+
24
+ def fully_qualified_object_name(
25
+ self,
26
+ database_name: Optional[sql_identifier.SqlIdentifier],
27
+ schema_name: Optional[sql_identifier.SqlIdentifier],
28
+ object_name: sql_identifier.SqlIdentifier,
29
+ ) -> str:
30
+ actual_database_name = database_name or self._database_name
31
+ actual_schema_name = schema_name or self._schema_name
32
+ return identifier.get_schema_level_object_identifier(
33
+ actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier()
34
+ )