snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
61
61
|
|
62
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
63
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
64
|
class SelectFwe(BaseTransformer):
|
71
65
|
r"""Filter: Select the p-values corresponding to Family-wise error rate
|
72
66
|
For more details on this class, see [sklearn.feature_selection.SelectFwe]
|
@@ -266,20 +260,17 @@ class SelectFwe(BaseTransformer):
|
|
266
260
|
self,
|
267
261
|
dataset: DataFrame,
|
268
262
|
inference_method: str,
|
269
|
-
) ->
|
270
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
271
|
-
return the available package that exists in the snowflake anaconda channel
|
263
|
+
) -> None:
|
264
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
272
265
|
|
273
266
|
Args:
|
274
267
|
dataset: snowpark dataframe
|
275
268
|
inference_method: the inference method such as predict, score...
|
276
|
-
|
269
|
+
|
277
270
|
Raises:
|
278
271
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
279
272
|
SnowflakeMLException: If the session is None, raise error
|
280
273
|
|
281
|
-
Returns:
|
282
|
-
A list of available package that exists in the snowflake anaconda channel
|
283
274
|
"""
|
284
275
|
if not self._is_fitted:
|
285
276
|
raise exceptions.SnowflakeMLException(
|
@@ -297,9 +288,7 @@ class SelectFwe(BaseTransformer):
|
|
297
288
|
"Session must not specified for snowpark dataset."
|
298
289
|
),
|
299
290
|
)
|
300
|
-
|
301
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
302
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
291
|
+
|
303
292
|
|
304
293
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
305
294
|
@telemetry.send_api_usage_telemetry(
|
@@ -345,7 +334,8 @@ class SelectFwe(BaseTransformer):
|
|
345
334
|
|
346
335
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
347
336
|
|
348
|
-
self.
|
337
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
338
|
+
self._deps = self._get_dependencies()
|
349
339
|
assert isinstance(
|
350
340
|
dataset._session, Session
|
351
341
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -430,10 +420,8 @@ class SelectFwe(BaseTransformer):
|
|
430
420
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
431
421
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
432
422
|
|
433
|
-
self.
|
434
|
-
|
435
|
-
inference_method=inference_method,
|
436
|
-
)
|
423
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
424
|
+
self._deps = self._get_dependencies()
|
437
425
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
438
426
|
|
439
427
|
transform_kwargs = dict(
|
@@ -500,16 +488,42 @@ class SelectFwe(BaseTransformer):
|
|
500
488
|
self._is_fitted = True
|
501
489
|
return output_result
|
502
490
|
|
491
|
+
|
492
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
493
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
494
|
+
""" Fit to data, then transform it
|
495
|
+
For more details on this function, see [sklearn.feature_selection.SelectFwe.fit_transform]
|
496
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFwe.html#sklearn.feature_selection.SelectFwe.fit_transform)
|
497
|
+
|
503
498
|
|
504
|
-
|
505
|
-
|
506
|
-
|
499
|
+
Raises:
|
500
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
504
|
+
Snowpark or Pandas DataFrame.
|
505
|
+
output_cols_prefix: Prefix for the response columns
|
507
506
|
Returns:
|
508
507
|
Transformed dataset.
|
509
508
|
"""
|
510
|
-
self.
|
511
|
-
|
512
|
-
|
509
|
+
self._infer_input_output_cols(dataset)
|
510
|
+
super()._check_dataset_type(dataset)
|
511
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
512
|
+
estimator=self._sklearn_object,
|
513
|
+
dataset=dataset,
|
514
|
+
input_cols=self.input_cols,
|
515
|
+
label_cols=self.label_cols,
|
516
|
+
sample_weight_col=self.sample_weight_col,
|
517
|
+
autogenerated=self._autogenerated,
|
518
|
+
subproject=_SUBPROJECT,
|
519
|
+
)
|
520
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
521
|
+
drop_input_cols=self._drop_input_cols,
|
522
|
+
expected_output_cols_list=self.output_cols,
|
523
|
+
)
|
524
|
+
self._sklearn_object = fitted_estimator
|
525
|
+
self._is_fitted = True
|
526
|
+
return output_result
|
513
527
|
|
514
528
|
|
515
529
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -600,10 +614,8 @@ class SelectFwe(BaseTransformer):
|
|
600
614
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
601
615
|
|
602
616
|
if isinstance(dataset, DataFrame):
|
603
|
-
self.
|
604
|
-
|
605
|
-
inference_method=inference_method,
|
606
|
-
)
|
617
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
618
|
+
self._deps = self._get_dependencies()
|
607
619
|
assert isinstance(
|
608
620
|
dataset._session, Session
|
609
621
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -668,10 +680,8 @@ class SelectFwe(BaseTransformer):
|
|
668
680
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
669
681
|
|
670
682
|
if isinstance(dataset, DataFrame):
|
671
|
-
self.
|
672
|
-
|
673
|
-
inference_method=inference_method,
|
674
|
-
)
|
683
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
684
|
+
self._deps = self._get_dependencies()
|
675
685
|
assert isinstance(
|
676
686
|
dataset._session, Session
|
677
687
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -733,10 +743,8 @@ class SelectFwe(BaseTransformer):
|
|
733
743
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
734
744
|
|
735
745
|
if isinstance(dataset, DataFrame):
|
736
|
-
self.
|
737
|
-
|
738
|
-
inference_method=inference_method,
|
739
|
-
)
|
746
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
747
|
+
self._deps = self._get_dependencies()
|
740
748
|
assert isinstance(
|
741
749
|
dataset._session, Session
|
742
750
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -802,10 +810,8 @@ class SelectFwe(BaseTransformer):
|
|
802
810
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
803
811
|
|
804
812
|
if isinstance(dataset, DataFrame):
|
805
|
-
self.
|
806
|
-
|
807
|
-
inference_method=inference_method,
|
808
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
809
815
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
810
816
|
transform_kwargs = dict(
|
811
817
|
session=dataset._session,
|
@@ -867,17 +873,15 @@ class SelectFwe(BaseTransformer):
|
|
867
873
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
868
874
|
|
869
875
|
if isinstance(dataset, DataFrame):
|
870
|
-
self.
|
871
|
-
|
872
|
-
inference_method="score",
|
873
|
-
)
|
876
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
877
|
+
self._deps = self._get_dependencies()
|
874
878
|
selected_cols = self._get_active_columns()
|
875
879
|
if len(selected_cols) > 0:
|
876
880
|
dataset = dataset.select(selected_cols)
|
877
881
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
878
882
|
transform_kwargs = dict(
|
879
883
|
session=dataset._session,
|
880
|
-
dependencies=
|
884
|
+
dependencies=self._deps,
|
881
885
|
score_sproc_imports=['sklearn'],
|
882
886
|
)
|
883
887
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -942,11 +946,8 @@ class SelectFwe(BaseTransformer):
|
|
942
946
|
|
943
947
|
if isinstance(dataset, DataFrame):
|
944
948
|
|
945
|
-
self.
|
946
|
-
|
947
|
-
inference_method=inference_method,
|
948
|
-
|
949
|
-
)
|
949
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
950
|
+
self._deps = self._get_dependencies()
|
950
951
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
951
952
|
transform_kwargs = dict(
|
952
953
|
session = dataset._session,
|
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
61
61
|
|
62
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
63
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
64
|
class SelectKBest(BaseTransformer):
|
71
65
|
r"""Select features according to the k highest scores
|
72
66
|
For more details on this class, see [sklearn.feature_selection.SelectKBest]
|
@@ -267,20 +261,17 @@ class SelectKBest(BaseTransformer):
|
|
267
261
|
self,
|
268
262
|
dataset: DataFrame,
|
269
263
|
inference_method: str,
|
270
|
-
) ->
|
271
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
272
|
-
return the available package that exists in the snowflake anaconda channel
|
264
|
+
) -> None:
|
265
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
273
266
|
|
274
267
|
Args:
|
275
268
|
dataset: snowpark dataframe
|
276
269
|
inference_method: the inference method such as predict, score...
|
277
|
-
|
270
|
+
|
278
271
|
Raises:
|
279
272
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
280
273
|
SnowflakeMLException: If the session is None, raise error
|
281
274
|
|
282
|
-
Returns:
|
283
|
-
A list of available package that exists in the snowflake anaconda channel
|
284
275
|
"""
|
285
276
|
if not self._is_fitted:
|
286
277
|
raise exceptions.SnowflakeMLException(
|
@@ -298,9 +289,7 @@ class SelectKBest(BaseTransformer):
|
|
298
289
|
"Session must not specified for snowpark dataset."
|
299
290
|
),
|
300
291
|
)
|
301
|
-
|
302
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
303
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
292
|
+
|
304
293
|
|
305
294
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
306
295
|
@telemetry.send_api_usage_telemetry(
|
@@ -346,7 +335,8 @@ class SelectKBest(BaseTransformer):
|
|
346
335
|
|
347
336
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
348
337
|
|
349
|
-
self.
|
338
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
339
|
+
self._deps = self._get_dependencies()
|
350
340
|
assert isinstance(
|
351
341
|
dataset._session, Session
|
352
342
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -431,10 +421,8 @@ class SelectKBest(BaseTransformer):
|
|
431
421
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
432
422
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
433
423
|
|
434
|
-
self.
|
435
|
-
|
436
|
-
inference_method=inference_method,
|
437
|
-
)
|
424
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
425
|
+
self._deps = self._get_dependencies()
|
438
426
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
439
427
|
|
440
428
|
transform_kwargs = dict(
|
@@ -501,16 +489,42 @@ class SelectKBest(BaseTransformer):
|
|
501
489
|
self._is_fitted = True
|
502
490
|
return output_result
|
503
491
|
|
492
|
+
|
493
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
494
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
495
|
+
""" Fit to data, then transform it
|
496
|
+
For more details on this function, see [sklearn.feature_selection.SelectKBest.fit_transform]
|
497
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectKBest.html#sklearn.feature_selection.SelectKBest.fit_transform)
|
498
|
+
|
504
499
|
|
505
|
-
|
506
|
-
|
507
|
-
|
500
|
+
Raises:
|
501
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
502
|
+
|
503
|
+
Args:
|
504
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
505
|
+
Snowpark or Pandas DataFrame.
|
506
|
+
output_cols_prefix: Prefix for the response columns
|
508
507
|
Returns:
|
509
508
|
Transformed dataset.
|
510
509
|
"""
|
511
|
-
self.
|
512
|
-
|
513
|
-
|
510
|
+
self._infer_input_output_cols(dataset)
|
511
|
+
super()._check_dataset_type(dataset)
|
512
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
513
|
+
estimator=self._sklearn_object,
|
514
|
+
dataset=dataset,
|
515
|
+
input_cols=self.input_cols,
|
516
|
+
label_cols=self.label_cols,
|
517
|
+
sample_weight_col=self.sample_weight_col,
|
518
|
+
autogenerated=self._autogenerated,
|
519
|
+
subproject=_SUBPROJECT,
|
520
|
+
)
|
521
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
522
|
+
drop_input_cols=self._drop_input_cols,
|
523
|
+
expected_output_cols_list=self.output_cols,
|
524
|
+
)
|
525
|
+
self._sklearn_object = fitted_estimator
|
526
|
+
self._is_fitted = True
|
527
|
+
return output_result
|
514
528
|
|
515
529
|
|
516
530
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -601,10 +615,8 @@ class SelectKBest(BaseTransformer):
|
|
601
615
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
602
616
|
|
603
617
|
if isinstance(dataset, DataFrame):
|
604
|
-
self.
|
605
|
-
|
606
|
-
inference_method=inference_method,
|
607
|
-
)
|
618
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
619
|
+
self._deps = self._get_dependencies()
|
608
620
|
assert isinstance(
|
609
621
|
dataset._session, Session
|
610
622
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -669,10 +681,8 @@ class SelectKBest(BaseTransformer):
|
|
669
681
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
670
682
|
|
671
683
|
if isinstance(dataset, DataFrame):
|
672
|
-
self.
|
673
|
-
|
674
|
-
inference_method=inference_method,
|
675
|
-
)
|
684
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
685
|
+
self._deps = self._get_dependencies()
|
676
686
|
assert isinstance(
|
677
687
|
dataset._session, Session
|
678
688
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -734,10 +744,8 @@ class SelectKBest(BaseTransformer):
|
|
734
744
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
735
745
|
|
736
746
|
if isinstance(dataset, DataFrame):
|
737
|
-
self.
|
738
|
-
|
739
|
-
inference_method=inference_method,
|
740
|
-
)
|
747
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
748
|
+
self._deps = self._get_dependencies()
|
741
749
|
assert isinstance(
|
742
750
|
dataset._session, Session
|
743
751
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -803,10 +811,8 @@ class SelectKBest(BaseTransformer):
|
|
803
811
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
804
812
|
|
805
813
|
if isinstance(dataset, DataFrame):
|
806
|
-
self.
|
807
|
-
|
808
|
-
inference_method=inference_method,
|
809
|
-
)
|
814
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
815
|
+
self._deps = self._get_dependencies()
|
810
816
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
811
817
|
transform_kwargs = dict(
|
812
818
|
session=dataset._session,
|
@@ -868,17 +874,15 @@ class SelectKBest(BaseTransformer):
|
|
868
874
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
869
875
|
|
870
876
|
if isinstance(dataset, DataFrame):
|
871
|
-
self.
|
872
|
-
|
873
|
-
inference_method="score",
|
874
|
-
)
|
877
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
878
|
+
self._deps = self._get_dependencies()
|
875
879
|
selected_cols = self._get_active_columns()
|
876
880
|
if len(selected_cols) > 0:
|
877
881
|
dataset = dataset.select(selected_cols)
|
878
882
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
879
883
|
transform_kwargs = dict(
|
880
884
|
session=dataset._session,
|
881
|
-
dependencies=
|
885
|
+
dependencies=self._deps,
|
882
886
|
score_sproc_imports=['sklearn'],
|
883
887
|
)
|
884
888
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -943,11 +947,8 @@ class SelectKBest(BaseTransformer):
|
|
943
947
|
|
944
948
|
if isinstance(dataset, DataFrame):
|
945
949
|
|
946
|
-
self.
|
947
|
-
|
948
|
-
inference_method=inference_method,
|
949
|
-
|
950
|
-
)
|
950
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
951
|
+
self._deps = self._get_dependencies()
|
951
952
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
952
953
|
transform_kwargs = dict(
|
953
954
|
session = dataset._session,
|
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
61
61
|
|
62
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
63
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
64
|
class SelectPercentile(BaseTransformer):
|
71
65
|
r"""Select features according to a percentile of the highest scores
|
72
66
|
For more details on this class, see [sklearn.feature_selection.SelectPercentile]
|
@@ -266,20 +260,17 @@ class SelectPercentile(BaseTransformer):
|
|
266
260
|
self,
|
267
261
|
dataset: DataFrame,
|
268
262
|
inference_method: str,
|
269
|
-
) ->
|
270
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
271
|
-
return the available package that exists in the snowflake anaconda channel
|
263
|
+
) -> None:
|
264
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
272
265
|
|
273
266
|
Args:
|
274
267
|
dataset: snowpark dataframe
|
275
268
|
inference_method: the inference method such as predict, score...
|
276
|
-
|
269
|
+
|
277
270
|
Raises:
|
278
271
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
279
272
|
SnowflakeMLException: If the session is None, raise error
|
280
273
|
|
281
|
-
Returns:
|
282
|
-
A list of available package that exists in the snowflake anaconda channel
|
283
274
|
"""
|
284
275
|
if not self._is_fitted:
|
285
276
|
raise exceptions.SnowflakeMLException(
|
@@ -297,9 +288,7 @@ class SelectPercentile(BaseTransformer):
|
|
297
288
|
"Session must not specified for snowpark dataset."
|
298
289
|
),
|
299
290
|
)
|
300
|
-
|
301
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
302
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
291
|
+
|
303
292
|
|
304
293
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
305
294
|
@telemetry.send_api_usage_telemetry(
|
@@ -345,7 +334,8 @@ class SelectPercentile(BaseTransformer):
|
|
345
334
|
|
346
335
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
347
336
|
|
348
|
-
self.
|
337
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
338
|
+
self._deps = self._get_dependencies()
|
349
339
|
assert isinstance(
|
350
340
|
dataset._session, Session
|
351
341
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -430,10 +420,8 @@ class SelectPercentile(BaseTransformer):
|
|
430
420
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
431
421
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
432
422
|
|
433
|
-
self.
|
434
|
-
|
435
|
-
inference_method=inference_method,
|
436
|
-
)
|
423
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
424
|
+
self._deps = self._get_dependencies()
|
437
425
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
438
426
|
|
439
427
|
transform_kwargs = dict(
|
@@ -500,16 +488,42 @@ class SelectPercentile(BaseTransformer):
|
|
500
488
|
self._is_fitted = True
|
501
489
|
return output_result
|
502
490
|
|
491
|
+
|
492
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
493
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
494
|
+
""" Fit to data, then transform it
|
495
|
+
For more details on this function, see [sklearn.feature_selection.SelectPercentile.fit_transform]
|
496
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectPercentile.html#sklearn.feature_selection.SelectPercentile.fit_transform)
|
497
|
+
|
503
498
|
|
504
|
-
|
505
|
-
|
506
|
-
|
499
|
+
Raises:
|
500
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
504
|
+
Snowpark or Pandas DataFrame.
|
505
|
+
output_cols_prefix: Prefix for the response columns
|
507
506
|
Returns:
|
508
507
|
Transformed dataset.
|
509
508
|
"""
|
510
|
-
self.
|
511
|
-
|
512
|
-
|
509
|
+
self._infer_input_output_cols(dataset)
|
510
|
+
super()._check_dataset_type(dataset)
|
511
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
512
|
+
estimator=self._sklearn_object,
|
513
|
+
dataset=dataset,
|
514
|
+
input_cols=self.input_cols,
|
515
|
+
label_cols=self.label_cols,
|
516
|
+
sample_weight_col=self.sample_weight_col,
|
517
|
+
autogenerated=self._autogenerated,
|
518
|
+
subproject=_SUBPROJECT,
|
519
|
+
)
|
520
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
521
|
+
drop_input_cols=self._drop_input_cols,
|
522
|
+
expected_output_cols_list=self.output_cols,
|
523
|
+
)
|
524
|
+
self._sklearn_object = fitted_estimator
|
525
|
+
self._is_fitted = True
|
526
|
+
return output_result
|
513
527
|
|
514
528
|
|
515
529
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -600,10 +614,8 @@ class SelectPercentile(BaseTransformer):
|
|
600
614
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
601
615
|
|
602
616
|
if isinstance(dataset, DataFrame):
|
603
|
-
self.
|
604
|
-
|
605
|
-
inference_method=inference_method,
|
606
|
-
)
|
617
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
618
|
+
self._deps = self._get_dependencies()
|
607
619
|
assert isinstance(
|
608
620
|
dataset._session, Session
|
609
621
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -668,10 +680,8 @@ class SelectPercentile(BaseTransformer):
|
|
668
680
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
669
681
|
|
670
682
|
if isinstance(dataset, DataFrame):
|
671
|
-
self.
|
672
|
-
|
673
|
-
inference_method=inference_method,
|
674
|
-
)
|
683
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
684
|
+
self._deps = self._get_dependencies()
|
675
685
|
assert isinstance(
|
676
686
|
dataset._session, Session
|
677
687
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -733,10 +743,8 @@ class SelectPercentile(BaseTransformer):
|
|
733
743
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
734
744
|
|
735
745
|
if isinstance(dataset, DataFrame):
|
736
|
-
self.
|
737
|
-
|
738
|
-
inference_method=inference_method,
|
739
|
-
)
|
746
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
747
|
+
self._deps = self._get_dependencies()
|
740
748
|
assert isinstance(
|
741
749
|
dataset._session, Session
|
742
750
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -802,10 +810,8 @@ class SelectPercentile(BaseTransformer):
|
|
802
810
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
803
811
|
|
804
812
|
if isinstance(dataset, DataFrame):
|
805
|
-
self.
|
806
|
-
|
807
|
-
inference_method=inference_method,
|
808
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
809
815
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
810
816
|
transform_kwargs = dict(
|
811
817
|
session=dataset._session,
|
@@ -867,17 +873,15 @@ class SelectPercentile(BaseTransformer):
|
|
867
873
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
868
874
|
|
869
875
|
if isinstance(dataset, DataFrame):
|
870
|
-
self.
|
871
|
-
|
872
|
-
inference_method="score",
|
873
|
-
)
|
876
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
877
|
+
self._deps = self._get_dependencies()
|
874
878
|
selected_cols = self._get_active_columns()
|
875
879
|
if len(selected_cols) > 0:
|
876
880
|
dataset = dataset.select(selected_cols)
|
877
881
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
878
882
|
transform_kwargs = dict(
|
879
883
|
session=dataset._session,
|
880
|
-
dependencies=
|
884
|
+
dependencies=self._deps,
|
881
885
|
score_sproc_imports=['sklearn'],
|
882
886
|
)
|
883
887
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -942,11 +946,8 @@ class SelectPercentile(BaseTransformer):
|
|
942
946
|
|
943
947
|
if isinstance(dataset, DataFrame):
|
944
948
|
|
945
|
-
self.
|
946
|
-
|
947
|
-
inference_method=inference_method,
|
948
|
-
|
949
|
-
)
|
949
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
950
|
+
self._deps = self._get_dependencies()
|
950
951
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
951
952
|
transform_kwargs = dict(
|
952
953
|
session = dataset._session,
|