snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sse_client.py +81 -0
  3. snowflake/cortex/_util.py +105 -8
  4. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  5. snowflake/ml/dataset/dataset.py +15 -12
  6. snowflake/ml/dataset/dataset_factory.py +3 -4
  7. snowflake/ml/feature_store/feature_store.py +2 -2
  8. snowflake/ml/model/_client/sql/model_version.py +2 -2
  9. snowflake/ml/model/_model_composer/model_composer.py +2 -2
  10. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  11. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  12. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  13. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  14. snowflake/ml/model/_signatures/core.py +13 -1
  15. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  16. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  17. snowflake/ml/model/model_signature.py +2 -0
  18. snowflake/ml/model/type_hints.py +1 -0
  19. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  20. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +156 -121
  21. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  22. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  23. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  24. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  25. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  26. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  27. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  28. snowflake/ml/modeling/cluster/birch.py +1 -1
  29. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  30. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  31. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  32. snowflake/ml/modeling/cluster/k_means.py +1 -1
  33. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  34. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  35. snowflake/ml/modeling/cluster/optics.py +1 -1
  36. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  37. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  38. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  39. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  40. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  41. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  42. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  43. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  44. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  45. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  46. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  47. snowflake/ml/modeling/covariance/oas.py +1 -1
  48. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  49. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  50. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  51. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  52. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  53. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  54. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  55. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  56. snowflake/ml/modeling/decomposition/pca.py +1 -1
  57. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  58. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  59. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  60. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  61. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  62. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  63. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  64. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  65. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  66. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  67. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  68. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  69. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  70. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  71. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  72. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  75. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  76. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  77. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  78. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  79. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  80. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  81. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  82. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  83. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  84. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  85. snowflake/ml/modeling/framework/base.py +3 -8
  86. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  87. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  88. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  89. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  90. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  91. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  92. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  93. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  94. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  95. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  96. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  97. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  98. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  99. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  100. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  101. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  102. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  103. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  104. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  106. snowflake/ml/modeling/linear_model/lars.py +1 -1
  107. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  109. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  110. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  112. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  113. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  119. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  121. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  122. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  123. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  124. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  125. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  126. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  127. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  128. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  129. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  130. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  131. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  132. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  133. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  134. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  135. snowflake/ml/modeling/manifold/isomap.py +1 -1
  136. snowflake/ml/modeling/manifold/mds.py +1 -1
  137. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  138. snowflake/ml/modeling/manifold/tsne.py +1 -1
  139. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  140. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  143. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  144. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  145. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  146. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  147. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  148. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  151. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  152. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  153. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  154. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  155. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  156. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  157. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  158. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  159. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  160. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  161. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  162. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  163. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  164. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  165. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  166. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  167. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  168. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  169. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  170. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  171. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  173. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  174. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  175. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  176. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  177. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  178. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  179. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  180. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  181. snowflake/ml/modeling/svm/svc.py +1 -1
  182. snowflake/ml/modeling/svm/svr.py +1 -1
  183. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  184. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  185. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  186. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  187. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  189. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +21 -5
  193. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +196 -195
  194. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  195. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  196. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,19 @@
1
1
  import inspect
2
2
  import numbers
3
+ import os
3
4
  from typing import Any, Callable, Dict, List, Set, Tuple
4
5
 
6
+ import cloudpickle as cp
5
7
  import numpy as np
6
8
  from numpy import typing as npt
7
- from typing_extensions import TypeGuard
8
9
 
9
10
  from snowflake.ml._internal.exceptions import error_codes, exceptions
11
+ from snowflake.ml._internal.utils import temp_file_utils
12
+ from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
10
13
  from snowflake.ml.modeling.framework._utils import to_native_format
11
14
  from snowflake.ml.modeling.framework.base import BaseTransformer
12
15
  from snowflake.snowpark import Session
16
+ from snowflake.snowpark._internal import utils as snowpark_utils
13
17
 
14
18
 
15
19
  def validate_sklearn_args(args: Dict[str, Tuple[Any, Any, bool]], klass: type) -> Dict[str, Any]:
@@ -97,6 +101,7 @@ def original_estimator_has_callable(attr: str) -> Callable[[Any], bool]:
97
101
  Returns:
98
102
  A function which checks for the existence of callable `attr` on the given object.
