snowflake-ml-python 1.5.0__py3-none-any.whl → 1.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_sentiment.py +7 -4
- snowflake/ml/_internal/env_utils.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
- snowflake/ml/dataset/__init__.py +2 -1
- snowflake/ml/dataset/dataset.py +4 -3
- snowflake/ml/dataset/dataset_reader.py +5 -8
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +283 -0
- snowflake/ml/feature_store/feature_store.py +160 -100
- snowflake/ml/feature_store/feature_view.py +30 -19
- snowflake/ml/fileset/embedded_stage_fs.py +15 -12
- snowflake/ml/fileset/snowfs.py +2 -30
- snowflake/ml/fileset/stage_fs.py +25 -7
- snowflake/ml/model/_client/model/model_impl.py +46 -39
- snowflake/ml/model/_client/model/model_version_impl.py +24 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +174 -16
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +32 -39
- snowflake/ml/model/_client/sql/model_version.py +111 -42
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_model_composer/model_composer.py +8 -4
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +90 -142
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +159 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +81 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +8 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +8 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +8 -1
- snowflake/ml/modeling/cluster/birch.py +8 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +8 -1
- snowflake/ml/modeling/cluster/dbscan.py +8 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +8 -1
- snowflake/ml/modeling/cluster/k_means.py +8 -1
- snowflake/ml/modeling/cluster/mean_shift.py +8 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +8 -1
- snowflake/ml/modeling/cluster/optics.py +8 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +8 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +8 -1
- snowflake/ml/modeling/compose/column_transformer.py +8 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +8 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +8 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +8 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +8 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +8 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +8 -1
- snowflake/ml/modeling/covariance/oas.py +8 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +8 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +8 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +8 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +8 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +8 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/pca.py +8 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +8 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +8 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +8 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +8 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +8 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +8 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +8 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +8 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +8 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +8 -1
- snowflake/ml/modeling/framework/base.py +4 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +8 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +8 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +8 -1
- snowflake/ml/modeling/impute/knn_imputer.py +8 -1
- snowflake/ml/modeling/impute/missing_indicator.py +8 -1
- snowflake/ml/modeling/impute/simple_imputer.py +21 -2
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +8 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +8 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +8 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +8 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +8 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +8 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/lars.py +8 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +8 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +8 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +8 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +8 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +8 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/perceptron.py +8 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/ridge.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +8 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +8 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +8 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +8 -1
- snowflake/ml/modeling/manifold/isomap.py +8 -1
- snowflake/ml/modeling/manifold/mds.py +8 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +8 -1
- snowflake/ml/modeling/manifold/tsne.py +8 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +8 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +8 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +8 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +8 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +8 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +8 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +8 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +8 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +8 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +8 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +8 -1
- snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +27 -7
- snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +8 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +8 -1
- snowflake/ml/modeling/svm/linear_svc.py +8 -1
- snowflake/ml/modeling/svm/linear_svr.py +8 -1
- snowflake/ml/modeling/svm/nu_svc.py +8 -1
- snowflake/ml/modeling/svm/nu_svr.py +8 -1
- snowflake/ml/modeling/svm/svc.py +8 -1
- snowflake/ml/modeling/svm/svr.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +8 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +8 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +8 -1
- snowflake/ml/registry/_manager/model_manager.py +95 -8
- snowflake/ml/registry/registry.py +10 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/METADATA +66 -10
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/RECORD +196 -192
- snowflake/ml/_internal/lineage/dataset_dataframe.py +0 -44
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.0.dist-info → snowflake_ml_python-1.5.2.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
8
8
|
|
9
9
|
import cloudpickle as cp
|
10
10
|
import numpy as np
|
11
|
-
import numpy.typing as npt
|
12
11
|
from sklearn import model_selection
|
13
12
|
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
|
14
13
|
|
@@ -36,6 +35,7 @@ from snowflake.snowpark._internal.utils import (
|
|
36
35
|
from snowflake.snowpark.functions import sproc, udtf
|
37
36
|
from snowflake.snowpark.row import Row
|
38
37
|
from snowflake.snowpark.types import IntegerType, StringType, StructField, StructType
|
38
|
+
from snowflake.snowpark.udtf import UDTFRegistration
|
39
39
|
|
40
40
|
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
41
41
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
@@ -154,7 +154,7 @@ def construct_cv_results(
|
|
154
154
|
return multimetric, estimator._format_results(param_grid, n_split, out)
|
155
155
|
|
156
156
|
|
157
|
-
def
|
157
|
+
def construct_cv_results_memory_efficient_version(
|
158
158
|
estimator: Union[GridSearchCV, RandomizedSearchCV],
|
159
159
|
n_split: int,
|
160
160
|
param_grid: List[Dict[str, Any]],
|
@@ -205,12 +205,35 @@ def construct_cv_results_new_implementation(
|
|
205
205
|
with io.BytesIO(hex_str) as f_reload:
|
206
206
|
out = cp.load(f_reload)
|
207
207
|
all_out.extend(out)
|
208
|
+
|
209
|
+
# because original SearchCV is ranked by parameter first and cv second,
|
210
|
+
# to make the memory efficient, we implemented by fitting on cv first and parameter second
|
211
|
+
# when retrieving the results back, the ordering should revert back to remain the same result as original SearchCV
|
212
|
+
def generate_the_order_by_parameter_index(all_combination_length: int) -> List[int]:
|
213
|
+
pattern = []
|
214
|
+
for i in range(all_combination_length):
|
215
|
+
if i % parameter_grid_length == 0:
|
216
|
+
pattern.append(i)
|
217
|
+
for i in range(1, parameter_grid_length):
|
218
|
+
for j in range(all_combination_length):
|
219
|
+
if j % parameter_grid_length == i:
|
220
|
+
pattern.append(j)
|
221
|
+
return pattern
|
222
|
+
|
223
|
+
def rerank_array(original_array: List[Any], pattern: List[int]) -> List[Any]:
|
224
|
+
reranked_array = []
|
225
|
+
for index in pattern:
|
226
|
+
reranked_array.append(original_array[index])
|
227
|
+
return reranked_array
|
228
|
+
|
229
|
+
pattern = generate_the_order_by_parameter_index(len(all_out))
|
230
|
+
reranked_all_out = rerank_array(all_out, pattern)
|
208
231
|
first_test_score = all_out[0]["test_scores"]
|
209
|
-
return first_test_score, estimator._format_results(param_grid, n_split,
|
232
|
+
return first_test_score, estimator._format_results(param_grid, n_split, reranked_all_out)
|
210
233
|
|
211
234
|
|
212
235
|
cp.register_pickle_by_value(inspect.getmodule(construct_cv_results))
|
213
|
-
cp.register_pickle_by_value(inspect.getmodule(
|
236
|
+
cp.register_pickle_by_value(inspect.getmodule(construct_cv_results_memory_efficient_version))
|
214
237
|
|
215
238
|
|
216
239
|
class DistributedHPOTrainer(SnowparkModelTrainer):
|
@@ -661,7 +684,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
661
684
|
|
662
685
|
return fit_estimator
|
663
686
|
|
664
|
-
def
|
687
|
+
def fit_search_snowpark_enable_efficient_memory_usage(
|
665
688
|
self,
|
666
689
|
param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
|
667
690
|
dataset: DataFrame,
|
@@ -675,7 +698,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
675
698
|
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
676
699
|
from itertools import product
|
677
700
|
|
678
|
-
import cachetools
|
679
701
|
from sklearn.base import clone, is_classifier
|
680
702
|
from sklearn.calibration import check_cv
|
681
703
|
|
@@ -696,9 +718,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
696
718
|
# Create a temp file and dump the estimator to that file.
|
697
719
|
estimator_file_name = get_temp_file_path()
|
698
720
|
params_to_evaluate = list(param_grid)
|
699
|
-
|
700
|
-
|
701
|
-
|
721
|
+
CONSTANTS: Dict[str, Any] = dict()
|
722
|
+
CONSTANTS["dataset_snowpark_cols"] = dataset.columns
|
723
|
+
CONSTANTS["n_candidates"] = len(params_to_evaluate)
|
724
|
+
CONSTANTS["_N_JOBS"] = estimator.n_jobs
|
725
|
+
CONSTANTS["_PRE_DISPATCH"] = estimator.pre_dispatch
|
702
726
|
|
703
727
|
with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
|
704
728
|
cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
|
@@ -718,7 +742,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
718
742
|
inspect.currentframe(), self.__class__.__name__
|
719
743
|
),
|
720
744
|
api_calls=[udtf],
|
721
|
-
custom_tags=dict([("
|
745
|
+
custom_tags=dict([("hpo_memory_efficient", True)]),
|
746
|
+
)
|
747
|
+
from snowflake.ml.modeling._internal.snowpark_implementations.distributed_search_udf_file import (
|
748
|
+
execute_template,
|
722
749
|
)
|
723
750
|
|
724
751
|
# Put locally serialized estimator on stage.
|
@@ -730,6 +757,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
730
757
|
)
|
731
758
|
estimator_location = os.path.basename(estimator_file_name)
|
732
759
|
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
760
|
+
CONSTANTS["estimator_location"] = estimator_location
|
733
761
|
|
734
762
|
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
735
763
|
random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
|
@@ -760,7 +788,6 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
760
788
|
) -> str:
|
761
789
|
import os
|
762
790
|
import time
|
763
|
-
from typing import Iterator
|
764
791
|
|
765
792
|
import cloudpickle as cp
|
766
793
|
import pandas as pd
|
@@ -882,146 +909,67 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
882
909
|
fit_and_score_kwargs_location = os.path.basename(local_fit_and_score_kwargs_file_name)
|
883
910
|
imports.append(f"@{temp_stage_name}/{fit_and_score_kwargs_location}")
|
884
911
|
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
npt.NDArray[Any],
|
893
|
-
npt.NDArray[Any],
|
894
|
-
List[List[int]],
|
895
|
-
List[Dict[str, Any]],
|
896
|
-
object,
|
897
|
-
Dict[str, Any],
|
898
|
-
]:
|
899
|
-
import pyarrow.parquet as pq
|
912
|
+
CONSTANTS["input_cols"] = input_cols
|
913
|
+
CONSTANTS["label_cols"] = label_cols
|
914
|
+
CONSTANTS["DATA_LENGTH"] = DATA_LENGTH
|
915
|
+
CONSTANTS["n_splits"] = n_splits
|
916
|
+
CONSTANTS["indices_location"] = indices_location
|
917
|
+
CONSTANTS["base_estimator_location"] = base_estimator_location
|
918
|
+
CONSTANTS["fit_and_score_kwargs_location"] = fit_and_score_kwargs_location
|
900
919
|
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
]
|
906
|
-
partial_df = [
|
907
|
-
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
908
|
-
for file_name in data_files
|
909
|
-
]
|
910
|
-
df = pd.concat(partial_df, ignore_index=True)
|
911
|
-
df.columns = [identifier.get_inferred_name(col_) for col_ in df.columns]
|
920
|
+
# (6) store the constants
|
921
|
+
local_constant_file_name = get_temp_file_path(prefix="constant")
|
922
|
+
with open(local_constant_file_name, mode="w+b") as local_indices_file_obj:
|
923
|
+
cp.dump(CONSTANTS, local_indices_file_obj)
|
912
924
|
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
925
|
+
# Put locally serialized indices on stage.
|
926
|
+
session.file.put(
|
927
|
+
local_constant_file_name,
|
928
|
+
temp_stage_name,
|
929
|
+
auto_compress=False,
|
930
|
+
overwrite=True,
|
931
|
+
)
|
932
|
+
constant_location = os.path.basename(local_constant_file_name)
|
933
|
+
imports.append(f"@{temp_stage_name}/{constant_location}")
|
920
934
|
|
921
|
-
|
922
|
-
|
923
|
-
sys._xoptions["snowflake_import_directory"], f"{indices_location}"
|
924
|
-
)
|
925
|
-
with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
|
926
|
-
indices = cp.load(local_indices_file_obj)
|
935
|
+
cross_validator_indices_length = int(len(cross_validator_indices))
|
936
|
+
parameter_grid_length = len(param_grid)
|
927
937
|
|
928
|
-
|
929
|
-
local_base_estimator_file_path = os.path.join(
|
930
|
-
sys._xoptions["snowflake_import_directory"], f"{base_estimator_location}"
|
931
|
-
)
|
932
|
-
with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
|
933
|
-
base_estimator = cp.load(local_base_estimator_file_obj)
|
938
|
+
assert estimator is not None
|
934
939
|
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
940
|
+
# Instantiate UDTFRegistration with the session object
|
941
|
+
udtf_registration = UDTFRegistration(session)
|
942
|
+
|
943
|
+
import tempfile
|
944
|
+
|
945
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, encoding="utf-8") as f:
|
946
|
+
udf_code = execute_template
|
947
|
+
f.file.write(udf_code)
|
948
|
+
f.file.flush()
|
949
|
+
|
950
|
+
# Register the UDTF function from the file
|
951
|
+
udtf_registration.register_from_file(
|
952
|
+
file_path=f.name,
|
953
|
+
handler_name="SearchCV",
|
954
|
+
name=random_udtf_name,
|
955
|
+
output_schema=StructType(
|
956
|
+
[StructField("FIRST_IDX", IntegerType()), StructField("EACH_CV_RESULTS", StringType())]
|
957
|
+
),
|
958
|
+
input_types=[IntegerType(), IntegerType(), IntegerType()],
|
959
|
+
replace=True,
|
960
|
+
imports=imports, # type: ignore[arg-type]
|
961
|
+
is_permanent=False,
|
962
|
+
packages=required_deps, # type: ignore[arg-type]
|
963
|
+
statement_params=udtf_statement_params,
|
950
964
|
)
|
951
965
|
|
952
|
-
# Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
|
953
|
-
class SearchCV:
|
954
|
-
def __init__(self) -> None:
|
955
|
-
X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs = _load_data_into_udf()
|
956
|
-
self.X = X
|
957
|
-
self.y = y
|
958
|
-
self.test_indices = indices
|
959
|
-
self.params_to_evaluate = params_to_evaluate
|
960
|
-
self.base_estimator = base_estimator
|
961
|
-
self.fit_and_score_kwargs = fit_and_score_kwargs
|
962
|
-
self.fit_score_params: List[Any] = []
|
963
|
-
self.cached_train_test_indices = []
|
964
|
-
# Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
|
965
|
-
full_index = np.arange(DATA_LENGTH)
|
966
|
-
for i in range(n_splits):
|
967
|
-
self.cached_train_test_indices.extend(
|
968
|
-
[[np.setdiff1d(full_index, self.test_indices[i]), self.test_indices[i]]]
|
969
|
-
)
|
970
|
-
|
971
|
-
def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
|
972
|
-
self.fit_score_params.extend([[idx, params_idx, cv_idx]])
|
973
|
-
|
974
|
-
def end_partition(self) -> Iterator[Tuple[int, str]]:
|
975
|
-
from sklearn.base import clone
|
976
|
-
from sklearn.model_selection._validation import _fit_and_score
|
977
|
-
from sklearn.utils.parallel import Parallel, delayed
|
978
|
-
|
979
|
-
parallel = Parallel(n_jobs=_N_JOBS, pre_dispatch=_PRE_DISPATCH)
|
980
|
-
|
981
|
-
out = parallel(
|
982
|
-
delayed(_fit_and_score)(
|
983
|
-
clone(self.base_estimator),
|
984
|
-
self.X,
|
985
|
-
self.y,
|
986
|
-
train=self.cached_train_test_indices[split_idx][0],
|
987
|
-
test=self.cached_train_test_indices[split_idx][1],
|
988
|
-
parameters=self.params_to_evaluate[cand_idx],
|
989
|
-
split_progress=(split_idx, n_splits),
|
990
|
-
candidate_progress=(cand_idx, n_candidates),
|
991
|
-
**self.fit_and_score_kwargs, # load sample weight here
|
992
|
-
)
|
993
|
-
for _, cand_idx, split_idx in self.fit_score_params
|
994
|
-
)
|
995
|
-
|
996
|
-
binary_cv_results = None
|
997
|
-
with io.BytesIO() as f:
|
998
|
-
cp.dump(out, f)
|
999
|
-
f.seek(0)
|
1000
|
-
binary_cv_results = f.getvalue().hex()
|
1001
|
-
yield (
|
1002
|
-
self.fit_score_params[0][0],
|
1003
|
-
binary_cv_results,
|
1004
|
-
)
|
1005
|
-
|
1006
|
-
session.udtf.register(
|
1007
|
-
SearchCV,
|
1008
|
-
output_schema=StructType([StructField("IDX", IntegerType()), StructField("CV_RESULTS", StringType())]),
|
1009
|
-
input_types=[IntegerType(), IntegerType(), IntegerType()],
|
1010
|
-
name=random_udtf_name,
|
1011
|
-
packages=required_deps, # type: ignore[arg-type]
|
1012
|
-
replace=True,
|
1013
|
-
is_permanent=False,
|
1014
|
-
imports=imports, # type: ignore[arg-type]
|
1015
|
-
statement_params=udtf_statement_params,
|
1016
|
-
)
|
1017
|
-
|
1018
966
|
HP_TUNING = F.table_function(random_udtf_name)
|
1019
967
|
|
1020
968
|
# param_indices is for the index for each parameter grid;
|
1021
969
|
# cv_indices is for the index for each cross_validator's fold;
|
1022
970
|
# param_cv_indices is for the index for the product of (len(param_indices) * len(cv_indices))
|
1023
|
-
|
1024
|
-
*product(range(
|
971
|
+
cv_indices, param_indices = zip(
|
972
|
+
*product(range(cross_validator_indices_length), range(parameter_grid_length))
|
1025
973
|
)
|
1026
974
|
|
1027
975
|
indices_info_pandas = pd.DataFrame(
|
@@ -1042,11 +990,11 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1042
990
|
),
|
1043
991
|
)
|
1044
992
|
|
1045
|
-
first_test_score, cv_results_ =
|
993
|
+
first_test_score, cv_results_ = construct_cv_results_memory_efficient_version(
|
1046
994
|
estimator,
|
1047
995
|
n_splits,
|
1048
996
|
list(param_grid),
|
1049
|
-
HP_raw_results.select("
|
997
|
+
HP_raw_results.select("EACH_CV_RESULTS").sort(F.col("FIRST_IDX")).collect(),
|
1050
998
|
cross_validator_indices_length,
|
1051
999
|
parameter_grid_length,
|
1052
1000
|
)
|
@@ -1163,7 +1111,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
|
|
1163
1111
|
pkg_versions=model_spec.pkgDependencies, session=self.session
|
1164
1112
|
)
|
1165
1113
|
if ENABLE_EFFICIENT_MEMORY_USAGE:
|
1166
|
-
return self.
|
1114
|
+
return self.fit_search_snowpark_enable_efficient_memory_usage(
|
1167
1115
|
param_grid=param_grid,
|
1168
1116
|
dataset=self.dataset,
|
1169
1117
|
session=self.session,
|
@@ -0,0 +1,159 @@
|
|
1
|
+
"""
|
2
|
+
Description:
|
3
|
+
This is the helper file for distributed_hpo_trainer.py to create UDTF by `register_from_file`.
|
4
|
+
Performance Benefits:
|
5
|
+
The performance benefits come from two aspects,
|
6
|
+
1. register_from_file can reduce duplicating loading data by only loading data once in each node
|
7
|
+
2. register_from_file enable user to load data in global variable, whereas writing UDF in python script cannot.
|
8
|
+
Developer Tips:
|
9
|
+
Because this script is now a string, so there's no type hinting, linting, etc. It is highly recommended
|
10
|
+
to develop in a python script, test the type hinting, and then convert it into a string.
|
11
|
+
"""
|
12
|
+
|
13
|
+
execute_template = """
|
14
|
+
from typing import Tuple, Any, List, Dict, Set, Iterator
|
15
|
+
import os
|
16
|
+
import sys
|
17
|
+
import pandas as pd
|
18
|
+
import numpy as np
|
19
|
+
import numpy.typing as npt
|
20
|
+
import cloudpickle as cp
|
21
|
+
import io
|
22
|
+
|
23
|
+
|
24
|
+
def _load_data_into_udf() -> Tuple[
|
25
|
+
npt.NDArray[Any],
|
26
|
+
npt.NDArray[Any],
|
27
|
+
List[List[int]],
|
28
|
+
List[Dict[str, Any]],
|
29
|
+
object,
|
30
|
+
Dict[str, Any],
|
31
|
+
Dict[str, Any],
|
32
|
+
]:
|
33
|
+
import pyarrow.parquet as pq
|
34
|
+
|
35
|
+
data_files = [
|
36
|
+
filename
|
37
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
38
|
+
if filename.startswith("dataset")
|
39
|
+
]
|
40
|
+
partial_df = [
|
41
|
+
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
42
|
+
for file_name in data_files
|
43
|
+
]
|
44
|
+
df = pd.concat(partial_df, ignore_index=True)
|
45
|
+
constant_file_path = None
|
46
|
+
for filename in os.listdir(sys._xoptions["snowflake_import_directory"]):
|
47
|
+
if filename.startswith("constant"):
|
48
|
+
constant_file_path = os.path.join(sys._xoptions["snowflake_import_directory"], f"{filename}")
|
49
|
+
if constant_file_path is None:
|
50
|
+
raise ValueError("UDTF cannot find the constant location, abort!")
|
51
|
+
with open(constant_file_path, mode="rb") as constant_file_obj:
|
52
|
+
CONSTANTS = cp.load(constant_file_obj)
|
53
|
+
df.columns = CONSTANTS['dataset_snowpark_cols']
|
54
|
+
|
55
|
+
# load parameter grid
|
56
|
+
local_estimator_file_path = os.path.join(
|
57
|
+
sys._xoptions["snowflake_import_directory"],
|
58
|
+
f"{CONSTANTS['estimator_location']}"
|
59
|
+
)
|
60
|
+
with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
|
61
|
+
estimator_objects = cp.load(local_estimator_file_obj)
|
62
|
+
params_to_evaluate = estimator_objects["param_grid"]
|
63
|
+
|
64
|
+
# load indices
|
65
|
+
local_indices_file_path = os.path.join(
|
66
|
+
sys._xoptions["snowflake_import_directory"],
|
67
|
+
f"{CONSTANTS['indices_location']}"
|
68
|
+
)
|
69
|
+
with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
|
70
|
+
indices = cp.load(local_indices_file_obj)
|
71
|
+
|
72
|
+
# load base estimator
|
73
|
+
local_base_estimator_file_path = os.path.join(
|
74
|
+
sys._xoptions["snowflake_import_directory"], f"{CONSTANTS['base_estimator_location']}"
|
75
|
+
)
|
76
|
+
with open(local_base_estimator_file_path, mode="rb") as local_base_estimator_file_obj:
|
77
|
+
base_estimator = cp.load(local_base_estimator_file_obj)
|
78
|
+
|
79
|
+
# load fit_and_score_kwargs
|
80
|
+
local_fit_and_score_kwargs_file_path = os.path.join(
|
81
|
+
sys._xoptions["snowflake_import_directory"], f"{CONSTANTS['fit_and_score_kwargs_location']}"
|
82
|
+
)
|
83
|
+
with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
|
84
|
+
fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
|
85
|
+
|
86
|
+
# convert dataframe to numpy would save memory consumption
|
87
|
+
return (
|
88
|
+
df[CONSTANTS['input_cols']].to_numpy(),
|
89
|
+
df[CONSTANTS['label_cols']].squeeze().to_numpy(),
|
90
|
+
indices,
|
91
|
+
params_to_evaluate,
|
92
|
+
base_estimator,
|
93
|
+
fit_and_score_kwargs,
|
94
|
+
CONSTANTS
|
95
|
+
)
|
96
|
+
|
97
|
+
|
98
|
+
global_load_data = _load_data_into_udf()
|
99
|
+
|
100
|
+
|
101
|
+
# Note Table functions (UDTFs) have a limit of 500 input arguments and 500 output columns.
|
102
|
+
class SearchCV:
|
103
|
+
def __init__(self) -> None:
|
104
|
+
X, y, indices, params_to_evaluate, base_estimator, fit_and_score_kwargs, CONSTANTS = global_load_data
|
105
|
+
self.X = X
|
106
|
+
self.y = y
|
107
|
+
self.test_indices = indices
|
108
|
+
self.params_to_evaluate = params_to_evaluate
|
109
|
+
self.base_estimator = base_estimator
|
110
|
+
self.fit_and_score_kwargs = fit_and_score_kwargs
|
111
|
+
self.fit_score_params: List[Any] = []
|
112
|
+
self.CONSTANTS = CONSTANTS
|
113
|
+
self.cv_indices_set: Set[int] = set()
|
114
|
+
|
115
|
+
def process(self, idx: int, params_idx: int, cv_idx: int) -> None:
|
116
|
+
self.fit_score_params.extend([[idx, params_idx, cv_idx]])
|
117
|
+
self.cv_indices_set.add(cv_idx)
|
118
|
+
|
119
|
+
def end_partition(self) -> Iterator[Tuple[int, str]]:
|
120
|
+
from sklearn.base import clone
|
121
|
+
from sklearn.model_selection._validation import _fit_and_score
|
122
|
+
from sklearn.utils.parallel import Parallel, delayed
|
123
|
+
|
124
|
+
cached_train_test_indices = {}
|
125
|
+
# Calculate the full index here to avoid duplicate calculation (which consumes a lot of memory)
|
126
|
+
full_index = np.arange(self.CONSTANTS['DATA_LENGTH'])
|
127
|
+
for i in self.cv_indices_set:
|
128
|
+
cached_train_test_indices[i] = [
|
129
|
+
np.setdiff1d(full_index, self.test_indices[i]),
|
130
|
+
self.test_indices[i],
|
131
|
+
]
|
132
|
+
|
133
|
+
parallel = Parallel(n_jobs=self.CONSTANTS['_N_JOBS'], pre_dispatch=self.CONSTANTS['_PRE_DISPATCH'])
|
134
|
+
|
135
|
+
out = parallel(
|
136
|
+
delayed(_fit_and_score)(
|
137
|
+
clone(self.base_estimator),
|
138
|
+
self.X,
|
139
|
+
self.y,
|
140
|
+
train=cached_train_test_indices[split_idx][0],
|
141
|
+
test=cached_train_test_indices[split_idx][1],
|
142
|
+
parameters=self.params_to_evaluate[cand_idx],
|
143
|
+
split_progress=(split_idx, self.CONSTANTS['n_splits']),
|
144
|
+
candidate_progress=(cand_idx, self.CONSTANTS['n_candidates']),
|
145
|
+
**self.fit_and_score_kwargs, # load sample weight here
|
146
|
+
)
|
147
|
+
for _, cand_idx, split_idx in self.fit_score_params
|
148
|
+
)
|
149
|
+
|
150
|
+
binary_cv_results = None
|
151
|
+
with io.BytesIO() as f:
|
152
|
+
cp.dump(out, f)
|
153
|
+
f.seek(0)
|
154
|
+
binary_cv_results = f.getvalue().hex()
|
155
|
+
yield (
|
156
|
+
self.fit_score_params[0][0],
|
157
|
+
binary_cv_results,
|
158
|
+
)
|
159
|
+
"""
|
@@ -45,6 +45,7 @@ cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
|
45
45
|
cp.register_pickle_by_value(inspect.getmodule(handle_inference_result))
|
46
46
|
|
47
47
|
_PROJECT = "ModelDevelopment"
|
48
|
+
_ENABLE_ANONYMOUS_SPROC = False
|
48
49
|
|
49
50
|
|
50
51
|
class SnowparkModelTrainer:
|
@@ -251,6 +252,27 @@ class SnowparkModelTrainer:
|
|
251
252
|
|
252
253
|
return fit_wrapper_function
|
253
254
|
|
255
|
+
def _get_fit_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
256
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
257
|
+
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
258
|
+
|
259
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
260
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
261
|
+
)
|
262
|
+
|
263
|
+
fit_wrapper_sproc = self.session.sproc.register(
|
264
|
+
func=self._build_fit_wrapper_sproc(model_spec=model_spec),
|
265
|
+
is_permanent=False,
|
266
|
+
name=fit_sproc_name,
|
267
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
268
|
+
replace=True,
|
269
|
+
session=self.session,
|
270
|
+
statement_params=statement_params,
|
271
|
+
anonymous=True,
|
272
|
+
)
|
273
|
+
|
274
|
+
return fit_wrapper_sproc
|
275
|
+
|
254
276
|
def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
255
277
|
# If the sproc already exists, don't register.
|
256
278
|
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
@@ -510,6 +532,28 @@ class SnowparkModelTrainer:
|
|
510
532
|
|
511
533
|
return fit_transform_wrapper_function
|
512
534
|
|
535
|
+
def _get_fit_predict_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
536
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
537
|
+
|
538
|
+
fit_predict_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
539
|
+
|
540
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
541
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
542
|
+
)
|
543
|
+
|
544
|
+
fit_predict_wrapper_sproc = self.session.sproc.register(
|
545
|
+
func=self._build_fit_predict_wrapper_sproc(model_spec=model_spec),
|
546
|
+
is_permanent=False,
|
547
|
+
name=fit_predict_sproc_name,
|
548
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
549
|
+
replace=True,
|
550
|
+
session=self.session,
|
551
|
+
statement_params=statement_params,
|
552
|
+
anonymous=True,
|
553
|
+
)
|
554
|
+
|
555
|
+
return fit_predict_wrapper_sproc
|
556
|
+
|
513
557
|
def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
514
558
|
# If the sproc already exists, don't register.
|
515
559
|
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
@@ -545,6 +589,27 @@ class SnowparkModelTrainer:
|
|
545
589
|
|
546
590
|
return fit_predict_wrapper_sproc
|
547
591
|
|
592
|
+
def _get_fit_transform_wrapper_sproc_anonymous(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
593
|
+
model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
|
594
|
+
|
595
|
+
fit_transform_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
596
|
+
|
597
|
+
relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
598
|
+
pkg_versions=model_spec.pkgDependencies, session=self.session
|
599
|
+
)
|
600
|
+
|
601
|
+
fit_transform_wrapper_sproc = self.session.sproc.register(
|
602
|
+
func=self._build_fit_transform_wrapper_sproc(model_spec=model_spec),
|
603
|
+
is_permanent=False,
|
604
|
+
name=fit_transform_sproc_name,
|
605
|
+
packages=["snowflake-snowpark-python"] + relaxed_dependencies, # type: ignore[arg-type]
|
606
|
+
replace=True,
|
607
|
+
session=self.session,
|
608
|
+
statement_params=statement_params,
|
609
|
+
anonymous=True,
|
610
|
+
)
|
611
|
+
return fit_transform_wrapper_sproc
|
612
|
+
|
548
613
|
def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str]) -> StoredProcedure:
|
549
614
|
# If the sproc already exists, don't register.
|
550
615
|
if not hasattr(self.session, "_FIT_WRAPPER_SPROCS"):
|
@@ -612,7 +677,10 @@ class SnowparkModelTrainer:
|
|
612
677
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
613
678
|
)
|
614
679
|
|
615
|
-
|
680
|
+
if _ENABLE_ANONYMOUS_SPROC:
|
681
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc_anonymous(statement_params=statement_params)
|
682
|
+
else:
|
683
|
+
fit_wrapper_sproc = self._get_fit_wrapper_sproc(statement_params=statement_params)
|
616
684
|
|
617
685
|
try:
|
618
686
|
sproc_export_file_name: str = fit_wrapper_sproc(
|
@@ -680,7 +748,11 @@ class SnowparkModelTrainer:
|
|
680
748
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
681
749
|
)
|
682
750
|
|
683
|
-
|
751
|
+
if _ENABLE_ANONYMOUS_SPROC:
|
752
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc_anonymous(statement_params=statement_params)
|
753
|
+
else:
|
754
|
+
fit_predict_wrapper_sproc = self._get_fit_predict_wrapper_sproc(statement_params=statement_params)
|
755
|
+
|
684
756
|
fit_predict_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
685
757
|
|
686
758
|
sproc_export_file_name: str = fit_predict_wrapper_sproc(
|
@@ -741,7 +813,13 @@ class SnowparkModelTrainer:
|
|
741
813
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
742
814
|
)
|
743
815
|
|
744
|
-
|
816
|
+
if _ENABLE_ANONYMOUS_SPROC:
|
817
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc_anonymous(
|
818
|
+
statement_params=statement_params
|
819
|
+
)
|
820
|
+
else:
|
821
|
+
fit_transform_wrapper_sproc = self._get_fit_transform_wrapper_sproc(statement_params=statement_params)
|
822
|
+
|
745
823
|
fit_transform_result_name = random_name_for_temp_object(TempObjectType.TABLE)
|
746
824
|
|
747
825
|
sproc_export_file_name: str = fit_transform_wrapper_sproc(
|
@@ -629,7 +629,14 @@ class CalibratedClassifierCV(BaseTransformer):
|
|
629
629
|
) -> List[str]:
|
630
630
|
# in case the inferred output column names dimension is different
|
631
631
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
632
|
-
|
632
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
|
633
|
+
|
634
|
+
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
635
|
+
# seen during the fit.
|
636
|
+
snowpark_column_names = dataset.select(self.input_cols).columns
|
637
|
+
sample_pd_df.columns = snowpark_column_names
|
638
|
+
|
639
|
+
output_df_pd = getattr(self, method)(sample_pd_df, output_cols_prefix)
|
633
640
|
output_df_columns = list(output_df_pd.columns)
|
634
641
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
635
642
|
if self.sample_weight_col:
|
@@ -606,7 +606,14 @@ class AffinityPropagation(BaseTransformer):
|
|
606
606
|
) -> List[str]:
|
607
607
|
# in case the inferred output column names dimension is different
|
608
608
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
609
|
-
|
609
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
|
610
|
+
|
611
|
+
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
612
|
+
# seen during the fit.
|
613
|
+
snowpark_column_names = dataset.select(self.input_cols).columns
|
614
|
+
sample_pd_df.columns = snowpark_column_names
|
615
|
+
|
616
|
+
output_df_pd = getattr(self, method)(sample_pd_df, output_cols_prefix)
|
610
617
|
output_df_columns = list(output_df_pd.columns)
|
611
618
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
612
619
|
if self.sample_weight_col:
|