snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (250) hide show
  1. snowflake/cortex/__init__.py +2 -1
  2. snowflake/cortex/_complete.py +240 -16
  3. snowflake/cortex/_extract_answer.py +0 -1
  4. snowflake/cortex/_sentiment.py +0 -1
  5. snowflake/cortex/_sse_client.py +81 -0
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +34 -10
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  16. snowflake/ml/_internal/telemetry.py +26 -0
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/dataset/dataset.py +54 -32
  20. snowflake/ml/dataset/dataset_factory.py +3 -4
  21. snowflake/ml/feature_store/feature_store.py +440 -243
  22. snowflake/ml/feature_store/feature_view.py +61 -9
  23. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  24. snowflake/ml/fileset/fileset.py +2 -2
  25. snowflake/ml/fileset/snowfs.py +4 -15
  26. snowflake/ml/fileset/stage_fs.py +6 -8
  27. snowflake/ml/lineage/__init__.py +3 -0
  28. snowflake/ml/lineage/lineage_node.py +139 -0
  29. snowflake/ml/model/_client/model/model_impl.py +47 -14
  30. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  31. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  32. snowflake/ml/model/_client/sql/model.py +1 -0
  33. snowflake/ml/model/_client/sql/model_version.py +47 -4
  34. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  35. snowflake/ml/model/_model_composer/model_composer.py +7 -6
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +7 -1
  37. snowflake/ml/model/_model_composer/model_method/function_generator.py +17 -1
  38. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +79 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -3
  40. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -5
  41. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  42. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -0
  43. snowflake/ml/model/_packager/model_handlers/catboost.py +2 -2
  44. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  45. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  46. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -2
  47. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  48. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  49. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  50. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  51. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  52. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  53. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  56. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  57. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  58. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  59. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  60. snowflake/ml/model/_packager/model_packager.py +9 -4
  61. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  62. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  63. snowflake/ml/model/_signatures/core.py +13 -1
  64. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  65. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  66. snowflake/ml/model/custom_model.py +22 -2
  67. snowflake/ml/model/model_signature.py +2 -0
  68. snowflake/ml/model/type_hints.py +74 -4
  69. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  70. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +158 -121
  71. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  72. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +39 -18
  73. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +88 -134
  74. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +22 -17
  75. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  76. snowflake/ml/modeling/cluster/affinity_propagation.py +5 -3
  77. snowflake/ml/modeling/cluster/agglomerative_clustering.py +5 -3
  78. snowflake/ml/modeling/cluster/birch.py +5 -3
  79. snowflake/ml/modeling/cluster/bisecting_k_means.py +5 -3
  80. snowflake/ml/modeling/cluster/dbscan.py +5 -3
  81. snowflake/ml/modeling/cluster/feature_agglomeration.py +5 -3
  82. snowflake/ml/modeling/cluster/k_means.py +5 -3
  83. snowflake/ml/modeling/cluster/mean_shift.py +5 -3
  84. snowflake/ml/modeling/cluster/mini_batch_k_means.py +5 -3
  85. snowflake/ml/modeling/cluster/optics.py +5 -3
  86. snowflake/ml/modeling/cluster/spectral_biclustering.py +5 -3
  87. snowflake/ml/modeling/cluster/spectral_clustering.py +5 -3
  88. snowflake/ml/modeling/cluster/spectral_coclustering.py +5 -3
  89. snowflake/ml/modeling/compose/column_transformer.py +5 -3
  90. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  91. snowflake/ml/modeling/covariance/elliptic_envelope.py +5 -3
  92. snowflake/ml/modeling/covariance/empirical_covariance.py +5 -3
  93. snowflake/ml/modeling/covariance/graphical_lasso.py +5 -3
  94. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +5 -3
  95. snowflake/ml/modeling/covariance/ledoit_wolf.py +5 -3
  96. snowflake/ml/modeling/covariance/min_cov_det.py +5 -3
  97. snowflake/ml/modeling/covariance/oas.py +5 -3
  98. snowflake/ml/modeling/covariance/shrunk_covariance.py +5 -3
  99. snowflake/ml/modeling/decomposition/dictionary_learning.py +5 -3
  100. snowflake/ml/modeling/decomposition/factor_analysis.py +5 -3
  101. snowflake/ml/modeling/decomposition/fast_ica.py +5 -3
  102. snowflake/ml/modeling/decomposition/incremental_pca.py +5 -3
  103. snowflake/ml/modeling/decomposition/kernel_pca.py +5 -3
  104. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -3
  105. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -3
  106. snowflake/ml/modeling/decomposition/pca.py +5 -3
  107. snowflake/ml/modeling/decomposition/sparse_pca.py +5 -3
  108. snowflake/ml/modeling/decomposition/truncated_svd.py +5 -3
  109. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  110. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  111. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  112. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  113. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  114. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  115. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  116. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  117. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  118. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  119. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  120. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  121. snowflake/ml/modeling/ensemble/isolation_forest.py +5 -3
  122. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  123. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  124. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  125. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  126. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  127. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  128. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  129. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  130. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  131. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  132. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  133. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  134. snowflake/ml/modeling/feature_selection/variance_threshold.py +5 -3
  135. snowflake/ml/modeling/framework/base.py +3 -8
  136. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  137. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  138. snowflake/ml/modeling/impute/iterative_imputer.py +5 -3
  139. snowflake/ml/modeling/impute/knn_imputer.py +5 -3
  140. snowflake/ml/modeling/impute/missing_indicator.py +5 -3
  141. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  142. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +5 -3
  143. snowflake/ml/modeling/kernel_approximation/nystroem.py +5 -3
  144. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +5 -3
  145. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +5 -3
  146. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +5 -3
  147. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  148. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  149. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  151. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  152. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  153. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  154. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  155. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  156. snowflake/ml/modeling/linear_model/lars.py +1 -1
  157. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  158. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  159. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  160. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  161. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  162. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  163. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  164. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  165. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  166. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  167. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  168. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  169. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  170. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  171. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  172. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  173. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  174. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  175. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  176. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  177. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  178. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  179. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  180. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  181. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -3
  182. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  183. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  184. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  185. snowflake/ml/modeling/manifold/isomap.py +5 -3
  186. snowflake/ml/modeling/manifold/mds.py +5 -3
  187. snowflake/ml/modeling/manifold/spectral_embedding.py +5 -3
  188. snowflake/ml/modeling/manifold/tsne.py +5 -3
  189. snowflake/ml/modeling/metrics/ranking.py +3 -0
  190. snowflake/ml/modeling/metrics/regression.py +3 -0
  191. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +5 -3
  192. snowflake/ml/modeling/mixture/gaussian_mixture.py +5 -3
  193. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  194. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  195. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  196. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  197. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  198. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  199. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  200. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  201. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  202. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  203. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  204. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  205. snowflake/ml/modeling/neighbors/kernel_density.py +5 -3
  206. snowflake/ml/modeling/neighbors/local_outlier_factor.py +5 -3
  207. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  208. snowflake/ml/modeling/neighbors/nearest_neighbors.py +5 -3
  209. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  210. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  211. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  212. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +5 -3
  213. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  214. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  215. snowflake/ml/modeling/pipeline/pipeline.py +6 -0
  216. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  217. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  218. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  219. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  220. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  221. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  222. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +53 -11
  223. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +44 -13
  224. snowflake/ml/modeling/preprocessing/polynomial_features.py +5 -3
  225. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  226. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  227. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  228. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  229. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  230. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  231. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  232. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  233. snowflake/ml/modeling/svm/svc.py +1 -1
  234. snowflake/ml/modeling/svm/svr.py +1 -1
  235. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  236. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  237. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  238. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  239. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  240. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  241. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  242. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  243. snowflake/ml/registry/_manager/model_manager.py +16 -3
  244. snowflake/ml/version.py +1 -1
  245. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/METADATA +51 -7
  246. snowflake_ml_python-1.5.4.dist-info/RECORD +389 -0
  247. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/WHEEL +1 -1
  248. snowflake_ml_python-1.5.2.dist-info/RECORD +0 -384
  249. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/LICENSE.txt +0 -0
  250. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.4.dist-info}/top_level.txt +0 -0
