snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
- snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
- snowflake/ml/modeling/cluster/birch.py +94 -124
- snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
- snowflake/ml/modeling/cluster/dbscan.py +94 -124
- snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
- snowflake/ml/modeling/cluster/k_means.py +93 -124
- snowflake/ml/modeling/cluster/mean_shift.py +94 -124
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
- snowflake/ml/modeling/cluster/optics.py +94 -124
- snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
- snowflake/ml/modeling/compose/column_transformer.py +94 -124
- snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
- snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
- snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
- snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
- snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
- snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
- snowflake/ml/modeling/covariance/oas.py +80 -110
- snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
- snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
- snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
- snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
- snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/pca.py +94 -124
- snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
- snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
- snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
- snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
- snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
- snowflake/ml/modeling/framework/base.py +2 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
- snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
- snowflake/ml/modeling/impute/knn_imputer.py +94 -124
- snowflake/ml/modeling/impute/missing_indicator.py +94 -124
- snowflake/ml/modeling/impute/simple_imputer.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
- snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
- snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/lars.py +96 -124
- snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
- snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
- snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
- snowflake/ml/modeling/linear_model/perceptron.py +95 -124
- snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ridge.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
- snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
- snowflake/ml/modeling/manifold/isomap.py +94 -124
- snowflake/ml/modeling/manifold/mds.py +94 -124
- snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
- snowflake/ml/modeling/manifold/tsne.py +94 -124
- snowflake/ml/modeling/metrics/classification.py +187 -52
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
- snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
- snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
- snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
- snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
- snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
- snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
- snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
- snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
- snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
- snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
- snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
- snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
- snowflake/ml/modeling/svm/linear_svc.py +96 -124
- snowflake/ml/modeling/svm/linear_svr.py +96 -124
- snowflake/ml/modeling/svm/nu_svc.py +96 -124
- snowflake/ml/modeling/svm/nu_svr.py +96 -124
- snowflake/ml/modeling/svm/svc.py +96 -124
- snowflake/ml/modeling/svm/svr.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
- snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -1,51 +1,29 @@
|
|
1
1
|
import importlib
|
2
2
|
import inspect
|
3
|
-
import io
|
4
3
|
import os
|
5
4
|
import posixpath
|
6
|
-
import
|
7
|
-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
5
|
+
from typing import Any, Dict, List, Optional
|
8
6
|
from uuid import uuid4
|
9
7
|
|
10
8
|
import cloudpickle as cp
|
11
|
-
import numpy as np
|
12
9
|
import pandas as pd
|
13
|
-
import sklearn
|
14
|
-
from scipy.stats import rankdata
|
15
|
-
from sklearn import model_selection
|
16
10
|
|
17
11
|
from snowflake.ml._internal import telemetry
|
18
12
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
19
|
-
from snowflake.ml._internal.exceptions import
|
20
|
-
error_codes,
|
21
|
-
exceptions,
|
22
|
-
modeling_error_messages,
|
23
|
-
)
|
13
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
24
14
|
from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
|
25
15
|
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
|
26
16
|
from snowflake.ml._internal.utils.temp_file_utils import (
|
27
17
|
cleanup_temp_files,
|
28
18
|
get_temp_file_path,
|
29
19
|
)
|
30
|
-
from snowflake.snowpark import
|
31
|
-
DataFrame,
|
32
|
-
Session,
|
33
|
-
exceptions as snowpark_exceptions,
|
34
|
-
functions as F,
|
35
|
-
)
|
20
|
+
from snowflake.snowpark import DataFrame, Session
|
36
21
|
from snowflake.snowpark._internal.utils import (
|
37
22
|
TempObjectType,
|
38
23
|
random_name_for_temp_object,
|
39
24
|
)
|
40
|
-
from snowflake.snowpark.functions import
|
41
|
-
from snowflake.snowpark.
|
42
|
-
from snowflake.snowpark.types import (
|
43
|
-
IntegerType,
|
44
|
-
PandasSeries,
|
45
|
-
StringType,
|
46
|
-
StructField,
|
47
|
-
StructType,
|
48
|
-
)
|
25
|
+
from snowflake.snowpark.functions import pandas_udf, sproc
|
26
|
+
from snowflake.snowpark.types import PandasSeries
|
49
27
|
|
50
28
|
cp.register_pickle_by_value(inspect.getmodule(get_temp_file_path))
|
51
29
|
cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
@@ -53,144 +31,6 @@ cp.register_pickle_by_value(inspect.getmodule(identifier.get_inferred_name))
|
|
53
31
|
_PROJECT = "ModelDevelopment"
|
54
32
|
|
55
33
|
|
56
|
-
class WrapperProvider:
|
57
|
-
def __init__(self) -> None:
|
58
|
-
self.imports: List[str] = []
|
59
|
-
self.dependencies: List[str] = []
|
60
|
-
|
61
|
-
def get_fit_wrapper_function(
|
62
|
-
self,
|
63
|
-
) -> Callable[[Any, List[str], str, str, List[str], List[str], Optional[str], Dict[str, str]], str]:
|
64
|
-
imports = self.imports # In order for the sproc to not resolve this reference in snowflake.ml
|
65
|
-
|
66
|
-
def fit_wrapper_function(
|
67
|
-
session: Session,
|
68
|
-
sql_queries: List[str],
|
69
|
-
stage_transform_file_name: str,
|
70
|
-
stage_result_file_name: str,
|
71
|
-
input_cols: List[str],
|
72
|
-
label_cols: List[str],
|
73
|
-
sample_weight_col: Optional[str],
|
74
|
-
statement_params: Dict[str, str],
|
75
|
-
) -> str:
|
76
|
-
import inspect
|
77
|
-
import os
|
78
|
-
|
79
|
-
import cloudpickle as cp
|
80
|
-
import pandas as pd
|
81
|
-
|
82
|
-
for import_name in imports:
|
83
|
-
importlib.import_module(import_name)
|
84
|
-
|
85
|
-
# Execute snowpark queries and obtain the results as pandas dataframe
|
86
|
-
# NB: this implies that the result data must fit into memory.
|
87
|
-
for query in sql_queries[:-1]:
|
88
|
-
_ = session.sql(query).collect(statement_params=statement_params)
|
89
|
-
sp_df = session.sql(sql_queries[-1])
|
90
|
-
df: pd.DataFrame = sp_df.to_pandas(statement_params=statement_params)
|
91
|
-
df.columns = sp_df.columns
|
92
|
-
|
93
|
-
local_transform_file_name = get_temp_file_path()
|
94
|
-
|
95
|
-
session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
|
96
|
-
|
97
|
-
local_transform_file_path = os.path.join(
|
98
|
-
local_transform_file_name, os.listdir(local_transform_file_name)[0]
|
99
|
-
)
|
100
|
-
with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
|
101
|
-
estimator = cp.load(local_transform_file_obj)
|
102
|
-
|
103
|
-
argspec = inspect.getfullargspec(estimator.fit)
|
104
|
-
args = {"X": df[input_cols]}
|
105
|
-
if label_cols:
|
106
|
-
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
107
|
-
args[label_arg_name] = df[label_cols].squeeze()
|
108
|
-
|
109
|
-
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
110
|
-
args["sample_weight"] = df[sample_weight_col].squeeze()
|
111
|
-
|
112
|
-
estimator.fit(**args)
|
113
|
-
|
114
|
-
local_result_file_name = get_temp_file_path()
|
115
|
-
|
116
|
-
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
117
|
-
cp.dump(estimator, local_result_file_obj)
|
118
|
-
|
119
|
-
session.file.put(
|
120
|
-
local_result_file_name,
|
121
|
-
stage_result_file_name,
|
122
|
-
auto_compress=False,
|
123
|
-
overwrite=True,
|
124
|
-
statement_params=statement_params,
|
125
|
-
)
|
126
|
-
|
127
|
-
# Note: you can add something like + "|" + str(df) to the return string
|
128
|
-
# to pass debug information to the caller.
|
129
|
-
return str(os.path.basename(local_result_file_name))
|
130
|
-
|
131
|
-
return fit_wrapper_function
|
132
|
-
|
133
|
-
|
134
|
-
class SklearnWrapperProvider(WrapperProvider):
|
135
|
-
def __init__(self) -> None:
|
136
|
-
import sklearn
|
137
|
-
|
138
|
-
self.imports: List[str] = ["sklearn"]
|
139
|
-
|
140
|
-
# TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
|
141
|
-
self.dependencies: List[str] = [
|
142
|
-
f"numpy=={np.__version__}",
|
143
|
-
f"scikit-learn=={sklearn.__version__}",
|
144
|
-
f"cloudpickle=={cp.__version__}",
|
145
|
-
]
|
146
|
-
|
147
|
-
|
148
|
-
class XGBoostWrapperProvider(WrapperProvider):
|
149
|
-
def __init__(self) -> None:
|
150
|
-
import xgboost
|
151
|
-
|
152
|
-
self.imports: List[str] = ["xgboost"]
|
153
|
-
self.dependencies = [
|
154
|
-
f"numpy=={np.__version__}",
|
155
|
-
f"xgboost=={xgboost.__version__}",
|
156
|
-
f"cloudpickle=={cp.__version__}",
|
157
|
-
]
|
158
|
-
|
159
|
-
|
160
|
-
class LightGBMWrapperProvider(WrapperProvider):
|
161
|
-
def __init__(self) -> None:
|
162
|
-
import lightgbm
|
163
|
-
|
164
|
-
self.imports: List[str] = ["lightgbm"]
|
165
|
-
self.dependencies = [
|
166
|
-
f"numpy=={np.__version__}",
|
167
|
-
f"lightgbm=={lightgbm.__version__}",
|
168
|
-
f"cloudpickle=={cp.__version__}",
|
169
|
-
]
|
170
|
-
|
171
|
-
|
172
|
-
class SklearnModelSelectionWrapperProvider(WrapperProvider):
|
173
|
-
def __init__(self) -> None:
|
174
|
-
import xgboost
|
175
|
-
|
176
|
-
self.imports: List[str] = ["sklearn", "xgboost"]
|
177
|
-
self.dependencies = [
|
178
|
-
f"numpy=={np.__version__}",
|
179
|
-
f"scikit-learn=={sklearn.__version__}",
|
180
|
-
f"cloudpickle=={cp.__version__}",
|
181
|
-
f"xgboost=={xgboost.__version__}",
|
182
|
-
]
|
183
|
-
|
184
|
-
# Only include lightgbm in the dependencies if it is installed.
|
185
|
-
try:
|
186
|
-
import lightgbm
|
187
|
-
except ModuleNotFoundError:
|
188
|
-
pass
|
189
|
-
else:
|
190
|
-
self.imports.append("lightgbm")
|
191
|
-
self.dependencies.append(f"lightgbm=={lightgbm.__version__}")
|
192
|
-
|
193
|
-
|
194
34
|
def _get_rand_id() -> str:
|
195
35
|
"""
|
196
36
|
Generate random id to be used in sproc and stage names.
|
@@ -202,171 +42,11 @@ def _get_rand_id() -> str:
|
|
202
42
|
|
203
43
|
|
204
44
|
class SnowparkHandlers:
|
205
|
-
def __init__(
|
206
|
-
self, class_name: str, subproject: str, wrapper_provider: WrapperProvider, autogenerated: Optional[bool] = False
|
207
|
-
) -> None:
|
45
|
+
def __init__(self, class_name: str, subproject: str, autogenerated: Optional[bool] = False) -> None:
|
208
46
|
self._class_name = class_name
|
209
47
|
self._subproject = subproject
|
210
|
-
self._wrapper_provider = wrapper_provider
|
211
48
|
self._autogenerated = autogenerated
|
212
49
|
|
213
|
-
def _get_fit_wrapper_sproc(
|
214
|
-
self, dependencies: List[str], session: Session, statement_params: Dict[str, str]
|
215
|
-
) -> StoredProcedure:
|
216
|
-
# If the sproc already exists, don't register.
|
217
|
-
if not hasattr(session, "_FIT_WRAPPER_SPROCS"):
|
218
|
-
session._FIT_WRAPPER_SPROCS: Dict[str, StoredProcedure] = {} # type: ignore[attr-defined, misc]
|
219
|
-
|
220
|
-
fit_sproc_key = self._wrapper_provider.__class__.__name__
|
221
|
-
if fit_sproc_key in session._FIT_WRAPPER_SPROCS: # type: ignore[attr-defined]
|
222
|
-
fit_sproc: StoredProcedure = session._FIT_WRAPPER_SPROCS[fit_sproc_key] # type: ignore[attr-defined]
|
223
|
-
return fit_sproc
|
224
|
-
|
225
|
-
fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
226
|
-
|
227
|
-
fit_wrapper_sproc = session.sproc.register(
|
228
|
-
func=self._wrapper_provider.get_fit_wrapper_function(),
|
229
|
-
is_permanent=False,
|
230
|
-
name=fit_sproc_name,
|
231
|
-
packages=dependencies, # type: ignore[arg-type]
|
232
|
-
replace=True,
|
233
|
-
session=session,
|
234
|
-
statement_params=statement_params,
|
235
|
-
)
|
236
|
-
|
237
|
-
session._FIT_WRAPPER_SPROCS[fit_sproc_key] = fit_wrapper_sproc # type: ignore[attr-defined]
|
238
|
-
|
239
|
-
return fit_wrapper_sproc
|
240
|
-
|
241
|
-
def fit_pandas(
|
242
|
-
self,
|
243
|
-
dataset: pd.DataFrame,
|
244
|
-
estimator: object,
|
245
|
-
input_cols: List[str],
|
246
|
-
label_cols: Optional[List[str]],
|
247
|
-
sample_weight_col: Optional[str],
|
248
|
-
) -> object:
|
249
|
-
assert hasattr(estimator, "fit") # Keep mypy happy
|
250
|
-
argspec = inspect.getfullargspec(estimator.fit)
|
251
|
-
args = {"X": dataset[input_cols]}
|
252
|
-
|
253
|
-
if label_cols:
|
254
|
-
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
255
|
-
args[label_arg_name] = dataset[label_cols].squeeze()
|
256
|
-
|
257
|
-
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
258
|
-
args["sample_weight"] = dataset[sample_weight_col].squeeze()
|
259
|
-
|
260
|
-
return estimator.fit(**args)
|
261
|
-
|
262
|
-
def fit_snowpark(
|
263
|
-
self,
|
264
|
-
dataset: DataFrame,
|
265
|
-
session: Session,
|
266
|
-
estimator: object,
|
267
|
-
dependencies: List[str],
|
268
|
-
input_cols: List[str],
|
269
|
-
label_cols: List[str],
|
270
|
-
sample_weight_col: Optional[str],
|
271
|
-
) -> Any:
|
272
|
-
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe_column_types(dataset)
|
273
|
-
|
274
|
-
# If we are already in a stored procedure, no need to kick off another one.
|
275
|
-
if SNOWML_SPROC_ENV in os.environ:
|
276
|
-
statement_params = telemetry.get_function_usage_statement_params(
|
277
|
-
project=_PROJECT,
|
278
|
-
subproject=self._subproject,
|
279
|
-
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
280
|
-
api_calls=[Session.call],
|
281
|
-
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
282
|
-
)
|
283
|
-
pd_df: pd.DataFrame = dataset.to_pandas(statement_params=statement_params)
|
284
|
-
pd_df.columns = dataset.columns
|
285
|
-
return self.fit_pandas(pd_df, estimator, input_cols, label_cols, sample_weight_col)
|
286
|
-
|
287
|
-
# Extract query that generated the dataframe. We will need to pass it to the fit procedure.
|
288
|
-
queries = dataset.queries["queries"]
|
289
|
-
|
290
|
-
# Create a temp file and dump the transform to that file.
|
291
|
-
local_transform_file_name = get_temp_file_path()
|
292
|
-
with open(local_transform_file_name, mode="w+b") as local_transform_file:
|
293
|
-
cp.dump(estimator, local_transform_file)
|
294
|
-
|
295
|
-
# Create temp stage to run fit.
|
296
|
-
transform_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
|
297
|
-
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
298
|
-
SqlResultValidator(session=session, query=stage_creation_query).has_dimensions(
|
299
|
-
expected_rows=1, expected_cols=1
|
300
|
-
).validate()
|
301
|
-
|
302
|
-
# Use posixpath to construct stage paths
|
303
|
-
stage_transform_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
304
|
-
stage_result_file_name = posixpath.join(transform_stage_name, os.path.basename(local_transform_file_name))
|
305
|
-
local_result_file_name = get_temp_file_path()
|
306
|
-
|
307
|
-
statement_params = telemetry.get_function_usage_statement_params(
|
308
|
-
project=_PROJECT,
|
309
|
-
subproject=self._subproject,
|
310
|
-
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
311
|
-
api_calls=[sproc],
|
312
|
-
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
313
|
-
)
|
314
|
-
# Put locally serialized transform on stage.
|
315
|
-
session.file.put(
|
316
|
-
local_transform_file_name,
|
317
|
-
stage_transform_file_name,
|
318
|
-
auto_compress=False,
|
319
|
-
overwrite=True,
|
320
|
-
statement_params=statement_params,
|
321
|
-
)
|
322
|
-
|
323
|
-
# Call fit sproc
|
324
|
-
statement_params = telemetry.get_function_usage_statement_params(
|
325
|
-
project=_PROJECT,
|
326
|
-
subproject=self._subproject,
|
327
|
-
function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
|
328
|
-
api_calls=[Session.call],
|
329
|
-
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
330
|
-
)
|
331
|
-
|
332
|
-
fit_wrapper_sproc = self._get_fit_wrapper_sproc(dependencies, session, statement_params)
|
333
|
-
|
334
|
-
try:
|
335
|
-
sproc_export_file_name: str = fit_wrapper_sproc(
|
336
|
-
session,
|
337
|
-
queries,
|
338
|
-
stage_transform_file_name,
|
339
|
-
stage_result_file_name,
|
340
|
-
input_cols,
|
341
|
-
label_cols,
|
342
|
-
sample_weight_col,
|
343
|
-
statement_params,
|
344
|
-
)
|
345
|
-
except snowpark_exceptions.SnowparkClientException as e:
|
346
|
-
if "fit() missing 1 required positional argument: 'y'" in str(e):
|
347
|
-
raise exceptions.SnowflakeMLException(
|
348
|
-
error_code=error_codes.NOT_FOUND,
|
349
|
-
original_exception=RuntimeError(modeling_error_messages.ATTRIBUTE_NOT_SET.format("label_cols")),
|
350
|
-
) from e
|
351
|
-
raise e
|
352
|
-
|
353
|
-
if "|" in sproc_export_file_name:
|
354
|
-
fields = sproc_export_file_name.strip().split("|")
|
355
|
-
sproc_export_file_name = fields[0]
|
356
|
-
|
357
|
-
session.file.get(
|
358
|
-
posixpath.join(stage_result_file_name, sproc_export_file_name),
|
359
|
-
local_result_file_name,
|
360
|
-
statement_params=statement_params,
|
361
|
-
)
|
362
|
-
|
363
|
-
with open(os.path.join(local_result_file_name, sproc_export_file_name), mode="r+b") as result_file_obj:
|
364
|
-
fit_estimator = cp.load(result_file_obj)
|
365
|
-
|
366
|
-
cleanup_temp_files([local_transform_file_name, local_result_file_name])
|
367
|
-
|
368
|
-
return fit_estimator
|
369
|
-
|
370
50
|
def batch_inference(
|
371
51
|
self,
|
372
52
|
dataset: DataFrame,
|
@@ -690,437 +370,3 @@ class SnowparkHandlers:
|
|
690
370
|
cleanup_temp_files([local_score_file_name])
|
691
371
|
|
692
372
|
return score
|
693
|
-
|
694
|
-
def fit_search_snowpark(
|
695
|
-
self,
|
696
|
-
param_grid: Union[model_selection.ParameterGrid, model_selection.ParameterSampler],
|
697
|
-
dataset: DataFrame,
|
698
|
-
session: Session,
|
699
|
-
estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
700
|
-
dependencies: List[str],
|
701
|
-
udf_imports: List[str],
|
702
|
-
input_cols: List[str],
|
703
|
-
label_cols: List[str],
|
704
|
-
sample_weight_col: Optional[str],
|
705
|
-
) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
|
706
|
-
from itertools import product
|
707
|
-
|
708
|
-
import cachetools
|
709
|
-
from sklearn.base import clone, is_classifier
|
710
|
-
from sklearn.calibration import check_cv
|
711
|
-
|
712
|
-
# Create one stage for data and for estimators.
|
713
|
-
temp_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
|
714
|
-
temp_stage_creation_query = f"CREATE OR REPLACE TEMP STAGE {temp_stage_name};"
|
715
|
-
session.sql(temp_stage_creation_query).collect()
|
716
|
-
|
717
|
-
# Stage data.
|
718
|
-
dataset = snowpark_dataframe_utils.cast_snowpark_dataframe(dataset)
|
719
|
-
remote_file_path = f"{temp_stage_name}/{temp_stage_name}.parquet"
|
720
|
-
dataset.write.copy_into_location( # type:ignore[call-overload]
|
721
|
-
remote_file_path, file_format_type="parquet", header=True, overwrite=True
|
722
|
-
)
|
723
|
-
imports = [f"@{row.name}" for row in session.sql(f"LIST @{temp_stage_name}").collect()]
|
724
|
-
|
725
|
-
# Store GridSearchCV's refit variable. If user set it as False, we don't need to refit it again
|
726
|
-
original_refit = estimator.refit
|
727
|
-
|
728
|
-
# Create a temp file and dump the estimator to that file.
|
729
|
-
estimator_file_name = get_temp_file_path()
|
730
|
-
params_to_evaluate = []
|
731
|
-
for param_to_eval in list(param_grid):
|
732
|
-
for k, v in param_to_eval.items():
|
733
|
-
param_to_eval[k] = [v]
|
734
|
-
params_to_evaluate.append([param_to_eval])
|
735
|
-
|
736
|
-
with open(estimator_file_name, mode="w+b") as local_estimator_file_obj:
|
737
|
-
# Set GridSearchCV refit as False and fit it again after retrieving the best param
|
738
|
-
estimator.refit = False
|
739
|
-
cp.dump(dict(estimator=estimator, param_grid=params_to_evaluate), local_estimator_file_obj)
|
740
|
-
stage_estimator_file_name = posixpath.join(temp_stage_name, os.path.basename(estimator_file_name))
|
741
|
-
sproc_statement_params = telemetry.get_function_usage_statement_params(
|
742
|
-
project=_PROJECT,
|
743
|
-
subproject=self._subproject,
|
744
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
745
|
-
inspect.currentframe(), self.__class__.__name__
|
746
|
-
),
|
747
|
-
api_calls=[sproc],
|
748
|
-
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
749
|
-
)
|
750
|
-
udtf_statement_params = telemetry.get_function_usage_statement_params(
|
751
|
-
project=_PROJECT,
|
752
|
-
subproject=self._subproject,
|
753
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
754
|
-
inspect.currentframe(), self.__class__.__name__
|
755
|
-
),
|
756
|
-
api_calls=[udtf],
|
757
|
-
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
758
|
-
)
|
759
|
-
|
760
|
-
# Put locally serialized estimator on stage.
|
761
|
-
put_result = session.file.put(
|
762
|
-
estimator_file_name,
|
763
|
-
temp_stage_name,
|
764
|
-
auto_compress=False,
|
765
|
-
overwrite=True,
|
766
|
-
)
|
767
|
-
estimator_location = put_result[0].target
|
768
|
-
imports.append(f"@{temp_stage_name}/{estimator_location}")
|
769
|
-
|
770
|
-
search_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
|
771
|
-
random_udtf_name = random_name_for_temp_object(TempObjectType.FUNCTION)
|
772
|
-
|
773
|
-
required_deps = dependencies + [
|
774
|
-
"snowflake-snowpark-python<2",
|
775
|
-
"fastparquet<2023.11",
|
776
|
-
"pyarrow<14",
|
777
|
-
"cachetools<5",
|
778
|
-
]
|
779
|
-
|
780
|
-
@sproc( # type: ignore[misc]
|
781
|
-
is_permanent=False,
|
782
|
-
name=search_sproc_name,
|
783
|
-
packages=required_deps, # type: ignore[arg-type]
|
784
|
-
replace=True,
|
785
|
-
session=session,
|
786
|
-
anonymous=True,
|
787
|
-
imports=imports, # type: ignore[arg-type]
|
788
|
-
statement_params=sproc_statement_params,
|
789
|
-
)
|
790
|
-
def _distributed_search(
|
791
|
-
session: Session,
|
792
|
-
imports: List[str],
|
793
|
-
stage_estimator_file_name: str,
|
794
|
-
input_cols: List[str],
|
795
|
-
label_cols: List[str],
|
796
|
-
) -> str:
|
797
|
-
import os
|
798
|
-
import time
|
799
|
-
from typing import Iterator
|
800
|
-
|
801
|
-
import cloudpickle as cp
|
802
|
-
import pandas as pd
|
803
|
-
import pyarrow.parquet as pq
|
804
|
-
from sklearn.metrics import check_scoring
|
805
|
-
from sklearn.metrics._scorer import _check_multimetric_scoring
|
806
|
-
|
807
|
-
for import_name in udf_imports:
|
808
|
-
importlib.import_module(import_name)
|
809
|
-
|
810
|
-
data_files = [
|
811
|
-
filename
|
812
|
-
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
813
|
-
if filename.startswith(temp_stage_name)
|
814
|
-
]
|
815
|
-
partial_df = [
|
816
|
-
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
817
|
-
for file_name in data_files
|
818
|
-
]
|
819
|
-
df = pd.concat(partial_df, ignore_index=True)
|
820
|
-
df.columns = [identifier.get_inferred_name(col) for col in df.columns]
|
821
|
-
|
822
|
-
X = df[input_cols]
|
823
|
-
y = df[label_cols].squeeze()
|
824
|
-
|
825
|
-
local_estimator_file_name = get_temp_file_path()
|
826
|
-
session.file.get(stage_estimator_file_name, local_estimator_file_name)
|
827
|
-
|
828
|
-
local_estimator_file_path = os.path.join(
|
829
|
-
local_estimator_file_name, os.listdir(local_estimator_file_name)[0]
|
830
|
-
)
|
831
|
-
with open(local_estimator_file_path, mode="r+b") as local_estimator_file_obj:
|
832
|
-
estimator = cp.load(local_estimator_file_obj)["estimator"]
|
833
|
-
|
834
|
-
cv_orig = check_cv(estimator.cv, y, classifier=is_classifier(estimator.estimator))
|
835
|
-
indices = [test for _, test in cv_orig.split(X, y)]
|
836
|
-
local_indices_file_name = get_temp_file_path()
|
837
|
-
with open(local_indices_file_name, mode="w+b") as local_indices_file_obj:
|
838
|
-
cp.dump(indices, local_indices_file_obj)
|
839
|
-
|
840
|
-
# Put locally serialized indices on stage.
|
841
|
-
put_result = session.file.put(
|
842
|
-
local_indices_file_name,
|
843
|
-
temp_stage_name,
|
844
|
-
auto_compress=False,
|
845
|
-
overwrite=True,
|
846
|
-
)
|
847
|
-
indices_location = put_result[0].target
|
848
|
-
imports.append(f"@{temp_stage_name}/{indices_location}")
|
849
|
-
indices_len = len(indices)
|
850
|
-
|
851
|
-
assert estimator is not None
|
852
|
-
|
853
|
-
@cachetools.cached(cache={})
|
854
|
-
def _load_data_into_udf() -> Tuple[
|
855
|
-
Dict[str, pd.DataFrame],
|
856
|
-
Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
|
857
|
-
pd.DataFrame,
|
858
|
-
int,
|
859
|
-
List[Dict[str, Any]],
|
860
|
-
]:
|
861
|
-
import pyarrow.parquet as pq
|
862
|
-
|
863
|
-
data_files = [
|
864
|
-
filename
|
865
|
-
for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
|
866
|
-
if filename.startswith(temp_stage_name)
|
867
|
-
]
|
868
|
-
partial_df = [
|
869
|
-
pq.read_table(os.path.join(sys._xoptions["snowflake_import_directory"], file_name)).to_pandas()
|
870
|
-
for file_name in data_files
|
871
|
-
]
|
872
|
-
df = pd.concat(partial_df, ignore_index=True)
|
873
|
-
df.columns = [identifier.get_inferred_name(col) for col in df.columns]
|
874
|
-
|
875
|
-
# load estimator
|
876
|
-
local_estimator_file_path = os.path.join(
|
877
|
-
sys._xoptions["snowflake_import_directory"], f"{estimator_location}"
|
878
|
-
)
|
879
|
-
with open(local_estimator_file_path, mode="rb") as local_estimator_file_obj:
|
880
|
-
estimator_objects = cp.load(local_estimator_file_obj)
|
881
|
-
estimator = estimator_objects["estimator"]
|
882
|
-
params_to_evaluate = estimator_objects["param_grid"]
|
883
|
-
|
884
|
-
# load indices
|
885
|
-
local_indices_file_path = os.path.join(
|
886
|
-
sys._xoptions["snowflake_import_directory"], f"{indices_location}"
|
887
|
-
)
|
888
|
-
with open(local_indices_file_path, mode="rb") as local_indices_file_obj:
|
889
|
-
indices = cp.load(local_indices_file_obj)
|
890
|
-
|
891
|
-
argspec = inspect.getfullargspec(estimator.fit)
|
892
|
-
args = {"X": df[input_cols]}
|
893
|
-
|
894
|
-
if label_cols:
|
895
|
-
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
896
|
-
args[label_arg_name] = df[label_cols].squeeze()
|
897
|
-
|
898
|
-
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
899
|
-
args["sample_weight"] = df[sample_weight_col].squeeze()
|
900
|
-
return args, estimator, indices, len(df), params_to_evaluate
|
901
|
-
|
902
|
-
class SearchCV:
|
903
|
-
def __init__(self) -> None:
|
904
|
-
args, estimator, indices, data_length, params_to_evaluate = _load_data_into_udf()
|
905
|
-
self.args = args
|
906
|
-
self.estimator = estimator
|
907
|
-
self.indices = indices
|
908
|
-
self.data_length = data_length
|
909
|
-
self.params_to_evaluate = params_to_evaluate
|
910
|
-
|
911
|
-
def process(self, params_idx: int, idx: int) -> Iterator[Tuple[str]]:
|
912
|
-
if hasattr(estimator, "param_grid"):
|
913
|
-
self.estimator.param_grid = self.params_to_evaluate[params_idx]
|
914
|
-
else:
|
915
|
-
self.estimator.param_distributions = self.params_to_evaluate[params_idx]
|
916
|
-
full_indices = np.array([i for i in range(self.data_length)])
|
917
|
-
test_indice = self.indices[idx]
|
918
|
-
train_indice = np.setdiff1d(full_indices, test_indice)
|
919
|
-
self.estimator.cv = [(train_indice, test_indice)]
|
920
|
-
self.estimator.fit(**self.args)
|
921
|
-
binary_cv_results = None
|
922
|
-
with io.BytesIO() as f:
|
923
|
-
cp.dump(self.estimator.cv_results_, f)
|
924
|
-
f.seek(0)
|
925
|
-
binary_cv_results = f.getvalue().hex()
|
926
|
-
yield (binary_cv_results,)
|
927
|
-
|
928
|
-
def end_partition(self) -> None:
|
929
|
-
...
|
930
|
-
|
931
|
-
session.udtf.register(
|
932
|
-
SearchCV,
|
933
|
-
output_schema=StructType([StructField("CV_RESULTS", StringType())]),
|
934
|
-
input_types=[IntegerType(), IntegerType()],
|
935
|
-
name=random_udtf_name,
|
936
|
-
packages=required_deps, # type: ignore[arg-type]
|
937
|
-
replace=True,
|
938
|
-
is_permanent=False,
|
939
|
-
imports=imports, # type: ignore[arg-type]
|
940
|
-
statement_params=udtf_statement_params,
|
941
|
-
)
|
942
|
-
|
943
|
-
HP_TUNING = F.table_function(random_udtf_name)
|
944
|
-
|
945
|
-
idx_length = int(indices_len)
|
946
|
-
params_length = len(param_grid)
|
947
|
-
idxs = [i for i in range(idx_length)]
|
948
|
-
param_indices, training_indices = [], []
|
949
|
-
for param_idx, cv_idx in product([param_index for param_index in range(params_length)], idxs):
|
950
|
-
param_indices.append(param_idx)
|
951
|
-
training_indices.append(cv_idx)
|
952
|
-
|
953
|
-
pd_df = pd.DataFrame(
|
954
|
-
{
|
955
|
-
"PARAMS": param_indices,
|
956
|
-
"TRAIN_IND": training_indices,
|
957
|
-
"PARAM_INDEX": [i for i in range(idx_length * params_length)],
|
958
|
-
}
|
959
|
-
)
|
960
|
-
df = session.create_dataframe(pd_df)
|
961
|
-
results = df.select(
|
962
|
-
F.cast(df["PARAM_INDEX"], IntegerType()).as_("PARAM_INDEX"),
|
963
|
-
(HP_TUNING(df["PARAMS"], df["TRAIN_IND"]).over(partition_by=df["PARAM_INDEX"])),
|
964
|
-
)
|
965
|
-
|
966
|
-
# cv_result maintains the original order
|
967
|
-
multimetric = False
|
968
|
-
cv_results_ = dict()
|
969
|
-
scorers = set()
|
970
|
-
for i, val in enumerate(results.select("CV_RESULTS").sort(col("PARAM_INDEX")).collect()):
|
971
|
-
# retrieved string had one more double quote in the front and end of the string.
|
972
|
-
# use [1:-1] to remove the extra double quotes
|
973
|
-
hex_str = bytes.fromhex(val[0])
|
974
|
-
with io.BytesIO(hex_str) as f_reload:
|
975
|
-
each_cv_result = cp.load(f_reload)
|
976
|
-
for k, v in each_cv_result.items():
|
977
|
-
cur_cv = i % idx_length
|
978
|
-
key = k
|
979
|
-
if "split0_test_" in k:
|
980
|
-
# For multi-metric evaluation, the scores for all the scorers are available in the
|
981
|
-
# cv_results_ dict at the keys ending with that scorer’s name ('_<scorer_name>')
|
982
|
-
# instead of '_score'.
|
983
|
-
scorers.add(k[len("split0_test_") :])
|
984
|
-
key = k.replace("split0_test", f"split{cur_cv}_test")
|
985
|
-
elif k.startswith("param"):
|
986
|
-
if cur_cv != 0:
|
987
|
-
key = False
|
988
|
-
if key:
|
989
|
-
if key not in cv_results_:
|
990
|
-
cv_results_[key] = v
|
991
|
-
else:
|
992
|
-
cv_results_[key] = np.concatenate([cv_results_[key], v])
|
993
|
-
|
994
|
-
multimetric = len(scorers) > 1
|
995
|
-
# Use numpy to re-calculate all the information in cv_results_ again
|
996
|
-
# Generally speaking, reshape all the results into the (scorers+2, idx_length, params_length) shape,
|
997
|
-
# and average them by the idx_length;
|
998
|
-
# idx_length is the number of cv folds; params_length is the number of parameter combinations
|
999
|
-
scores = [
|
1000
|
-
np.reshape(
|
1001
|
-
np.concatenate([cv_results_[f"split{cur_cv}_test_{score}"] for cur_cv in range(idx_length)]),
|
1002
|
-
(idx_length, -1),
|
1003
|
-
)
|
1004
|
-
for score in scorers
|
1005
|
-
]
|
1006
|
-
|
1007
|
-
fit_score_test_matrix = np.stack(
|
1008
|
-
[
|
1009
|
-
np.reshape(cv_results_["mean_fit_time"], (idx_length, -1)),
|
1010
|
-
np.reshape(cv_results_["mean_score_time"], (idx_length, -1)),
|
1011
|
-
]
|
1012
|
-
+ scores
|
1013
|
-
)
|
1014
|
-
|
1015
|
-
mean_fit_score_test_matrix = np.mean(fit_score_test_matrix, axis=1)
|
1016
|
-
std_fit_score_test_matrix = np.std(fit_score_test_matrix, axis=1)
|
1017
|
-
cv_results_["std_fit_time"] = std_fit_score_test_matrix[0]
|
1018
|
-
cv_results_["mean_fit_time"] = mean_fit_score_test_matrix[0]
|
1019
|
-
cv_results_["std_score_time"] = std_fit_score_test_matrix[1]
|
1020
|
-
cv_results_["mean_score_time"] = mean_fit_score_test_matrix[1]
|
1021
|
-
for idx, score in enumerate(scorers):
|
1022
|
-
cv_results_[f"std_test_{score}"] = std_fit_score_test_matrix[idx + 2]
|
1023
|
-
cv_results_[f"mean_test_{score}"] = mean_fit_score_test_matrix[idx + 2]
|
1024
|
-
# re-compute the ranking again with mean_test_<score>.
|
1025
|
-
cv_results_[f"rank_test_{score}"] = rankdata(-cv_results_[f"mean_test_{score}"], method="min")
|
1026
|
-
# The best param is the highest ranking (which is 1) and we choose the first time ranking 1 appeared.
|
1027
|
-
# If all scores are `nan`, `rankdata` will also produce an array of `nan` values.
|
1028
|
-
# In that case, default to first index.
|
1029
|
-
best_param_index = (
|
1030
|
-
np.where(cv_results_[f"rank_test_{score}"] == 1)[0][0]
|
1031
|
-
if not np.isnan(cv_results_[f"rank_test_{score}"]).all()
|
1032
|
-
else 0
|
1033
|
-
)
|
1034
|
-
|
1035
|
-
estimator.cv_results_ = cv_results_
|
1036
|
-
estimator.multimetric_ = multimetric
|
1037
|
-
|
1038
|
-
# Reconstruct the sklearn estimator.
|
1039
|
-
refit_metric = "score"
|
1040
|
-
if callable(estimator.scoring):
|
1041
|
-
scorers = estimator.scoring
|
1042
|
-
elif estimator.scoring is None or isinstance(estimator.scoring, str):
|
1043
|
-
scorers = check_scoring(estimator.estimator, estimator.scoring)
|
1044
|
-
else:
|
1045
|
-
scorers = _check_multimetric_scoring(estimator.estimator, estimator.scoring)
|
1046
|
-
estimator._check_refit_for_multimetric(scorers)
|
1047
|
-
refit_metric = original_refit
|
1048
|
-
|
1049
|
-
estimator.scorer_ = scorers
|
1050
|
-
|
1051
|
-
# check refit_metric now for a callabe scorer that is multimetric
|
1052
|
-
if callable(estimator.scoring) and estimator.multimetric_:
|
1053
|
-
refit_metric = original_refit
|
1054
|
-
|
1055
|
-
# For multi-metric evaluation, store the best_index_, best_params_ and
|
1056
|
-
# best_score_ iff refit is one of the scorer names
|
1057
|
-
# In single metric evaluation, refit_metric is "score"
|
1058
|
-
if original_refit or not estimator.multimetric_:
|
1059
|
-
estimator.best_index_ = estimator._select_best_index(original_refit, refit_metric, cv_results_)
|
1060
|
-
if not callable(original_refit):
|
1061
|
-
# With a non-custom callable, we can select the best score
|
1062
|
-
# based on the best index
|
1063
|
-
estimator.best_score_ = cv_results_[f"mean_test_{refit_metric}"][estimator.best_index_]
|
1064
|
-
estimator.best_params_ = cv_results_["params"][best_param_index]
|
1065
|
-
|
1066
|
-
if original_refit:
|
1067
|
-
estimator.best_estimator_ = clone(estimator.estimator).set_params(
|
1068
|
-
**clone(estimator.best_params_, safe=False)
|
1069
|
-
)
|
1070
|
-
|
1071
|
-
# Let the sproc use all cores to refit.
|
1072
|
-
estimator.n_jobs = -1 if not estimator.n_jobs else estimator.n_jobs
|
1073
|
-
|
1074
|
-
# process the input as args
|
1075
|
-
argspec = inspect.getfullargspec(estimator.fit)
|
1076
|
-
args = {"X": X}
|
1077
|
-
if label_cols:
|
1078
|
-
label_arg_name = "Y" if "Y" in argspec.args else "y"
|
1079
|
-
args[label_arg_name] = y
|
1080
|
-
if sample_weight_col is not None and "sample_weight" in argspec.args:
|
1081
|
-
args["sample_weight"] = df[sample_weight_col].squeeze()
|
1082
|
-
estimator.refit = original_refit
|
1083
|
-
refit_start_time = time.time()
|
1084
|
-
estimator.best_estimator_.fit(**args)
|
1085
|
-
refit_end_time = time.time()
|
1086
|
-
estimator.refit_time_ = refit_end_time - refit_start_time
|
1087
|
-
|
1088
|
-
if hasattr(estimator.best_estimator_, "feature_names_in_"):
|
1089
|
-
estimator.feature_names_in_ = estimator.best_estimator_.feature_names_in_
|
1090
|
-
|
1091
|
-
local_result_file_name = get_temp_file_path()
|
1092
|
-
|
1093
|
-
with open(local_result_file_name, mode="w+b") as local_result_file_obj:
|
1094
|
-
cp.dump(estimator, local_result_file_obj)
|
1095
|
-
|
1096
|
-
session.file.put(
|
1097
|
-
local_result_file_name,
|
1098
|
-
temp_stage_name,
|
1099
|
-
auto_compress=False,
|
1100
|
-
overwrite=True,
|
1101
|
-
)
|
1102
|
-
|
1103
|
-
# Note: you can add something like + "|" + str(df) to the return string
|
1104
|
-
# to pass debug information to the caller.
|
1105
|
-
return str(os.path.basename(local_result_file_name))
|
1106
|
-
|
1107
|
-
sproc_export_file_name = _distributed_search(
|
1108
|
-
session,
|
1109
|
-
imports,
|
1110
|
-
stage_estimator_file_name,
|
1111
|
-
input_cols,
|
1112
|
-
label_cols,
|
1113
|
-
)
|
1114
|
-
|
1115
|
-
local_estimator_path = get_temp_file_path()
|
1116
|
-
session.file.get(
|
1117
|
-
posixpath.join(temp_stage_name, sproc_export_file_name),
|
1118
|
-
local_estimator_path,
|
1119
|
-
)
|
1120
|
-
|
1121
|
-
with open(os.path.join(local_estimator_path, sproc_export_file_name), mode="r+b") as result_file_obj:
|
1122
|
-
fit_estimator = cp.load(result_file_obj)
|
1123
|
-
|
1124
|
-
cleanup_temp_files([local_estimator_path])
|
1125
|
-
|
1126
|
-
return fit_estimator
|