snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sentiment.py +7 -4
  3. snowflake/cortex/_sse_client.py +81 -0
  4. snowflake/cortex/_util.py +105 -8
  5. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  6. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  7. snowflake/ml/dataset/dataset.py +15 -12
  8. snowflake/ml/dataset/dataset_factory.py +3 -4
  9. snowflake/ml/feature_store/access_manager.py +34 -30
  10. snowflake/ml/feature_store/feature_store.py +3 -3
  11. snowflake/ml/feature_store/feature_view.py +12 -11
  12. snowflake/ml/fileset/snowfs.py +2 -31
  13. snowflake/ml/model/_client/ops/model_ops.py +43 -0
  14. snowflake/ml/model/_client/sql/model_version.py +55 -3
  15. snowflake/ml/model/_model_composer/model_composer.py +7 -3
  16. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  17. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  18. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  19. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  20. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  21. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  22. snowflake/ml/model/_signatures/core.py +13 -1
  23. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  24. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  25. snowflake/ml/model/model_signature.py +2 -0
  26. snowflake/ml/model/type_hints.py +1 -0
  27. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
  29. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
  30. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  31. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  32. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
  36. snowflake/ml/modeling/cluster/birch.py +9 -2
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
  38. snowflake/ml/modeling/cluster/dbscan.py +9 -2
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
  40. snowflake/ml/modeling/cluster/k_means.py +9 -2
  41. snowflake/ml/modeling/cluster/mean_shift.py +9 -2
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
  43. snowflake/ml/modeling/cluster/optics.py +9 -2
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
  47. snowflake/ml/modeling/compose/column_transformer.py +9 -2
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
  54. snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
  55. snowflake/ml/modeling/covariance/oas.py +9 -2
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
  59. snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
  64. snowflake/ml/modeling/decomposition/pca.py +9 -2
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
  93. snowflake/ml/modeling/framework/base.py +3 -8
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
  96. snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
  97. snowflake/ml/modeling/impute/knn_imputer.py +9 -2
  98. snowflake/ml/modeling/impute/missing_indicator.py +9 -2
  99. snowflake/ml/modeling/impute/simple_imputer.py +28 -5
  100. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
  101. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
  102. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
  103. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
  104. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
  105. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
  106. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
  107. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
  108. snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
  109. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
  110. snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
  111. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
  112. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
  113. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
  114. snowflake/ml/modeling/linear_model/lars.py +9 -2
  115. snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
  116. snowflake/ml/modeling/linear_model/lasso.py +9 -2
  117. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
  118. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
  119. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
  120. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
  121. snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
  122. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
  123. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
  124. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
  126. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
  127. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
  128. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
  129. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
  130. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
  131. snowflake/ml/modeling/linear_model/perceptron.py +9 -2
  132. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
  133. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
  134. snowflake/ml/modeling/linear_model/ridge.py +9 -2
  135. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
  136. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
  137. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
  138. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
  139. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
  140. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
  141. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
  142. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
  143. snowflake/ml/modeling/manifold/isomap.py +9 -2
  144. snowflake/ml/modeling/manifold/mds.py +9 -2
  145. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
  146. snowflake/ml/modeling/manifold/tsne.py +9 -2
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
  161. snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
  171. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  172. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  173. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  174. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  175. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  176. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  177. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  178. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  179. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  180. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
  182. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  183. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  184. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
  185. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
  186. snowflake/ml/modeling/svm/linear_svc.py +9 -2
  187. snowflake/ml/modeling/svm/linear_svr.py +9 -2
  188. snowflake/ml/modeling/svm/nu_svc.py +9 -2
  189. snowflake/ml/modeling/svm/nu_svr.py +9 -2
  190. snowflake/ml/modeling/svm/svc.py +9 -2
  191. snowflake/ml/modeling/svm/svr.py +9 -2
  192. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
  193. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
  194. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
  195. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
  196. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
  197. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
  198. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
  199. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
  200. snowflake/ml/registry/_manager/model_manager.py +59 -1
  201. snowflake/ml/registry/registry.py +10 -1
  202. snowflake/ml/version.py +1 -1
  203. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
  204. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
  205. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  206. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  207. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -17,30 +17,19 @@ from snowflake.ml._internal.utils import (
17
17
  identifier,
18
18
  pkg_version_utils,
19
19
  snowpark_dataframe_utils,
20
+ temp_file_utils,
20
21
  )
