snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (218) hide show
  1. snowflake/ml/_internal/env_utils.py +72 -31
  2. snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
  3. snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
  4. snowflake/ml/_internal/exceptions/error_codes.py +3 -0
  5. snowflake/ml/_internal/lineage/data_source.py +10 -0
  6. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  7. snowflake/ml/_internal/telemetry.py +1 -0
  8. snowflake/ml/_internal/utils/identifier.py +1 -1
  9. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  10. snowflake/ml/dataset/__init__.py +11 -0
  11. snowflake/ml/dataset/dataset.py +455 -129
  12. snowflake/ml/dataset/dataset_factory.py +53 -0
  13. snowflake/ml/dataset/dataset_metadata.py +103 -0
  14. snowflake/ml/dataset/dataset_reader.py +199 -0
  15. snowflake/ml/feature_store/__init__.py +6 -0
  16. snowflake/ml/feature_store/access_manager.py +279 -0
  17. snowflake/ml/feature_store/feature_store.py +544 -358
  18. snowflake/ml/feature_store/feature_view.py +55 -16
  19. snowflake/ml/fileset/embedded_stage_fs.py +149 -0
  20. snowflake/ml/fileset/sfcfs.py +0 -4
  21. snowflake/ml/fileset/snowfs.py +160 -0
  22. snowflake/ml/fileset/stage_fs.py +25 -10
  23. snowflake/ml/model/__init__.py +2 -2
  24. snowflake/ml/model/_api.py +16 -1
  25. snowflake/ml/model/_client/model/model_impl.py +65 -31
  26. snowflake/ml/model/_client/model/model_version_impl.py +159 -2
  27. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  28. snowflake/ml/model/_client/ops/model_ops.py +268 -83
  29. snowflake/ml/model/_client/sql/_base.py +34 -0
  30. snowflake/ml/model/_client/sql/model.py +42 -47
  31. snowflake/ml/model/_client/sql/model_version.py +164 -39
  32. snowflake/ml/model/_client/sql/stage.py +6 -32
  33. snowflake/ml/model/_client/sql/tag.py +32 -56
  34. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
  35. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
  36. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  37. snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
  38. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
  39. snowflake/ml/model/_model_composer/model_composer.py +22 -1
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
  41. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
  42. snowflake/ml/model/_packager/model_env/model_env.py +41 -0
  43. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
  45. snowflake/ml/model/_packager/model_packager.py +0 -3
  46. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
  47. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
  48. snowflake/ml/modeling/_internal/model_trainer.py +7 -0
  49. snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
  50. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
  51. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
  53. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
  54. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
  55. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
  56. snowflake/ml/modeling/cluster/birch.py +53 -52
  57. snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
  58. snowflake/ml/modeling/cluster/dbscan.py +51 -52
  59. snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
  60. snowflake/ml/modeling/cluster/k_means.py +53 -52
  61. snowflake/ml/modeling/cluster/mean_shift.py +51 -52
  62. snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
  63. snowflake/ml/modeling/cluster/optics.py +51 -52
  64. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
  65. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
  66. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
  67. snowflake/ml/modeling/compose/column_transformer.py +53 -52
  68. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
  69. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
  70. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
  71. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
  72. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
  73. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
  74. snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
  75. snowflake/ml/modeling/covariance/oas.py +51 -52
  76. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
  77. snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
  78. snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
  79. snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
  80. snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
  81. snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
  82. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
  83. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
  84. snowflake/ml/modeling/decomposition/pca.py +53 -52
  85. snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
  86. snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
  87. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
  88. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
  89. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
  90. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
  91. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
  92. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
  93. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
  94. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
  95. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
  96. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
  97. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
  99. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
  100. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
  101. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
  102. snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
  103. snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
  104. snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
  105. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
  106. snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
  107. snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
  108. snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
  109. snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
  110. snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
  111. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
  112. snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
  113. snowflake/ml/modeling/framework/base.py +64 -36
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
  116. snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
  117. snowflake/ml/modeling/impute/knn_imputer.py +53 -52
  118. snowflake/ml/modeling/impute/missing_indicator.py +53 -52
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
  127. snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
  129. snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
  133. snowflake/ml/modeling/linear_model/lars.py +51 -52
  134. snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
  135. snowflake/ml/modeling/linear_model/lasso.py +51 -52
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
  140. snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
  150. snowflake/ml/modeling/linear_model/perceptron.py +51 -52
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
  153. snowflake/ml/modeling/linear_model/ridge.py +51 -52
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
  162. snowflake/ml/modeling/manifold/isomap.py +53 -52
  163. snowflake/ml/modeling/manifold/mds.py +53 -52
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
  165. snowflake/ml/modeling/manifold/tsne.py +53 -52
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
  180. snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
  190. snowflake/ml/modeling/pipeline/pipeline.py +538 -36
  191. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
  195. snowflake/ml/modeling/svm/linear_svc.py +51 -52
  196. snowflake/ml/modeling/svm/linear_svr.py +51 -52
  197. snowflake/ml/modeling/svm/nu_svc.py +51 -52
  198. snowflake/ml/modeling/svm/nu_svr.py +51 -52
  199. snowflake/ml/modeling/svm/svc.py +51 -52
  200. snowflake/ml/modeling/svm/svr.py +51 -52
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
  209. snowflake/ml/registry/_manager/model_manager.py +36 -7
  210. snowflake/ml/registry/model_registry.py +3 -149
  211. snowflake/ml/version.py +1 -1
  212. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
  213. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
  214. snowflake/ml/registry/_artifact_manager.py +0 -156
  215. snowflake/ml/registry/artifact.py +0 -46
  216. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
  217. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
  218. {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -72,24 +72,40 @@ class MLRuntimeTransformHandlers:
72
72
 
73
73
  """
74
74
 
75
- handler = SnowparkTransformHandlers(
76
- dataset=self.dataset,
77
- estimator=self.estimator,
78
- class_name=self._class_name,
79
- subproject=self._subproject,
80
- autogenerated=self._autogenerated,
81
- )
82
- return handler.batch_inference(
83
- inference_method,
84
- input_cols,
85
- expected_output_cols,
86
- session,
87
- dependencies,
88
- drop_input_cols,
89
- expected_output_cols_type,
90
- *args,
91
- **kwargs,
92
- )
75
+ mlrs_inference_methods = ["predict", "predict_proba", "predict_log_proba"]
76
+
77
+ if inference_method in mlrs_inference_methods:
78
+ result_df = self.client.inference(
79
+ estimator=self.estimator,
80
+ dataset=self.dataset,
81
+ inference_method=inference_method,
82
+ input_cols=input_cols,
83
+ output_cols=expected_output_cols,
84
+ drop_input_cols=drop_input_cols,
85
+ )
86
+
87
+ else:
88
+ handler = SnowparkTransformHandlers(
89
+ dataset=self.dataset,
90
+ estimator=self.estimator,
91
+ class_name=self._class_name,
92
+ subproject=self._subproject,
93
+ autogenerated=self._autogenerated,
94
+ )
95
+ result_df = handler.batch_inference(
96
+ inference_method,
97
+ input_cols,
98
+ expected_output_cols,
99
+ session,
100
+ dependencies,
101
+ drop_input_cols,
102
+ expected_output_cols_type,
103
+ *args,
104
+ **kwargs,
105
+ )
106
+
107
+ assert isinstance(result_df, DataFrame) # mypy - The MLRS return types are annotated as `object`.
108
+ return result_df
93
109
 
94
110
  def score(
95
111
  self,
@@ -22,3 +22,10 @@ class ModelTrainer(Protocol):
22
22
  drop_input_cols: Optional[bool] = False,
23
23
  ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
24
24
  raise NotImplementedError
25
+
26
+ def train_fit_transform(
27
+ self,
28
+ expected_output_cols_list: List[str],
29
+ drop_input_cols: Optional[bool] = False,
30
+ ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
31
+ raise NotImplementedError
@@ -138,21 +138,13 @@ class ModelTrainerBuilder:
138
138
  cls,
139
139
  estimator: object,
140
140
  dataset: Union[DataFrame, pd.DataFrame],
141
- input_cols: Optional[List[str]] = None,
141
+ input_cols: List[str],
142
142
  autogenerated: bool = False,
143
143
  subproject: str = "",
144
144
  ) -> ModelTrainer:
145
145
  """
146
146
  Builder method that creates an appropriate ModelTrainer instance based on the given params.
147
147
  """
148
- if input_cols is None:
149
- raise exceptions.SnowflakeMLException(
150
- error_code=error_codes.NOT_FOUND,
151
- original_exception=ValueError(
152
- "The input column names (input_cols) is None.\n"
153
- "Please put your input_cols when initializing the estimator\n"
154
- ),
155
- )
156
148
  if isinstance(dataset, pd.DataFrame):
157
149
  return PandasModelTrainer(
158
150
  estimator=estimator,
@@ -179,3 +171,44 @@ class ModelTrainerBuilder:
179
171
  f"Unexpected dataset type: {type(dataset)}."
180
172
  "Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
181
173
  )
174
+
175
+ @classmethod
176
+ def build_fit_transform(
177
+ cls,
178
+ estimator: object,
179
+ dataset: Union[DataFrame, pd.DataFrame],
180
+ input_cols: List[str],
181
+ label_cols: Optional[List[str]] = None,
182
+ sample_weight_col: Optional[str] = None,
183
+ autogenerated: bool = False,
184
+ subproject: str = "",
185
+ ) -> ModelTrainer:
186
+ """
187
+ Builder method that creates an appropriate ModelTrainer instance based on the given params.
188
+ """
189
+ if isinstance(dataset, pd.DataFrame):
190
+ return PandasModelTrainer(
191
+ estimator=estimator,
192
+ dataset=dataset,
193
+ input_cols=input_cols,
194
+ label_cols=label_cols,
195
+ sample_weight_col=sample_weight_col,
196
+ )
197
+ elif isinstance(dataset, DataFrame):
198
+ trainer_klass = SnowparkModelTrainer
199
+ init_args = {
200
+ "estimator": estimator,
201
+ "dataset": dataset,
202
+ "session": dataset._session,
203
+ "input_cols": input_cols,
204
+ "label_cols": label_cols,
205
+ "sample_weight_col": sample_weight_col,
206
+ "autogenerated": autogenerated,
207
+ "subproject": subproject,
208
+ }
209
+ return trainer_klass(**init_args) # type: ignore[arg-type]
210
+ else:
211
+ raise TypeError(
212
+ f"Unexpected dataset type: {type(dataset)}."
213
+ "Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
214
+ )
@@ -4,7 +4,7 @@ import io
4
4
  import os
5
5
  import posixpath
6
6
  import sys
7
- from typing import Any, Dict, List, Optional, Tuple, Union
7
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
8
8
 
9
9
  import cloudpickle as cp
10
10
  import numpy as np
@@ -154,7 +154,7 @@ def construct_cv_results(
154
154
  return multimetric, estimator._format_results(param_grid, n_split, out)
155
155
 
156
156
 
157
- def construct_cv_results_new_implementation(
157
+ def construct_cv_results_memory_efficient_version(
158
158
  estimator: Union[GridSearchCV, RandomizedSearchCV],
159
159
  n_split: int,
160
160
  param_grid: List[Dict[str, Any]],
@@ -205,12 +205,35 @@ def construct_cv_results_new_implementation(
205
205
  with io.BytesIO(hex_str) as f_reload:
206
206
  out = cp.load(f_reload)
207
207
  all_out.extend(out)
208
+
209
+ # because original SearchCV is ranked by parameter first and cv second,
210
+ # to make the memory efficient, we implemented by fitting on cv first and parameter second
211
+ # when retrieving the results back, the ordering should revert back to remain the same result as original SearchCV
212
+ def generate_the_order_by_parameter_index(all_combination_length: int) -> List[int]:
213
+ pattern = []
214
+ for i in range(all_combination_length):
215
+ if i % parameter_grid_length == 0:
216
+ pattern.append(i)
217
+ for i in range(1, parameter_grid_length):
218
+ for j in range(all_combination_length):
219
+ if j % parameter_grid_length == i:
220
+ pattern.append(j)
221
+ return pattern
222
+
223
+ def rerank_array(original_array: List[Any], pattern: List[int]) -> List[Any]:
224
+ reranked_array = []
225
+ for index in pattern:
226
+ reranked_array.append(original_array[index])
227
+ return reranked_array
228
+
229
+ pattern = generate_the_order_by_parameter_index(len(all_out))
230
+ reranked_all_out = rerank_array(all_out, pattern)
208
231
  first_test_score = all_out[0]["test_scores"]
209
- return first_test_score, estimator._format_results(param_grid, n_split, all_out)
232
+ return first_test_score, estimator._format_results(param_grid, n_split, reranked_all_out)
210
233
 
211
234
 
212
235
  cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
213
- cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_new_implementation))
236
+ cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_memory_efficient_version))
214
237
 
215
238
 
216
239
  class DistributedHPOTrainer(SnowparkModelTrainer):
@@ -661,7 +684,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
661
684
 
662
685
  return fit_estimator
663
686
 
664
- def fit_search_snowpark_new_implementation(
687
+ def fit_search_snowpark_enable_efficient_memory_usage(
665
688
  self,
666
689
  param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
667
690
  dataset: DataFrame,
@@ -718,7 +741,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
718
741
  inspect.currentframe(), self.__class__.__name__
719
742
  ),
720
743
  api_calls=[udtf],
721
- custom_tags=dict([("hpo_udtf", True)]),
744
+ custom_tags=dict([("hpo_memory_efficient", True)]),
722
745
  )
723
746
 
724
747
  # Put locally serialized estimator on stage.
@@ -960,22 +983,26 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
960
983
  self.base_estimator = base_estimator
961
984
  self.fit_and_score_kwargs = fit_and_score_kwargs
962
985
  self.fit_score_params: List[Any] = []
963
- self.cached_train_test_indices = []
964
- # Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
965
- full_index = np.arange(DATA_LENGTH)
966
- for i in range(n_splits):
967
- self.cached_train_test_indices.extend(
968
- [[np.setdiff1d(full_index, self.test_indices[i]), self.test_indices[i]]]
969
- )
986
+ self.cv_indices_set: Set[int] = set()
970
987
 
971
988
  def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
972
989
  self.fit_score_params.extend([[idx, params_idx, cv_idx]])
990
+ self.cv_indices_set.add(cv_idx)
973
991
 
974
992
  def end_partition(self) -> Iterator[Tuple[int, str]]:
975
993
  from sklearn.base import clone
976
994
  from sklearn.model_selection._validation import _fit_and_score
977
995
  from sklearn.utils.parallel import Parallel, delayed
978
996
 
997
+ cached_train_test_indices = {}
998
+ # Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
999
+ full_index = np.arange(DATA_LENGTH)
1000
+ for i in self.cv_indices_set:
1001
+ cached_train_test_indices[i] = [
1002
+ np.setdiff1d(full_index, self.test_indices[i]),
1003
+ self.test_indices[i],
1004
+ ]
1005
+
979
1006
  parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
980
1007
 
981
1008
  out = parallel(
@@ -983,8 +1010,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
983
1010
  clone(self.base_estimator),
984
1011
  self.X,
985
1012
  self.y,
986
- train=self.cached_train_test_indices[split_idx][0],
987
- test=self.cached_train_test_indices[split_idx][1],
1013
+ train=cached_train_test_indices[split_idx][0],
1014
+ test=cached_train_test_indices[split_idx][1],
988
1015
  parameters=self.params_to_evaluate[cand_idx],
989
1016
  split_progress=(split_idx, n_splits),
990
1017
  candidate_progress=(cand_idx, n_candidates),
@@ -1005,7 +1032,9 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1005
1032
 
1006
1033
  session.udtf.register(
1007
1034
  SearchCV,
1008
- output_schema=StructType([StructField("IDX", IntegerType()), StructField("CV_RESULTS", StringType())]),
1035
+ output_schema=StructType(
1036
+ [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
1037
+ ),
1009
1038
  input_types=[IntegerType(), IntegerType(), IntegerType()],
1010
1039
  name=random_udtf_name,
1011
1040
  packages=required_deps, # type: ignore[arg-type]
@@ -1020,8 +1049,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1020
1049
  # param_indices is for the index for each parameter grid;
1021
1050
  # cv_indices is for the index for each cross_validator's fold;
1022
1051
  # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
1023
- param_indices, cv_indices = zip(
1024
- *product(range(parameter_grid_length), range(cross_validator_indices_length))
1052
+ cv_indices, param_indices = zip(
1053
+ *product(range(cross_validator_indices_length), range(parameter_grid_length))
1025
1054
  )
1026
1055
 
1027
1056
  indices_info_pandas = pd.DataFrame(
@@ -1042,11 +1071,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1042
1071
  ),
1043
1072
  )
1044
1073
 
1045
- first_test_score, cv_results_ = construct_cv_results_new_implementation(
1074
+ first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
1046
1075
  estimator,
1047
1076
  n_splits,
1048
1077
  list(param_grid),
1049
- HP_raw_results.select("CV_RESULTS").sort(F.col("IDX")).collect(),
1078
+ HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
1050
1079
  cross_validator_indices_length,
1051
1080
  parameter_grid_length,
1052
1081
  )
@@ -1163,7 +1192,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1163
1192
  pkg_versions=model_spec.pkgDependencies, session=self.session
1164
1193
  )
1165
1194
  if ENABLE_EFFICIENT_MEMORY_USAGE:
1166
- return self.fit_search_snowpark_new_implementation(
1195
+ return self.fit_search_snowpark_enable_efficient_memory_usage(
1167
1196
  param_grid=param_grid,
1168
1197
  dataset=self.dataset,
1169
1198
  session=self.session,
@@ -9,7 +9,11 @@ import cloudpickle as cp
9
9
  import pandas as pd
10
10
 
11
11
  from snowflake.ml._internal import telemetry
12
- from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
12
+ from snowflake.ml._internal.utils import (
13
+ identifier,
14
+ pkg_version_utils,
15
+ snowpark_dataframe_utils,
16
+ )
13
17
  from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
14
18
  from snowflake.ml._internal.utils.temp_file_utils import (
15
19
  cleanup_temp_files,
@@ -91,6 +95,7 @@ class SnowparkTransformHandlers:
91
95
  A new dataset of the same type as the input dataset.
92
96
  """
93
97
 
98
+ dependencies = self._get_validated_snowpark_dependencies(session, dependencies)
94
99
  dataset = self.dataset
95
100
  estimator = self.estimator
96
101
  # Register vectorized UDF for batch inference
@@ -210,7 +215,8 @@ class SnowparkTransformHandlers:
210
215
  Returns:
211
216
  An accuracy score for the model on the given test data.
212
217
  """
213
-
218
+ dependencies = self._get_validated_snowpark_dependencies(session, dependencies)
219
+ dependencies.append("snowflake-snowpark-python")
214
220
  dataset = self.dataset
215
221
  estimator = self.estimator
216
222
  dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(dataset)
@@ -335,3 +341,19 @@ class SnowparkTransformHandlers:
335
341
  cleanup_temp_files([local_score_file_name])
336
342
 
337
343
  return score
344
+
345
+ def _get_validated_snowpark_dependencies(self, session: Session, dependencies: List[str]) -> List[str]:
346
+ """A helper function to validate dependencies and return the available packages that exists
347
+ in the snowflake anaconda channel
348
+
349
+ Args:
350
+ session: the active snowpark Session
351
+ dependencies: unvalidated dependencies
352
+
353
+ Returns:
354
+ A list of packages present in the snoflake conda channel.
355
+ """
356
+
357
+ return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
358
+ pkg_versions=dependencies, session=session, subproject=self._subproject
359
+ )