snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class KernelPCA(BaseTransformer):
|
70
64
|
r"""Kernel Principal component analysis (KPCA) [1]_
|
71
65
|
For more details on this class, see [sklearn.decomposition.KernelPCA]
|
@@ -378,20 +372,17 @@ class KernelPCA(BaseTransformer):
|
|
378
372
|
self,
|
379
373
|
dataset: DataFrame,
|
380
374
|
inference_method: str,
|
381
|
-
) ->
|
382
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
383
|
-
return the available package that exists in the snowflake anaconda channel
|
375
|
+
) -> None:
|
376
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
384
377
|
|
385
378
|
Args:
|
386
379
|
dataset: snowpark dataframe
|
387
380
|
inference_method: the inference method such as predict, score...
|
388
|
-
|
381
|
+
|
389
382
|
Raises:
|
390
383
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
391
384
|
SnowflakeMLException: If the session is None, raise error
|
392
385
|
|
393
|
-
Returns:
|
394
|
-
A list of available package that exists in the snowflake anaconda channel
|
395
386
|
"""
|
396
387
|
if not self._is_fitted:
|
397
388
|
raise exceptions.SnowflakeMLException(
|
@@ -409,9 +400,7 @@ class KernelPCA(BaseTransformer):
|
|
409
400
|
"Session must not specified for snowpark dataset."
|
410
401
|
),
|
411
402
|
)
|
412
|
-
|
413
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
414
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
403
|
+
|
415
404
|
|
416
405
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
417
406
|
@telemetry.send_api_usage_telemetry(
|
@@ -457,7 +446,8 @@ class KernelPCA(BaseTransformer):
|
|
457
446
|
|
458
447
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
459
448
|
|
460
|
-
self.
|
449
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
450
|
+
self._deps = self._get_dependencies()
|
461
451
|
assert isinstance(
|
462
452
|
dataset._session, Session
|
463
453
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -542,10 +532,8 @@ class KernelPCA(BaseTransformer):
|
|
542
532
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
543
533
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
544
534
|
|
545
|
-
self.
|
546
|
-
|
547
|
-
inference_method=inference_method,
|
548
|
-
)
|
535
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
536
|
+
self._deps = self._get_dependencies()
|
549
537
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
550
538
|
|
551
539
|
transform_kwargs = dict(
|
@@ -612,16 +600,42 @@ class KernelPCA(BaseTransformer):
|
|
612
600
|
self._is_fitted = True
|
613
601
|
return output_result
|
614
602
|
|
603
|
+
|
604
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
605
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
606
|
+
""" Fit the model from data in X and transform X
|
607
|
+
For more details on this function, see [sklearn.decomposition.KernelPCA.fit_transform]
|
608
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.KernelPCA.html#sklearn.decomposition.KernelPCA.fit_transform)
|
609
|
+
|
615
610
|
|
616
|
-
|
617
|
-
|
618
|
-
|
611
|
+
Raises:
|
612
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
613
|
+
|
614
|
+
Args:
|
615
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
616
|
+
Snowpark or Pandas DataFrame.
|
617
|
+
output_cols_prefix: Prefix for the response columns
|
619
618
|
Returns:
|
620
619
|
Transformed dataset.
|
621
620
|
"""
|
622
|
-
self.
|
623
|
-
|
624
|
-
|
621
|
+
self._infer_input_output_cols(dataset)
|
622
|
+
super()._check_dataset_type(dataset)
|
623
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
624
|
+
estimator=self._sklearn_object,
|
625
|
+
dataset=dataset,
|
626
|
+
input_cols=self.input_cols,
|
627
|
+
label_cols=self.label_cols,
|
628
|
+
sample_weight_col=self.sample_weight_col,
|
629
|
+
autogenerated=self._autogenerated,
|
630
|
+
subproject=_SUBPROJECT,
|
631
|
+
)
|
632
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
633
|
+
drop_input_cols=self._drop_input_cols,
|
634
|
+
expected_output_cols_list=self.output_cols,
|
635
|
+
)
|
636
|
+
self._sklearn_object = fitted_estimator
|
637
|
+
self._is_fitted = True
|
638
|
+
return output_result
|
625
639
|
|
626
640
|
|
627
641
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -712,10 +726,8 @@ class KernelPCA(BaseTransformer):
|
|
712
726
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
713
727
|
|
714
728
|
if isinstance(dataset, DataFrame):
|
715
|
-
self.
|
716
|
-
|
717
|
-
inference_method=inference_method,
|
718
|
-
)
|
729
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
730
|
+
self._deps = self._get_dependencies()
|
719
731
|
assert isinstance(
|
720
732
|
dataset._session, Session
|
721
733
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -780,10 +792,8 @@ class KernelPCA(BaseTransformer):
|
|
780
792
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
781
793
|
|
782
794
|
if isinstance(dataset, DataFrame):
|
783
|
-
self.
|
784
|
-
|
785
|
-
inference_method=inference_method,
|
786
|
-
)
|
795
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
796
|
+
self._deps = self._get_dependencies()
|
787
797
|
assert isinstance(
|
788
798
|
dataset._session, Session
|
789
799
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -845,10 +855,8 @@ class KernelPCA(BaseTransformer):
|
|
845
855
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
846
856
|
|
847
857
|
if isinstance(dataset, DataFrame):
|
848
|
-
self.
|
849
|
-
|
850
|
-
inference_method=inference_method,
|
851
|
-
)
|
858
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
859
|
+
self._deps = self._get_dependencies()
|
852
860
|
assert isinstance(
|
853
861
|
dataset._session, Session
|
854
862
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -914,10 +922,8 @@ class KernelPCA(BaseTransformer):
|
|
914
922
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
915
923
|
|
916
924
|
if isinstance(dataset, DataFrame):
|
917
|
-
self.
|
918
|
-
|
919
|
-
inference_method=inference_method,
|
920
|
-
)
|
925
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
926
|
+
self._deps = self._get_dependencies()
|
921
927
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
922
928
|
transform_kwargs = dict(
|
923
929
|
session=dataset._session,
|
@@ -979,17 +985,15 @@ class KernelPCA(BaseTransformer):
|
|
979
985
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
980
986
|
|
981
987
|
if isinstance(dataset, DataFrame):
|
982
|
-
self.
|
983
|
-
|
984
|
-
inference_method="score",
|
985
|
-
)
|
988
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
989
|
+
self._deps = self._get_dependencies()
|
986
990
|
selected_cols = self._get_active_columns()
|
987
991
|
if len(selected_cols) > 0:
|
988
992
|
dataset = dataset.select(selected_cols)
|
989
993
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
990
994
|
transform_kwargs = dict(
|
991
995
|
session=dataset._session,
|
992
|
-
dependencies=
|
996
|
+
dependencies=self._deps,
|
993
997
|
score_sproc_imports=['sklearn'],
|
994
998
|
)
|
995
999
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1054,11 +1058,8 @@ class KernelPCA(BaseTransformer):
|
|
1054
1058
|
|
1055
1059
|
if isinstance(dataset, DataFrame):
|
1056
1060
|
|
1057
|
-
self.
|
1058
|
-
|
1059
|
-
inference_method=inference_method,
|
1060
|
-
|
1061
|
-
)
|
1061
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1062
|
+
self._deps = self._get_dependencies()
|
1062
1063
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1063
1064
|
transform_kwargs = dict(
|
1064
1065
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MiniBatchDictionaryLearning(BaseTransformer):
|
70
64
|
r"""Mini-batch dictionary learning
|
71
65
|
For more details on this class, see [sklearn.decomposition.MiniBatchDictionaryLearning]
|
@@ -400,20 +394,17 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
400
394
|
self,
|
401
395
|
dataset: DataFrame,
|
402
396
|
inference_method: str,
|
403
|
-
) ->
|
404
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
405
|
-
return the available package that exists in the snowflake anaconda channel
|
397
|
+
) -> None:
|
398
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
406
399
|
|
407
400
|
Args:
|
408
401
|
dataset: snowpark dataframe
|
409
402
|
inference_method: the inference method such as predict, score...
|
410
|
-
|
403
|
+
|
411
404
|
Raises:
|
412
405
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
413
406
|
SnowflakeMLException: If the session is None, raise error
|
414
407
|
|
415
|
-
Returns:
|
416
|
-
A list of available package that exists in the snowflake anaconda channel
|
417
408
|
"""
|
418
409
|
if not self._is_fitted:
|
419
410
|
raise exceptions.SnowflakeMLException(
|
@@ -431,9 +422,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
431
422
|
"Session must not specified for snowpark dataset."
|
432
423
|
),
|
433
424
|
)
|
434
|
-
|
435
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
436
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
425
|
+
|
437
426
|
|
438
427
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
439
428
|
@telemetry.send_api_usage_telemetry(
|
@@ -479,7 +468,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
479
468
|
|
480
469
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
481
470
|
|
482
|
-
self.
|
471
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
472
|
+
self._deps = self._get_dependencies()
|
483
473
|
assert isinstance(
|
484
474
|
dataset._session, Session
|
485
475
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -564,10 +554,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
564
554
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
565
555
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
566
556
|
|
567
|
-
self.
|
568
|
-
|
569
|
-
inference_method=inference_method,
|
570
|
-
)
|
557
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
558
|
+
self._deps = self._get_dependencies()
|
571
559
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
572
560
|
|
573
561
|
transform_kwargs = dict(
|
@@ -634,16 +622,42 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
634
622
|
self._is_fitted = True
|
635
623
|
return output_result
|
636
624
|
|
625
|
+
|
626
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
627
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
628
|
+
""" Fit to data, then transform it
|
629
|
+
For more details on this function, see [sklearn.decomposition.MiniBatchDictionaryLearning.fit_transform]
|
630
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.MiniBatchDictionaryLearning.html#sklearn.decomposition.MiniBatchDictionaryLearning.fit_transform)
|
631
|
+
|
637
632
|
|
638
|
-
|
639
|
-
|
640
|
-
|
633
|
+
Raises:
|
634
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
635
|
+
|
636
|
+
Args:
|
637
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
638
|
+
Snowpark or Pandas DataFrame.
|
639
|
+
output_cols_prefix: Prefix for the response columns
|
641
640
|
Returns:
|
642
641
|
Transformed dataset.
|
643
642
|
"""
|
644
|
-
self.
|
645
|
-
|
646
|
-
|
643
|
+
self._infer_input_output_cols(dataset)
|
644
|
+
super()._check_dataset_type(dataset)
|
645
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
646
|
+
estimator=self._sklearn_object,
|
647
|
+
dataset=dataset,
|
648
|
+
input_cols=self.input_cols,
|
649
|
+
label_cols=self.label_cols,
|
650
|
+
sample_weight_col=self.sample_weight_col,
|
651
|
+
autogenerated=self._autogenerated,
|
652
|
+
subproject=_SUBPROJECT,
|
653
|
+
)
|
654
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
655
|
+
drop_input_cols=self._drop_input_cols,
|
656
|
+
expected_output_cols_list=self.output_cols,
|
657
|
+
)
|
658
|
+
self._sklearn_object = fitted_estimator
|
659
|
+
self._is_fitted = True
|
660
|
+
return output_result
|
647
661
|
|
648
662
|
|
649
663
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -734,10 +748,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
734
748
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
735
749
|
|
736
750
|
if isinstance(dataset, DataFrame):
|
737
|
-
self.
|
738
|
-
|
739
|
-
inference_method=inference_method,
|
740
|
-
)
|
751
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
752
|
+
self._deps = self._get_dependencies()
|
741
753
|
assert isinstance(
|
742
754
|
dataset._session, Session
|
743
755
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -802,10 +814,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
802
814
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
803
815
|
|
804
816
|
if isinstance(dataset, DataFrame):
|
805
|
-
self.
|
806
|
-
|
807
|
-
inference_method=inference_method,
|
808
|
-
)
|
817
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
818
|
+
self._deps = self._get_dependencies()
|
809
819
|
assert isinstance(
|
810
820
|
dataset._session, Session
|
811
821
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -867,10 +877,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
867
877
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
868
878
|
|
869
879
|
if isinstance(dataset, DataFrame):
|
870
|
-
self.
|
871
|
-
|
872
|
-
inference_method=inference_method,
|
873
|
-
)
|
880
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
881
|
+
self._deps = self._get_dependencies()
|
874
882
|
assert isinstance(
|
875
883
|
dataset._session, Session
|
876
884
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -936,10 +944,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
936
944
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
937
945
|
|
938
946
|
if isinstance(dataset, DataFrame):
|
939
|
-
self.
|
940
|
-
|
941
|
-
inference_method=inference_method,
|
942
|
-
)
|
947
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
948
|
+
self._deps = self._get_dependencies()
|
943
949
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
944
950
|
transform_kwargs = dict(
|
945
951
|
session=dataset._session,
|
@@ -1001,17 +1007,15 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1001
1007
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1002
1008
|
|
1003
1009
|
if isinstance(dataset, DataFrame):
|
1004
|
-
self.
|
1005
|
-
|
1006
|
-
inference_method="score",
|
1007
|
-
)
|
1010
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
1011
|
+
self._deps = self._get_dependencies()
|
1008
1012
|
selected_cols = self._get_active_columns()
|
1009
1013
|
if len(selected_cols) > 0:
|
1010
1014
|
dataset = dataset.select(selected_cols)
|
1011
1015
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1012
1016
|
transform_kwargs = dict(
|
1013
1017
|
session=dataset._session,
|
1014
|
-
dependencies=
|
1018
|
+
dependencies=self._deps,
|
1015
1019
|
score_sproc_imports=['sklearn'],
|
1016
1020
|
)
|
1017
1021
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1076,11 +1080,8 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1076
1080
|
|
1077
1081
|
if isinstance(dataset, DataFrame):
|
1078
1082
|
|
1079
|
-
self.
|
1080
|
-
|
1081
|
-
inference_method=inference_method,
|
1082
|
-
|
1083
|
-
)
|
1083
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1084
|
+
self._deps = self._get_dependencies()
|
1084
1085
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1085
1086
|
transform_kwargs = dict(
|
1086
1087
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class MiniBatchSparsePCA(BaseTransformer):
|
70
64
|
r"""Mini-batch Sparse Principal Components Analysis
|
71
65
|
For more details on this class, see [sklearn.decomposition.MiniBatchSparsePCA]
|
@@ -345,20 +339,17 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
345
339
|
self,
|
346
340
|
dataset: DataFrame,
|
347
341
|
inference_method: str,
|
348
|
-
) ->
|
349
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
350
|
-
return the available package that exists in the snowflake anaconda channel
|
342
|
+
) -> None:
|
343
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
351
344
|
|
352
345
|
Args:
|
353
346
|
dataset: snowpark dataframe
|
354
347
|
inference_method: the inference method such as predict, score...
|
355
|
-
|
348
|
+
|
356
349
|
Raises:
|
357
350
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
358
351
|
SnowflakeMLException: If the session is None, raise error
|
359
352
|
|
360
|
-
Returns:
|
361
|
-
A list of available package that exists in the snowflake anaconda channel
|
362
353
|
"""
|
363
354
|
if not self._is_fitted:
|
364
355
|
raise exceptions.SnowflakeMLException(
|
@@ -376,9 +367,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
376
367
|
"Session must not specified for snowpark dataset."
|
377
368
|
),
|
378
369
|
)
|
379
|
-
|
380
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
381
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
370
|
+
|
382
371
|
|
383
372
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
384
373
|
@telemetry.send_api_usage_telemetry(
|
@@ -424,7 +413,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
424
413
|
|
425
414
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
426
415
|
|
427
|
-
self.
|
416
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
417
|
+
self._deps = self._get_dependencies()
|
428
418
|
assert isinstance(
|
429
419
|
dataset._session, Session
|
430
420
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -509,10 +499,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
509
499
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
510
500
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
511
501
|
|
512
|
-
self.
|
513
|
-
|
514
|
-
inference_method=inference_method,
|
515
|
-
)
|
502
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
503
|
+
self._deps = self._get_dependencies()
|
516
504
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
517
505
|
|
518
506
|
transform_kwargs = dict(
|
@@ -579,16 +567,42 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
579
567
|
self._is_fitted = True
|
580
568
|
return output_result
|
581
569
|
|
570
|
+
|
571
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
572
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
573
|
+
""" Fit to data, then transform it
|
574
|
+
For more details on this function, see [sklearn.decomposition.MiniBatchSparsePCA.fit_transform]
|
575
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.MiniBatchSparsePCA.html#sklearn.decomposition.MiniBatchSparsePCA.fit_transform)
|
576
|
+
|
582
577
|
|
583
|
-
|
584
|
-
|
585
|
-
|
578
|
+
Raises:
|
579
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
580
|
+
|
581
|
+
Args:
|
582
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
583
|
+
Snowpark or Pandas DataFrame.
|
584
|
+
output_cols_prefix: Prefix for the response columns
|
586
585
|
Returns:
|
587
586
|
Transformed dataset.
|
588
587
|
"""
|
589
|
-
self.
|
590
|
-
|
591
|
-
|
588
|
+
self._infer_input_output_cols(dataset)
|
589
|
+
super()._check_dataset_type(dataset)
|
590
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
591
|
+
estimator=self._sklearn_object,
|
592
|
+
dataset=dataset,
|
593
|
+
input_cols=self.input_cols,
|
594
|
+
label_cols=self.label_cols,
|
595
|
+
sample_weight_col=self.sample_weight_col,
|
596
|
+
autogenerated=self._autogenerated,
|
597
|
+
subproject=_SUBPROJECT,
|
598
|
+
)
|
599
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
600
|
+
drop_input_cols=self._drop_input_cols,
|
601
|
+
expected_output_cols_list=self.output_cols,
|
602
|
+
)
|
603
|
+
self._sklearn_object = fitted_estimator
|
604
|
+
self._is_fitted = True
|
605
|
+
return output_result
|
592
606
|
|
593
607
|
|
594
608
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -679,10 +693,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
679
693
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
680
694
|
|
681
695
|
if isinstance(dataset, DataFrame):
|
682
|
-
self.
|
683
|
-
|
684
|
-
inference_method=inference_method,
|
685
|
-
)
|
696
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
697
|
+
self._deps = self._get_dependencies()
|
686
698
|
assert isinstance(
|
687
699
|
dataset._session, Session
|
688
700
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -747,10 +759,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
747
759
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
748
760
|
|
749
761
|
if isinstance(dataset, DataFrame):
|
750
|
-
self.
|
751
|
-
|
752
|
-
inference_method=inference_method,
|
753
|
-
)
|
762
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
763
|
+
self._deps = self._get_dependencies()
|
754
764
|
assert isinstance(
|
755
765
|
dataset._session, Session
|
756
766
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -812,10 +822,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
812
822
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
813
823
|
|
814
824
|
if isinstance(dataset, DataFrame):
|
815
|
-
self.
|
816
|
-
|
817
|
-
inference_method=inference_method,
|
818
|
-
)
|
825
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
826
|
+
self._deps = self._get_dependencies()
|
819
827
|
assert isinstance(
|
820
828
|
dataset._session, Session
|
821
829
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -881,10 +889,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
881
889
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
882
890
|
|
883
891
|
if isinstance(dataset, DataFrame):
|
884
|
-
self.
|
885
|
-
|
886
|
-
inference_method=inference_method,
|
887
|
-
)
|
892
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
893
|
+
self._deps = self._get_dependencies()
|
888
894
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
889
895
|
transform_kwargs = dict(
|
890
896
|
session=dataset._session,
|
@@ -946,17 +952,15 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
946
952
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
947
953
|
|
948
954
|
if isinstance(dataset, DataFrame):
|
949
|
-
self.
|
950
|
-
|
951
|
-
inference_method="score",
|
952
|
-
)
|
955
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
956
|
+
self._deps = self._get_dependencies()
|
953
957
|
selected_cols = self._get_active_columns()
|
954
958
|
if len(selected_cols) > 0:
|
955
959
|
dataset = dataset.select(selected_cols)
|
956
960
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
957
961
|
transform_kwargs = dict(
|
958
962
|
session=dataset._session,
|
959
|
-
dependencies=
|
963
|
+
dependencies=self._deps,
|
960
964
|
score_sproc_imports=['sklearn'],
|
961
965
|
)
|
962
966
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,11 +1025,8 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
1021
1025
|
|
1022
1026
|
if isinstance(dataset, DataFrame):
|
1023
1027
|
|
1024
|
-
self.
|
1025
|
-
|
1026
|
-
inference_method=inference_method,
|
1027
|
-
|
1028
|
-
)
|
1028
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1029
|
+
self._deps = self._get_dependencies()
|
1029
1030
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1030
1031
|
transform_kwargs = dict(
|
1031
1032
|
session = dataset._session,
|