snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class PCA(BaseTransformer):
|
70
64
|
r"""Principal component analysis (PCA)
|
71
65
|
For more details on this class, see [sklearn.decomposition.PCA]
|
@@ -347,20 +341,17 @@ class PCA(BaseTransformer):
|
|
347
341
|
self,
|
348
342
|
dataset: DataFrame,
|
349
343
|
inference_method: str,
|
350
|
-
) ->
|
351
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
352
|
-
return the available package that exists in the snowflake anaconda channel
|
344
|
+
) -> None:
|
345
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
353
346
|
|
354
347
|
Args:
|
355
348
|
dataset: snowpark dataframe
|
356
349
|
inference_method: the inference method such as predict, score...
|
357
|
-
|
350
|
+
|
358
351
|
Raises:
|
359
352
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
360
353
|
SnowflakeMLException: If the session is None, raise error
|
361
354
|
|
362
|
-
Returns:
|
363
|
-
A list of available package that exists in the snowflake anaconda channel
|
364
355
|
"""
|
365
356
|
if not self._is_fitted:
|
366
357
|
raise exceptions.SnowflakeMLException(
|
@@ -378,9 +369,7 @@ class PCA(BaseTransformer):
|
|
378
369
|
"Session must not specified for snowpark dataset."
|
379
370
|
),
|
380
371
|
)
|
381
|
-
|
382
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
383
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
372
|
+
|
384
373
|
|
385
374
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
386
375
|
@telemetry.send_api_usage_telemetry(
|
@@ -426,7 +415,8 @@ class PCA(BaseTransformer):
|
|
426
415
|
|
427
416
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
428
417
|
|
429
|
-
self.
|
418
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
419
|
+
self._deps = self._get_dependencies()
|
430
420
|
assert isinstance(
|
431
421
|
dataset._session, Session
|
432
422
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -511,10 +501,8 @@ class PCA(BaseTransformer):
|
|
511
501
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
512
502
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
513
503
|
|
514
|
-
self.
|
515
|
-
|
516
|
-
inference_method=inference_method,
|
517
|
-
)
|
504
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
505
|
+
self._deps = self._get_dependencies()
|
518
506
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
519
507
|
|
520
508
|
transform_kwargs = dict(
|
@@ -581,16 +569,42 @@ class PCA(BaseTransformer):
|
|
581
569
|
self._is_fitted = True
|
582
570
|
return output_result
|
583
571
|
|
572
|
+
|
573
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
574
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
575
|
+
""" Fit the model with X and apply the dimensionality reduction on X
|
576
|
+
For more details on this function, see [sklearn.decomposition.PCA.fit_transform]
|
577
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html#sklearn.decomposition.PCA.fit_transform)
|
578
|
+
|
584
579
|
|
585
|
-
|
586
|
-
|
587
|
-
|
580
|
+
Raises:
|
581
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
582
|
+
|
583
|
+
Args:
|
584
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
585
|
+
Snowpark or Pandas DataFrame.
|
586
|
+
output_cols_prefix: Prefix for the response columns
|
588
587
|
Returns:
|
589
588
|
Transformed dataset.
|
590
589
|
"""
|
591
|
-
self.
|
592
|
-
|
593
|
-
|
590
|
+
self._infer_input_output_cols(dataset)
|
591
|
+
super()._check_dataset_type(dataset)
|
592
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
593
|
+
estimator=self._sklearn_object,
|
594
|
+
dataset=dataset,
|
595
|
+
input_cols=self.input_cols,
|
596
|
+
label_cols=self.label_cols,
|
597
|
+
sample_weight_col=self.sample_weight_col,
|
598
|
+
autogenerated=self._autogenerated,
|
599
|
+
subproject=_SUBPROJECT,
|
600
|
+
)
|
601
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
602
|
+
drop_input_cols=self._drop_input_cols,
|
603
|
+
expected_output_cols_list=self.output_cols,
|
604
|
+
)
|
605
|
+
self._sklearn_object = fitted_estimator
|
606
|
+
self._is_fitted = True
|
607
|
+
return output_result
|
594
608
|
|
595
609
|
|
596
610
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -681,10 +695,8 @@ class PCA(BaseTransformer):
|
|
681
695
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
682
696
|
|
683
697
|
if isinstance(dataset, DataFrame):
|
684
|
-
self.
|
685
|
-
|
686
|
-
inference_method=inference_method,
|
687
|
-
)
|
698
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
699
|
+
self._deps = self._get_dependencies()
|
688
700
|
assert isinstance(
|
689
701
|
dataset._session, Session
|
690
702
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -749,10 +761,8 @@ class PCA(BaseTransformer):
|
|
749
761
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
750
762
|
|
751
763
|
if isinstance(dataset, DataFrame):
|
752
|
-
self.
|
753
|
-
|
754
|
-
inference_method=inference_method,
|
755
|
-
)
|
764
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
765
|
+
self._deps = self._get_dependencies()
|
756
766
|
assert isinstance(
|
757
767
|
dataset._session, Session
|
758
768
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -814,10 +824,8 @@ class PCA(BaseTransformer):
|
|
814
824
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
815
825
|
|
816
826
|
if isinstance(dataset, DataFrame):
|
817
|
-
self.
|
818
|
-
|
819
|
-
inference_method=inference_method,
|
820
|
-
)
|
827
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
828
|
+
self._deps = self._get_dependencies()
|
821
829
|
assert isinstance(
|
822
830
|
dataset._session, Session
|
823
831
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -885,10 +893,8 @@ class PCA(BaseTransformer):
|
|
885
893
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
886
894
|
|
887
895
|
if isinstance(dataset, DataFrame):
|
888
|
-
self.
|
889
|
-
|
890
|
-
inference_method=inference_method,
|
891
|
-
)
|
896
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
897
|
+
self._deps = self._get_dependencies()
|
892
898
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
893
899
|
transform_kwargs = dict(
|
894
900
|
session=dataset._session,
|
@@ -952,17 +958,15 @@ class PCA(BaseTransformer):
|
|
952
958
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
953
959
|
|
954
960
|
if isinstance(dataset, DataFrame):
|
955
|
-
self.
|
956
|
-
|
957
|
-
inference_method="score",
|
958
|
-
)
|
961
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
962
|
+
self._deps = self._get_dependencies()
|
959
963
|
selected_cols = self._get_active_columns()
|
960
964
|
if len(selected_cols) > 0:
|
961
965
|
dataset = dataset.select(selected_cols)
|
962
966
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
963
967
|
transform_kwargs = dict(
|
964
968
|
session=dataset._session,
|
965
|
-
dependencies=
|
969
|
+
dependencies=self._deps,
|
966
970
|
score_sproc_imports=['sklearn'],
|
967
971
|
)
|
968
972
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1027,11 +1031,8 @@ class PCA(BaseTransformer):
|
|
1027
1031
|
|
1028
1032
|
if isinstance(dataset, DataFrame):
|
1029
1033
|
|
1030
|
-
self.
|
1031
|
-
|
1032
|
-
inference_method=inference_method,
|
1033
|
-
|
1034
|
-
)
|
1034
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1035
|
+
self._deps = self._get_dependencies()
|
1035
1036
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1036
1037
|
transform_kwargs = dict(
|
1037
1038
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class SparsePCA(BaseTransformer):
|
70
64
|
r"""Sparse Principal Components Analysis (SparsePCA)
|
71
65
|
For more details on this class, see [sklearn.decomposition.SparsePCA]
|
@@ -320,20 +314,17 @@ class SparsePCA(BaseTransformer):
|
|
320
314
|
self,
|
321
315
|
dataset: DataFrame,
|
322
316
|
inference_method: str,
|
323
|
-
) ->
|
324
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
325
|
-
return the available package that exists in the snowflake anaconda channel
|
317
|
+
) -> None:
|
318
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
326
319
|
|
327
320
|
Args:
|
328
321
|
dataset: snowpark dataframe
|
329
322
|
inference_method: the inference method such as predict, score...
|
330
|
-
|
323
|
+
|
331
324
|
Raises:
|
332
325
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
333
326
|
SnowflakeMLException: If the session is None, raise error
|
334
327
|
|
335
|
-
Returns:
|
336
|
-
A list of available package that exists in the snowflake anaconda channel
|
337
328
|
"""
|
338
329
|
if not self._is_fitted:
|
339
330
|
raise exceptions.SnowflakeMLException(
|
@@ -351,9 +342,7 @@ class SparsePCA(BaseTransformer):
|
|
351
342
|
"Session must not specified for snowpark dataset."
|
352
343
|
),
|
353
344
|
)
|
354
|
-
|
355
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
356
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
345
|
+
|
357
346
|
|
358
347
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
359
348
|
@telemetry.send_api_usage_telemetry(
|
@@ -399,7 +388,8 @@ class SparsePCA(BaseTransformer):
|
|
399
388
|
|
400
389
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
401
390
|
|
402
|
-
self.
|
391
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
392
|
+
self._deps = self._get_dependencies()
|
403
393
|
assert isinstance(
|
404
394
|
dataset._session, Session
|
405
395
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -484,10 +474,8 @@ class SparsePCA(BaseTransformer):
|
|
484
474
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
485
475
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
486
476
|
|
487
|
-
self.
|
488
|
-
|
489
|
-
inference_method=inference_method,
|
490
|
-
)
|
477
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
478
|
+
self._deps = self._get_dependencies()
|
491
479
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
492
480
|
|
493
481
|
transform_kwargs = dict(
|
@@ -554,16 +542,42 @@ class SparsePCA(BaseTransformer):
|
|
554
542
|
self._is_fitted = True
|
555
543
|
return output_result
|
556
544
|
|
545
|
+
|
546
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
547
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
548
|
+
""" Fit to data, then transform it
|
549
|
+
For more details on this function, see [sklearn.decomposition.SparsePCA.fit_transform]
|
550
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.SparsePCA.html#sklearn.decomposition.SparsePCA.fit_transform)
|
551
|
+
|
557
552
|
|
558
|
-
|
559
|
-
|
560
|
-
|
553
|
+
Raises:
|
554
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
555
|
+
|
556
|
+
Args:
|
557
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
558
|
+
Snowpark or Pandas DataFrame.
|
559
|
+
output_cols_prefix: Prefix for the response columns
|
561
560
|
Returns:
|
562
561
|
Transformed dataset.
|
563
562
|
"""
|
564
|
-
self.
|
565
|
-
|
566
|
-
|
563
|
+
self._infer_input_output_cols(dataset)
|
564
|
+
super()._check_dataset_type(dataset)
|
565
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
566
|
+
estimator=self._sklearn_object,
|
567
|
+
dataset=dataset,
|
568
|
+
input_cols=self.input_cols,
|
569
|
+
label_cols=self.label_cols,
|
570
|
+
sample_weight_col=self.sample_weight_col,
|
571
|
+
autogenerated=self._autogenerated,
|
572
|
+
subproject=_SUBPROJECT,
|
573
|
+
)
|
574
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
575
|
+
drop_input_cols=self._drop_input_cols,
|
576
|
+
expected_output_cols_list=self.output_cols,
|
577
|
+
)
|
578
|
+
self._sklearn_object = fitted_estimator
|
579
|
+
self._is_fitted = True
|
580
|
+
return output_result
|
567
581
|
|
568
582
|
|
569
583
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -654,10 +668,8 @@ class SparsePCA(BaseTransformer):
|
|
654
668
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
655
669
|
|
656
670
|
if isinstance(dataset, DataFrame):
|
657
|
-
self.
|
658
|
-
|
659
|
-
inference_method=inference_method,
|
660
|
-
)
|
671
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
672
|
+
self._deps = self._get_dependencies()
|
661
673
|
assert isinstance(
|
662
674
|
dataset._session, Session
|
663
675
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -722,10 +734,8 @@ class SparsePCA(BaseTransformer):
|
|
722
734
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
723
735
|
|
724
736
|
if isinstance(dataset, DataFrame):
|
725
|
-
self.
|
726
|
-
|
727
|
-
inference_method=inference_method,
|
728
|
-
)
|
737
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
738
|
+
self._deps = self._get_dependencies()
|
729
739
|
assert isinstance(
|
730
740
|
dataset._session, Session
|
731
741
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -787,10 +797,8 @@ class SparsePCA(BaseTransformer):
|
|
787
797
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
788
798
|
|
789
799
|
if isinstance(dataset, DataFrame):
|
790
|
-
self.
|
791
|
-
|
792
|
-
inference_method=inference_method,
|
793
|
-
)
|
800
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
801
|
+
self._deps = self._get_dependencies()
|
794
802
|
assert isinstance(
|
795
803
|
dataset._session, Session
|
796
804
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -856,10 +864,8 @@ class SparsePCA(BaseTransformer):
|
|
856
864
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
857
865
|
|
858
866
|
if isinstance(dataset, DataFrame):
|
859
|
-
self.
|
860
|
-
|
861
|
-
inference_method=inference_method,
|
862
|
-
)
|
867
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
868
|
+
self._deps = self._get_dependencies()
|
863
869
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
864
870
|
transform_kwargs = dict(
|
865
871
|
session=dataset._session,
|
@@ -921,17 +927,15 @@ class SparsePCA(BaseTransformer):
|
|
921
927
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
922
928
|
|
923
929
|
if isinstance(dataset, DataFrame):
|
924
|
-
self.
|
925
|
-
|
926
|
-
inference_method="score",
|
927
|
-
)
|
930
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
931
|
+
self._deps = self._get_dependencies()
|
928
932
|
selected_cols = self._get_active_columns()
|
929
933
|
if len(selected_cols) > 0:
|
930
934
|
dataset = dataset.select(selected_cols)
|
931
935
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
932
936
|
transform_kwargs = dict(
|
933
937
|
session=dataset._session,
|
934
|
-
dependencies=
|
938
|
+
dependencies=self._deps,
|
935
939
|
score_sproc_imports=['sklearn'],
|
936
940
|
)
|
937
941
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -996,11 +1000,8 @@ class SparsePCA(BaseTransformer):
|
|
996
1000
|
|
997
1001
|
if isinstance(dataset, DataFrame):
|
998
1002
|
|
999
|
-
self.
|
1000
|
-
|
1001
|
-
inference_method=inference_method,
|
1002
|
-
|
1003
|
-
)
|
1003
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1004
|
+
self._deps = self._get_dependencies()
|
1004
1005
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1005
1006
|
transform_kwargs = dict(
|
1006
1007
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class TruncatedSVD(BaseTransformer):
|
70
64
|
r"""Dimensionality reduction using truncated SVD (aka LSA)
|
71
65
|
For more details on this class, see [sklearn.decomposition.TruncatedSVD]
|
@@ -301,20 +295,17 @@ class TruncatedSVD(BaseTransformer):
|
|
301
295
|
self,
|
302
296
|
dataset: DataFrame,
|
303
297
|
inference_method: str,
|
304
|
-
) ->
|
305
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
306
|
-
return the available package that exists in the snowflake anaconda channel
|
298
|
+
) -> None:
|
299
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
307
300
|
|
308
301
|
Args:
|
309
302
|
dataset: snowpark dataframe
|
310
303
|
inference_method: the inference method such as predict, score...
|
311
|
-
|
304
|
+
|
312
305
|
Raises:
|
313
306
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
314
307
|
SnowflakeMLException: If the session is None, raise error
|
315
308
|
|
316
|
-
Returns:
|
317
|
-
A list of available package that exists in the snowflake anaconda channel
|
318
309
|
"""
|
319
310
|
if not self._is_fitted:
|
320
311
|
raise exceptions.SnowflakeMLException(
|
@@ -332,9 +323,7 @@ class TruncatedSVD(BaseTransformer):
|
|
332
323
|
"Session must not specified for snowpark dataset."
|
333
324
|
),
|
334
325
|
)
|
335
|
-
|
336
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
337
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
326
|
+
|
338
327
|
|
339
328
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
340
329
|
@telemetry.send_api_usage_telemetry(
|
@@ -380,7 +369,8 @@ class TruncatedSVD(BaseTransformer):
|
|
380
369
|
|
381
370
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
382
371
|
|
383
|
-
self.
|
372
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
373
|
+
self._deps = self._get_dependencies()
|
384
374
|
assert isinstance(
|
385
375
|
dataset._session, Session
|
386
376
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -465,10 +455,8 @@ class TruncatedSVD(BaseTransformer):
|
|
465
455
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
466
456
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
467
457
|
|
468
|
-
self.
|
469
|
-
|
470
|
-
inference_method=inference_method,
|
471
|
-
)
|
458
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
459
|
+
self._deps = self._get_dependencies()
|
472
460
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
473
461
|
|
474
462
|
transform_kwargs = dict(
|
@@ -535,16 +523,42 @@ class TruncatedSVD(BaseTransformer):
|
|
535
523
|
self._is_fitted = True
|
536
524
|
return output_result
|
537
525
|
|
526
|
+
|
527
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
528
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
529
|
+
""" Fit model to X and perform dimensionality reduction on X
|
530
|
+
For more details on this function, see [sklearn.decomposition.TruncatedSVD.fit_transform]
|
531
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html#sklearn.decomposition.TruncatedSVD.fit_transform)
|
532
|
+
|
538
533
|
|
539
|
-
|
540
|
-
|
541
|
-
|
534
|
+
Raises:
|
535
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
536
|
+
|
537
|
+
Args:
|
538
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
539
|
+
Snowpark or Pandas DataFrame.
|
540
|
+
output_cols_prefix: Prefix for the response columns
|
542
541
|
Returns:
|
543
542
|
Transformed dataset.
|
544
543
|
"""
|
545
|
-
self.
|
546
|
-
|
547
|
-
|
544
|
+
self._infer_input_output_cols(dataset)
|
545
|
+
super()._check_dataset_type(dataset)
|
546
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
547
|
+
estimator=self._sklearn_object,
|
548
|
+
dataset=dataset,
|
549
|
+
input_cols=self.input_cols,
|
550
|
+
label_cols=self.label_cols,
|
551
|
+
sample_weight_col=self.sample_weight_col,
|
552
|
+
autogenerated=self._autogenerated,
|
553
|
+
subproject=_SUBPROJECT,
|
554
|
+
)
|
555
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
556
|
+
drop_input_cols=self._drop_input_cols,
|
557
|
+
expected_output_cols_list=self.output_cols,
|
558
|
+
)
|
559
|
+
self._sklearn_object = fitted_estimator
|
560
|
+
self._is_fitted = True
|
561
|
+
return output_result
|
548
562
|
|
549
563
|
|
550
564
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -635,10 +649,8 @@ class TruncatedSVD(BaseTransformer):
|
|
635
649
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
636
650
|
|
637
651
|
if isinstance(dataset, DataFrame):
|
638
|
-
self.
|
639
|
-
|
640
|
-
inference_method=inference_method,
|
641
|
-
)
|
652
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
653
|
+
self._deps = self._get_dependencies()
|
642
654
|
assert isinstance(
|
643
655
|
dataset._session, Session
|
644
656
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -703,10 +715,8 @@ class TruncatedSVD(BaseTransformer):
|
|
703
715
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
704
716
|
|
705
717
|
if isinstance(dataset, DataFrame):
|
706
|
-
self.
|
707
|
-
|
708
|
-
inference_method=inference_method,
|
709
|
-
)
|
718
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
719
|
+
self._deps = self._get_dependencies()
|
710
720
|
assert isinstance(
|
711
721
|
dataset._session, Session
|
712
722
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -768,10 +778,8 @@ class TruncatedSVD(BaseTransformer):
|
|
768
778
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
769
779
|
|
770
780
|
if isinstance(dataset, DataFrame):
|
771
|
-
self.
|
772
|
-
|
773
|
-
inference_method=inference_method,
|
774
|
-
)
|
781
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
782
|
+
self._deps = self._get_dependencies()
|
775
783
|
assert isinstance(
|
776
784
|
dataset._session, Session
|
777
785
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -837,10 +845,8 @@ class TruncatedSVD(BaseTransformer):
|
|
837
845
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
838
846
|
|
839
847
|
if isinstance(dataset, DataFrame):
|
840
|
-
self.
|
841
|
-
|
842
|
-
inference_method=inference_method,
|
843
|
-
)
|
848
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
849
|
+
self._deps = self._get_dependencies()
|
844
850
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
845
851
|
transform_kwargs = dict(
|
846
852
|
session=dataset._session,
|
@@ -902,17 +908,15 @@ class TruncatedSVD(BaseTransformer):
|
|
902
908
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
903
909
|
|
904
910
|
if isinstance(dataset, DataFrame):
|
905
|
-
self.
|
906
|
-
|
907
|
-
inference_method="score",
|
908
|
-
)
|
911
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
912
|
+
self._deps = self._get_dependencies()
|
909
913
|
selected_cols = self._get_active_columns()
|
910
914
|
if len(selected_cols) > 0:
|
911
915
|
dataset = dataset.select(selected_cols)
|
912
916
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
913
917
|
transform_kwargs = dict(
|
914
918
|
session=dataset._session,
|
915
|
-
dependencies=
|
919
|
+
dependencies=self._deps,
|
916
920
|
score_sproc_imports=['sklearn'],
|
917
921
|
)
|
918
922
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -977,11 +981,8 @@ class TruncatedSVD(BaseTransformer):
|
|
977
981
|
|
978
982
|
if isinstance(dataset, DataFrame):
|
979
983
|
|
980
|
-
self.
|
981
|
-
|
982
|
-
inference_method=inference_method,
|
983
|
-
|
984
|
-
)
|
984
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
985
|
+
self._deps = self._get_dependencies()
|
985
986
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
986
987
|
transform_kwargs = dict(
|
987
988
|
session = dataset._session,
|