snowflake-ml-python 1.3.1__py3-none-any.whl → 1.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +11 -1
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/feature_store/feature_store.py +166 -184
- snowflake/ml/feature_store/feature_view.py +12 -24
- snowflake/ml/fileset/sfcfs.py +56 -50
- snowflake/ml/fileset/stage_fs.py +48 -13
- snowflake/ml/model/_client/model/model_version_impl.py +6 -49
- snowflake/ml/model/_client/ops/model_ops.py +78 -29
- snowflake/ml/model/_client/sql/model.py +23 -2
- snowflake/ml/model/_client/sql/model_version.py +22 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -3
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +7 -5
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +19 -54
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +8 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +13 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +1 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +36 -6
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -2
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +545 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -5
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +195 -123
- snowflake/ml/modeling/cluster/affinity_propagation.py +195 -123
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +195 -123
- snowflake/ml/modeling/cluster/birch.py +195 -123
- snowflake/ml/modeling/cluster/bisecting_k_means.py +195 -123
- snowflake/ml/modeling/cluster/dbscan.py +195 -123
- snowflake/ml/modeling/cluster/feature_agglomeration.py +195 -123
- snowflake/ml/modeling/cluster/k_means.py +195 -123
- snowflake/ml/modeling/cluster/mean_shift.py +195 -123
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +195 -123
- snowflake/ml/modeling/cluster/optics.py +195 -123
- snowflake/ml/modeling/cluster/spectral_biclustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_clustering.py +195 -123
- snowflake/ml/modeling/cluster/spectral_coclustering.py +195 -123
- snowflake/ml/modeling/compose/column_transformer.py +195 -123
- snowflake/ml/modeling/compose/transformed_target_regressor.py +195 -123
- snowflake/ml/modeling/covariance/elliptic_envelope.py +195 -123
- snowflake/ml/modeling/covariance/empirical_covariance.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso.py +195 -123
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +195 -123
- snowflake/ml/modeling/covariance/ledoit_wolf.py +195 -123
- snowflake/ml/modeling/covariance/min_cov_det.py +195 -123
- snowflake/ml/modeling/covariance/oas.py +195 -123
- snowflake/ml/modeling/covariance/shrunk_covariance.py +195 -123
- snowflake/ml/modeling/decomposition/dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/factor_analysis.py +195 -123
- snowflake/ml/modeling/decomposition/fast_ica.py +195 -123
- snowflake/ml/modeling/decomposition/incremental_pca.py +195 -123
- snowflake/ml/modeling/decomposition/kernel_pca.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +195 -123
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/pca.py +195 -123
- snowflake/ml/modeling/decomposition/sparse_pca.py +195 -123
- snowflake/ml/modeling/decomposition/truncated_svd.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/bagging_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/isolation_forest.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/stacking_regressor.py +195 -123
- snowflake/ml/modeling/ensemble/voting_classifier.py +195 -123
- snowflake/ml/modeling/ensemble/voting_regressor.py +195 -123
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fdr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fpr.py +195 -123
- snowflake/ml/modeling/feature_selection/select_fwe.py +195 -123
- snowflake/ml/modeling/feature_selection/select_k_best.py +195 -123
- snowflake/ml/modeling/feature_selection/select_percentile.py +195 -123
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +195 -123
- snowflake/ml/modeling/feature_selection/variance_threshold.py +195 -123
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +24 -6
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +195 -123
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +195 -123
- snowflake/ml/modeling/impute/iterative_imputer.py +195 -123
- snowflake/ml/modeling/impute/knn_imputer.py +195 -123
- snowflake/ml/modeling/impute/missing_indicator.py +195 -123
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/nystroem.py +195 -123
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +195 -123
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +195 -123
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +195 -123
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +195 -123
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +198 -125
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +198 -125
- snowflake/ml/modeling/linear_model/ard_regression.py +195 -123
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/gamma_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/huber_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/lars.py +195 -123
- snowflake/ml/modeling/linear_model/lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +195 -123
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +195 -123
- snowflake/ml/modeling/linear_model/linear_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression.py +195 -123
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +195 -123
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +195 -123
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/perceptron.py +195 -123
- snowflake/ml/modeling/linear_model/poisson_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ransac_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/ridge.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +195 -123
- snowflake/ml/modeling/linear_model/ridge_cv.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_classifier.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +195 -123
- snowflake/ml/modeling/linear_model/sgd_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +195 -123
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +195 -123
- snowflake/ml/modeling/manifold/isomap.py +195 -123
- snowflake/ml/modeling/manifold/mds.py +195 -123
- snowflake/ml/modeling/manifold/spectral_embedding.py +195 -123
- snowflake/ml/modeling/manifold/tsne.py +195 -123
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +195 -123
- snowflake/ml/modeling/mixture/gaussian_mixture.py +195 -123
- snowflake/ml/modeling/model_selection/grid_search_cv.py +42 -18
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +42 -18
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +195 -123
- snowflake/ml/modeling/multiclass/output_code_classifier.py +195 -123
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/complement_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +195 -123
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neighbors/kernel_density.py +195 -123
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_centroid.py +195 -123
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +195 -123
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +195 -123
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +195 -123
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_classifier.py +195 -123
- snowflake/ml/modeling/neural_network/mlp_regressor.py +195 -123
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +195 -123
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +195 -123
- snowflake/ml/modeling/semi_supervised/label_spreading.py +195 -123
- snowflake/ml/modeling/svm/linear_svc.py +195 -123
- snowflake/ml/modeling/svm/linear_svr.py +195 -123
- snowflake/ml/modeling/svm/nu_svc.py +195 -123
- snowflake/ml/modeling/svm/nu_svr.py +195 -123
- snowflake/ml/modeling/svm/svc.py +195 -123
- snowflake/ml/modeling/svm/svr.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/decision_tree_regressor.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_classifier.py +195 -123
- snowflake/ml/modeling/tree/extra_tree_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgb_regressor.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +195 -123
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +195 -123
- snowflake/ml/registry/_manager/model_manager.py +5 -1
- snowflake/ml/registry/model_registry.py +99 -26
- snowflake/ml/registry/registry.py +3 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/METADATA +94 -55
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/RECORD +218 -212
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.1.dist-info → snowflake_ml_python-1.4.1.dist-info}/top_level.txt +0 -0
@@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
8
8
|
|
9
9
|
import cloudpickle as cp
|
10
10
|
import numpy as np
|
11
|
+
import numpy.typing as npt
|
11
12
|
from sklearn import model_selection
|
12
13
|
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
13
14
|
|
@@ -38,9 +39,11 @@ from snowflake.snowpark.types import IntegerType, StringType, StructField, Struc
|
|
38
39
|
|
39
40
|
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
40
41
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
42
|
+
cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snowpark_dataframe))
|
41
43
|
|
42
44
|
_PROJECT = "ModelDevelopment"
|
43
45
|
DEFAULT_UDTF_NJOBS = 3
|
46
|
+
ENABLE_EFFICIENT_MEMORY_USAGE = False
|
44
47
|
|
45
48
|
|
46
49
|
def construct_cv_results(
|
@@ -151,7 +154,63 @@ def construct_cv_results(
|
|
151
154
|
return multimetric, estimator._format_results(param_grid, n_split, out)
|
152
155
|
|
153
156
|
|
157
|
+
def construct_cv_results_new_implementation(
|
158
|
+
estimator: Union[GridSearchCV, RandomizedSearchCV],
|
159
|
+
n_split: int,
|
160
|
+
param_grid: List[Dict[str, Any]],
|
161
|
+
cv_results_raw_hex: List[Row],
|
162
|
+
cross_validator_indices_length: int,
|
163
|
+
parameter_grid_length: int,
|
164
|
+
) -> Tuple[Any, Dict[str, Any]]:
|
165
|
+
"""Construct the cross validation result from the UDF.
|
166
|
+
The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
|
167
|
+
This function need to decode the string and then call _format_result to stick them back together
|
168
|
+
to align with original sklearn result.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
estimator (Union[GridSearchCV, RandomizedSearchCV]): The sklearn object of estimator
|
172
|
+
GridSearchCV or RandomizedSearchCV
|
173
|
+
n_split (int): The number of split, which is determined by build_cross_validator.get_n_splits(X, y, groups)
|
174
|
+
param_grid (List[Dict[str, Any]]): the list of parameter grid or parameter sampler
|
175
|
+
cv_results_raw_hex (List[Row]): the list of cv_results from each cv and parameter grid combination.
|
176
|
+
Because UDxF can only return string, and numpy array/masked arrays cannot be encoded in a
|
177
|
+
json format. Each cv_result is encoded into hex string.
|
178
|
+
cross_validator_indices_length (int): the length of cross validator indices
|
179
|
+
parameter_grid_length (int): the length of parameter grid combination
|
180
|
+
|
181
|
+
Raises:
|
182
|
+
ValueError: Retrieved empty cross validation results
|
183
|
+
ValueError: Cross validator index length is 0
|
184
|
+
ValueError: Parameter index length is 0
|
185
|
+
ValueError: Retrieved incorrect dataframe dimension from Snowpark's UDTF.
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Tuple[Any, Dict[str, Any]]: returns first_test_score, cv_results_
|
189
|
+
"""
|
190
|
+
# Filter corner cases: either the snowpark dataframe result is empty; or index length is empty
|
191
|
+
if len(cv_results_raw_hex) == 0:
|
192
|
+
raise ValueError(
|
193
|
+
"Retrieved empty cross validation results from snowpark. Please retry or contact snowflake support."
|
194
|
+
)
|
195
|
+
if cross_validator_indices_length == 0:
|
196
|
+
raise ValueError("Cross validator index length is 0. Was the CV iterator empty? ")
|
197
|
+
if parameter_grid_length == 0:
|
198
|
+
raise ValueError("Parameter index length is 0. Were there no candidates?")
|
199
|
+
|
200
|
+
all_out = []
|
201
|
+
|
202
|
+
for each_cv_result_hex in cv_results_raw_hex:
|
203
|
+
# convert the hex string back to cv_results_
|
204
|
+
hex_str = bytes.fromhex(each_cv_result_hex[0])
|
205
|
+
with io.BytesIO(hex_str) as f_reload:
|
206
|
+
out = cp.load(f_reload)
|
207
|
+
all_out.extend(out)
|
208
|
+
first_test_score = all_out[0]["test_scores"]
|
209
|
+
return first_test_score, estimator._format_results(param_grid, n_split, all_out)
|
210
|
+
|
211
|
+
|
154
212
|
cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
|
213
|
+
cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_new_implementation))
|
155
214
|
|
156
215
|
|
157
216
|
class DistributedHPOTrainer(SnowparkModelTrainer):
|
@@ -602,6 +661,479 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
602
661
|
|
603
662
|
return fit_estimator
|
604
663
|
|
664
|
+
def fit_search_snowpark_new_implementation(
|
665
|
+
self,
|
666
|
+
param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
|
667
|
+
dataset: DataFrame,
|
668
|
+
session: Session,
|
669
|
+
estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
670
|
+
dependencies: List[str],
|
671
|
+
udf_imports: List[str],
|
672
|
+
input_cols: List[str],
|
673
|
+
label_cols: Optional[List[str]],
|
674
|
+
sample_weight_col: Optional[str],
|
675
|
+
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
676
|
+
from itertools import product
|
677
|
+
|
678
|
+
import cachetools
|
679
|
+
from sklearn.base import clone, is_classifier
|
680
|
+
from sklearn.calibration import check_cv
|
681
|
+
|
682
|
+
# Create one stage for data and for estimators.
|
683
|
+
temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
|
684
|
+
temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
|
685
|
+
session.sql(temp_stage_creation_query).collect()
|
686
|
+
|
687
|
+
# Stage data as parquet file
|
688
|
+
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
|
689
|
+
dataset_file_name = "dataset"
|
690
|
+
remote_file_path = f"{temp_stage_name}/{dataset_file_name}.parquet"
|
691
|
+
dataset.write.copy_into_location( # type:ignore[call-overload]
|
692
|
+
remote_file_path, file_format_type="parquet", header=True, overwrite=True
|
693
|
+
)
|
694
|
+
imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}/{dataset_file_name}").collect()]
|
695
|
+
|
696
|
+
# Create a temp file and dump the estimator to that file.
|
697
|
+
estimator_file_name = get_temp_file_path()
|
698
|
+
params_to_evaluate = list(param_grid)
|
699
|
+
n_candidates = len(params_to_evaluate)
|
700
|
+
_N_JOBS = estimator.n_jobs
|
701
|
+
_PRE_DISPATCH = estimator.pre_dispatch
|
702
|
+
|
703
|
+
with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
|
704
|
+
cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
|
705
|
+
stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name))
|
706
|
+
sproc_statement_params = telemetry.get_function_usage_statement_params(
|
707
|
+
project=_PROJECT,
|
708
|
+
subproject=self._subproject,
|
709
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
710
|
+
inspect.currentframe(), self.__class__.__name__
|
711
|
+
),
|
712
|
+
api_calls=[sproc],
|
713
|
+
)
|
714
|
+
udtf_statement_params = telemetry.get_function_usage_statement_params(
|
715
|
+
project=_PROJECT,
|
716
|
+
subproject=self._subproject,
|
717
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
718
|
+
inspect.currentframe(), self.__class__.__name__
|
719
|
+
),
|
720
|
+
api_calls=[udtf],
|
721
|
+
custom_tags=dict([("hpo_udtf", True)]),
|
722
|
+
)
|
723
|
+
|
724
|
+
# Put locally serialized estimator on stage.
|
725
|
+
session.file.put(
|
726
|
+
estimator_file_name,
|
727
|
+
temp_stage_name,
|
728
|
+
auto_compress=False,
|
729
|
+
overwrite=True,
|
730
|
+
)
|
731
|
+
estimator_location = os.path.basename(estimator_file_name)
|
732
|
+
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
733
|
+
|
734
|
+
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
735
|
+
random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
|
736
|
+
|
737
|
+
required_deps = dependencies + [
|
738
|
+
"snowflake-snowpark-python<2",
|
739
|
+
"fastparquet<2023.11",
|
740
|
+
"pyarrow<14",
|
741
|
+
"cachetools<6",
|
742
|
+
]
|
743
|
+
|
744
|
+
@sproc( # type: ignore[misc]
|
745
|
+
is_permanent=False,
|
746
|
+
name=search_sproc_name,
|
747
|
+
packages=required_deps, # type: ignore[arg-type]
|
748
|
+
replace=True,
|
749
|
+
session=session,
|
750
|
+
anonymous=True,
|
751
|
+
imports=imports, # type: ignore[arg-type]
|
752
|
+
statement_params=sproc_statement_params,
|
753
|
+
)
|
754
|
+
def _distributed_search(
|
755
|
+
session: Session,
|
756
|
+
imports: List[str],
|
757
|
+
stage_estimator_file_name: str,
|
758
|
+
input_cols: List[str],
|
759
|
+
label_cols: Optional[List[str]],
|
760
|
+
) -> str:
|
761
|
+
import os
|
762
|
+
import time
|
763
|
+
from typing import Iterator
|
764
|
+
|
765
|
+
import cloudpickle as cp
|
766
|
+
import pandas as pd
|
767
|
+
import pyarrow.parquet as pq
|
768
|
+
from sklearn.metrics import check_scoring
|
769
|
+
from sklearn.metrics._scorer import _check_multimetric_scoring
|
770
|
+
from sklearn.utils.validation import _check_fit_params, indexable
|
771
|
+
|
772
|
+
# import packages in sproc
|
773
|
+
for import_name in udf_imports:
|
774
|
+
importlib.import_module(import_name)
|
775
|
+
|
776
|
+
# os.cpu_count() returns the number of logical CPUs in the system. Returns None if undetermined.
|
777
|
+
_NUM_CPUs = os.cpu_count() or 1
|
778
|
+
|
779
|
+
# load dataset
|
780
|
+
data_files = [
|
781
|
+
filename
|
782
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
783
|
+
if filename.startswith(dataset_file_name)
|
784
|
+
]
|
785
|
+
partial_df = [
|
786
|
+
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
787
|
+
for file_name in data_files
|
788
|
+
]
|
789
|
+
df = pd.concat(partial_df, ignore_index=True)
|
790
|
+
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
791
|
+
|
792
|
+
X = df[input_cols]
|
793
|
+
y = df[label_cols].squeeze() if label_cols else None
|
794
|
+
DATA_LENGTH = len(df)
|
795
|
+
fit_params = {}
|
796
|
+
if sample_weight_col:
|
797
|
+
fit_params["sample_weight"] = df[sample_weight_col].squeeze()
|
798
|
+
|
799
|
+
local_estimator_file_folder_name = get_temp_file_path()
|
800
|
+
session.file.get(stage_estimator_file_name, local_estimator_file_folder_name)
|
801
|
+
|
802
|
+
local_estimator_file_path = os.path.join(
|
803
|
+
local_estimator_file_folder_name, os.listdir(local_estimator_file_folder_name)[0]
|
804
|
+
)
|
805
|
+
with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
|
806
|
+
estimator = cp.load(local_estimator_file_obj)["estimator"]
|
807
|
+
|
808
|
+
# preprocess the attributes - (1) scorer
|
809
|
+
refit_metric = "score"
|
810
|
+
if callable(estimator.scoring):
|
811
|
+
scorers = estimator.scoring
|
812
|
+
elif estimator.scoring is None or isinstance(estimator.scoring, str):
|
813
|
+
scorers = check_scoring(estimator.estimator, estimator.scoring)
|
814
|
+
else:
|
815
|
+
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
816
|
+
estimator._check_refit_for_multimetric(scorers)
|
817
|
+
refit_metric = estimator.refit
|
818
|
+
|
819
|
+
# preprocess the attributes - (2) check fit_params
|
820
|
+
groups = None
|
821
|
+
X, y, _ = indexable(X, y, groups)
|
822
|
+
fit_params = _check_fit_params(X, fit_params)
|
823
|
+
|
824
|
+
# preprocess the attributes - (3) safe clone base estimator
|
825
|
+
base_estimator = clone(estimator.estimator)
|
826
|
+
|
827
|
+
# preprocess the attributes - (4) check cv
|
828
|
+
build_cross_validator = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
|
829
|
+
n_splits = build_cross_validator.get_n_splits(X, y, groups)
|
830
|
+
|
831
|
+
# preprocess the attributes - (5) generate fit_and_score_kwargs
|
832
|
+
fit_and_score_kwargs = dict(
|
833
|
+
scorer=scorers,
|
834
|
+
fit_params=fit_params,
|
835
|
+
return_train_score=estimator.return_train_score,
|
836
|
+
return_n_test_samples=True,
|
837
|
+
return_times=True,
|
838
|
+
return_parameters=False,
|
839
|
+
error_score=estimator.error_score,
|
840
|
+
verbose=estimator.verbose,
|
841
|
+
)
|
842
|
+
|
843
|
+
# (1) store the cross_validator's test indices only to save space
|
844
|
+
cross_validator_indices = [test for _, test in build_cross_validator.split(X, y, None)]
|
845
|
+
local_indices_file_name = get_temp_file_path()
|
846
|
+
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
847
|
+
cp.dump(cross_validator_indices, local_indices_file_obj)
|
848
|
+
|
849
|
+
# Put locally serialized indices on stage.
|
850
|
+
session.file.put(
|
851
|
+
local_indices_file_name,
|
852
|
+
temp_stage_name,
|
853
|
+
auto_compress=False,
|
854
|
+
overwrite=True,
|
855
|
+
)
|
856
|
+
indices_location = os.path.basename(local_indices_file_name)
|
857
|
+
imports.append(f"@{temp_stage_name}/{indices_location}")
|
858
|
+
|
859
|
+
# (2) store the base estimator
|
860
|
+
local_base_estimator_file_name = get_temp_file_path()
|
861
|
+
with open(local_base_estimator_file_name, mode="w+b") as local_base_estimator_file_obj:
|
862
|
+
cp.dump(base_estimator, local_base_estimator_file_obj)
|
863
|
+
session.file.put(
|
864
|
+
local_base_estimator_file_name,
|
865
|
+
temp_stage_name,
|
866
|
+
auto_compress=False,
|
867
|
+
overwrite=True,
|
868
|
+
)
|
869
|
+
base_estimator_location = os.path.basename(local_base_estimator_file_name)
|
870
|
+
imports.append(f"@{temp_stage_name}/{base_estimator_location}")
|
871
|
+
|
872
|
+
# (3) store the fit_and_score_kwargs
|
873
|
+
local_fit_and_score_kwargs_file_name = get_temp_file_path()
|
874
|
+
with open(local_fit_and_score_kwargs_file_name, mode="w+b") as local_fit_and_score_kwargs_file_obj:
|
875
|
+
cp.dump(fit_and_score_kwargs, local_fit_and_score_kwargs_file_obj)
|
876
|
+
session.file.put(
|
877
|
+
local_fit_and_score_kwargs_file_name,
|
878
|
+
temp_stage_name,
|
879
|
+
auto_compress=False,
|
880
|
+
overwrite=True,
|
881
|
+
)
|
882
|
+
fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
|
883
|
+
imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
|
884
|
+
|
885
|
+
cross_validator_indices_length = int(len(cross_validator_indices))
|
886
|
+
parameter_grid_length = len(param_grid)
|
887
|
+
|
888
|
+
assert estimator is not None
|
889
|
+
|
890
|
+
@cachetools.cached(cache={})
|
891
|
+
def _load_data_into_udf() -> Tuple[
|
892
|
+
npt.NDArray[Any],
|
893
|
+
npt.NDArray[Any],
|
894
|
+
List[List[int]],
|
895
|
+
List[Dict[str, Any]],
|
896
|
+
object,
|
897
|
+
Dict[str, Any],
|
898
|
+
]:
|
899
|
+
import pyarrow.parquet as pq
|
900
|
+
|
901
|
+
data_files = [
|
902
|
+
filename
|
903
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
904
|
+
if filename.startswith(dataset_file_name)
|
905
|
+
]
|
906
|
+
partial_df = [
|
907
|
+
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
908
|
+
for file_name in data_files
|
909
|
+
]
|
910
|
+
df = pd.concat(partial_df, ignore_index=True)
|
911
|
+
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
912
|
+
|
913
|
+
# load parameter grid
|
914
|
+
local_estimator_file_path = os.path.join(
|
915
|
+
sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
|
916
|
+
)
|
917
|
+
with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
|
918
|
+
estimator_objects = cp.load(local_estimator_file_obj)
|
919
|
+
params_to_evaluate = estimator_objects["param_grid"]
|
920
|
+
|
921
|
+
# load indices
|
922
|
+
local_indices_file_path = os.path.join(
|
923
|
+
sys._xoptions["snowflake_import_directory"], f"{indices_location}"
|
924
|
+
)
|
925
|
+
with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
|
926
|
+
indices = cp.load(local_indices_file_obj)
|
927
|
+
|
928
|
+
# load base estimator
|
929
|
+
local_base_estimator_file_path = os.path.join(
|
930
|
+
sys._xoptions["snowflake_import_directory"], f"{base_estimator_location}"
|
931
|
+
)
|
932
|
+
with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
|
933
|
+
base_estimator = cp.load(local_base_estimator_file_obj)
|
934
|
+
|
935
|
+
# load fit_and_score_kwargs
|
936
|
+
local_fit_and_score_kwargs_file_path = os.path.join(
|
937
|
+
sys._xoptions["snowflake_import_directory"], f"{fit_and_score_kwargs_location}"
|
938
|
+
)
|
939
|
+
with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
|
940
|
+
fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
|
941
|
+
|
942
|
+
# convert dataframe to numpy would save memory consumption
|
943
|
+
return (
|
944
|
+
df[input_cols].to_numpy(),
|
945
|
+
df[label_cols].squeeze().to_numpy(),
|
946
|
+
indices,
|
947
|
+
params_to_evaluate,
|
948
|
+
base_estimator,
|
949
|
+
fit_and_score_kwargs,
|
950
|
+
)
|
951
|
+
|
952
|
+
# Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
|
953
|
+
class SearchCV:
|
954
|
+
def __init__(self) -> None:
|
955
|
+
X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs = _load_data_into_udf()
|
956
|
+
self.X = X
|
957
|
+
self.y = y
|
958
|
+
self.test_indices = indices
|
959
|
+
self.params_to_evaluate = params_to_evaluate
|
960
|
+
self.base_estimator = base_estimator
|
961
|
+
self.fit_and_score_kwargs = fit_and_score_kwargs
|
962
|
+
self.fit_score_params: List[Any] = []
|
963
|
+
self.cached_train_test_indices = []
|
964
|
+
# Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
|
965
|
+
full_index = np.arange(DATA_LENGTH)
|
966
|
+
for i in range(n_splits):
|
967
|
+
self.cached_train_test_indices.extend(
|
968
|
+
[[np.setdiff1d(full_index, self.test_indices[i]), self.test_indices[i]]]
|
969
|
+
)
|
970
|
+
|
971
|
+
def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
|
972
|
+
self.fit_score_params.extend([[idx, params_idx, cv_idx]])
|
973
|
+
|
974
|
+
def end_partition(self) -> Iterator[Tuple[int, str]]:
|
975
|
+
from sklearn.base import clone
|
976
|
+
from sklearn.model_selection._validation import _fit_and_score
|
977
|
+
from sklearn.utils.parallel import Parallel, delayed
|
978
|
+
|
979
|
+
parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
|
980
|
+
|
981
|
+
out = parallel(
|
982
|
+
delayed(_fit_and_score)(
|
983
|
+
clone(self.base_estimator),
|
984
|
+
self.X,
|
985
|
+
self.y,
|
986
|
+
train=self.cached_train_test_indices[split_idx][0],
|
987
|
+
test=self.cached_train_test_indices[split_idx][1],
|
988
|
+
parameters=self.params_to_evaluate[cand_idx],
|
989
|
+
split_progress=(split_idx, n_splits),
|
990
|
+
candidate_progress=(cand_idx, n_candidates),
|
991
|
+
**self.fit_and_score_kwargs, # load sample weight here
|
992
|
+
)
|
993
|
+
for _, cand_idx, split_idx in self.fit_score_params
|
994
|
+
)
|
995
|
+
|
996
|
+
binary_cv_results = None
|
997
|
+
with io.BytesIO() as f:
|
998
|
+
cp.dump(out, f)
|
999
|
+
f.seek(0)
|
1000
|
+
binary_cv_results = f.getvalue().hex()
|
1001
|
+
yield (
|
1002
|
+
self.fit_score_params[0][0],
|
1003
|
+
binary_cv_results,
|
1004
|
+
)
|
1005
|
+
|
1006
|
+
session.udtf.register(
|
1007
|
+
SearchCV,
|
1008
|
+
output_schema=StructType([StructField("IDX", IntegerType()), StructField("CV_RESULTS", StringType())]),
|
1009
|
+
input_types=[IntegerType(), IntegerType(), IntegerType()],
|
1010
|
+
name=random_udtf_name,
|
1011
|
+
packages=required_deps, # type: ignore[arg-type]
|
1012
|
+
replace=True,
|
1013
|
+
is_permanent=False,
|
1014
|
+
imports=imports, # type: ignore[arg-type]
|
1015
|
+
statement_params=udtf_statement_params,
|
1016
|
+
)
|
1017
|
+
|
1018
|
+
HP_TUNING = F.table_function(random_udtf_name)
|
1019
|
+
|
1020
|
+
# param_indices is for the index for each parameter grid;
|
1021
|
+
# cv_indices is for the index for each cross_validator's fold;
|
1022
|
+
# param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
|
1023
|
+
param_indices, cv_indices = zip(
|
1024
|
+
*product(range(parameter_grid_length), range(cross_validator_indices_length))
|
1025
|
+
)
|
1026
|
+
|
1027
|
+
indices_info_pandas = pd.DataFrame(
|
1028
|
+
{
|
1029
|
+
"IDX": [i // _NUM_CPUs for i in range(parameter_grid_length * cross_validator_indices_length)],
|
1030
|
+
"PARAM_IND": param_indices,
|
1031
|
+
"CV_IND": cv_indices,
|
1032
|
+
}
|
1033
|
+
)
|
1034
|
+
|
1035
|
+
indices_info_sp = session.create_dataframe(indices_info_pandas)
|
1036
|
+
# execute udtf by querying HP_TUNING table
|
1037
|
+
HP_raw_results = indices_info_sp.select(
|
1038
|
+
(
|
1039
|
+
HP_TUNING(indices_info_sp["IDX"], indices_info_sp["PARAM_IND"], indices_info_sp["CV_IND"]).over(
|
1040
|
+
partition_by="IDX"
|
1041
|
+
)
|
1042
|
+
),
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
first_test_score, cv_results_ = construct_cv_results_new_implementation(
|
1046
|
+
estimator,
|
1047
|
+
n_splits,
|
1048
|
+
list(param_grid),
|
1049
|
+
HP_raw_results.select("CV_RESULTS").sort(F.col("IDX")).collect(),
|
1050
|
+
cross_validator_indices_length,
|
1051
|
+
parameter_grid_length,
|
1052
|
+
)
|
1053
|
+
|
1054
|
+
estimator.cv_results_ = cv_results_
|
1055
|
+
estimator.multimetric_ = isinstance(first_test_score, dict)
|
1056
|
+
|
1057
|
+
# check refit_metric now for a callable scorer that is multimetric
|
1058
|
+
if callable(estimator.scoring) and estimator.multimetric_:
|
1059
|
+
estimator._check_refit_for_multimetric(first_test_score)
|
1060
|
+
refit_metric = estimator.refit
|
1061
|
+
|
1062
|
+
# For multi-metric evaluation, store the best_index_, best_params_ and
|
1063
|
+
# best_score_ iff refit is one of the scorer names
|
1064
|
+
# In single metric evaluation, refit_metric is "score"
|
1065
|
+
if estimator.refit or not estimator.multimetric_:
|
1066
|
+
estimator.best_index_ = estimator._select_best_index(estimator.refit, refit_metric, cv_results_)
|
1067
|
+
if not callable(estimator.refit):
|
1068
|
+
# With a non-custom callable, we can select the best score
|
1069
|
+
# based on the best index
|
1070
|
+
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
1071
|
+
estimator.best_params_ = cv_results_["params"][estimator.best_index_]
|
1072
|
+
|
1073
|
+
if estimator.refit:
|
1074
|
+
estimator.best_estimator_ = clone(base_estimator).set_params(
|
1075
|
+
**clone(estimator.best_params_, safe=False)
|
1076
|
+
)
|
1077
|
+
|
1078
|
+
# Let the sproc use all cores to refit.
|
1079
|
+
estimator.n_jobs = estimator.n_jobs or -1
|
1080
|
+
|
1081
|
+
# process the input as args
|
1082
|
+
argspec = inspect.getfullargspec(estimator.fit)
|
1083
|
+
args = {"X": X}
|
1084
|
+
if label_cols:
|
1085
|
+
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
1086
|
+
args[label_arg_name] = y
|
1087
|
+
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
1088
|
+
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1089
|
+
# estimator.refit = original_refit
|
1090
|
+
refit_start_time = time.time()
|
1091
|
+
estimator.best_estimator_.fit(**args)
|
1092
|
+
refit_end_time = time.time()
|
1093
|
+
estimator.refit_time_ = refit_end_time - refit_start_time
|
1094
|
+
|
1095
|
+
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
1096
|
+
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
1097
|
+
|
1098
|
+
# Store the only scorer not as a dict for single metric evaluation
|
1099
|
+
estimator.scorer_ = scorers
|
1100
|
+
estimator.n_splits_ = n_splits
|
1101
|
+
|
1102
|
+
local_result_file_name = get_temp_file_path()
|
1103
|
+
|
1104
|
+
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
1105
|
+
cp.dump(estimator, local_result_file_obj)
|
1106
|
+
|
1107
|
+
session.file.put(
|
1108
|
+
local_result_file_name,
|
1109
|
+
temp_stage_name,
|
1110
|
+
auto_compress=False,
|
1111
|
+
overwrite=True,
|
1112
|
+
)
|
1113
|
+
|
1114
|
+
return str(os.path.basename(local_result_file_name))
|
1115
|
+
|
1116
|
+
sproc_export_file_name = _distributed_search(
|
1117
|
+
session,
|
1118
|
+
imports,
|
1119
|
+
stage_estimator_file_name,
|
1120
|
+
input_cols,
|
1121
|
+
label_cols,
|
1122
|
+
)
|
1123
|
+
|
1124
|
+
local_estimator_path = get_temp_file_path()
|
1125
|
+
session.file.get(
|
1126
|
+
posixpath.join(temp_stage_name, sproc_export_file_name),
|
1127
|
+
local_estimator_path,
|
1128
|
+
)
|
1129
|
+
|
1130
|
+
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
1131
|
+
fit_estimator = cp.load(result_file_obj)
|
1132
|
+
|
1133
|
+
cleanup_temp_files(local_estimator_path)
|
1134
|
+
|
1135
|
+
return fit_estimator
|
1136
|
+
|
605
1137
|
def train(self) -> object:
|
606
1138
|
"""
|
607
1139
|
Runs hyper parameter optimization by distributing the tasks across warehouse.
|
@@ -630,6 +1162,19 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
630
1162
|
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
631
1163
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
632
1164
|
)
|
1165
|
+
if ENABLE_EFFICIENT_MEMORY_USAGE:
|
1166
|
+
return self.fit_search_snowpark_new_implementation(
|
1167
|
+
param_grid=param_grid,
|
1168
|
+
dataset=self.dataset,
|
1169
|
+
session=self.session,
|
1170
|
+
estimator=self.estimator,
|
1171
|
+
dependencies=relaxed_dependencies,
|
1172
|
+
udf_imports=["sklearn"],
|
1173
|
+
input_cols=self.input_cols,
|
1174
|
+
label_cols=self.label_cols,
|
1175
|
+
sample_weight_col=self.sample_weight_col,
|
1176
|
+
)
|
1177
|
+
|
633
1178
|
return self.fit_search_snowpark(
|
634
1179
|
param_grid=param_grid,
|
635
1180
|
dataset=self.dataset,
|
@@ -131,9 +131,12 @@ class SnowparkTransformHandlers:
|
|
131
131
|
|
132
132
|
input_df.columns = snowpark_cols
|
133
133
|
|
134
|
+
if hasattr(estimator, "n_jobs"):
|
135
|
+
# Vectorized UDF cannot handle joblib multiprocessing right now, deactivate the n_jobs
|
136
|
+
estimator.n_jobs = 1
|
134
137
|
inference_res = getattr(estimator, inference_method)(input_df, *args, **kwargs)
|
135
138
|
|
136
|
-
transformed_numpy_array,
|
139
|
+
transformed_numpy_array, _ = handle_inference_result(
|
137
140
|
inference_res=inference_res,
|
138
141
|
output_cols=expected_output_cols,
|
139
142
|
inference_method=inference_method,
|
@@ -141,13 +144,13 @@ class SnowparkTransformHandlers:
|
|
141
144
|
)
|
142
145
|
|
143
146
|
if len(transformed_numpy_array.shape) > 1:
|
144
|
-
if transformed_numpy_array.shape[1] != len(
|
147
|
+
if transformed_numpy_array.shape[1] != len(expected_output_cols):
|
145
148
|
series = pd.Series(transformed_numpy_array.tolist())
|
146
|
-
transformed_pandas_df = pd.DataFrame(series, columns=
|
149
|
+
transformed_pandas_df = pd.DataFrame(series, columns=expected_output_cols)
|
147
150
|
else:
|
148
|
-
transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=
|
151
|
+
transformed_pandas_df = pd.DataFrame(transformed_numpy_array.tolist(), columns=expected_output_cols)
|
149
152
|
else:
|
150
|
-
transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=
|
153
|
+
transformed_pandas_df = pd.DataFrame(transformed_numpy_array, columns=expected_output_cols)
|
151
154
|
|
152
155
|
return transformed_pandas_df.to_dict("records") # type: ignore[no-any-return]
|
153
156
|
|