snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class KNNImputer(BaseTransformer):
|
70
64
|
r"""Imputation for completing missing values using k-Nearest Neighbors
|
71
65
|
For more details on this class, see [sklearn.impute.KNNImputer]
|
@@ -311,20 +305,17 @@ class KNNImputer(BaseTransformer):
|
|
311
305
|
self,
|
312
306
|
dataset: DataFrame,
|
313
307
|
inference_method: str,
|
314
|
-
) ->
|
315
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
316
|
-
return the available package that exists in the snowflake anaconda channel
|
308
|
+
) -> None:
|
309
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
317
310
|
|
318
311
|
Args:
|
319
312
|
dataset: snowpark dataframe
|
320
313
|
inference_method: the inference method such as predict, score...
|
321
|
-
|
314
|
+
|
322
315
|
Raises:
|
323
316
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
324
317
|
SnowflakeMLException: If the session is None, raise error
|
325
318
|
|
326
|
-
Returns:
|
327
|
-
A list of available package that exists in the snowflake anaconda channel
|
328
319
|
"""
|
329
320
|
if not self._is_fitted:
|
330
321
|
raise exceptions.SnowflakeMLException(
|
@@ -342,9 +333,7 @@ class KNNImputer(BaseTransformer):
|
|
342
333
|
"Session must not specified for snowpark dataset."
|
343
334
|
),
|
344
335
|
)
|
345
|
-
|
346
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
347
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
336
|
+
|
348
337
|
|
349
338
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
350
339
|
@telemetry.send_api_usage_telemetry(
|
@@ -390,7 +379,8 @@ class KNNImputer(BaseTransformer):
|
|
390
379
|
|
391
380
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
392
381
|
|
393
|
-
self.
|
382
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
383
|
+
self._deps = self._get_dependencies()
|
394
384
|
assert isinstance(
|
395
385
|
dataset._session, Session
|
396
386
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -475,10 +465,8 @@ class KNNImputer(BaseTransformer):
|
|
475
465
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
476
466
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
477
467
|
|
478
|
-
self.
|
479
|
-
|
480
|
-
inference_method=inference_method,
|
481
|
-
)
|
468
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
469
|
+
self._deps = self._get_dependencies()
|
482
470
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
483
471
|
|
484
472
|
transform_kwargs = dict(
|
@@ -545,16 +533,42 @@ class KNNImputer(BaseTransformer):
|
|
545
533
|
self._is_fitted = True
|
546
534
|
return output_result
|
547
535
|
|
536
|
+
|
537
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
538
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
539
|
+
""" Fit to data, then transform it
|
540
|
+
For more details on this function, see [sklearn.impute.KNNImputer.fit_transform]
|
541
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.impute.KNNImputer.html#sklearn.impute.KNNImputer.fit_transform)
|
542
|
+
|
548
543
|
|
549
|
-
|
550
|
-
|
551
|
-
|
544
|
+
Raises:
|
545
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
546
|
+
|
547
|
+
Args:
|
548
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
549
|
+
Snowpark or Pandas DataFrame.
|
550
|
+
output_cols_prefix: Prefix for the response columns
|
552
551
|
Returns:
|
553
552
|
Transformed dataset.
|
554
553
|
"""
|
555
|
-
self.
|
556
|
-
|
557
|
-
|
554
|
+
self._infer_input_output_cols(dataset)
|
555
|
+
super()._check_dataset_type(dataset)
|
556
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
557
|
+
estimator=self._sklearn_object,
|
558
|
+
dataset=dataset,
|
559
|
+
input_cols=self.input_cols,
|
560
|
+
label_cols=self.label_cols,
|
561
|
+
sample_weight_col=self.sample_weight_col,
|
562
|
+
autogenerated=self._autogenerated,
|
563
|
+
subproject=_SUBPROJECT,
|
564
|
+
)
|
565
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
566
|
+
drop_input_cols=self._drop_input_cols,
|
567
|
+
expected_output_cols_list=self.output_cols,
|
568
|
+
)
|
569
|
+
self._sklearn_object = fitted_estimator
|
570
|
+
self._is_fitted = True
|
571
|
+
return output_result
|
558
572
|
|
559
573
|
|
560
574
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -645,10 +659,8 @@ class KNNImputer(BaseTransformer):
|
|
645
659
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
646
660
|
|
647
661
|
if isinstance(dataset, DataFrame):
|
648
|
-
self.
|
649
|
-
|
650
|
-
inference_method=inference_method,
|
651
|
-
)
|
662
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
663
|
+
self._deps = self._get_dependencies()
|
652
664
|
assert isinstance(
|
653
665
|
dataset._session, Session
|
654
666
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -713,10 +725,8 @@ class KNNImputer(BaseTransformer):
|
|
713
725
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
714
726
|
|
715
727
|
if isinstance(dataset, DataFrame):
|
716
|
-
self.
|
717
|
-
|
718
|
-
inference_method=inference_method,
|
719
|
-
)
|
728
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
729
|
+
self._deps = self._get_dependencies()
|
720
730
|
assert isinstance(
|
721
731
|
dataset._session, Session
|
722
732
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -778,10 +788,8 @@ class KNNImputer(BaseTransformer):
|
|
778
788
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
779
789
|
|
780
790
|
if isinstance(dataset, DataFrame):
|
781
|
-
self.
|
782
|
-
|
783
|
-
inference_method=inference_method,
|
784
|
-
)
|
791
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
792
|
+
self._deps = self._get_dependencies()
|
785
793
|
assert isinstance(
|
786
794
|
dataset._session, Session
|
787
795
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -847,10 +855,8 @@ class KNNImputer(BaseTransformer):
|
|
847
855
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
848
856
|
|
849
857
|
if isinstance(dataset, DataFrame):
|
850
|
-
self.
|
851
|
-
|
852
|
-
inference_method=inference_method,
|
853
|
-
)
|
858
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
859
|
+
self._deps = self._get_dependencies()
|
854
860
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
855
861
|
transform_kwargs = dict(
|
856
862
|
session=dataset._session,
|
@@ -912,17 +918,15 @@ class KNNImputer(BaseTransformer):
|
|
912
918
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
913
919
|
|
914
920
|
if isinstance(dataset, DataFrame):
|
915
|
-
self.
|
916
|
-
|
917
|
-
inference_method="score",
|
918
|
-
)
|
921
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
922
|
+
self._deps = self._get_dependencies()
|
919
923
|
selected_cols = self._get_active_columns()
|
920
924
|
if len(selected_cols) > 0:
|
921
925
|
dataset = dataset.select(selected_cols)
|
922
926
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
923
927
|
transform_kwargs = dict(
|
924
928
|
session=dataset._session,
|
925
|
-
dependencies=
|
929
|
+
dependencies=self._deps,
|
926
930
|
score_sproc_imports=['sklearn'],
|
927
931
|
)
|
928
932
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -987,11 +991,8 @@ class KNNImputer(BaseTransformer):
|
|
987
991
|
|
988
992
|
if isinstance(dataset, DataFrame):
|
989
993
|
|
990
|
-
self.
|
991
|
-
|
992
|
-
inference_method=inference_method,
|
993
|
-
|
994
|
-
)
|
994
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
995
|
+
self._deps = self._get_dependencies()
|
995
996
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
996
997
|
transform_kwargs = dict(
|
997
998
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.impute".replace("sklearn
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MissingIndicator(BaseTransformer):
|
70
64
|
r"""Binary indicators for missing values
|
71
65
|
For more details on this class, see [sklearn.impute.MissingIndicator]
|
@@ -285,20 +279,17 @@ class MissingIndicator(BaseTransformer):
|
|
285
279
|
self,
|
286
280
|
dataset: DataFrame,
|
287
281
|
inference_method: str,
|
288
|
-
) ->
|
289
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
290
|
-
return the available package that exists in the snowflake anaconda channel
|
282
|
+
) -> None:
|
283
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
291
284
|
|
292
285
|
Args:
|
293
286
|
dataset: snowpark dataframe
|
294
287
|
inference_method: the inference method such as predict, score...
|
295
|
-
|
288
|
+
|
296
289
|
Raises:
|
297
290
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
298
291
|
SnowflakeMLException: If the session is None, raise error
|
299
292
|
|
300
|
-
Returns:
|
301
|
-
A list of available package that exists in the snowflake anaconda channel
|
302
293
|
"""
|
303
294
|
if not self._is_fitted:
|
304
295
|
raise exceptions.SnowflakeMLException(
|
@@ -316,9 +307,7 @@ class MissingIndicator(BaseTransformer):
|
|
316
307
|
"Session must not specified for snowpark dataset."
|
317
308
|
),
|
318
309
|
)
|
319
|
-
|
320
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
321
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
310
|
+
|
322
311
|
|
323
312
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
324
313
|
@telemetry.send_api_usage_telemetry(
|
@@ -364,7 +353,8 @@ class MissingIndicator(BaseTransformer):
|
|
364
353
|
|
365
354
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
366
355
|
|
367
|
-
self.
|
356
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
357
|
+
self._deps = self._get_dependencies()
|
368
358
|
assert isinstance(
|
369
359
|
dataset._session, Session
|
370
360
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -449,10 +439,8 @@ class MissingIndicator(BaseTransformer):
|
|
449
439
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
450
440
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
451
441
|
|
452
|
-
self.
|
453
|
-
|
454
|
-
inference_method=inference_method,
|
455
|
-
)
|
442
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
443
|
+
self._deps = self._get_dependencies()
|
456
444
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
457
445
|
|
458
446
|
transform_kwargs = dict(
|
@@ -519,16 +507,42 @@ class MissingIndicator(BaseTransformer):
|
|
519
507
|
self._is_fitted = True
|
520
508
|
return output_result
|
521
509
|
|
510
|
+
|
511
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
512
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
513
|
+
""" Generate missing values indicator for `X`
|
514
|
+
For more details on this function, see [sklearn.impute.MissingIndicator.fit_transform]
|
515
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.impute.MissingIndicator.html#sklearn.impute.MissingIndicator.fit_transform)
|
516
|
+
|
522
517
|
|
523
|
-
|
524
|
-
|
525
|
-
|
518
|
+
Raises:
|
519
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
520
|
+
|
521
|
+
Args:
|
522
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
523
|
+
Snowpark or Pandas DataFrame.
|
524
|
+
output_cols_prefix: Prefix for the response columns
|
526
525
|
Returns:
|
527
526
|
Transformed dataset.
|
528
527
|
"""
|
529
|
-
self.
|
530
|
-
|
531
|
-
|
528
|
+
self._infer_input_output_cols(dataset)
|
529
|
+
super()._check_dataset_type(dataset)
|
530
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
531
|
+
estimator=self._sklearn_object,
|
532
|
+
dataset=dataset,
|
533
|
+
input_cols=self.input_cols,
|
534
|
+
label_cols=self.label_cols,
|
535
|
+
sample_weight_col=self.sample_weight_col,
|
536
|
+
autogenerated=self._autogenerated,
|
537
|
+
subproject=_SUBPROJECT,
|
538
|
+
)
|
539
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
540
|
+
drop_input_cols=self._drop_input_cols,
|
541
|
+
expected_output_cols_list=self.output_cols,
|
542
|
+
)
|
543
|
+
self._sklearn_object = fitted_estimator
|
544
|
+
self._is_fitted = True
|
545
|
+
return output_result
|
532
546
|
|
533
547
|
|
534
548
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -619,10 +633,8 @@ class MissingIndicator(BaseTransformer):
|
|
619
633
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
620
634
|
|
621
635
|
if isinstance(dataset, DataFrame):
|
622
|
-
self.
|
623
|
-
|
624
|
-
inference_method=inference_method,
|
625
|
-
)
|
636
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
637
|
+
self._deps = self._get_dependencies()
|
626
638
|
assert isinstance(
|
627
639
|
dataset._session, Session
|
628
640
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -687,10 +699,8 @@ class MissingIndicator(BaseTransformer):
|
|
687
699
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
688
700
|
|
689
701
|
if isinstance(dataset, DataFrame):
|
690
|
-
self.
|
691
|
-
|
692
|
-
inference_method=inference_method,
|
693
|
-
)
|
702
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
703
|
+
self._deps = self._get_dependencies()
|
694
704
|
assert isinstance(
|
695
705
|
dataset._session, Session
|
696
706
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -752,10 +762,8 @@ class MissingIndicator(BaseTransformer):
|
|
752
762
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
753
763
|
|
754
764
|
if isinstance(dataset, DataFrame):
|
755
|
-
self.
|
756
|
-
|
757
|
-
inference_method=inference_method,
|
758
|
-
)
|
765
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
766
|
+
self._deps = self._get_dependencies()
|
759
767
|
assert isinstance(
|
760
768
|
dataset._session, Session
|
761
769
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -821,10 +829,8 @@ class MissingIndicator(BaseTransformer):
|
|
821
829
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
822
830
|
|
823
831
|
if isinstance(dataset, DataFrame):
|
824
|
-
self.
|
825
|
-
|
826
|
-
inference_method=inference_method,
|
827
|
-
)
|
832
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
833
|
+
self._deps = self._get_dependencies()
|
828
834
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
829
835
|
transform_kwargs = dict(
|
830
836
|
session=dataset._session,
|
@@ -886,17 +892,15 @@ class MissingIndicator(BaseTransformer):
|
|
886
892
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
887
893
|
|
888
894
|
if isinstance(dataset, DataFrame):
|
889
|
-
self.
|
890
|
-
|
891
|
-
inference_method="score",
|
892
|
-
)
|
895
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
896
|
+
self._deps = self._get_dependencies()
|
893
897
|
selected_cols = self._get_active_columns()
|
894
898
|
if len(selected_cols) > 0:
|
895
899
|
dataset = dataset.select(selected_cols)
|
896
900
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
897
901
|
transform_kwargs = dict(
|
898
902
|
session=dataset._session,
|
899
|
-
dependencies=
|
903
|
+
dependencies=self._deps,
|
900
904
|
score_sproc_imports=['sklearn'],
|
901
905
|
)
|
902
906
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -961,11 +965,8 @@ class MissingIndicator(BaseTransformer):
|
|
961
965
|
|
962
966
|
if isinstance(dataset, DataFrame):
|
963
967
|
|
964
|
-
self.
|
965
|
-
|
966
|
-
inference_method=inference_method,
|
967
|
-
|
968
|
-
)
|
968
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
969
|
+
self._deps = self._get_dependencies()
|
969
970
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
970
971
|
transform_kwargs = dict(
|
971
972
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.kernel_approximation".re
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class AdditiveChi2Sampler(BaseTransformer):
|
70
64
|
r"""Approximate feature map for additive chi2 kernel
|
71
65
|
For more details on this class, see [sklearn.kernel_approximation.AdditiveChi2Sampler]
|
@@ -260,20 +254,17 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
260
254
|
self,
|
261
255
|
dataset: DataFrame,
|
262
256
|
inference_method: str,
|
263
|
-
) ->
|
264
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
265
|
-
return the available package that exists in the snowflake anaconda channel
|
257
|
+
) -> None:
|
258
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
266
259
|
|
267
260
|
Args:
|
268
261
|
dataset: snowpark dataframe
|
269
262
|
inference_method: the inference method such as predict, score...
|
270
|
-
|
263
|
+
|
271
264
|
Raises:
|
272
265
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
273
266
|
SnowflakeMLException: If the session is None, raise error
|
274
267
|
|
275
|
-
Returns:
|
276
|
-
A list of available package that exists in the snowflake anaconda channel
|
277
268
|
"""
|
278
269
|
if not self._is_fitted:
|
279
270
|
raise exceptions.SnowflakeMLException(
|
@@ -291,9 +282,7 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
291
282
|
"Session must not specified for snowpark dataset."
|
292
283
|
),
|
293
284
|
)
|
294
|
-
|
295
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
296
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
285
|
+
|
297
286
|
|
298
287
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
299
288
|
@telemetry.send_api_usage_telemetry(
|
@@ -339,7 +328,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
339
328
|
|
340
329
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
341
330
|
|
342
|
-
self.
|
331
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
332
|
+
self._deps = self._get_dependencies()
|
343
333
|
assert isinstance(
|
344
334
|
dataset._session, Session
|
345
335
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -424,10 +414,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
424
414
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
425
415
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
426
416
|
|
427
|
-
self.
|
428
|
-
|
429
|
-
inference_method=inference_method,
|
430
|
-
)
|
417
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
418
|
+
self._deps = self._get_dependencies()
|
431
419
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
432
420
|
|
433
421
|
transform_kwargs = dict(
|
@@ -494,16 +482,42 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
494
482
|
self._is_fitted = True
|
495
483
|
return output_result
|
496
484
|
|
485
|
+
|
486
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
487
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
488
|
+
""" Fit to data, then transform it
|
489
|
+
For more details on this function, see [sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform]
|
490
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.kernel_approximation.AdditiveChi2Sampler.html#sklearn.kernel_approximation.AdditiveChi2Sampler.fit_transform)
|
491
|
+
|
497
492
|
|
498
|
-
|
499
|
-
|
500
|
-
|
493
|
+
Raises:
|
494
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
495
|
+
|
496
|
+
Args:
|
497
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
498
|
+
Snowpark or Pandas DataFrame.
|
499
|
+
output_cols_prefix: Prefix for the response columns
|
501
500
|
Returns:
|
502
501
|
Transformed dataset.
|
503
502
|
"""
|
504
|
-
self.
|
505
|
-
|
506
|
-
|
503
|
+
self._infer_input_output_cols(dataset)
|
504
|
+
super()._check_dataset_type(dataset)
|
505
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
506
|
+
estimator=self._sklearn_object,
|
507
|
+
dataset=dataset,
|
508
|
+
input_cols=self.input_cols,
|
509
|
+
label_cols=self.label_cols,
|
510
|
+
sample_weight_col=self.sample_weight_col,
|
511
|
+
autogenerated=self._autogenerated,
|
512
|
+
subproject=_SUBPROJECT,
|
513
|
+
)
|
514
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
515
|
+
drop_input_cols=self._drop_input_cols,
|
516
|
+
expected_output_cols_list=self.output_cols,
|
517
|
+
)
|
518
|
+
self._sklearn_object = fitted_estimator
|
519
|
+
self._is_fitted = True
|
520
|
+
return output_result
|
507
521
|
|
508
522
|
|
509
523
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -594,10 +608,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
594
608
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
595
609
|
|
596
610
|
if isinstance(dataset, DataFrame):
|
597
|
-
self.
|
598
|
-
|
599
|
-
inference_method=inference_method,
|
600
|
-
)
|
611
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
612
|
+
self._deps = self._get_dependencies()
|
601
613
|
assert isinstance(
|
602
614
|
dataset._session, Session
|
603
615
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -662,10 +674,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
662
674
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
663
675
|
|
664
676
|
if isinstance(dataset, DataFrame):
|
665
|
-
self.
|
666
|
-
|
667
|
-
inference_method=inference_method,
|
668
|
-
)
|
677
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
678
|
+
self._deps = self._get_dependencies()
|
669
679
|
assert isinstance(
|
670
680
|
dataset._session, Session
|
671
681
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -727,10 +737,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
727
737
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
728
738
|
|
729
739
|
if isinstance(dataset, DataFrame):
|
730
|
-
self.
|
731
|
-
|
732
|
-
inference_method=inference_method,
|
733
|
-
)
|
740
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
741
|
+
self._deps = self._get_dependencies()
|
734
742
|
assert isinstance(
|
735
743
|
dataset._session, Session
|
736
744
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -796,10 +804,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
796
804
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
797
805
|
|
798
806
|
if isinstance(dataset, DataFrame):
|
799
|
-
self.
|
800
|
-
|
801
|
-
inference_method=inference_method,
|
802
|
-
)
|
807
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
808
|
+
self._deps = self._get_dependencies()
|
803
809
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
804
810
|
transform_kwargs = dict(
|
805
811
|
session=dataset._session,
|
@@ -861,17 +867,15 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
861
867
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
862
868
|
|
863
869
|
if isinstance(dataset, DataFrame):
|
864
|
-
self.
|
865
|
-
|
866
|
-
inference_method="score",
|
867
|
-
)
|
870
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
871
|
+
self._deps = self._get_dependencies()
|
868
872
|
selected_cols = self._get_active_columns()
|
869
873
|
if len(selected_cols) > 0:
|
870
874
|
dataset = dataset.select(selected_cols)
|
871
875
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
872
876
|
transform_kwargs = dict(
|
873
877
|
session=dataset._session,
|
874
|
-
dependencies=
|
878
|
+
dependencies=self._deps,
|
875
879
|
score_sproc_imports=['sklearn'],
|
876
880
|
)
|
877
881
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -936,11 +940,8 @@ class AdditiveChi2Sampler(BaseTransformer):
|
|
936
940
|
|
937
941
|
if isinstance(dataset, DataFrame):
|
938
942
|
|
939
|
-
self.
|
940
|
-
|
941
|
-
inference_method=inference_method,
|
942
|
-
|
943
|
-
)
|
943
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
944
|
+
self._deps = self._get_dependencies()
|
944
945
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
945
946
|
transform_kwargs = dict(
|
946
947
|
session = dataset._session,
|