snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
61
61
|
|
62
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
63
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
64
|
class GenericUnivariateSelect(BaseTransformer):
|
71
65
|
r"""Univariate feature selector with configurable strategy
|
72
66
|
For more details on this class, see [sklearn.feature_selection.GenericUnivariateSelect]
|
@@ -270,20 +264,17 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
270
264
|
self,
|
271
265
|
dataset: DataFrame,
|
272
266
|
inference_method: str,
|
273
|
-
) ->
|
274
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
275
|
-
return the available package that exists in the snowflake anaconda channel
|
267
|
+
) -> None:
|
268
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
276
269
|
|
277
270
|
Args:
|
278
271
|
dataset: snowpark dataframe
|
279
272
|
inference_method: the inference method such as predict, score...
|
280
|
-
|
273
|
+
|
281
274
|
Raises:
|
282
275
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
283
276
|
SnowflakeMLException: If the session is None, raise error
|
284
277
|
|
285
|
-
Returns:
|
286
|
-
A list of available package that exists in the snowflake anaconda channel
|
287
278
|
"""
|
288
279
|
if not self._is_fitted:
|
289
280
|
raise exceptions.SnowflakeMLException(
|
@@ -301,9 +292,7 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
301
292
|
"Session must not specified for snowpark dataset."
|
302
293
|
),
|
303
294
|
)
|
304
|
-
|
305
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
306
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
295
|
+
|
307
296
|
|
308
297
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
309
298
|
@telemetry.send_api_usage_telemetry(
|
@@ -349,7 +338,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
349
338
|
|
350
339
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
351
340
|
|
352
|
-
self.
|
341
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
342
|
+
self._deps = self._get_dependencies()
|
353
343
|
assert isinstance(
|
354
344
|
dataset._session, Session
|
355
345
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -434,10 +424,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
434
424
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
435
425
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
436
426
|
|
437
|
-
self.
|
438
|
-
|
439
|
-
inference_method=inference_method,
|
440
|
-
)
|
427
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
428
|
+
self._deps = self._get_dependencies()
|
441
429
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
442
430
|
|
443
431
|
transform_kwargs = dict(
|
@@ -504,16 +492,42 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
504
492
|
self._is_fitted = True
|
505
493
|
return output_result
|
506
494
|
|
495
|
+
|
496
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
497
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
498
|
+
""" Fit to data, then transform it
|
499
|
+
For more details on this function, see [sklearn.feature_selection.GenericUnivariateSelect.fit_transform]
|
500
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.GenericUnivariateSelect.html#sklearn.feature_selection.GenericUnivariateSelect.fit_transform)
|
501
|
+
|
507
502
|
|
508
|
-
|
509
|
-
|
510
|
-
|
503
|
+
Raises:
|
504
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
505
|
+
|
506
|
+
Args:
|
507
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
508
|
+
Snowpark or Pandas DataFrame.
|
509
|
+
output_cols_prefix: Prefix for the response columns
|
511
510
|
Returns:
|
512
511
|
Transformed dataset.
|
513
512
|
"""
|
514
|
-
self.
|
515
|
-
|
516
|
-
|
513
|
+
self._infer_input_output_cols(dataset)
|
514
|
+
super()._check_dataset_type(dataset)
|
515
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
516
|
+
estimator=self._sklearn_object,
|
517
|
+
dataset=dataset,
|
518
|
+
input_cols=self.input_cols,
|
519
|
+
label_cols=self.label_cols,
|
520
|
+
sample_weight_col=self.sample_weight_col,
|
521
|
+
autogenerated=self._autogenerated,
|
522
|
+
subproject=_SUBPROJECT,
|
523
|
+
)
|
524
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
525
|
+
drop_input_cols=self._drop_input_cols,
|
526
|
+
expected_output_cols_list=self.output_cols,
|
527
|
+
)
|
528
|
+
self._sklearn_object = fitted_estimator
|
529
|
+
self._is_fitted = True
|
530
|
+
return output_result
|
517
531
|
|
518
532
|
|
519
533
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -604,10 +618,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
604
618
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
605
619
|
|
606
620
|
if isinstance(dataset, DataFrame):
|
607
|
-
self.
|
608
|
-
|
609
|
-
inference_method=inference_method,
|
610
|
-
)
|
621
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
622
|
+
self._deps = self._get_dependencies()
|
611
623
|
assert isinstance(
|
612
624
|
dataset._session, Session
|
613
625
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -672,10 +684,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
672
684
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
673
685
|
|
674
686
|
if isinstance(dataset, DataFrame):
|
675
|
-
self.
|
676
|
-
|
677
|
-
inference_method=inference_method,
|
678
|
-
)
|
687
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
688
|
+
self._deps = self._get_dependencies()
|
679
689
|
assert isinstance(
|
680
690
|
dataset._session, Session
|
681
691
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -737,10 +747,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
737
747
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
738
748
|
|
739
749
|
if isinstance(dataset, DataFrame):
|
740
|
-
self.
|
741
|
-
|
742
|
-
inference_method=inference_method,
|
743
|
-
)
|
750
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
751
|
+
self._deps = self._get_dependencies()
|
744
752
|
assert isinstance(
|
745
753
|
dataset._session, Session
|
746
754
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -806,10 +814,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
806
814
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
807
815
|
|
808
816
|
if isinstance(dataset, DataFrame):
|
809
|
-
self.
|
810
|
-
|
811
|
-
inference_method=inference_method,
|
812
|
-
)
|
817
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
818
|
+
self._deps = self._get_dependencies()
|
813
819
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
814
820
|
transform_kwargs = dict(
|
815
821
|
session=dataset._session,
|
@@ -871,17 +877,15 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
871
877
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
872
878
|
|
873
879
|
if isinstance(dataset, DataFrame):
|
874
|
-
self.
|
875
|
-
|
876
|
-
inference_method="score",
|
877
|
-
)
|
880
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
881
|
+
self._deps = self._get_dependencies()
|
878
882
|
selected_cols = self._get_active_columns()
|
879
883
|
if len(selected_cols) > 0:
|
880
884
|
dataset = dataset.select(selected_cols)
|
881
885
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
882
886
|
transform_kwargs = dict(
|
883
887
|
session=dataset._session,
|
884
|
-
dependencies=
|
888
|
+
dependencies=self._deps,
|
885
889
|
score_sproc_imports=['sklearn'],
|
886
890
|
)
|
887
891
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -946,11 +950,8 @@ class GenericUnivariateSelect(BaseTransformer):
|
|
946
950
|
|
947
951
|
if isinstance(dataset, DataFrame):
|
948
952
|
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method=inference_method,
|
952
|
-
|
953
|
-
)
|
953
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
954
|
+
self._deps = self._get_dependencies()
|
954
955
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
955
956
|
transform_kwargs = dict(
|
956
957
|
session = dataset._session,
|
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
61
61
|
|
62
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
63
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
64
|
class SelectFdr(BaseTransformer):
|
71
65
|
r"""Filter: Select the p-values for an estimated false discovery rate
|
72
66
|
For more details on this class, see [sklearn.feature_selection.SelectFdr]
|
@@ -266,20 +260,17 @@ class SelectFdr(BaseTransformer):
|
|
266
260
|
self,
|
267
261
|
dataset: DataFrame,
|
268
262
|
inference_method: str,
|
269
|
-
) ->
|
270
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
271
|
-
return the available package that exists in the snowflake anaconda channel
|
263
|
+
) -> None:
|
264
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
272
265
|
|
273
266
|
Args:
|
274
267
|
dataset: snowpark dataframe
|
275
268
|
inference_method: the inference method such as predict, score...
|
276
|
-
|
269
|
+
|
277
270
|
Raises:
|
278
271
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
279
272
|
SnowflakeMLException: If the session is None, raise error
|
280
273
|
|
281
|
-
Returns:
|
282
|
-
A list of available package that exists in the snowflake anaconda channel
|
283
274
|
"""
|
284
275
|
if not self._is_fitted:
|
285
276
|
raise exceptions.SnowflakeMLException(
|
@@ -297,9 +288,7 @@ class SelectFdr(BaseTransformer):
|
|
297
288
|
"Session must not specified for snowpark dataset."
|
298
289
|
),
|
299
290
|
)
|
300
|
-
|
301
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
302
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
291
|
+
|
303
292
|
|
304
293
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
305
294
|
@telemetry.send_api_usage_telemetry(
|
@@ -345,7 +334,8 @@ class SelectFdr(BaseTransformer):
|
|
345
334
|
|
346
335
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
347
336
|
|
348
|
-
self.
|
337
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
338
|
+
self._deps = self._get_dependencies()
|
349
339
|
assert isinstance(
|
350
340
|
dataset._session, Session
|
351
341
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -430,10 +420,8 @@ class SelectFdr(BaseTransformer):
|
|
430
420
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
431
421
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
432
422
|
|
433
|
-
self.
|
434
|
-
|
435
|
-
inference_method=inference_method,
|
436
|
-
)
|
423
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
424
|
+
self._deps = self._get_dependencies()
|
437
425
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
438
426
|
|
439
427
|
transform_kwargs = dict(
|
@@ -500,16 +488,42 @@ class SelectFdr(BaseTransformer):
|
|
500
488
|
self._is_fitted = True
|
501
489
|
return output_result
|
502
490
|
|
491
|
+
|
492
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
493
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
494
|
+
""" Fit to data, then transform it
|
495
|
+
For more details on this function, see [sklearn.feature_selection.SelectFdr.fit_transform]
|
496
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFdr.html#sklearn.feature_selection.SelectFdr.fit_transform)
|
497
|
+
|
503
498
|
|
504
|
-
|
505
|
-
|
506
|
-
|
499
|
+
Raises:
|
500
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
504
|
+
Snowpark or Pandas DataFrame.
|
505
|
+
output_cols_prefix: Prefix for the response columns
|
507
506
|
Returns:
|
508
507
|
Transformed dataset.
|
509
508
|
"""
|
510
|
-
self.
|
511
|
-
|
512
|
-
|
509
|
+
self._infer_input_output_cols(dataset)
|
510
|
+
super()._check_dataset_type(dataset)
|
511
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
512
|
+
estimator=self._sklearn_object,
|
513
|
+
dataset=dataset,
|
514
|
+
input_cols=self.input_cols,
|
515
|
+
label_cols=self.label_cols,
|
516
|
+
sample_weight_col=self.sample_weight_col,
|
517
|
+
autogenerated=self._autogenerated,
|
518
|
+
subproject=_SUBPROJECT,
|
519
|
+
)
|
520
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
521
|
+
drop_input_cols=self._drop_input_cols,
|
522
|
+
expected_output_cols_list=self.output_cols,
|
523
|
+
)
|
524
|
+
self._sklearn_object = fitted_estimator
|
525
|
+
self._is_fitted = True
|
526
|
+
return output_result
|
513
527
|
|
514
528
|
|
515
529
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -600,10 +614,8 @@ class SelectFdr(BaseTransformer):
|
|
600
614
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
601
615
|
|
602
616
|
if isinstance(dataset, DataFrame):
|
603
|
-
self.
|
604
|
-
|
605
|
-
inference_method=inference_method,
|
606
|
-
)
|
617
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
618
|
+
self._deps = self._get_dependencies()
|
607
619
|
assert isinstance(
|
608
620
|
dataset._session, Session
|
609
621
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -668,10 +680,8 @@ class SelectFdr(BaseTransformer):
|
|
668
680
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
669
681
|
|
670
682
|
if isinstance(dataset, DataFrame):
|
671
|
-
self.
|
672
|
-
|
673
|
-
inference_method=inference_method,
|
674
|
-
)
|
683
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
684
|
+
self._deps = self._get_dependencies()
|
675
685
|
assert isinstance(
|
676
686
|
dataset._session, Session
|
677
687
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -733,10 +743,8 @@ class SelectFdr(BaseTransformer):
|
|
733
743
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
734
744
|
|
735
745
|
if isinstance(dataset, DataFrame):
|
736
|
-
self.
|
737
|
-
|
738
|
-
inference_method=inference_method,
|
739
|
-
)
|
746
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
747
|
+
self._deps = self._get_dependencies()
|
740
748
|
assert isinstance(
|
741
749
|
dataset._session, Session
|
742
750
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -802,10 +810,8 @@ class SelectFdr(BaseTransformer):
|
|
802
810
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
803
811
|
|
804
812
|
if isinstance(dataset, DataFrame):
|
805
|
-
self.
|
806
|
-
|
807
|
-
inference_method=inference_method,
|
808
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
809
815
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
810
816
|
transform_kwargs = dict(
|
811
817
|
session=dataset._session,
|
@@ -867,17 +873,15 @@ class SelectFdr(BaseTransformer):
|
|
867
873
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
868
874
|
|
869
875
|
if isinstance(dataset, DataFrame):
|
870
|
-
self.
|
871
|
-
|
872
|
-
inference_method="score",
|
873
|
-
)
|
876
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
877
|
+
self._deps = self._get_dependencies()
|
874
878
|
selected_cols = self._get_active_columns()
|
875
879
|
if len(selected_cols) > 0:
|
876
880
|
dataset = dataset.select(selected_cols)
|
877
881
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
878
882
|
transform_kwargs = dict(
|
879
883
|
session=dataset._session,
|
880
|
-
dependencies=
|
884
|
+
dependencies=self._deps,
|
881
885
|
score_sproc_imports=['sklearn'],
|
882
886
|
)
|
883
887
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -942,11 +946,8 @@ class SelectFdr(BaseTransformer):
|
|
942
946
|
|
943
947
|
if isinstance(dataset, DataFrame):
|
944
948
|
|
945
|
-
self.
|
946
|
-
|
947
|
-
inference_method=inference_method,
|
948
|
-
|
949
|
-
)
|
949
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
950
|
+
self._deps = self._get_dependencies()
|
950
951
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
951
952
|
transform_kwargs = dict(
|
952
953
|
session = dataset._session,
|
@@ -61,12 +61,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.feature_selection".repla
|
|
61
61
|
|
62
62
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
63
63
|
|
64
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
65
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
66
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
67
|
-
return check
|
68
|
-
|
69
|
-
|
70
64
|
class SelectFpr(BaseTransformer):
|
71
65
|
r"""Filter: Select the pvalues below alpha based on a FPR test
|
72
66
|
For more details on this class, see [sklearn.feature_selection.SelectFpr]
|
@@ -266,20 +260,17 @@ class SelectFpr(BaseTransformer):
|
|
266
260
|
self,
|
267
261
|
dataset: DataFrame,
|
268
262
|
inference_method: str,
|
269
|
-
) ->
|
270
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
271
|
-
return the available package that exists in the snowflake anaconda channel
|
263
|
+
) -> None:
|
264
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
272
265
|
|
273
266
|
Args:
|
274
267
|
dataset: snowpark dataframe
|
275
268
|
inference_method: the inference method such as predict, score...
|
276
|
-
|
269
|
+
|
277
270
|
Raises:
|
278
271
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
279
272
|
SnowflakeMLException: If the session is None, raise error
|
280
273
|
|
281
|
-
Returns:
|
282
|
-
A list of available package that exists in the snowflake anaconda channel
|
283
274
|
"""
|
284
275
|
if not self._is_fitted:
|
285
276
|
raise exceptions.SnowflakeMLException(
|
@@ -297,9 +288,7 @@ class SelectFpr(BaseTransformer):
|
|
297
288
|
"Session must not specified for snowpark dataset."
|
298
289
|
),
|
299
290
|
)
|
300
|
-
|
301
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
302
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
291
|
+
|
303
292
|
|
304
293
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
305
294
|
@telemetry.send_api_usage_telemetry(
|
@@ -345,7 +334,8 @@ class SelectFpr(BaseTransformer):
|
|
345
334
|
|
346
335
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
347
336
|
|
348
|
-
self.
|
337
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
338
|
+
self._deps = self._get_dependencies()
|
349
339
|
assert isinstance(
|
350
340
|
dataset._session, Session
|
351
341
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -430,10 +420,8 @@ class SelectFpr(BaseTransformer):
|
|
430
420
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
431
421
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
432
422
|
|
433
|
-
self.
|
434
|
-
|
435
|
-
inference_method=inference_method,
|
436
|
-
)
|
423
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
424
|
+
self._deps = self._get_dependencies()
|
437
425
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
438
426
|
|
439
427
|
transform_kwargs = dict(
|
@@ -500,16 +488,42 @@ class SelectFpr(BaseTransformer):
|
|
500
488
|
self._is_fitted = True
|
501
489
|
return output_result
|
502
490
|
|
491
|
+
|
492
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
493
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
494
|
+
""" Fit to data, then transform it
|
495
|
+
For more details on this function, see [sklearn.feature_selection.SelectFpr.fit_transform]
|
496
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.SelectFpr.html#sklearn.feature_selection.SelectFpr.fit_transform)
|
497
|
+
|
503
498
|
|
504
|
-
|
505
|
-
|
506
|
-
|
499
|
+
Raises:
|
500
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
501
|
+
|
502
|
+
Args:
|
503
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
504
|
+
Snowpark or Pandas DataFrame.
|
505
|
+
output_cols_prefix: Prefix for the response columns
|
507
506
|
Returns:
|
508
507
|
Transformed dataset.
|
509
508
|
"""
|
510
|
-
self.
|
511
|
-
|
512
|
-
|
509
|
+
self._infer_input_output_cols(dataset)
|
510
|
+
super()._check_dataset_type(dataset)
|
511
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
512
|
+
estimator=self._sklearn_object,
|
513
|
+
dataset=dataset,
|
514
|
+
input_cols=self.input_cols,
|
515
|
+
label_cols=self.label_cols,
|
516
|
+
sample_weight_col=self.sample_weight_col,
|
517
|
+
autogenerated=self._autogenerated,
|
518
|
+
subproject=_SUBPROJECT,
|
519
|
+
)
|
520
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
521
|
+
drop_input_cols=self._drop_input_cols,
|
522
|
+
expected_output_cols_list=self.output_cols,
|
523
|
+
)
|
524
|
+
self._sklearn_object = fitted_estimator
|
525
|
+
self._is_fitted = True
|
526
|
+
return output_result
|
513
527
|
|
514
528
|
|
515
529
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -600,10 +614,8 @@ class SelectFpr(BaseTransformer):
|
|
600
614
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
601
615
|
|
602
616
|
if isinstance(dataset, DataFrame):
|
603
|
-
self.
|
604
|
-
|
605
|
-
inference_method=inference_method,
|
606
|
-
)
|
617
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
618
|
+
self._deps = self._get_dependencies()
|
607
619
|
assert isinstance(
|
608
620
|
dataset._session, Session
|
609
621
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -668,10 +680,8 @@ class SelectFpr(BaseTransformer):
|
|
668
680
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
669
681
|
|
670
682
|
if isinstance(dataset, DataFrame):
|
671
|
-
self.
|
672
|
-
|
673
|
-
inference_method=inference_method,
|
674
|
-
)
|
683
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
684
|
+
self._deps = self._get_dependencies()
|
675
685
|
assert isinstance(
|
676
686
|
dataset._session, Session
|
677
687
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -733,10 +743,8 @@ class SelectFpr(BaseTransformer):
|
|
733
743
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
734
744
|
|
735
745
|
if isinstance(dataset, DataFrame):
|
736
|
-
self.
|
737
|
-
|
738
|
-
inference_method=inference_method,
|
739
|
-
)
|
746
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
747
|
+
self._deps = self._get_dependencies()
|
740
748
|
assert isinstance(
|
741
749
|
dataset._session, Session
|
742
750
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -802,10 +810,8 @@ class SelectFpr(BaseTransformer):
|
|
802
810
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
803
811
|
|
804
812
|
if isinstance(dataset, DataFrame):
|
805
|
-
self.
|
806
|
-
|
807
|
-
inference_method=inference_method,
|
808
|
-
)
|
813
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
814
|
+
self._deps = self._get_dependencies()
|
809
815
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
810
816
|
transform_kwargs = dict(
|
811
817
|
session=dataset._session,
|
@@ -867,17 +873,15 @@ class SelectFpr(BaseTransformer):
|
|
867
873
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
868
874
|
|
869
875
|
if isinstance(dataset, DataFrame):
|
870
|
-
self.
|
871
|
-
|
872
|
-
inference_method="score",
|
873
|
-
)
|
876
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
877
|
+
self._deps = self._get_dependencies()
|
874
878
|
selected_cols = self._get_active_columns()
|
875
879
|
if len(selected_cols) > 0:
|
876
880
|
dataset = dataset.select(selected_cols)
|
877
881
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
878
882
|
transform_kwargs = dict(
|
879
883
|
session=dataset._session,
|
880
|
-
dependencies=
|
884
|
+
dependencies=self._deps,
|
881
885
|
score_sproc_imports=['sklearn'],
|
882
886
|
)
|
883
887
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -942,11 +946,8 @@ class SelectFpr(BaseTransformer):
|
|
942
946
|
|
943
947
|
if isinstance(dataset, DataFrame):
|
944
948
|
|
945
|
-
self.
|
946
|
-
|
947
|
-
inference_method=inference_method,
|
948
|
-
|
949
|
-
)
|
949
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
950
|
+
self._deps = self._get_dependencies()
|
950
951
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
951
952
|
transform_kwargs = dict(
|
952
953
|
session = dataset._session,
|