@@ -54,6 +54,7 @@ _SupportedNumpyDtype = Union[
54
54
  "np.bool_",
55
55
  "np.str_",
56
56
  "np.bytes_",
57
+ "np.datetime64",
57
58
  ]
58
59
  _SupportedNumpyArray = npt.NDArray[_SupportedNumpyDtype]
59
60
  _SupportedBuiltinsList = Sequence[_SupportedBuiltins]
@@ -312,15 +313,84 @@ ModelSaveOption = Union[
312
313
  ]
313
314
 
314
315
 
315
- class ModelLoadOption(TypedDict):
316
- """Options for loading the model.
316
+ class BaseModelLoadOption(TypedDict):
317
+ """Options for loading the model."""
318
+
319
+ ...
320
+
321
+
322
+ class CatBoostModelLoadOptions(BaseModelLoadOption):
323
+ use_gpu: NotRequired[bool]
324
+
325
+
326
+ class CustomModelLoadOption(BaseModelLoadOption):
327
+ ...
328
+
329
+
330
+ class SKLModelLoadOptions(BaseModelLoadOption):
331
+ ...
332
+
333
+
334
+ class XGBModelLoadOptions(BaseModelLoadOption):
335
+ use_gpu: NotRequired[bool]
336
+
337
+
338
+ class LGBMModelLoadOptions(BaseModelLoadOption):
339
+ ...
340
+
341
+
342
+ class SNOWModelLoadOptions(BaseModelLoadOption):
343
+ ...
317
344
 
318
- use_gpu: Enable GPU-specific loading logic.
319
- """
320
345
 
