snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class StackingRegressor(BaseTransformer):
|
70
64
|
r"""Stack of estimators with a final regressor
|
71
65
|
For more details on this class, see [sklearn.ensemble.StackingRegressor]
|
@@ -316,20 +310,17 @@ class StackingRegressor(BaseTransformer):
|
|
316
310
|
self,
|
317
311
|
dataset: DataFrame,
|
318
312
|
inference_method: str,
|
319
|
-
) ->
|
320
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
321
|
-
return the available package that exists in the snowflake anaconda channel
|
313
|
+
) -> None:
|
314
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
322
315
|
|
323
316
|
Args:
|
324
317
|
dataset: snowpark dataframe
|
325
318
|
inference_method: the inference method such as predict, score...
|
326
|
-
|
319
|
+
|
327
320
|
Raises:
|
328
321
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
329
322
|
SnowflakeMLException: If the session is None, raise error
|
330
323
|
|
331
|
-
Returns:
|
332
|
-
A list of available package that exists in the snowflake anaconda channel
|
333
324
|
"""
|
334
325
|
if not self._is_fitted:
|
335
326
|
raise exceptions.SnowflakeMLException(
|
@@ -347,9 +338,7 @@ class StackingRegressor(BaseTransformer):
|
|
347
338
|
"Session must not specified for snowpark dataset."
|
348
339
|
),
|
349
340
|
)
|
350
|
-
|
351
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
352
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
341
|
+
|
353
342
|
|
354
343
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
355
344
|
@telemetry.send_api_usage_telemetry(
|
@@ -397,7 +386,8 @@ class StackingRegressor(BaseTransformer):
|
|
397
386
|
|
398
387
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
399
388
|
|
400
|
-
self.
|
389
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
390
|
+
self._deps = self._get_dependencies()
|
401
391
|
assert isinstance(
|
402
392
|
dataset._session, Session
|
403
393
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -482,10 +472,8 @@ class StackingRegressor(BaseTransformer):
|
|
482
472
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
483
473
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
484
474
|
|
485
|
-
self.
|
486
|
-
|
487
|
-
inference_method=inference_method,
|
488
|
-
)
|
475
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
476
|
+
self._deps = self._get_dependencies()
|
489
477
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
490
478
|
|
491
479
|
transform_kwargs = dict(
|
@@ -552,16 +540,42 @@ class StackingRegressor(BaseTransformer):
|
|
552
540
|
self._is_fitted = True
|
553
541
|
return output_result
|
554
542
|
|
543
|
+
|
544
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
545
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
546
|
+
""" Fit the estimators and return the predictions for X for each estimator
|
547
|
+
For more details on this function, see [sklearn.ensemble.StackingRegressor.fit_transform]
|
548
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.StackingRegressor.html#sklearn.ensemble.StackingRegressor.fit_transform)
|
549
|
+
|
555
550
|
|
556
|
-
|
557
|
-
|
558
|
-
|
551
|
+
Raises:
|
552
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
553
|
+
|
554
|
+
Args:
|
555
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
556
|
+
Snowpark or Pandas DataFrame.
|
557
|
+
output_cols_prefix: Prefix for the response columns
|
559
558
|
Returns:
|
560
559
|
Transformed dataset.
|
561
560
|
"""
|
562
|
-
self.
|
563
|
-
|
564
|
-
|
561
|
+
self._infer_input_output_cols(dataset)
|
562
|
+
super()._check_dataset_type(dataset)
|
563
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
564
|
+
estimator=self._sklearn_object,
|
565
|
+
dataset=dataset,
|
566
|
+
input_cols=self.input_cols,
|
567
|
+
label_cols=self.label_cols,
|
568
|
+
sample_weight_col=self.sample_weight_col,
|
569
|
+
autogenerated=self._autogenerated,
|
570
|
+
subproject=_SUBPROJECT,
|
571
|
+
)
|
572
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
573
|
+
drop_input_cols=self._drop_input_cols,
|
574
|
+
expected_output_cols_list=self.output_cols,
|
575
|
+
)
|
576
|
+
self._sklearn_object = fitted_estimator
|
577
|
+
self._is_fitted = True
|
578
|
+
return output_result
|
565
579
|
|
566
580
|
|
567
581
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -652,10 +666,8 @@ class StackingRegressor(BaseTransformer):
|
|
652
666
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
653
667
|
|
654
668
|
if isinstance(dataset, DataFrame):
|
655
|
-
self.
|
656
|
-
|
657
|
-
inference_method=inference_method,
|
658
|
-
)
|
669
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
670
|
+
self._deps = self._get_dependencies()
|
659
671
|
assert isinstance(
|
660
672
|
dataset._session, Session
|
661
673
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -720,10 +732,8 @@ class StackingRegressor(BaseTransformer):
|
|
720
732
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
721
733
|
|
722
734
|
if isinstance(dataset, DataFrame):
|
723
|
-
self.
|
724
|
-
|
725
|
-
inference_method=inference_method,
|
726
|
-
)
|
735
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
736
|
+
self._deps = self._get_dependencies()
|
727
737
|
assert isinstance(
|
728
738
|
dataset._session, Session
|
729
739
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -785,10 +795,8 @@ class StackingRegressor(BaseTransformer):
|
|
785
795
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
786
796
|
|
787
797
|
if isinstance(dataset, DataFrame):
|
788
|
-
self.
|
789
|
-
|
790
|
-
inference_method=inference_method,
|
791
|
-
)
|
798
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
799
|
+
self._deps = self._get_dependencies()
|
792
800
|
assert isinstance(
|
793
801
|
dataset._session, Session
|
794
802
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -854,10 +862,8 @@ class StackingRegressor(BaseTransformer):
|
|
854
862
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
855
863
|
|
856
864
|
if isinstance(dataset, DataFrame):
|
857
|
-
self.
|
858
|
-
|
859
|
-
inference_method=inference_method,
|
860
|
-
)
|
865
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
866
|
+
self._deps = self._get_dependencies()
|
861
867
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
862
868
|
transform_kwargs = dict(
|
863
869
|
session=dataset._session,
|
@@ -921,17 +927,15 @@ class StackingRegressor(BaseTransformer):
|
|
921
927
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
922
928
|
|
923
929
|
if isinstance(dataset, DataFrame):
|
924
|
-
self.
|
925
|
-
|
926
|
-
inference_method="score",
|
927
|
-
)
|
930
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
931
|
+
self._deps = self._get_dependencies()
|
928
932
|
selected_cols = self._get_active_columns()
|
929
933
|
if len(selected_cols) > 0:
|
930
934
|
dataset = dataset.select(selected_cols)
|
931
935
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
932
936
|
transform_kwargs = dict(
|
933
937
|
session=dataset._session,
|
934
|
-
dependencies=
|
938
|
+
dependencies=self._deps,
|
935
939
|
score_sproc_imports=['sklearn'],
|
936
940
|
)
|
937
941
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -996,11 +1000,8 @@ class StackingRegressor(BaseTransformer):
|
|
996
1000
|
|
997
1001
|
if isinstance(dataset, DataFrame):
|
998
1002
|
|
999
|
-
self.
|
1000
|
-
|
1001
|
-
inference_method=inference_method,
|
1002
|
-
|
1003
|
-
)
|
1003
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1004
|
+
self._deps = self._get_dependencies()
|
1004
1005
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1005
1006
|
transform_kwargs = dict(
|
1006
1007
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class VotingClassifier(BaseTransformer):
|
70
64
|
r"""Soft Voting/Majority Rule classifier for unfitted estimators
|
71
65
|
For more details on this class, see [sklearn.ensemble.VotingClassifier]
|
@@ -298,20 +292,17 @@ class VotingClassifier(BaseTransformer):
|
|
298
292
|
self,
|
299
293
|
dataset: DataFrame,
|
300
294
|
inference_method: str,
|
301
|
-
) ->
|
302
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
303
|
-
return the available package that exists in the snowflake anaconda channel
|
295
|
+
) -> None:
|
296
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
304
297
|
|
305
298
|
Args:
|
306
299
|
dataset: snowpark dataframe
|
307
300
|
inference_method: the inference method such as predict, score...
|
308
|
-
|
301
|
+
|
309
302
|
Raises:
|
310
303
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
311
304
|
SnowflakeMLException: If the session is None, raise error
|
312
305
|
|
313
|
-
Returns:
|
314
|
-
A list of available package that exists in the snowflake anaconda channel
|
315
306
|
"""
|
316
307
|
if not self._is_fitted:
|
317
308
|
raise exceptions.SnowflakeMLException(
|
@@ -329,9 +320,7 @@ class VotingClassifier(BaseTransformer):
|
|
329
320
|
"Session must not specified for snowpark dataset."
|
330
321
|
),
|
331
322
|
)
|
332
|
-
|
333
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
334
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
323
|
+
|
335
324
|
|
336
325
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
337
326
|
@telemetry.send_api_usage_telemetry(
|
@@ -379,7 +368,8 @@ class VotingClassifier(BaseTransformer):
|
|
379
368
|
|
380
369
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
381
370
|
|
382
|
-
self.
|
371
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
372
|
+
self._deps = self._get_dependencies()
|
383
373
|
assert isinstance(
|
384
374
|
dataset._session, Session
|
385
375
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -464,10 +454,8 @@ class VotingClassifier(BaseTransformer):
|
|
464
454
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
465
455
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
466
456
|
|
467
|
-
self.
|
468
|
-
|
469
|
-
inference_method=inference_method,
|
470
|
-
)
|
457
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
458
|
+
self._deps = self._get_dependencies()
|
471
459
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
472
460
|
|
473
461
|
transform_kwargs = dict(
|
@@ -534,16 +522,42 @@ class VotingClassifier(BaseTransformer):
|
|
534
522
|
self._is_fitted = True
|
535
523
|
return output_result
|
536
524
|
|
525
|
+
|
526
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
527
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
528
|
+
""" Return class labels or probabilities for each estimator
|
529
|
+
For more details on this function, see [sklearn.ensemble.VotingClassifier.fit_transform]
|
530
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingClassifier.html#sklearn.ensemble.VotingClassifier.fit_transform)
|
531
|
+
|
537
532
|
|
538
|
-
|
539
|
-
|
540
|
-
|
533
|
+
Raises:
|
534
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
535
|
+
|
536
|
+
Args:
|
537
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
538
|
+
Snowpark or Pandas DataFrame.
|
539
|
+
output_cols_prefix: Prefix for the response columns
|
541
540
|
Returns:
|
542
541
|
Transformed dataset.
|
543
542
|
"""
|
544
|
-
self.
|
545
|
-
|
546
|
-
|
543
|
+
self._infer_input_output_cols(dataset)
|
544
|
+
super()._check_dataset_type(dataset)
|
545
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
546
|
+
estimator=self._sklearn_object,
|
547
|
+
dataset=dataset,
|
548
|
+
input_cols=self.input_cols,
|
549
|
+
label_cols=self.label_cols,
|
550
|
+
sample_weight_col=self.sample_weight_col,
|
551
|
+
autogenerated=self._autogenerated,
|
552
|
+
subproject=_SUBPROJECT,
|
553
|
+
)
|
554
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
555
|
+
drop_input_cols=self._drop_input_cols,
|
556
|
+
expected_output_cols_list=self.output_cols,
|
557
|
+
)
|
558
|
+
self._sklearn_object = fitted_estimator
|
559
|
+
self._is_fitted = True
|
560
|
+
return output_result
|
547
561
|
|
548
562
|
|
549
563
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -636,10 +650,8 @@ class VotingClassifier(BaseTransformer):
|
|
636
650
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
637
651
|
|
638
652
|
if isinstance(dataset, DataFrame):
|
639
|
-
self.
|
640
|
-
|
641
|
-
inference_method=inference_method,
|
642
|
-
)
|
653
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
654
|
+
self._deps = self._get_dependencies()
|
643
655
|
assert isinstance(
|
644
656
|
dataset._session, Session
|
645
657
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -706,10 +718,8 @@ class VotingClassifier(BaseTransformer):
|
|
706
718
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
707
719
|
|
708
720
|
if isinstance(dataset, DataFrame):
|
709
|
-
self.
|
710
|
-
|
711
|
-
inference_method=inference_method,
|
712
|
-
)
|
721
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
722
|
+
self._deps = self._get_dependencies()
|
713
723
|
assert isinstance(
|
714
724
|
dataset._session, Session
|
715
725
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -771,10 +781,8 @@ class VotingClassifier(BaseTransformer):
|
|
771
781
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
772
782
|
|
773
783
|
if isinstance(dataset, DataFrame):
|
774
|
-
self.
|
775
|
-
|
776
|
-
inference_method=inference_method,
|
777
|
-
)
|
784
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
785
|
+
self._deps = self._get_dependencies()
|
778
786
|
assert isinstance(
|
779
787
|
dataset._session, Session
|
780
788
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -840,10 +848,8 @@ class VotingClassifier(BaseTransformer):
|
|
840
848
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
841
849
|
|
842
850
|
if isinstance(dataset, DataFrame):
|
843
|
-
self.
|
844
|
-
|
845
|
-
inference_method=inference_method,
|
846
|
-
)
|
851
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
852
|
+
self._deps = self._get_dependencies()
|
847
853
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
848
854
|
transform_kwargs = dict(
|
849
855
|
session=dataset._session,
|
@@ -907,17 +913,15 @@ class VotingClassifier(BaseTransformer):
|
|
907
913
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
908
914
|
|
909
915
|
if isinstance(dataset, DataFrame):
|
910
|
-
self.
|
911
|
-
|
912
|
-
inference_method="score",
|
913
|
-
)
|
916
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
917
|
+
self._deps = self._get_dependencies()
|
914
918
|
selected_cols = self._get_active_columns()
|
915
919
|
if len(selected_cols) > 0:
|
916
920
|
dataset = dataset.select(selected_cols)
|
917
921
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
918
922
|
transform_kwargs = dict(
|
919
923
|
session=dataset._session,
|
920
|
-
dependencies=
|
924
|
+
dependencies=self._deps,
|
921
925
|
score_sproc_imports=['sklearn'],
|
922
926
|
)
|
923
927
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -982,11 +986,8 @@ class VotingClassifier(BaseTransformer):
|
|
982
986
|
|
983
987
|
if isinstance(dataset, DataFrame):
|
984
988
|
|
985
|
-
self.
|
986
|
-
|
987
|
-
inference_method=inference_method,
|
988
|
-
|
989
|
-
)
|
989
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
990
|
+
self._deps = self._get_dependencies()
|
990
991
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
991
992
|
transform_kwargs = dict(
|
992
993
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class VotingRegressor(BaseTransformer):
|
70
64
|
r"""Prediction voting regressor for unfitted estimators
|
71
65
|
For more details on this class, see [sklearn.ensemble.VotingRegressor]
|
@@ -280,20 +274,17 @@ class VotingRegressor(BaseTransformer):
|
|
280
274
|
self,
|
281
275
|
dataset: DataFrame,
|
282
276
|
inference_method: str,
|
283
|
-
) ->
|
284
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
285
|
-
return the available package that exists in the snowflake anaconda channel
|
277
|
+
) -> None:
|
278
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
286
279
|
|
287
280
|
Args:
|
288
281
|
dataset: snowpark dataframe
|
289
282
|
inference_method: the inference method such as predict, score...
|
290
|
-
|
283
|
+
|
291
284
|
Raises:
|
292
285
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
293
286
|
SnowflakeMLException: If the session is None, raise error
|
294
287
|
|
295
|
-
Returns:
|
296
|
-
A list of available package that exists in the snowflake anaconda channel
|
297
288
|
"""
|
298
289
|
if not self._is_fitted:
|
299
290
|
raise exceptions.SnowflakeMLException(
|
@@ -311,9 +302,7 @@ class VotingRegressor(BaseTransformer):
|
|
311
302
|
"Session must not specified for snowpark dataset."
|
312
303
|
),
|
313
304
|
)
|
314
|
-
|
315
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
316
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
305
|
+
|
317
306
|
|
318
307
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
319
308
|
@telemetry.send_api_usage_telemetry(
|
@@ -361,7 +350,8 @@ class VotingRegressor(BaseTransformer):
|
|
361
350
|
|
362
351
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
363
352
|
|
364
|
-
self.
|
353
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
354
|
+
self._deps = self._get_dependencies()
|
365
355
|
assert isinstance(
|
366
356
|
dataset._session, Session
|
367
357
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -446,10 +436,8 @@ class VotingRegressor(BaseTransformer):
|
|
446
436
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
447
437
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
448
438
|
|
449
|
-
self.
|
450
|
-
|
451
|
-
inference_method=inference_method,
|
452
|
-
)
|
439
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
440
|
+
self._deps = self._get_dependencies()
|
453
441
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
454
442
|
|
455
443
|
transform_kwargs = dict(
|
@@ -516,16 +504,42 @@ class VotingRegressor(BaseTransformer):
|
|
516
504
|
self._is_fitted = True
|
517
505
|
return output_result
|
518
506
|
|
507
|
+
|
508
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
509
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
510
|
+
""" Return class labels or probabilities for each estimator
|
511
|
+
For more details on this function, see [sklearn.ensemble.VotingRegressor.fit_transform]
|
512
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.VotingRegressor.html#sklearn.ensemble.VotingRegressor.fit_transform)
|
513
|
+
|
519
514
|
|
520
|
-
|
521
|
-
|
522
|
-
|
515
|
+
Raises:
|
516
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
517
|
+
|
518
|
+
Args:
|
519
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
520
|
+
Snowpark or Pandas DataFrame.
|
521
|
+
output_cols_prefix: Prefix for the response columns
|
523
522
|
Returns:
|
524
523
|
Transformed dataset.
|
525
524
|
"""
|
526
|
-
self.
|
527
|
-
|
528
|
-
|
525
|
+
self._infer_input_output_cols(dataset)
|
526
|
+
super()._check_dataset_type(dataset)
|
527
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
528
|
+
estimator=self._sklearn_object,
|
529
|
+
dataset=dataset,
|
530
|
+
input_cols=self.input_cols,
|
531
|
+
label_cols=self.label_cols,
|
532
|
+
sample_weight_col=self.sample_weight_col,
|
533
|
+
autogenerated=self._autogenerated,
|
534
|
+
subproject=_SUBPROJECT,
|
535
|
+
)
|
536
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
537
|
+
drop_input_cols=self._drop_input_cols,
|
538
|
+
expected_output_cols_list=self.output_cols,
|
539
|
+
)
|
540
|
+
self._sklearn_object = fitted_estimator
|
541
|
+
self._is_fitted = True
|
542
|
+
return output_result
|
529
543
|
|
530
544
|
|
531
545
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -616,10 +630,8 @@ class VotingRegressor(BaseTransformer):
|
|
616
630
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
617
631
|
|
618
632
|
if isinstance(dataset, DataFrame):
|
619
|
-
self.
|
620
|
-
|
621
|
-
inference_method=inference_method,
|
622
|
-
)
|
633
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
634
|
+
self._deps = self._get_dependencies()
|
623
635
|
assert isinstance(
|
624
636
|
dataset._session, Session
|
625
637
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -684,10 +696,8 @@ class VotingRegressor(BaseTransformer):
|
|
684
696
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
685
697
|
|
686
698
|
if isinstance(dataset, DataFrame):
|
687
|
-
self.
|
688
|
-
|
689
|
-
inference_method=inference_method,
|
690
|
-
)
|
699
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
700
|
+
self._deps = self._get_dependencies()
|
691
701
|
assert isinstance(
|
692
702
|
dataset._session, Session
|
693
703
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -749,10 +759,8 @@ class VotingRegressor(BaseTransformer):
|
|
749
759
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
750
760
|
|
751
761
|
if isinstance(dataset, DataFrame):
|
752
|
-
self.
|
753
|
-
|
754
|
-
inference_method=inference_method,
|
755
|
-
)
|
762
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
763
|
+
self._deps = self._get_dependencies()
|
756
764
|
assert isinstance(
|
757
765
|
dataset._session, Session
|
758
766
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -818,10 +826,8 @@ class VotingRegressor(BaseTransformer):
|
|
818
826
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
819
827
|
|
820
828
|
if isinstance(dataset, DataFrame):
|
821
|
-
self.
|
822
|
-
|
823
|
-
inference_method=inference_method,
|
824
|
-
)
|
829
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
830
|
+
self._deps = self._get_dependencies()
|
825
831
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
826
832
|
transform_kwargs = dict(
|
827
833
|
session=dataset._session,
|
@@ -885,17 +891,15 @@ class VotingRegressor(BaseTransformer):
|
|
885
891
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
886
892
|
|
887
893
|
if isinstance(dataset, DataFrame):
|
888
|
-
self.
|
889
|
-
|
890
|
-
inference_method="score",
|
891
|
-
)
|
894
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
895
|
+
self._deps = self._get_dependencies()
|
892
896
|
selected_cols = self._get_active_columns()
|
893
897
|
if len(selected_cols) > 0:
|
894
898
|
dataset = dataset.select(selected_cols)
|
895
899
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
896
900
|
transform_kwargs = dict(
|
897
901
|
session=dataset._session,
|
898
|
-
dependencies=
|
902
|
+
dependencies=self._deps,
|
899
903
|
score_sproc_imports=['sklearn'],
|
900
904
|
)
|
901
905
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -960,11 +964,8 @@ class VotingRegressor(BaseTransformer):
|
|
960
964
|
|
961
965
|
if isinstance(dataset, DataFrame):
|
962
966
|
|
963
|
-
self.
|
964
|
-
|
965
|
-
inference_method=inference_method,
|
966
|
-
|
967
|
-
)
|
967
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
968
|
+
self._deps = self._get_dependencies()
|
968
969
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
969
970
|
transform_kwargs = dict(
|
970
971
|
session = dataset._session,
|