21
- from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
22
- from snowflake.ml._internal.utils.temp_file_utils import (
23
- cleanup_temp_files,
24
- get_temp_file_path,
25
- )
22
+ from snowflake.ml.modeling._internal import estimator_utils
26
23
  from snowflake.ml.modeling._internal.estimator_utils import handle_inference_result
27
24
  from snowflake.ml.modeling._internal.model_specifications import (
28
25
  ModelSpecifications,
29
26
  ModelSpecificationsBuilder,
30
27
  )
31
- from snowflake.snowpark import (
32
- DataFrame,
33
- Session,
34
- exceptions as snowpark_exceptions,
35
- functions as F,
36
- )
37
- from snowflake.snowpark._internal.utils import (
38
- TempObjectType,
39
- random_name_for_temp_object,
40
- )
28
+ from snowflake.snowpark import DataFrame, Session, exceptions as snowpark_exceptions
29
+ from snowflake.snowpark._internal import utils as snowpark_utils
41
30
  from snowflake.snowpark.stored_procedure import StoredProcedure
42
31
 
43
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
32
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
44
33
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
45
34
  cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
46
35
 
@@ -90,60 +79,6 @@ class SnowparkModelTrainer:
90
79
  self._subproject = subproject
91
80
  self._class_name = estimator.__class__.__name__
92
81
 
93
- def _create_temp_stage(self) -> str:
94
- """
95
- Creates temporary stage.
96
-
97
- Returns:
98
- Temp stage name.
99
- """
100
- # Create temp stage to upload pickled model file.
101
- transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
102
- stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
103
- SqlResultValidator(session=self.session, query=stage_creation_query).has_dimensions(
104
- expected_rows=1, expected_cols=1
105
- ).validate()
106
- return transform_stage_name
107
-
108
- def _upload_model_to_stage(self, stage_name: str) -> Tuple[str, str]:
109
- """
110
- Util method to pickle and upload the model to a temp Snowflake stage.
111
-
112
- Args:
113
- stage_name: Stage name to save model.
114
-
115
- Returns:
116
- a tuple containing stage file paths for pickled input model for training and location to store trained
117
- models(response from training sproc).
118
- """
119
- # Create a temp file and dump the transform to that file.
120
- local_transform_file_name = get_temp_file_path()
121
- with open(local_transform_file_name, mode="w+b") as local_transform_file:
122
- cp.dump(self.estimator, local_transform_file)
123
-
124
- # Use posixpath to construct stage paths
125
- stage_transform_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
126
- stage_result_file_name = posixpath.join(stage_name, os.path.basename(local_transform_file_name))
127
-
128
- statement_params = telemetry.get_function_usage_statement_params(
129
- project=_PROJECT,
130
- subproject=self._subproject,
131
- function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
132
- api_calls=[F.sproc],
133
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
134
- )
135
- # Put locally serialized transform on stage.
136
- self.session.file.put(
137
- local_transform_file_name,
138
- stage_transform_file_name,
139
- auto_compress=False,
140
- overwrite=True,
141
- statement_params=statement_params,
142
- )
143
-
144
- cleanup_temp_files([local_transform_file_name])
145
- return (stage_transform_file_name, stage_result_file_name)
146
-
147
82
  def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
148
83
  """
149
84
  Downloads the serialized model from a stage location and unpickles it.
@@ -156,7 +91,7 @@ class SnowparkModelTrainer:
156
91
  Returns:
157
92
  Deserialized model object.
158
93
  """
159
- local_result_file_name = get_temp_file_path()
94
+ local_result_file_name = temp_file_utils.get_temp_file_path()
160
95
  self.session.file.get(
161
96
  posixpath.join(dir_path, file_name),
162
97
  local_result_file_name,
@@ -166,13 +101,13 @@ class SnowparkModelTrainer:
166
101
  with open(os.path.join(local_result_file_name, file_name), mode="r+b") as result_file_obj:
167
102
  fit_estimator = cp.load(result_file_obj)
168
103
 
169
- cleanup_temp_files([local_result_file_name])
104
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
170
105
  return fit_estimator
171
106
 
172
107
  def _build_fit_wrapper_sproc(
173
108
  self,
174
109
  model_spec: ModelSpecifications,
175
- ) -> Callable[[Any, List[str], str, str, List[str], List[str], Optional[str], Dict[str, str]], str]:
110
+ ) -> Callable[[Any, List[str], str, List[str], List[str], Optional[str], Dict[str, str]], str]:
176
111
  """
177
112
  Constructs and returns a python stored procedure function to be used for training model.
178
113
 
@@ -188,8 +123,7 @@ class SnowparkModelTrainer:
188
123
  def fit_wrapper_function(
189
124
  session: Session,
190
125
  sql_queries: List[str],
191
- stage_transform_file_name: str,
192
- stage_result_file_name: str,
126
+ temp_stage_name: str,
193
127
  input_cols: List[str],
194
128
  label_cols: List[str],
195
129
  sample_weight_col: Optional[str],
@@ -212,9 +146,13 @@ class SnowparkModelTrainer:
212
146
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
213
147
  df.columns = sp_df.columns
214
148
 
215
- local_transform_file_name = get_temp_file_path()
149
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
216
150
 
217
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
151
+ session.file.get(
152
+ stage_location=temp_stage_name,
153
+ target_directory=local_transform_file_name,
154
+ statement_params=statement_params,
155
+ )
218
156
 
219
157
  local_transform_file_path = os.path.join(
220
158
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -233,14 +171,14 @@ class SnowparkModelTrainer:
233
171
 
234
172
  estimator.fit(**args)
235
173
 
236
- local_result_file_name = get_temp_file_path()
174
+ local_result_file_name = temp_file_utils.get_temp_file_path()
237
175
 
238
176
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
239
177
  cp.dump(estimator, local_result_file_obj)
240
178
 
241
179
  session.file.put(
242
- local_result_file_name,
243
- stage_result_file_name,
180
+ local_file_name=local_result_file_name,
181
+ stage_location=temp_stage_name,
244
182
  auto_compress=False,
245
183
  overwrite=True,
246
184
  statement_params=statement_params,
@@ -254,7 +192,7 @@ class SnowparkModelTrainer:
254
192
 
255
193
  def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
256
194
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
257
- fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
195
+ fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
258
196
 
259
197
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
260
198
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -284,7 +222,7 @@ class SnowparkModelTrainer:
284
222
  fit_sproc: StoredProcedure = self.session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
285
223
  return fit_sproc
286
224
 
287
- fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
225
+ fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
288
226
 
289
227
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
290
228
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -307,7 +245,7 @@ class SnowparkModelTrainer:
307
245
  def _build_fit_predict_wrapper_sproc(
308
246
  self,
309
247
  model_spec: ModelSpecifications,
310
- ) -> Callable[[Session, List[str], str, str, List[str], Dict[str, str], bool, List[str], str], str]:
248
+ ) -> Callable[[Session, List[str], str, List[str], Dict[str, str], bool, List[str], str], str]:
311
249
  """
312
250
  Constructs and returns a python stored procedure function to be used for training model.
313
251
 
@@ -323,8 +261,7 @@ class SnowparkModelTrainer:
323
261
  def fit_predict_wrapper_function(
324
262
  session: Session,
325
263
  sql_queries: List[str],
326
- stage_transform_file_name: str,
327
- stage_result_file_name: str,
264
+ temp_stage_name: str,
328
265
  input_cols: List[str],
329
266
  statement_params: Dict[str, str],
330
267
  drop_input_cols: bool,
@@ -347,9 +284,13 @@ class SnowparkModelTrainer:
347
284
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
348
285
  df.columns = sp_df.columns
349
286
 
350
- local_transform_file_name = get_temp_file_path()
287
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
351
288
 
352
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
289
+ session.file.get(
290
+ stage_location=temp_stage_name,
291
+ target_directory=local_transform_file_name,
292
+ statement_params=statement_params,
293
+ )
353
294
 
354
295
  local_transform_file_path = os.path.join(
355
296
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -359,14 +300,14 @@ class SnowparkModelTrainer:
359
300
 
360
301
  fit_predict_result = estimator.fit_predict(X=df[input_cols])
361
302
 
362
- local_result_file_name = get_temp_file_path()
303
+ local_result_file_name = temp_file_utils.get_temp_file_path()
363
304
 
364
305
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
365
306
  cp.dump(estimator, local_result_file_obj)
366
307
 
367
308
  session.file.put(
368
- local_result_file_name,
369
- stage_result_file_name,
309
+ local_file_name=local_result_file_name,
310
+ stage_location=temp_stage_name,
370
311
  auto_compress=False,
371
312
  overwrite=True,
372
313
  statement_params=statement_params,
@@ -407,7 +348,6 @@ class SnowparkModelTrainer:
407
348
  Session,
408
349
  List[str],
409
350
  str,
410
- str,
411
351
  List[str],
412
352
  Optional[List[str]],
413
353
  Optional[str],
@@ -433,8 +373,7 @@ class SnowparkModelTrainer:
433
373
  def fit_transform_wrapper_function(
434
374
  session: Session,
435
375
  sql_queries: List[str],
436
- stage_transform_file_name: str,
437
- stage_result_file_name: str,
376
+ temp_stage_name: str,
438
377
  input_cols: List[str],
439
378
  label_cols: Optional[List[str]],
440
379
  sample_weight_col: Optional[str],
@@ -459,9 +398,13 @@ class SnowparkModelTrainer:
459
398
  df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
460
399
  df.columns = sp_df.columns
461
400
 
462
- local_transform_file_name = get_temp_file_path()
401
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
463
402
 
464
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
403
+ session.file.get(
404
+ stage_location=temp_stage_name,
405
+ target_directory=local_transform_file_name,
406
+ statement_params=statement_params,
407
+ )
465
408
 
466
409
  local_transform_file_path = os.path.join(
467
410
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -480,14 +423,14 @@ class SnowparkModelTrainer:
480
423
 
481
424
  fit_transform_result = estimator.fit_transform(**args)
482
425
 
483
- local_result_file_name = get_temp_file_path()
426
+ local_result_file_name = temp_file_utils.get_temp_file_path()
484
427
 
485
428
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
486
429
  cp.dump(estimator, local_result_file_obj)
487
430
 
488
431
  session.file.put(
489
- local_result_file_name,
490
- stage_result_file_name,
432
+ local_file_name=local_result_file_name,
433
+ stage_location=temp_stage_name,
491
434
  auto_compress=False,
492
435
  overwrite=True,
493
436
  statement_params=statement_params,
@@ -535,7 +478,7 @@ class SnowparkModelTrainer:
535
478
  def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
536
479
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
537
480
 
538
- fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
481
+ fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
539
482
 
540
483
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
541
484
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -567,7 +510,7 @@ class SnowparkModelTrainer:
567
510
  ]
568
511
  return fit_sproc
569
512
 
570
- fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
513
+ fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
571
514
 
572
515
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
573
516
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -592,7 +535,7 @@ class SnowparkModelTrainer:
592
535
  def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
593
536
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
594
537
 
595
- fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
538
+ fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
596
539
 
597
540
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
598
541
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -623,7 +566,7 @@ class SnowparkModelTrainer:
623
566
  ]
624
567
  return fit_sproc
625
568
 
626
- fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
569
+ fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
627
570
 
628
571
  relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
629
572
  pkg_versions=model_spec.pkgDependencies, session=self.session
@@ -663,19 +606,21 @@ class SnowparkModelTrainer:
663
606
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
664
607
  queries = dataset.queries["queries"]
665
608
 
666
- transform_stage_name = self._create_temp_stage()
667
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
668
- stage_name=transform_stage_name
669
- )
670
-
671
- # Call fit sproc
609
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
672
610
  statement_params = telemetry.get_function_usage_statement_params(
673
611
  project=_PROJECT,
674
612
  subproject=self._subproject,
675
613
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
676
614
  api_calls=[Session.call],
677
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
615
+ custom_tags={"autogen": True} if self._autogenerated else None,
678
616
  )
617
+ estimator_utils.upload_model_to_stage(
618
+ stage_name=temp_stage_name,
619
+ estimator=self.estimator,
620
+ session=self.session,
621
+ statement_params=statement_params,
622
+ )
623
+ # Call fit sproc
679
624
 
680
625
  if _ENABLE_ANONYMOUS_SPROC:
681
626
  fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
@@ -686,8 +631,7 @@ class SnowparkModelTrainer:
686
631
  sproc_export_file_name: str = fit_wrapper_sproc(
687
632
  self.session,
688
633
  queries,
689
- stage_transform_file_name,
690
- stage_result_file_name,
634
+ temp_stage_name,
691
635
  self.input_cols,
692
636
  self.label_cols,
693
637
  self.sample_weight_col,
@@ -706,7 +650,7 @@ class SnowparkModelTrainer:
706
650
  sproc_export_file_name = fields[0]
707
651
 
708
652
  return self._fetch_model_from_stage(
709
- dir_path=stage_result_file_name,
653
+ dir_path=temp_stage_name,
710
654
  file_name=sproc_export_file_name,
711
655
  statement_params=statement_params,
712
656
  )
@@ -734,32 +678,34 @@ class SnowparkModelTrainer:
734
678
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
735
679
  queries = dataset.queries["queries"]
736
680
 
737
- transform_stage_name = self._create_temp_stage()
738
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
739
- stage_name=transform_stage_name
740
- )
741
-
742
- # Call fit sproc
743
681
  statement_params = telemetry.get_function_usage_statement_params(
744
682
  project=_PROJECT,
745
683
  subproject=self._subproject,
746
684
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
747
685
  api_calls=[Session.call],
748
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
686
+ custom_tags={"autogen": True} if self._autogenerated else None,
749
687
  )
750
688
 
689
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
690
+ estimator_utils.upload_model_to_stage(
691
+ stage_name=temp_stage_name,
692
+ estimator=self.estimator,
693
+ session=self.session,
694
+ statement_params=statement_params,
695
+ )
696
+
697
+ # Call fit sproc
751
698
  if _ENABLE_ANONYMOUS_SPROC:
752
699
  fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
753
700
  else:
754
701
  fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
755
702
 
756
- fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
703
+ fit_predict_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
757
704
 
758
705
  sproc_export_file_name: str = fit_predict_wrapper_sproc(
759
706
  self.session,
760
707
  queries,
761
- stage_transform_file_name,
762
- stage_result_file_name,
708
+ temp_stage_name,
763
709
  self.input_cols,
764
710
  statement_params,
765
711
  drop_input_cols,
@@ -769,7 +715,7 @@ class SnowparkModelTrainer:
769
715
 
770
716
  output_result_sp = self.session.table(fit_predict_result_name)
771
717
  fitted_estimator = self._fetch_model_from_stage(
772
- dir_path=stage_result_file_name,
718
+ dir_path=temp_stage_name,
773
719
  file_name=sproc_export_file_name,
774
720
  statement_params=statement_params,
775
721
  )
@@ -799,20 +745,23 @@ class SnowparkModelTrainer:
799
745
  # Extract query that generated the dataframe. We will need to pass it to the fit procedure.
800
746
  queries = dataset.queries["queries"]
801
747
 
802
- transform_stage_name = self._create_temp_stage()
803
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(
804
- stage_name=transform_stage_name
805
- )
806
-
807
- # Call fit sproc
808
748
  statement_params = telemetry.get_function_usage_statement_params(
809
749
  project=_PROJECT,
810
750
  subproject=self._subproject,
811
751
  function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
812
752
  api_calls=[Session.call],
813
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
753
+ custom_tags={"autogen": True} if self._autogenerated else None,
754
+ )
755
+
756
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
757
+ estimator_utils.upload_model_to_stage(
758
+ stage_name=temp_stage_name,
759
+ estimator=self.estimator,
760
+ session=self.session,
761
+ statement_params=statement_params,
814
762
  )
815
763
 
764
+ # Call fit sproc
816
765
  if _ENABLE_ANONYMOUS_SPROC:
817
766
  fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
818
767
  statement_params=statement_params
@@ -820,13 +769,12 @@ class SnowparkModelTrainer:
820
769
  else:
821
770
  fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
822
771
 
823
- fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
772
+ fit_transform_result_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
824
773
 
825
774
  sproc_export_file_name: str = fit_transform_wrapper_sproc(
826
775
  self.session,
827
776
  queries,
828
- stage_transform_file_name,
829
- stage_result_file_name,
777
+ temp_stage_name,
830
778
  self.input_cols,
831
779
  self.label_cols,
832
780
  self.sample_weight_col,
@@ -838,7 +786,7 @@ class SnowparkModelTrainer:
838
786
 
839
787
  output_result_sp = self.session.table(fit_transform_result_name)
840
788
  fitted_estimator = self._fetch_model_from_stage(
841
- dir_path=stage_result_file_name,
789
+ dir_path=temp_stage_name,
842
790
  file_name=sproc_export_file_name,
843
791
  statement_params=statement_params,
844
792
  )
@@ -13,12 +13,12 @@ from snowflake.ml._internal.exceptions import (
13
13
  exceptions,
14
14
  modeling_error_messages,
15
15
  )
16
- from snowflake.ml._internal.utils import pkg_version_utils
16
+ from snowflake.ml._internal.utils import pkg_version_utils, temp_file_utils
17
17
  from snowflake.ml._internal.utils.query_result_checker import ResultValidator
18
18
  from snowflake.ml._internal.utils.snowpark_dataframe_utils import (
19
19
  cast_snowpark_dataframe,
20
20
  )
21
- from snowflake.ml._internal.utils.temp_file_utils import get_temp_file_path
21
+ from snowflake.ml.modeling._internal import estimator_utils
22
22
  from snowflake.ml.modeling._internal.model_specifications import (
23
23
  ModelSpecifications,
24
24
  ModelSpecificationsBuilder,
@@ -306,8 +306,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
306
306
  ) # type: ignore[misc]
307
307
  def fit_wrapper_sproc(
308
308
  session: Session,
309
- stage_transform_file_name: str,
310
- stage_result_file_name: str,
311
309
  dataset_stage_name: str,
312
310
  batch_size: int,
313
311
  input_cols: List[str],
@@ -320,9 +318,13 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
320
318
 
321
319
  import cloudpickle as cp
322
320
 
323
- local_transform_file_name = get_temp_file_path()
321
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
324
322
 
325
- session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
323
+ session.file.get(
324
+ stage_location=dataset_stage_name,
325
+ target_directory=local_transform_file_name,
326
+ statement_params=statement_params,
327
+ )
326
328
 
327
329
  local_transform_file_path = os.path.join(
328
330
  local_transform_file_name, os.listdir(local_transform_file_name)[0]
@@ -345,13 +347,13 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
345
347
  sample_weight_col=sample_weight_col,
346
348
  )
347
349
 
348
- local_result_file_name = get_temp_file_path()
350
+ local_result_file_name = temp_file_utils.get_temp_file_path()
349
351
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
350
352
  cp.dump(estimator, local_result_file_obj)
351
353
 
352
354
  session.file.put(
353
- local_result_file_name,
354
- stage_result_file_name,
355
+ local_file_name=local_result_file_name,
356
+ stage_location=dataset_stage_name,
355
357
  auto_compress=False,
356
358
  overwrite=True,
357
359
  statement_params=statement_params,
@@ -394,11 +396,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
394
396
  SnowflakeMLException: For known types of user and system errors.
395
397
  e: For every unexpected exception from SnowflakeClient.
396
398
  """
397
- temp_stage_name = self._create_temp_stage()
398
- (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(stage_name=temp_stage_name)
399
- data_file_paths = self._write_training_data_to_stage(dataset_stage_name=temp_stage_name)
400
-
401
- # Call fit sproc
402
399
  statement_params = telemetry.get_function_usage_statement_params(
403
400
  project=_PROJECT,
404
401
  subproject=self._subproject,
@@ -406,7 +403,16 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
406
403
  api_calls=[Session.call],
407
404
  custom_tags=None,
408
405
  )
406
+ temp_stage_name = estimator_utils.create_temp_stage(self.session)
407
+ estimator_utils.upload_model_to_stage(
408
+ stage_name=temp_stage_name,
409
+ estimator=self.estimator,
410
+ session=self.session,
411
+ statement_params=statement_params,
412
+ )
413
+ data_file_paths = self._write_training_data_to_stage(dataset_stage_name=temp_stage_name)
409
414
 
415
+ # Call fit sproc
410
416
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
411
417
  fit_wrapper = self._get_xgb_external_memory_fit_wrapper_sproc(
412
418
  model_spec=model_spec,
@@ -418,8 +424,6 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
418
424
  try:
419
425
  sproc_export_file_name = fit_wrapper(
420
426
  self.session,
421
- stage_transform_file_name,
422
- stage_result_file_name,
423
427
  temp_stage_name,
424
428
  self._batch_size,
425
429
  self.input_cols,
@@ -440,7 +444,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
440
444
  sproc_export_file_name = fields[0]
441
445
 
442
446
  return self._fetch_model_from_stage(
443
- dir_path=stage_result_file_name,
447
+ dir_path=temp_stage_name,
444
448
  file_name=sproc_export_file_name,
445
449
  statement_params=statement_params,
446
450
  )
@@ -296,7 +296,7 @@ class CalibratedClassifierCV(BaseTransformer):
296
296
  inspect.currentframe(), CalibratedClassifierCV.__class__.__name__
297
297
  ),
298
298
  api_calls=[Session.call],
299
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
299
+ custom_tags={"autogen": True} if self._autogenerated else None,
300
300
  )
301
301
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
302
302
  pd_df.columns = dataset.columns
@@ -629,7 +629,14 @@ class CalibratedClassifierCV(BaseTransformer):
629
629
  ) -> List[str]:
630
630
  # in case the inferred output column names dimension is different
631
631
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
632
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
633
+
634
+ # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
+ # seen during the fit.
636
+ snowpark_column_names = dataset.select(self.input_cols).columns
637
+ sample_pd_df.columns = snowpark_column_names
638
+
639
+ output_df_pd = getattr(self, method)(sample_pd_df, output_cols_prefix)
633
640
  output_df_columns = list(output_df_pd.columns)
634
641
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
635
642
  if self.sample_weight_col:
@@ -271,7 +271,7 @@ class AffinityPropagation(BaseTransformer):
271
271
  inspect.currentframe(), AffinityPropagation.__class__.__name__
272
272
  ),
273
273
  api_calls=[Session.call],
274
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
274
+ custom_tags={"autogen": True} if self._autogenerated else None,
275
275
  )
276
276
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
277
277
  pd_df.columns = dataset.columns
@@ -606,7 +606,14 @@ class AffinityPropagation(BaseTransformer):
606
606
  ) -> List[str]:
607
607
  # in case the inferred output column names dimension is different
608
608
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
609
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
610
+
611
+ # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
+ # seen during the fit.
613
+ snowpark_column_names = dataset.select(self.input_cols).columns
614
+ sample_pd_df.columns = snowpark_column_names
615
+
616
+ output_df_pd = getattr(self, method)(sample_pd_df, output_cols_prefix)
610
617
  output_df_columns = list(output_df_pd.columns)
611
618
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
612
619
  if self.sample_weight_col:
@@ -304,7 +304,7 @@ class AgglomerativeClustering(BaseTransformer):
304
304
  inspect.currentframe(), AgglomerativeClustering.__class__.__name__
305
305
  ),
306
306
  api_calls=[Session.call],
307
- custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
307
+ custom_tags={"autogen": True} if self._autogenerated else None,
308
308
  )
309
309
  pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
310
310
  pd_df.columns = dataset.columns
@@ -637,7 +637,14 @@ class AgglomerativeClustering(BaseTransformer):
637
637
  ) -> List[str]:
638
638
  # in case the inferred output column names dimension is different
639
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
640
- output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
640
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
641
+
642
+ # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
643
+ # seen during the fit.
644
+ snowpark_column_names = dataset.select(self.input_cols).columns
645
+ sample_pd_df.columns = snowpark_column_names
646
+
647
+ output_df_pd = getattr(self, method)(sample_pd_df, output_cols_prefix)
641
648
  output_df_columns = list(output_df_pd.columns)
642
649
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
643
650
  if self.sample_weight_col: