snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,15 @@ from snowflake.ml.modeling._internal.transformer_protocols import (
|
|
34
34
|
BatchInferenceKwargsTypedDict,
|
35
35
|
ScoreKwargsTypedDict
|
36
36
|
)
|
37
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
38
|
+
from snowflake.ml.model.model_signature import (
|
39
|
+
BaseFeatureSpec,
|
40
|
+
DataType,
|
41
|
+
FeatureSpec,
|
42
|
+
ModelSignature,
|
43
|
+
_infer_signature,
|
44
|
+
_rename_signature_with_snowflake_identifiers,
|
45
|
+
)
|
37
46
|
|
38
47
|
from snowflake.ml.modeling._internal.model_transformer_builder import ModelTransformerBuilder
|
39
48
|
|
@@ -44,16 +53,6 @@ from snowflake.ml.modeling._internal.estimator_utils import (
|
|
44
53
|
validate_sklearn_args,
|
45
54
|
)
|
46
55
|
|
47
|
-
from snowflake.ml.model.model_signature import (
|
48
|
-
DataType,
|
49
|
-
FeatureSpec,
|
50
|
-
ModelSignature,
|
51
|
-
_infer_signature,
|
52
|
-
_rename_signature_with_snowflake_identifiers,
|
53
|
-
BaseFeatureSpec,
|
54
|
-
)
|
55
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
56
|
-
|
57
56
|
_PROJECT = "ModelDevelopment"
|
58
57
|
# Derive subproject from module name by removing "sklearn"
|
59
58
|
# and converting module name from underscore to CamelCase
|
@@ -62,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
62
61
|
|
63
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
64
63
|
|
65
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
66
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
67
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
68
|
-
return check
|
69
|
-
|
70
|
-
|
71
64
|
class SelectKBest(BaseTransformer):
|
72
65
|
r"""Select features according to the k highest scores
|
73
66
|
For more details on this class, see [sklearn.feature_selection.SelectKBest]
|
@@ -206,12 +199,7 @@ class SelectKBest(BaseTransformer):
|
|
206
199
|
)
|
207
200
|
return selected_cols
|
208
201
|
|
209
|
-
|
210
|
-
project=_PROJECT,
|
211
|
-
subproject=_SUBPROJECT,
|
212
|
-
custom_tags=dict([("autogen", True)]),
|
213
|
-
)
|
214
|
-
def fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SelectKBest":
|
202
|
+
def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> "SelectKBest":
|
215
203
|
"""Run score function on (X, y) and get the appropriate features
|
216
204
|
For more details on this function, see [sklearn.feature_selection.SelectKBest.fit]
|
217
205
|
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest.fit)
|
@@ -238,12 +226,14 @@ class SelectKBest(BaseTransformer):
|
|
238
226
|
|
239
227
|
self._snowpark_cols = dataset.select(self.input_cols).columns
|
240
228
|
|
241
|
-
|
229
|
+
# If we are already in a stored procedure, no need to kick off another one.
|
242
230
|
if SNOWML_SPROC_ENV in os.environ:
|
243
231
|
statement_params = telemetry.get_function_usage_statement_params(
|
244
232
|
project=_PROJECT,
|
245
233
|
subproject=_SUBPROJECT,
|
246
|
-
function_name=telemetry.get_statement_params_full_func_name(
|
234
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
235
|
+
inspect.currentframe(), SelectKBest.__class__.__name__
|
236
|
+
),
|
247
237
|
api_calls=[Session.call],
|
248
238
|
custom_tags=dict([("autogen", True)]) if self._autogenerated else None,
|
249
239
|
)
|
@@ -264,27 +254,24 @@ class SelectKBest(BaseTransformer):
|
|
264
254
|
)
|
265
255
|
self._sklearn_object = model_trainer.train()
|
266
256
|
self._is_fitted = True
|
267
|
-
self.
|
257
|
+
self._generate_model_signatures(dataset)
|
268
258
|
return self
|
269
259
|
|
270
260
|
def _batch_inference_validate_snowpark(
|
271
261
|
self,
|
272
262
|
dataset: DataFrame,
|
273
263
|
inference_method: str,
|
274
|
-
) ->
|
275
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
276
|
-
return the available package that exists in the snowflake anaconda channel
|
264
|
+
) -> None:
|
265
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
277
266
|
|
278
267
|
Args:
|
279
268
|
dataset: snowpark dataframe
|
280
269
|
inference_method: the inference method such as predict, score...
|
281
|
-
|
270
|
+
|
282
271
|
Raises:
|
283
272
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
284
273
|
SnowflakeMLException: If the session is None, raise error
|
285
274
|
|
286
|
-
Returns:
|
287
|
-
A list of available package that exists in the snowflake anaconda channel
|
288
275
|
"""
|
289
276
|
if not self._is_fitted:
|
290
277
|
raise exceptions.SnowflakeMLException(
|
@@ -302,9 +289,7 @@ class SelectKBest(BaseTransformer):
|
|
302
289
|
"Session must not specified for snowpark dataset."
|
303
290
|
),
|
304
291
|
)
|
305
|
-
|
306
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
307
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
292
|
+
|
308
293
|
|
309
294
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
310
295
|
@telemetry.send_api_usage_telemetry(
|
@@ -338,7 +323,9 @@ class SelectKBest(BaseTransformer):
|
|
338
323
|
# when it is classifier, infer the datatype from label columns
|
339
324
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
340
325
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
341
|
-
label_cols_signatures = [
|
326
|
+
label_cols_signatures = [
|
327
|
+
row for row in self.model_signatures['predict'].outputs if row.name in self.output_cols
|
328
|
+
]
|
342
329
|
if len(label_cols_signatures) == 0:
|
343
330
|
error_str = f"Output columns {self.output_cols} do not match model signatures {self.model_signatures['predict'].outputs}."
|
344
331
|
raise exceptions.SnowflakeMLException(
|
@@ -346,25 +333,23 @@ class SelectKBest(BaseTransformer):
|
|
346
333
|
original_exception=ValueError(error_str),
|
347
334
|
)
|
348
335
|
|
349
|
-
expected_type_inferred = convert_sp_to_sf_type(
|
350
|
-
label_cols_signatures[0].as_snowpark_type()
|
351
|
-
)
|
336
|
+
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
352
337
|
|
353
|
-
self.
|
354
|
-
|
338
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
339
|
+
self._deps = self._get_dependencies()
|
340
|
+
assert isinstance(
|
341
|
+
dataset._session, Session
|
342
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
355
343
|
|
356
344
|
transform_kwargs = dict(
|
357
|
-
session
|
358
|
-
dependencies
|
359
|
-
drop_input_cols
|
360
|
-
expected_output_cols_type
|
345
|
+
session=dataset._session,
|
346
|
+
dependencies=self._deps,
|
347
|
+
drop_input_cols=self._drop_input_cols,
|
348
|
+
expected_output_cols_type=expected_type_inferred,
|
361
349
|
)
|
362
350
|
|
363
351
|
elif isinstance(dataset, pd.DataFrame):
|
364
|
-
transform_kwargs = dict(
|
365
|
-
snowpark_input_cols = self._snowpark_cols,
|
366
|
-
drop_input_cols = self._drop_input_cols
|
367
|
-
)
|
352
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
368
353
|
|
369
354
|
transform_handlers = ModelTransformerBuilder.build(
|
370
355
|
dataset=dataset,
|
@@ -406,7 +391,7 @@ class SelectKBest(BaseTransformer):
|
|
406
391
|
Transformed dataset.
|
407
392
|
"""
|
408
393
|
super()._check_dataset_type(dataset)
|
409
|
-
inference_method="transform"
|
394
|
+
inference_method = "transform"
|
410
395
|
|
411
396
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
412
397
|
# are specific to the type of dataset used.
|
@@ -436,24 +421,19 @@ class SelectKBest(BaseTransformer):
|
|
436
421
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
437
422
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
438
423
|
|
439
|
-
self.
|
440
|
-
|
441
|
-
inference_method=inference_method,
|
442
|
-
)
|
424
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
425
|
+
self._deps = self._get_dependencies()
|
443
426
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
444
427
|
|
445
428
|
transform_kwargs = dict(
|
446
|
-
session
|
447
|
-
dependencies
|
448
|
-
drop_input_cols
|
449
|
-
expected_output_cols_type
|
429
|
+
session=dataset._session,
|
430
|
+
dependencies=self._deps,
|
431
|
+
drop_input_cols=self._drop_input_cols,
|
432
|
+
expected_output_cols_type=expected_dtype,
|
450
433
|
)
|
451
434
|
|
452
435
|
elif isinstance(dataset, pd.DataFrame):
|
453
|
-
transform_kwargs = dict(
|
454
|
-
snowpark_input_cols = self._snowpark_cols,
|
455
|
-
drop_input_cols = self._drop_input_cols
|
456
|
-
)
|
436
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
457
437
|
|
458
438
|
transform_handlers = ModelTransformerBuilder.build(
|
459
439
|
dataset=dataset,
|
@@ -472,7 +452,11 @@ class SelectKBest(BaseTransformer):
|
|
472
452
|
return output_df
|
473
453
|
|
474
454
|
@available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
|
475
|
-
def fit_predict(
|
455
|
+
def fit_predict(
|
456
|
+
self,
|
457
|
+
dataset: Union[DataFrame, pd.DataFrame],
|
458
|
+
output_cols_prefix: str = "fit_predict_",
|
459
|
+
) -> Union[DataFrame, pd.DataFrame]:
|
476
460
|
""" Method not supported for this class.
|
477
461
|
|
478
462
|
|
@@ -497,22 +481,106 @@ class SelectKBest(BaseTransformer):
|
|
497
481
|
)
|
498
482
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
499
483
|
drop_input_cols=self._drop_input_cols,
|
500
|
-
expected_output_cols_list=
|
484
|
+
expected_output_cols_list=(
|
485
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
486
|
+
),
|
501
487
|
)
|
502
488
|
self._sklearn_object = fitted_estimator
|
503
489
|
self._is_fitted = True
|
504
490
|
return output_result
|
505
491
|
|
492
|
+
|
493
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
494
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
495
|
+
""" Fit to data, then transform it
|
496
|
+
For more details on this function, see [sklearn.feature_selection.SelectKBest.fit_transform]
|
497
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest.fit_transform)
|
498
|
+
|
499
|
+
|
500
|
+
Raises:
|
501
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
506
502
|
|
507
|
-
|
508
|
-
|
509
|
-
|
503
|
+
Args:
|
504
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
505
|
+
Snowpark or Pandas DataFrame.
|
506
|
+
output_cols_prefix: Prefix for the response columns
|
510
507
|
Returns:
|
511
508
|
Transformed dataset.
|
512
509
|
"""
|
513
|
-
self.
|
514
|
-
|
515
|
-
|
510
|
+
self._infer_input_output_cols(dataset)
|
511
|
+
super()._check_dataset_type(dataset)
|
512
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
513
|
+
estimator=self._sklearn_object,
|
514
|
+
dataset=dataset,
|
515
|
+
input_cols=self.input_cols,
|
516
|
+
label_cols=self.label_cols,
|
517
|
+
sample_weight_col=self.sample_weight_col,
|
518
|
+
autogenerated=self._autogenerated,
|
519
|
+
subproject=_SUBPROJECT,
|
520
|
+
)
|
521
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
522
|
+
drop_input_cols=self._drop_input_cols,
|
523
|
+
expected_output_cols_list=self.output_cols,
|
524
|
+
)
|
525
|
+
self._sklearn_object = fitted_estimator
|
526
|
+
self._is_fitted = True
|
527
|
+
return output_result
|
528
|
+
|
529
|
+
|
530
|
+
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
531
|
+
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
532
|
+
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
533
|
+
"""
|
534
|
+
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
535
|
+
# The following condition is introduced for kneighbors methods, and not used in other methods
|
536
|
+
if output_cols:
|
537
|
+
output_cols = [
|
538
|
+
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
539
|
+
for c in output_cols
|
540
|
+
]
|
541
|
+
elif getattr(self._sklearn_object, "classes_", None) is None:
|
542
|
+
output_cols = [output_cols_prefix]
|
543
|
+
elif self._sklearn_object is not None:
|
544
|
+
classes = self._sklearn_object.classes_
|
545
|
+
if isinstance(classes, numpy.ndarray):
|
546
|
+
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
547
|
+
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
548
|
+
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
549
|
+
output_cols = []
|
550
|
+
for i, cl in enumerate(classes):
|
551
|
+
# For binary classification, there is only one output column for each class
|
552
|
+
# ndarray as the two classes are complementary.
|
553
|
+
if len(cl) == 2:
|
554
|
+
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
555
|
+
else:
|
556
|
+
output_cols.extend([
|
557
|
+
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
558
|
+
])
|
559
|
+
else:
|
560
|
+
output_cols = []
|
561
|
+
|
562
|
+
# Make sure column names are valid snowflake identifiers.
|
563
|
+
assert output_cols is not None # Make MyPy happy
|
564
|
+
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
565
|
+
|
566
|
+
return rv
|
567
|
+
|
568
|
+
def _align_expected_output_names(
|
569
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
570
|
+
) -> List[str]:
|
571
|
+
# in case the inferred output column names dimension is different
|
572
|
+
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
573
|
+
output_df_pd = getattr(self, method)(dataset.limit(1).to_pandas(), output_cols_prefix)
|
574
|
+
output_df_columns = list(output_df_pd.columns)
|
575
|
+
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
576
|
+
if self.sample_weight_col:
|
577
|
+
output_df_columns_set -= set(self.sample_weight_col)
|
578
|
+
# if the dimension of inferred output column names is correct; use it
|
579
|
+
if len(expected_output_cols_list) == len(output_df_columns_set):
|
580
|
+
return expected_output_cols_list
|
581
|
+
# otherwise, use the sklearn estimator's output
|
582
|
+
else:
|
583
|
+
return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
516
584
|
|
517
585
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
518
586
|
@telemetry.send_api_usage_telemetry(
|
@@ -544,24 +612,26 @@ class SelectKBest(BaseTransformer):
|
|
544
612
|
# are specific to the type of dataset used.
|
545
613
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
546
614
|
|
615
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
616
|
+
|
547
617
|
if isinstance(dataset, DataFrame):
|
548
|
-
self.
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
618
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
619
|
+
self._deps = self._get_dependencies()
|
620
|
+
assert isinstance(
|
621
|
+
dataset._session, Session
|
622
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
553
623
|
transform_kwargs = dict(
|
554
624
|
session=dataset._session,
|
555
625
|
dependencies=self._deps,
|
556
|
-
drop_input_cols
|
626
|
+
drop_input_cols=self._drop_input_cols,
|
557
627
|
expected_output_cols_type="float",
|
558
628
|
)
|
629
|
+
expected_output_cols = self._align_expected_output_names(
|
630
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
631
|
+
)
|
559
632
|
|
560
633
|
elif isinstance(dataset, pd.DataFrame):
|
561
|
-
transform_kwargs = dict(
|
562
|
-
snowpark_input_cols = self._snowpark_cols,
|
563
|
-
drop_input_cols = self._drop_input_cols
|
564
|
-
)
|
634
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
565
635
|
|
566
636
|
transform_handlers = ModelTransformerBuilder.build(
|
567
637
|
dataset=dataset,
|
@@ -573,7 +643,7 @@ class SelectKBest(BaseTransformer):
|
|
573
643
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
574
644
|
inference_method=inference_method,
|
575
645
|
input_cols=self.input_cols,
|
576
|
-
expected_output_cols=
|
646
|
+
expected_output_cols=expected_output_cols,
|
577
647
|
**transform_kwargs
|
578
648
|
)
|
579
649
|
return output_df
|
@@ -603,29 +673,30 @@ class SelectKBest(BaseTransformer):
|
|
603
673
|
Output dataset with log probability of the sample for each class in the model.
|
604
674
|
"""
|
605
675
|
super()._check_dataset_type(dataset)
|
606
|
-
inference_method="predict_log_proba"
|
676
|
+
inference_method = "predict_log_proba"
|
677
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
607
678
|
|
608
679
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
609
680
|
# are specific to the type of dataset used.
|
610
681
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
611
682
|
|
612
683
|
if isinstance(dataset, DataFrame):
|
613
|
-
self.
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
684
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
685
|
+
self._deps = self._get_dependencies()
|
686
|
+
assert isinstance(
|
687
|
+
dataset._session, Session
|
688
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
618
689
|
transform_kwargs = dict(
|
619
690
|
session=dataset._session,
|
620
691
|
dependencies=self._deps,
|
621
|
-
drop_input_cols
|
692
|
+
drop_input_cols=self._drop_input_cols,
|
622
693
|
expected_output_cols_type="float",
|
623
694
|
)
|
695
|
+
expected_output_cols = self._align_expected_output_names(
|
696
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
697
|
+
)
|
624
698
|
elif isinstance(dataset, pd.DataFrame):
|
625
|
-
transform_kwargs = dict(
|
626
|
-
snowpark_input_cols = self._snowpark_cols,
|
627
|
-
drop_input_cols = self._drop_input_cols
|
628
|
-
)
|
699
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
629
700
|
|
630
701
|
transform_handlers = ModelTransformerBuilder.build(
|
631
702
|
dataset=dataset,
|
@@ -638,7 +709,7 @@ class SelectKBest(BaseTransformer):
|
|
638
709
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
639
710
|
inference_method=inference_method,
|
640
711
|
input_cols=self.input_cols,
|
641
|
-
expected_output_cols=
|
712
|
+
expected_output_cols=expected_output_cols,
|
642
713
|
**transform_kwargs
|
643
714
|
)
|
644
715
|
return output_df
|
@@ -664,30 +735,32 @@ class SelectKBest(BaseTransformer):
|
|
664
735
|
Output dataset with results of the decision function for the samples in input dataset.
|
665
736
|
"""
|
666
737
|
super()._check_dataset_type(dataset)
|
667
|
-
inference_method="decision_function"
|
738
|
+
inference_method = "decision_function"
|
668
739
|
|
669
740
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
670
741
|
# are specific to the type of dataset used.
|
671
742
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
672
743
|
|
744
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
745
|
+
|
673
746
|
if isinstance(dataset, DataFrame):
|
674
|
-
self.
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
747
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
748
|
+
self._deps = self._get_dependencies()
|
749
|
+
assert isinstance(
|
750
|
+
dataset._session, Session
|
751
|
+
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
679
752
|
transform_kwargs = dict(
|
680
753
|
session=dataset._session,
|
681
754
|
dependencies=self._deps,
|
682
|
-
drop_input_cols
|
755
|
+
drop_input_cols=self._drop_input_cols,
|
683
756
|
expected_output_cols_type="float",
|
684
757
|
)
|
758
|
+
expected_output_cols = self._align_expected_output_names(
|
759
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
760
|
+
)
|
685
761
|
|
686
762
|
elif isinstance(dataset, pd.DataFrame):
|
687
|
-
transform_kwargs = dict(
|
688
|
-
snowpark_input_cols = self._snowpark_cols,
|
689
|
-
drop_input_cols = self._drop_input_cols
|
690
|
-
)
|
763
|
+
transform_kwargs = dict(snowpark_input_cols=self._snowpark_cols, drop_input_cols=self._drop_input_cols)
|
691
764
|
|
692
765
|
transform_handlers = ModelTransformerBuilder.build(
|
693
766
|
dataset=dataset,
|
@@ -700,7 +773,7 @@ class SelectKBest(BaseTransformer):
|
|
700
773
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
701
774
|
inference_method=inference_method,
|
702
775
|
input_cols=self.input_cols,
|
703
|
-
expected_output_cols=
|
776
|
+
expected_output_cols=expected_output_cols,
|
704
777
|
**transform_kwargs
|
705
778
|
)
|
706
779
|
return output_df
|
@@ -729,17 +802,17 @@ class SelectKBest(BaseTransformer):
|
|
729
802
|
Output dataset with probability of the sample for each class in the model.
|
730
803
|
"""
|
731
804
|
super()._check_dataset_type(dataset)
|
732
|
-
inference_method="score_samples"
|
805
|
+
inference_method = "score_samples"
|
733
806
|
|
734
807
|
# This dictionary contains optional kwargs for batch inference. These kwargs
|
735
808
|
# are specific to the type of dataset used.
|
736
809
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
737
810
|
|
811
|
+
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
812
|
+
|
738
813
|
if isinstance(dataset, DataFrame):
|
739
|
-
self.
|
740
|
-
|
741
|
-
inference_method=inference_method,
|
742
|
-
)
|
814
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
815
|
+
self._deps = self._get_dependencies()
|
743
816
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
744
817
|
transform_kwargs = dict(
|
745
818
|
session=dataset._session,
|
@@ -747,6 +820,9 @@ class SelectKBest(BaseTransformer):
|
|
747
820
|
drop_input_cols = self._drop_input_cols,
|
748
821
|
expected_output_cols_type="float",
|
749
822
|
)
|
823
|
+
expected_output_cols = self._align_expected_output_names(
|
824
|
+
inference_method, dataset, expected_output_cols, output_cols_prefix
|
825
|
+
)
|
750
826
|
|
751
827
|
elif isinstance(dataset, pd.DataFrame):
|
752
828
|
transform_kwargs = dict(
|
@@ -765,7 +841,7 @@ class SelectKBest(BaseTransformer):
|
|
765
841
|
output_df: DATAFRAME_TYPE = transform_handlers.batch_inference(
|
766
842
|
inference_method=inference_method,
|
767
843
|
input_cols=self.input_cols,
|
768
|
-
expected_output_cols=
|
844
|
+
expected_output_cols=expected_output_cols,
|
769
845
|
**transform_kwargs
|
770
846
|
)
|
771
847
|
return output_df
|
@@ -798,17 +874,15 @@ class SelectKBest(BaseTransformer):
|
|
798
874
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
799
875
|
|
800
876
|
if isinstance(dataset, DataFrame):
|
801
|
-
self.
|
802
|
-
|
803
|
-
inference_method="score",
|
804
|
-
)
|
877
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
878
|
+
self._deps = self._get_dependencies()
|
805
879
|
selected_cols = self._get_active_columns()
|
806
880
|
if len(selected_cols) > 0:
|
807
881
|
dataset = dataset.select(selected_cols)
|
808
882
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
809
883
|
transform_kwargs = dict(
|
810
884
|
session=dataset._session,
|
811
|
-
dependencies=
|
885
|
+
dependencies=self._deps,
|
812
886
|
score_sproc_imports=['sklearn'],
|
813
887
|
)
|
814
888
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -873,11 +947,8 @@ class SelectKBest(BaseTransformer):
|
|
873
947
|
|
874
948
|
if isinstance(dataset, DataFrame):
|
875
949
|
|
876
|
-
self.
|
877
|
-
|
878
|
-
inference_method=inference_method,
|
879
|
-
|
880
|
-
)
|
950
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
951
|
+
self._deps = self._get_dependencies()
|
881
952
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
882
953
|
transform_kwargs = dict(
|
883
954
|
session = dataset._session,
|
@@ -910,50 +981,84 @@ class SelectKBest(BaseTransformer):
|
|
910
981
|
)
|
911
982
|
return output_df
|
912
983
|
|
984
|
+
|
985
|
+
|
986
|
+
def to_sklearn(self) -> Any:
|
987
|
+
"""Get sklearn.feature_selection.SelectKBest object.
|
988
|
+
"""
|
989
|
+
if self._sklearn_object is None:
|
990
|
+
self._sklearn_object = self._create_sklearn_object()
|
991
|
+
return self._sklearn_object
|
992
|
+
|
993
|
+
def to_xgboost(self) -> Any:
|
994
|
+
raise exceptions.SnowflakeMLException(
|
995
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
996
|
+
original_exception=AttributeError(
|
997
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
998
|
+
"to_xgboost()",
|
999
|
+
"to_sklearn()"
|
1000
|
+
)
|
1001
|
+
),
|
1002
|
+
)
|
913
1003
|
|
914
|
-
def
|
1004
|
+
def to_lightgbm(self) -> Any:
|
1005
|
+
raise exceptions.SnowflakeMLException(
|
1006
|
+
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1007
|
+
original_exception=AttributeError(
|
1008
|
+
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1009
|
+
"to_lightgbm()",
|
1010
|
+
"to_sklearn()"
|
1011
|
+
)
|
1012
|
+
),
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
def _get_dependencies(self) -> List[str]:
|
1016
|
+
return self._deps
|
1017
|
+
|
1018
|
+
|
1019
|
+
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
915
1020
|
self._model_signature_dict = dict()
|
916
1021
|
|
917
1022
|
PROB_FUNCTIONS = ["predict_log_proba", "predict_proba", "decision_function"]
|
918
1023
|
|
919
|
-
inputs = list(_infer_signature(dataset[self.input_cols], "input"))
|
1024
|
+
inputs = list(_infer_signature(dataset[self.input_cols], "input", use_snowflake_identifiers=True))
|
920
1025
|
outputs: List[BaseFeatureSpec] = []
|
921
1026
|
if hasattr(self, "predict"):
|
922
1027
|
# keep mypy happy
|
923
|
-
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
1028
|
+
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
924
1029
|
# For classifier, the type of predict is the same as the type of label
|
925
|
-
if self._sklearn_object._estimator_type ==
|
926
|
-
|
1030
|
+
if self._sklearn_object._estimator_type == "classifier":
|
1031
|
+
# label columns is the desired type for output
|
927
1032
|
outputs = list(_infer_signature(dataset[self.label_cols], "output", use_snowflake_identifiers=True))
|
928
1033
|
# rename the output columns
|
929
1034
|
outputs = list(model_signature_utils.rename_features(outputs, self.output_cols))
|
930
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
931
|
-
|
932
|
-
|
1035
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1036
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1037
|
+
)
|
933
1038
|
# For mixture models that use the density mixin, `predict` returns the argmax of the log prob.
|
934
1039
|
# For outlier models, returns -1 for outliers and 1 for inliers.
|
935
|
-
# Clusterer returns int64 cluster labels.
|
1040
|
+
# Clusterer returns int64 cluster labels.
|
936
1041
|
elif self._sklearn_object._estimator_type in ["DensityEstimator", "clusterer", "outlier_detector"]:
|
937
1042
|
outputs = [FeatureSpec(dtype=DataType.INT64, name=c) for c in self.output_cols]
|
938
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
939
|
-
|
940
|
-
|
941
|
-
|
1043
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1044
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1045
|
+
)
|
1046
|
+
|
942
1047
|
# For regressor, the type of predict is float64
|
943
|
-
elif self._sklearn_object._estimator_type ==
|
1048
|
+
elif self._sklearn_object._estimator_type == "regressor":
|
944
1049
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in self.output_cols]
|
945
|
-
self._model_signature_dict["predict"] = ModelSignature(
|
946
|
-
|
947
|
-
|
948
|
-
|
1050
|
+
self._model_signature_dict["predict"] = ModelSignature(
|
1051
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1052
|
+
)
|
1053
|
+
|
949
1054
|
for prob_func in PROB_FUNCTIONS:
|
950
1055
|
if hasattr(self, prob_func):
|
951
1056
|
output_cols_prefix: str = f"{prob_func}_"
|
952
1057
|
output_column_names = self._get_output_column_names(output_cols_prefix)
|
953
1058
|
outputs = [FeatureSpec(dtype=DataType.DOUBLE, name=c) for c in output_column_names]
|
954
|
-
self._model_signature_dict[prob_func] = ModelSignature(
|
955
|
-
|
956
|
-
|
1059
|
+
self._model_signature_dict[prob_func] = ModelSignature(
|
1060
|
+
inputs, ([] if self._drop_input_cols else inputs) + outputs
|
1061
|
+
)
|
957
1062
|
|
958
1063
|
# Output signature names may still need to be renamed, since they were not created with `_infer_signature`.
|
959
1064
|
items = list(self._model_signature_dict.items())
|
@@ -966,10 +1071,10 @@ class SelectKBest(BaseTransformer):
|
|
966
1071
|
"""Returns model signature of current class.
|
967
1072
|
|
968
1073
|
Raises:
|
969
|
-
|
1074
|
+
SnowflakeMLException: If estimator is not fitted, then model signature cannot be inferred
|
970
1075
|
|
971
1076
|
Returns:
|
972
|
-
Dict
|
1077
|
+
Dict with each method and its input output signature
|
973
1078
|
"""
|
974
1079
|
if self._model_signature_dict is None:
|
975
1080
|
raise exceptions.SnowflakeMLException(
|
@@ -977,35 +1082,3 @@ class SelectKBest(BaseTransformer):
|
|
977
1082
|
original_exception=RuntimeError("Estimator not fitted before accessing property model_signatures!"),
|
978
1083
|
)
|
979
1084
|
return self._model_signature_dict
|
980
|
-
|
981
|
-
def to_sklearn(self) -> Any:
|
982
|
-
"""Get sklearn.feature_selection.SelectKBest object.
|
983
|
-
"""
|
984
|
-
if self._sklearn_object is None:
|
985
|
-
self._sklearn_object = self._create_sklearn_object()
|
986
|
-
return self._sklearn_object
|
987
|
-
|
988
|
-
def to_xgboost(self) -> Any:
|
989
|
-
raise exceptions.SnowflakeMLException(
|
990
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
991
|
-
original_exception=AttributeError(
|
992
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
993
|
-
"to_xgboost()",
|
994
|
-
"to_sklearn()"
|
995
|
-
)
|
996
|
-
),
|
997
|
-
)
|
998
|
-
|
999
|
-
def to_lightgbm(self) -> Any:
|
1000
|
-
raise exceptions.SnowflakeMLException(
|
1001
|
-
error_code=error_codes.METHOD_NOT_ALLOWED,
|
1002
|
-
original_exception=AttributeError(
|
1003
|
-
modeling_error_messages.UNSUPPORTED_MODEL_CONVERSION.format(
|
1004
|
-
"to_lightgbm()",
|
1005
|
-
"to_sklearn()"
|
1006
|
-
)
|
1007
|
-
),
|
1008
|
-
)
|
1009
|
-
|
1010
|
-
def _get_dependencies(self) -> List[str]:
|
1011
|
-
return self._deps
|