snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. snowflake/cortex/_sentiment.py +7 -4
  2. snowflake/ml/_internal/env_utils.py +6 -0
  3. snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
  4. snowflake/ml/_internal/telemetry.py +1 -0
  5. snowflake/ml/_internal/utils/identifier.py +1 -1
  6. snowflake/ml/_internal/utils/sql_identifier.py +14 -1
  7. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  8. snowflake/ml/dataset/__init__.py +2 -1
  9. snowflake/ml/dataset/dataset.py +4 -3
  10. snowflake/ml/dataset/dataset_reader.py +5 -8
  11. snowflake/ml/feature_store/__init__.py +6 -0
  12. snowflake/ml/feature_store/access_manager.py +283 -0
  13. snowflake/ml/feature_store/feature_store.py +160 -100
  14. snowflake/ml/feature_store/feature_view.py +30 -19
  15. snowflake/ml/fileset/embedded_stage_fs.py +15 -12
  16. snowflake/ml/fileset/snowfs.py +2 -30
  17. snowflake/ml/fileset/stage_fs.py +25 -7
  18. snowflake/ml/model/_client/model/model_impl.py +46 -39
  19. snowflake/ml/model/_client/model/model_version_impl.py +24 -2
  20. snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
  21. snowflake/ml/model/_client/ops/model_ops.py +174 -16
  22. snowflake/ml/model/_client/sql/_base.py +34 -0
  23. snowflake/ml/model/_client/sql/model.py +32 -39
  24. snowflake/ml/model/_client/sql/model_version.py +111 -42
  25. snowflake/ml/model/_client/sql/stage.py +6 -32
  26. snowflake/ml/model/_client/sql/tag.py +32 -56
  27. snowflake/ml/model/_model_composer/model_composer.py +8 -4
  28. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
  29. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  30. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  31. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
  32. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
  33. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
  34. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
  35. snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
  36. snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
  37. snowflake/ml/modeling/cluster/birch.py +8 -1
  38. snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
  39. snowflake/ml/modeling/cluster/dbscan.py +8 -1
  40. snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
  41. snowflake/ml/modeling/cluster/k_means.py +8 -1
  42. snowflake/ml/modeling/cluster/mean_shift.py +8 -1
  43. snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
  44. snowflake/ml/modeling/cluster/optics.py +8 -1
  45. snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
  46. snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
  47. snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
  48. snowflake/ml/modeling/compose/column_transformer.py +8 -1
  49. snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
  50. snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
  51. snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
  52. snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
  53. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
  54. snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
  55. snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
  56. snowflake/ml/modeling/covariance/oas.py +8 -1
  57. snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
  58. snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
  59. snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
  60. snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
  61. snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
  62. snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
  63. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
  64. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
  65. snowflake/ml/modeling/decomposition/pca.py +8 -1
  66. snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
  67. snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
  68. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
  69. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
  70. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
  71. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
  72. snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
  73. snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
  74. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
  75. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
  76. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
  77. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
  79. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
  80. snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
  81. snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
  82. snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
  83. snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
  84. snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
  85. snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
  86. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
  87. snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
  88. snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
  89. snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
  90. snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
  91. snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
  92. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
  93. snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
  94. snowflake/ml/modeling/framework/base.py +4 -3
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
  96. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
  97. snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
  98. snowflake/ml/modeling/impute/knn_imputer.py +8 -1
  99. snowflake/ml/modeling/impute/missing_indicator.py +8 -1
  100. snowflake/ml/modeling/impute/simple_imputer.py +21 -2
  101. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
  102. snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
  103. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
  104. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
  105. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
  106. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
  107. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
  108. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
  109. snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
  110. snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
  111. snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
  112. snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
  113. snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
  114. snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
  115. snowflake/ml/modeling/linear_model/lars.py +8 -1
  116. snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
  117. snowflake/ml/modeling/linear_model/lasso.py +8 -1
  118. snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
  119. snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
  120. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
  122. snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
  123. snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
  124. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
  126. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
  127. snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
  128. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
  129. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
  130. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
  131. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
  132. snowflake/ml/modeling/linear_model/perceptron.py +8 -1
  133. snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
  134. snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
  135. snowflake/ml/modeling/linear_model/ridge.py +8 -1
  136. snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
  137. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
  138. snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
  139. snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
  140. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
  141. snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
  142. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
  143. snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
  144. snowflake/ml/modeling/manifold/isomap.py +8 -1
  145. snowflake/ml/modeling/manifold/mds.py +8 -1
  146. snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
  147. snowflake/ml/modeling/manifold/tsne.py +8 -1
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
  150. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
  151. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
  152. snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
  153. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
  154. snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
  155. snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
  156. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
  157. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
  158. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
  159. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
  160. snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
  161. snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
  162. snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
  163. snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
  164. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
  165. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
  166. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
  167. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
  168. snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
  169. snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
  170. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  171. snowflake/ml/modeling/pipeline/pipeline.py +27 -7
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
  173. snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
  174. snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
  175. snowflake/ml/modeling/svm/linear_svc.py +8 -1
  176. snowflake/ml/modeling/svm/linear_svr.py +8 -1
  177. snowflake/ml/modeling/svm/nu_svc.py +8 -1
  178. snowflake/ml/modeling/svm/nu_svr.py +8 -1
  179. snowflake/ml/modeling/svm/svc.py +8 -1
  180. snowflake/ml/modeling/svm/svr.py +8 -1
  181. snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
  182. snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
  183. snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
  184. snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
  185. snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
  186. snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
  187. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
  188. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
  189. snowflake/ml/registry/_manager/model_manager.py +95 -8
  190. snowflake/ml/registry/registry.py +10 -1
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
  193. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
  194. snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
  195. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
  196. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
  197. {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  import cloudpickle as cp
10
10
  import numpy as np
11
- import numpy.typing as npt
12
11
  from sklearn import model_selection
13
12
  from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
14
13
 
@@ -36,6 +35,7 @@ from snowflake.snowpark._internal.utils import (
36
35
  from snowflake.snowpark.functions import sproc, udtf
37
36
  from snowflake.snowpark.row import Row
38
37
  from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
38
+ from snowflake.snowpark.udtf import UDTFRegistration
39
39
 
40
40
  cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
41
41
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
@@ -154,7 +154,7 @@ def construct_cv_results(
154
154
  return multimetric, estimator._format_results(param_grid, n_split, out)
155
155
 
156
156
 
157
- def construct_cv_results_new_implementation(
157
+ def construct_cv_results_memory_efficient_version(
158
158
  estimator: Union[GridSearchCV, RandomizedSearchCV],
159
159
  n_split: int,
160
160
  param_grid: List[Dict[str, Any]],
@@ -205,12 +205,35 @@ def construct_cv_results_new_implementation(
205
205
  with io.BytesIO(hex_str) as f_reload:
206
206
  out = cp.load(f_reload)
207
207
  all_out.extend(out)
208
+
209
+ # because original SearchCV is ranked by parameter first and cv second,
210
+ # to make the memory efficient, we implemented by fitting on cv first and parameter second
211
+ # when retrieving the results back, the ordering should revert back to remain the same result as original SearchCV
212
+ def generate_the_order_by_parameter_index(all_combination_length: int) -> List[int]:
213
+ pattern = []
214
+ for i in range(all_combination_length):
215
+ if i % parameter_grid_length == 0:
216
+ pattern.append(i)
217
+ for i in range(1, parameter_grid_length):
218
+ for j in range(all_combination_length):
219
+ if j % parameter_grid_length == i:
220
+ pattern.append(j)
221
+ return pattern
222
+
223
+ def rerank_array(original_array: List[Any], pattern: List[int]) -> List[Any]:
224
+ reranked_array = []
225
+ for index in pattern:
226
+ reranked_array.append(original_array[index])
227
+ return reranked_array
228
+
229
+ pattern = generate_the_order_by_parameter_index(len(all_out))
230
+ reranked_all_out = rerank_array(all_out, pattern)
208
231
  first_test_score = all_out[0]["test_scores"]
209
- return first_test_score, estimator._format_results(param_grid, n_split, all_out)
232
+ return first_test_score, estimator._format_results(param_grid, n_split, reranked_all_out)
210
233
 
211
234
 
212
235
  cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
213
- cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_new_implementation))
236
+ cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_memory_efficient_version))
214
237
 
215
238
 
216
239
  class DistributedHPOTrainer(SnowparkModelTrainer):
@@ -661,7 +684,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
661
684
 
662
685
  return fit_estimator
663
686
 
664
- def fit_search_snowpark_new_implementation(
687
+ def fit_search_snowpark_enable_efficient_memory_usage(
665
688
  self,
666
689
  param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
667
690
  dataset: DataFrame,
@@ -675,7 +698,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
675
698
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
676
699
  from itertools import product
677
700
 
678
- import cachetools
679
701
  from sklearn.base import clone, is_classifier
680
702
  from sklearn.calibration import check_cv
681
703
 
@@ -696,9 +718,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
696
718
  # Create a temp file and dump the estimator to that file.
697
719
  estimator_file_name = get_temp_file_path()
698
720
  params_to_evaluate = list(param_grid)
699
- n_candidates = len(params_to_evaluate)
700
- _N_JOBS = estimator.n_jobs
701
- _PRE_DISPATCH = estimator.pre_dispatch
721
+ CONSTANTS: Dict[str, Any] = dict()
722
+ CONSTANTS["dataset_snowpark_cols"] = dataset.columns
723
+ CONSTANTS["n_candidates"] = len(params_to_evaluate)
724
+ CONSTANTS["_N_JOBS"] = estimator.n_jobs
725
+ CONSTANTS["_PRE_DISPATCH"] = estimator.pre_dispatch
702
726
 
703
727
  with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
704
728
  cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
@@ -718,7 +742,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
718
742
  inspect.currentframe(), self.__class__.__name__
719
743
  ),
720
744
  api_calls=[udtf],
721
- custom_tags=dict([("hpo_udtf", True)]),
745
+ custom_tags=dict([("hpo_memory_efficient", True)]),
746
+ )
747
+ from snowflake.ml.modeling._internal.snowpark_implementations.distributed_search_udf_file import (
748
+ execute_template,
722
749
  )
723
750
 
724
751
  # Put locally serialized estimator on stage.
@@ -730,6 +757,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
730
757
  )
731
758
  estimator_location = os.path.basename(estimator_file_name)
732
759
  imports.append(f"@{temp_stage_name}/{estimator_location}")
760
+ CONSTANTS["estimator_location"] = estimator_location
733
761
 
734
762
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
735
763
  random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
@@ -760,7 +788,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
760
788
  ) -> str:
