snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sentiment.py +7 -4
  3. snowflake/cortex/_sse_client.py +81 -0
  4. snowflake/cortex/_util.py +105 -8
  5. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  6. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  7. snowflake/ml/dataset/dataset.py +15 -12
  8. snowflake/ml/dataset/dataset_factory.py +3 -4
  9. snowflake/ml/feature_store/access_manager.py +34 -30
  10. snowflake/ml/feature_store/feature_store.py +3 -3
  11. snowflake/ml/feature_store/feature_view.py +12 -11
  12. snowflake/ml/fileset/snowfs.py +2 -31
  13. snowflake/ml/model/_client/ops/model_ops.py +43 -0
  14. snowflake/ml/model/_client/sql/model_version.py +55 -3
  15. snowflake/ml/model/_model_composer/model_composer.py +7 -3
  16. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  17. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  18. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  19. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  20. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  21. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  22. snowflake/ml/model/_signatures/core.py +13 -1
  23. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  24. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  25. snowflake/ml/model/model_signature.py +2 -0
  26. snowflake/ml/model/type_hints.py +1 -0
  27. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
  29. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
  30. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  31. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  32. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
  36. snowflake/ml/modeling/cluster/birch.py +9 -2
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
  38. snowflake/ml/modeling/cluster/dbscan.py +9 -2
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
  40. snowflake/ml/modeling/cluster/k_means.py +9 -2
  41. snowflake/ml/modeling/cluster/mean_shift.py +9 -2
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
  43. snowflake/ml/modeling/cluster/optics.py +9 -2
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
  47. snowflake/ml/modeling/compose/column_transformer.py +9 -2
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
  54. snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
  55. snowflake/ml/modeling/covariance/oas.py +9 -2
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
  59. snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
  64. snowflake/ml/modeling/decomposition/pca.py +9 -2
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
  93. snowflake/ml/modeling/framework/base.py +3 -8
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
  96. snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
  97. snowflake/ml/modeling/impute/knn_imputer.py +9 -2
  98. snowflake/ml/modeling/impute/missing_indicator.py +9 -2
  99. snowflake/ml/modeling/impute/simple_imputer.py +28 -5
  100. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
  101. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
  102. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
  103. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
  104. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
  105. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
  106. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
  107. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
  108. snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
  109. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
  110. snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
  111. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
  112. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
  113. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
  114. snowflake/ml/modeling/linear_model/lars.py +9 -2
  115. snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
  116. snowflake/ml/modeling/linear_model/lasso.py +9 -2
  117. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
  118. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
  119. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
  120. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
  121. snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
  122. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
  123. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
  124. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
  126. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
  127. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
  128. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
  129. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
  130. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
  131. snowflake/ml/modeling/linear_model/perceptron.py +9 -2
  132. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
  133. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
  134. snowflake/ml/modeling/linear_model/ridge.py +9 -2
  135. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
  136. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
  137. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
  138. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
  139. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
  140. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
  141. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
  142. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
  143. snowflake/ml/modeling/manifold/isomap.py +9 -2
  144. snowflake/ml/modeling/manifold/mds.py +9 -2
  145. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
  146. snowflake/ml/modeling/manifold/tsne.py +9 -2
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
  161. snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
  171. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  172. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  173. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  174. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  175. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  176. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  177. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  178. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  179. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  180. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
  182. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  183. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  184. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
  185. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
  186. snowflake/ml/modeling/svm/linear_svc.py +9 -2
  187. snowflake/ml/modeling/svm/linear_svr.py +9 -2
  188. snowflake/ml/modeling/svm/nu_svc.py +9 -2
  189. snowflake/ml/modeling/svm/nu_svr.py +9 -2
  190. snowflake/ml/modeling/svm/svc.py +9 -2
  191. snowflake/ml/modeling/svm/svr.py +9 -2
  192. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
  193. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
  194. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
  195. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
  196. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
  197. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
  198. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
  199. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
  200. snowflake/ml/registry/_manager/model_manager.py +59 -1
  201. snowflake/ml/registry/registry.py +10 -1
  202. snowflake/ml/version.py +1 -1
  203. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
  204. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
  205. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  206. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  207. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,11 @@ import io
4
4
  import os
5
5
  import posixpath
