snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. snowflake/ml/_internal/env_utils.py +11 -1
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/utils/formatting.py +1 -1
  7. snowflake/ml/_internal/utils/identifier.py +3 -1
  8. snowflake/ml/_internal/utils/sql_identifier.py +2 -6
  9. snowflake/ml/feature_store/feature_store.py +166 -184
  10. snowflake/ml/feature_store/feature_view.py +12 -24
  11. snowflake/ml/fileset/sfcfs.py +56 -50
  12. snowflake/ml/fileset/stage_fs.py +48 -13
  13. snowflake/ml/model/_client/model/model_version_impl.py +6 -49
  14. snowflake/ml/model/_client/ops/model_ops.py +78 -29
  15. snowflake/ml/model/_client/sql/model.py +23 -2
  16. snowflake/ml/model/_client/sql/model_version.py +22 -1
  17. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
  18. snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
  19. snowflake/ml/model/_model_composer/model_composer.py +7 -5
  20. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
  22. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  23. snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
  24. snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
  25. snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
  29. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
  32. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
  33. snowflake/ml/model/_packager/model_packager.py +2 -2
  34. snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +21 -2
  38. snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
  40. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  41. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
  46. snowflake/ml/modeling/cluster/birch.py +195 -123
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
  48. snowflake/ml/modeling/cluster/dbscan.py +195 -123
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
  50. snowflake/ml/modeling/cluster/k_means.py +195 -123
  51. snowflake/ml/modeling/cluster/mean_shift.py +195 -123
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
  53. snowflake/ml/modeling/cluster/optics.py +195 -123
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
  57. snowflake/ml/modeling/compose/column_transformer.py +195 -123
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
  64. snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
  65. snowflake/ml/modeling/covariance/oas.py +195 -123
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
  69. snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
  74. snowflake/ml/modeling/decomposition/pca.py +195 -123
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
  103. snowflake/ml/modeling/framework/_utils.py +8 -1
  104. snowflake/ml/modeling/framework/base.py +24 -6
  105. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
  107. snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
  108. snowflake/ml/modeling/impute/knn_imputer.py +195 -123
  109. snowflake/ml/modeling/impute/missing_indicator.py +195 -123
  110. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
  119. snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
  121. snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
  125. snowflake/ml/modeling/linear_model/lars.py +195 -123
  126. snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
  127. snowflake/ml/modeling/linear_model/lasso.py +195 -123
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
  132. snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
  142. snowflake/ml/modeling/linear_model/perceptron.py +195 -123
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
  145. snowflake/ml/modeling/linear_model/ridge.py +195 -123
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
  154. snowflake/ml/modeling/manifold/isomap.py +195 -123
  155. snowflake/ml/modeling/manifold/mds.py +195 -123
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
  157. snowflake/ml/modeling/manifold/tsne.py +195 -123
  158. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
  159. snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
  160. snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
  161. snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
  162. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
  163. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
  164. snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
  165. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
  166. snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
  167. snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
  168. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
  169. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
  170. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
  171. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
  172. snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
  173. snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
  174. snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
  175. snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
  176. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
  177. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
  178. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
  179. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
  180. snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
  181. snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
  182. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  183. snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
  184. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
  185. snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
  186. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
  187. snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
  188. snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
  189. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
  190. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
  191. snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
  192. snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
  193. snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
  194. snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
  195. snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
  196. snowflake/ml/modeling/svm/linear_svc.py +195 -123
  197. snowflake/ml/modeling/svm/linear_svr.py +195 -123
  198. snowflake/ml/modeling/svm/nu_svc.py +195 -123
  199. snowflake/ml/modeling/svm/nu_svr.py +195 -123
  200. snowflake/ml/modeling/svm/svc.py +195 -123
  201. snowflake/ml/modeling/svm/svr.py +195 -123
  202. snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
  203. snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
  204. snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
  205. snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
  206. snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
  207. snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
  208. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
  209. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
  210. snowflake/ml/registry/_manager/model_manager.py +5 -1
  211. snowflake/ml/registry/model_registry.py +99 -26
  212. snowflake/ml/registry/registry.py +3 -2
  213. snowflake/ml/version.py +1 -1
  214. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
  215. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
  216. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
  217. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
  218. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
  219. {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cloudpickle as cp
10
10
  import numpy as np
11
+ import numpy.typing as npt
11
12
  from sklearn import model_selection
12
13
  from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
13
14
 
@@ -38,9 +39,11 @@ from snowflake.snowpark.types import IntegerType, StringType, StructField, Struc
38
39
 
39
40
  cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
40
41
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
42
+ cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
41
43
 
42
44
  _PROJECT = "ModelDevelopment"
43
45
  DEFAULT_UDTF_NJOBS = 3
46
+ ENABLE_EFFICIENT_MEMORY_USAGE = False
44
47
 
45
48
 
46
49
  def construct_cv_results(
@@ -151,7 +154,63 @@ def construct_cv_results(
151
154
  return multimetric, estimator._format_results(param_grid, n_split, out)
152
155
 
153
156
 
157
+ def construct_cv_results_new_implementation(
158
+ estimator: Union[GridSearchCV, RandomizedSearchCV],
159
+ n_split: int,
160
+ param_grid: List[Dict[str, Any]],
161
+ cv_results_raw_hex: List[Row],
162
+ cross_validator_indices_length: int,
163
+ parameter_grid_length: int,
164
+ ) -> Tuple[Any, Dict[str, Any]]:
165
+ """Construct the cross validation result from the UDF.
166
+ The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
167
+ This function need to decode the string and then call _format_result to stick them back together
168
+ to align with original sklearn result.
169
+
170
+ Args:
171
+ estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
172
+ GridSearchCV or RandomizedSearchCV
173
+ n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
174
+ param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
175
+ cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
176
+ Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
177
+ json format. Each cv_result is encoded into hex string.
178
+ cross_validator_indices_length (int): the length of cross validator indices
179
+ parameter_grid_length (int): the length of parameter grid combination
180
+
181
+ Raises:
182
+ ValueError: Retrieved empty cross validation results
183
+ ValueError: Cross validator index length is 0
184
+ ValueError: Parameter index length is 0
185
+ ValueError: Retrieved incorrect dataframe dimension from Snowpark's UDTF.
186
+
187
+ Returns:
188
+ Tuple[Any, Dict[str, Any]]: returns first_test_score, cv_results_
189
+ """
190
+ # Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
191
+ if len(cv_results_raw_hex) == 0:
192
+ raise ValueError(
193
+ "Retrieved empty cross validation results from snowpark. Please retry or contact snowflake support."
194
+ )
195
+ if cross_validator_indices_length == 0:
196
+ raise ValueError("Cross validator index length is 0. Was the CV iterator empty? ")
197
+ if parameter_grid_length == 0:
198
+ raise ValueError("Parameter index length is 0. Were there no candidates?")
199
+
200
+ all_out = []
201
+
202
+ for each_cv_result_hex in cv_results_raw_hex:
203
+ # convert the hex string back to cv_results_
204
+ hex_str = bytes.fromhex(each_cv_result_hex[0])
205
+ with io.BytesIO(hex_str) as f_reload:
206
+ out = cp.load(f_reload)
207
+ all_out.extend(out)
208
+ first_test_score = all_out[0]["test_scores"]
209
+ return first_test_score, estimator._format_results(param_grid, n_split, all_out)
210
+
211
+
154
212
  cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
213
+ cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_new_implementation))
155
214
 
156
215
 
157
216
  class DistributedHPOTrainer(SnowparkModelTrainer):
@@ -602,6 +661,479 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
602
661
 
603
662
  return fit_estimator
604
663
 
664
+ def fit_search_snowpark_new_implementation(
665
+ self,
666
+ param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
667
+ dataset: DataFrame,
668
+ session: Session,
669
+ estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
670
+ dependencies: List[str],
671
+ udf_imports: List[str],
672
+ input_cols: List[str],
673
+ label_cols: Optional[List[str]],
674
+ sample_weight_col: Optional[str],
675
+ ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
676
+ from itertools import product
677
+
678
+ import cachetools
679
+ from sklearn.base import clone, is_classifier
680
+ from sklearn.calibration import check_cv
681
+
682
+ # Create one stage for data and for estimators.
683
+ temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
684
+ temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
685
+ session.sql(temp_stage_creation_query).collect()
686
+
687
+ # Stage data as parquet file
688
+ dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
689
+ dataset_file_name = "dataset"
690
+ remote_file_path = f"{temp_stage_name}/{dataset_file_name}.parquet"
691
+ dataset.write.copy_into_location( # type:ignore[call-overload]
692
+ remote_file_path, file_format_type="parquet", header=True, overwrite=True
693
+ )
694
+ imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
695
+
696
+ # Create a temp file and dump the estimator to that file.
697
+ estimator_file_name = get_temp_file_path()
698
+ params_to_evaluate = list(param_grid)
699
+ n_candidates = len(params_to_evaluate)
700
+ _N_JOBS = estimator.n_jobs
701
+ _PRE_DISPATCH = estimator.pre_dispatch
702
+
703
+ with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
704
+ cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
705
+ stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name))
706
+ sproc_statement_params = telemetry.get_function_usage_statement_params(
707
+ project=_PROJECT,
708
+ subproject=self._subproject,
709
+ function_name=telemetry.get_statement_params_full_func_name(
710
+ inspect.currentframe(), self.__class__.__name__
711
+ ),
712
+ api_calls=[sproc],
713
+ )
714
+ udtf_statement_params = telemetry.get_function_usage_statement_params(
715
+ project=_PROJECT,
716
+ subproject=self._subproject,
717
+ function_name=telemetry.get_statement_params_full_func_name(
718
+ inspect.currentframe(), self.__class__.__name__
719
+ ),
720
+ api_calls=[udtf],
721
+ custom_tags=dict([("hpo_udtf", True)]),
722
+ )
723
+
724
+ # Put locally serialized estimator on stage.
725
+ session.file.put(
726
+ estimator_file_name,
727
+ temp_stage_name,
728
+ auto_compress=False,
729
+ overwrite=True,
730
+ )
731
+ estimator_location = os.path.basename(estimator_file_name)
732
+ imports.append(f"@{temp_stage_name}/{estimator_location}")
733
+
734
+ search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
735
+ random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
736
+
737
+ required_deps = dependencies + [
738
+ "snowflake-snowpark-python<2",
739
+ "fastparquet<2023.11",
740
+ "pyarrow<14",
741
+ "cachetools<6",
742
+ ]
743
+
744
+ @sproc( # type: ignore[misc]
745
+ is_permanent=False,
746
+ name=search_sproc_name,
747
+ packages=required_deps, # type: ignore[arg-type]
748
+ replace=True,
749
+ session=session,
750
+ anonymous=True,
751
+ imports=imports, # type: ignore[arg-type]
752
+ statement_params=sproc_statement_params,
753
+ )
754
+ def _distributed_search(
755
+ session: Session,
756
+ imports: List[str],
757
+ stage_estimator_file_name: str,
758
+ input_cols: List[str],
759
+ label_cols: Optional[List[str]],
760
+ ) -> str:
761
+ import os
762
+ import time
763
+ from typing import Iterator
764
+
765
+ import cloudpickle as cp
766
+ import pandas as pd
767
+ import pyarrow.parquet as pq
768
+ from sklearn.metrics import check_scoring
769
+ from sklearn.metrics._scorer import _check_multimetric_scoring
770
+ from sklearn.utils.validation import _check_fit_params, indexable
771
+
772
+ # import packages in sproc
773
+ for import_name in udf_imports:
774
+ importlib.import_module(import_name)
775
+
776
+ # os.cpu_count() returns the number of logical CPUs in the system. Returns None if undetermined.
777
+ _NUM_CPUs = os.cpu_count() or 1
778
+
779
+ # load dataset
780
+ data_files = [
781
+ filename
782
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
783
+ if filename.startswith(dataset_file_name)
784
+ ]
785
+ partial_df = [
786
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
787
+ for file_name in data_files
788
+ ]
789
+ df = pd.concat(partial_df, ignore_index=True)
790
+ df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
791
+
792
+ X = df[input_cols]
793
+ y = df[label_cols].squeeze() if label_cols else None
794
+ DATA_LENGTH = len(df)
795
+ fit_params = {}
796
+ if sample_weight_col:
797
+ fit_params["sample_weight"] = df[sample_weight_col].squeeze()
798
+
799
+ local_estimator_file_folder_name = get_temp_file_path()
800
+ session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
801
+
802
+ local_estimator_file_path = os.path.join(
803
+ local_estimator_file_folder_name, os.listdir(local_estimator_file_folder_name)[0]
804
+ )
805
+ with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
806
+ estimator = cp.load(local_estimator_file_obj)["estimator"]
807
+
808
+ # preprocess the attributes - (1) scorer
809
+ refit_metric = "score"
810
+ if callable(estimator.scoring):
811
+ scorers = estimator.scoring
812
+ elif estimator.scoring is None or isinstance(estimator.scoring, str):
813
+ scorers = check_scoring(estimator.estimator, estimator.scoring)
814
+ else:
815
+ scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
816
+ estimator._check_refit_for_multimetric(scorers)
817
+ refit_metric = estimator.refit
818
+
819
+ # preprocess the attributes - (2) check fit_params
820
+ groups = None
821
+ X, y, _ = indexable(X, y, groups)
822
+ fit_params = _check_fit_params(X, fit_params)
823
+
824
+ # preprocess the attributes - (3) safe clone base estimator
825
+ base_estimator = clone(estimator.estimator)
826
+
827
+ # preprocess the attributes - (4) check cv
828
+ build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
829
+ n_splits = build_cross_validator.get_n_splits(X, y, groups)
830
+
831
+ # preprocess the attributes - (5) generate fit_and_score_kwargs
832
+ fit_and_score_kwargs = dict(
833
+ scorer=scorers,
834
+ fit_params=fit_params,
835
+ return_train_score=estimator.return_train_score,
836
+ return_n_test_samples=True,
837
+ return_times=True,
838
+ return_parameters=False,
839
+ error_score=estimator.error_score,
840
+ verbose=estimator.verbose,
841
+ )
842
+
843
+ # (1) store the cross_validator's test indices only to save space
844
+ cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
845
+ local_indices_file_name = get_temp_file_path()
846
+ with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
847
+ cp.dump(cross_validator_indices, local_indices_file_obj)
848
+
849
+ # Put locally serialized indices on stage.
850
+ session.file.put(
851
+ local_indices_file_name,
852
+ temp_stage_name,
853
+ auto_compress=False,
854
+ overwrite=True,
855
+ )
856
+ indices_location = os.path.basename(local_indices_file_name)
857
+ imports.append(f"@{temp_stage_name}/{indices_location}")
858
+
859
+ # (2) store the base estimator
860
+ local_base_estimator_file_name = get_temp_file_path()
861
+ with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
862
+ cp.dump(base_estimator, local_base_estimator_file_obj)
863
+ session.file.put(
864
+ local_base_estimator_file_name,
865
+ temp_stage_name,
866
+ auto_compress=False,
867
+ overwrite=True,
868
+ )
869
+ base_estimator_location = os.path.basename(local_base_estimator_file_name)
870
+ imports.append(f"@{temp_stage_name}/{base_estimator_location}")
871
+
872
+ # (3) store the fit_and_score_kwargs
873
+ local_fit_and_score_kwargs_file_name = get_temp_file_path()
874
+ with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
875
+ cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
876
+ session.file.put(
877
+ local_fit_and_score_kwargs_file_name,
878
+ temp_stage_name,
879
+ auto_compress=False,
880
+ overwrite=True,
881
+ )
882
+ fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
883
+ imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
884
+
885
+ cross_validator_indices_length = int(len(cross_validator_indices))
886
+ parameter_grid_length = len(param_grid)
887
+
888
+ assert estimator is not None
889
+
890
+ @cachetools.cached(cache={})
891
+ def _load_data_into_udf() -> Tuple[
892
+ npt.NDArray[Any],
893
+ npt.NDArray[Any],
894
+ List[List[int]],
895
+ List[Dict[str, Any]],
896
+ object,
897
+ Dict[str, Any],
898
+ ]:
899
+ import pyarrow.parquet as pq
900
+
901
+ data_files = [
902
+ filename
903
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
904
+ if filename.startswith(dataset_file_name)
905
+ ]
906
+ partial_df = [
907
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
908
+ for file_name in data_files
909
+ ]
910
+ df = pd.concat(partial_df, ignore_index=True)
911
+ df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
912
+
913
+ # load parameter grid
914
+ local_estimator_file_path = os.path.join(
915
+ sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
916
+ )
917
+ with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
918
+ estimator_objects = cp.load(local_estimator_file_obj)
919
+ params_to_evaluate = estimator_objects["param_grid"]
920
+
921
+ # load indices
922
+ local_indices_file_path = os.path.join(
923
+ sys._xoptions["snowflake_import_directory"], f"{indices_location}"
924
+ )
925
+ with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
926
+ indices = cp.load(local_indices_file_obj)
927
+
928
+ # load base estimator
929
+ local_base_estimator_file_path = os.path.join(
930
+ sys._xoptions["snowflake_import_directory"], f"{base_estimator_location}"
931
+ )
932
+ with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
933
+ base_estimator = cp.load(local_base_estimator_file_obj)
934
+
935
+ # load fit_and_score_kwargs
936
+ local_fit_and_score_kwargs_file_path = os.path.join(
937
+ sys._xoptions["snowflake_import_directory"], f"{fit_and_score_kwargs_location}"
938
+ )
939
+ with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
940
+ fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
941
+
942
+ # convert dataframe to numpy would save memory consumption
943
+ return (
944
+ df[input_cols].to_numpy(),
945
+ df[label_cols].squeeze().to_numpy(),
946
+ indices,
947
+ params_to_evaluate,
948
+ base_estimator,
949
+ fit_and_score_kwargs,
950
+ )
951
+
952
+ # Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
953
+ class SearchCV:
954
+ def __init__(self) -> None:
955
+ X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs = _load_data_into_udf()
956
+ self.X = X
957
+ self.y = y
958
+ self.test_indices = indices
959
+ self.params_to_evaluate = params_to_evaluate
960
+ self.base_estimator = base_estimator
961
+ self.fit_and_score_kwargs = fit_and_score_kwargs
962
+ self.fit_score_params: List[Any] = []
963
+ self.cached_train_test_indices = []
964
+ # Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
965
+ full_index = np.arange(DATA_LENGTH)
966
+ for i in range(n_splits):
967
+ self.cached_train_test_indices.extend(
968
+ [[np.setdiff1d(full_index, self.test_indices[i]), self.test_indices[i]]]
969
+ )
970
+
971
+ def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
972
+ self.fit_score_params.extend([[idx, params_idx, cv_idx]])
973
+
974
+ def end_partition(self) -> Iterator[Tuple[int, str]]:
975
+ from sklearn.base import clone
976
+ from sklearn.model_selection._validation import _fit_and_score
977
+ from sklearn.utils.parallel import Parallel, delayed
978
+
979
+ parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
980
+
981
+ out = parallel(
982
+ delayed(_fit_and_score)(
983
+ clone(self.base_estimator),
984
+ self.X,
985
+ self.y,
986
+ train=self.cached_train_test_indices[split_idx][0],
987
+ test=self.cached_train_test_indices[split_idx][1],
988
+ parameters=self.params_to_evaluate[cand_idx],
989
+ split_progress=(split_idx, n_splits),
990
+ candidate_progress=(cand_idx, n_candidates),
991
+ **self.fit_and_score_kwargs, # load sample weight here
992
+ )
993
+ for _, cand_idx, split_idx in self.fit_score_params
994
+ )
995
+
996
+ binary_cv_results = None
997
+ with io.BytesIO() as f:
998
+ cp.dump(out, f)
999
+ f.seek(0)
1000
+ binary_cv_results = f.getvalue().hex()
1001
+ yield (
1002
+ self.fit_score_params[0][0],
1003
+ binary_cv_results,
1004
+ )
1005
+
1006
+ session.udtf.register(
1007
+ SearchCV,
1008
+ output_schema=StructType([StructField("IDX", IntegerType()), StructField("CV_RESULTS", StringType())]),
1009
+ input_types=[IntegerType(), IntegerType(), IntegerType()],
1010
+ name=random_udtf_name,
1011
+ packages=required_deps, # type: ignore[arg-type]
1012
+ replace=True,
1013
+ is_permanent=False,
1014
+ imports=imports, # type: ignore[arg-type]
1015
+ statement_params=udtf_statement_params,
1016
+ )
1017
+
1018
+ HP_TUNING = F.table_function(random_udtf_name)
1019
+
1020
+ # param_indices is for the index for each parameter grid;
1021
+ # cv_indices is for the index for each cross_validator's fold;
1022
+ # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
1023
+ param_indices, cv_indices = zip(
1024
+ *product(range(parameter_grid_length), range(cross_validator_indices_length))
1025
+ )
1026
+
1027
+ indices_info_pandas = pd.DataFrame(
1028
+ {
1029
+ "IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
1030
+ "PARAM_IND": param_indices,
1031
+ "CV_IND": cv_indices,
1032
+ }
1033
+ )
1034
+
1035
+ indices_info_sp = session.create_dataframe(indices_info_pandas)
1036
+ # execute udtf by querying HP_TUNING table
1037
+ HP_raw_results = indices_info_sp.select(
1038
+ (
1039
+ HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
1040
+ partition_by="IDX"
1041
+ )
1042
+ ),
1043
+ )
1044
+
1045
+ first_test_score, cv_results_ = construct_cv_results_new_implementation(
1046
+ estimator,
1047
+ n_splits,
1048
+ list(param_grid),
1049
+ HP_raw_results.select("CV_RESULTS").sort(F.col("IDX")).collect(),
1050
+ cross_validator_indices_length,
1051
+ parameter_grid_length,
1052
+ )
1053
+
1054
+ estimator.cv_results_ = cv_results_
1055
+ estimator.multimetric_ = isinstance(first_test_score, dict)
1056
+
1057
+ # check refit_metric now for a callable scorer that is multimetric
1058
+ if callable(estimator.scoring) and estimator.multimetric_:
1059
+ estimator._check_refit_for_multimetric(first_test_score)
1060
+ refit_metric = estimator.refit
1061
+
1062
+ # For multi-metric evaluation, store the best_index_, best_params_ and
1063
+ # best_score_ iff refit is one of the scorer names
1064
+ # In single metric evaluation, refit_metric is "score"
1065
+ if estimator.refit or not estimator.multimetric_:
1066
+ estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1067
+ if not callable(estimator.refit):
1068
+ # With a non-custom callable, we can select the best score
1069
+ # based on the best index
1070
+ estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1071
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1072
+
1073
+ if estimator.refit:
1074
+ estimator.best_estimator_ = clone(base_estimator).set_params(
1075
+ **clone(estimator.best_params_, safe=False)
1076
+ )
1077
+
1078
+ # Let the sproc use all cores to refit.
1079
+ estimator.n_jobs = estimator.n_jobs or -1
1080
+
1081
+ # process the input as args
1082
+ argspec = inspect.getfullargspec(estimator.fit)
1083
+ args = {"X": X}
1084
+ if label_cols:
1085
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
1086
+ args[label_arg_name] = y
1087
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
1088
+ args["sample_weight"] = df[sample_weight_col].squeeze()
1089
+ # estimator.refit = original_refit
1090
+ refit_start_time = time.time()
1091
+ estimator.best_estimator_.fit(**args)
1092
+ refit_end_time = time.time()
1093
+ estimator.refit_time_ = refit_end_time - refit_start_time
1094
+
1095
+ if hasattr(estimator.best_estimator_, "feature_names_in_"):
1096
+ estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1097
+
1098
+ # Store the only scorer not as a dict for single metric evaluation
1099
+ estimator.scorer_ = scorers
1100
+ estimator.n_splits_ = n_splits
1101
+
1102
+ local_result_file_name = get_temp_file_path()
1103
+
1104
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1105
+ cp.dump(estimator, local_result_file_obj)
1106
+
1107
+ session.file.put(
1108
+ local_result_file_name,
1109
+ temp_stage_name,
1110
+ auto_compress=False,
1111
+ overwrite=True,
1112
+ )
1113
+
1114
+ return str(os.path.basename(local_result_file_name))
1115
+
1116
+ sproc_export_file_name = _distributed_search(
1117
+ session,
1118
+ imports,
1119
+ stage_estimator_file_name,
1120
+ input_cols,
1121
+ label_cols,
1122
+ )
1123
+
1124
+ local_estimator_path = get_temp_file_path()
1125
+ session.file.get(
1126
+ posixpath.join(temp_stage_name, sproc_export_file_name),
1127
+ local_estimator_path,
1128
+ )
1129
+
1130
+ with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
1131
+ fit_estimator = cp.load(result_file_obj)
1132
+
1133
+ cleanup_temp_files(local_estimator_path)
1134
+
1135
+ return fit_estimator
1136
+
605
1137
  def train(self) -> object:
606
1138
  """
607
1139
  Runs hyper parameter optimization by distributing the tasks across warehouse.
@@ -630,6 +1162,19 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
630
1162
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
631
1163
  pkg_versions=model_spec.pkgDependencies, session=self.session
632
1164
  )
1165
+ if ENABLE_EFFICIENT_MEMORY_USAGE:
1166
+ return self.fit_search_snowpark_new_implementation(
1167
+ param_grid=param_grid,
1168
+ dataset=self.dataset,
1169
+ session=self.session,
1170
+ estimator=self.estimator,
1171
+ dependencies=relaxed_dependencies,
1172
+ udf_imports=["sklearn"],
1173
+ input_cols=self.input_cols,
1174
+ label_cols=self.label_cols,
1175
+ sample_weight_col=self.sample_weight_col,
1176
+ )
1177
+
633
1178
  return self.fit_search_snowpark(
634
1179
  param_grid=param_grid,
635
1180
  dataset=self.dataset,
@@ -131,9 +131,12 @@ class SnowparkTransformHandlers:
131
131
 
132
132
  input_df.columns = snowpark_cols
133
133
 
134
+ if hasattr(estimator, "n_jobs"):
135
+ # Vectorized UDF cannot handle joblib multiprocessing right now, deactivate the n_jobs
136
+ estimator.n_jobs = 1
134
137
  inference_res = getattr(estimator, inference_method)(input_df, *args, **kwargs)
135
138
 
136
- transformed_numpy_array, output_cols = handle_inference_result(
139
+ transformed_numpy_array, _ = handle_inference_result(
137
140
  inference_res=inference_res,
138
141
  output_cols=expected_output_cols,
139
142
  inference_method=inference_method,
@@ -141,13 +144,13 @@ class SnowparkTransformHandlers:
141
144
  )
142
145
 
143
146
  if len(transformed_numpy_array.shape) > 1:
144
- if transformed_numpy_array.shape[1] != len(output_cols):
147
+ if transformed_numpy_array.shape[1] != len(expected_output_cols):
145
148
  series = pd.Series(transformed_numpy_array.tolist())
146
- transformed_pandas_df = pd.DataFrame(series, columns=output_cols)
149
+ transformed_pandas_df = pd.DataFrame(series, columns=expected_output_cols)
147
150
  else:
148
- transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=output_cols)
151
+ transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=expected_output_cols)
149
152
  else:
150
- transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=output_cols)
153
+ transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=expected_output_cols)
151
154
 
152
155
  return transformed_pandas_df.to_dict("records") # type: ignore[no-any-return]
153
156