761
789
  import os
762
790
  import time
763
- from typing import Iterator
764
791
 
765
792
  import cloudpickle as cp
766
793
  import pandas as pd
@@ -882,146 +909,67 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
882
909
  fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
883
910
  imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
884
911
 
885
- cross_validator_indices_length = int(len(cross_validator_indices))
886
- parameter_grid_length = len(param_grid)
887
-
888
- assert estimator is not None
889
-
890
- @cachetools.cached(cache={})
891
- def _load_data_into_udf() -> Tuple[
892
- npt.NDArray[Any],
893
- npt.NDArray[Any],
894
- List[List[int]],
895
- List[Dict[str, Any]],
896
- object,
897
- Dict[str, Any],
898
- ]:
899
- import pyarrow.parquet as pq
912
+ CONSTANTS["input_cols"] = input_cols
913
+ CONSTANTS["label_cols"] = label_cols
914
+ CONSTANTS["DATA_LENGTH"] = DATA_LENGTH
915
+ CONSTANTS["n_splits"] = n_splits
916
+ CONSTANTS["indices_location"] = indices_location
917
+ CONSTANTS["base_estimator_location"] = base_estimator_location
918
+ CONSTANTS["fit_and_score_kwargs_location"] = fit_and_score_kwargs_location
900
919
 
