snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cloudpickle as cp
10
10
  import numpy as np
11
+ import numpy.typing as npt
11
12
  from sklearn import model_selection
12
13
  from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
13
14
 
@@ -38,9 +39,11 @@ from snowflake.snowpark.types import IntegerType, StringType, StructField, Struc
38
39
 
39
40
  cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
40
41
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
42
+ cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
41
43
 
42
44
  _PROJECT = "ModelDevelopment"
43
45
  DEFAULT_UDTF_NJOBS = 3
46
+ ENABLE_EFFICIENT_MEMORY_USAGE = False
44
47
 
45
48
 
46
49
  def construct_cv_results(
@@ -151,7 +154,63 @@ def construct_cv_results(
151
154
  return multimetric, estimator._format_results(param_grid, n_split, out)
152
155
 
153
156
 
157
+ def construct_cv_results_new_implementation(
158
+ estimator: Union[GridSearchCV, RandomizedSearchCV],
159
+ n_split: int,
160
+ param_grid: List[Dict[str, Any]],
161
+ cv_results_raw_hex: List[Row],
162
+ cross_validator_indices_length: int,
163
+ parameter_grid_length: int,
164
+ ) -> Tuple[Any, Dict[str, Any]]:
165
+ """Construct the cross validation result from the UDF.
166
+ The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
167
+ This function need to decode the string and then call _format_result to stick them back together
168
+ to align with original sklearn result.
169
+
170
+ Args:
171
+ estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
172
+ GridSearchCV or RandomizedSearchCV
173
+ n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
174
+ param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
175
+ cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
176
+ Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
177
+ json format. Each cv_result is encoded into hex string.
178
+ cross_validator_indices_length (int): the length of cross validator indices
179
+ parameter_grid_length (int): the length of parameter grid combination
180
+
181
+ Raises:
182
+ ValueError: Retrieved empty cross validation results
183
+ ValueError: Cross validator index length is 0
184
+ ValueError: Parameter index length is 0
185
+ ValueError: Retrieved incorrect dataframe dimension from Snowpark's UDTF.
186
+
187
+ Returns:
188
+ Tuple[Any, Dict[str, Any]]: returns first_test_score, cv_results_
189
+ """
190
+ # Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
191
+ if len(cv_results_raw_hex) == 0:
192
+ raise ValueError(
193
+ "Retrieved empty cross validation results from snowpark. Please retry or contact snowflake support."
194
+ )
195
+ if cross_validator_indices_length == 0:
196
+ raise ValueError("Cross validator index length is 0. Was the CV iterator empty? ")
197
+ if parameter_grid_length == 0:
198
+ raise ValueError("Parameter index length is 0. Were there no candidates?")
199
+
200
+ all_out = []
201
+
202
+ for each_cv_result_hex in cv_results_raw_hex:
203
+ # convert the hex string back to cv_results_
204
+ hex_str = bytes.fromhex(each_cv_result_hex[0])
205
+ with io.BytesIO(hex_str) as f_reload:
206
+ out = cp.load(f_reload)
207
+ all_out.extend(out)
208
+ first_test_score = all_out[0]["test_scores"]
209
+ return first_test_score, estimator._format_results(param_grid, n_split, all_out)
210
+
211
+
154
212
  cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
213
+ cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_new_implementation))
155
214
 
156
215
 
157
216
  class DistributedHPOTrainer(SnowparkModelTrainer):
@@ -277,7 +336,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
277
336
  imports.append(f"@{temp_stage_name}/{estimator_location}")
278
337
 
279
338
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
280
- random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
339
+ random_udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
281
340
 
282
341
  required_deps = dependencies + [
283
342
  "snowflake-snowpark-python<2",
@@ -602,6 +661,480 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
602
661
 
603
662
  return fit_estimator
604
663
 
664
+ def fit_search_snowpark_new_implementation(
665
+ self,
666
+ param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
667
+ dataset: DataFrame,
668
+ session: Session,
669
+ estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
670
+ dependencies: List[str],
671
+ udf_imports: List[str],
672
+ input_cols: List[str],
673
+ label_cols: Optional[List[str]],
674
+ sample_weight_col: Optional[str],
675
+ ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
676
+ from itertools import product
677
+
678
+ import cachetools
679
+ from sklearn.base import clone, is_classifier
680
+ from sklearn.calibration import check_cv
681
+
682
+ # Create one stage for data and for estimators.
683
+ temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
684
+ temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
685
+ session.sql(temp_stage_creation_query).collect()
686
+
687
+ # Stage data as parquet file
688
+ dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
689
+ dataset_file_name = "dataset"
690
+ remote_file_path = f"{temp_stage_name}/{dataset_file_name}.parquet"
691
+ dataset.write.copy_into_location( # type:ignore[call-overload]
692
+ remote_file_path, file_format_type="parquet", header=True, overwrite=True
693
+ )
694
+ imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
695
+
696
+ # Create a temp file and dump the estimator to that file.
697
+ estimator_file_name = get_temp_file_path()
698
+ params_to_evaluate = list(param_grid)
699
+ n_candidates = len(params_to_evaluate)
700
+ _N_JOBS = estimator.n_jobs
701
+ _PRE_DISPATCH = estimator.pre_dispatch
702
+
703
+ with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
704
+ cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
705
+ stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name))
706
+ sproc_statement_params = telemetry.get_function_usage_statement_params(
707
+ project=_PROJECT,
708
+ subproject=self._subproject,
709
+ function_name=telemetry.get_statement_params_full_func_name(
710
+ inspect.currentframe(), self.__class__.__name__
711
+ ),
712
+ api_calls=[sproc],
713
+ )
714
+ udtf_statement_params = telemetry.get_function_usage_statement_params(
715
+ project=_PROJECT,
716
+ subproject=self._subproject,
717
+ function_name=telemetry.get_statement_params_full_func_name(
718
+ inspect.currentframe(), self.__class__.__name__
719
+ ),
720
+ api_calls=[udtf],
721
+ custom_tags=dict([("hpo_udtf", True)]),
722
+ )
723
+
724
+ # Put locally serialized estimator on stage.
725
+ session.file.put(
726
+ estimator_file_name,
727
+ temp_stage_name,
728
+ auto_compress=False,
729
+ overwrite=True,
730
+ )
731
+ estimator_location = os.path.basename(estimator_file_name)
732
+ imports.append(f"@{temp_stage_name}/{estimator_location}")
733
+
734
+ search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
735
+ random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
736
+
737
+ required_deps = dependencies + [
738
+ "snowflake-snowpark-python<2",
739
+ "fastparquet<2023.11",
740
+ "pyarrow<14",
741
+ "cachetools<6",
742
+ ]
743
+
744
+ @sproc( # type: ignore[misc]
745
+ is_permanent=False,
746
+ name=search_sproc_name,
747
+ packages=required_deps, # type: ignore[arg-type]
748
+ replace=True,
749
+ session=session,
750
+ anonymous=True,
751
+ imports=imports, # type: ignore[arg-type]
752
+ statement_params=sproc_statement_params,
753
+ )
754
+ def _distributed_search(
755
+ session: Session,
756
+ imports: List[str],
757
+ stage_estimator_file_name: str,
758
+ input_cols: List[str],
759
+ label_cols: Optional[List[str]],
760
+ ) -> str:
761
+ import os
762
+ import time
763
+ from typing import Iterator
764
+
765
+ import cloudpickle as cp
766
+ import pandas as pd
767
+ import pyarrow.parquet as pq
768
+ from sklearn.metrics import check_scoring
769
+ from sklearn.metrics._scorer import _check_multimetric_scoring
770
+ from sklearn.utils.validation import _check_fit_params, indexable
771
+
772
+ # import packages in sproc
773
+ for import_name in udf_imports:
774
+ importlib.import_module(import_name)
775
+
776
+ # os.cpu_count() returns the number of logical CPUs in the system. Returns None if undetermined.
777
+ _NUM_CPUs = os.cpu_count() or 1
778
+
779
+ # load dataset
780
+ data_files = [
781
+ filename
782
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
783
+ if filename.startswith(dataset_file_name)
784
+ ]
785
+ partial_df = [
786
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
787
+ for file_name in data_files
788
+ ]
789
+ df = pd.concat(partial_df, ignore_index=True)
790
+ df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
791
+
792
+ X = df[input_cols]
793
+ y = df[label_cols].squeeze() if label_cols else None
794
+ DATA_LENGTH = len(df)
795
+ fit_params = {}
796
+ if sample_weight_col:
797
+ fit_params["sample_weight"] = df[sample_weight_col].squeeze()
798
+
799
+ local_estimator_file_folder_name = get_temp_file_path()
800
+ session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
801
+
802
+ local_estimator_file_path = os.path.join(
803
+ local_estimator_file_folder_name, os.listdir(local_estimator_file_folder_name)[0]
804
+ )
805
+ with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
806
+ estimator = cp.load(local_estimator_file_obj)["estimator"]
807
+
808
+ # preprocess the attributes - (1) scorer
809
+ refit_metric = "score"
810
+ if callable(estimator.scoring):
811
+ scorers = estimator.scoring
812
+ elif estimator.scoring is None or isinstance(estimator.scoring, str):
813
+ scorers = check_scoring(estimator.estimator, estimator.scoring)
814
+ else:
815
+ scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
816
+ estimator._check_refit_for_multimetric(scorers)
817
+ refit_metric = estimator.refit
818
+
819
+ # preprocess the attributes - (2) check fit_params
820
+ groups = None
821
+ X, y, _ = indexable(X, y, groups)
822
+ fit_params = _check_fit_params(X, fit_params)
823
+
824
+ # preprocess the attributes - (3) safe clone base estimator
825
+ base_estimator = clone(estimator.estimator)
826
+
827
+ # preprocess the attributes - (4) check cv
828
+ build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
829
+ n_splits = build_cross_validator.get_n_splits(X, y, groups)
830
+
831
+ # preprocess the attributes - (5) generate fit_and_score_kwargs
832
+ fit_and_score_kwargs = dict(
833
+ scorer=scorers,
834
+ fit_params=fit_params,
835
+ return_train_score=estimator.return_train_score,
836
+ return_n_test_samples=True,
837
+ return_times=True,
838
+ return_parameters=False,
839
+ error_score=estimator.error_score,
840
+ verbose=estimator.verbose,
841
+ )
842
+
843
+ # (1) store the cross_validator's test indices only to save space
844
+ cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
845
+ local_indices_file_name = get_temp_file_path()
846
+ with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
847
+ cp.dump(cross_validator_indices, local_indices_file_obj)
848
+
849
+ # Put locally serialized indices on stage.
850
+ session.file.put(
851
+ local_indices_file_name,
852
+ temp_stage_name,
853
+ auto_compress=False,
854
+ overwrite=True,
855
+ )
856
+ indices_location = os.path.basename(local_indices_file_name)
857
+ imports.append(f"@{temp_stage_name}/{indices_location}")
858
+
859
+ # (2) store the base estimator
860
+ local_base_estimator_file_name = get_temp_file_path()
861
+ with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
862
+ cp.dump(base_estimator, local_base_estimator_file_obj)
863
+ session.file.put(
864
+ local_base_estimator_file_name,
865
+ temp_stage_name,
866
+ auto_compress=False,
867
+ overwrite=True,
868
+ )
869
+ base_estimator_location = os.path.basename(local_base_estimator_file_name)
870
+ imports.append(f"@{temp_stage_name}/{base_estimator_location}")
871
+
872
+ # (3) store the fit_and_score_kwargs
873
+ local_fit_and_score_kwargs_file_name = get_temp_file_path()
874
+ with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
875
+ cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
876
+ session.file.put(
877
+ local_fit_and_score_kwargs_file_name,
878
+ temp_stage_name,
879
+ auto_compress=False,
880
+ overwrite=True,
881
+ )
882
+ fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
883
+ imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
884
+
885
+ cross_validator_indices_length = int(len(cross_validator_indices))
886
+ parameter_grid_length = len(param_grid)
887
+
888
+ assert estimator is not None
889
+
890
+ @cachetools.cached(cache={})
891
+ def _load_data_into_udf() -> Tuple[
892
+ npt.NDArray[Any],
893
+ npt.NDArray[Any],
894
+ List[List[int]],
895
+ List[Dict[str, Any]],
896
+ object,
897
+ Dict[str, Any],
898
+ ]:
899
+ import pyarrow.parquet as pq
900
+
901
+ data_files = [
902
+ filename
903
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
904
+ if filename.startswith(dataset_file_name)
905
+ ]
906
+ partial_df = [
907
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
908
+ for file_name in data_files
909
+ ]
910
+ df = pd.concat(partial_df, ignore_index=True)
911
+ df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
912
+
913
+ # load parameter grid
914
+ local_estimator_file_path = os.path.join(
915
+ sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
916
+ )
917
+ with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
918
+ estimator_objects = cp.load(local_estimator_file_obj)
919
+ params_to_evaluate = estimator_objects["param_grid"]
920
+
921
+ # load indices
922
+ local_indices_file_path = os.path.join(
923
+ sys._xoptions["snowflake_import_directory"], f"{indices_location}"
924
+ )
925
+ with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
926
+ indices = cp.load(local_indices_file_obj)
927
+
928
+ # load base estimator
929
+ local_base_estimator_file_path = os.path.join(
930
+ sys._xoptions["snowflake_import_directory"], f"{base_estimator_location}"
931
+ )
932
+ with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
933
+ base_estimator = cp.load(local_base_estimator_file_obj)
934
+
935
+ # load fit_and_score_kwargs
936
+ local_fit_and_score_kwargs_file_path = os.path.join(
937
+ sys._xoptions["snowflake_import_directory"], f"{fit_and_score_kwargs_location}"
938
+ )
939
+ with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
940
+ fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
941
+
942
+ # convert dataframe to numpy would save memory consumption
943
+ return (
944
+ df[input_cols].to_numpy(),
945
+ df[label_cols].squeeze().to_numpy(),
946
+ indices,
947
+ params_to_evaluate,
948
+ base_estimator,
949
+ fit_and_score_kwargs,
950
+ )
951
+
952
+ # Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
953
+ class SearchCV:
954
+ def __init__(self) -> None:
955
+ X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs = _load_data_into_udf()
956
+ self.X = X
957
+ self.y = y
958
+ self.indices = indices
959
+ self.params_to_evaluate = params_to_evaluate
960
+ self.base_estimator = base_estimator
961
+ self.fit_and_score_kwargs = fit_and_score_kwargs
962
+ self.fit_score_params: List[Any] = []
963
+
964
+ def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
965
+ # 1. Calculate the parameter list
966
+ parameters = self.params_to_evaluate[params_idx]
967
+ # 2. Calculate the cross validator indices
968
+ # cross validator's indices: we stored test indices only (to save space);
969
+ # use the full index to re-construct each train index back.
970
+ full_index = np.array([i for i in range(DATA_LENGTH)])
971
+ test_index = self.indices[cv_idx]
972
+ train_index = np.setdiff1d(full_index, test_index)
973
+ self.fit_score_params.extend([[idx, (params_idx, parameters), (cv_idx, (train_index, test_index))]])
974
+
975
+ def end_partition(self) -> Iterator[Tuple[int, str]]:
976
+ from sklearn.base import clone
977
+ from sklearn.model_selection._validation import _fit_and_score
978
+ from sklearn.utils.parallel import Parallel, delayed
979
+
980
+ parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
981
+
982
+ out = parallel(
983
+ delayed(_fit_and_score)(
984
+ clone(self.base_estimator),
985
+ self.X,
986
+ self.y,
987
+ train=train,
988
+ test=test,
989
+ parameters=parameters,
990
+ split_progress=(split_idx, n_splits),
991
+ candidate_progress=(cand_idx, n_candidates),
992
+ **self.fit_and_score_kwargs, # load sample weight here
993
+ )
994
+ for _, (cand_idx, parameters), (split_idx, (train, test)) in self.fit_score_params
995
+ )
996
+
997
+ binary_cv_results = None
998
+ with io.BytesIO() as f:
999
+ cp.dump(out, f)
1000
+ f.seek(0)
1001
+ binary_cv_results = f.getvalue().hex()
1002
+ yield (
1003
+ self.fit_score_params[0][0],
1004
+ binary_cv_results,
1005
+ )
1006
+
1007
+ session.udtf.register(
1008
+ SearchCV,
1009
+ output_schema=StructType([StructField("IDX", IntegerType()), StructField("CV_RESULTS", StringType())]),
1010
+ input_types=[IntegerType(), IntegerType(), IntegerType()],
1011
+ name=random_udtf_name,
1012
+ packages=required_deps, # type: ignore[arg-type]
1013
+ replace=True,
1014
+ is_permanent=False,
1015
+ imports=imports, # type: ignore[arg-type]
1016
+ statement_params=udtf_statement_params,
1017
+ )
1018
+
1019
+ HP_TUNING = F.table_function(random_udtf_name)
1020
+
1021
+ # param_indices is for the index for each parameter grid;
1022
+ # cv_indices is for the index for each cross_validator's fold;
1023
+ # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
1024
+ param_indices, cv_indices = zip(
1025
+ *product(range(parameter_grid_length), range(cross_validator_indices_length))
1026
+ )
1027
+
1028
+ indices_info_pandas = pd.DataFrame(
1029
+ {
1030
+ "IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
1031
+ "PARAM_IND": param_indices,
1032
+ "CV_IND": cv_indices,
1033
+ }
1034
+ )
1035
+
1036
+ indices_info_sp = session.create_dataframe(indices_info_pandas)
1037
+ # execute udtf by querying HP_TUNING table
1038
+ HP_raw_results = indices_info_sp.select(
1039
+ (
1040
+ HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
1041
+ partition_by="IDX"
1042
+ )
1043
+ ),
1044
+ )
1045
+
1046
+ first_test_score, cv_results_ = construct_cv_results_new_implementation(
1047
+ estimator,
1048
+ n_splits,
1049
+ list(param_grid),
1050
+ HP_raw_results.select("CV_RESULTS").sort(F.col("IDX")).collect(),
1051
+ cross_validator_indices_length,
1052
+ parameter_grid_length,
1053
+ )
1054
+
1055
+ estimator.cv_results_ = cv_results_
1056
+ estimator.multimetric_ = isinstance(first_test_score, dict)
1057
+
1058
+ # check refit_metric now for a callable scorer that is multimetric
1059
+ if callable(estimator.scoring) and estimator.multimetric_:
1060
+ estimator._check_refit_for_multimetric(first_test_score)
1061
+ refit_metric = estimator.refit
1062
+
1063
+ # For multi-metric evaluation, store the best_index_, best_params_ and
1064
+ # best_score_ iff refit is one of the scorer names
1065
+ # In single metric evaluation, refit_metric is "score"
1066
+ if estimator.refit or not estimator.multimetric_:
1067
+ estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1068
+ if not callable(estimator.refit):
1069
+ # With a non-custom callable, we can select the best score
1070
+ # based on the best index
1071
+ estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1072
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1073
+
1074
+ if estimator.refit:
1075
+ estimator.best_estimator_ = clone(base_estimator).set_params(
1076
+ **clone(estimator.best_params_, safe=False)
1077
+ )
1078
+
1079
+ # Let the sproc use all cores to refit.
1080
+ estimator.n_jobs = estimator.n_jobs or -1
1081
+
1082
+ # process the input as args
1083
+ argspec = inspect.getfullargspec(estimator.fit)
1084
+ args = {"X": X}
1085
+ if label_cols:
1086
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
1087
+ args[label_arg_name] = y
1088
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
1089
+ args["sample_weight"] = df[sample_weight_col].squeeze()
1090
+ # estimator.refit = original_refit
1091
+ refit_start_time = time.time()
1092
+ estimator.best_estimator_.fit(**args)
1093
+ refit_end_time = time.time()
1094
+ estimator.refit_time_ = refit_end_time - refit_start_time
1095
+
1096
+ if hasattr(estimator.best_estimator_, "feature_names_in_"):
1097
+ estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1098
+
1099
+ # Store the only scorer not as a dict for single metric evaluation
1100
+ estimator.scorer_ = scorers
1101
+ estimator.n_splits_ = n_splits
1102
+
1103
+ local_result_file_name = get_temp_file_path()
1104
+
1105
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1106
+ cp.dump(estimator, local_result_file_obj)
1107
+
1108
+ session.file.put(
1109
+ local_result_file_name,
1110
+ temp_stage_name,
1111
+ auto_compress=False,
1112
+ overwrite=True,
1113
+ )
1114
+
1115
+ return str(os.path.basename(local_result_file_name))
1116
+
1117
+ sproc_export_file_name = _distributed_search(
1118
+ session,
1119
+ imports,
1120
+ stage_estimator_file_name,
1121
+ input_cols,
1122
+ label_cols,
1123
+ )
1124
+
1125
+ local_estimator_path = get_temp_file_path()
1126
+ session.file.get(
1127
+ posixpath.join(temp_stage_name, sproc_export_file_name),
1128
+ local_estimator_path,
1129
+ )
1130
+
1131
+ with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
1132
+ fit_estimator = cp.load(result_file_obj)
1133
+
1134
+ cleanup_temp_files(local_estimator_path)
1135
+
1136
+ return fit_estimator
1137
+
605
1138
  def train(self) -> object:
606
1139
  """
607
1140
  Runs hyper parameter optimization by distributing the tasks across warehouse.
@@ -630,6 +1163,19 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
630
1163
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
631
1164
  pkg_versions=model_spec.pkgDependencies, session=self.session
632
1165
  )
1166
+ if ENABLE_EFFICIENT_MEMORY_USAGE:
1167
+ return self.fit_search_snowpark_new_implementation(
1168
+ param_grid=param_grid,
1169
+ dataset=self.dataset,
1170
+ session=self.session,
1171
+ estimator=self.estimator,
1172
+ dependencies=relaxed_dependencies,
1173
+ udf_imports=["sklearn"],
1174
+ input_cols=self.input_cols,
1175
+ label_cols=self.label_cols,
1176
+ sample_weight_col=self.sample_weight_col,
1177
+ )
1178
+
633
1179
  return self.fit_search_snowpark(
634
1180
  param_grid=param_grid,
635
1181
  dataset=self.dataset,