99
103
  """
104
+ from typing_extensions import TypeGuard
100
105
 
101
106
  def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
102
107
  """Check for the existence of callable `attr` in self.
@@ -218,3 +223,55 @@ def handle_inference_result(
218
223
  )
219
224
 
220
225
  return transformed_numpy_array, output_cols
226
+
227
+
228
+ def create_temp_stage(session: Session) -> str:
229
+ """Creates temporary stage.
230
+
231
+ Args:
232
+ session: Session
233
+
234
+ Returns:
235
+ Temp stage name.
236
+ """
237
+ # Create temp stage to upload pickled model file.
238
+ transform_stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
239
+ stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
240
+ SqlResultValidator(session=session, query=stage_creation_query).has_dimensions(
241
+ expected_rows=1, expected_cols=1
242
+ ).validate()
243
+ return transform_stage_name
244
+
245
+
246
+ def upload_model_to_stage(
247
+ stage_name: str, estimator: object, session: Session, statement_params: Dict[str, str]
248
+ ) -> str:
249
+ """Util method to pickle and upload the model to a temp Snowflake stage.
250
+
251
+
252
+ Args:
253
+ stage_name: Stage name to save model.
254
+ estimator: Estimator object to upload to stage (sklearn model object)
255
+ session: The snowpark session to use.
256
+ statement_params: Statement parameters for query telemetry.
257
+
258
+ Returns:
259
+ a tuple containing stage file paths for pickled input model for training and location to store trained
260
+ models(response from training sproc).
261
+ """
262
+ # Create a temp file and dump the transform to that file.
263
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
264
+ with open(local_transform_file_name, mode="w+b") as local_transform_file:
265
+ cp.dump(estimator, local_transform_file)
266
+
267
+ # Put locally serialized transform on stage.
268
+ session.file.put(
269
+ local_file_name=local_transform_file_name,
270
+ stage_location=stage_name,
271
+ auto_compress=False,
272
+ overwrite=True,
273
+ statement_params=statement_params,
274
+ )
275
+
276
+ temp_file_utils.cleanup_temp_files([local_transform_file_name])
277
+ return os.path.basename(local_transform_file_name)
@@ -4,6 +4,7 @@ import io
4
4
  import os
5
5
  import posixpath
6
6
  import sys
7
+ import uuid
7
8
  from typing import Any, Dict, List, Optional, Tuple, Union
8
9
 
9
10
  import cloudpickle as cp
@@ -16,10 +17,7 @@ from snowflake.ml._internal.utils import (
16
17
  identifier,
17
18
  pkg_version_utils,
18
19
  snowpark_dataframe_utils,
19
- )
20
- from snowflake.ml._internal.utils.temp_file_utils import (
21
- cleanup_temp_files,
22
- get_temp_file_path,
20
+ temp_file_utils,
23
21
  )
24
22
  from snowflake.ml.modeling._internal.model_specifications import (
25
23
  ModelSpecificationsBuilder,
@@ -37,13 +35,14 @@ from snowflake.snowpark.row import Row
37
35
  from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
38
36
  from snowflake.snowpark.udtf import UDTFRegistration
39
37
 
40
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
38
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
41
39
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
42
40
  cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
43
41
 
44
42
  _PROJECT = "ModelDevelopment"
45
43
  DEFAULT_UDTF_NJOBS = 3
46
44
  ENABLE_EFFICIENT_MEMORY_USAGE = False
45
+ _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
47
46
 
48
47
 
49
48
  def construct_cv_results(
@@ -318,7 +317,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
318
317
  original_refit = estimator.refit
319
318
 
320
319
  # Create a temp file and dump the estimator to that file.
321
- estimator_file_name = get_temp_file_path()
320
+ estimator_file_name = temp_file_utils.get_temp_file_path()
322
321
  params_to_evaluate = []
323
322
  for param_to_eval in list(param_grid):
324
323
  for k, v in param_to_eval.items():
@@ -357,6 +356,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
357
356
  )
358
357
  estimator_location = put_result[0].target
359
358
  imports.append(f"@{temp_stage_name}/{estimator_location}")
359
+ temp_file_utils.cleanup_temp_files([estimator_file_name])
360
360
 
361
361
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
362
362
  random_udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
@@ -413,7 +413,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
413
413
  X = df[input_cols]
414
414
  y = df[label_cols].squeeze() if label_cols else None
415
415
 
416
- local_estimator_file_name = get_temp_file_path()
416
+ local_estimator_file_name = temp_file_utils.get_temp_file_path()
417
417
  session.file.get(stage_estimator_file_name, local_estimator_file_name)
418
418
 
419
419
  local_estimator_file_path = os.path.join(
@@ -429,7 +429,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
429
429
  n_splits = build_cross_validator.get_n_splits(X, y, None)
430
430
  # store the cross_validator's test indices only to save space
431
431
  cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
432
- local_indices_file_name = get_temp_file_path()
432
+ local_indices_file_name = temp_file_utils.get_temp_file_path()
433
433
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
434
434
  cp.dump(cross_validator_indices, local_indices_file_obj)
435
435
 
@@ -445,6 +445,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
445
445
  cross_validator_indices_length = int(len(cross_validator_indices))
446
446
  parameter_grid_length = len(param_grid)
447
447
 
448
+ temp_file_utils.cleanup_temp_files([local_estimator_file_name, local_indices_file_name])
449
+
448
450
  assert estimator is not None
449
451
 
450
452
  @cachetools.cached(cache={})
@@ -647,7 +649,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
647
649
  if hasattr(estimator.best_estimator_, "feature_names_in_"):
648
650
  estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
649
651
 
650
- local_result_file_name = get_temp_file_path()
652
+ local_result_file_name = temp_file_utils.get_temp_file_path()
651
653
 
652
654
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
653
655
  cp.dump(estimator, local_result_file_obj)
@@ -658,6 +660,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
658
660
  auto_compress=False,
659
661
  overwrite=True,
660
662
  )
663
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
661
664
 
662
665
  # Note: you can add something like + "|" + str(df) to the return string
663
666
  # to pass debug information to the caller.
@@ -671,7 +674,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
671
674
  label_cols,
672
675
  )
673
676
 
674
- local_estimator_path = get_temp_file_path()
677
+ local_estimator_path = temp_file_utils.get_temp_file_path()
675
678
  session.file.get(
676
679
  posixpath.join(temp_stage_name, sproc_export_file_name),
677
680
  local_estimator_path,
@@ -680,7 +683,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
680
683
  with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
681
684
  fit_estimator = cp.load(result_file_obj)
682
685
 
683
- cleanup_temp_files([local_estimator_path])
686
+ temp_file_utils.cleanup_temp_files([local_estimator_path])
684
687
 
685
688
  return fit_estimator
686
689
 
@@ -716,7 +719,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
716
719
  imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
717
720
 
718
721
  # Create a temp file and dump the estimator to that file.
719
- estimator_file_name = get_temp_file_path()
722
+ estimator_file_name = temp_file_utils.get_temp_file_path()
720
723
  params_to_evaluate = list(param_grid)
721
724
  CONSTANTS: Dict[str, Any] = dict()
722
725
  CONSTANTS["dataset_snowpark_cols"] = dataset.columns
@@ -757,6 +760,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
757
760
  )
758
761
  estimator_location = os.path.basename(estimator_file_name)
759
762
  imports.append(f"@{temp_stage_name}/{estimator_location}")
763
+ temp_file_utils.cleanup_temp_files([estimator_file_name])
760
764
  CONSTANTS["estimator_location"] = estimator_location
761
765
 
762
766
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
@@ -823,7 +827,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
823
827
  if sample_weight_col:
824
828
  fit_params["sample_weight"] = df[sample_weight_col].squeeze()
825
829
 
826
- local_estimator_file_folder_name = get_temp_file_path()
830
+ local_estimator_file_folder_name = temp_file_utils.get_temp_file_path()
827
831
  session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
828
832
 