901
- data_files = [
902
- filename
903
- for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
904
- if filename.startswith(dataset_file_name)
905
- ]
906
- partial_df = [
907
- pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
908
- for file_name in data_files
909
- ]
910
- df = pd.concat(partial_df, ignore_index=True)
911
- df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
920
+ # (6) store the constants
921
+ local_constant_file_name = get_temp_file_path(prefix="constant")
922
+ with open(local_constant_file_name, mode="w+b") as local_indices_file_obj:
923
+ cp.dump(CONSTANTS, local_indices_file_obj)
912
924
 
913
- # load parameter grid
914
- local_estimator_file_path = os.path.join(
915
- sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
916
- )
917
- with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
918
- estimator_objects = cp.load(local_estimator_file_obj)
919
- params_to_evaluate = estimator_objects["param_grid"]
925
+ # Put locally serialized indices on stage.
926
+ session.file.put(
927
+ local_constant_file_name,
928
+ temp_stage_name,
929
+ auto_compress=False,
930
+ overwrite=True,
931
+ )
932
+ constant_location = os.path.basename(local_constant_file_name)
933
+ imports.append(f"@{temp_stage_name}/{constant_location}")
920
934
 
921
- # load indices
922
- local_indices_file_path = os.path.join(
923
- sys._xoptions["snowflake_import_directory"], f"{indices_location}"
924
- )
925
- with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
926
- indices = cp.load(local_indices_file_obj)
935
+ cross_validator_indices_length = int(len(cross_validator_indices))
936
+ parameter_grid_length = len(param_grid)
927
937
 