346
+ class PyTorchLoadOptions(BaseModelLoadOption):
321
347
  use_gpu: NotRequired[bool]
322
348
 
323
349
 
350
+ class TorchScriptLoadOptions(BaseModelLoadOption):
351
+ use_gpu: NotRequired[bool]
352
+
353
+
354
+ class TensorflowLoadOptions(BaseModelLoadOption):
355
+ ...
356
+
357
+
358
+ class MLFlowLoadOptions(BaseModelLoadOption):
359
+ ...
360
+
361
+
362
+ class HuggingFaceLoadOptions(BaseModelLoadOption):
363
+ use_gpu: NotRequired[bool]
364
+ device_map: NotRequired[str]
365
+ device: NotRequired[Union[str, int]]
366
+
367
+
368
+ class SentenceTransformersLoadOptions(BaseModelLoadOption):
369
+ use_gpu: NotRequired[bool]
370
+
371
+
372
+ class LLMLoadOptions(BaseModelLoadOption):
373
+ ...
374
+
375
+
376
+ ModelLoadOption = Union[
377
+ BaseModelLoadOption,
378
+ CatBoostModelLoadOptions,
379
+ CustomModelLoadOption,
380
+ LGBMModelLoadOptions,
381
+ SKLModelLoadOptions,
382
+ XGBModelLoadOptions,
383
+ SNOWModelLoadOptions,
384
+ PyTorchLoadOptions,
385
+ TorchScriptLoadOptions,
386
+ TensorflowLoadOptions,
387
+ MLFlowLoadOptions,
388
+ HuggingFaceLoadOptions,
389
+ SentenceTransformersLoadOptions,
390
+ LLMLoadOptions,
391
+ ]
392
+
393
+
324
394
  class SnowparkContainerServiceDeployDetails(TypedDict):
325
395
  """
326
396
  Attributes:
@@ -1,15 +1,19 @@
1
1
  import inspect
2
2
  import numbers
3
+ import os
3
4
  from typing import Any, Callable, Dict, List, Set, Tuple
4
5
 
6
+ import cloudpickle as cp
5
7
  import numpy as np
6
8
  from numpy import typing as npt
7
- from typing_extensions import TypeGuard
8
9
 
9
10
  from snowflake.ml._internal.exceptions import error_codes, exceptions
11
+ from snowflake.ml._internal.utils import temp_file_utils
12
+ from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
10
13
  from snowflake.ml.modeling.framework._utils import to_native_format
11
14
  from snowflake.ml.modeling.framework.base import BaseTransformer
12
15
  from snowflake.snowpark import Session
16
+ from snowflake.snowpark._internal import utils as snowpark_utils
13
17
 
14
18
 
15
19
  def validate_sklearn_args(args: Dict[str, Tuple[Any, Any, bool]], klass: type) -> Dict[str, Any]:
@@ -97,6 +101,7 @@ def original_estimator_has_callable(attr: str) -> Callable[[Any], bool]:
97
101
  Returns:
98
102
  A function which checks for the existence of callable `attr` on the given object.
99
103
  """
104
+ from typing_extensions import TypeGuard
100
105
 
101
106
  def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
