snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +26 -5
- snowflake/cortex/_sentiment.py +7 -4
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_util.py +105 -8
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
- snowflake/ml/dataset/dataset.py +15 -12
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/access_manager.py +34 -30
- snowflake/ml/feature_store/feature_store.py +3 -3
- snowflake/ml/feature_store/feature_view.py +12 -11
- snowflake/ml/fileset/snowfs.py +2 -31
- snowflake/ml/model/_client/ops/model_ops.py +43 -0
- snowflake/ml/model/_client/sql/model_version.py +55 -3
- snowflake/ml/model/_model_composer/model_composer.py +7 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
- snowflake/ml/modeling/cluster/birch.py +9 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
- snowflake/ml/modeling/cluster/dbscan.py +9 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
- snowflake/ml/modeling/cluster/k_means.py +9 -2
- snowflake/ml/modeling/cluster/mean_shift.py +9 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
- snowflake/ml/modeling/cluster/optics.py +9 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
- snowflake/ml/modeling/compose/column_transformer.py +9 -2
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
- snowflake/ml/modeling/covariance/oas.py +9 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
- snowflake/ml/modeling/decomposition/pca.py +9 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
- snowflake/ml/modeling/impute/knn_imputer.py +9 -2
- snowflake/ml/modeling/impute/missing_indicator.py +9 -2
- snowflake/ml/modeling/impute/simple_imputer.py +28 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/lars.py +9 -2
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/perceptron.py +9 -2
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ridge.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
- snowflake/ml/modeling/manifold/isomap.py +9 -2
- snowflake/ml/modeling/manifold/mds.py +9 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
- snowflake/ml/modeling/manifold/tsne.py +9 -2
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
- snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +5 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
- snowflake/ml/modeling/svm/linear_svc.py +9 -2
- snowflake/ml/modeling/svm/linear_svr.py +9 -2
- snowflake/ml/modeling/svm/nu_svc.py +9 -2
- snowflake/ml/modeling/svm/nu_svr.py +9 -2
- snowflake/ml/modeling/svm/svc.py +9 -2
- snowflake/ml/modeling/svm/svr.py +9 -2
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
- snowflake/ml/registry/_manager/model_manager.py +59 -1
- snowflake/ml/registry/registry.py +10 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,11 @@ import io
|
|
4
4
|
import os
|
5
5
|
import posixpath
|
6
6
|
import sys
|
7
|
-
|
7
|
+
import uuid
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
8
9
|
|
9
10
|
import cloudpickle as cp
|
10
11
|
import numpy as np
|
11
|
-
import numpy.typing as npt
|
12
12
|
from sklearn import model_selection
|
13
13
|
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
14
14
|
|
@@ -17,10 +17,7 @@ from snowflake.ml._internal.utils import (
|
|
17
17
|
identifier,
|
18
18
|
pkg_version_utils,
|
19
19
|
snowpark_dataframe_utils,
|
20
|
-
|
21
|
-
from snowflake.ml._internal.utils.temp_file_utils import (
|
22
|
-
cleanup_temp_files,
|
23
|
-
get_temp_file_path,
|
20
|
+
temp_file_utils,
|
24
21
|
)
|
25
22
|
from snowflake.ml.modeling._internal.model_specifications import (
|
26
23
|
ModelSpecificationsBuilder,
|
@@ -36,14 +33,16 @@ from snowflake.snowpark._internal.utils import (
|
|
36
33
|
from snowflake.snowpark.functions import sproc, udtf
|
37
34
|
from snowflake.snowpark.row import Row
|
38
35
|
from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
|
36
|
+
from snowflake.snowpark.udtf import UDTFRegistration
|
39
37
|
|
40
|
-
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
38
|
+
cp.register_pickle_by_value(inspect.getmodule(temp_file_utils.get_temp_file_path))
|
41
39
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
42
40
|
cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
|
43
41
|
|
44
42
|
_PROJECT = "ModelDevelopment"
|
45
43
|
DEFAULT_UDTF_NJOBS = 3
|
46
44
|
ENABLE_EFFICIENT_MEMORY_USAGE = False
|
45
|
+
_UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
|
47
46
|
|
48
47
|
|
49
48
|
def construct_cv_results(
|
@@ -318,7 +317,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
318
317
|
original_refit = estimator.refit
|
319
318
|
|
320
319
|
# Create a temp file and dump the estimator to that file.
|
321
|
-
estimator_file_name = get_temp_file_path()
|
320
|
+
estimator_file_name = temp_file_utils.get_temp_file_path()
|
322
321
|
params_to_evaluate = []
|
323
322
|
for param_to_eval in list(param_grid):
|
324
323
|
for k, v in param_to_eval.items():
|
@@ -357,6 +356,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
357
356
|
)
|
358
357
|
estimator_location = put_result[0].target
|
359
358
|
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
359
|
+
temp_file_utils.cleanup_temp_files([estimator_file_name])
|
360
360
|
|
361
361
|
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
362
362
|
random_udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
|
@@ -413,7 +413,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
413
413
|
X = df[input_cols]
|
414
414
|
y = df[label_cols].squeeze() if label_cols else None
|
415
415
|
|
416
|
-
local_estimator_file_name = get_temp_file_path()
|
416
|
+
local_estimator_file_name = temp_file_utils.get_temp_file_path()
|
417
417
|
session.file.get(stage_estimator_file_name, local_estimator_file_name)
|
418
418
|
|
419
419
|
local_estimator_file_path = os.path.join(
|
@@ -429,7 +429,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
429
429
|
n_splits = build_cross_validator.get_n_splits(X, y, None)
|
430
430
|
# store the cross_validator's test indices only to save space
|
431
431
|
cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
|
432
|
-
local_indices_file_name = get_temp_file_path()
|
432
|
+
local_indices_file_name = temp_file_utils.get_temp_file_path()
|
433
433
|
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
434
434
|
cp.dump(cross_validator_indices, local_indices_file_obj)
|
435
435
|
|
@@ -445,6 +445,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
445
445
|
cross_validator_indices_length = int(len(cross_validator_indices))
|
446
446
|
parameter_grid_length = len(param_grid)
|
447
447
|
|
448
|
+
temp_file_utils.cleanup_temp_files([local_estimator_file_name, local_indices_file_name])
|
449
|
+
|
448
450
|
assert estimator is not None
|
449
451
|
|
450
452
|
@cachetools.cached(cache={})
|
@@ -647,7 +649,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
647
649
|
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
648
650
|
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
649
651
|
|
650
|
-
local_result_file_name = get_temp_file_path()
|
652
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
651
653
|
|
652
654
|
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
653
655
|
cp.dump(estimator, local_result_file_obj)
|
@@ -658,6 +660,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
658
660
|
auto_compress=False,
|
659
661
|
overwrite=True,
|
660
662
|
)
|
663
|
+
temp_file_utils.cleanup_temp_files([local_result_file_name])
|
661
664
|
|
662
665
|
# Note: you can add something like + "|" + str(df) to the return string
|
663
666
|
# to pass debug information to the caller.
|
@@ -671,7 +674,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
671
674
|
label_cols,
|
672
675
|
)
|
673
676
|
|
674
|
-
local_estimator_path = get_temp_file_path()
|
677
|
+
local_estimator_path = temp_file_utils.get_temp_file_path()
|
675
678
|
session.file.get(
|
676
679
|
posixpath.join(temp_stage_name, sproc_export_file_name),
|
677
680
|
local_estimator_path,
|
@@ -680,7 +683,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
680
683
|
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
681
684
|
fit_estimator = cp.load(result_file_obj)
|
682
685
|
|
683
|
-
cleanup_temp_files([local_estimator_path])
|
686
|
+
temp_file_utils.cleanup_temp_files([local_estimator_path])
|
684
687
|
|
685
688
|
return fit_estimator
|
686
689
|
|
@@ -698,7 +701,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
698
701
|
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
699
702
|
from itertools import product
|
700
703
|
|
701
|
-
import cachetools
|
702
704
|
from sklearn.base import clone, is_classifier
|
703
705
|
from sklearn.calibration import check_cv
|
704
706
|
|
@@ -717,11 +719,13 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
717
719
|
imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
|
718
720
|
|
719
721
|
# Create a temp file and dump the estimator to that file.
|
720
|
-
estimator_file_name = get_temp_file_path()
|
722
|
+
estimator_file_name = temp_file_utils.get_temp_file_path()
|
721
723
|
params_to_evaluate = list(param_grid)
|
722
|
-
|
723
|
-
|
724
|
-
|
724
|
+
CONSTANTS: Dict[str, Any] = dict()
|
725
|
+
CONSTANTS["dataset_snowpark_cols"] = dataset.columns
|
726
|
+
CONSTANTS["n_candidates"] = len(params_to_evaluate)
|
727
|
+
CONSTANTS["_N_JOBS"] = estimator.n_jobs
|
728
|
+
CONSTANTS["_PRE_DISPATCH"] = estimator.pre_dispatch
|
725
729
|
|
726
730
|
with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
|
727
731
|
cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
|
@@ -743,6 +747,9 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
743
747
|
api_calls=[udtf],
|
744
748
|
custom_tags=dict([("hpo_memory_efficient", True)]),
|
745
749
|
)
|
750
|
+
from snowflake.ml.modeling._internal.snowpark_implementations.distributed_search_udf_file import (
|
751
|
+
execute_template,
|
752
|
+
)
|
746
753
|
|
747
754
|
# Put locally serialized estimator on stage.
|
748
755
|
session.file.put(
|
@@ -753,6 +760,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
753
760
|
)
|
754
761
|
estimator_location = os.path.basename(estimator_file_name)
|
755
762
|
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
763
|
+
temp_file_utils.cleanup_temp_files([estimator_file_name])
|
764
|
+
CONSTANTS["estimator_location"] = estimator_location
|
756
765
|
|
757
766
|
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
758
767
|
random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
|
@@ -783,7 +792,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
783
792
|
) -> str:
|
784
793
|
import os
|
785
794
|
import time
|
786
|
-
from typing import Iterator
|
787
795
|
|
788
796
|
import cloudpickle as cp
|
789
797
|
import pandas as pd
|
@@ -819,7 +827,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
819
827
|
if sample_weight_col:
|
820
828
|
fit_params["sample_weight"] = df[sample_weight_col].squeeze()
|
821
829
|
|
822
|
-
local_estimator_file_folder_name = get_temp_file_path()
|
830
|
+
local_estimator_file_folder_name = temp_file_utils.get_temp_file_path()
|
823
831
|
session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
|
824
832
|
|
825
833
|
local_estimator_file_path = os.path.join(
|
@@ -865,7 +873,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
865
873
|
|
866
874
|
# (1) store the cross_validator's test indices only to save space
|
867
875
|
cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
|
868
|
-
local_indices_file_name = get_temp_file_path()
|
876
|
+
local_indices_file_name = temp_file_utils.get_temp_file_path()
|
869
877
|
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
870
878
|
cp.dump(cross_validator_indices, local_indices_file_obj)
|
871
879
|
|
@@ -880,7 +888,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
880
888
|
imports.append(f"@{temp_stage_name}/{indices_location}")
|
881
889
|
|
882
890
|
# (2) store the base estimator
|
883
|
-
local_base_estimator_file_name = get_temp_file_path()
|
891
|
+
local_base_estimator_file_name = temp_file_utils.get_temp_file_path()
|
884
892
|
with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
|
885
893
|
cp.dump(base_estimator, local_base_estimator_file_obj)
|
886
894
|
session.file.put(
|
@@ -893,7 +901,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
893
901
|
imports.append(f"@{temp_stage_name}/{base_estimator_location}")
|
894
902
|
|
895
903
|
# (3) store the fit_and_score_kwargs
|
896
|
-
local_fit_and_score_kwargs_file_name = get_temp_file_path()
|
904
|
+
local_fit_and_score_kwargs_file_name = temp_file_utils.get_temp_file_path()
|
897
905
|
with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
|
898
906
|
cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
|
899
907
|
session.file.put(
|
@@ -905,242 +913,188 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
905
913
|
fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
|
906
914
|
imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
|
907
915
|
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
916
|
+
CONSTANTS["input_cols"] = input_cols
|
917
|
+
CONSTANTS["label_cols"] = label_cols
|
918
|
+
CONSTANTS["DATA_LENGTH"] = DATA_LENGTH
|
919
|
+
CONSTANTS["n_splits"] = n_splits
|
920
|
+
CONSTANTS["indices_location"] = indices_location
|
921
|
+
CONSTANTS["base_estimator_location"] = base_estimator_location
|
922
|
+
CONSTANTS["fit_and_score_kwargs_location"] = fit_and_score_kwargs_location
|
912
923
|
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
List[List[int]],
|
918
|
-
List[Dict[str, Any]],
|
919
|
-
object,
|
920
|
-
Dict[str, Any],
|
921
|
-
]:
|
922
|
-
import pyarrow.parquet as pq
|
924
|
+
# (6) store the constants
|
925
|
+
local_constant_file_name = temp_file_utils.get_temp_file_path(prefix="constant")
|
926
|
+
with open(local_constant_file_name, mode="w+b") as local_indices_file_obj:
|
927
|
+
cp.dump(CONSTANTS, local_indices_file_obj)
|
923
928
|
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
929
|
+
# Put locally serialized indices on stage.
|
930
|
+
session.file.put(
|
931
|
+
local_constant_file_name,
|
932
|
+
temp_stage_name,
|
933
|
+
auto_compress=False,
|
934
|
+
overwrite=True,
|
935
|
+
)
|
936
|
+
constant_location = os.path.basename(local_constant_file_name)
|
937
|
+
imports.append(f"@{temp_stage_name}/{constant_location}")
|
938
|
+
|
939
|
+
temp_file_utils.cleanup_temp_files(
|
940
|
+
[
|
941
|
+
local_estimator_file_folder_name,
|
942
|
+
local_indices_file_name,
|
943
|
+
local_base_estimator_file_name,
|
944
|
+
local_base_estimator_file_name,
|
945
|
+
local_fit_and_score_kwargs_file_name,
|
946
|
+
local_constant_file_name,
|
932
947
|
]
|
933
|
-
|
934
|
-
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
948
|
+
)
|
935
949
|
|
936
|
-
|
937
|
-
|
938
|
-
sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
|
939
|
-
)
|
940
|
-
with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
|
941
|
-
estimator_objects = cp.load(local_estimator_file_obj)
|
942
|
-
params_to_evaluate = estimator_objects["param_grid"]
|
950
|
+
cross_validator_indices_length = int(len(cross_validator_indices))
|
951
|
+
parameter_grid_length = len(param_grid)
|
943
952
|
|
944
|
-
|
945
|
-
local_indices_file_path = os.path.join(
|
946
|
-
sys._xoptions["snowflake_import_directory"], f"{indices_location}"
|
947
|
-
)
|
948
|
-
with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
|
949
|
-
indices = cp.load(local_indices_file_obj)
|
953
|
+
assert estimator is not None
|
950
954
|
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
955
|
+
# Instantiate UDTFRegistration with the session object
|
956
|
+
udtf_registration = UDTFRegistration(session)
|
957
|
+
|
958
|
+
import tempfile
|
959
|
+
|
960
|
+
# delete is set to False to support Windows environment
|
961
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
|
962
|
+
udf_code = execute_template
|
963
|
+
f.file.write(udf_code)
|
964
|
+
f.file.flush()
|
965
|
+
|
966
|
+
# Use catchall exception handling and a finally block to clean up the _UDTF_STAGE_NAME
|
967
|
+
try:
|
968
|
+
# Create one stage for data and for estimators.
|
969
|
+
# Because only permanent functions support _sf_node_singleton for now, therefore,
|
970
|
+
# UDTF creation would change to is_permanent=True, and manually drop the stage after UDTF is done
|
971
|
+
_stage_creation_query_udtf = f"CREATE OR REPLACE STAGE {_UDTF_STAGE_NAME};"
|
972
|
+
session.sql(_stage_creation_query_udtf).collect()
|
973
|
+
|
974
|
+
# Register the UDTF function from the file
|
975
|
+
udtf_registration.register_from_file(
|
976
|
+
file_path=f.name,
|
977
|
+
handler_name="SearchCV",
|
978
|
+
name=random_udtf_name,
|
979
|
+
output_schema=StructType(
|
980
|
+
[StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
|
981
|
+
),
|
982
|
+
input_types=[IntegerType(), IntegerType(), IntegerType()],
|
983
|
+
replace=True,
|
984
|
+
imports=imports, # type: ignore[arg-type]
|
985
|
+
stage_location=_UDTF_STAGE_NAME,
|
986
|
+
is_permanent=True,
|
987
|
+
packages=required_deps, # type: ignore[arg-type]
|
988
|
+
statement_params=udtf_statement_params,
|
989
|
+
)
|
957
990
|
|
958
|
-
|
959
|
-
local_fit_and_score_kwargs_file_path = os.path.join(
|
960
|
-
sys._xoptions["snowflake_import_directory"], f"{fit_and_score_kwargs_location}"
|
961
|
-
)
|
962
|
-
with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
|
963
|
-
fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
|
964
|
-
|
965
|
-
# convert dataframe to numpy would save memory consumption
|
966
|
-
return (
|
967
|
-
df[input_cols].to_numpy(),
|
968
|
-
df[label_cols].squeeze().to_numpy(),
|
969
|
-
indices,
|
970
|
-
params_to_evaluate,
|
971
|
-
base_estimator,
|
972
|
-
fit_and_score_kwargs,
|
973
|
-
)
|
991
|
+
HP_TUNING = F.table_function(random_udtf_name)
|
974
992
|
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
self.y = y
|
981
|
-
self.test_indices = indices
|
982
|
-
self.params_to_evaluate = params_to_evaluate
|
983
|
-
self.base_estimator = base_estimator
|
984
|
-
self.fit_and_score_kwargs = fit_and_score_kwargs
|
985
|
-
self.fit_score_params: List[Any] = []
|
986
|
-
self.cv_indices_set: Set[int] = set()
|
987
|
-
|
988
|
-
def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
|
989
|
-
self.fit_score_params.extend([[idx, params_idx, cv_idx]])
|
990
|
-
self.cv_indices_set.add(cv_idx)
|
991
|
-
|
992
|
-
def end_partition(self) -> Iterator[Tuple[int, str]]:
|
993
|
-
from sklearn.base import clone
|
994
|
-
from sklearn.model_selection._validation import _fit_and_score
|
995
|
-
from sklearn.utils.parallel import Parallel, delayed
|
996
|
-
|
997
|
-
cached_train_test_indices = {}
|
998
|
-
# Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
|
999
|
-
full_index = np.arange(DATA_LENGTH)
|
1000
|
-
for i in self.cv_indices_set:
|
1001
|
-
cached_train_test_indices[i] = [
|
1002
|
-
np.setdiff1d(full_index, self.test_indices[i]),
|
1003
|
-
self.test_indices[i],
|
1004
|
-
]
|
1005
|
-
|
1006
|
-
parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
|
1007
|
-
|
1008
|
-
out = parallel(
|
1009
|
-
delayed(_fit_and_score)(
|
1010
|
-
clone(self.base_estimator),
|
1011
|
-
self.X,
|
1012
|
-
self.y,
|
1013
|
-
train=cached_train_test_indices[split_idx][0],
|
1014
|
-
test=cached_train_test_indices[split_idx][1],
|
1015
|
-
parameters=self.params_to_evaluate[cand_idx],
|
1016
|
-
split_progress=(split_idx, n_splits),
|
1017
|
-
candidate_progress=(cand_idx, n_candidates),
|
1018
|
-
**self.fit_and_score_kwargs, # load sample weight here
|
1019
|
-
)
|
1020
|
-
for _, cand_idx, split_idx in self.fit_score_params
|
993
|
+
# param_indices is for the index for each parameter grid;
|
994
|
+
# cv_indices is for the index for each cross_validator's fold;
|
995
|
+
# param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
|
996
|
+
cv_indices, param_indices = zip(
|
997
|
+
*product(range(cross_validator_indices_length), range(parameter_grid_length))
|
1021
998
|
)
|
1022
999
|
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1000
|
+
indices_info_pandas = pd.DataFrame(
|
1001
|
+
{
|
1002
|
+
"IDX": [
|
1003
|
+
i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)
|
1004
|
+
],
|
1005
|
+
"PARAM_IND": param_indices,
|
1006
|
+
"CV_IND": cv_indices,
|
1007
|
+
}
|
1031
1008
|
)
|
1032
1009
|
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
replace=True,
|
1042
|
-
is_permanent=False,
|
1043
|
-
imports=imports, # type: ignore[arg-type]
|
1044
|
-
statement_params=udtf_statement_params,
|
1045
|
-
)
|
1046
|
-
|
1047
|
-
HP_TUNING = F.table_function(random_udtf_name)
|
1048
|
-
|
1049
|
-
# param_indices is for the index for each parameter grid;
|
1050
|
-
# cv_indices is for the index for each cross_validator's fold;
|
1051
|
-
# param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
|
1052
|
-
cv_indices, param_indices = zip(
|
1053
|
-
*product(range(cross_validator_indices_length), range(parameter_grid_length))
|
1054
|
-
)
|
1055
|
-
|
1056
|
-
indices_info_pandas = pd.DataFrame(
|
1057
|
-
{
|
1058
|
-
"IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
|
1059
|
-
"PARAM_IND": param_indices,
|
1060
|
-
"CV_IND": cv_indices,
|
1061
|
-
}
|
1062
|
-
)
|
1063
|
-
|
1064
|
-
indices_info_sp = session.create_dataframe(indices_info_pandas)
|
1065
|
-
# execute udtf by querying HP_TUNING table
|
1066
|
-
HP_raw_results = indices_info_sp.select(
|
1067
|
-
(
|
1068
|
-
HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
|
1069
|
-
partition_by="IDX"
|
1010
|
+
indices_info_sp = session.create_dataframe(indices_info_pandas)
|
1011
|
+
# execute udtf by querying HP_TUNING table
|
1012
|
+
HP_raw_results = indices_info_sp.select(
|
1013
|
+
(
|
1014
|
+
HP_TUNING(
|
1015
|
+
indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]
|
1016
|
+
).over(partition_by="IDX")
|
1017
|
+
),
|
1070
1018
|
)
|
1071
|
-
),
|
1072
|
-
)
|
1073
|
-
|
1074
|
-
first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
|
1075
|
-
estimator,
|
1076
|
-
n_splits,
|
1077
|
-
list(param_grid),
|
1078
|
-
HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
|
1079
|
-
cross_validator_indices_length,
|
1080
|
-
parameter_grid_length,
|
1081
|
-
)
|
1082
|
-
|
1083
|
-
estimator.cv_results_ = cv_results_
|
1084
|
-
estimator.multimetric_ = isinstance(first_test_score, dict)
|
1085
|
-
|
1086
|
-
# check refit_metric now for a callable scorer that is multimetric
|
1087
|
-
if callable(estimator.scoring) and estimator.multimetric_:
|
1088
|
-
estimator._check_refit_for_multimetric(first_test_score)
|
1089
|
-
refit_metric = estimator.refit
|
1090
|
-
|
1091
|
-
# For multi-metric evaluation, store the best_index_, best_params_ and
|
1092
|
-
# best_score_ iff refit is one of the scorer names
|
1093
|
-
# In single metric evaluation, refit_metric is "score"
|
1094
|
-
if estimator.refit or not estimator.multimetric_:
|
1095
|
-
estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
|
1096
|
-
if not callable(estimator.refit):
|
1097
|
-
# With a non-custom callable, we can select the best score
|
1098
|
-
# based on the best index
|
1099
|
-
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
1100
|
-
estimator.best_params_ = cv_results_["params"][estimator.best_index_]
|
1101
|
-
|
1102
|
-
if estimator.refit:
|
1103
|
-
estimator.best_estimator_ = clone(base_estimator).set_params(
|
1104
|
-
**clone(estimator.best_params_, safe=False)
|
1105
|
-
)
|
1106
1019
|
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
args[label_arg_name] = y
|
1116
|
-
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
1117
|
-
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1118
|
-
# estimator.refit = original_refit
|
1119
|
-
refit_start_time = time.time()
|
1120
|
-
estimator.best_estimator_.fit(**args)
|
1121
|
-
refit_end_time = time.time()
|
1122
|
-
estimator.refit_time_ = refit_end_time - refit_start_time
|
1020
|
+
first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
|
1021
|
+
estimator,
|
1022
|
+
n_splits,
|
1023
|
+
list(param_grid),
|
1024
|
+
HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
|
1025
|
+
cross_validator_indices_length,
|
1026
|
+
parameter_grid_length,
|
1027
|
+
)
|
1123
1028
|
|
1124
|
-
|
1125
|
-
estimator.
|
1029
|
+
estimator.cv_results_ = cv_results_
|
1030
|
+
estimator.multimetric_ = isinstance(first_test_score, dict)
|
1031
|
+
|
1032
|
+
# check refit_metric now for a callable scorer that is multimetric
|
1033
|
+
if callable(estimator.scoring) and estimator.multimetric_:
|
1034
|
+
estimator._check_refit_for_multimetric(first_test_score)
|
1035
|
+
refit_metric = estimator.refit
|
1036
|
+
|
1037
|
+
# For multi-metric evaluation, store the best_index_, best_params_ and
|
1038
|
+
# best_score_ iff refit is one of the scorer names
|
1039
|
+
# In single metric evaluation, refit_metric is "score"
|
1040
|
+
if estimator.refit or not estimator.multimetric_:
|
1041
|
+
estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
|
1042
|
+
if not callable(estimator.refit):
|
1043
|
+
# With a non-custom callable, we can select the best score
|
1044
|
+
# based on the best index
|
1045
|
+
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
1046
|
+
estimator.best_params_ = cv_results_["params"][estimator.best_index_]
|
1047
|
+
|
1048
|
+
if estimator.refit:
|
1049
|
+
estimator.best_estimator_ = clone(base_estimator).set_params(
|
1050
|
+
**clone(estimator.best_params_, safe=False)
|
1051
|
+
)
|
1126
1052
|
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1053
|
+
# Let the sproc use all cores to refit.
|
1054
|
+
estimator.n_jobs = estimator.n_jobs or -1
|
1055
|
+
|
1056
|
+
# process the input as args
|
1057
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
1058
|
+
args = {"X": X}
|
1059
|
+
if label_cols:
|
1060
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
1061
|
+
args[label_arg_name] = y
|
1062
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
1063
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1064
|
+
# estimator.refit = original_refit
|
1065
|
+
refit_start_time = time.time()
|
1066
|
+
estimator.best_estimator_.fit(**args)
|
1067
|
+
refit_end_time = time.time()
|
1068
|
+
estimator.refit_time_ = refit_end_time - refit_start_time
|
1069
|
+
|
1070
|
+
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
1071
|
+
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
1072
|
+
|
1073
|
+
# Store the only scorer not as a dict for single metric evaluation
|
1074
|
+
estimator.scorer_ = scorers
|
1075
|
+
estimator.n_splits_ = n_splits
|
1076
|
+
|
1077
|
+
local_result_file_name = temp_file_utils.get_temp_file_path()
|
1078
|
+
|
1079
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
1080
|
+
cp.dump(estimator, local_result_file_obj)
|
1081
|
+
|
1082
|
+
session.file.put(
|
1083
|
+
local_result_file_name,
|
1084
|
+
temp_stage_name,
|
1085
|
+
auto_compress=False,
|
1086
|
+
overwrite=True,
|
1087
|
+
)
|
1130
1088
|
|
1131
|
-
|
1089
|
+
# Clean up the stages and files
|
1090
|
+
session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
|
1132
1091
|
|
1133
|
-
|
1134
|
-
cp.dump(estimator, local_result_file_obj)
|
1092
|
+
temp_file_utils.cleanup_temp_files([local_result_file_name])
|
1135
1093
|
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
overwrite=True,
|
1141
|
-
)
|
1142
|
-
|
1143
|
-
return str(os.path.basename(local_result_file_name))
|
1094
|
+
return str(os.path.basename(local_result_file_name))
|
1095
|
+
finally:
|
1096
|
+
# Clean up the stages
|
1097
|
+
session.sql(f"DROP STAGE IF EXISTS {_UDTF_STAGE_NAME}")
|
1144
1098
|
|
1145
1099
|
sproc_export_file_name = _distributed_search(
|
1146
1100
|
session,
|
@@ -1150,7 +1104,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1150
1104
|
label_cols,
|
1151
1105
|
)
|
1152
1106
|
|
1153
|
-
local_estimator_path = get_temp_file_path()
|
1107
|
+
local_estimator_path = temp_file_utils.get_temp_file_path()
|
1154
1108
|
session.file.get(
|
1155
1109
|
posixpath.join(temp_stage_name, sproc_export_file_name),
|
1156
1110
|
local_estimator_path,
|
@@ -1159,7 +1113,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1159
1113
|
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
1160
1114
|
fit_estimator = cp.load(result_file_obj)
|
1161
1115
|
|
1162
|
-
cleanup_temp_files(local_estimator_path)
|
1116
|
+
temp_file_utils.cleanup_temp_files(local_estimator_path)
|
1163
1117
|
|
1164
1118
|
return fit_estimator
|
1165
1119
|
|