928
- # load base estimator
929
- local_base_estimator_file_path = os.path.join(
930
- sys._xoptions["snowflake_import_directory"], f"{base_estimator_location}"
931
- )
932
- with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
933
- base_estimator = cp.load(local_base_estimator_file_obj)
938
+ assert estimator is not None
934
939
 
935
- # load fit_and_score_kwargs
936
- local_fit_and_score_kwargs_file_path = os.path.join(
937
- sys._xoptions["snowflake_import_directory"], f"{fit_and_score_kwargs_location}"
938
- )
939
- with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
940
- fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
941
-
942
- # convert dataframe to numpy would save memory consumption
943
- return (
944
- df[input_cols].to_numpy(),
945
- df[label_cols].squeeze().to_numpy(),
946
- indices,
947
- params_to_evaluate,
948
- base_estimator,
949
- fit_and_score_kwargs,
940
+ # Instantiate UDTFRegistration with the session object
941
+ udtf_registration = UDTFRegistration(session)
942
+
943
+ import tempfile
944
+
945
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
946
+ udf_code = execute_template
947
+ f.file.write(udf_code)
948
+ f.file.flush()
949
+
950
+ # Register the UDTF function from the file
951
+ udtf_registration.register_from_file(
952
+ file_path=f.name,
953
+ handler_name="SearchCV",
954
+ name=random_udtf_name,
955
+ output_schema=StructType(
956
+ [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
957
+ ),
958
+ input_types=[IntegerType(), IntegerType(), IntegerType()],
959
+ replace=True,
960
+ imports=imports, # type: ignore[arg-type]
961
+ is_permanent=False,
962
+ packages=required_deps, # type: ignore[arg-type]
963
+ statement_params=udtf_statement_params,
950
964
  )
951
965
 
952
- # Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
953
- class SearchCV:
954
- def __init__(self) -> None:
955
- X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs = _load_data_into_udf()
956
- self.X = X
957
- self.y = y
958
- self.test_indices = indices
959
- self.params_to_evaluate = params_to_evaluate
960
- self.base_estimator = base_estimator
961
- self.fit_and_score_kwargs = fit_and_score_kwargs
962
- self.fit_score_params: List[Any] = []
963
- self.cached_train_test_indices = []
964
- # Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
965
- full_index = np.arange(DATA_LENGTH)
966
- for i in range(n_splits):
967
- self.cached_train_test_indices.extend(
968
- [[np.setdiff1d(full_index, self.test_indices[i]), self.test_indices[i]]]
969
- )
970
-
971
- def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
972
- self.fit_score_params.extend([[idx, params_idx, cv_idx]])
973
-
974
- def end_partition(self) -> Iterator[Tuple[int, str]]:
975
- from sklearn.base import clone
976
- from sklearn.model_selection._validation import _fit_and_score
977
- from sklearn.utils.parallel import Parallel, delayed
978
-
979
- parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
980
-
981
- out = parallel(
982
- delayed(_fit_and_score)(
983
- clone(self.base_estimator),
984
- self.X,
985
- self.y,
986
- train=self.cached_train_test_indices[split_idx][0],
987
- test=self.cached_train_test_indices[split_idx][1],
988
- parameters=self.params_to_evaluate[cand_idx],
989
- split_progress=(split_idx, n_splits),
990
- candidate_progress=(cand_idx, n_candidates),
991
- **self.fit_and_score_kwargs, # load sample weight here
992
- )
993
- for _, cand_idx, split_idx in self.fit_score_params
994
- )
995
-
996
- binary_cv_results = None
997
- with io.BytesIO() as f:
998
- cp.dump(out, f)
999
- f.seek(0)
1000
- binary_cv_results = f.getvalue().hex()
1001
- yield (
1002
- self.fit_score_params[0][0],
1003
- binary_cv_results,
1004
- )
1005
-
1006
- session.udtf.register(
1007
- SearchCV,
1008
- output_schema=StructType([StructField("IDX", IntegerType()), StructField("CV_RESULTS", StringType())]),
1009
- input_types=[IntegerType(), IntegerType(), IntegerType()],
1010
- name=random_udtf_name,
1011
- packages=required_deps, # type: ignore[arg-type]
1012
- replace=True,
1013
- is_permanent=False,
1014
- imports=imports, # type: ignore[arg-type]
1015
- statement_params=udtf_statement_params,
1016
- )
1017
-
1018
966
  HP_TUNING = F.table_function(random_udtf_name)
1019
967
 
1020
968
  # param_indices is for the index for each parameter grid;
1021
969
  # cv_indices is for the index for each cross_validator's fold;
1022
970
  # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
1023
- param_indices, cv_indices = zip(
1024
- *product(range(parameter_grid_length), range(cross_validator_indices_length))
971
+ cv_indices, param_indices = zip(
972
+ *product(range(cross_validator_indices_length), range(parameter_grid_length))
1025
973
  )
1026
974
 
1027
975
  indices_info_pandas = pd.DataFrame(
@@ -1042,11 +990,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1042
990
  ),
1043
991
  )
1044
992
 
1045
- first_test_score, cv_results_ = construct_cv_results_new_implementation(
993
+ first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
1046
994
  estimator,
1047
995
  n_splits,
1048
996
  list(param_grid),
1049
- HP_raw_results.select("CV_RESULTS").sort(F.col("IDX")).collect(),
997
+ HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
1050
998
  cross_validator_indices_length,
1051
999
  parameter_grid_length,
1052
1000
  )
@@ -1163,7 +1111,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1163
1111
  pkg_versions=model_spec.pkgDependencies, session=self.session
1164
1112
  )
1165
1113
  if ENABLE_EFFICIENT_MEMORY_USAGE:
1166
- return self.fit_search_snowpark_new_implementation(
1114
+ return self.fit_search_snowpark_enable_efficient_memory_usage(
1167
1115
  param_grid=param_grid,
1168
1116
  dataset=self.dataset,
1169
1117
  session=self.session,
@@ -0,0 +1,159 @@
1
+ """
2
+ Description:
3
+ This is the helper file for distributed_hpo_trainer.py to create UDTF by `register_from_file`.
4
+ Performance Benefits:
5
+ The performance benefits come from two aspects,
6
+ 1. register_from_file can reduce duplicating loading data by only loading data once in each node
7
+ 2. register_from_file enable user to load data in global variable, whereas writing UDF in python script cannot.
8
+ Developer Tips:
9
+ Because this script is now a string, so there's no type hinting, linting, etc. It is highly recommended
10
+ to develop in a python script, test the type hinting, and then convert it into a string.
11
+ """
12
+
13
+ execute_template = """
14
+ from typing import Tuple, Any, List, Dict, Set, Iterator
15
+ import os
16
+ import sys
17
+ import pandas as pd
18
+ import numpy as np
19
+ import numpy.typing as npt
20
+ import cloudpickle as cp
21
+ import io
22
+
23
+
24
+ def _load_data_into_udf() -> Tuple[
25
+ npt.NDArray[Any],
26
+ npt.NDArray[Any],
27
+ List[List[int]],
28
+ List[Dict[str, Any]],
29
+ object,
30
+ Dict[str, Any],
31
+ Dict[str, Any],
32
+ ]:
33
+ import pyarrow.parquet as pq
34
+
35
+ data_files = [
36
+ filename
37
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
38
+ if filename.startswith("dataset")
39
+ ]
40
+ partial_df = [
41
+ pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
42
+ for file_name in data_files
43
+ ]
44
+ df = pd.concat(partial_df, ignore_index=True)
45
+ constant_file_path = None
46
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"]):
47
+ if filename.startswith("constant"):
48
+ constant_file_path = os.path.join(sys._xoptions["snowflake_import_directory"], f"{filename}")
49
+ if constant_file_path is None:
50
+ raise ValueError("UDTF cannot find the constant location, abort!")
51
+ with open(constant_file_path, mode="rb") as constant_file_obj:
52
+ CONSTANTS = cp.load(constant_file_obj)
53
+ df.columns = CONSTANTS['dataset_snowpark_cols']
54
+
55
+ # load parameter grid
56
+ local_estimator_file_path = os.path.join(
57
+ sys._xoptions["snowflake_import_directory"],
58
+ f"{CONSTANTS['estimator_location']}"
59
+ )
60
+ with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
61
+ estimator_objects = cp.load(local_estimator_file_obj)
62
+ params_to_evaluate = estimator_objects["param_grid"]
63
+
64
+ # load indices
65
+ local_indices_file_path = os.path.join(
66
+ sys._xoptions["snowflake_import_directory"],
67
+ f"{CONSTANTS['indices_location']}"
68
+ )
69
+ with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
70
+ indices = cp.load(local_indices_file_obj)
71
+
72
+ # load base estimator
73
+ local_base_estimator_file_path = os.path.join(
74
+ sys._xoptions["snowflake_import_directory"], f"{CONSTANTS['base_estimator_location']}"
75
+ )
76
+ with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
77
+ base_estimator = cp.load(local_base_estimator_file_obj)
78
+
79
+ # load fit_and_score_kwargs
80
+ local_fit_and_score_kwargs_file_path = os.path.join(
81
+ sys._xoptions["snowflake_import_directory"], f"{CONSTANTS['fit_and_score_kwargs_location']}"
82
+ )
83
+ with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
84
+ fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
85
+
86
+ # convert dataframe to numpy would save memory consumption
87
+ return (
88
+ df[CONSTANTS['input_cols']].to_numpy(),
89
+ df[CONSTANTS['label_cols']].squeeze().to_numpy(),
90
+ indices,
91
+ params_to_evaluate,
92
+ base_estimator,
93
+ fit_and_score_kwargs,
94
+ CONSTANTS
95
+ )
96
+
97
+
98
+ global_load_data = _load_data_into_udf()
99
+
100
+
101
+ # Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
102
+ class SearchCV:
103
+ def __init__(self) -> None:
104
+ X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs, CONSTANTS = global_load_data
105
+ self.X = X
106
+ self.y = y
107
+ self.test_indices = indices
108
+ self.params_to_evaluate = params_to_evaluate
109
+ self.base_estimator = base_estimator
110
+ self.fit_and_score_kwargs = fit_and_score_kwargs
111
+ self.fit_score_params: List[Any] = []
112
+ self.CONSTANTS = CONSTANTS
113
+ self.cv_indices_set: Set[int] = set()
114
+
115
+ def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
116
+ self.fit_score_params.extend([[idx, params_idx, cv_idx]])
117
+ self.cv_indices_set.add(cv_idx)
118
+
119
+ def end_partition(self) -> Iterator[Tuple[int, str]]:
120
+ from sklearn.base import clone
121
+ from sklearn.model_selection._validation import _fit_and_score
122
+ from sklearn.utils.parallel import Parallel, delayed
123
+
124
+ cached_train_test_indices = {}
125
+ # Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
126
+ full_index = np.arange(self.CONSTANTS['DATA_LENGTH'])
127
+ for i in self.cv_indices_set:
128
+ cached_train_test_indices[i] = [
129
+ np.setdiff1d(full_index, self.test_indices[i]),
130
+ self.test_indices[i],
131
+ ]
132
+
133
+ parallel = Parallel(n_jobs=self.CONSTANTS['_N_JOBS'], pre_dispatch=self.CONSTANTS['_PRE_DISPATCH'])
134
+
135
+ out = parallel(
136
+ delayed(_fit_and_score)(
137
+ clone(self.base_estimator),
138
+ self.X,
139
+ self.y,
140
+ train=cached_train_test_indices[split_idx][0],
141
+ test=cached_train_test_indices[split_idx][1],
142
+ parameters=self.params_to_evaluate[cand_idx],
143
+ split_progress=(split_idx, self.CONSTANTS['n_splits']),
144
+ candidate_progress=(cand_idx, self.CONSTANTS['n_candidates']),
145
+ **self.fit_and_score_kwargs, # load sample weight here
146
+ )
147
+ for _, cand_idx, split_idx in self.fit_score_params
148
+ )
149
+
150
+ binary_cv_results = None
151
+ with io.BytesIO() as f:
152
+ cp.dump(out, f)
153
+ f.seek(0)
154
+ binary_cv_results = f.getvalue().hex()
155
+ yield (
156
+ self.fit_score_params[0][0],
157
+ binary_cv_results,
158
+ )
159
+ """
@@ -45,6 +45,7 @@ cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
45
45
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
46
46
 
47
47
  _PROJECT = "ModelDevelopment"
48
+ _ENABLE_ANONYMOUS_SPROC = False
48
49
 
49
50
 
50
51
  class SnowparkModelTrainer:
@@ -251,6 +252,27 @@ class SnowparkModelTrainer:
251
252
 
252
253
  return fit_wrapper_function
253
254
 
255
+ def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
256
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
257
+ fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
258
+
259
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
260
+ pkg_versions=model_spec.pkgDependencies, session=self.session
261
+ )
262
+
263
+ fit_wrapper_sproc = self.session.sproc.register(
264
+ func=self._build_fit_wrapper_sproc(model_spec=model_spec),
265
+ is_permanent=False,
266
+ name=fit_sproc_name,
267
+ packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
268
+ replace=True,
269
+ session=self.session,
270
+ statement_params=statement_params,
271
+ anonymous=True,
272
+ )
273
+
274
+ return fit_wrapper_sproc
275
+
254
276
  def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
255
277
  # If the sproc already exists, don't register.
256
278
  if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
@@ -510,6 +532,28 @@ class SnowparkModelTrainer:
510
532
 
511
533
  return fit_transform_wrapper_function
512
534
 
535
+ def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
536
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
537
+
538
+ fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
539
+
540
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
541
+ pkg_versions=model_spec.pkgDependencies, session=self.session
542
+ )
543
+
544
+ fit_predict_wrapper_sproc = self.session.sproc.register(
545
+ func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec),
546
+ is_permanent=False,
547
+ name=fit_predict_sproc_name,
548
+ packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
549
+ replace=True,
550
+ session=self.session,
551
+ statement_params=statement_params,
552
+ anonymous=True,
553
+ )
554
+
555
+ return fit_predict_wrapper_sproc
556
+
513
557
  def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
