snowflake-ml-python 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/identifier.py +78 -72
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +108 -135
- snowflake/ml/modeling/cluster/affinity_propagation.py +106 -135
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +106 -135
- snowflake/ml/modeling/cluster/birch.py +106 -135
- snowflake/ml/modeling/cluster/bisecting_k_means.py +106 -135
- snowflake/ml/modeling/cluster/dbscan.py +106 -135
- snowflake/ml/modeling/cluster/feature_agglomeration.py +106 -135
- snowflake/ml/modeling/cluster/k_means.py +105 -135
- snowflake/ml/modeling/cluster/mean_shift.py +106 -135
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +105 -135
- snowflake/ml/modeling/cluster/optics.py +106 -135
- snowflake/ml/modeling/cluster/spectral_biclustering.py +106 -135
- snowflake/ml/modeling/cluster/spectral_clustering.py +106 -135
- snowflake/ml/modeling/cluster/spectral_coclustering.py +106 -135
- snowflake/ml/modeling/compose/column_transformer.py +106 -135
- snowflake/ml/modeling/compose/transformed_target_regressor.py +108 -135
- snowflake/ml/modeling/covariance/elliptic_envelope.py +106 -135
- snowflake/ml/modeling/covariance/empirical_covariance.py +99 -128
- snowflake/ml/modeling/covariance/graphical_lasso.py +106 -135
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +106 -135
- snowflake/ml/modeling/covariance/ledoit_wolf.py +104 -133
- snowflake/ml/modeling/covariance/min_cov_det.py +106 -135
- snowflake/ml/modeling/covariance/oas.py +99 -128
- snowflake/ml/modeling/covariance/shrunk_covariance.py +103 -132
- snowflake/ml/modeling/decomposition/dictionary_learning.py +106 -135
- snowflake/ml/modeling/decomposition/factor_analysis.py +106 -135
- snowflake/ml/modeling/decomposition/fast_ica.py +106 -135
- snowflake/ml/modeling/decomposition/incremental_pca.py +106 -135
- snowflake/ml/modeling/decomposition/kernel_pca.py +106 -135
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +106 -135
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +106 -135
- snowflake/ml/modeling/decomposition/pca.py +106 -135
- snowflake/ml/modeling/decomposition/sparse_pca.py +106 -135
- snowflake/ml/modeling/decomposition/truncated_svd.py +106 -135
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +108 -135
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +108 -135
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/bagging_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/bagging_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/isolation_forest.py +106 -135
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/stacking_regressor.py +108 -135
- snowflake/ml/modeling/ensemble/voting_classifier.py +108 -135
- snowflake/ml/modeling/ensemble/voting_regressor.py +108 -135
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +101 -128
- snowflake/ml/modeling/feature_selection/select_fdr.py +99 -126
- snowflake/ml/modeling/feature_selection/select_fpr.py +99 -126
- snowflake/ml/modeling/feature_selection/select_fwe.py +99 -126
- snowflake/ml/modeling/feature_selection/select_k_best.py +100 -127
- snowflake/ml/modeling/feature_selection/select_percentile.py +99 -126
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +106 -135
- snowflake/ml/modeling/feature_selection/variance_threshold.py +95 -124
- snowflake/ml/modeling/framework/base.py +83 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +108 -135
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +108 -135
- snowflake/ml/modeling/impute/iterative_imputer.py +106 -135
- snowflake/ml/modeling/impute/knn_imputer.py +106 -135
- snowflake/ml/modeling/impute/missing_indicator.py +106 -135
- snowflake/ml/modeling/impute/simple_imputer.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +96 -125
- snowflake/ml/modeling/kernel_approximation/nystroem.py +106 -135
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +106 -135
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +105 -134
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +103 -132
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +108 -135
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +90 -118
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +90 -118
- snowflake/ml/modeling/linear_model/ard_regression.py +108 -135
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +108 -135
- snowflake/ml/modeling/linear_model/elastic_net.py +108 -135
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +108 -135
- snowflake/ml/modeling/linear_model/gamma_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/huber_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/lars.py +108 -135
- snowflake/ml/modeling/linear_model/lars_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +108 -135
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +108 -135
- snowflake/ml/modeling/linear_model/linear_regression.py +108 -135
- snowflake/ml/modeling/linear_model/logistic_regression.py +108 -135
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +108 -135
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +108 -135
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +108 -135
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +107 -135
- snowflake/ml/modeling/linear_model/perceptron.py +107 -135
- snowflake/ml/modeling/linear_model/poisson_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/ransac_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/ridge.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +108 -135
- snowflake/ml/modeling/linear_model/ridge_cv.py +108 -135
- snowflake/ml/modeling/linear_model/sgd_classifier.py +108 -135
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +106 -135
- snowflake/ml/modeling/linear_model/sgd_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +108 -135
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +108 -135
- snowflake/ml/modeling/manifold/isomap.py +106 -135
- snowflake/ml/modeling/manifold/mds.py +106 -135
- snowflake/ml/modeling/manifold/spectral_embedding.py +106 -135
- snowflake/ml/modeling/manifold/tsne.py +106 -135
- snowflake/ml/modeling/metrics/classification.py +196 -55
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +106 -135
- snowflake/ml/modeling/mixture/gaussian_mixture.py +106 -135
- snowflake/ml/modeling/model_selection/grid_search_cv.py +91 -148
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +93 -154
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +105 -132
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +108 -135
- snowflake/ml/modeling/multiclass/output_code_classifier.py +108 -135
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/complement_nb.py +108 -135
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +98 -125
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +107 -134
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +108 -135
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +108 -135
- snowflake/ml/modeling/neighbors/kernel_density.py +106 -135
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +106 -135
- snowflake/ml/modeling/neighbors/nearest_centroid.py +108 -135
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +106 -135
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +108 -135
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +108 -135
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +108 -135
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +106 -135
- snowflake/ml/modeling/neural_network/mlp_classifier.py +108 -135
- snowflake/ml/modeling/neural_network/mlp_regressor.py +108 -135
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +25 -8
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +9 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +31 -11
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +27 -9
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +42 -14
- snowflake/ml/modeling/preprocessing/normalizer.py +9 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +26 -10
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +37 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +106 -135
- snowflake/ml/modeling/preprocessing/robust_scaler.py +39 -13
- snowflake/ml/modeling/preprocessing/standard_scaler.py +36 -12
- snowflake/ml/modeling/semi_supervised/label_propagation.py +108 -135
- snowflake/ml/modeling/semi_supervised/label_spreading.py +108 -135
- snowflake/ml/modeling/svm/linear_svc.py +108 -135
- snowflake/ml/modeling/svm/linear_svr.py +108 -135
- snowflake/ml/modeling/svm/nu_svc.py +108 -135
- snowflake/ml/modeling/svm/nu_svr.py +108 -135
- snowflake/ml/modeling/svm/svc.py +108 -135
- snowflake/ml/modeling/svm/svr.py +108 -135
- snowflake/ml/modeling/tree/decision_tree_classifier.py +108 -135
- snowflake/ml/modeling/tree/decision_tree_regressor.py +108 -135
- snowflake/ml/modeling/tree/extra_tree_classifier.py +108 -135
- snowflake/ml/modeling/tree/extra_tree_regressor.py +108 -135
- snowflake/ml/modeling/xgboost/xgb_classifier.py +108 -136
- snowflake/ml/modeling/xgboost/xgb_regressor.py +108 -136
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +108 -136
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +108 -136
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +34 -1
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.0.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -2,13 +2,13 @@
|
|
2
2
|
# This code is auto-generated using the sklearn_wrapper_template.py_template template.
|
3
3
|
# Do not modify the auto-generated code(except automatic reformatting by precommit hooks).
|
4
4
|
#
|
5
|
-
from typing import Dict, Iterable, List, Optional, Set, Union
|
5
|
+
from typing import Any, Dict, Iterable, List, Optional, Set, Union
|
6
6
|
from uuid import uuid4
|
7
7
|
|
8
|
+
import cloudpickle as cp
|
8
9
|
import numpy as np
|
9
10
|
import pandas as pd
|
10
11
|
import sklearn.model_selection
|
11
|
-
from sklearn.model_selection import ParameterGrid
|
12
12
|
from sklearn.utils.metaestimators import available_if
|
13
13
|
|
14
14
|
from snowflake.ml._internal import telemetry
|
@@ -25,13 +25,12 @@ from snowflake.ml.model.model_signature import (
|
|
25
25
|
from snowflake.ml.modeling._internal.estimator_protocols import CVHandlers
|
26
26
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
27
27
|
gather_dependencies,
|
28
|
-
is_single_node,
|
29
28
|
original_estimator_has_callable,
|
30
29
|
transform_snowml_obj_to_sklearn_obj,
|
31
30
|
validate_sklearn_args,
|
32
31
|
)
|
32
|
+
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
33
33
|
from snowflake.ml.modeling._internal.snowpark_handlers import (
|
34
|
-
SklearnModelSelectionWrapperProvider,
|
35
34
|
SnowparkHandlers as HandlersImpl,
|
36
35
|
)
|
37
36
|
from snowflake.ml.modeling.framework.base import BaseTransformer
|
@@ -53,19 +52,54 @@ class GridSearchCV(BaseTransformer):
|
|
53
52
|
|
54
53
|
Parameters
|
55
54
|
----------
|
56
|
-
estimator
|
55
|
+
estimator: estimator object
|
57
56
|
This is assumed to implement the scikit-learn estimator interface.
|
58
57
|
Either estimator needs to provide a ``score`` function,
|
59
58
|
or ``scoring`` must be passed.
|
60
59
|
|
61
|
-
param_grid
|
60
|
+
param_grid: dict or list of dictionaries
|
62
61
|
Dictionary with parameters names (`str`) as keys and lists of
|
63
62
|
parameter settings to try as values, or a list of such
|
64
63
|
dictionaries, in which case the grids spanned by each dictionary
|
65
64
|
in the list are explored. This enables searching over any sequence
|
66
65
|
of parameter settings.
|
67
66
|
|
68
|
-
|
67
|
+
input_cols: Optional[Union[str, List[str]]]
|
68
|
+
A string or list of strings representing column names that contain features.
|
69
|
+
If this parameter is not specified, all columns in the input DataFrame except
|
70
|
+
the columns specified by label_cols and sample-weight_col parameters are
|
71
|
+
considered input columns.
|
72
|
+
|
73
|
+
label_cols: Optional[Union[str, List[str]]]
|
74
|
+
A string or list of strings representing column names that contain labels.
|
75
|
+
This is a required param for estimators, as there is no way to infer these
|
76
|
+
columns. If this parameter is not specified, then object is fitted without
|
77
|
+
labels(Like a transformer).
|
78
|
+
|
79
|
+
output_cols: Optional[Union[str, List[str]]]
|
80
|
+
A string or list of strings representing column names that will store the
|
81
|
+
output of predict and transform operations. The length of output_cols mus
|
82
|
+
match the expected number of output columns from the specific estimator or
|
83
|
+
transformer class used.
|
84
|
+
If this parameter is not specified, output column names are derived by
|
85
|
+
adding an OUTPUT_ prefix to the label column names. These inferred output
|
86
|
+
column names work for estimator's predict() method, but output_cols must
|
87
|
+
be set explicitly for transformers.
|
88
|
+
|
89
|
+
passthrough_cols: A string or a list of strings indicating column names to be excluded from any
|
90
|
+
operations (such as train, transform, or inference). These specified column(s)
|
91
|
+
will remain untouched throughout the process. This option is helpful in scenarios
|
92
|
+
requiring automatic input_cols inference, but need to avoid using specific
|
93
|
+
columns, like index columns, during training or inference.
|
94
|
+
|
95
|
+
sample_weight_col: Optional[str]
|
96
|
+
A string representing the column name containing the examples’ weights.
|
97
|
+
This argument is only required when working with weighted datasets.
|
98
|
+
|
99
|
+
drop_input_cols: Optional[bool], default=False
|
100
|
+
If set, the response of predict(), transform() methods will not contain input columns.
|
101
|
+
|
102
|
+
scoring: str, callable, list, tuple or dict, default=None
|
69
103
|
Strategy to evaluate the performance of the cross-validated model on
|
70
104
|
the test set.
|
71
105
|
|
@@ -83,13 +117,13 @@ class GridSearchCV(BaseTransformer):
|
|
83
117
|
|
84
118
|
See :ref:`multimetric_grid_search` for an example.
|
85
119
|
|
86
|
-
n_jobs
|
120
|
+
n_jobs: int, default=None
|
87
121
|
Number of jobs to run in parallel.
|
88
122
|
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
|
89
123
|
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
|
90
124
|
for more details.
|
91
125
|
|
92
|
-
refit
|
126
|
+
refit: bool, str, or callable, default=True
|
93
127
|
Refit an estimator using the best found parameters on the whole
|
94
128
|
dataset.
|
95
129
|
|
@@ -120,7 +154,7 @@ class GridSearchCV(BaseTransformer):
|
|
120
154
|
to see how to design a custom selection strategy using a callable
|
121
155
|
via `refit`.
|
122
156
|
|
123
|
-
cv
|
157
|
+
cv: int, cross-validation generator or an iterable, default=None
|
124
158
|
Determines the cross-validation splitting strategy.
|
125
159
|
Possible inputs for cv are:
|
126
160
|
|
@@ -137,7 +171,7 @@ class GridSearchCV(BaseTransformer):
|
|
137
171
|
Refer :ref:`User Guide <cross_validation>` for the various
|
138
172
|
cross-validation strategies that can be used here.
|
139
173
|
|
140
|
-
verbose
|
174
|
+
verbose: int
|
141
175
|
Controls the verbosity: the higher, the more messages.
|
142
176
|
|
143
177
|
- >1 : the computation time for each fold and parameter candidate is
|
@@ -146,7 +180,7 @@ class GridSearchCV(BaseTransformer):
|
|
146
180
|
- >3 : the fold and candidate parameter indexes are also displayed
|
147
181
|
together with the starting time of the computation.
|
148
182
|
|
149
|
-
pre_dispatch
|
183
|
+
pre_dispatch: int, or str, default='2*n_jobs'
|
150
184
|
Controls the number of jobs that get dispatched during parallel
|
151
185
|
execution. Reducing this number can be useful to avoid an
|
152
186
|
explosion of memory consumption when more jobs get dispatched
|
@@ -163,13 +197,13 @@ class GridSearchCV(BaseTransformer):
|
|
163
197
|
- A str, giving an expression as a function of n_jobs,
|
164
198
|
as in '2*n_jobs'
|
165
199
|
|
166
|
-
error_score
|
200
|
+
error_score: 'raise' or numeric, default=np.nan
|
167
201
|
Value to assign to the score if an error occurs in estimator fitting.
|
168
202
|
If set to 'raise', the error is raised. If a numeric value is given,
|
169
203
|
FitFailedWarning is raised. This parameter does not affect the refit
|
170
204
|
step, which will always raise the error.
|
171
205
|
|
172
|
-
return_train_score
|
206
|
+
return_train_score: bool, default=False
|
173
207
|
If ``False``, the ``cv_results_`` attribute will not include training
|
174
208
|
scores.
|
175
209
|
Computing training scores is used to get insights on how different
|
@@ -177,35 +211,6 @@ class GridSearchCV(BaseTransformer):
|
|
177
211
|
However computing the scores on the training set can be computationally
|
178
212
|
expensive and is not strictly required to select the parameters that
|
179
213
|
yield the best generalization performance.
|
180
|
-
|
181
|
-
input_cols : Optional[Union[str, List[str]]]
|
182
|
-
A string or list of strings representing column names that contain features.
|
183
|
-
If this parameter is not specified, all columns in the input DataFrame except
|
184
|
-
the columns specified by label_cols and sample-weight_col parameters are
|
185
|
-
considered input columns.
|
186
|
-
|
187
|
-
label_cols : Optional[Union[str, List[str]]]
|
188
|
-
A string or list of strings representing column names that contain labels.
|
189
|
-
This is a required param for estimators, as there is no way to infer these
|
190
|
-
columns. If this parameter is not specified, then object is fitted without
|
191
|
-
labels(Like a transformer).
|
192
|
-
|
193
|
-
output_cols: Optional[Union[str, List[str]]]
|
194
|
-
A string or list of strings representing column names that will store the
|
195
|
-
output of predict and transform operations. The length of output_cols mus
|
196
|
-
match the expected number of output columns from the specific estimator or
|
197
|
-
transformer class used.
|
198
|
-
If this parameter is not specified, output column names are derived by
|
199
|
-
adding an OUTPUT_ prefix to the label column names. These inferred output
|
200
|
-
column names work for estimator's predict() method, but output_cols must
|
201
|
-
be set explicitly for transformers.
|
202
|
-
|
203
|
-
sample_weight_col: Optional[str]
|
204
|
-
A string representing the column name containing the examples’ weights.
|
205
|
-
This argument is only required when working with weighted datasets.
|
206
|
-
|
207
|
-
drop_input_cols: Optional[bool], default=False
|
208
|
-
If set, the response of predict(), transform() methods will not contain input columns.
|
209
214
|
"""
|
210
215
|
_ENABLE_DISTRIBUTED = True
|
211
216
|
|
@@ -225,11 +230,16 @@ class GridSearchCV(BaseTransformer):
|
|
225
230
|
input_cols: Optional[Union[str, Iterable[str]]] = None,
|
226
231
|
output_cols: Optional[Union[str, Iterable[str]]] = None,
|
227
232
|
label_cols: Optional[Union[str, Iterable[str]]] = None,
|
233
|
+
passthrough_cols: Optional[Union[str, Iterable[str]]] = None,
|
228
234
|
drop_input_cols: Optional[bool] = False,
|
229
235
|
sample_weight_col: Optional[str] = None,
|
230
236
|
) -> None:
|
231
237
|
super().__init__()
|
232
|
-
deps: Set[str] =
|
238
|
+
deps: Set[str] = {
|
239
|
+
f"numpy=={np.__version__}",
|
240
|
+
f"scikit-learn=={sklearn.__version__}",
|
241
|
+
f"cloudpickle=={cp.__version__}",
|
242
|
+
}
|
233
243
|
deps = deps | gather_dependencies(estimator)
|
234
244
|
self._deps = list(deps)
|
235
245
|
estimator = transform_snowml_obj_to_sklearn_obj(estimator)
|
@@ -246,7 +256,7 @@ class GridSearchCV(BaseTransformer):
|
|
246
256
|
"return_train_score": (return_train_score, False, False),
|
247
257
|
}
|
248
258
|
cleaned_up_init_args = validate_sklearn_args(args=init_args, klass=sklearn.model_selection.GridSearchCV)
|
249
|
-
self._sklearn_object = sklearn.model_selection.GridSearchCV(
|
259
|
+
self._sklearn_object: Any = sklearn.model_selection.GridSearchCV(
|
250
260
|
**cleaned_up_init_args,
|
251
261
|
)
|
252
262
|
self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
|
@@ -255,10 +265,10 @@ class GridSearchCV(BaseTransformer):
|
|
255
265
|
self.set_label_cols(label_cols)
|
256
266
|
self.set_drop_input_cols(drop_input_cols)
|
257
267
|
self.set_sample_weight_col(sample_weight_col)
|
268
|
+
self.set_passthrough_cols(passthrough_cols)
|
258
269
|
self._handlers: CVHandlers = HandlersImpl(
|
259
270
|
class_name=self.__class__.__name__,
|
260
271
|
subproject=_SUBPROJECT,
|
261
|
-
wrapper_provider=SklearnModelSelectionWrapperProvider(),
|
262
272
|
)
|
263
273
|
|
264
274
|
def _get_rand_id(self) -> str:
|
@@ -270,21 +280,6 @@ class GridSearchCV(BaseTransformer):
|
|
270
280
|
"""
|
271
281
|
return str(uuid4()).replace("-", "_").upper()
|
272
282
|
|
273
|
-
def _infer_input_output_cols(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
274
|
-
"""
|
275
|
-
Infer `self.input_cols` and `self.output_cols` if they are not explicitly set.
|
276
|
-
|
277
|
-
Args:
|
278
|
-
dataset: Input dataset.
|
279
|
-
"""
|
280
|
-
if not self.input_cols:
|
281
|
-
cols = [c for c in dataset.columns if c not in self.get_label_cols() and c != self.sample_weight_col]
|
282
|
-
self.set_input_cols(input_cols=cols)
|
283
|
-
|
284
|
-
if not self.output_cols:
|
285
|
-
cols = [identifier.concat_names(ids=["OUTPUT_", c]) for c in self.label_cols]
|
286
|
-
self.set_output_cols(output_cols=cols)
|
287
|
-
|
288
283
|
def _get_active_columns(self) -> List[str]:
|
289
284
|
""" "Get the list of columns that are relevant to the transformer."""
|
290
285
|
selected_cols = (
|
@@ -301,10 +296,6 @@ class GridSearchCV(BaseTransformer):
|
|
301
296
|
For more details on this function, see [sklearn.model_selection.GridSearchCV.fit]
|
302
297
|
(https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV.fit)
|
303
298
|
|
304
|
-
|
305
|
-
Raises:
|
306
|
-
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
307
|
-
|
308
299
|
Args:
|
309
300
|
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
310
301
|
Snowpark or Pandas DataFrame.
|
@@ -313,70 +304,37 @@ class GridSearchCV(BaseTransformer):
|
|
313
304
|
self
|
314
305
|
"""
|
315
306
|
self._infer_input_output_cols(dataset)
|
316
|
-
if
|
317
|
-
self.
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
"Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
|
307
|
+
if self._sklearn_object.n_jobs is None:
|
308
|
+
self._sklearn_object.n_jobs = -1
|
309
|
+
if isinstance(dataset, DataFrame):
|
310
|
+
session = dataset._session
|
311
|
+
assert session is not None # keep mypy happy
|
312
|
+
# Validate that key package version in user workspace are supported in snowflake conda channel
|
313
|
+
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
314
|
+
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
315
|
+
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
326
316
|
)
|
327
|
-
self._is_fitted = True
|
328
|
-
self._get_model_signatures(dataset)
|
329
|
-
return self
|
330
317
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
336
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
337
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
338
|
-
)
|
318
|
+
# Specify input columns so column pruning will be enforced
|
319
|
+
selected_cols = self._get_active_columns()
|
320
|
+
if len(selected_cols) > 0:
|
321
|
+
dataset = dataset.select(selected_cols)
|
339
322
|
|
340
|
-
|
341
|
-
if len(selected_cols) > 0:
|
342
|
-
dataset = dataset.select(selected_cols)
|
323
|
+
self._snowpark_cols = dataset.select(self.input_cols).columns
|
343
324
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
session=session,
|
358
|
-
estimator=self._sklearn_object,
|
359
|
-
dependencies=self._get_dependencies(),
|
360
|
-
udf_imports=["sklearn"],
|
361
|
-
input_cols=self.input_cols,
|
362
|
-
label_cols=self.label_cols,
|
363
|
-
sample_weight_col=self.sample_weight_col,
|
364
|
-
)
|
365
|
-
else:
|
366
|
-
# Fall back with stored procedure implementation
|
367
|
-
# set the parallel factor to default to minus one, to fully accelerate the cores in single node
|
368
|
-
if self._sklearn_object.n_jobs is None:
|
369
|
-
self._sklearn_object.n_jobs = -1
|
370
|
-
|
371
|
-
self._sklearn_object = self._handlers.fit_snowpark(
|
372
|
-
dataset,
|
373
|
-
session,
|
374
|
-
self._sklearn_object,
|
375
|
-
["snowflake-snowpark-python"] + self._get_dependencies(),
|
376
|
-
self.input_cols,
|
377
|
-
self.label_cols,
|
378
|
-
self.sample_weight_col,
|
379
|
-
)
|
325
|
+
model_trainer = ModelTrainerBuilder.build(
|
326
|
+
estimator=self._sklearn_object,
|
327
|
+
dataset=dataset,
|
328
|
+
input_cols=self.input_cols,
|
329
|
+
label_cols=self.label_cols,
|
330
|
+
sample_weight_col=self.sample_weight_col,
|
331
|
+
autogenerated=False,
|
332
|
+
subproject=_SUBPROJECT,
|
333
|
+
)
|
334
|
+
self._sklearn_object = model_trainer.train()
|
335
|
+
self._is_fitted = True
|
336
|
+
self._get_model_signatures(dataset)
|
337
|
+
return self
|
380
338
|
|
381
339
|
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
382
340
|
if self._drop_input_cols:
|
@@ -433,7 +391,7 @@ class GridSearchCV(BaseTransformer):
|
|
433
391
|
# input cols need to match unquoted / quoted
|
434
392
|
input_cols = self.input_cols
|
435
393
|
unquoted_input_cols = identifier.get_unescaped_names(self.input_cols)
|
436
|
-
quoted_input_cols = identifier.
|
394
|
+
quoted_input_cols = identifier.get_inferred_names(unquoted_input_cols)
|
437
395
|
|
438
396
|
estimator = self._sklearn_object
|
439
397
|
|
@@ -530,10 +488,6 @@ class GridSearchCV(BaseTransformer):
|
|
530
488
|
project=_PROJECT,
|
531
489
|
subproject=_SUBPROJECT,
|
532
490
|
)
|
533
|
-
@telemetry.add_stmt_params_to_df(
|
534
|
-
project=_PROJECT,
|
535
|
-
subproject=_SUBPROJECT,
|
536
|
-
)
|
537
491
|
def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
538
492
|
"""Call predict on the estimator with the best found parameters
|
539
493
|
For more details on this function, see [sklearn.model_selection.GridSearchCV.predict]
|
@@ -576,10 +530,6 @@ class GridSearchCV(BaseTransformer):
|
|
576
530
|
project=_PROJECT,
|
577
531
|
subproject=_SUBPROJECT,
|
578
532
|
)
|
579
|
-
@telemetry.add_stmt_params_to_df(
|
580
|
-
project=_PROJECT,
|
581
|
-
subproject=_SUBPROJECT,
|
582
|
-
)
|
583
533
|
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
584
534
|
"""Call transform on the estimator with the best found parameters
|
585
535
|
For more details on this function, see [sklearn.model_selection.GridSearchCV.transform]
|
@@ -643,10 +593,6 @@ class GridSearchCV(BaseTransformer):
|
|
643
593
|
project=_PROJECT,
|
644
594
|
subproject=_SUBPROJECT,
|
645
595
|
)
|
646
|
-
@telemetry.add_stmt_params_to_df(
|
647
|
-
project=_PROJECT,
|
648
|
-
subproject=_SUBPROJECT,
|
649
|
-
)
|
650
596
|
def predict_proba(
|
651
597
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
|
652
598
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -684,10 +630,6 @@ class GridSearchCV(BaseTransformer):
|
|
684
630
|
project=_PROJECT,
|
685
631
|
subproject=_SUBPROJECT,
|
686
632
|
)
|
687
|
-
@telemetry.add_stmt_params_to_df(
|
688
|
-
project=_PROJECT,
|
689
|
-
subproject=_SUBPROJECT,
|
690
|
-
)
|
691
633
|
def predict_log_proba(
|
692
634
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
|
693
635
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -726,10 +668,6 @@ class GridSearchCV(BaseTransformer):
|
|
726
668
|
project=_PROJECT,
|
727
669
|
subproject=_SUBPROJECT,
|
728
670
|
)
|
729
|
-
@telemetry.add_stmt_params_to_df(
|
730
|
-
project=_PROJECT,
|
731
|
-
subproject=_SUBPROJECT,
|
732
|
-
)
|
733
671
|
def decision_function(
|
734
672
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
|
735
673
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -766,6 +704,8 @@ class GridSearchCV(BaseTransformer):
|
|
766
704
|
@available_if(original_estimator_has_callable("score")) # type: ignore[misc]
|
767
705
|
def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float:
|
768
706
|
"""
|
707
|
+
If implemented by the original estimator, return the score for the dataset.
|
708
|
+
|
769
709
|
Args:
|
770
710
|
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
771
711
|
Snowpark or Pandas DataFrame.
|
@@ -818,9 +758,9 @@ class GridSearchCV(BaseTransformer):
|
|
818
758
|
# For classifier, the type of predict is the same as the type of label
|
819
759
|
if self._sklearn_object._estimator_type == "classifier":
|
820
760
|
# label columns is the desired type for output
|
821
|
-
outputs = _infer_signature(dataset[self.label_cols], "output")
|
761
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output"))
|
822
762
|
# rename the output columns
|
823
|
-
outputs = model_signature_utils.rename_features(outputs, self.output_cols)
|
763
|
+
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
824
764
|
self._model_signature_dict["predict"] = ModelSignature(
|
825
765
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
826
766
|
)
|
@@ -857,6 +797,9 @@ class GridSearchCV(BaseTransformer):
|
|
857
797
|
return self._model_signature_dict
|
858
798
|
|
859
799
|
def to_sklearn(self) -> sklearn.model_selection.GridSearchCV:
|
800
|
+
"""
|
801
|
+
Get sklearn.model_selection.GridSearchCV object.
|
802
|
+
"""
|
860
803
|
assert self._sklearn_object is not None
|
861
804
|
return self._sklearn_object
|
862
805
|
|