snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
- snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
- snowflake/ml/modeling/cluster/birch.py +94 -124
- snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
- snowflake/ml/modeling/cluster/dbscan.py +94 -124
- snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
- snowflake/ml/modeling/cluster/k_means.py +93 -124
- snowflake/ml/modeling/cluster/mean_shift.py +94 -124
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
- snowflake/ml/modeling/cluster/optics.py +94 -124
- snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
- snowflake/ml/modeling/compose/column_transformer.py +94 -124
- snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
- snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
- snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
- snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
- snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
- snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
- snowflake/ml/modeling/covariance/oas.py +80 -110
- snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
- snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
- snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
- snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
- snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/pca.py +94 -124
- snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
- snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
- snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
- snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
- snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
- snowflake/ml/modeling/framework/base.py +2 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
- snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
- snowflake/ml/modeling/impute/knn_imputer.py +94 -124
- snowflake/ml/modeling/impute/missing_indicator.py +94 -124
- snowflake/ml/modeling/impute/simple_imputer.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
- snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
- snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/lars.py +96 -124
- snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
- snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
- snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
- snowflake/ml/modeling/linear_model/perceptron.py +95 -124
- snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ridge.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
- snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
- snowflake/ml/modeling/manifold/isomap.py +94 -124
- snowflake/ml/modeling/manifold/mds.py +94 -124
- snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
- snowflake/ml/modeling/manifold/tsne.py +94 -124
- snowflake/ml/modeling/metrics/classification.py +187 -52
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
- snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
- snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
- snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
- snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
- snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
- snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
- snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
- snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
- snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
- snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
- snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
- snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
- snowflake/ml/modeling/svm/linear_svc.py +96 -124
- snowflake/ml/modeling/svm/linear_svr.py +96 -124
- snowflake/ml/modeling/svm/nu_svc.py +96 -124
- snowflake/ml/modeling/svm/nu_svr.py +96 -124
- snowflake/ml/modeling/svm/svc.py +96 -124
- snowflake/ml/modeling/svm/svr.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
- snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -2,13 +2,13 @@
|
|
2
2
|
# This code is auto-generated using the sklearn_wrapper_template.py_template template.
|
3
3
|
# Do not modify the auto-generated code(except automatic reformatting by precommit hooks).
|
4
4
|
#
|
5
|
-
from typing import Dict, Iterable, List, Optional, Set, Union
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Set, Union
|
6
6
|
from uuid import uuid4
|
7
7
|
|
8
|
+
import cloudpickle as cp
|
8
9
|
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
import sklearn.model_selection
|
11
|
-
from sklearn.model_selection import ParameterGrid
|
12
12
|
from sklearn.utils.metaestimators import available_if
|
13
13
|
|
14
14
|
from snowflake.ml._internal import telemetry
|
@@ -25,13 +25,12 @@ from snowflake.ml.model.model_signature import (
|
|
25
25
|
from snowflake.ml.modeling._internal.estimator_protocols import CVHandlers
|
26
26
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
27
27
|
gather_dependencies,
|
28
|
-
is_single_node,
|
29
28
|
original_estimator_has_callable,
|
30
29
|
transform_snowml_obj_to_sklearn_obj,
|
31
30
|
validate_sklearn_args,
|
32
31
|
)
|
32
|
+
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
33
33
|
from snowflake.ml.modeling._internal.snowpark_handlers import (
|
34
|
-
SklearnModelSelectionWrapperProvider,
|
35
34
|
SnowparkHandlers as HandlersImpl,
|
36
35
|
)
|
37
36
|
from snowflake.ml.modeling.framework.base import BaseTransformer
|
@@ -53,19 +52,54 @@ class GridSearchCV(BaseTransformer):
|
|
53
52
|
|
54
53
|
Parameters
|
55
54
|
----------
|
56
|
-
estimator
|
55
|
+
estimator: estimator object
|
57
56
|
This is assumed to implement the scikit-learn estimator interface.
|
58
57
|
Either estimator needs to provide a ``score`` function,
|
59
58
|
or ``scoring`` must be passed.
|
60
59
|
|
61
|
-
param_grid
|
60
|
+
param_grid: dict or list of dictionaries
|
62
61
|
Dictionary with parameters names (`str`) as keys and lists of
|
63
62
|
parameter settings to try as values, or a list of such
|
64
63
|
dictionaries, in which case the grids spanned by each dictionary
|
65
64
|
in the list are explored. This enables searching over any sequence
|
66
65
|
of parameter settings.
|
67
66
|
|
68
|
-
|
67
|
+
input_cols: Optional[Union[str, List[str]]]
|
68
|
+
A string or list of strings representing column names that contain features.
|
69
|
+
If this parameter is not specified, all columns in the input DataFrame except
|
70
|
+
the columns specified by label_cols and sample-weight_col parameters are
|
71
|
+
considered input columns.
|
72
|
+
|
73
|
+
label_cols: Optional[Union[str, List[str]]]
|
74
|
+
A string or list of strings representing column names that contain labels.
|
75
|
+
This is a required param for estimators, as there is no way to infer these
|
76
|
+
columns. If this parameter is not specified, then object is fitted without
|
77
|
+
labels(Like a transformer).
|
78
|
+
|
79
|
+
output_cols: Optional[Union[str, List[str]]]
|
80
|
+
A string or list of strings representing column names that will store the
|
81
|
+
output of predict and transform operations. The length of output_cols mus
|
82
|
+
match the expected number of output columns from the specific estimator or
|
83
|
+
transformer class used.
|
84
|
+
If this parameter is not specified, output column names are derived by
|
85
|
+
adding an OUTPUT_ prefix to the label column names. These inferred output
|
86
|
+
column names work for estimator's predict() method, but output_cols must
|
87
|
+
be set explicitly for transformers.
|
88
|
+
|
89
|
+
passthrough_cols: A string or a list of strings indicating column names to be excluded from any
|
90
|
+
operations (such as train, transform, or inference). These specified column(s)
|
91
|
+
will remain untouched throughout the process. This option is helpful in scenarios
|
92
|
+
requiring automatic input_cols inference, but need to avoid using specific
|
93
|
+
columns, like index columns, during training or inference.
|
94
|
+
|
95
|
+
sample_weight_col: Optional[str]
|
96
|
+
A string representing the column name containing the examples’ weights.
|
97
|
+
This argument is only required when working with weighted datasets.
|
98
|
+
|
99
|
+
drop_input_cols: Optional[bool], default=False
|
100
|
+
If set, the response of predict(), transform() methods will not contain input columns.
|
101
|
+
|
102
|
+
scoring: str, callable, list, tuple or dict, default=None
|
69
103
|
Strategy to evaluate the performance of the cross-validated model on
|
70
104
|
the test set.
|
71
105
|
|
@@ -83,13 +117,13 @@ class GridSearchCV(BaseTransformer):
|
|
83
117
|
|
84
118
|
See :ref:`multimetric_grid_search` for an example.
|
85
119
|
|
86
|
-
n_jobs
|
120
|
+
n_jobs: int, default=None
|
87
121
|
Number of jobs to run in parallel.
|
88
122
|
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
|
89
123
|
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
|
90
124
|
for more details.
|
91
125
|
|
92
|
-
refit
|
126
|
+
refit: bool, str, or callable, default=True
|
93
127
|
Refit an estimator using the best found parameters on the whole
|
94
128
|
dataset.
|
95
129
|
|
@@ -120,7 +154,7 @@ class GridSearchCV(BaseTransformer):
|
|
120
154
|
to see how to design a custom selection strategy using a callable
|
121
155
|
via `refit`.
|
122
156
|
|
123
|
-
cv
|
157
|
+
cv: int, cross-validation generator or an iterable, default=None
|
124
158
|
Determines the cross-validation splitting strategy.
|
125
159
|
Possible inputs for cv are:
|
126
160
|
|
@@ -137,7 +171,7 @@ class GridSearchCV(BaseTransformer):
|
|
137
171
|
Refer :ref:`User Guide <cross_validation>` for the various
|
138
172
|
cross-validation strategies that can be used here.
|
139
173
|
|
140
|
-
verbose
|
174
|
+
verbose: int
|
141
175
|
Controls the verbosity: the higher, the more messages.
|
142
176
|
|
143
177
|
- >1 : the computation time for each fold and parameter candidate is
|
@@ -146,7 +180,7 @@ class GridSearchCV(BaseTransformer):
|
|
146
180
|
- >3 : the fold and candidate parameter indexes are also displayed
|
147
181
|
together with the starting time of the computation.
|
148
182
|
|
149
|
-
pre_dispatch
|
183
|
+
pre_dispatch: int, or str, default='2*n_jobs'
|
150
184
|
Controls the number of jobs that get dispatched during parallel
|
151
185
|
execution. Reducing this number can be useful to avoid an
|
152
186
|
explosion of memory consumption when more jobs get dispatched
|
@@ -163,13 +197,13 @@ class GridSearchCV(BaseTransformer):
|
|
163
197
|
- A str, giving an expression as a function of n_jobs,
|
164
198
|
as in '2*n_jobs'
|
165
199
|
|
166
|
-
error_score
|
200
|
+
error_score: 'raise' or numeric, default=np.nan
|
167
201
|
Value to assign to the score if an error occurs in estimator fitting.
|
168
202
|
If set to 'raise', the error is raised. If a numeric value is given,
|
169
203
|
FitFailedWarning is raised. This parameter does not affect the refit
|
170
204
|
step, which will always raise the error.
|
171
205
|
|
172
|
-
return_train_score
|
206
|
+
return_train_score: bool, default=False
|
173
207
|
If ``False``, the ``cv_results_`` attribute will not include training
|
174
208
|
scores.
|
175
209
|
Computing training scores is used to get insights on how different
|
@@ -177,41 +211,6 @@ class GridSearchCV(BaseTransformer):
|
|
177
211
|
However computing the scores on the training set can be computationally
|
178
212
|
expensive and is not strictly required to select the parameters that
|
179
213
|
yield the best generalization performance.
|
180
|
-
|
181
|
-
input_cols : Optional[Union[str, List[str]]]
|
182
|
-
A string or list of strings representing column names that contain features.
|
183
|
-
If this parameter is not specified, all columns in the input DataFrame except
|
184
|
-
the columns specified by label_cols and sample-weight_col parameters are
|
185
|
-
considered input columns.
|
186
|
-
|
187
|
-
label_cols : Optional[Union[str, List[str]]]
|
188
|
-
A string or list of strings representing column names that contain labels.
|
189
|
-
This is a required param for estimators, as there is no way to infer these
|
190
|
-
columns. If this parameter is not specified, then object is fitted without
|
191
|
-
labels(Like a transformer).
|
192
|
-
|
193
|
-
output_cols: Optional[Union[str, List[str]]]
|
194
|
-
A string or list of strings representing column names that will store the
|
195
|
-
output of predict and transform operations. The length of output_cols mus
|
196
|
-
match the expected number of output columns from the specific estimator or
|
197
|
-
transformer class used.
|
198
|
-
If this parameter is not specified, output column names are derived by
|
199
|
-
adding an OUTPUT_ prefix to the label column names. These inferred output
|
200
|
-
column names work for estimator's predict() method, but output_cols must
|
201
|
-
be set explicitly for transformers.
|
202
|
-
|
203
|
-
passthrough_cols: A string or a list of strings indicating column names to be excluded from any
|
204
|
-
operations (such as train, transform, or inference). These specified column(s)
|
205
|
-
will remain untouched throughout the process. This option is helpful in scenarios
|
206
|
-
requiring automatic input_cols inference, but need to avoid using specific
|
207
|
-
columns, like index columns, during training or inference.
|
208
|
-
|
209
|
-
sample_weight_col: Optional[str]
|
210
|
-
A string representing the column name containing the examples’ weights.
|
211
|
-
This argument is only required when working with weighted datasets.
|
212
|
-
|
213
|
-
drop_input_cols: Optional[bool], default=False
|
214
|
-
If set, the response of predict(), transform() methods will not contain input columns.
|
215
214
|
"""
|
216
215
|
_ENABLE_DISTRIBUTED = True
|
217
216
|
|
@@ -236,7 +235,11 @@ class GridSearchCV(BaseTransformer):
|
|
236
235
|
sample_weight_col: Optional[str] = None,
|
237
236
|
) -> None:
|
238
237
|
super().__init__()
|
239
|
-
deps: Set[str] =
|
238
|
+
deps: Set[str] = {
|
239
|
+
f"numpy=={np.__version__}",
|
240
|
+
f"scikit-learn=={sklearn.__version__}",
|
241
|
+
f"cloudpickle=={cp.__version__}",
|
242
|
+
}
|
240
243
|
deps = deps | gather_dependencies(estimator)
|
241
244
|
self._deps = list(deps)
|
242
245
|
estimator = transform_snowml_obj_to_sklearn_obj(estimator)
|
@@ -253,7 +256,7 @@ class GridSearchCV(BaseTransformer):
|
|
253
256
|
"return_train_score": (return_train_score, False, False),
|
254
257
|
}
|
255
258
|
cleaned_up_init_args = validate_sklearn_args(args=init_args, klass=sklearn.model_selection.GridSearchCV)
|
256
|
-
self._sklearn_object = sklearn.model_selection.GridSearchCV(
|
259
|
+
self._sklearn_object: Any = sklearn.model_selection.GridSearchCV(
|
257
260
|
**cleaned_up_init_args,
|
258
261
|
)
|
259
262
|
self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
|
@@ -266,7 +269,6 @@ class GridSearchCV(BaseTransformer):
|
|
266
269
|
self._handlers: CVHandlers = HandlersImpl(
|
267
270
|
class_name=self.__class__.__name__,
|
268
271
|
subproject=_SUBPROJECT,
|
269
|
-
wrapper_provider=SklearnModelSelectionWrapperProvider(),
|
270
272
|
)
|
271
273
|
|
272
274
|
def _get_rand_id(self) -> str:
|
@@ -294,10 +296,6 @@ class GridSearchCV(BaseTransformer):
|
|
294
296
|
For more details on this function, see [sklearn.model_selection.GridSearchCV.fit]
|
295
297
|
(https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV.fit)
|
296
298
|
|
297
|
-
|
298
|
-
Raises:
|
299
|
-
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
300
|
-
|
301
299
|
Args:
|
302
300
|
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
303
301
|
Snowpark or Pandas DataFrame.
|
@@ -306,70 +304,37 @@ class GridSearchCV(BaseTransformer):
|
|
306
304
|
self
|
307
305
|
"""
|
308
306
|
self._infer_input_output_cols(dataset)
|
309
|
-
if
|
310
|
-
self.
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
"Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
|
307
|
+
if self._sklearn_object.n_jobs is None:
|
308
|
+
self._sklearn_object.n_jobs = -1
|
309
|
+
if isinstance(dataset, DataFrame):
|
310
|
+
session = dataset._session
|
311
|
+
assert session is not None # keep mypy happy
|
312
|
+
# Validate that key package version in user workspace are supported in snowflake conda channel
|
313
|
+
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
314
|
+
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
315
|
+
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
319
316
|
)
|
320
|
-
self._is_fitted = True
|
321
|
-
self._get_model_signatures(dataset)
|
322
|
-
return self
|
323
317
|
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
329
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
330
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
331
|
-
)
|
318
|
+
# Specify input columns so column pruning will be enforced
|
319
|
+
selected_cols = self._get_active_columns()
|
320
|
+
if len(selected_cols) > 0:
|
321
|
+
dataset = dataset.select(selected_cols)
|
332
322
|
|
333
|
-
|
334
|
-
if len(selected_cols) > 0:
|
335
|
-
dataset = dataset.select(selected_cols)
|
323
|
+
self._snowpark_cols = dataset.select(self.input_cols).columns
|
336
324
|
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
session=session,
|
351
|
-
estimator=self._sklearn_object,
|
352
|
-
dependencies=self._get_dependencies(),
|
353
|
-
udf_imports=["sklearn"],
|
354
|
-
input_cols=self.input_cols,
|
355
|
-
label_cols=self.label_cols,
|
356
|
-
sample_weight_col=self.sample_weight_col,
|
357
|
-
)
|
358
|
-
else:
|
359
|
-
# Fall back with stored procedure implementation
|
360
|
-
# set the parallel factor to default to minus one, to fully accelerate the cores in single node
|
361
|
-
if self._sklearn_object.n_jobs is None:
|
362
|
-
self._sklearn_object.n_jobs = -1
|
363
|
-
|
364
|
-
self._sklearn_object = self._handlers.fit_snowpark(
|
365
|
-
dataset,
|
366
|
-
session,
|
367
|
-
self._sklearn_object,
|
368
|
-
["snowflake-snowpark-python"] + self._get_dependencies(),
|
369
|
-
self.input_cols,
|
370
|
-
self.label_cols,
|
371
|
-
self.sample_weight_col,
|
372
|
-
)
|
325
|
+
model_trainer = ModelTrainerBuilder.build(
|
326
|
+
estimator=self._sklearn_object,
|
327
|
+
dataset=dataset,
|
328
|
+
input_cols=self.input_cols,
|
329
|
+
label_cols=self.label_cols,
|
330
|
+
sample_weight_col=self.sample_weight_col,
|
331
|
+
autogenerated=False,
|
332
|
+
subproject=_SUBPROJECT,
|
333
|
+
)
|
334
|
+
self._sklearn_object = model_trainer.train()
|
335
|
+
self._is_fitted = True
|
336
|
+
self._get_model_signatures(dataset)
|
337
|
+
return self
|
373
338
|
|
374
339
|
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
375
340
|
if self._drop_input_cols:
|
@@ -523,10 +488,6 @@ class GridSearchCV(BaseTransformer):
|
|
523
488
|
project=_PROJECT,
|
524
489
|
subproject=_SUBPROJECT,
|
525
490
|
)
|
526
|
-
@telemetry.add_stmt_params_to_df(
|
527
|
-
project=_PROJECT,
|
528
|
-
subproject=_SUBPROJECT,
|
529
|
-
)
|
530
491
|
def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
531
492
|
"""Call predict on the estimator with the best found parameters
|
532
493
|
For more details on this function, see [sklearn.model_selection.GridSearchCV.predict]
|
@@ -569,10 +530,6 @@ class GridSearchCV(BaseTransformer):
|
|
569
530
|
project=_PROJECT,
|
570
531
|
subproject=_SUBPROJECT,
|
571
532
|
)
|
572
|
-
@telemetry.add_stmt_params_to_df(
|
573
|
-
project=_PROJECT,
|
574
|
-
subproject=_SUBPROJECT,
|
575
|
-
)
|
576
533
|
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
577
534
|
"""Call transform on the estimator with the best found parameters
|
578
535
|
For more details on this function, see [sklearn.model_selection.GridSearchCV.transform]
|
@@ -636,10 +593,6 @@ class GridSearchCV(BaseTransformer):
|
|
636
593
|
project=_PROJECT,
|
637
594
|
subproject=_SUBPROJECT,
|
638
595
|
)
|
639
|
-
@telemetry.add_stmt_params_to_df(
|
640
|
-
project=_PROJECT,
|
641
|
-
subproject=_SUBPROJECT,
|
642
|
-
)
|
643
596
|
def predict_proba(
|
644
597
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
|
645
598
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -677,10 +630,6 @@ class GridSearchCV(BaseTransformer):
|
|
677
630
|
project=_PROJECT,
|
678
631
|
subproject=_SUBPROJECT,
|
679
632
|
)
|
680
|
-
@telemetry.add_stmt_params_to_df(
|
681
|
-
project=_PROJECT,
|
682
|
-
subproject=_SUBPROJECT,
|
683
|
-
)
|
684
633
|
def predict_log_proba(
|
685
634
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
|
686
635
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -719,10 +668,6 @@ class GridSearchCV(BaseTransformer):
|
|
719
668
|
project=_PROJECT,
|
720
669
|
subproject=_SUBPROJECT,
|
721
670
|
)
|
722
|
-
@telemetry.add_stmt_params_to_df(
|
723
|
-
project=_PROJECT,
|
724
|
-
subproject=_SUBPROJECT,
|
725
|
-
)
|
726
671
|
def decision_function(
|
727
672
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
|
728
673
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -759,6 +704,8 @@ class GridSearchCV(BaseTransformer):
|
|
759
704
|
@available_if(original_estimator_has_callable("score")) # type: ignore[misc]
|
760
705
|
def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float:
|
761
706
|
"""
|
707
|
+
If implemented by the original estimator, return the score for the dataset.
|
708
|
+
|
762
709
|
Args:
|
763
710
|
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
764
711
|
Snowpark or Pandas DataFrame.
|
@@ -811,9 +758,9 @@ class GridSearchCV(BaseTransformer):
|
|
811
758
|
# For classifier, the type of predict is the same as the type of label
|
812
759
|
if self._sklearn_object._estimator_type == "classifier":
|
813
760
|
# label columns is the desired type for output
|
814
|
-
outputs = _infer_signature(dataset[self.label_cols], "output")
|
761
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output"))
|
815
762
|
# rename the output columns
|
816
|
-
outputs = model_signature_utils.rename_features(outputs, self.output_cols)
|
763
|
+
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
817
764
|
self._model_signature_dict["predict"] = ModelSignature(
|
818
765
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
819
766
|
)
|
@@ -850,6 +797,9 @@ class GridSearchCV(BaseTransformer):
|
|
850
797
|
return self._model_signature_dict
|
851
798
|
|
852
799
|
def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
|
800
|
+
"""
|
801
|
+
Get sklearn.model_selection.GridSearchCV object.
|
802
|
+
"""
|
853
803
|
assert self._sklearn_object is not None
|
854
804
|
return self._sklearn_object
|
855
805
|
|