514
558
  # If the sproc already exists, don't register.
515
559
  if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
@@ -545,6 +589,27 @@ class SnowparkModelTrainer:
545
589
 
546
590
  return fit_predict_wrapper_sproc
547
591
 
592
+ def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
593
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
594
+
595
+ fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
596
+
597
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
598
+ pkg_versions=model_spec.pkgDependencies, session=self.session
599
+ )
600
+
601
+ fit_transform_wrapper_sproc = self.session.sproc.register(
602
+ func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
603
+ is_permanent=False,
604
+ name=fit_transform_sproc_name,
605
+ packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
606
+ replace=True,
607
+ session=self.session,
608
+ statement_params=statement_params,
609
+ anonymous=True,
610
+ )
611
+ return fit_transform_wrapper_sproc
612
+
548
613
  def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
549
614
  # If the sproc already exists, don't register.
550
615
  if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
@@ -612,7 +677,10 @@ class SnowparkModelTrainer:
612
677
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
613
678
  )
614
679
 
615
- fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
680
+ if _ENABLE_ANONYMOUS_SPROC:
681
+ fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
682
+ else:
683
+ fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
616
684
 
617
685
  try:
618
686
  sproc_export_file_name: str = fit_wrapper_sproc(
@@ -680,7 +748,11 @@ class SnowparkModelTrainer:
680
748
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
681
749
  )
