snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
8
8
|
|
9
9
|
import cloudpickle as cp
|
10
10
|
import numpy as np
|
11
|
+
import numpy.typing as npt
|
11
12
|
from sklearn import model_selection
|
12
13
|
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
13
14
|
|
@@ -38,9 +39,11 @@ from snowflake.snowpark.types import IntegerType, StringType, StructField, Struc
|
|
38
39
|
|
39
40
|
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
40
41
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
42
|
+
cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
|
41
43
|
|
42
44
|
_PROJECT = "ModelDevelopment"
|
43
45
|
DEFAULT_UDTF_NJOBS = 3
|
46
|
+
ENABLE_EFFICIENT_MEMORY_USAGE = False
|
44
47
|
|
45
48
|
|
46
49
|
def construct_cv_results(
|
@@ -151,7 +154,63 @@ def construct_cv_results(
|
|
151
154
|
return multimetric, estimator._format_results(param_grid, n_split, out)
|
152
155
|
|
153
156
|
|
157
|
+
def construct_cv_results_new_implementation(
|
158
|
+
estimator: Union[GridSearchCV, RandomizedSearchCV],
|
159
|
+
n_split: int,
|
160
|
+
param_grid: List[Dict[str, Any]],
|
161
|
+
cv_results_raw_hex: List[Row],
|
162
|
+
cross_validator_indices_length: int,
|
163
|
+
parameter_grid_length: int,
|
164
|
+
) -> Tuple[Any, Dict[str, Any]]:
|
165
|
+
"""Construct the cross validation result from the UDF.
|
166
|
+
The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
|
167
|
+
This function need to decode the string and then call _format_result to stick them back together
|
168
|
+
to align with original sklearn result.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
|
172
|
+
GridSearchCV or RandomizedSearchCV
|
173
|
+
n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
|
174
|
+
param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
|
175
|
+
cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
|
176
|
+
Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
|
177
|
+
json format. Each cv_result is encoded into hex string.
|
178
|
+
cross_validator_indices_length (int): the length of cross validator indices
|
179
|
+
parameter_grid_length (int): the length of parameter grid combination
|
180
|
+
|
181
|
+
Raises:
|
182
|
+
ValueError: Retrieved empty cross validation results
|
183
|
+
ValueError: Cross validator index length is 0
|
184
|
+
ValueError: Parameter index length is 0
|
185
|
+
ValueError: Retrieved incorrect dataframe dimension from Snowpark's UDTF.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Tuple[Any, Dict[str, Any]]: returns first_test_score, cv_results_
|
189
|
+
"""
|
190
|
+
# Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
|
191
|
+
if len(cv_results_raw_hex) == 0:
|
192
|
+
raise ValueError(
|
193
|
+
"Retrieved empty cross validation results from snowpark. Please retry or contact snowflake support."
|
194
|
+
)
|
195
|
+
if cross_validator_indices_length == 0:
|
196
|
+
raise ValueError("Cross validator index length is 0. Was the CV iterator empty? ")
|
197
|
+
if parameter_grid_length == 0:
|
198
|
+
raise ValueError("Parameter index length is 0. Were there no candidates?")
|
199
|
+
|
200
|
+
all_out = []
|
201
|
+
|
202
|
+
for each_cv_result_hex in cv_results_raw_hex:
|
203
|
+
# convert the hex string back to cv_results_
|
204
|
+
hex_str = bytes.fromhex(each_cv_result_hex[0])
|
205
|
+
with io.BytesIO(hex_str) as f_reload:
|
206
|
+
out = cp.load(f_reload)
|
207
|
+
all_out.extend(out)
|
208
|
+
first_test_score = all_out[0]["test_scores"]
|
209
|
+
return first_test_score, estimator._format_results(param_grid, n_split, all_out)
|
210
|
+
|
211
|
+
|
154
212
|
cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
|
213
|
+
cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_new_implementation))
|
155
214
|
|
156
215
|
|
157
216
|
class DistributedHPOTrainer(SnowparkModelTrainer):
|
@@ -277,7 +336,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
277
336
|
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
278
337
|
|
279
338
|
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
280
|
-
random_udtf_name = random_name_for_temp_object(TempObjectType.
|
339
|
+
random_udtf_name = random_name_for_temp_object(TempObjectType.TABLE_FUNCTION)
|
281
340
|
|
282
341
|
required_deps = dependencies + [
|
283
342
|
"snowflake-snowpark-python<2",
|
@@ -602,6 +661,480 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
602
661
|
|
603
662
|
return fit_estimator
|
604
663
|
|
664
|
+
def fit_search_snowpark_new_implementation(
|
665
|
+
self,
|
666
|
+
param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
|
667
|
+
dataset: DataFrame,
|
668
|
+
session: Session,
|
669
|
+
estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
670
|
+
dependencies: List[str],
|
671
|
+
udf_imports: List[str],
|
672
|
+
input_cols: List[str],
|
673
|
+
label_cols: Optional[List[str]],
|
674
|
+
sample_weight_col: Optional[str],
|
675
|
+
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
676
|
+
from itertools import product
|
677
|
+
|
678
|
+
import cachetools
|
679
|
+
from sklearn.base import clone, is_classifier
|
680
|
+
from sklearn.calibration import check_cv
|
681
|
+
|
682
|
+
# Create one stage for data and for estimators.
|
683
|
+
temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
|
684
|
+
temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
|
685
|
+
session.sql(temp_stage_creation_query).collect()
|
686
|
+
|
687
|
+
# Stage data as parquet file
|
688
|
+
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
|
689
|
+
dataset_file_name = "dataset"
|
690
|
+
remote_file_path = f"{temp_stage_name}/{dataset_file_name}.parquet"
|
691
|
+
dataset.write.copy_into_location( # type:ignore[call-overload]
|
692
|
+
remote_file_path, file_format_type="parquet", header=True, overwrite=True
|
693
|
+
)
|
694
|
+
imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
|
695
|
+
|
696
|
+
# Create a temp file and dump the estimator to that file.
|
697
|
+
estimator_file_name = get_temp_file_path()
|
698
|
+
params_to_evaluate = list(param_grid)
|
699
|
+
n_candidates = len(params_to_evaluate)
|
700
|
+
_N_JOBS = estimator.n_jobs
|
701
|
+
_PRE_DISPATCH = estimator.pre_dispatch
|
702
|
+
|
703
|
+
with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
|
704
|
+
cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
|
705
|
+
stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name))
|
706
|
+
sproc_statement_params = telemetry.get_function_usage_statement_params(
|
707
|
+
project=_PROJECT,
|
708
|
+
subproject=self._subproject,
|
709
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
710
|
+
inspect.currentframe(), self.__class__.__name__
|
711
|
+
),
|
712
|
+
api_calls=[sproc],
|
713
|
+
)
|
714
|
+
udtf_statement_params = telemetry.get_function_usage_statement_params(
|
715
|
+
project=_PROJECT,
|
716
|
+
subproject=self._subproject,
|
717
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
718
|
+
inspect.currentframe(), self.__class__.__name__
|
719
|
+
),
|
720
|
+
api_calls=[udtf],
|
721
|
+
custom_tags=dict([("hpo_udtf", True)]),
|
722
|
+
)
|
723
|
+
|
724
|
+
# Put locally serialized estimator on stage.
|
725
|
+
session.file.put(
|
726
|
+
estimator_file_name,
|
727
|
+
temp_stage_name,
|
728
|
+
auto_compress=False,
|
729
|
+
overwrite=True,
|
730
|
+
)
|
731
|
+
estimator_location = os.path.basename(estimator_file_name)
|
732
|
+
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
733
|
+
|
734
|
+
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
735
|
+
random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
|
736
|
+
|
737
|
+
required_deps = dependencies + [
|
738
|
+
"snowflake-snowpark-python<2",
|
739
|
+
"fastparquet<2023.11",
|
740
|
+
"pyarrow<14",
|
741
|
+
"cachetools<6",
|
742
|
+
]
|
743
|
+
|
744
|
+
@sproc( # type: ignore[misc]
|
745
|
+
is_permanent=False,
|
746
|
+
name=search_sproc_name,
|
747
|
+
packages=required_deps, # type: ignore[arg-type]
|
748
|
+
replace=True,
|
749
|
+
session=session,
|
750
|
+
anonymous=True,
|
751
|
+
imports=imports, # type: ignore[arg-type]
|
752
|
+
statement_params=sproc_statement_params,
|
753
|
+
)
|
754
|
+
def _distributed_search(
|
755
|
+
session: Session,
|
756
|
+
imports: List[str],
|
757
|
+
stage_estimator_file_name: str,
|
758
|
+
input_cols: List[str],
|
759
|
+
label_cols: Optional[List[str]],
|
760
|
+
) -> str:
|
761
|
+
import os
|
762
|
+
import time
|
763
|
+
from typing import Iterator
|
764
|
+
|
765
|
+
import cloudpickle as cp
|
766
|
+
import pandas as pd
|
767
|
+
import pyarrow.parquet as pq
|
768
|
+
from sklearn.metrics import check_scoring
|
769
|
+
from sklearn.metrics._scorer import _check_multimetric_scoring
|
770
|
+
from sklearn.utils.validation import _check_fit_params, indexable
|
771
|
+
|
772
|
+
# import packages in sproc
|
773
|
+
for import_name in udf_imports:
|
774
|
+
importlib.import_module(import_name)
|
775
|
+
|
776
|
+
# os.cpu_count() returns the number of logical CPUs in the system. Returns None if undetermined.
|
777
|
+
_NUM_CPUs = os.cpu_count() or 1
|
778
|
+
|
779
|
+
# load dataset
|
780
|
+
data_files = [
|
781
|
+
filename
|
782
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
783
|
+
if filename.startswith(dataset_file_name)
|
784
|
+
]
|
785
|
+
partial_df = [
|
786
|
+
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
787
|
+
for file_name in data_files
|
788
|
+
]
|
789
|
+
df = pd.concat(partial_df, ignore_index=True)
|
790
|
+
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
791
|
+
|
792
|
+
X = df[input_cols]
|
793
|
+
y = df[label_cols].squeeze() if label_cols else None
|
794
|
+
DATA_LENGTH = len(df)
|
795
|
+
fit_params = {}
|
796
|
+
if sample_weight_col:
|
797
|
+
fit_params["sample_weight"] = df[sample_weight_col].squeeze()
|
798
|
+
|
799
|
+
local_estimator_file_folder_name = get_temp_file_path()
|
800
|
+
session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
|
801
|
+
|
802
|
+
local_estimator_file_path = os.path.join(
|
803
|
+
local_estimator_file_folder_name, os.listdir(local_estimator_file_folder_name)[0]
|
804
|
+
)
|
805
|
+
with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
|
806
|
+
estimator = cp.load(local_estimator_file_obj)["estimator"]
|
807
|
+
|
808
|
+
# preprocess the attributes - (1) scorer
|
809
|
+
refit_metric = "score"
|
810
|
+
if callable(estimator.scoring):
|
811
|
+
scorers = estimator.scoring
|
812
|
+
elif estimator.scoring is None or isinstance(estimator.scoring, str):
|
813
|
+
scorers = check_scoring(estimator.estimator, estimator.scoring)
|
814
|
+
else:
|
815
|
+
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
816
|
+
estimator._check_refit_for_multimetric(scorers)
|
817
|
+
refit_metric = estimator.refit
|
818
|
+
|
819
|
+
# preprocess the attributes - (2) check fit_params
|
820
|
+
groups = None
|
821
|
+
X, y, _ = indexable(X, y, groups)
|
822
|
+
fit_params = _check_fit_params(X, fit_params)
|
823
|
+
|
824
|
+
# preprocess the attributes - (3) safe clone base estimator
|
825
|
+
base_estimator = clone(estimator.estimator)
|
826
|
+
|
827
|
+
# preprocess the attributes - (4) check cv
|
828
|
+
build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
|
829
|
+
n_splits = build_cross_validator.get_n_splits(X, y, groups)
|
830
|
+
|
831
|
+
# preprocess the attributes - (5) generate fit_and_score_kwargs
|
832
|
+
fit_and_score_kwargs = dict(
|
833
|
+
scorer=scorers,
|
834
|
+
fit_params=fit_params,
|
835
|
+
return_train_score=estimator.return_train_score,
|
836
|
+
return_n_test_samples=True,
|
837
|
+
return_times=True,
|
838
|
+
return_parameters=False,
|
839
|
+
error_score=estimator.error_score,
|
840
|
+
verbose=estimator.verbose,
|
841
|
+
)
|
842
|
+
|
843
|
+
# (1) store the cross_validator's test indices only to save space
|
844
|
+
cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
|
845
|
+
local_indices_file_name = get_temp_file_path()
|
846
|
+
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
847
|
+
cp.dump(cross_validator_indices, local_indices_file_obj)
|
848
|
+
|
849
|
+
# Put locally serialized indices on stage.
|
850
|
+
session.file.put(
|
851
|
+
local_indices_file_name,
|
852
|
+
temp_stage_name,
|
853
|
+
auto_compress=False,
|
854
|
+
overwrite=True,
|
855
|
+
)
|
856
|
+
indices_location = os.path.basename(local_indices_file_name)
|
857
|
+
imports.append(f"@{temp_stage_name}/{indices_location}")
|
858
|
+
|
859
|
+
# (2) store the base estimator
|
860
|
+
local_base_estimator_file_name = get_temp_file_path()
|
861
|
+
with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
|
862
|
+
cp.dump(base_estimator, local_base_estimator_file_obj)
|
863
|
+
session.file.put(
|
864
|
+
local_base_estimator_file_name,
|
865
|
+
temp_stage_name,
|
866
|
+
auto_compress=False,
|
867
|
+
overwrite=True,
|
868
|
+
)
|
869
|
+
base_estimator_location = os.path.basename(local_base_estimator_file_name)
|
870
|
+
imports.append(f"@{temp_stage_name}/{base_estimator_location}")
|
871
|
+
|
872
|
+
# (3) store the fit_and_score_kwargs
|
873
|
+
local_fit_and_score_kwargs_file_name = get_temp_file_path()
|
874
|
+
with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
|
875
|
+
cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
|
876
|
+
session.file.put(
|
877
|
+
local_fit_and_score_kwargs_file_name,
|
878
|
+
temp_stage_name,
|
879
|
+
auto_compress=False,
|
880
|
+
overwrite=True,
|
881
|
+
)
|
882
|
+
fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
|
883
|
+
imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
|
884
|
+
|
885
|
+
cross_validator_indices_length = int(len(cross_validator_indices))
|
886
|
+
parameter_grid_length = len(param_grid)
|
887
|
+
|
888
|
+
assert estimator is not None
|
889
|
+
|
890
|
+
@cachetools.cached(cache={})
|
891
|
+
def _load_data_into_udf() -> Tuple[
|
892
|
+
npt.NDArray[Any],
|
893
|
+
npt.NDArray[Any],
|
894
|
+
List[List[int]],
|
895
|
+
List[Dict[str, Any]],
|
896
|
+
object,
|
897
|
+
Dict[str, Any],
|
898
|
+
]:
|
899
|
+
import pyarrow.parquet as pq
|
900
|
+
|
901
|
+
data_files = [
|
902
|
+
filename
|
903
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
904
|
+
if filename.startswith(dataset_file_name)
|
905
|
+
]
|
906
|
+
partial_df = [
|
907
|
+
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
908
|
+
for file_name in data_files
|
909
|
+
]
|
910
|
+
df = pd.concat(partial_df, ignore_index=True)
|
911
|
+
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
912
|
+
|
913
|
+
# load parameter grid
|
914
|
+
local_estimator_file_path = os.path.join(
|
915
|
+
sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
|
916
|
+
)
|
917
|
+
with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
|
918
|
+
estimator_objects = cp.load(local_estimator_file_obj)
|
919
|
+
params_to_evaluate = estimator_objects["param_grid"]
|
920
|
+
|
921
|
+
# load indices
|
922
|
+
local_indices_file_path = os.path.join(
|
923
|
+
sys._xoptions["snowflake_import_directory"], f"{indices_location}"
|
924
|
+
)
|
925
|
+
with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
|
926
|
+
indices = cp.load(local_indices_file_obj)
|
927
|
+
|
928
|
+
# load base estimator
|
929
|
+
local_base_estimator_file_path = os.path.join(
|
930
|
+
sys._xoptions["snowflake_import_directory"], f"{base_estimator_location}"
|
931
|
+
)
|
932
|
+
with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
|
933
|
+
base_estimator = cp.load(local_base_estimator_file_obj)
|
934
|
+
|
935
|
+
# load fit_and_score_kwargs
|
936
|
+
local_fit_and_score_kwargs_file_path = os.path.join(
|
937
|
+
sys._xoptions["snowflake_import_directory"], f"{fit_and_score_kwargs_location}"
|
938
|
+
)
|
939
|
+
with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
|
940
|
+
fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
|
941
|
+
|
942
|
+
# convert dataframe to numpy would save memory consumption
|
943
|
+
return (
|
944
|
+
df[input_cols].to_numpy(),
|
945
|
+
df[label_cols].squeeze().to_numpy(),
|
946
|
+
indices,
|
947
|
+
params_to_evaluate,
|
948
|
+
base_estimator,
|
949
|
+
fit_and_score_kwargs,
|
950
|
+
)
|
951
|
+
|
952
|
+
# Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
|
953
|
+
class SearchCV:
|
954
|
+
def __init__(self) -> None:
|
955
|
+
X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs = _load_data_into_udf()
|
956
|
+
self.X = X
|
957
|
+
self.y = y
|
958
|
+
self.indices = indices
|
959
|
+
self.params_to_evaluate = params_to_evaluate
|
960
|
+
self.base_estimator = base_estimator
|
961
|
+
self.fit_and_score_kwargs = fit_and_score_kwargs
|
962
|
+
self.fit_score_params: List[Any] = []
|
963
|
+
|
964
|
+
def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
|
965
|
+
# 1. Calculate the parameter list
|
966
|
+
parameters = self.params_to_evaluate[params_idx]
|
967
|
+
# 2. Calculate the cross validator indices
|
968
|
+
# cross validator's indices: we stored test indices only (to save space);
|
969
|
+
# use the full index to re-construct each train index back.
|
970
|
+
full_index = np.array([i for i in range(DATA_LENGTH)])
|
971
|
+
test_index = self.indices[cv_idx]
|
972
|
+
train_index = np.setdiff1d(full_index, test_index)
|
973
|
+
self.fit_score_params.extend([[idx, (params_idx, parameters), (cv_idx, (train_index, test_index))]])
|
974
|
+
|
975
|
+
def end_partition(self) -> Iterator[Tuple[int, str]]:
|
976
|
+
from sklearn.base import clone
|
977
|
+
from sklearn.model_selection._validation import _fit_and_score
|
978
|
+
from sklearn.utils.parallel import Parallel, delayed
|
979
|
+
|
980
|
+
parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
|
981
|
+
|
982
|
+
out = parallel(
|
983
|
+
delayed(_fit_and_score)(
|
984
|
+
clone(self.base_estimator),
|
985
|
+
self.X,
|
986
|
+
self.y,
|
987
|
+
train=train,
|
988
|
+
test=test,
|
989
|
+
parameters=parameters,
|
990
|
+
split_progress=(split_idx, n_splits),
|
991
|
+
candidate_progress=(cand_idx, n_candidates),
|
992
|
+
**self.fit_and_score_kwargs, # load sample weight here
|
993
|
+
)
|
994
|
+
for _, (cand_idx, parameters), (split_idx, (train, test)) in self.fit_score_params
|
995
|
+
)
|
996
|
+
|
997
|
+
binary_cv_results = None
|
998
|
+
with io.BytesIO() as f:
|
999
|
+
cp.dump(out, f)
|
1000
|
+
f.seek(0)
|
1001
|
+
binary_cv_results = f.getvalue().hex()
|
1002
|
+
yield (
|
1003
|
+
self.fit_score_params[0][0],
|
1004
|
+
binary_cv_results,
|
1005
|
+
)
|
1006
|
+
|
1007
|
+
session.udtf.register(
|
1008
|
+
SearchCV,
|
1009
|
+
output_schema=StructType([StructField("IDX", IntegerType()), StructField("CV_RESULTS", StringType())]),
|
1010
|
+
input_types=[IntegerType(), IntegerType(), IntegerType()],
|
1011
|
+
name=random_udtf_name,
|
1012
|
+
packages=required_deps, # type: ignore[arg-type]
|
1013
|
+
replace=True,
|
1014
|
+
is_permanent=False,
|
1015
|
+
imports=imports, # type: ignore[arg-type]
|
1016
|
+
statement_params=udtf_statement_params,
|
1017
|
+
)
|
1018
|
+
|
1019
|
+
HP_TUNING = F.table_function(random_udtf_name)
|
1020
|
+
|
1021
|
+
# param_indices is for the index for each parameter grid;
|
1022
|
+
# cv_indices is for the index for each cross_validator's fold;
|
1023
|
+
# param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
|
1024
|
+
param_indices, cv_indices = zip(
|
1025
|
+
*product(range(parameter_grid_length), range(cross_validator_indices_length))
|
1026
|
+
)
|
1027
|
+
|
1028
|
+
indices_info_pandas = pd.DataFrame(
|
1029
|
+
{
|
1030
|
+
"IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
|
1031
|
+
"PARAM_IND": param_indices,
|
1032
|
+
"CV_IND": cv_indices,
|
1033
|
+
}
|
1034
|
+
)
|
1035
|
+
|
1036
|
+
indices_info_sp = session.create_dataframe(indices_info_pandas)
|
1037
|
+
# execute udtf by querying HP_TUNING table
|
1038
|
+
HP_raw_results = indices_info_sp.select(
|
1039
|
+
(
|
1040
|
+
HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
|
1041
|
+
partition_by="IDX"
|
1042
|
+
)
|
1043
|
+
),
|
1044
|
+
)
|
1045
|
+
|
1046
|
+
first_test_score, cv_results_ = construct_cv_results_new_implementation(
|
1047
|
+
estimator,
|
1048
|
+
n_splits,
|
1049
|
+
list(param_grid),
|
1050
|
+
HP_raw_results.select("CV_RESULTS").sort(F.col("IDX")).collect(),
|
1051
|
+
cross_validator_indices_length,
|
1052
|
+
parameter_grid_length,
|
1053
|
+
)
|
1054
|
+
|
1055
|
+
estimator.cv_results_ = cv_results_
|
1056
|
+
estimator.multimetric_ = isinstance(first_test_score, dict)
|
1057
|
+
|
1058
|
+
# check refit_metric now for a callable scorer that is multimetric
|
1059
|
+
if callable(estimator.scoring) and estimator.multimetric_:
|
1060
|
+
estimator._check_refit_for_multimetric(first_test_score)
|
1061
|
+
refit_metric = estimator.refit
|
1062
|
+
|
1063
|
+
# For multi-metric evaluation, store the best_index_, best_params_ and
|
1064
|
+
# best_score_ iff refit is one of the scorer names
|
1065
|
+
# In single metric evaluation, refit_metric is "score"
|
1066
|
+
if estimator.refit or not estimator.multimetric_:
|
1067
|
+
estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
|
1068
|
+
if not callable(estimator.refit):
|
1069
|
+
# With a non-custom callable, we can select the best score
|
1070
|
+
# based on the best index
|
1071
|
+
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
1072
|
+
estimator.best_params_ = cv_results_["params"][estimator.best_index_]
|
1073
|
+
|
1074
|
+
if estimator.refit:
|
1075
|
+
estimator.best_estimator_ = clone(base_estimator).set_params(
|
1076
|
+
**clone(estimator.best_params_, safe=False)
|
1077
|
+
)
|
1078
|
+
|
1079
|
+
# Let the sproc use all cores to refit.
|
1080
|
+
estimator.n_jobs = estimator.n_jobs or -1
|
1081
|
+
|
1082
|
+
# process the input as args
|
1083
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
1084
|
+
args = {"X": X}
|
1085
|
+
if label_cols:
|
1086
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
1087
|
+
args[label_arg_name] = y
|
1088
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
1089
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1090
|
+
# estimator.refit = original_refit
|
1091
|
+
refit_start_time = time.time()
|
1092
|
+
estimator.best_estimator_.fit(**args)
|
1093
|
+
refit_end_time = time.time()
|
1094
|
+
estimator.refit_time_ = refit_end_time - refit_start_time
|
1095
|
+
|
1096
|
+
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
1097
|
+
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
1098
|
+
|
1099
|
+
# Store the only scorer not as a dict for single metric evaluation
|
1100
|
+
estimator.scorer_ = scorers
|
1101
|
+
estimator.n_splits_ = n_splits
|
1102
|
+
|
1103
|
+
local_result_file_name = get_temp_file_path()
|
1104
|
+
|
1105
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
1106
|
+
cp.dump(estimator, local_result_file_obj)
|
1107
|
+
|
1108
|
+
session.file.put(
|
1109
|
+
local_result_file_name,
|
1110
|
+
temp_stage_name,
|
1111
|
+
auto_compress=False,
|
1112
|
+
overwrite=True,
|
1113
|
+
)
|
1114
|
+
|
1115
|
+
return str(os.path.basename(local_result_file_name))
|
1116
|
+
|
1117
|
+
sproc_export_file_name = _distributed_search(
|
1118
|
+
session,
|
1119
|
+
imports,
|
1120
|
+
stage_estimator_file_name,
|
1121
|
+
input_cols,
|
1122
|
+
label_cols,
|
1123
|
+
)
|
1124
|
+
|
1125
|
+
local_estimator_path = get_temp_file_path()
|
1126
|
+
session.file.get(
|
1127
|
+
posixpath.join(temp_stage_name, sproc_export_file_name),
|
1128
|
+
local_estimator_path,
|
1129
|
+
)
|
1130
|
+
|
1131
|
+
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
1132
|
+
fit_estimator = cp.load(result_file_obj)
|
1133
|
+
|
1134
|
+
cleanup_temp_files(local_estimator_path)
|
1135
|
+
|
1136
|
+
return fit_estimator
|
1137
|
+
|
605
1138
|
def train(self) -> object:
|
606
1139
|
"""
|
607
1140
|
Runs hyper parameter optimization by distributing the tasks across warehouse.
|
@@ -630,6 +1163,19 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
630
1163
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
631
1164
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
632
1165
|
)
|
1166
|
+
if ENABLE_EFFICIENT_MEMORY_USAGE:
|
1167
|
+
return self.fit_search_snowpark_new_implementation(
|
1168
|
+
param_grid=param_grid,
|
1169
|
+
dataset=self.dataset,
|
1170
|
+
session=self.session,
|
1171
|
+
estimator=self.estimator,
|
1172
|
+
dependencies=relaxed_dependencies,
|
1173
|
+
udf_imports=["sklearn"],
|
1174
|
+
input_cols=self.input_cols,
|
1175
|
+
label_cols=self.label_cols,
|
1176
|
+
sample_weight_col=self.sample_weight_col,
|
1177
|
+
)
|
1178
|
+
|
633
1179
|
return self.fit_search_snowpark(
|
634
1180
|
param_grid=param_grid,
|
635
1181
|
dataset=self.dataset,
|