snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".r
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class LinearDiscriminantAnalysis(BaseTransformer):
|
70
64
|
r"""Linear Discriminant Analysis
|
71
65
|
For more details on this class, see [sklearn.discriminant_analysis.LinearDiscriminantAnalysis]
|
@@ -318,20 +312,17 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
318
312
|
self,
|
319
313
|
dataset: DataFrame,
|
320
314
|
inference_method: str,
|
321
|
-
) ->
|
322
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
323
|
-
return the available package that exists in the snowflake anaconda channel
|
315
|
+
) -> None:
|
316
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
324
317
|
|
325
318
|
Args:
|
326
319
|
dataset: snowpark dataframe
|
327
320
|
inference_method: the inference method such as predict, score...
|
328
|
-
|
321
|
+
|
329
322
|
Raises:
|
330
323
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
331
324
|
SnowflakeMLException: If the session is None, raise error
|
332
325
|
|
333
|
-
Returns:
|
334
|
-
A list of available package that exists in the snowflake anaconda channel
|
335
326
|
"""
|
336
327
|
if not self._is_fitted:
|
337
328
|
raise exceptions.SnowflakeMLException(
|
@@ -349,9 +340,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
349
340
|
"Session must not specified for snowpark dataset."
|
350
341
|
),
|
351
342
|
)
|
352
|
-
|
353
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
354
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
343
|
+
|
355
344
|
|
356
345
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
357
346
|
@telemetry.send_api_usage_telemetry(
|
@@ -399,7 +388,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
399
388
|
|
400
389
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
401
390
|
|
402
|
-
self.
|
391
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
392
|
+
self._deps = self._get_dependencies()
|
403
393
|
assert isinstance(
|
404
394
|
dataset._session, Session
|
405
395
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -484,10 +474,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
484
474
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
485
475
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
486
476
|
|
487
|
-
self.
|
488
|
-
|
489
|
-
inference_method=inference_method,
|
490
|
-
)
|
477
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
478
|
+
self._deps = self._get_dependencies()
|
491
479
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
492
480
|
|
493
481
|
transform_kwargs = dict(
|
@@ -554,16 +542,42 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
554
542
|
self._is_fitted = True
|
555
543
|
return output_result
|
556
544
|
|
545
|
+
|
546
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
547
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
548
|
+
""" Fit to data, then transform it
|
549
|
+
For more details on this function, see [sklearn.discriminant_analysis.LinearDiscriminantAnalysis.fit_transform]
|
550
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.discriminant_analysis.LinearDiscriminantAnalysis.html#sklearn.discriminant_analysis.LinearDiscriminantAnalysis.fit_transform)
|
551
|
+
|
557
552
|
|
558
|
-
|
559
|
-
|
560
|
-
|
553
|
+
Raises:
|
554
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
558
|
+
Snowpark or Pandas DataFrame.
|
559
|
+
output_cols_prefix: Prefix for the response columns
|
561
560
|
Returns:
|
562
561
|
Transformed dataset.
|
563
562
|
"""
|
564
|
-
self.
|
565
|
-
|
566
|
-
|
563
|
+
self._infer_input_output_cols(dataset)
|
564
|
+
super()._check_dataset_type(dataset)
|
565
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
566
|
+
estimator=self._sklearn_object,
|
567
|
+
dataset=dataset,
|
568
|
+
input_cols=self.input_cols,
|
569
|
+
label_cols=self.label_cols,
|
570
|
+
sample_weight_col=self.sample_weight_col,
|
571
|
+
autogenerated=self._autogenerated,
|
572
|
+
subproject=_SUBPROJECT,
|
573
|
+
)
|
574
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
575
|
+
drop_input_cols=self._drop_input_cols,
|
576
|
+
expected_output_cols_list=self.output_cols,
|
577
|
+
)
|
578
|
+
self._sklearn_object = fitted_estimator
|
579
|
+
self._is_fitted = True
|
580
|
+
return output_result
|
567
581
|
|
568
582
|
|
569
583
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -656,10 +670,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
656
670
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
657
671
|
|
658
672
|
if isinstance(dataset, DataFrame):
|
659
|
-
self.
|
660
|
-
|
661
|
-
inference_method=inference_method,
|
662
|
-
)
|
673
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
674
|
+
self._deps = self._get_dependencies()
|
663
675
|
assert isinstance(
|
664
676
|
dataset._session, Session
|
665
677
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -726,10 +738,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
726
738
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
727
739
|
|
728
740
|
if isinstance(dataset, DataFrame):
|
729
|
-
self.
|
730
|
-
|
731
|
-
inference_method=inference_method,
|
732
|
-
)
|
741
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
742
|
+
self._deps = self._get_dependencies()
|
733
743
|
assert isinstance(
|
734
744
|
dataset._session, Session
|
735
745
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -793,10 +803,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
793
803
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
794
804
|
|
795
805
|
if isinstance(dataset, DataFrame):
|
796
|
-
self.
|
797
|
-
|
798
|
-
inference_method=inference_method,
|
799
|
-
)
|
806
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
807
|
+
self._deps = self._get_dependencies()
|
800
808
|
assert isinstance(
|
801
809
|
dataset._session, Session
|
802
810
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -862,10 +870,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
862
870
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
863
871
|
|
864
872
|
if isinstance(dataset, DataFrame):
|
865
|
-
self.
|
866
|
-
|
867
|
-
inference_method=inference_method,
|
868
|
-
)
|
873
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
874
|
+
self._deps = self._get_dependencies()
|
869
875
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
870
876
|
transform_kwargs = dict(
|
871
877
|
session=dataset._session,
|
@@ -929,17 +935,15 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
929
935
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
930
936
|
|
931
937
|
if isinstance(dataset, DataFrame):
|
932
|
-
self.
|
933
|
-
|
934
|
-
inference_method="score",
|
935
|
-
)
|
938
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
939
|
+
self._deps = self._get_dependencies()
|
936
940
|
selected_cols = self._get_active_columns()
|
937
941
|
if len(selected_cols) > 0:
|
938
942
|
dataset = dataset.select(selected_cols)
|
939
943
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
940
944
|
transform_kwargs = dict(
|
941
945
|
session=dataset._session,
|
942
|
-
dependencies=
|
946
|
+
dependencies=self._deps,
|
943
947
|
score_sproc_imports=['sklearn'],
|
944
948
|
)
|
945
949
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1004,11 +1008,8 @@ class LinearDiscriminantAnalysis(BaseTransformer):
|
|
1004
1008
|
|
1005
1009
|
if isinstance(dataset, DataFrame):
|
1006
1010
|
|
1007
|
-
self.
|
1008
|
-
|
1009
|
-
inference_method=inference_method,
|
1010
|
-
|
1011
|
-
)
|
1011
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1012
|
+
self._deps = self._get_dependencies()
|
1012
1013
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1013
1014
|
transform_kwargs = dict(
|
1014
1015
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.discriminant_analysis".r
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class QuadraticDiscriminantAnalysis(BaseTransformer):
|
70
64
|
r"""Quadratic Discriminant Analysis
|
71
65
|
For more details on this class, see [sklearn.discriminant_analysis.QuadraticDiscriminantAnalysis]
|
@@ -280,20 +274,17 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
280
274
|
self,
|
281
275
|
dataset: DataFrame,
|
282
276
|
inference_method: str,
|
283
|
-
) ->
|
284
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
285
|
-
return the available package that exists in the snowflake anaconda channel
|
277
|
+
) -> None:
|
278
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
286
279
|
|
287
280
|
Args:
|
288
281
|
dataset: snowpark dataframe
|
289
282
|
inference_method: the inference method such as predict, score...
|
290
|
-
|
283
|
+
|
291
284
|
Raises:
|
292
285
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
293
286
|
SnowflakeMLException: If the session is None, raise error
|
294
287
|
|
295
|
-
Returns:
|
296
|
-
A list of available package that exists in the snowflake anaconda channel
|
297
288
|
"""
|
298
289
|
if not self._is_fitted:
|
299
290
|
raise exceptions.SnowflakeMLException(
|
@@ -311,9 +302,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
311
302
|
"Session must not specified for snowpark dataset."
|
312
303
|
),
|
313
304
|
)
|
314
|
-
|
315
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
316
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
305
|
+
|
317
306
|
|
318
307
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
319
308
|
@telemetry.send_api_usage_telemetry(
|
@@ -361,7 +350,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
361
350
|
|
362
351
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
363
352
|
|
364
|
-
self.
|
353
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
354
|
+
self._deps = self._get_dependencies()
|
365
355
|
assert isinstance(
|
366
356
|
dataset._session, Session
|
367
357
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -444,10 +434,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
444
434
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
445
435
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
446
436
|
|
447
|
-
self.
|
448
|
-
|
449
|
-
inference_method=inference_method,
|
450
|
-
)
|
437
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
438
|
+
self._deps = self._get_dependencies()
|
451
439
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
452
440
|
|
453
441
|
transform_kwargs = dict(
|
@@ -514,16 +502,40 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
514
502
|
self._is_fitted = True
|
515
503
|
return output_result
|
516
504
|
|
505
|
+
|
506
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
507
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
508
|
+
""" Method not supported for this class.
|
517
509
|
|
518
|
-
|
519
|
-
|
520
|
-
|
510
|
+
|
511
|
+
Raises:
|
512
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
513
|
+
|
514
|
+
Args:
|
515
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
516
|
+
Snowpark or Pandas DataFrame.
|
517
|
+
output_cols_prefix: Prefix for the response columns
|
521
518
|
Returns:
|
522
519
|
Transformed dataset.
|
523
520
|
"""
|
524
|
-
self.
|
525
|
-
|
526
|
-
|
521
|
+
self._infer_input_output_cols(dataset)
|
522
|
+
super()._check_dataset_type(dataset)
|
523
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
524
|
+
estimator=self._sklearn_object,
|
525
|
+
dataset=dataset,
|
526
|
+
input_cols=self.input_cols,
|
527
|
+
label_cols=self.label_cols,
|
528
|
+
sample_weight_col=self.sample_weight_col,
|
529
|
+
autogenerated=self._autogenerated,
|
530
|
+
subproject=_SUBPROJECT,
|
531
|
+
)
|
532
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
534
|
+
expected_output_cols_list=self.output_cols,
|
535
|
+
)
|
536
|
+
self._sklearn_object = fitted_estimator
|
537
|
+
self._is_fitted = True
|
538
|
+
return output_result
|
527
539
|
|
528
540
|
|
529
541
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -616,10 +628,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
616
628
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
617
629
|
|
618
630
|
if isinstance(dataset, DataFrame):
|
619
|
-
self.
|
620
|
-
|
621
|
-
inference_method=inference_method,
|
622
|
-
)
|
631
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
632
|
+
self._deps = self._get_dependencies()
|
623
633
|
assert isinstance(
|
624
634
|
dataset._session, Session
|
625
635
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -686,10 +696,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
686
696
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
687
697
|
|
688
698
|
if isinstance(dataset, DataFrame):
|
689
|
-
self.
|
690
|
-
|
691
|
-
inference_method=inference_method,
|
692
|
-
)
|
699
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
700
|
+
self._deps = self._get_dependencies()
|
693
701
|
assert isinstance(
|
694
702
|
dataset._session, Session
|
695
703
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -753,10 +761,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
753
761
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
754
762
|
|
755
763
|
if isinstance(dataset, DataFrame):
|
756
|
-
self.
|
757
|
-
|
758
|
-
inference_method=inference_method,
|
759
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
760
766
|
assert isinstance(
|
761
767
|
dataset._session, Session
|
762
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -822,10 +828,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
822
828
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
823
829
|
|
824
830
|
if isinstance(dataset, DataFrame):
|
825
|
-
self.
|
826
|
-
|
827
|
-
inference_method=inference_method,
|
828
|
-
)
|
831
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
832
|
+
self._deps = self._get_dependencies()
|
829
833
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
830
834
|
transform_kwargs = dict(
|
831
835
|
session=dataset._session,
|
@@ -889,17 +893,15 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
889
893
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
890
894
|
|
891
895
|
if isinstance(dataset, DataFrame):
|
892
|
-
self.
|
893
|
-
|
894
|
-
inference_method="score",
|
895
|
-
)
|
896
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
897
|
+
self._deps = self._get_dependencies()
|
896
898
|
selected_cols = self._get_active_columns()
|
897
899
|
if len(selected_cols) > 0:
|
898
900
|
dataset = dataset.select(selected_cols)
|
899
901
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
900
902
|
transform_kwargs = dict(
|
901
903
|
session=dataset._session,
|
902
|
-
dependencies=
|
904
|
+
dependencies=self._deps,
|
903
905
|
score_sproc_imports=['sklearn'],
|
904
906
|
)
|
905
907
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -964,11 +966,8 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
|
|
964
966
|
|
965
967
|
if isinstance(dataset, DataFrame):
|
966
968
|
|
967
|
-
self.
|
968
|
-
|
969
|
-
inference_method=inference_method,
|
970
|
-
|
971
|
-
)
|
969
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
970
|
+
self._deps = self._get_dependencies()
|
972
971
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
973
972
|
transform_kwargs = dict(
|
974
973
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.ensemble".replace("sklea
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class AdaBoostClassifier(BaseTransformer):
|
70
64
|
r"""An AdaBoost classifier
|
71
65
|
For more details on this class, see [sklearn.ensemble.AdaBoostClassifier]
|
@@ -305,20 +299,17 @@ class AdaBoostClassifier(BaseTransformer):
|
|
305
299
|
self,
|
306
300
|
dataset: DataFrame,
|
307
301
|
inference_method: str,
|
308
|
-
) ->
|
309
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
310
|
-
return the available package that exists in the snowflake anaconda channel
|
302
|
+
) -> None:
|
303
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
311
304
|
|
312
305
|
Args:
|
313
306
|
dataset: snowpark dataframe
|
314
307
|
inference_method: the inference method such as predict, score...
|
315
|
-
|
308
|
+
|
316
309
|
Raises:
|
317
310
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
318
311
|
SnowflakeMLException: If the session is None, raise error
|
319
312
|
|
320
|
-
Returns:
|
321
|
-
A list of available package that exists in the snowflake anaconda channel
|
322
313
|
"""
|
323
314
|
if not self._is_fitted:
|
324
315
|
raise exceptions.SnowflakeMLException(
|
@@ -336,9 +327,7 @@ class AdaBoostClassifier(BaseTransformer):
|
|
336
327
|
"Session must not specified for snowpark dataset."
|
337
328
|
),
|
338
329
|
)
|
339
|
-
|
340
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
341
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
330
|
+
|
342
331
|
|
343
332
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
344
333
|
@telemetry.send_api_usage_telemetry(
|
@@ -386,7 +375,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
386
375
|
|
387
376
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
388
377
|
|
389
|
-
self.
|
378
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
379
|
+
self._deps = self._get_dependencies()
|
390
380
|
assert isinstance(
|
391
381
|
dataset._session, Session
|
392
382
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -469,10 +459,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
469
459
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
470
460
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
471
461
|
|
472
|
-
self.
|
473
|
-
|
474
|
-
inference_method=inference_method,
|
475
|
-
)
|
462
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
463
|
+
self._deps = self._get_dependencies()
|
476
464
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
477
465
|
|
478
466
|
transform_kwargs = dict(
|
@@ -539,16 +527,40 @@ class AdaBoostClassifier(BaseTransformer):
|
|
539
527
|
self._is_fitted = True
|
540
528
|
return output_result
|
541
529
|
|
530
|
+
|
531
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
532
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
533
|
+
""" Method not supported for this class.
|
542
534
|
|
543
|
-
|
544
|
-
|
545
|
-
|
535
|
+
|
536
|
+
Raises:
|
537
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
538
|
+
|
539
|
+
Args:
|
540
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
541
|
+
Snowpark or Pandas DataFrame.
|
542
|
+
output_cols_prefix: Prefix for the response columns
|
546
543
|
Returns:
|
547
544
|
Transformed dataset.
|
548
545
|
"""
|
549
|
-
self.
|
550
|
-
|
551
|
-
|
546
|
+
self._infer_input_output_cols(dataset)
|
547
|
+
super()._check_dataset_type(dataset)
|
548
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
549
|
+
estimator=self._sklearn_object,
|
550
|
+
dataset=dataset,
|
551
|
+
input_cols=self.input_cols,
|
552
|
+
label_cols=self.label_cols,
|
553
|
+
sample_weight_col=self.sample_weight_col,
|
554
|
+
autogenerated=self._autogenerated,
|
555
|
+
subproject=_SUBPROJECT,
|
556
|
+
)
|
557
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
558
|
+
drop_input_cols=self._drop_input_cols,
|
559
|
+
expected_output_cols_list=self.output_cols,
|
560
|
+
)
|
561
|
+
self._sklearn_object = fitted_estimator
|
562
|
+
self._is_fitted = True
|
563
|
+
return output_result
|
552
564
|
|
553
565
|
|
554
566
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -641,10 +653,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
641
653
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
642
654
|
|
643
655
|
if isinstance(dataset, DataFrame):
|
644
|
-
self.
|
645
|
-
|
646
|
-
inference_method=inference_method,
|
647
|
-
)
|
656
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
657
|
+
self._deps = self._get_dependencies()
|
648
658
|
assert isinstance(
|
649
659
|
dataset._session, Session
|
650
660
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -711,10 +721,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
711
721
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
712
722
|
|
713
723
|
if isinstance(dataset, DataFrame):
|
714
|
-
self.
|
715
|
-
|
716
|
-
inference_method=inference_method,
|
717
|
-
)
|
724
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
725
|
+
self._deps = self._get_dependencies()
|
718
726
|
assert isinstance(
|
719
727
|
dataset._session, Session
|
720
728
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -778,10 +786,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
778
786
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
779
787
|
|
780
788
|
if isinstance(dataset, DataFrame):
|
781
|
-
self.
|
782
|
-
|
783
|
-
inference_method=inference_method,
|
784
|
-
)
|
789
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
790
|
+
self._deps = self._get_dependencies()
|
785
791
|
assert isinstance(
|
786
792
|
dataset._session, Session
|
787
793
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -847,10 +853,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
847
853
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
848
854
|
|
849
855
|
if isinstance(dataset, DataFrame):
|
850
|
-
self.
|
851
|
-
|
852
|
-
inference_method=inference_method,
|
853
|
-
)
|
856
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
857
|
+
self._deps = self._get_dependencies()
|
854
858
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
855
859
|
transform_kwargs = dict(
|
856
860
|
session=dataset._session,
|
@@ -914,17 +918,15 @@ class AdaBoostClassifier(BaseTransformer):
|
|
914
918
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
915
919
|
|
916
920
|
if isinstance(dataset, DataFrame):
|
917
|
-
self.
|
918
|
-
|
919
|
-
inference_method="score",
|
920
|
-
)
|
921
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
922
|
+
self._deps = self._get_dependencies()
|
921
923
|
selected_cols = self._get_active_columns()
|
922
924
|
if len(selected_cols) > 0:
|
923
925
|
dataset = dataset.select(selected_cols)
|
924
926
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
925
927
|
transform_kwargs = dict(
|
926
928
|
session=dataset._session,
|
927
|
-
dependencies=
|
929
|
+
dependencies=self._deps,
|
928
930
|
score_sproc_imports=['sklearn'],
|
929
931
|
)
|
930
932
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -989,11 +991,8 @@ class AdaBoostClassifier(BaseTransformer):
|
|
989
991
|
|
990
992
|
if isinstance(dataset, DataFrame):
|
991
993
|
|
992
|
-
self.
|
993
|
-
|
994
|
-
inference_method=inference_method,
|
995
|
-
|
996
|
-
)
|
994
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
995
|
+
self._deps = self._get_dependencies()
|
997
996
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
998
997
|
transform_kwargs = dict(
|
999
998
|
session = dataset._session,
|