102
107
  """Check for the existence of callable `attr` in self.
@@ -218,3 +223,55 @@ def handle_inference_result(
218
223
  )
219
224
 
220
225
  return transformed_numpy_array, output_cols
226
+
227
+
228
+ def create_temp_stage(session: Session) -> str:
229
+ """Creates temporary stage.
230
+
231
+ Args:
232
+ session: Session
233
+
234
+ Returns:
235
+ Temp stage name.
236
+ """
237
+ # Create temp stage to upload pickled model file.
238
+ transform_stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
239
+ stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
240
+ SqlResultValidator(session=session, query=stage_creation_query).has_dimensions(
241
+ expected_rows=1, expected_cols=1
242
+ ).validate()
243
+ return transform_stage_name
244
+
245
+
246
+ def upload_model_to_stage(
247
+ stage_name: str, estimator: object, session: Session, statement_params: Dict[str, str]
248
+ ) -> str:
249
+ """Util method to pickle and upload the model to a temp Snowflake stage.
250
+
251
+
252
+ Args:
253
+ stage_name: Stage name to save model.
254
+ estimator: Estimator object to upload to stage (sklearn model object)
255
+ session: The snowpark session to use.
256
+ statement_params: Statement parameters for query telemetry.
257
+
258
+ Returns:
259
+ a tuple containing stage file paths for pickled input model for training and location to store trained
260
+ models(response from training sproc).
261
+ """
262
+ # Create a temp file and dump the transform to that file.
263
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
264
+ with open(local_transform_file_name, mode="w+b") as local_transform_file:
265
+ cp.dump(estimator, local_transform_file)
266
+
267
+ # Put locally serialized transform on stage.
268
+ session.file.put(
269
+ local_file_name=local_transform_file_name,
270
+ stage_location=stage_name,
271
+ auto_compress=False,
272
+ overwrite=True,
273
+ statement_params=statement_params,
274
+ )
275
+
276
+ temp_file_utils.cleanup_temp_files([local_transform_file_name])
277
+ return os.path.basename(local_transform_file_name)
@@ -4,6 +4,7 @@ import io
4
4
  import os
5
5
  import posixpath
6
6
  import sys
7
+ import uuid
7
8
  from typing import Any, Dict, List, Optional, Tuple, Union
8
9
 
9
10
  import cloudpickle as cp
@@ -16,10 +17,7 @@ from snowflake.ml._internal.utils import (
16
17
  identifier,
17
18
  pkg_version_utils,
18
19
  snowpark_dataframe_utils,
19
- )
20
- from snowflake.ml._internal.utils.temp_file_utils import (
21
- cleanup_temp_files,
22
- get_temp_file_path,
20
+ temp_file_utils,
23
21
  )
24
22
  from snowflake.ml.modeling._internal.model_specifications import (
25
23
  ModelSpecificationsBuilder,
@@ -37,13 +35,14 @@ from snowflake.snowpark.row import Row
37
35
  from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
38
36
  from snowflake.snowpark.udtf import UDTFRegistration
39
37
 
40
- cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
38
+ cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
41
39
  cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
42
40
  cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
43
41
 
44
42
  _PROJECT = "ModelDevelopment"
45
43
  DEFAULT_UDTF_NJOBS = 3
46
44
  ENABLE_EFFICIENT_MEMORY_USAGE = False
45
+ _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
47
46
 
48
47
 
49
48
  def construct_cv_results(
@@ -318,7 +317,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
318
317
  original_refit = estimator.refit
319
318
 
320
319
  # Create a temp file and dump the estimator to that file.
321
- estimator_file_name = get_temp_file_path()
320
+ estimator_file_name = temp_file_utils.get_temp_file_path()
322
321
  params_to_evaluate = []
323
322
  for param_to_eval in list(param_grid):
324
323
  for k, v in param_to_eval.items():
@@ -357,6 +356,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
357
356
  )
358
357
  estimator_location = put_result[0].target
359
358
  imports.append(f"@{temp_stage_name}/{estimator_location}")
359
+ temp_file_utils.cleanup_temp_files([estimator_file_name])
360
360
 
361
361
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
362
362
  random_udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
@@ -377,6 +377,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
377
377
  anonymous=True,
378
378
  imports=imports, # type: ignore[arg-type]
379
379
  statement_params=sproc_statement_params,
380
+ execute_as="caller",
380
381
  )
381
382
  def _distributed_search(
382
383
  session: Session,
@@ -413,7 +414,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
413
414
  X = df[input_cols]
414
415
  y = df[label_cols].squeeze() if label_cols else None
415
416
 
416
- local_estimator_file_name = get_temp_file_path()
417
+ local_estimator_file_name = temp_file_utils.get_temp_file_path()
417
418
  session.file.get(stage_estimator_file_name, local_estimator_file_name)
418
419
 
419
420
  local_estimator_file_path = os.path.join(
@@ -429,7 +430,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
429
430
  n_splits = build_cross_validator.get_n_splits(X, y, None)
430
431
  # store the cross_validator's test indices only to save space
431
432
  cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
432
- local_indices_file_name = get_temp_file_path()
433
+ local_indices_file_name = temp_file_utils.get_temp_file_path()
433
434
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
434
435
  cp.dump(cross_validator_indices, local_indices_file_obj)
435
436
 
@@ -445,6 +446,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
445
446
  cross_validator_indices_length = int(len(cross_validator_indices))
446
447
  parameter_grid_length = len(param_grid)
447
448
 
449
+ temp_file_utils.cleanup_temp_files([local_estimator_file_name, local_indices_file_name])
450
+
448
451
  assert estimator is not None
449
452
 
450
453
  @cachetools.cached(cache={})
@@ -647,7 +650,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
647
650
  if hasattr(estimator.best_estimator_, "feature_names_in_"):
648
651
  estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
649
652
 
650
- local_result_file_name = get_temp_file_path()
653
+ local_result_file_name = temp_file_utils.get_temp_file_path()
651
654
 
652
655
  with open(local_result_file_name, mode="w+b") as local_result_file_obj:
653
656
  cp.dump(estimator, local_result_file_obj)
@@ -658,6 +661,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
658
661
  auto_compress=False,
659
662
  overwrite=True,
660
663
  )
664
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
661
665
 
662
666
  # Note: you can add something like + "|" + str(df) to the return string
663
667
  # to pass debug information to the caller.
@@ -671,7 +675,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
671
675
  label_cols,
672
676
  )
673
677
 
674
- local_estimator_path = get_temp_file_path()
678
+ local_estimator_path = temp_file_utils.get_temp_file_path()
675
679
  session.file.get(
676
680
  posixpath.join(temp_stage_name, sproc_export_file_name),
677
681
  local_estimator_path,
@@ -680,7 +684,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
680
684
  with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
681
685
  fit_estimator = cp.load(result_file_obj)
682
686
 
683
- cleanup_temp_files([local_estimator_path])
687
+ temp_file_utils.cleanup_temp_files([local_estimator_path])
684
688
 
685
689
  return fit_estimator
686
690
 
@@ -716,7 +720,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
716
720
  imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
717
721
 
718
722
  # Create a temp file and dump the estimator to that file.
719
- estimator_file_name = get_temp_file_path()
723
+ estimator_file_name = temp_file_utils.get_temp_file_path()
720
724
  params_to_evaluate = list(param_grid)
721
725
  CONSTANTS: Dict[str, Any] = dict()
722
726
  CONSTANTS["dataset_snowpark_cols"] = dataset.columns
@@ -757,6 +761,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
757
761
  )
758
762
  estimator_location = os.path.basename(estimator_file_name)
759
763
  imports.append(f"@{temp_stage_name}/{estimator_location}")
764
+ temp_file_utils.cleanup_temp_files([estimator_file_name])
760
765
  CONSTANTS["estimator_location"] = estimator_location
761
766
 
762
767
  search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
@@ -778,6 +783,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
778
783
  anonymous=True,
779
784
  imports=imports, # type: ignore[arg-type]
780
785
  statement_params=sproc_statement_params,
786
+ execute_as="caller",
781
787
  )
782
788
  def _distributed_search(
783
789
  session: Session,
@@ -823,7 +829,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
823
829
  if sample_weight_col:
824
830
  fit_params["sample_weight"] = df[sample_weight_col].squeeze()
825
831
 
826
- local_estimator_file_folder_name = get_temp_file_path()
832
+ local_estimator_file_folder_name = temp_file_utils.get_temp_file_path()
827
833
  session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
828
834
 
829
835
  local_estimator_file_path = os.path.join(
@@ -869,7 +875,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
869
875
 
870
876
  # (1) store the cross_validator's test indices only to save space
871
877
  cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
872
- local_indices_file_name = get_temp_file_path()
878
+ local_indices_file_name = temp_file_utils.get_temp_file_path()
873
879
  with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
874
880
  cp.dump(cross_validator_indices, local_indices_file_obj)
875
881
 
@@ -884,7 +890,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
884
890
  imports.append(f"@{temp_stage_name}/{indices_location}")
885
891
 
886
892
  # (2) store the base estimator
887
- local_base_estimator_file_name = get_temp_file_path()
893
+ local_base_estimator_file_name = temp_file_utils.get_temp_file_path()
888
894
  with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
889
895
  cp.dump(base_estimator, local_base_estimator_file_obj)
890
896
  session.file.put(
@@ -897,7 +903,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
897
903
  imports.append(f"@{temp_stage_name}/{base_estimator_location}")
898
904
 
899
905
  # (3) store the fit_and_score_kwargs
900
- local_fit_and_score_kwargs_file_name = get_temp_file_path()
906
+ local_fit_and_score_kwargs_file_name = temp_file_utils.get_temp_file_path()
901
907
  with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
902
908
  cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
903
909
  session.file.put(
@@ -918,7 +924,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
918
924
  CONSTANTS["fit_and_score_kwargs_location"] = fit_and_score_kwargs_location
919
925
 
920
926
  # (6) store the constants
921
- local_constant_file_name = get_temp_file_path(prefix="constant")
927
+ local_constant_file_name = temp_file_utils.get_temp_file_path(prefix="constant")
922
928
  with open(local_constant_file_name, mode="w+b") as local_indices_file_obj:
923
929
  cp.dump(CONSTANTS, local_indices_file_obj)
924
930
 
@@ -932,6 +938,17 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
932
938
  constant_location = os.path.basename(local_constant_file_name)
933
939
  imports.append(f"@{temp_stage_name}/{constant_location}")
934
940
 
941
+ temp_file_utils.cleanup_temp_files(
942
+ [
943
+ local_estimator_file_folder_name,
944
+ local_indices_file_name,
945
+ local_base_estimator_file_name,
946
+ local_base_estimator_file_name,
947
+ local_fit_and_score_kwargs_file_name,
948
+ local_constant_file_name,
949
+ ]
950
+ )
951
+
935
952
  cross_validator_indices_length = int(len(cross_validator_indices))
936
953
  parameter_grid_length = len(param_grid)
937
954
 
@@ -942,124 +959,144 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
942
959
 
943
960
  import tempfile
944
961
 
962
+ # delete is set to False to support Windows environment
945
963
  with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
946
964
  udf_code = execute_template
947
965
  f.file.write(udf_code)
948
966
  f.file.flush()
949
967
 
950
- # Register the UDTF function from the file
951
- udtf_registration.register_from_file(
952
- file_path=f.name,
953
- handler_name="SearchCV",
954
- name=random_udtf_name,
955
- output_schema=StructType(
956
- [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
957
- ),
958
- input_types=[IntegerType(), IntegerType(), IntegerType()],
959
- replace=True,
960
- imports=imports, # type: ignore[arg-type]
961
- is_permanent=False,
962
- packages=required_deps, # type: ignore[arg-type]
963
- statement_params=udtf_statement_params,
964
- )
965
-
966
- HP_TUNING = F.table_function(random_udtf_name)
967
-
968
- # param_indices is for the index for each parameter grid;
969
- # cv_indices is for the index for each cross_validator's fold;
970
- # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
971
- cv_indices, param_indices = zip(
972
- *product(range(cross_validator_indices_length), range(parameter_grid_length))
973
- )
974
-
975
- indices_info_pandas = pd.DataFrame(
976
- {
977
- "IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
978
- "PARAM_IND": param_indices,
979
- "CV_IND": cv_indices,
980
- }
981
- )
982
-
983
- indices_info_sp = session.create_dataframe(indices_info_pandas)
984
- # execute udtf by querying HP_TUNING table
985
- HP_raw_results = indices_info_sp.select(
986
- (
987
- HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
988
- partition_by="IDX"
968
+ # Use catchall exception handling and a finally block to clean up the _UDTF_STAGE_NAME
969
+ try:
970
+ # Create one stage for data and for estimators.
971
+ # Because only permanent functions support _sf_node_singleton for now, therefore,
972
+ # UDTF creation would change to is_permanent=True, and manually drop the stage after UDTF is done
973
+ _stage_creation_query_udtf = f"CREATE OR REPLACE STAGE {_UDTF_STAGE_NAME};"
974
+ session.sql(_stage_creation_query_udtf).collect()
975
+
976
+ # Register the UDTF function from the file
977
+ udtf_registration.register_from_file(
978
+ file_path=f.name,
979
+ handler_name="SearchCV",
980
+ name=random_udtf_name,
981
+ output_schema=StructType(
982
+ [StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
983
+ ),
984
+ input_types=[IntegerType(), IntegerType(), IntegerType()],
985
+ replace=True,
986
+ imports=imports, # type: ignore[arg-type]
987
+ stage_location=_UDTF_STAGE_NAME,
988
+ is_permanent=True,
989
+ packages=required_deps, # type: ignore[arg-type]
990
+ statement_params=udtf_statement_params,
989
991
  )
990
- ),
991
- )
992
-
993
- first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
994
- estimator,
995
- n_splits,
996
- list(param_grid),
997
- HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
998
- cross_validator_indices_length,
999
- parameter_grid_length,
1000
- )
1001
-
1002
- estimator.cv_results_ = cv_results_
1003
- estimator.multimetric_ = isinstance(first_test_score, dict)
1004
992
 
1005
- # check refit_metric now for a callable scorer that is multimetric
1006
- if callable(estimator.scoring) and estimator.multimetric_:
1007
- estimator._check_refit_for_multimetric(first_test_score)
1008
- refit_metric = estimator.refit
993
+ HP_TUNING = F.table_function(random_udtf_name)
1009
994
 
1010
- # For multi-metric evaluation, store the best_index_, best_params_ and
1011
- # best_score_ iff refit is one of the scorer names
1012
- # In single metric evaluation, refit_metric is "score"
1013
- if estimator.refit or not estimator.multimetric_:
1014
- estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1015
- if not callable(estimator.refit):
1016
- # With a non-custom callable, we can select the best score
1017
- # based on the best index
1018
- estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1019
- estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1020
-
1021
- if estimator.refit:
1022
- estimator.best_estimator_ = clone(base_estimator).set_params(
1023
- **clone(estimator.best_params_, safe=False)
1024
- )
995
+ # param_indices is for the index for each parameter grid;
996
+ # cv_indices is for the index for each cross_validator's fold;
997
+ # param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
998
+ cv_indices, param_indices = zip(
999
+ *product(range(cross_validator_indices_length), range(parameter_grid_length))
1000
+ )
1025
1001
 
1026
- # Let the sproc use all cores to refit.
1027
- estimator.n_jobs = estimator.n_jobs or -1
1002
+ indices_info_pandas = pd.DataFrame(
1003
+ {
1004
+ "IDX": [
1005
+ i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)
1006
+ ],
1007
+ "PARAM_IND": param_indices,
1008
+ "CV_IND": cv_indices,
1009
+ }
1010
+ )
1028
1011
 
1029
- # process the input as args
1030
- argspec = inspect.getfullargspec(estimator.fit)
1031
- args = {"X": X}
1032
- if label_cols:
1033
- label_arg_name = "Y" if "Y" in argspec.args else "y"
1034
- args[label_arg_name] = y
1035
- if sample_weight_col is not None and "sample_weight" in argspec.args:
1036
- args["sample_weight"] = df[sample_weight_col].squeeze()
1037
- # estimator.refit = original_refit
1038
- refit_start_time = time.time()
1039
- estimator.best_estimator_.fit(**args)
1040
- refit_end_time = time.time()
1041
- estimator.refit_time_ = refit_end_time - refit_start_time
1012
+ indices_info_sp = session.create_dataframe(indices_info_pandas)
1013
+ # execute udtf by querying HP_TUNING table
1014
+ HP_raw_results = indices_info_sp.select(
1015
+ (
1016
+ HP_TUNING(
1017
+ indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]
1018
+ ).over(partition_by="IDX")
1019
+ ),
1020
+ )
1042
1021
 
1043
- if hasattr(estimator.best_estimator_, "feature_names_in_"):
1044
- estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1022
+ first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
1023
+ estimator,
1024
+ n_splits,
1025
+ list(param_grid),
1026
+ HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
1027
+ cross_validator_indices_length,
1028
+ parameter_grid_length,
1029
+ )
1045
1030
 
1046
- # Store the only scorer not as a dict for single metric evaluation
1047
- estimator.scorer_ = scorers
1048
- estimator.n_splits_ = n_splits
1031
+ estimator.cv_results_ = cv_results_
1032
+ estimator.multimetric_ = isinstance(first_test_score, dict)
1033
+
1034
+ # check refit_metric now for a callable scorer that is multimetric
1035
+ if callable(estimator.scoring) and estimator.multimetric_:
1036
+ estimator._check_refit_for_multimetric(first_test_score)
1037
+ refit_metric = estimator.refit
1038
+
1039
+ # For multi-metric evaluation, store the best_index_, best_params_ and
1040
+ # best_score_ iff refit is one of the scorer names
1041
+ # In single metric evaluation, refit_metric is "score"
1042
+ if estimator.refit or not estimator.multimetric_:
1043
+ estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
1044
+ if not callable(estimator.refit):
1045
+ # With a non-custom callable, we can select the best score
1046
+ # based on the best index
1047
+ estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
1048
+ estimator.best_params_ = cv_results_["params"][estimator.best_index_]
1049
+
1050
+ if estimator.refit:
1051
+ estimator.best_estimator_ = clone(base_estimator).set_params(
1052
+ **clone(estimator.best_params_, safe=False)
1053
+ )
1049
1054
 
1050
- local_result_file_name = get_temp_file_path()
1055
+ # Let the sproc use all cores to refit.
1056
+ estimator.n_jobs = estimator.n_jobs or -1
1057
+
1058
+ # process the input as args
1059
+ argspec = inspect.getfullargspec(estimator.fit)
1060
+ args = {"X": X}
1061
+ if label_cols:
1062
+ label_arg_name = "Y" if "Y" in argspec.args else "y"
1063
+ args[label_arg_name] = y
1064
+ if sample_weight_col is not None and "sample_weight" in argspec.args:
1065
+ args["sample_weight"] = df[sample_weight_col].squeeze()
1066
+ # estimator.refit = original_refit
1067
+ refit_start_time = time.time()
1068
+ estimator.best_estimator_.fit(**args)
1069
+ refit_end_time = time.time()
1070
+ estimator.refit_time_ = refit_end_time - refit_start_time
1071
+
1072
+ if hasattr(estimator.best_estimator_, "feature_names_in_"):
1073
+ estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
1074
+
1075
+ # Store the only scorer not as a dict for single metric evaluation
1076
+ estimator.scorer_ = scorers
1077
+ estimator.n_splits_ = n_splits
1078
+
1079
+ local_result_file_name = temp_file_utils.get_temp_file_path()
1080
+
1081
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1082
+ cp.dump(estimator, local_result_file_obj)
1083
+
1084
+ session.file.put(
1085
+ local_result_file_name,
1086
+ temp_stage_name,
1087
+ auto_compress=False,
1088
+ overwrite=True,
1089
+ )
1051
1090
 
1052
- with open(local_result_file_name, mode="w+b") as local_result_file_obj:
1053
- cp.dump(estimator, local_result_file_obj)
1091
+ # Clean up the stages and files
1092
+ session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
1054
1093
 
1055
- session.file.put(
1056
- local_result_file_name,
1057
- temp_stage_name,
1058
- auto_compress=False,
1059
- overwrite=True,
1060
- )
1094
+ temp_file_utils.cleanup_temp_files([local_result_file_name])
1061
1095
 
1062
- return str(os.path.basename(local_result_file_name))
1096
+ return str(os.path.basename(local_result_file_name))
1097
+ finally:
1098
+ # Clean up the stages
1099
+ session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
1063
1100
 
1064
1101
  sproc_export_file_name = _distributed_search(
1065
1102
  session,
@@ -1069,7 +1106,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1069
1106
  label_cols,
1070
1107
  )
1071
1108
 
1072
- local_estimator_path = get_temp_file_path()
1109
+ local_estimator_path = temp_file_utils.get_temp_file_path()
1073
1110
  session.file.get(
1074
1111
  posixpath.join(temp_stage_name, sproc_export_file_name),
1075
1112
  local_estimator_path,
@@ -1078,7 +1115,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
1078
1115
  with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
1079
1116
  fit_estimator = cp.load(result_file_obj)
1080
1117
 
1081
- cleanup_temp_files(local_estimator_path)
1118
+ temp_file_utils.cleanup_temp_files(local_estimator_path)
1082
1119
 
1083
1120
  return fit_estimator
1084
1121
 
@@ -156,4 +156,6 @@ class SearchCV:
156
156
  self.fit_score_params[0][0],
157
157
  binary_cv_results,
158
158
  )
159
+
160
+ SearchCV._sf_node_singleton = True
159
161
  """