829
833
  local_estimator_file_path = os.path.join(
@@ -869,7 +873,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
869
873
 
870
874
  # (1) store the cross_validator's test indices only to save space
871
875
  cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
872
- local_indices_file_name = get_temp_file_path()
876
+ local_indices_file_name = temp_file_utils.get_temp_file_path()
873
877
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
874
878
  cp.dump(cross_validator_indices, local_indices_file_obj)
875
879
 
@@ -884,7 +888,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
884
888
  imports.append(f"@{temp_stage_name}/{indices_location}")
885
889
 
886
890
  # (2) store the base estimator
887
- local_base_estimator_file_name = get_temp_file_path()
891
+ local_base_estimator_file_name = temp_file_utils.get_temp_file_path()
888
892
  with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
889
893
  cp.dump(base_estimator, local_base_estimator_file_obj)
890
894
  session.file.put(
@@ -897,7 +901,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
897
901
  imports.append(f"@{temp_stage_name}/{base_estimator_location}")
898
902
 
899
903
  # (3) store the fit_and_score_kwargs
900
- local_fit_and_score_kwargs_file_name = get_temp_file_path()
904
+ local_fit_and_score_kwargs_file_name = temp_file_utils.get_temp_file_path()
901
905
  with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
902
906
  cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
903
907
  session.file.put(
@@ -918,7 +922,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
918
922
  CONSTANTS["fit_and_score_kwargs_location"] = fit_and_score_kwargs_location
919
923
 
920
924
  # (6) store the constants
921
- local_constant_file_name = get_temp_file_path(prefix="constant")
925
+ local_constant_file_name = temp_file_utils.get_temp_file_path(prefix="constant")
922
926
  with open(local_constant_file_name, mode="w+b") as local_indices_file_obj:
923
927
  cp.dump(CONSTANTS, local_indices_file_obj)
924
928
 
@@ -932,6 +936,17 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
932
936
  constant_location = os.path.basename(local_constant_file_name)
933
937
  imports.append(f"@{temp_stage_name}/{constant_location}")
934
938
 
939
+ temp_file_utils.cleanup_temp_files(
940
+ [
941
+ local_estimator_file_folder_name,
942
+ local_indices_file_name,
943
+ local_base_estimator_file_name,
944
+ local_base_estimator_file_name,
945
+ local_fit_and_score_kwargs_file_name,
946
+ local_constant_file_name,
947
+ ]
948
+ )
949
+
935
950
  cross_validator_indices_length = int(len(cross_validator_indices))
936
951
  parameter_grid_length = len(param_grid)
937
952
 
@@ -942,124 +957,144 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
942
957
 
943
958
  import tempfile
944
959
 
960
+ # delete is set to False to support Windows environment
945
961
  with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
946
962
  udf_code = execute_template
947
963
  f.file.write(udf_code)
948
964
  f.file.flush()
949
965
 
950
- # Register the UDTF function from the file
951
- udtf_registration.register_from_file(
952
- file_path=f.name,
953
- handler_name="SearchCV",
954
- name=random_udtf_name,
955
- output_schema=StructType(
956
- [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
957
- ),
958
- input_types=[IntegerType(), IntegerType(), IntegerType()],
959
- replace=True,
960
- imports=imports, # type: ignore[arg-type]
961
- is_permanent=False,
962
- packages=required_deps, # type: ignore[arg-type]
963
- statement_params=udtf_statement_params,
964
- )
965
-
966
- HP_TUNING = F.table_function(random_udtf_name)
967
-
968
- # param_indices is for the index for each parameter grid;
969
- # cv_indices is for the index for each cross_validator's fold;
970
- # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
971
- cv_indices, param_indices = zip(
972
- *product(range(cross_validator_indices_length), range(parameter_grid_length))
973
- )
974
-
975
- indices_info_pandas = pd.DataFrame(
976
- {
977
- "IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
978
- "PARAM_IND": param_indices,
979
- "CV_IND": cv_indices,
980
- }
981
- )
982
-
983
- indices_info_sp = session.create_dataframe(indices_info_pandas)
984
- # execute udtf by querying HP_TUNING table
985
- HP_raw_results = indices_info_sp.select(
986
- (
987
- HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
988
- partition_by="IDX"
966
+ # Use catchall exception handling and a finally block to clean up the _UDTF_STAGE_NAME
967
+ try:
968
+ # Create one stage for data and for estimators.
969
+ # Because only permanent functions support _sf_node_singleton for now, therefore,
970
+ # UDTF creation would change to is_permanent=True, and manually drop the stage after UDTF is done
971
+ _stage_creation_query_udtf = f"CREATE OR REPLACE STAGE {_UDTF_STAGE_NAME};"
972
+ session.sql(_stage_creation_query_udtf).collect()
973
+
974
+ # Register the UDTF function from the file
975
+ udtf_registration.register_from_file(
976
+ file_path=f.name,
977
+ handler_name="SearchCV",
978
+ name=random_udtf_name,
979
+ output_schema=StructType(
980
+ [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
981
+ ),
982
+ input_types=[IntegerType(), IntegerType(), IntegerType()],
983
+ replace=True,
984
+ imports=imports, # type: ignore[arg-type]
985
+ stage_location=_UDTF_STAGE_NAME,
986
+ is_permanent=True,
987
+ packages=required_deps, # type: ignore[arg-type]
988
+ statement_params=udtf_statement_params,
989
989
  )
990
- ),
991
- )
992
-
993
- first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
994
- estimator,
995
- n_splits,
996
- list(param_grid),
997
- HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
998
- cross_validator_indices_length,
999
- parameter_grid_length,
1000
- )
1001
-
1002
- estimator.cv_results_ = cv_results_
1003
- estimator.multimetric_ = isinstance(first_test_score, dict)
1004
990
 
1005
- # check refit_metric now for a callable scorer that is multimetric
1006
- if callable(estimator.scoring) and estimator.multimetric_:
1007
- estimator._check_refit_for_multimetric(first_test_score)
1008
- refit_metric = estimator.refit
991
+ HP_TUNING = F.table_function(random_udtf_name)
1009
992
 
1010
- # For multi-metric evaluation, store the best_index_, best_params_ and
1011
- # best_score_ iff refit is one of the scorer names
1012
- # In single metric evaluation, refit_metric is "score"
1013
- if estimator.refit or not estimator.multimetric_:
1014
- estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1015
- if not callable(estimator.refit):
1016
- # With a non-custom callable, we can select the best score
1017
- # based on the best index
1018
- estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1019
- estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1020
-
1021
- if estimator.refit:
1022
- estimator.best_estimator_ = clone(base_estimator).set_params(
1023
- **clone(estimator.best_params_, safe=False)
1024
- )
993
+ # param_indices is for the index for each parameter grid;
994
+ # cv_indices is for the index for each cross_validator's fold;
995
+ # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
996
+ cv_indices, param_indices = zip(
997
+ *product(range(cross_validator_indices_length), range(parameter_grid_length))
998
+ )
1025
999
 
1026
- # Let the sproc use all cores to refit.
1027
- estimator.n_jobs = estimator.n_jobs or -1
1000
+ indices_info_pandas = pd.DataFrame(
1001
+ {
1002
+ "IDX": [
1003
+ i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)
1004
+ ],
1005
+ "PARAM_IND": param_indices,
1006
+ "CV_IND": cv_indices,
1007
+ }
1008
+ )
1028
1009
 
1029
- # process the input as args
1030
- argspec = inspect.getfullargspec(estimator.fit)
1031
- args = {"X": X}
1032
- if label_cols:
1033
- label_arg_name = "Y" if "Y" in argspec.args else "y"
1034
- args[label_arg_name] = y
1035
- if sample_weight_col is not None and "sample_weight" in argspec.args:
1036
- args["sample_weight"] = df[sample_weight_col].squeeze()
1037
- # estimator.refit = original_refit
1038
- refit_start_time = time.time()
1039
- estimator.best_estimator_.fit(**args)
1040
- refit_end_time = time.time()
1041
- estimator.refit_time_ = refit_end_time - refit_start_time
1010
+ indices_info_sp = session.create_dataframe(indices_info_pandas)
1011
+ # execute udtf by querying HP_TUNING table
1012
+ HP_raw_results = indices_info_sp.select(
1013
+ (
1014
+ HP_TUNING(
1015
+ indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]
1016
+ ).over(partition_by="IDX")
1017
+ ),
1018
+ )
1042
1019
 
1043
- if hasattr(estimator.best_estimator_, "feature_names_in_"):
1044
- estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1020
+ first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
1021
+ estimator,
1022
+ n_splits,
1023
+ list(param_grid),
1024
+ HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
1025
+ cross_validator_indices_length,
1026
+ parameter_grid_length,
1027
+ )
1045
1028
 
1046
- # Store the only scorer not as a dict for single metric evaluation
1047
- estimator.scorer_ = scorers
1048
- estimator.n_splits_ = n_splits
1029
+ estimator.cv_results_ = cv_results_
1030
+ estimator.multimetric_ = isinstance(first_test_score, dict)
1031
+
1032
+ # check refit_metric now for a callable scorer that is multimetric
1033
+ if callable(estimator.scoring) and estimator.multimetric_:
1034
+ estimator._check_refit_for_multimetric(first_test_score)
1035
+ refit_metric = estimator.refit
1036
+
1037
+ # For multi-metric evaluation, store the best_index_, best_params_ and
1038
+ # best_score_ iff refit is one of the scorer names
1039
+ # In single metric evaluation, refit_metric is "score"
1040
+ if estimator.refit or not estimator.multimetric_:
1041
+ estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1042
+ if not callable(estimator.refit):
1043
+ # With a non-custom callable, we can select the best score
1044
+ # based on the best index
1045
+ estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1046
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1047
+
1048
+ if estimator.refit:
1049
+ estimator.best_estimator_ = clone(base_estimator).set_params(
1050
+ **clone(estimator.best_params_, safe=False)
1051
+ )
1049
1052
 
1050
- local_result_file_name = get_temp_file_path()
1053
+ # Let the sproc use all cores to refit.
1054
+ estimator.n_jobs = estimator.n_jobs or -1
1055
+
1056
+ # process the input as args
1057
+ argspec = inspect.getfullargspec(estimator.fit)
1058
+ args = {"X": X}
1059
+ if label_cols:
1060
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
1061
+ args[label_arg_name] = y
1062
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
1063
+ args["sample_weight"] = df[sample_weight_col].squeeze()
1064
+ # estimator.refit = original_refit
1065
+ refit_start_time = time.time()
1066
+ estimator.best_estimator_.fit(**args)
1067
+ refit_end_time = time.time()
1068
+ estimator.refit_time_ = refit_end_time - refit_start_time
1069
+
1070
+ if hasattr(estimator.best_estimator_, "feature_names_in_"):
1071
+ estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1072
+
1073
+ # Store the only scorer not as a dict for single metric evaluation
1074
+ estimator.scorer_ = scorers
1075
+ estimator.n_splits_ = n_splits
1076
+
1077
+ local_result_file_name = temp_file_utils.get_temp_file_path()
1078
+
1079
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1080
+ cp.dump(estimator, local_result_file_obj)
1081
+
1082
+ session.file.put(
1083
+ local_result_file_name,
1084
+ temp_stage_name,
1085
+ auto_compress=False,
1086
+ overwrite=True,
1087
+ )
1051
1088
 
1052
- with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1053
- cp.dump(estimator, local_result_file_obj)
1089
+ # Clean up the stages and files
1090
+ session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
1054
1091
 
1055
- session.file.put(
1056
- local_result_file_name,
1057
- temp_stage_name,
1058
- auto_compress=False,
1059
- overwrite=True,
1060
- )
1092
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
1061
1093
 
1062
- return str(os.path.basename(local_result_file_name))
1094
+ return str(os.path.basename(local_result_file_name))
1095
+ finally:
1096
+ # Clean up the stages
1097
+ session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
1063
1098
 
1064
1099
  sproc_export_file_name = _distributed_search(
1065
1100
  session,
@@ -1069,7 +1104,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1069
1104
  label_cols,
1070
1105
  )
1071
1106
 
1072
- local_estimator_path = get_temp_file_path()
1107
+ local_estimator_path = temp_file_utils.get_temp_file_path()
1073
1108
  session.file.get(
1074
1109
  posixpath.join(temp_stage_name, sproc_export_file_name),
1075
1110
  local_estimator_path,
@@ -1078,7 +1113,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1078
1113
  with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
1079
1114
  fit_estimator = cp.load(result_file_obj)
1080
1115
 
1081
- cleanup_temp_files(local_estimator_path)
1116
+ temp_file_utils.cleanup_temp_files(local_estimator_path)
1082
1117
 
1083
1118
  return fit_estimator
1084
1119
 
@@ -156,4 +156,6 @@ class SearchCV:
156
156
  self.fit_score_params[0][0],
157
157
  binary_cv_results,
158
158
  )
159
+
160
+ SearchCV._sf_node_singleton = True
159
161
  """