snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +1 -1
- snowflake/cortex/_extract_answer.py +1 -1
- snowflake/cortex/_sentiment.py +1 -1
- snowflake/cortex/_summarize.py +1 -1
- snowflake/cortex/_translate.py +1 -1
- snowflake/ml/_internal/env_utils.py +68 -6
- snowflake/ml/_internal/file_utils.py +34 -4
- snowflake/ml/_internal/telemetry.py +79 -91
- snowflake/ml/_internal/utils/retryable_http.py +16 -4
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/model/_api.py +21 -14
- snowflake/ml/model/_client/model/model_impl.py +176 -0
- snowflake/ml/model/_client/model/model_method_info.py +19 -0
- snowflake/ml/model/_client/model/model_version_impl.py +291 -0
- snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
- snowflake/ml/model/_client/ops/model_ops.py +308 -0
- snowflake/ml/model/_client/sql/model.py +75 -0
- snowflake/ml/model/_client/sql/model_version.py +213 -0
- snowflake/ml/model/_client/sql/stage.py +40 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
- snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
- snowflake/ml/model/_model_composer/model_composer.py +31 -9
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
- snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/model_signature.py +108 -53
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
- snowflake/ml/modeling/_internal/model_specifications.py +146 -0
- snowflake/ml/modeling/_internal/model_trainer.py +13 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
- snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
- snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
- snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
- snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
- snowflake/ml/modeling/cluster/birch.py +94 -124
- snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
- snowflake/ml/modeling/cluster/dbscan.py +94 -124
- snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
- snowflake/ml/modeling/cluster/k_means.py +93 -124
- snowflake/ml/modeling/cluster/mean_shift.py +94 -124
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
- snowflake/ml/modeling/cluster/optics.py +94 -124
- snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
- snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
- snowflake/ml/modeling/compose/column_transformer.py +94 -124
- snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
- snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
- snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
- snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
- snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
- snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
- snowflake/ml/modeling/covariance/oas.py +80 -110
- snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
- snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
- snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
- snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
- snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/pca.py +94 -124
- snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
- snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
- snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
- snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
- snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
- snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
- snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
- snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
- snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
- snowflake/ml/modeling/framework/base.py +2 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
- snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
- snowflake/ml/modeling/impute/knn_imputer.py +94 -124
- snowflake/ml/modeling/impute/missing_indicator.py +94 -124
- snowflake/ml/modeling/impute/simple_imputer.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
- snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
- snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/lars.py +96 -124
- snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
- snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
- snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
- snowflake/ml/modeling/linear_model/perceptron.py +95 -124
- snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/ridge.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
- snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
- snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
- snowflake/ml/modeling/manifold/isomap.py +94 -124
- snowflake/ml/modeling/manifold/mds.py +94 -124
- snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
- snowflake/ml/modeling/manifold/tsne.py +94 -124
- snowflake/ml/modeling/metrics/classification.py +187 -52
- snowflake/ml/modeling/metrics/correlation.py +4 -2
- snowflake/ml/modeling/metrics/covariance.py +7 -4
- snowflake/ml/modeling/metrics/ranking.py +32 -16
- snowflake/ml/modeling/metrics/regression.py +60 -32
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
- snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
- snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
- snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
- snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
- snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
- snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
- snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
- snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
- snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
- snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
- snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
- snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
- snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
- snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
- snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
- snowflake/ml/modeling/svm/linear_svc.py +96 -124
- snowflake/ml/modeling/svm/linear_svr.py +96 -124
- snowflake/ml/modeling/svm/nu_svc.py +96 -124
- snowflake/ml/modeling/svm/nu_svr.py +96 -124
- snowflake/ml/modeling/svm/svc.py +96 -124
- snowflake/ml/modeling/svm/svr.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
- snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
- snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
- snowflake/ml/registry/model_registry.py +2 -0
- snowflake/ml/registry/registry.py +215 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
- snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
- snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
- {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -1,11 +1,11 @@
|
|
1
|
-
from typing import Dict, Iterable, List, Optional, Set, Union
|
1
|
+
from typing import Any, Dict, Iterable, List, Optional, Set, Union
|
2
2
|
from uuid import uuid4
|
3
3
|
|
4
|
+
import cloudpickle as cp
|
4
5
|
import numpy as np
|
5
6
|
import pandas as pd
|
6
7
|
import sklearn
|
7
8
|
import sklearn.model_selection
|
8
|
-
from sklearn.model_selection import ParameterSampler
|
9
9
|
from sklearn.utils.metaestimators import available_if
|
10
10
|
|
11
11
|
from snowflake.ml._internal import telemetry
|
@@ -22,13 +22,12 @@ from snowflake.ml.model.model_signature import (
|
|
22
22
|
from snowflake.ml.modeling._internal.estimator_protocols import CVHandlers
|
23
23
|
from snowflake.ml.modeling._internal.estimator_utils import (
|
24
24
|
gather_dependencies,
|
25
|
-
is_single_node,
|
26
25
|
original_estimator_has_callable,
|
27
26
|
transform_snowml_obj_to_sklearn_obj,
|
28
27
|
validate_sklearn_args,
|
29
28
|
)
|
29
|
+
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
30
30
|
from snowflake.ml.modeling._internal.snowpark_handlers import (
|
31
|
-
SklearnModelSelectionWrapperProvider,
|
32
31
|
SnowparkHandlers as HandlersImpl,
|
33
32
|
)
|
34
33
|
from snowflake.ml.modeling.framework.base import BaseTransformer
|
@@ -50,13 +49,13 @@ class RandomizedSearchCV(BaseTransformer):
|
|
50
49
|
|
51
50
|
Parameters
|
52
51
|
----------
|
53
|
-
estimator
|
52
|
+
estimator: estimator object
|
54
53
|
An object of that type is instantiated for each grid point.
|
55
54
|
This is assumed to implement the scikit-learn estimator interface.
|
56
55
|
Either estimator needs to provide a ``score`` function,
|
57
56
|
or ``scoring`` must be passed.
|
58
57
|
|
59
|
-
param_distributions
|
58
|
+
param_distributions: dict or list of dicts
|
60
59
|
Dictionary with parameters names (`str`) as keys and distributions
|
61
60
|
or lists of parameters to try. Distributions must provide a ``rvs``
|
62
61
|
method for sampling (such as those from scipy.stats.distributions).
|
@@ -64,11 +63,46 @@ class RandomizedSearchCV(BaseTransformer):
|
|
64
63
|
If a list of dicts is given, first a dict is sampled uniformly, and
|
65
64
|
then a parameter is sampled using that dict as above.
|
66
65
|
|
67
|
-
|
66
|
+
input_cols: Optional[Union[str, List[str]]]
|
67
|
+
A string or list of strings representing column names that contain features.
|
68
|
+
If this parameter is not specified, all columns in the input DataFrame except
|
69
|
+
the columns specified by label_cols and sample-weight_col parameters are
|
70
|
+
considered input columns.
|
71
|
+
|
72
|
+
label_cols: Optional[Union[str, List[str]]]
|
73
|
+
A string or list of strings representing column names that contain labels.
|
74
|
+
This is a required param for estimators, as there is no way to infer these
|
75
|
+
columns. If this parameter is not specified, then object is fitted without
|
76
|
+
labels(Like a transformer).
|
77
|
+
|
78
|
+
output_cols: Optional[Union[str, List[str]]]
|
79
|
+
A string or list of strings representing column names that will store the
|
80
|
+
output of predict and transform operations. The length of output_cols mus
|
81
|
+
match the expected number of output columns from the specific estimator or
|
82
|
+
transformer class used.
|
83
|
+
If this parameter is not specified, output column names are derived by
|
84
|
+
adding an OUTPUT_ prefix to the label column names. These inferred output
|
85
|
+
column names work for estimator's predict() method, but output_cols must
|
86
|
+
be set explicitly for transformers.
|
87
|
+
|
88
|
+
passthrough_cols: A string or a list of strings indicating column names to be excluded from any
|
89
|
+
operations (such as train, transform, or inference). These specified column(s)
|
90
|
+
will remain untouched throughout the process. This option is helpful in scenarios
|
91
|
+
requiring automatic input_cols inference, but need to avoid using specific
|
92
|
+
columns, like index columns, during training or inference.
|
93
|
+
|
94
|
+
sample_weight_col: Optional[str]
|
95
|
+
A string representing the column name containing the examples’ weights.
|
96
|
+
This argument is only required when working with weighted datasets.
|
97
|
+
|
98
|
+
drop_input_cols: Optional[bool], default=False
|
99
|
+
If set, the response of predict(), transform() methods will not contain input columns.
|
100
|
+
|
101
|
+
n_iter: int, default=10
|
68
102
|
Number of parameter settings that are sampled. n_iter trades
|
69
103
|
off runtime vs quality of the solution.
|
70
104
|
|
71
|
-
scoring
|
105
|
+
scoring: str, callable, list, tuple or dict, default=None
|
72
106
|
Strategy to evaluate the performance of the cross-validated model on
|
73
107
|
the test set.
|
74
108
|
|
@@ -88,13 +122,13 @@ class RandomizedSearchCV(BaseTransformer):
|
|
88
122
|
|
89
123
|
If None, the estimator's score method is used.
|
90
124
|
|
91
|
-
n_jobs
|
125
|
+
n_jobs: int, default=None
|
92
126
|
Number of jobs to run in parallel.
|
93
127
|
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
|
94
128
|
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
|
95
129
|
for more details.
|
96
130
|
|
97
|
-
refit
|
131
|
+
refit: bool, str, or callable, default=True
|
98
132
|
Refit an estimator using the best found parameters on the whole
|
99
133
|
dataset.
|
100
134
|
|
@@ -121,7 +155,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
121
155
|
See ``scoring`` parameter to know more about multiple metric
|
122
156
|
evaluation.
|
123
157
|
|
124
|
-
cv
|
158
|
+
cv: int, cross-validation generator or an iterable, default=None
|
125
159
|
Determines the cross-validation splitting strategy.
|
126
160
|
Possible inputs for cv are:
|
127
161
|
|
@@ -138,7 +172,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
138
172
|
Refer :ref:`User Guide <cross_validation>` for the various
|
139
173
|
cross-validation strategies that can be used here.
|
140
174
|
|
141
|
-
verbose
|
175
|
+
verbose: int
|
142
176
|
Controls the verbosity: the higher, the more messages.
|
143
177
|
|
144
178
|
- >1 : the computation time for each fold and parameter candidate is
|
@@ -147,7 +181,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
147
181
|
- >3 : the fold and candidate parameter indexes are also displayed
|
148
182
|
together with the starting time of the computation.
|
149
183
|
|
150
|
-
pre_dispatch
|
184
|
+
pre_dispatch: int, or str, default='2*n_jobs'
|
151
185
|
Controls the number of jobs that get dispatched during parallel
|
152
186
|
execution. Reducing this number can be useful to avoid an
|
153
187
|
explosion of memory consumption when more jobs get dispatched
|
@@ -164,20 +198,20 @@ class RandomizedSearchCV(BaseTransformer):
|
|
164
198
|
- A str, giving an expression as a function of n_jobs,
|
165
199
|
as in '2*n_jobs'
|
166
200
|
|
167
|
-
random_state
|
201
|
+
random_state: int, RandomState instance or None, default=None
|
168
202
|
Pseudo random number generator state used for random uniform sampling
|
169
203
|
from lists of possible values instead of scipy.stats distributions.
|
170
204
|
Pass an int for reproducible output across multiple
|
171
205
|
function calls.
|
172
206
|
See :term:`Glossary <random_state>`.
|
173
207
|
|
174
|
-
error_score
|
208
|
+
error_score: 'raise' or numeric, default=np.nan
|
175
209
|
Value to assign to the score if an error occurs in estimator fitting.
|
176
210
|
If set to 'raise', the error is raised. If a numeric value is given,
|
177
211
|
FitFailedWarning is raised. This parameter does not affect the refit
|
178
212
|
step, which will always raise the error.
|
179
213
|
|
180
|
-
return_train_score
|
214
|
+
return_train_score: bool, default=False
|
181
215
|
If ``False``, the ``cv_results_`` attribute will not include training
|
182
216
|
scores.
|
183
217
|
Computing training scores is used to get insights on how different
|
@@ -185,41 +219,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
185
219
|
However computing the scores on the training set can be computationally
|
186
220
|
expensive and is not strictly required to select the parameters that
|
187
221
|
yield the best generalization performance.
|
188
|
-
|
189
|
-
input_cols : Optional[Union[str, List[str]]]
|
190
|
-
A string or list of strings representing column names that contain features.
|
191
|
-
If this parameter is not specified, all columns in the input DataFrame except
|
192
|
-
the columns specified by label_cols and sample-weight_col parameters are
|
193
|
-
considered input columns.
|
194
|
-
|
195
|
-
label_cols : Optional[Union[str, List[str]]]
|
196
|
-
A string or list of strings representing column names that contain labels.
|
197
|
-
This is a required param for estimators, as there is no way to infer these
|
198
|
-
columns. If this parameter is not specified, then object is fitted without
|
199
|
-
labels(Like a transformer).
|
200
|
-
|
201
|
-
output_cols: Optional[Union[str, List[str]]]
|
202
|
-
A string or list of strings representing column names that will store the
|
203
|
-
output of predict and transform operations. The length of output_cols mus
|
204
|
-
match the expected number of output columns from the specific estimator or
|
205
|
-
transformer class used.
|
206
|
-
If this parameter is not specified, output column names are derived by
|
207
|
-
adding an OUTPUT_ prefix to the label column names. These inferred output
|
208
|
-
column names work for estimator's predict() method, but output_cols must
|
209
|
-
be set explicitly for transformers.
|
210
|
-
|
211
|
-
passthrough_cols: A string or a list of strings indicating column names to be excluded from any
|
212
|
-
operations (such as train, transform, or inference). These specified column(s)
|
213
|
-
will remain untouched throughout the process. This option is helpful in scenarios
|
214
|
-
requiring automatic input_cols inference, but need to avoid using specific
|
215
|
-
columns, like index columns, during training or inference.
|
216
|
-
|
217
|
-
sample_weight_col: Optional[str]
|
218
|
-
A string representing the column name containing the examples’ weights.
|
219
|
-
This argument is only required when working with weighted datasets.
|
220
|
-
|
221
|
-
drop_input_cols: Optional[bool], default=False
|
222
|
-
If set, the response of predict(), transform() methods will not contain input columns.
|
223
222
|
"""
|
224
223
|
_ENABLE_DISTRIBUTED = True
|
225
224
|
|
@@ -246,7 +245,11 @@ class RandomizedSearchCV(BaseTransformer):
|
|
246
245
|
sample_weight_col: Optional[str] = None,
|
247
246
|
) -> None:
|
248
247
|
super().__init__()
|
249
|
-
deps: Set[str] =
|
248
|
+
deps: Set[str] = {
|
249
|
+
f"numpy=={np.__version__}",
|
250
|
+
f"scikit-learn=={sklearn.__version__}",
|
251
|
+
f"cloudpickle=={cp.__version__}",
|
252
|
+
}
|
250
253
|
deps = deps | gather_dependencies(estimator)
|
251
254
|
self._deps = list(deps)
|
252
255
|
estimator = transform_snowml_obj_to_sklearn_obj(estimator)
|
@@ -265,7 +268,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
265
268
|
"return_train_score": (return_train_score, False, False),
|
266
269
|
}
|
267
270
|
cleaned_up_init_args = validate_sklearn_args(args=init_args, klass=sklearn.model_selection.RandomizedSearchCV)
|
268
|
-
self._sklearn_object = sklearn.model_selection.RandomizedSearchCV(
|
271
|
+
self._sklearn_object: Any = sklearn.model_selection.RandomizedSearchCV(
|
269
272
|
**cleaned_up_init_args,
|
270
273
|
)
|
271
274
|
self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
|
@@ -278,7 +281,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
278
281
|
self._handlers: CVHandlers = HandlersImpl(
|
279
282
|
class_name=self.__class__.__name__,
|
280
283
|
subproject=_SUBPROJECT,
|
281
|
-
wrapper_provider=SklearnModelSelectionWrapperProvider(),
|
282
284
|
)
|
283
285
|
|
284
286
|
def _get_rand_id(self) -> str:
|
@@ -306,10 +308,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
306
308
|
For more details on this function, see [sklearn.model_selection.RandomizedSearchCV.fit]
|
307
309
|
(https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV.fit)
|
308
310
|
|
309
|
-
|
310
|
-
Raises:
|
311
|
-
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
312
|
-
|
313
311
|
Args:
|
314
312
|
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
315
313
|
Snowpark or Pandas DataFrame.
|
@@ -318,74 +316,37 @@ class RandomizedSearchCV(BaseTransformer):
|
|
318
316
|
self
|
319
317
|
"""
|
320
318
|
self._infer_input_output_cols(dataset)
|
321
|
-
if
|
322
|
-
self.
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
"Supported dataset types: snowpark.DataFrame, pandas.DataFrame."
|
319
|
+
if hasattr(self._sklearn_object, "n_jobs") and self._sklearn_object.n_jobs is None:
|
320
|
+
self._sklearn_object.n_jobs = -1
|
321
|
+
if isinstance(dataset, DataFrame):
|
322
|
+
session = dataset._session
|
323
|
+
assert session is not None # keep mypy happy
|
324
|
+
# Validate that key package version in user workspace are supported in snowflake conda channel
|
325
|
+
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
326
|
+
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
327
|
+
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
331
328
|
)
|
332
|
-
self._is_fitted = True
|
333
|
-
self._get_model_signatures(dataset)
|
334
|
-
return self
|
335
329
|
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
# If customer doesn't have package in conda channel, replace the ones have the closest versions
|
341
|
-
self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
342
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT
|
343
|
-
)
|
330
|
+
# Specify input columns so column pruning will be enforced
|
331
|
+
selected_cols = self._get_active_columns()
|
332
|
+
if len(selected_cols) > 0:
|
333
|
+
dataset = dataset.select(selected_cols)
|
344
334
|
|
345
|
-
|
346
|
-
if len(selected_cols) > 0:
|
347
|
-
dataset = dataset.select(selected_cols)
|
335
|
+
self._snowpark_cols = dataset.select(self.input_cols).columns
|
348
336
|
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
n_iter=self._sklearn_object.n_iter,
|
363
|
-
random_state=self._sklearn_object.random_state,
|
364
|
-
),
|
365
|
-
dataset=dataset,
|
366
|
-
session=session,
|
367
|
-
estimator=self._sklearn_object,
|
368
|
-
dependencies=self._get_dependencies(),
|
369
|
-
udf_imports=["sklearn"],
|
370
|
-
input_cols=self.input_cols,
|
371
|
-
label_cols=self.label_cols,
|
372
|
-
sample_weight_col=self.sample_weight_col,
|
373
|
-
)
|
374
|
-
else:
|
375
|
-
# Fall back with stored procedure implementation
|
376
|
-
# set the parallel factor to default to minus one, to fully accelerate the cores in single node
|
377
|
-
if self._sklearn_object.n_jobs is None:
|
378
|
-
self._sklearn_object.n_jobs = -1
|
379
|
-
|
380
|
-
self._sklearn_object = self._handlers.fit_snowpark(
|
381
|
-
dataset,
|
382
|
-
session,
|
383
|
-
self._sklearn_object,
|
384
|
-
["snowflake-snowpark-python"] + self._get_dependencies(),
|
385
|
-
self.input_cols,
|
386
|
-
self.label_cols,
|
387
|
-
self.sample_weight_col,
|
388
|
-
)
|
337
|
+
model_trainer = ModelTrainerBuilder.build(
|
338
|
+
estimator=self._sklearn_object,
|
339
|
+
dataset=dataset,
|
340
|
+
input_cols=self.input_cols,
|
341
|
+
label_cols=self.label_cols,
|
342
|
+
sample_weight_col=self.sample_weight_col,
|
343
|
+
autogenerated=False,
|
344
|
+
subproject=_SUBPROJECT,
|
345
|
+
)
|
346
|
+
self._sklearn_object = model_trainer.train()
|
347
|
+
self._is_fitted = True
|
348
|
+
self._get_model_signatures(dataset)
|
349
|
+
return self
|
389
350
|
|
390
351
|
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
391
352
|
if self._drop_input_cols:
|
@@ -539,10 +500,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
539
500
|
project=_PROJECT,
|
540
501
|
subproject=_SUBPROJECT,
|
541
502
|
)
|
542
|
-
@telemetry.add_stmt_params_to_df(
|
543
|
-
project=_PROJECT,
|
544
|
-
subproject=_SUBPROJECT,
|
545
|
-
)
|
546
503
|
def predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
547
504
|
"""Call predict on the estimator with the best found parameters
|
548
505
|
For more details on this function, see [sklearn.model_selection.RandomizedSearchCV.predict]
|
@@ -584,10 +541,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
584
541
|
project=_PROJECT,
|
585
542
|
subproject=_SUBPROJECT,
|
586
543
|
)
|
587
|
-
@telemetry.add_stmt_params_to_df(
|
588
|
-
project=_PROJECT,
|
589
|
-
subproject=_SUBPROJECT,
|
590
|
-
)
|
591
544
|
def transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
|
592
545
|
"""Call transform on the estimator with the best found parameters
|
593
546
|
For more details on this function, see [sklearn.model_selection.RandomizedSearchCV.transform]
|
@@ -651,10 +604,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
651
604
|
project=_PROJECT,
|
652
605
|
subproject=_SUBPROJECT,
|
653
606
|
)
|
654
|
-
@telemetry.add_stmt_params_to_df(
|
655
|
-
project=_PROJECT,
|
656
|
-
subproject=_SUBPROJECT,
|
657
|
-
)
|
658
607
|
def predict_proba(
|
659
608
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_proba_"
|
660
609
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -692,10 +641,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
692
641
|
project=_PROJECT,
|
693
642
|
subproject=_SUBPROJECT,
|
694
643
|
)
|
695
|
-
@telemetry.add_stmt_params_to_df(
|
696
|
-
project=_PROJECT,
|
697
|
-
subproject=_SUBPROJECT,
|
698
|
-
)
|
699
644
|
def predict_log_proba(
|
700
645
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "predict_log_proba_"
|
701
646
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -734,10 +679,6 @@ class RandomizedSearchCV(BaseTransformer):
|
|
734
679
|
project=_PROJECT,
|
735
680
|
subproject=_SUBPROJECT,
|
736
681
|
)
|
737
|
-
@telemetry.add_stmt_params_to_df(
|
738
|
-
project=_PROJECT,
|
739
|
-
subproject=_SUBPROJECT,
|
740
|
-
)
|
741
682
|
def decision_function(
|
742
683
|
self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "decision_function_"
|
743
684
|
) -> Union[DataFrame, pd.DataFrame]:
|
@@ -774,6 +715,8 @@ class RandomizedSearchCV(BaseTransformer):
|
|
774
715
|
@available_if(original_estimator_has_callable("score")) # type: ignore[misc]
|
775
716
|
def score(self, dataset: Union[DataFrame, pd.DataFrame]) -> float:
|
776
717
|
"""
|
718
|
+
If implemented by the original estimator, return the score for the dataset.
|
719
|
+
|
777
720
|
Args:
|
778
721
|
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
779
722
|
Snowpark or Pandas DataFrame.
|
@@ -826,9 +769,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
826
769
|
# For classifier, the type of predict is the same as the type of label
|
827
770
|
if self._sklearn_object._estimator_type == "classifier":
|
828
771
|
# label columns is the desired type for output
|
829
|
-
outputs = _infer_signature(dataset[self.label_cols], "output")
|
772
|
+
outputs = list(_infer_signature(dataset[self.label_cols], "output"))
|
830
773
|
# rename the output columns
|
831
|
-
outputs = model_signature_utils.rename_features(outputs, self.output_cols)
|
774
|
+
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
832
775
|
self._model_signature_dict["predict"] = ModelSignature(
|
833
776
|
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
834
777
|
)
|
@@ -865,6 +808,9 @@ class RandomizedSearchCV(BaseTransformer):
|
|
865
808
|
return self._model_signature_dict
|
866
809
|
|
867
810
|
def to_sklearn(self) -> sklearn.model_selection.RandomizedSearchCV:
|
811
|
+
"""
|
812
|
+
Get sklearn.model_selection.RandomizedSearchCV object.
|
813
|
+
"""
|
868
814
|
assert self._sklearn_object is not None
|
869
815
|
return self._sklearn_object
|
870
816
|
|