682
750
 
683
- fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
751
+ if _ENABLE_ANONYMOUS_SPROC:
752
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
753
+ else:
754
+ fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
755
+
684
756
  fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
685
757
 
686
758
  sproc_export_file_name: str = fit_predict_wrapper_sproc(
@@ -741,7 +813,13 @@ class SnowparkModelTrainer:
741
813
  custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
742
814
  )
743
815
 
744
- fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
816
+ if _ENABLE_ANONYMOUS_SPROC:
817
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
818
+ statement_params=statement_params
819
+ )
820
+ else:
821
+ fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
822
+
745
823
  fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
746
824
 
747
825
  sproc_export_file_name: str = fit_transform_wrapper_sproc(
@@ -629,7 +629,14 @@ class CalibratedClassifierCV(BaseTransformer):
629
629
  ) -> List[str]:
630
630
  # in case the inferred output column names dimension is different
631
631
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
632
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
633
+
634
+ # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
+ # seen during the fit.
636
+ snowpark_column_names = dataset.select(self.input_cols).columns
637
+ sample_pd_df.columns = snowpark_column_names
638
+
639
+ output_df_pd = getattr(self, method)(sample_pd_df, output_cols_prefix)
633
640
  output_df_columns = list(output_df_pd.columns)
634
641
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
635
642
  if self.sample_weight_col:
@@ -606,7 +606,14 @@ class AffinityPropagation(BaseTransformer):
606
606
  ) -> List[str]:
607
607
  # in case the inferred output column names dimension is different
608
608
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
609
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
610
+
611
+ # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
+ # seen during the fit.
613
+ snowpark_column_names = dataset.select(self.input_cols).columns
614
+ sample_pd_df.columns = snowpark_column_names
615
+
616
+ output_df_pd = getattr(self, method)(sample_pd_df, output_cols_prefix)
610
617
  output_df_columns = list(output_df_pd.columns)
611
618
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
612
619
  if self.sample_weight_col: