snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. snowflake/cortex/_sentiment.py +7 -4
  2. snowflake/ml/_internal/env_utils.py +6 -0
  3. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  4. snowflake/ml/_internal/telemetry.py +1 -0
  5. snowflake/ml/_internal/utils/identifier.py +1 -1
  6. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  7. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  8. snowflake/ml/dataset/__init__.py +2 -1
  9. snowflake/ml/dataset/dataset.py +4 -3
  10. snowflake/ml/dataset/dataset_reader.py +5 -8
  11. snowflake/ml/feature_store/__init__.py +6 -0
  12. snowflake/ml/feature_store/access_manager.py +283 -0
  13. snowflake/ml/feature_store/feature_store.py +160 -100
  14. snowflake/ml/feature_store/feature_view.py +30 -19
  15. snowflake/ml/fileset/embedded_stage_fs.py +15 -12
  16. snowflake/ml/fileset/snowfs.py +2 -30
  17. snowflake/ml/fileset/stage_fs.py +25 -7
  18. snowflake/ml/model/_client/model/model_impl.py +46 -39
  19. snowflake/ml/model/_client/model/model_version_impl.py +24 -2
  20. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  21. snowflake/ml/model/_client/ops/model_ops.py +174 -16
  22. snowflake/ml/model/_client/sql/_base.py +34 -0
  23. snowflake/ml/model/_client/sql/model.py +32 -39
  24. snowflake/ml/model/_client/sql/model_version.py +111 -42
  25. snowflake/ml/model/_client/sql/stage.py +6 -32
  26. snowflake/ml/model/_client/sql/tag.py +32 -56
  27. snowflake/ml/model/_model_composer/model_composer.py +8 -4
  28. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  29. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  30. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  31. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
  32. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
  33. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
  34. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
  35. snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
  36. snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
  37. snowflake/ml/modeling/cluster/birch.py +8 -1
  38. snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
  39. snowflake/ml/modeling/cluster/dbscan.py +8 -1
  40. snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
  41. snowflake/ml/modeling/cluster/k_means.py +8 -1
  42. snowflake/ml/modeling/cluster/mean_shift.py +8 -1
  43. snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
  44. snowflake/ml/modeling/cluster/optics.py +8 -1
  45. snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
  46. snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
  47. snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
  48. snowflake/ml/modeling/compose/column_transformer.py +8 -1
  49. snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
  50. snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
  51. snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
  52. snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
  53. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
  54. snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
  55. snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
  56. snowflake/ml/modeling/covariance/oas.py +8 -1
  57. snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
  58. snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
  59. snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
  60. snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
  61. snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
  62. snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
  63. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
  64. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
  65. snowflake/ml/modeling/decomposition/pca.py +8 -1
  66. snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
  67. snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
  68. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
  69. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
  70. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
  71. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
  72. snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
  73. snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
  74. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
  75. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
  76. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
  77. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
  79. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
  80. snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
  81. snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
  82. snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
  83. snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
  84. snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
  85. snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
  86. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
  87. snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
  88. snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
  89. snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
  90. snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
  91. snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
  92. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
  93. snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
  94. snowflake/ml/modeling/framework/base.py +4 -3
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
  96. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
  97. snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
  98. snowflake/ml/modeling/impute/knn_imputer.py +8 -1
  99. snowflake/ml/modeling/impute/missing_indicator.py +8 -1
  100. snowflake/ml/modeling/impute/simple_imputer.py +21 -2
  101. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
  102. snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
  103. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
  104. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
  105. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
  106. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
  107. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
  108. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
  109. snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
  110. snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
  111. snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
  112. snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
  113. snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
  114. snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
  115. snowflake/ml/modeling/linear_model/lars.py +8 -1
  116. snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
  117. snowflake/ml/modeling/linear_model/lasso.py +8 -1
  118. snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
  119. snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
  120. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
  122. snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
  123. snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
  124. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
  126. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
  127. snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
  128. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
  129. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
  130. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
  131. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
  132. snowflake/ml/modeling/linear_model/perceptron.py +8 -1
  133. snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
  134. snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
  135. snowflake/ml/modeling/linear_model/ridge.py +8 -1
  136. snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
  137. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
  138. snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
  139. snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
  140. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
  141. snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
  142. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
  143. snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
  144. snowflake/ml/modeling/manifold/isomap.py +8 -1
  145. snowflake/ml/modeling/manifold/mds.py +8 -1
  146. snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
  147. snowflake/ml/modeling/manifold/tsne.py +8 -1
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
  150. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
  151. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
  152. snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
  153. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
  154. snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
  155. snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
  156. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
  157. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
  158. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
  159. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
  160. snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
  161. snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
  162. snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
  163. snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
  164. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
  165. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
  166. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
  167. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
  168. snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
  169. snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
  170. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  171. snowflake/ml/modeling/pipeline/pipeline.py +27 -7
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
  173. snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
  174. snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
  175. snowflake/ml/modeling/svm/linear_svc.py +8 -1
  176. snowflake/ml/modeling/svm/linear_svr.py +8 -1
  177. snowflake/ml/modeling/svm/nu_svc.py +8 -1
  178. snowflake/ml/modeling/svm/nu_svr.py +8 -1
  179. snowflake/ml/modeling/svm/svc.py +8 -1
  180. snowflake/ml/modeling/svm/svr.py +8 -1
  181. snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
  182. snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
  183. snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
  184. snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
  185. snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
  186. snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
  187. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
  188. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
  189. snowflake/ml/registry/_manager/model_manager.py +95 -8
  190. snowflake/ml/registry/registry.py +10 -1
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
  193. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
  194. snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
  195. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
  196. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
  197. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
@@ -74,37 +74,57 @@ class ModelOperator:
74
74
  and self._model_version_client == __value._model_version_client
75
75
  )
76
76
 
77
- def prepare_model_stage_path(self, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
77
+ def prepare_model_stage_path(
78
+ self,
79
+ *,
80
+ database_name: Optional[sql_identifier.SqlIdentifier],
81
+ schema_name: Optional[sql_identifier.SqlIdentifier],
82
+ statement_params: Optional[Dict[str, Any]] = None,
83
+ ) -> str:
78
84
  stage_name = sql_identifier.SqlIdentifier(
79
85
  snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
80
86
  )
81
- self._stage_client.create_tmp_stage(stage_name=stage_name, statement_params=statement_params)
82
- return f"@{self._stage_client.fully_qualified_stage_name(stage_name)}/model"
87
+ self._stage_client.create_tmp_stage(
88
+ database_name=database_name,
89
+ schema_name=schema_name,
90
+ stage_name=stage_name,
91
+ statement_params=statement_params,
92
+ )
93
+ return f"@{self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)}/model"
83
94
 
84
95
  def create_from_stage(
85
96
  self,
86
97
  composed_model: model_composer.ModelComposer,
87
98
  *,
99
+ database_name: Optional[sql_identifier.SqlIdentifier],
100
+ schema_name: Optional[sql_identifier.SqlIdentifier],
88
101
  model_name: sql_identifier.SqlIdentifier,
89
102
  version_name: sql_identifier.SqlIdentifier,
90
103
  statement_params: Optional[Dict[str, Any]] = None,
91
104
  ) -> None:
92
105
  stage_path = str(composed_model.stage_path)
93
106
  if self.validate_existence(
107
+ database_name=database_name,
108
+ schema_name=schema_name,
94
109
  model_name=model_name,
95
110
  statement_params=statement_params,
96
111
  ):
97
112
  if self.validate_existence(
113
+ database_name=database_name,
114
+ schema_name=schema_name,
98
115
  model_name=model_name,
99
116
  version_name=version_name,
100
117
  statement_params=statement_params,
101
118
  ):
102
119
  raise ValueError(
103
- f"Model {self._model_version_client.fully_qualified_model_name(model_name)} "
104
- f"version {version_name} already existed."
120
+ "Model "
121
+ f"{self._model_version_client.fully_qualified_object_name(database_name, schema_name, model_name)}"
122
+ f" version {version_name} already existed."
105
123
  )
106
124
  else:
107
125
  self._model_version_client.add_version_from_stage(
126
+ database_name=database_name,
127
+ schema_name=schema_name,
108
128
  stage_path=stage_path,
109
129
  model_name=model_name,
110
130
  version_name=version_name,
@@ -112,26 +132,77 @@ class ModelOperator:
112
132
  )
113
133
  else:
114
134
  self._model_version_client.create_from_stage(
135
+ database_name=database_name,
136
+ schema_name=schema_name,
115
137
  stage_path=stage_path,
116
138
  model_name=model_name,
117
139
  version_name=version_name,
118
140
  statement_params=statement_params,
119
141
  )
120
142
 
143
+ def create_from_model_version(
144
+ self,
145
+ *,
146
+ source_database_name: Optional[sql_identifier.SqlIdentifier],
147
+ source_schema_name: Optional[sql_identifier.SqlIdentifier],
148
+ source_model_name: sql_identifier.SqlIdentifier,
149
+ source_version_name: sql_identifier.SqlIdentifier,
150
+ database_name: Optional[sql_identifier.SqlIdentifier],
151
+ schema_name: Optional[sql_identifier.SqlIdentifier],
152
+ model_name: sql_identifier.SqlIdentifier,
153
+ version_name: sql_identifier.SqlIdentifier,
154
+ statement_params: Optional[Dict[str, Any]] = None,
155
+ ) -> None:
156
+ if self.validate_existence(
157
+ database_name=database_name,
158
+ schema_name=schema_name,
159
+ model_name=model_name,
160
+ statement_params=statement_params,
161
+ ):
162
+ return self._model_version_client.add_version_from_model_version(
163
+ source_database_name=source_database_name,
164
+ source_schema_name=source_schema_name,
165
+ source_model_name=source_model_name,
166
+ source_version_name=source_version_name,
167
+ database_name=database_name,
168
+ schema_name=schema_name,
169
+ model_name=model_name,
170
+ version_name=version_name,
171
+ statement_params=statement_params,
172
+ )
173
+ else:
174
+ return self._model_version_client.create_from_model_version(
175
+ source_database_name=source_database_name,
176
+ source_schema_name=source_schema_name,
177
+ source_model_name=source_model_name,
178
+ source_version_name=source_version_name,
179
+ database_name=database_name,
180
+ schema_name=schema_name,
181
+ model_name=model_name,
182
+ version_name=version_name,
183
+ statement_params=statement_params,
184
+ )
185
+
121
186
  def show_models_or_versions(
122
187
  self,
123
188
  *,
189
+ database_name: Optional[sql_identifier.SqlIdentifier],
190
+ schema_name: Optional[sql_identifier.SqlIdentifier],
124
191
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
125
192
  statement_params: Optional[Dict[str, Any]] = None,
126
193
  ) -> List[row.Row]:
127
194
  if model_name:
128
195
  return self._model_client.show_versions(
196
+ database_name=database_name,
197
+ schema_name=schema_name,
129
198
  model_name=model_name,
130
199
  validate_result=False,
131
200
  statement_params=statement_params,
132
201
  )
133
202
  else:
134
203
  return self._model_client.show_models(
204
+ database_name=database_name,
205
+ schema_name=schema_name,
135
206
  validate_result=False,
136
207
  statement_params=statement_params,
137
208
  )
@@ -139,10 +210,14 @@ class ModelOperator:
139
210
  def list_models_or_versions(
140
211
  self,
141
212
  *,
213
+ database_name: Optional[sql_identifier.SqlIdentifier],
214
+ schema_name: Optional[sql_identifier.SqlIdentifier],
142
215
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
143
216
  statement_params: Optional[Dict[str, Any]] = None,
144
217
  ) -> List[sql_identifier.SqlIdentifier]:
145
218
  res = self.show_models_or_versions(
219
+ database_name=database_name,
220
+ schema_name=schema_name,
146
221
  model_name=model_name,
147
222
  statement_params=statement_params,
148
223
  )
@@ -155,12 +230,16 @@ class ModelOperator:
155
230
  def validate_existence(
156
231
  self,
157
232
  *,
233
+ database_name: Optional[sql_identifier.SqlIdentifier],
234
+ schema_name: Optional[sql_identifier.SqlIdentifier],
158
235
  model_name: sql_identifier.SqlIdentifier,
159
236
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
160
237
  statement_params: Optional[Dict[str, Any]] = None,
161
238
  ) -> bool:
162
239
  if version_name:
163
240
  res = self._model_client.show_versions(
241
+ database_name=database_name,
242
+ schema_name=schema_name,
164
243
  model_name=model_name,
165
244
  version_name=version_name,
166
245
  validate_result=False,
@@ -168,6 +247,8 @@ class ModelOperator:
168
247
  )
169
248
  else:
170
249
  res = self._model_client.show_models(
250
+ database_name=database_name,
251
+ schema_name=schema_name,
171
252
  model_name=model_name,
172
253
  validate_result=False,
173
254
  statement_params=statement_params,
@@ -177,12 +258,16 @@ class ModelOperator:
177
258
  def get_comment(
178
259
  self,
179
260
  *,
261
+ database_name: Optional[sql_identifier.SqlIdentifier],
262
+ schema_name: Optional[sql_identifier.SqlIdentifier],
180
263
  model_name: sql_identifier.SqlIdentifier,
181
264
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
182
265
  statement_params: Optional[Dict[str, Any]] = None,
183
266
  ) -> str:
184
267
  if version_name:
185
268
  res = self._model_client.show_versions(
269
+ database_name=database_name,
270
+ schema_name=schema_name,
186
271
  model_name=model_name,
187
272
  version_name=version_name,
188
273
  statement_params=statement_params,
@@ -190,6 +275,8 @@ class ModelOperator:
190
275
  col_name = self._model_client.MODEL_VERSION_COMMENT_COL_NAME
191
276
  else:
192
277
  res = self._model_client.show_models(
278
+ database_name=database_name,
279
+ schema_name=schema_name,
193
280
  model_name=model_name,
194
281
  statement_params=statement_params,
195
282
  )
@@ -200,6 +287,8 @@ class ModelOperator:
200
287
  self,
201
288
  *,
202
289
  comment: str,
290
+ database_name: Optional[sql_identifier.SqlIdentifier],
291
+ schema_name: Optional[sql_identifier.SqlIdentifier],
203
292
  model_name: sql_identifier.SqlIdentifier,
204
293
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
205
294
  statement_params: Optional[Dict[str, Any]] = None,
@@ -207,6 +296,8 @@ class ModelOperator:
207
296
  if version_name:
208
297
  self._model_version_client.set_comment(
209
298
  comment=comment,
299
+ database_name=database_name,
300
+ schema_name=schema_name,
210
301
  model_name=model_name,
211
302
  version_name=version_name,
212
303
  statement_params=statement_params,
@@ -214,6 +305,8 @@ class ModelOperator:
214
305
  else:
215
306
  self._model_client.set_comment(
216
307
  comment=comment,
308
+ database_name=database_name,
309
+ schema_name=schema_name,
217
310
  model_name=model_name,
218
311
  statement_params=statement_params,
219
312
  )
@@ -221,25 +314,42 @@ class ModelOperator:
221
314
  def set_default_version(
222
315
  self,
223
316
  *,
317
+ database_name: Optional[sql_identifier.SqlIdentifier],
318
+ schema_name: Optional[sql_identifier.SqlIdentifier],
224
319
  model_name: sql_identifier.SqlIdentifier,
225
320
  version_name: sql_identifier.SqlIdentifier,
226
321
  statement_params: Optional[Dict[str, Any]] = None,
227
322
  ) -> None:
228
323
  if not self.validate_existence(
229
- model_name=model_name, version_name=version_name, statement_params=statement_params
324
+ database_name=database_name,
325
+ schema_name=schema_name,
326
+ model_name=model_name,
327
+ version_name=version_name,
328
+ statement_params=statement_params,
230
329
  ):
231
330
  raise ValueError(f"You cannot set version {version_name} as default version as it does not exist.")
232
331
  self._model_version_client.set_default_version(
233
- model_name=model_name, version_name=version_name, statement_params=statement_params
332
+ database_name=database_name,
333
+ schema_name=schema_name,
334
+ model_name=model_name,
335
+ version_name=version_name,
336
+ statement_params=statement_params,
234
337
  )
235
338
 
236
339
  def get_default_version(
237
340
  self,
238
341
  *,
342
+ database_name: Optional[sql_identifier.SqlIdentifier],
343
+ schema_name: Optional[sql_identifier.SqlIdentifier],
239
344
  model_name: sql_identifier.SqlIdentifier,
240
345
  statement_params: Optional[Dict[str, Any]] = None,
241
346
  ) -> sql_identifier.SqlIdentifier:
242
- res = self._model_client.show_models(model_name=model_name, statement_params=statement_params)[0]
347
+ res = self._model_client.show_models(
348
+ database_name=database_name,
349
+ schema_name=schema_name,
350
+ model_name=model_name,
351
+ statement_params=statement_params,
352
+ )[0]
243
353
  return sql_identifier.SqlIdentifier(
244
354
  res[self._model_client.MODEL_DEFAULT_VERSION_NAME_COL_NAME], case_sensitive=True
245
355
  )
@@ -247,14 +357,18 @@ class ModelOperator:
247
357
  def get_tag_value(
248
358
  self,
249
359
  *,
360
+ database_name: Optional[sql_identifier.SqlIdentifier],
361
+ schema_name: Optional[sql_identifier.SqlIdentifier],
250
362
  model_name: sql_identifier.SqlIdentifier,
251
- tag_database_name: sql_identifier.SqlIdentifier,
252
- tag_schema_name: sql_identifier.SqlIdentifier,
363
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
364
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
253
365
  tag_name: sql_identifier.SqlIdentifier,
254
366
  statement_params: Optional[Dict[str, Any]] = None,
255
367
  ) -> Optional[str]:
256
368
  r = self._tag_client.get_tag_value(
257
- module_name=model_name,
369
+ database_name=database_name,
370
+ schema_name=schema_name,
371
+ model_name=model_name,
258
372
  tag_database_name=tag_database_name,
259
373
  tag_schema_name=tag_schema_name,
260
374
  tag_name=tag_name,
@@ -268,11 +382,15 @@ class ModelOperator:
268
382
  def show_tags(
269
383
  self,
270
384
  *,
385
+ database_name: Optional[sql_identifier.SqlIdentifier],
386
+ schema_name: Optional[sql_identifier.SqlIdentifier],
271
387
  model_name: sql_identifier.SqlIdentifier,
272
388
  statement_params: Optional[Dict[str, Any]] = None,
273
389
  ) -> Dict[str, str]:
274
390
  tags_info = self._tag_client.get_tag_list(
275
- module_name=model_name,
391
+ database_name=database_name,
392
+ schema_name=schema_name,
393
+ model_name=model_name,
276
394
  statement_params=statement_params,
277
395
  )
278
396
  res: Dict[str, str] = {
@@ -288,14 +406,18 @@ class ModelOperator:
288
406
  def set_tag(
289
407
  self,
290
408
  *,
409
+ database_name: Optional[sql_identifier.SqlIdentifier],
410
+ schema_name: Optional[sql_identifier.SqlIdentifier],
291
411
  model_name: sql_identifier.SqlIdentifier,
292
- tag_database_name: sql_identifier.SqlIdentifier,
293
- tag_schema_name: sql_identifier.SqlIdentifier,
412
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
413
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
294
414
  tag_name: sql_identifier.SqlIdentifier,
295
415
  tag_value: str,
296
416
  statement_params: Optional[Dict[str, Any]] = None,
297
417
  ) -> None:
298
418
  self._tag_client.set_tag_on_model(
419
+ database_name=database_name,
420
+ schema_name=schema_name,
299
421
  model_name=model_name,
300
422
  tag_database_name=tag_database_name,
301
423
  tag_schema_name=tag_schema_name,
@@ -307,13 +429,17 @@ class ModelOperator:
307
429
  def unset_tag(
308
430
  self,
309
431
  *,
432
+ database_name: Optional[sql_identifier.SqlIdentifier],
433
+ schema_name: Optional[sql_identifier.SqlIdentifier],
310
434
  model_name: sql_identifier.SqlIdentifier,
311
- tag_database_name: sql_identifier.SqlIdentifier,
312
- tag_schema_name: sql_identifier.SqlIdentifier,
435
+ tag_database_name: Optional[sql_identifier.SqlIdentifier],
436
+ tag_schema_name: Optional[sql_identifier.SqlIdentifier],
313
437
  tag_name: sql_identifier.SqlIdentifier,
314
438
  statement_params: Optional[Dict[str, Any]] = None,
315
439
  ) -> None:
316
440
  self._tag_client.unset_tag_on_model(
441
+ database_name=database_name,
442
+ schema_name=schema_name,
317
443
  model_name=model_name,
318
444
  tag_database_name=tag_database_name,
319
445
  tag_schema_name=tag_schema_name,
@@ -324,12 +450,16 @@ class ModelOperator:
324
450
  def get_model_version_manifest(
325
451
  self,
326
452
  *,
453
+ database_name: Optional[sql_identifier.SqlIdentifier],
454
+ schema_name: Optional[sql_identifier.SqlIdentifier],
327
455
  model_name: sql_identifier.SqlIdentifier,
328
456
  version_name: sql_identifier.SqlIdentifier,
329
457
  statement_params: Optional[Dict[str, Any]] = None,
330
458
  ) -> model_manifest_schema.ModelManifestDict:
331
459
  with tempfile.TemporaryDirectory() as tmpdir:
332
460
  self._model_version_client.get_file(
461
+ database_name=database_name,
462
+ schema_name=schema_name,
333
463
  model_name=model_name,
334
464
  version_name=version_name,
335
465
  file_path=pathlib.PurePosixPath(model_manifest.ModelManifest.MANIFEST_FILE_REL_PATH),
@@ -362,11 +492,15 @@ class ModelOperator:
362
492
  def get_functions(
363
493
  self,
364
494
  *,
495
+ database_name: Optional[sql_identifier.SqlIdentifier],
496
+ schema_name: Optional[sql_identifier.SqlIdentifier],
365
497
  model_name: sql_identifier.SqlIdentifier,
366
498
  version_name: sql_identifier.SqlIdentifier,
367
499
  statement_params: Optional[Dict[str, Any]] = None,
368
500
  ) -> List[model_manifest_schema.ModelFunctionInfo]:
369
501
  raw_model_spec_res = self._model_client.show_versions(
502
+ database_name=database_name,
503
+ schema_name=schema_name,
370
504
  model_name=model_name,
371
505
  version_name=version_name,
372
506
  check_model_details=True,
@@ -375,6 +509,8 @@ class ModelOperator:
375
509
  model_spec_dict = yaml.safe_load(raw_model_spec_res)
376
510
  model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
377
511
  show_functions_res = self._model_version_client.show_functions(
512
+ database_name=database_name,
513
+ schema_name=schema_name,
378
514
  model_name=model_name,
379
515
  version_name=version_name,
380
516
  statement_params=statement_params,
@@ -419,6 +555,8 @@ class ModelOperator:
419
555
  method_function_type: str,
420
556
  signature: model_signature.ModelSignature,
421
557
  X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
558
+ database_name: Optional[sql_identifier.SqlIdentifier],
559
+ schema_name: Optional[sql_identifier.SqlIdentifier],
422
560
  model_name: sql_identifier.SqlIdentifier,
423
561
  version_name: sql_identifier.SqlIdentifier,
424
562
  strict_input_validation: bool = False,
@@ -466,6 +604,8 @@ class ModelOperator:
466
604
  input_df=s_df,
467
605
  input_args=input_args,
468
606
  returns=returns,
607
+ database_name=database_name,
608
+ schema_name=schema_name,
469
609
  model_name=model_name,
470
610
  version_name=version_name,
471
611
  statement_params=statement_params,
@@ -477,6 +617,8 @@ class ModelOperator:
477
617
  input_args=input_args,
478
618
  partition_column=partition_column,
479
619
  returns=returns,
620
+ database_name=database_name,
621
+ schema_name=schema_name,
480
622
  model_name=model_name,
481
623
  version_name=version_name,
482
624
  statement_params=statement_params,
@@ -504,18 +646,24 @@ class ModelOperator:
504
646
  def delete_model_or_version(
505
647
  self,
506
648
  *,
649
+ database_name: Optional[sql_identifier.SqlIdentifier],
650
+ schema_name: Optional[sql_identifier.SqlIdentifier],
507
651
  model_name: sql_identifier.SqlIdentifier,
508
652
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
509
653
  statement_params: Optional[Dict[str, Any]] = None,
510
654
  ) -> None:
511
655
  if version_name:
512
656
  self._model_version_client.drop_version(
657
+ database_name=database_name,
658
+ schema_name=schema_name,
513
659
  model_name=model_name,
514
660
  version_name=version_name,
515
661
  statement_params=statement_params,
516
662
  )
517
663
  else:
518
664
  self._model_client.drop_model(
665
+ database_name=database_name,
666
+ schema_name=schema_name,
519
667
  model_name=model_name,
520
668
  statement_params=statement_params,
521
669
  )
@@ -523,6 +671,8 @@ class ModelOperator:
523
671
  def rename(
524
672
  self,
525
673
  *,
674
+ database_name: Optional[sql_identifier.SqlIdentifier],
675
+ schema_name: Optional[sql_identifier.SqlIdentifier],
526
676
  model_name: sql_identifier.SqlIdentifier,
527
677
  new_model_db: Optional[sql_identifier.SqlIdentifier],
528
678
  new_model_schema: Optional[sql_identifier.SqlIdentifier],
@@ -530,6 +680,8 @@ class ModelOperator:
530
680
  statement_params: Optional[Dict[str, Any]] = None,
531
681
  ) -> None:
532
682
  self._model_client.rename(
683
+ database_name=database_name,
684
+ schema_name=schema_name,
533
685
  model_name=model_name,
534
686
  new_model_db=new_model_db,
535
687
  new_model_schema=new_model_schema,
@@ -554,6 +706,8 @@ class ModelOperator:
554
706
  def download_files(
555
707
  self,
556
708
  *,
709
+ database_name: Optional[sql_identifier.SqlIdentifier],
710
+ schema_name: Optional[sql_identifier.SqlIdentifier],
557
711
  model_name: sql_identifier.SqlIdentifier,
558
712
  version_name: sql_identifier.SqlIdentifier,
559
713
  target_path: pathlib.Path,
@@ -562,6 +716,8 @@ class ModelOperator:
562
716
  ) -> None:
563
717
  for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
564
718
  list_file_res = self._model_version_client.list_file(
719
+ database_name=database_name,
720
+ schema_name=schema_name,
565
721
  model_name=model_name,
566
722
  version_name=version_name,
567
723
  file_path=remote_rel_path,
@@ -576,6 +732,8 @@ class ModelOperator:
576
732
  local_file_dir = target_path / stage_file_path.parent
577
733
  local_file_dir.mkdir(parents=True, exist_ok=True)
578
734
  self._model_version_client.get_file(
735
+ database_name=database_name,
736
+ schema_name=schema_name,
579
737
  model_name=model_name,
580
738
  version_name=version_name,
581
739
  file_path=stage_file_path,
@@ -0,0 +1,34 @@
1
+ from typing import Optional
2
+
3
+ from snowflake.ml._internal.utils import identifier, sql_identifier
4
+ from snowflake.snowpark import session
5
+
6
+
7
+ class _BaseSQLClient:
8
+ def __init__(
9
+ self,
10
+ session: session.Session,
11
+ *,
12
+ database_name: sql_identifier.SqlIdentifier,
13
+ schema_name: sql_identifier.SqlIdentifier,
14
+ ) -> None:
15
+ self._session = session
16
+ self._database_name = database_name
17
+ self._schema_name = schema_name
18
+
19
+ def __eq__(self, __value: object) -> bool:
20
+ if not isinstance(__value, _BaseSQLClient):
21
+ return False
22
+ return self._database_name == __value._database_name and self._schema_name == __value._schema_name
23
+
24
+ def fully_qualified_object_name(
25
+ self,
26
+ database_name: Optional[sql_identifier.SqlIdentifier],
27
+ schema_name: Optional[sql_identifier.SqlIdentifier],
28
+ object_name: sql_identifier.SqlIdentifier,
29
+ ) -> str:
30
+ actual_database_name = database_name or self._database_name
31
+ actual_schema_name = schema_name or self._schema_name
32
+ return identifier.get_schema_level_object_identifier(
33
+ actual_database_name.identifier(), actual_schema_name.identifier(), object_name.identifier()
34
+ )
@@ -1,14 +1,11 @@
1
1
  from typing import Any, Dict, List, Optional
2
2
 
3
- from snowflake.ml._internal.utils import (
4
- identifier,
5
- query_result_checker,
6
- sql_identifier,
7
- )
8
- from snowflake.snowpark import row, session
3
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
+ from snowflake.ml.model._client.sql import _base
5
+ from snowflake.snowpark import row
9
6
 
10
7
 
11
- class ModelSQLClient:
8
+ class ModelSQLClient(_base._BaseSQLClient):
12
9
  MODEL_NAME_COL_NAME = "name"
13
10
  MODEL_COMMENT_COL_NAME = "comment"
14
11
  MODEL_DEFAULT_VERSION_NAME_COL_NAME = "default_version_name"
@@ -18,35 +15,18 @@ class ModelSQLClient:
18
15
  MODEL_VERSION_METADATA_COL_NAME = "metadata"
19
16
  MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
20
17
 
21
- def __init__(
22
- self,
23
- session: session.Session,
24
- *,
25
- database_name: sql_identifier.SqlIdentifier,
26
- schema_name: sql_identifier.SqlIdentifier,
27
- ) -> None:
28
- self._session = session
29
- self._database_name = database_name
30
- self._schema_name = schema_name
31
-
32
- def __eq__(self, __value: object) -> bool:
33
- if not isinstance(__value, ModelSQLClient):
34
- return False
35
- return self._database_name == __value._database_name and self._schema_name == __value._schema_name
36
-
37
- def fully_qualified_model_name(self, model_name: sql_identifier.SqlIdentifier) -> str:
38
- return identifier.get_schema_level_object_identifier(
39
- self._database_name.identifier(), self._schema_name.identifier(), model_name.identifier()
40
- )
41
-
42
18
  def show_models(
43
19
  self,
44
20
  *,
21
+ database_name: Optional[sql_identifier.SqlIdentifier],
22
+ schema_name: Optional[sql_identifier.SqlIdentifier],
45
23
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
46
24
  validate_result: bool = True,
47
25
  statement_params: Optional[Dict[str, Any]] = None,
48
26
  ) -> List[row.Row]:
49
- fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
27
+ actual_database_name = database_name or self._database_name
28
+ actual_schema_name = schema_name or self._schema_name
29
+ fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
50
30
  like_sql = ""
51
31
  if model_name:
52
32
  like_sql = f" LIKE '{model_name.resolved()}'"
@@ -69,6 +49,8 @@ class ModelSQLClient:
69
49
  def show_versions(
70
50
  self,
71
51
  *,
52
+ database_name: Optional[sql_identifier.SqlIdentifier],
53
+ schema_name: Optional[sql_identifier.SqlIdentifier],
72
54
  model_name: sql_identifier.SqlIdentifier,
73
55
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
74
56
  validate_result: bool = True,
@@ -82,7 +64,10 @@ class ModelSQLClient:
82
64
  res = (
83
65
  query_result_checker.SqlResultValidator(
84
66
  self._session,
85
- f"SHOW VERSIONS{like_sql} IN MODEL {self.fully_qualified_model_name(model_name)}",
67
+ (
68
+ f"SHOW VERSIONS{like_sql} IN "
69
+ f"MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
70
+ ),
86
71
  statement_params=statement_params,
87
72
  )
88
73
  .has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
@@ -99,31 +84,40 @@ class ModelSQLClient:
99
84
  def set_comment(
100
85
  self,
101
86
  *,
102
- comment: str,
87
+ database_name: Optional[sql_identifier.SqlIdentifier],
88
+ schema_name: Optional[sql_identifier.SqlIdentifier],
103
89
  model_name: sql_identifier.SqlIdentifier,
90
+ comment: str,
104
91
  statement_params: Optional[Dict[str, Any]] = None,
105
92
  ) -> None:
106
93
  query_result_checker.SqlResultValidator(
107
94
  self._session,
108
- f"COMMENT ON MODEL {self.fully_qualified_model_name(model_name)} IS $${comment}$$",
95
+ (
96
+ f"COMMENT ON MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
97
+ f" IS $${comment}$$"
98
+ ),
109
99
  statement_params=statement_params,
110
100
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
111
101
 
112
102
  def drop_model(
113
103
  self,
114
104
  *,
105
+ database_name: Optional[sql_identifier.SqlIdentifier],
106
+ schema_name: Optional[sql_identifier.SqlIdentifier],
115
107
  model_name: sql_identifier.SqlIdentifier,
116
108
  statement_params: Optional[Dict[str, Any]] = None,
117
109
  ) -> None:
118
110
  query_result_checker.SqlResultValidator(
119
111
  self._session,
120
- f"DROP MODEL {self.fully_qualified_model_name(model_name)}",
112
+ f"DROP MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}",
121
113
  statement_params=statement_params,
122
114
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
123
115
 
124
116
  def rename(
125
117
  self,
126
118
  *,
119
+ database_name: Optional[sql_identifier.SqlIdentifier],
120
+ schema_name: Optional[sql_identifier.SqlIdentifier],
127
121
  model_name: sql_identifier.SqlIdentifier,
128
122
  new_model_db: Optional[sql_identifier.SqlIdentifier],
129
123
  new_model_schema: Optional[sql_identifier.SqlIdentifier],
@@ -131,13 +125,12 @@ class ModelSQLClient:
131
125
  statement_params: Optional[Dict[str, Any]] = None,
132
126
  ) -> None:
133
127
  # Use registry's database and schema if a non fully qualified new model name is provided.
134
- new_fully_qualified_name = identifier.get_schema_level_object_identifier(
135
- new_model_db.identifier() if new_model_db else self._database_name.identifier(),
136
- new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(),
137
- new_model_name.identifier(),
138
- )
128
+ new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
139
129
  query_result_checker.SqlResultValidator(
140
130
  self._session,
141
- f"ALTER MODEL {self.fully_qualified_model_name(model_name)} RENAME TO {new_fully_qualified_name}",
131
+ (
132
+ f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
133
+ f" RENAME TO {new_fully_qualified_name}"
134
+ ),
142
135
  statement_params=statement_params,
143
136
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()