6
6
  import sys
7
- from typing import Any, Dict, List, Optional, Set, Tuple, Union
7
+ import uuid
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
9
 
9
10
  import cloudpickle as cp
10
11
  import numpy as np
11
- import numpy.typing as npt
12
12
  from sklearn import model_selection
13
13
  from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
14
14
 
@@ -17,10 +17,7 @@ from snowflake.ml._internal.utils import (
17
17
  identifier,
18
18
  pkg_version_utils,
19
19
  snowpark_dataframe_utils,
20
- )
21
- from snowflake.ml._internal.utils.temp_file_utils import (
22
- cleanup_temp_files,
23
- get_temp_file_path,
20
+ temp_file_utils,
24
21
  )
25
22
  from snowflake.ml.modeling._internal.model_specifications import (
26
23
  ModelSpecificationsBuilder,
@@ -36,14 +33,16 @@ from snowflake.snowpark._internal.utils import (
36
33
  from snowflake.snowpark.functions import sproc, udtf
37
34
  from snowflake.snowpark.row import Row
38
35
  from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
36
+ from snowflake.snowpark.udtf import UDTFRegistration
39
37
 
40
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
38
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
41
39
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
42
40
  cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
43
41
 
44
42
  _PROJECT = "ModelDevelopment"
45
43
  DEFAULT_UDTF_NJOBS = 3
46
44
  ENABLE_EFFICIENT_MEMORY_USAGE = False
45
+ _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
47
46
 
48
47
 
49
48
  def construct_cv_results(
@@ -318,7 +317,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
318
317
  original_refit = estimator.refit
319
318
 
320
319
  # Create a temp file and dump the estimator to that file.
321
- estimator_file_name = get_temp_file_path()
320
+ estimator_file_name = temp_file_utils.get_temp_file_path()
322
321
  params_to_evaluate = []
323
322
  for param_to_eval in list(param_grid):
324
323
  for k, v in param_to_eval.items():
@@ -357,6 +356,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
357
356
  )
358
357
  estimator_location = put_result[0].target
359
358
  imports.append(f"@{temp_stage_name}/{estimator_location}")
359
+ temp_file_utils.cleanup_temp_files([estimator_file_name])
360
360
 
361
361
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
362
362
  random_udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
@@ -413,7 +413,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
413
413
  X = df[input_cols]
414
414
  y = df[label_cols].squeeze() if label_cols else None
415
415
 
416
- local_estimator_file_name = get_temp_file_path()
416
+ local_estimator_file_name = temp_file_utils.get_temp_file_path()
417
417
  session.file.get(stage_estimator_file_name, local_estimator_file_name)
418
418
 
419
419
  local_estimator_file_path = os.path.join(
@@ -429,7 +429,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
429
429
  n_splits = build_cross_validator.get_n_splits(X, y, None)
430
430
  # store the cross_validator's test indices only to save space
431
431
  cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
432
- local_indices_file_name = get_temp_file_path()
432
+ local_indices_file_name = temp_file_utils.get_temp_file_path()
433
433
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
434
434
  cp.dump(cross_validator_indices, local_indices_file_obj)
435
435
 
@@ -445,6 +445,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
445
445
  cross_validator_indices_length = int(len(cross_validator_indices))
446
446
  parameter_grid_length = len(param_grid)
447
447
 
448
+ temp_file_utils.cleanup_temp_files([local_estimator_file_name, local_indices_file_name])
449
+
448
450
  assert estimator is not None
449
451
 
450
452
  @cachetools.cached(cache={})
@@ -647,7 +649,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
647
649
  if hasattr(estimator.best_estimator_, "feature_names_in_"):
648
650
  estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
649
651
 
650
- local_result_file_name = get_temp_file_path()
652
+ local_result_file_name = temp_file_utils.get_temp_file_path()
651
653
 
652
654
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
653
655
  cp.dump(estimator, local_result_file_obj)
@@ -658,6 +660,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
658
660
  auto_compress=False,
659
661
  overwrite=True,
660
662
  )
663
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
661
664
 
662
665
  # Note: you can add something like + "|" + str(df) to the return string
663
666
  # to pass debug information to the caller.
@@ -671,7 +674,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
671
674
  label_cols,
672
675
  )
673
676
 
674
- local_estimator_path = get_temp_file_path()
677
+ local_estimator_path = temp_file_utils.get_temp_file_path()
675
678
  session.file.get(
676
679
  posixpath.join(temp_stage_name, sproc_export_file_name),
677
680
  local_estimator_path,
@@ -680,7 +683,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
680
683
  with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
681
684
  fit_estimator = cp.load(result_file_obj)
682
685
 
683
- cleanup_temp_files([local_estimator_path])
686
+ temp_file_utils.cleanup_temp_files([local_estimator_path])
684
687
 
685
688
  return fit_estimator
686
689
 
@@ -698,7 +701,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
698
701
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
699
702
  from itertools import product
700
703
 
701
- import cachetools
702
704
  from sklearn.base import clone, is_classifier
703
705
  from sklearn.calibration import check_cv
704
706
 
@@ -717,11 +719,13 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
717
719
  imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
718
720
 
719
721
  # Create a temp file and dump the estimator to that file.
720
- estimator_file_name = get_temp_file_path()
722
+ estimator_file_name = temp_file_utils.get_temp_file_path()
721
723
  params_to_evaluate = list(param_grid)
722
- n_candidates = len(params_to_evaluate)
723
- _N_JOBS = estimator.n_jobs
724
- _PRE_DISPATCH = estimator.pre_dispatch
724
+ CONSTANTS: Dict[str, Any] = dict()
725
+ CONSTANTS["dataset_snowpark_cols"] = dataset.columns
726
+ CONSTANTS["n_candidates"] = len(params_to_evaluate)
727
+ CONSTANTS["_N_JOBS"] = estimator.n_jobs
728
+ CONSTANTS["_PRE_DISPATCH"] = estimator.pre_dispatch
725
729
 
726
730
  with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
727
731
  cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
@@ -743,6 +747,9 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
743
747
  api_calls=[udtf],
744
748
  custom_tags=dict([("hpo_memory_efficient", True)]),
745
749
  )
750
+ from snowflake.ml.modeling._internal.snowpark_implementations.distributed_search_udf_file import (
751
+ execute_template,
752
+ )
746
753
 
747
754
  # Put locally serialized estimator on stage.
748
755
  session.file.put(
@@ -753,6 +760,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
753
760
  )
754
761
  estimator_location = os.path.basename(estimator_file_name)
755
762
  imports.append(f"@{temp_stage_name}/{estimator_location}")
763
+ temp_file_utils.cleanup_temp_files([estimator_file_name])
764
+ CONSTANTS["estimator_location"] = estimator_location
756
765
 
757
766
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
758
767
  random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
@@ -783,7 +792,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
783
792
  ) -> str:
784
793
  import os
785
794
  import time
786
- from typing import Iterator
787
795
 
788
796
  import cloudpickle as cp
789
797
  import pandas as pd
@@ -819,7 +827,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
819
827
  if sample_weight_col:
820
828
  fit_params["sample_weight"] = df[sample_weight_col].squeeze()
821
829
 
822
- local_estimator_file_folder_name = get_temp_file_path()
830
+ local_estimator_file_folder_name = temp_file_utils.get_temp_file_path()
823
831
  session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
824
832
 
825
833
  local_estimator_file_path = os.path.join(
@@ -865,7 +873,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
865
873
 
866
874
  # (1) store the cross_validator's test indices only to save space
867
875
  cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
868
- local_indices_file_name = get_temp_file_path()
876
+ local_indices_file_name = temp_file_utils.get_temp_file_path()
869
877
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
870
878
  cp.dump(cross_validator_indices, local_indices_file_obj)
871
879
 
@@ -880,7 +888,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
880
888
  imports.append(f"@{temp_stage_name}/{indices_location}")
881
889
 
882
890
  # (2) store the base estimator
883
- local_base_estimator_file_name = get_temp_file_path()
891
+ local_base_estimator_file_name = temp_file_utils.get_temp_file_path()
884
892
  with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
885
893
  cp.dump(base_estimator, local_base_estimator_file_obj)
886
894
  session.file.put(
@@ -893,7 +901,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
893
901
  imports.append(f"@{temp_stage_name}/{base_estimator_location}")
894
902
 
895
903
  # (3) store the fit_and_score_kwargs
896
- local_fit_and_score_kwargs_file_name = get_temp_file_path()
904
+ local_fit_and_score_kwargs_file_name = temp_file_utils.get_temp_file_path()
897
905
  with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
898
906
  cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
899
907
  session.file.put(
@@ -905,242 +913,188 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
905
913
  fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
906
914
  imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
907
915
 
908
- cross_validator_indices_length = int(len(cross_validator_indices))
909
- parameter_grid_length = len(param_grid)
910
-
911
- assert estimator is not None
916
+ CONSTANTS["input_cols"] = input_cols
917
+ CONSTANTS["label_cols"] = label_cols
918
+ CONSTANTS["DATA_LENGTH"] = DATA_LENGTH
919
+ CONSTANTS["n_splits"] = n_splits
920
+ CONSTANTS["indices_location"] = indices_location
921
+ CONSTANTS["base_estimator_location"] = base_estimator_location
922
+ CONSTANTS["fit_and_score_kwargs_location"] = fit_and_score_kwargs_location
912
923
 
913
- @cachetools.cached(cache={})
914
- def _load_data_into_udf() -> Tuple[
915
- npt.NDArray[Any],
916
- npt.NDArray[Any],
917
- List[List[int]],
918
- List[Dict[str, Any]],
919
- object,
920
- Dict[str, Any],
921
- ]:
922
- import pyarrow.parquet as pq
924
+ # (6) store the constants
925
+ local_constant_file_name = temp_file_utils.get_temp_file_path(prefix="constant")
926
+ with open(local_constant_file_name, mode="w+b") as local_indices_file_obj:
927
+ cp.dump(CONSTANTS, local_indices_file_obj)
923
928
 
924
- data_files = [
925
- filename
926
- for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
927
- if filename.startswith(dataset_file_name)
928
- ]
929
- partial_df = [
930
- pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
931
- for file_name in data_files
929
+ # Put locally serialized indices on stage.
930
+ session.file.put(
931
+ local_constant_file_name,
932
+ temp_stage_name,
933
+ auto_compress=False,
934
+ overwrite=True,
935
+ )
936
+ constant_location = os.path.basename(local_constant_file_name)
937
+ imports.append(f"@{temp_stage_name}/{constant_location}")
938
+
939
+ temp_file_utils.cleanup_temp_files(
940
+ [
941
+ local_estimator_file_folder_name,
942
+ local_indices_file_name,
943
+ local_base_estimator_file_name,
944
+ local_base_estimator_file_name,
945
+ local_fit_and_score_kwargs_file_name,
946
+ local_constant_file_name,
932
947
  ]
933
- df = pd.concat(partial_df, ignore_index=True)
934
- df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
948
+ )
935
949
 
936
- # load parameter grid
937
- local_estimator_file_path = os.path.join(
938
- sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
939
- )
940
- with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
941
- estimator_objects = cp.load(local_estimator_file_obj)
942
- params_to_evaluate = estimator_objects["param_grid"]
950
+ cross_validator_indices_length = int(len(cross_validator_indices))
951
+ parameter_grid_length = len(param_grid)
943
952
 
944
- # load indices
945
- local_indices_file_path = os.path.join(
946
- sys._xoptions["snowflake_import_directory"], f"{indices_location}"
947
- )
948
- with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
949
- indices = cp.load(local_indices_file_obj)
953
+ assert estimator is not None
950
954
 
951
- # load base estimator
952
- local_base_estimator_file_path = os.path.join(
953
- sys._xoptions["snowflake_import_directory"], f"{base_estimator_location}"
954
- )
955
- with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
956
- base_estimator = cp.load(local_base_estimator_file_obj)
955
+ # Instantiate UDTFRegistration with the session object
956
+ udtf_registration = UDTFRegistration(session)
957
+
958
+ import tempfile
959
+
960
+ # delete is set to False to support Windows environment
961
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
962
+ udf_code = execute_template
963
+ f.file.write(udf_code)
964
+ f.file.flush()
965
+
966
+ # Use catchall exception handling and a finally block to clean up the _UDTF_STAGE_NAME
967
+ try:
968
+ # Create one stage for data and for estimators.
969
+ # Because only permanent functions support _sf_node_singleton for now, therefore,
970
+ # UDTF creation would change to is_permanent=True, and manually drop the stage after UDTF is done
971
+ _stage_creation_query_udtf = f"CREATE OR REPLACE STAGE {_UDTF_STAGE_NAME};"
972
+ session.sql(_stage_creation_query_udtf).collect()
973
+
974
+ # Register the UDTF function from the file
975
+ udtf_registration.register_from_file(
976
+ file_path=f.name,
977
+ handler_name="SearchCV",
978
+ name=random_udtf_name,
979
+ output_schema=StructType(
980
+ [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
981
+ ),
982
+ input_types=[IntegerType(), IntegerType(), IntegerType()],
983
+ replace=True,
984
+ imports=imports, # type: ignore[arg-type]
985
+ stage_location=_UDTF_STAGE_NAME,
986
+ is_permanent=True,
987
+ packages=required_deps, # type: ignore[arg-type]
988
+ statement_params=udtf_statement_params,
989
+ )
957
990
 
958
- # load fit_and_score_kwargs
959
- local_fit_and_score_kwargs_file_path = os.path.join(
960
- sys._xoptions["snowflake_import_directory"], f"{fit_and_score_kwargs_location}"
961
- )
962
- with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
963
- fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
964
-
965
- # convert dataframe to numpy would save memory consumption
966
- return (
967
- df[input_cols].to_numpy(),
968
- df[label_cols].squeeze().to_numpy(),
969
- indices,
970
- params_to_evaluate,
971
- base_estimator,
972
- fit_and_score_kwargs,
973
- )
991
+ HP_TUNING = F.table_function(random_udtf_name)
974
992
 
975
- # Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
976
- class SearchCV:
977
- def __init__(self) -> None:
978
- X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs = _load_data_into_udf()
979
- self.X = X
980
- self.y = y
981
- self.test_indices = indices
982
- self.params_to_evaluate = params_to_evaluate
983
- self.base_estimator = base_estimator
984
- self.fit_and_score_kwargs = fit_and_score_kwargs
985
- self.fit_score_params: List[Any] = []
986
- self.cv_indices_set: Set[int] = set()
987
-
988
- def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
989
- self.fit_score_params.extend([[idx, params_idx, cv_idx]])
990
- self.cv_indices_set.add(cv_idx)
991
-
992
- def end_partition(self) -> Iterator[Tuple[int, str]]:
993
- from sklearn.base import clone
994
- from sklearn.model_selection._validation import _fit_and_score
995
- from sklearn.utils.parallel import Parallel, delayed
996
-
997
- cached_train_test_indices = {}
998
- # Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
999
- full_index = np.arange(DATA_LENGTH)
1000
- for i in self.cv_indices_set:
1001
- cached_train_test_indices[i] = [
1002
- np.setdiff1d(full_index, self.test_indices[i]),
1003
- self.test_indices[i],
1004
- ]
1005
-
1006
- parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
1007
-
1008
- out = parallel(
1009
- delayed(_fit_and_score)(
1010
- clone(self.base_estimator),
1011
- self.X,
1012
- self.y,
1013
- train=cached_train_test_indices[split_idx][0],
1014
- test=cached_train_test_indices[split_idx][1],
1015
- parameters=self.params_to_evaluate[cand_idx],
1016
- split_progress=(split_idx, n_splits),
1017
- candidate_progress=(cand_idx, n_candidates),
1018
- **self.fit_and_score_kwargs, # load sample weight here
1019
- )
1020
- for _, cand_idx, split_idx in self.fit_score_params
993
+ # param_indices is for the index for each parameter grid;
994
+ # cv_indices is for the index for each cross_validator's fold;
995
+ # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
996
+ cv_indices, param_indices = zip(
997
+ *product(range(cross_validator_indices_length), range(parameter_grid_length))
1021
998
  )
1022
999
 
1023
- binary_cv_results = None
1024
- with io.BytesIO() as f:
1025
- cp.dump(out, f)
1026
- f.seek(0)
1027
- binary_cv_results = f.getvalue().hex()
1028
- yield (
1029
- self.fit_score_params[0][0],
1030
- binary_cv_results,
1000
+ indices_info_pandas = pd.DataFrame(
1001
+ {
1002
+ "IDX": [
1003
+ i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)
1004
+ ],
1005
+ "PARAM_IND": param_indices,
1006
+ "CV_IND": cv_indices,
1007
+ }
1031
1008
  )
1032
1009
 
1033
- session.udtf.register(
1034
- SearchCV,
1035
- output_schema=StructType(
1036
- [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
1037
- ),
1038
- input_types=[IntegerType(), IntegerType(), IntegerType()],
1039
- name=random_udtf_name,
1040
- packages=required_deps, # type: ignore[arg-type]
1041
- replace=True,
1042
- is_permanent=False,
1043
- imports=imports, # type: ignore[arg-type]
1044
- statement_params=udtf_statement_params,
1045
- )
1046
-
1047
- HP_TUNING = F.table_function(random_udtf_name)
1048
-
1049
- # param_indices is for the index for each parameter grid;
1050
- # cv_indices is for the index for each cross_validator's fold;
1051
- # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
1052
- cv_indices, param_indices = zip(
1053
- *product(range(cross_validator_indices_length), range(parameter_grid_length))
1054
- )
1055
-
1056
- indices_info_pandas = pd.DataFrame(
1057
- {
1058
- "IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
1059
- "PARAM_IND": param_indices,
1060
- "CV_IND": cv_indices,
1061
- }
1062
- )
1063
-
1064
- indices_info_sp = session.create_dataframe(indices_info_pandas)
1065
- # execute udtf by querying HP_TUNING table
1066
- HP_raw_results = indices_info_sp.select(
1067
- (
1068
- HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
1069
- partition_by="IDX"
1010
+ indices_info_sp = session.create_dataframe(indices_info_pandas)
1011
+ # execute udtf by querying HP_TUNING table
1012
+ HP_raw_results = indices_info_sp.select(
1013
+ (
1014
+ HP_TUNING(
1015
+ indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]
1016
+ ).over(partition_by="IDX")
1017
+ ),
1070
1018
  )
1071
- ),
1072
- )
1073
-
1074
- first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
1075
- estimator,
1076
- n_splits,
1077
- list(param_grid),
1078
- HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
1079
- cross_validator_indices_length,
1080
- parameter_grid_length,
1081
- )
1082
-
1083
- estimator.cv_results_ = cv_results_
1084
- estimator.multimetric_ = isinstance(first_test_score, dict)
1085
-
1086
- # check refit_metric now for a callable scorer that is multimetric
1087
- if callable(estimator.scoring) and estimator.multimetric_:
1088
- estimator._check_refit_for_multimetric(first_test_score)
1089
- refit_metric = estimator.refit
1090
-
1091
- # For multi-metric evaluation, store the best_index_, best_params_ and
1092
- # best_score_ iff refit is one of the scorer names
1093
- # In single metric evaluation, refit_metric is "score"
1094
- if estimator.refit or not estimator.multimetric_:
1095
- estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1096
- if not callable(estimator.refit):
1097
- # With a non-custom callable, we can select the best score
1098
- # based on the best index
1099
- estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1100
- estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1101
-
1102
- if estimator.refit:
1103
- estimator.best_estimator_ = clone(base_estimator).set_params(
1104
- **clone(estimator.best_params_, safe=False)
1105
- )
1106
1019
 
1107
- # Let the sproc use all cores to refit.
1108
- estimator.n_jobs = estimator.n_jobs or -1
1109
-
1110
- # process the input as args
1111
- argspec = inspect.getfullargspec(estimator.fit)
1112
- args = {"X": X}
1113
- if label_cols:
1114
- label_arg_name = "Y" if "Y" in argspec.args else "y"
1115
- args[label_arg_name] = y
1116
- if sample_weight_col is not None and "sample_weight" in argspec.args:
1117
- args["sample_weight"] = df[sample_weight_col].squeeze()
1118
- # estimator.refit = original_refit
1119
- refit_start_time = time.time()
1120
- estimator.best_estimator_.fit(**args)
1121
- refit_end_time = time.time()
1122
- estimator.refit_time_ = refit_end_time - refit_start_time
1020
+ first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
1021
+ estimator,
1022
+ n_splits,
1023
+ list(param_grid),
1024
+ HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
1025
+ cross_validator_indices_length,
1026
+ parameter_grid_length,
1027
+ )
1123
1028
 
1124
- if hasattr(estimator.best_estimator_, "feature_names_in_"):
1125
- estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1029
+ estimator.cv_results_ = cv_results_
1030
+ estimator.multimetric_ = isinstance(first_test_score, dict)
1031
+
1032
+ # check refit_metric now for a callable scorer that is multimetric
1033
+ if callable(estimator.scoring) and estimator.multimetric_:
1034
+ estimator._check_refit_for_multimetric(first_test_score)
1035
+ refit_metric = estimator.refit
1036
+
1037
+ # For multi-metric evaluation, store the best_index_, best_params_ and
1038
+ # best_score_ iff refit is one of the scorer names
1039
+ # In single metric evaluation, refit_metric is "score"
1040
+ if estimator.refit or not estimator.multimetric_:
1041
+ estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1042
+ if not callable(estimator.refit):
1043
+ # With a non-custom callable, we can select the best score
1044
+ # based on the best index
1045
+ estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1046
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1047
+
1048
+ if estimator.refit:
1049
+ estimator.best_estimator_ = clone(base_estimator).set_params(
1050
+ **clone(estimator.best_params_, safe=False)
1051
+ )
1126
1052
 
1127
- # Store the only scorer not as a dict for single metric evaluation
1128
- estimator.scorer_ = scorers
1129
- estimator.n_splits_ = n_splits
1053
+ # Let the sproc use all cores to refit.
1054
+ estimator.n_jobs = estimator.n_jobs or -1
1055
+
1056
+ # process the input as args
1057
+ argspec = inspect.getfullargspec(estimator.fit)
1058
+ args = {"X": X}
1059
+ if label_cols:
1060
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
1061
+ args[label_arg_name] = y
1062
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
1063
+ args["sample_weight"] = df[sample_weight_col].squeeze()
1064
+ # estimator.refit = original_refit
1065
+ refit_start_time = time.time()
1066
+ estimator.best_estimator_.fit(**args)
1067
+ refit_end_time = time.time()
1068
+ estimator.refit_time_ = refit_end_time - refit_start_time
1069
+
1070
+ if hasattr(estimator.best_estimator_, "feature_names_in_"):
1071
+ estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1072
+
1073
+ # Store the only scorer not as a dict for single metric evaluation
1074
+ estimator.scorer_ = scorers
1075
+ estimator.n_splits_ = n_splits
1076
+
1077
+ local_result_file_name = temp_file_utils.get_temp_file_path()
1078
+
1079
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1080
+ cp.dump(estimator, local_result_file_obj)
1081
+
1082
+ session.file.put(
1083
+ local_result_file_name,
1084
+ temp_stage_name,
1085
+ auto_compress=False,
1086
+ overwrite=True,
1087
+ )
1130
1088
 
1131
- local_result_file_name = get_temp_file_path()
1089
+ # Clean up the stages and files
1090
+ session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
1132
1091
 
1133
- with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1134
- cp.dump(estimator, local_result_file_obj)
1092
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
1135
1093
 
1136
- session.file.put(
1137
- local_result_file_name,
1138
- temp_stage_name,
1139
- auto_compress=False,
1140
- overwrite=True,
1141
- )
1142
-
1143
- return str(os.path.basename(local_result_file_name))
1094
+ return str(os.path.basename(local_result_file_name))
1095
+ finally:
1096
+ # Clean up the stages
1097
+ session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
1144
1098
 
1145
1099
  sproc_export_file_name = _distributed_search(
1146
1100
  session,
@@ -1150,7 +1104,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1150
1104
  label_cols,
1151
1105
  )
1152
1106
 
1153
- local_estimator_path = get_temp_file_path()
1107
+ local_estimator_path = temp_file_utils.get_temp_file_path()
1154
1108
  session.file.get(
1155
1109
  posixpath.join(temp_stage_name, sproc_export_file_name),
1156
1110
  local_estimator_path,
@@ -1159,7 +1113,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1159
1113
  with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
1160
1114
  fit_estimator = cp.load(result_file_obj)
1161
1115
 
1162
- cleanup_temp_files(local_estimator_path)
1116
+ temp_file_utils.cleanup_temp_files(local_estimator_path)
1163
1117
 
1164
1118
  return fit_estimator
1165
1119