snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +72 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +95 -0
- snowflake/ml/_internal/telemetry.py +1 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/sql_identifier.py +14 -1
- snowflake/ml/dataset/__init__.py +11 -0
- snowflake/ml/dataset/dataset.py +455 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +199 -0
- snowflake/ml/feature_store/__init__.py +6 -0
- snowflake/ml/feature_store/access_manager.py +279 -0
- snowflake/ml/feature_store/feature_store.py +544 -358
- snowflake/ml/feature_store/feature_view.py +55 -16
- snowflake/ml/fileset/embedded_stage_fs.py +149 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +160 -0
- snowflake/ml/fileset/stage_fs.py +25 -10
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +65 -31
- snowflake/ml/model/_client/model/model_version_impl.py +159 -2
- snowflake/ml/model/_client/ops/metadata_ops.py +27 -4
- snowflake/ml/model/_client/ops/model_ops.py +268 -83
- snowflake/ml/model/_client/sql/_base.py +34 -0
- snowflake/ml/model/_client/sql/model.py +42 -47
- snowflake/ml/model/_client/sql/model_version.py +164 -39
- snowflake/ml/model/_client/sql/stage.py +6 -32
- snowflake/ml/model/_client/sql/tag.py +32 -56
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +50 -21
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +340 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +64 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +538 -36
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/_manager/model_manager.py +36 -7
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/METADATA +112 -7
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/RECORD +216 -206
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.1.dist-info}/top_level.txt +0 -0
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class FactorAnalysis(BaseTransformer):
|
70
64
|
r"""Factor Analysis (FA)
|
71
65
|
For more details on this class, see [sklearn.decomposition.FactorAnalysis]
|
@@ -312,20 +306,17 @@ class FactorAnalysis(BaseTransformer):
|
|
312
306
|
self,
|
313
307
|
dataset: DataFrame,
|
314
308
|
inference_method: str,
|
315
|
-
) ->
|
316
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
317
|
-
return the available package that exists in the snowflake anaconda channel
|
309
|
+
) -> None:
|
310
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
318
311
|
|
319
312
|
Args:
|
320
313
|
dataset: snowpark dataframe
|
321
314
|
inference_method: the inference method such as predict, score...
|
322
|
-
|
315
|
+
|
323
316
|
Raises:
|
324
317
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
325
318
|
SnowflakeMLException: If the session is None, raise error
|
326
319
|
|
327
|
-
Returns:
|
328
|
-
A list of available package that exists in the snowflake anaconda channel
|
329
320
|
"""
|
330
321
|
if not self._is_fitted:
|
331
322
|
raise exceptions.SnowflakeMLException(
|
@@ -343,9 +334,7 @@ class FactorAnalysis(BaseTransformer):
|
|
343
334
|
"Session must not specified for snowpark dataset."
|
344
335
|
),
|
345
336
|
)
|
346
|
-
|
347
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
348
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
337
|
+
|
349
338
|
|
350
339
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
351
340
|
@telemetry.send_api_usage_telemetry(
|
@@ -391,7 +380,8 @@ class FactorAnalysis(BaseTransformer):
|
|
391
380
|
|
392
381
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
393
382
|
|
394
|
-
self.
|
383
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
384
|
+
self._deps = self._get_dependencies()
|
395
385
|
assert isinstance(
|
396
386
|
dataset._session, Session
|
397
387
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -476,10 +466,8 @@ class FactorAnalysis(BaseTransformer):
|
|
476
466
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
477
467
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
478
468
|
|
479
|
-
self.
|
480
|
-
|
481
|
-
inference_method=inference_method,
|
482
|
-
)
|
469
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
470
|
+
self._deps = self._get_dependencies()
|
483
471
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
484
472
|
|
485
473
|
transform_kwargs = dict(
|
@@ -546,16 +534,42 @@ class FactorAnalysis(BaseTransformer):
|
|
546
534
|
self._is_fitted = True
|
547
535
|
return output_result
|
548
536
|
|
537
|
+
|
538
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
539
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
540
|
+
""" Fit to data, then transform it
|
541
|
+
For more details on this function, see [sklearn.decomposition.FactorAnalysis.fit_transform]
|
542
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FactorAnalysis.html#sklearn.decomposition.FactorAnalysis.fit_transform)
|
543
|
+
|
549
544
|
|
550
|
-
|
551
|
-
|
552
|
-
|
545
|
+
Raises:
|
546
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
547
|
+
|
548
|
+
Args:
|
549
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
550
|
+
Snowpark or Pandas DataFrame.
|
551
|
+
output_cols_prefix: Prefix for the response columns
|
553
552
|
Returns:
|
554
553
|
Transformed dataset.
|
555
554
|
"""
|
556
|
-
self.
|
557
|
-
|
558
|
-
|
555
|
+
self._infer_input_output_cols(dataset)
|
556
|
+
super()._check_dataset_type(dataset)
|
557
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
558
|
+
estimator=self._sklearn_object,
|
559
|
+
dataset=dataset,
|
560
|
+
input_cols=self.input_cols,
|
561
|
+
label_cols=self.label_cols,
|
562
|
+
sample_weight_col=self.sample_weight_col,
|
563
|
+
autogenerated=self._autogenerated,
|
564
|
+
subproject=_SUBPROJECT,
|
565
|
+
)
|
566
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
567
|
+
drop_input_cols=self._drop_input_cols,
|
568
|
+
expected_output_cols_list=self.output_cols,
|
569
|
+
)
|
570
|
+
self._sklearn_object = fitted_estimator
|
571
|
+
self._is_fitted = True
|
572
|
+
return output_result
|
559
573
|
|
560
574
|
|
561
575
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -646,10 +660,8 @@ class FactorAnalysis(BaseTransformer):
|
|
646
660
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
647
661
|
|
648
662
|
if isinstance(dataset, DataFrame):
|
649
|
-
self.
|
650
|
-
|
651
|
-
inference_method=inference_method,
|
652
|
-
)
|
663
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
664
|
+
self._deps = self._get_dependencies()
|
653
665
|
assert isinstance(
|
654
666
|
dataset._session, Session
|
655
667
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -714,10 +726,8 @@ class FactorAnalysis(BaseTransformer):
|
|
714
726
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
715
727
|
|
716
728
|
if isinstance(dataset, DataFrame):
|
717
|
-
self.
|
718
|
-
|
719
|
-
inference_method=inference_method,
|
720
|
-
)
|
729
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
730
|
+
self._deps = self._get_dependencies()
|
721
731
|
assert isinstance(
|
722
732
|
dataset._session, Session
|
723
733
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -779,10 +789,8 @@ class FactorAnalysis(BaseTransformer):
|
|
779
789
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
780
790
|
|
781
791
|
if isinstance(dataset, DataFrame):
|
782
|
-
self.
|
783
|
-
|
784
|
-
inference_method=inference_method,
|
785
|
-
)
|
792
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
793
|
+
self._deps = self._get_dependencies()
|
786
794
|
assert isinstance(
|
787
795
|
dataset._session, Session
|
788
796
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -850,10 +858,8 @@ class FactorAnalysis(BaseTransformer):
|
|
850
858
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
851
859
|
|
852
860
|
if isinstance(dataset, DataFrame):
|
853
|
-
self.
|
854
|
-
|
855
|
-
inference_method=inference_method,
|
856
|
-
)
|
861
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
862
|
+
self._deps = self._get_dependencies()
|
857
863
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
858
864
|
transform_kwargs = dict(
|
859
865
|
session=dataset._session,
|
@@ -917,17 +923,15 @@ class FactorAnalysis(BaseTransformer):
|
|
917
923
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
918
924
|
|
919
925
|
if isinstance(dataset, DataFrame):
|
920
|
-
self.
|
921
|
-
|
922
|
-
inference_method="score",
|
923
|
-
)
|
926
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
927
|
+
self._deps = self._get_dependencies()
|
924
928
|
selected_cols = self._get_active_columns()
|
925
929
|
if len(selected_cols) > 0:
|
926
930
|
dataset = dataset.select(selected_cols)
|
927
931
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
928
932
|
transform_kwargs = dict(
|
929
933
|
session=dataset._session,
|
930
|
-
dependencies=
|
934
|
+
dependencies=self._deps,
|
931
935
|
score_sproc_imports=['sklearn'],
|
932
936
|
)
|
933
937
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -992,11 +996,8 @@ class FactorAnalysis(BaseTransformer):
|
|
992
996
|
|
993
997
|
if isinstance(dataset, DataFrame):
|
994
998
|
|
995
|
-
self.
|
996
|
-
|
997
|
-
inference_method=inference_method,
|
998
|
-
|
999
|
-
)
|
999
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1000
|
+
self._deps = self._get_dependencies()
|
1000
1001
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1001
1002
|
transform_kwargs = dict(
|
1002
1003
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class FastICA(BaseTransformer):
|
70
64
|
r"""FastICA: a fast algorithm for Independent Component Analysis
|
71
65
|
For more details on this class, see [sklearn.decomposition.FastICA]
|
@@ -330,20 +324,17 @@ class FastICA(BaseTransformer):
|
|
330
324
|
self,
|
331
325
|
dataset: DataFrame,
|
332
326
|
inference_method: str,
|
333
|
-
) ->
|
334
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
335
|
-
return the available package that exists in the snowflake anaconda channel
|
327
|
+
) -> None:
|
328
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
336
329
|
|
337
330
|
Args:
|
338
331
|
dataset: snowpark dataframe
|
339
332
|
inference_method: the inference method such as predict, score...
|
340
|
-
|
333
|
+
|
341
334
|
Raises:
|
342
335
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
343
336
|
SnowflakeMLException: If the session is None, raise error
|
344
337
|
|
345
|
-
Returns:
|
346
|
-
A list of available package that exists in the snowflake anaconda channel
|
347
338
|
"""
|
348
339
|
if not self._is_fitted:
|
349
340
|
raise exceptions.SnowflakeMLException(
|
@@ -361,9 +352,7 @@ class FastICA(BaseTransformer):
|
|
361
352
|
"Session must not specified for snowpark dataset."
|
362
353
|
),
|
363
354
|
)
|
364
|
-
|
365
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
366
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
355
|
+
|
367
356
|
|
368
357
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
369
358
|
@telemetry.send_api_usage_telemetry(
|
@@ -409,7 +398,8 @@ class FastICA(BaseTransformer):
|
|
409
398
|
|
410
399
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
411
400
|
|
412
|
-
self.
|
401
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
402
|
+
self._deps = self._get_dependencies()
|
413
403
|
assert isinstance(
|
414
404
|
dataset._session, Session
|
415
405
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -494,10 +484,8 @@ class FastICA(BaseTransformer):
|
|
494
484
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
495
485
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
496
486
|
|
497
|
-
self.
|
498
|
-
|
499
|
-
inference_method=inference_method,
|
500
|
-
)
|
487
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
488
|
+
self._deps = self._get_dependencies()
|
501
489
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
502
490
|
|
503
491
|
transform_kwargs = dict(
|
@@ -564,16 +552,42 @@ class FastICA(BaseTransformer):
|
|
564
552
|
self._is_fitted = True
|
565
553
|
return output_result
|
566
554
|
|
555
|
+
|
556
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
557
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
558
|
+
""" Fit the model and recover the sources from X
|
559
|
+
For more details on this function, see [sklearn.decomposition.FastICA.fit_transform]
|
560
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FastICA.html#sklearn.decomposition.FastICA.fit_transform)
|
561
|
+
|
567
562
|
|
568
|
-
|
569
|
-
|
570
|
-
|
563
|
+
Raises:
|
564
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
565
|
+
|
566
|
+
Args:
|
567
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
568
|
+
Snowpark or Pandas DataFrame.
|
569
|
+
output_cols_prefix: Prefix for the response columns
|
571
570
|
Returns:
|
572
571
|
Transformed dataset.
|
573
572
|
"""
|
574
|
-
self.
|
575
|
-
|
576
|
-
|
573
|
+
self._infer_input_output_cols(dataset)
|
574
|
+
super()._check_dataset_type(dataset)
|
575
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
576
|
+
estimator=self._sklearn_object,
|
577
|
+
dataset=dataset,
|
578
|
+
input_cols=self.input_cols,
|
579
|
+
label_cols=self.label_cols,
|
580
|
+
sample_weight_col=self.sample_weight_col,
|
581
|
+
autogenerated=self._autogenerated,
|
582
|
+
subproject=_SUBPROJECT,
|
583
|
+
)
|
584
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
585
|
+
drop_input_cols=self._drop_input_cols,
|
586
|
+
expected_output_cols_list=self.output_cols,
|
587
|
+
)
|
588
|
+
self._sklearn_object = fitted_estimator
|
589
|
+
self._is_fitted = True
|
590
|
+
return output_result
|
577
591
|
|
578
592
|
|
579
593
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -664,10 +678,8 @@ class FastICA(BaseTransformer):
|
|
664
678
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
665
679
|
|
666
680
|
if isinstance(dataset, DataFrame):
|
667
|
-
self.
|
668
|
-
|
669
|
-
inference_method=inference_method,
|
670
|
-
)
|
681
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
682
|
+
self._deps = self._get_dependencies()
|
671
683
|
assert isinstance(
|
672
684
|
dataset._session, Session
|
673
685
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -732,10 +744,8 @@ class FastICA(BaseTransformer):
|
|
732
744
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
733
745
|
|
734
746
|
if isinstance(dataset, DataFrame):
|
735
|
-
self.
|
736
|
-
|
737
|
-
inference_method=inference_method,
|
738
|
-
)
|
747
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
748
|
+
self._deps = self._get_dependencies()
|
739
749
|
assert isinstance(
|
740
750
|
dataset._session, Session
|
741
751
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -797,10 +807,8 @@ class FastICA(BaseTransformer):
|
|
797
807
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
798
808
|
|
799
809
|
if isinstance(dataset, DataFrame):
|
800
|
-
self.
|
801
|
-
|
802
|
-
inference_method=inference_method,
|
803
|
-
)
|
810
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
811
|
+
self._deps = self._get_dependencies()
|
804
812
|
assert isinstance(
|
805
813
|
dataset._session, Session
|
806
814
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -866,10 +874,8 @@ class FastICA(BaseTransformer):
|
|
866
874
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
867
875
|
|
868
876
|
if isinstance(dataset, DataFrame):
|
869
|
-
self.
|
870
|
-
|
871
|
-
inference_method=inference_method,
|
872
|
-
)
|
877
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
878
|
+
self._deps = self._get_dependencies()
|
873
879
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
874
880
|
transform_kwargs = dict(
|
875
881
|
session=dataset._session,
|
@@ -931,17 +937,15 @@ class FastICA(BaseTransformer):
|
|
931
937
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
932
938
|
|
933
939
|
if isinstance(dataset, DataFrame):
|
934
|
-
self.
|
935
|
-
|
936
|
-
inference_method="score",
|
937
|
-
)
|
940
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
941
|
+
self._deps = self._get_dependencies()
|
938
942
|
selected_cols = self._get_active_columns()
|
939
943
|
if len(selected_cols) > 0:
|
940
944
|
dataset = dataset.select(selected_cols)
|
941
945
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
942
946
|
transform_kwargs = dict(
|
943
947
|
session=dataset._session,
|
944
|
-
dependencies=
|
948
|
+
dependencies=self._deps,
|
945
949
|
score_sproc_imports=['sklearn'],
|
946
950
|
)
|
947
951
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1006,11 +1010,8 @@ class FastICA(BaseTransformer):
|
|
1006
1010
|
|
1007
1011
|
if isinstance(dataset, DataFrame):
|
1008
1012
|
|
1009
|
-
self.
|
1010
|
-
|
1011
|
-
inference_method=inference_method,
|
1012
|
-
|
1013
|
-
)
|
1013
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
1014
|
+
self._deps = self._get_dependencies()
|
1014
1015
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
1015
1016
|
transform_kwargs = dict(
|
1016
1017
|
session = dataset._session,
|
@@ -60,12 +60,6 @@ _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.decomposition".replace("
|
|
60
60
|
|
61
61
|
DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
|
62
62
|
|
63
|
-
def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
|
64
|
-
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
65
|
-
return False and callable(getattr(self._sklearn_object, "fit_transform", None))
|
66
|
-
return check
|
67
|
-
|
68
|
-
|
69
63
|
class IncrementalPCA(BaseTransformer):
|
70
64
|
r"""Incremental principal components analysis (IPCA)
|
71
65
|
For more details on this class, see [sklearn.decomposition.IncrementalPCA]
|
@@ -282,20 +276,17 @@ class IncrementalPCA(BaseTransformer):
|
|
282
276
|
self,
|
283
277
|
dataset: DataFrame,
|
284
278
|
inference_method: str,
|
285
|
-
) ->
|
286
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
287
|
-
return the available package that exists in the snowflake anaconda channel
|
279
|
+
) -> None:
|
280
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe.
|
288
281
|
|
289
282
|
Args:
|
290
283
|
dataset: snowpark dataframe
|
291
284
|
inference_method: the inference method such as predict, score...
|
292
|
-
|
285
|
+
|
293
286
|
Raises:
|
294
287
|
SnowflakeMLException: If the estimator is not fitted, raise error
|
295
288
|
SnowflakeMLException: If the session is None, raise error
|
296
289
|
|
297
|
-
Returns:
|
298
|
-
A list of available package that exists in the snowflake anaconda channel
|
299
290
|
"""
|
300
291
|
if not self._is_fitted:
|
301
292
|
raise exceptions.SnowflakeMLException(
|
@@ -313,9 +304,7 @@ class IncrementalPCA(BaseTransformer):
|
|
313
304
|
"Session must not specified for snowpark dataset."
|
314
305
|
),
|
315
306
|
)
|
316
|
-
|
317
|
-
return pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
318
|
-
pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
|
307
|
+
|
319
308
|
|
320
309
|
@available_if(original_estimator_has_callable("predict")) # type: ignore[misc]
|
321
310
|
@telemetry.send_api_usage_telemetry(
|
@@ -361,7 +350,8 @@ class IncrementalPCA(BaseTransformer):
|
|
361
350
|
|
362
351
|
expected_type_inferred = convert_sp_to_sf_type(label_cols_signatures[0].as_snowpark_type())
|
363
352
|
|
364
|
-
self.
|
353
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
354
|
+
self._deps = self._get_dependencies()
|
365
355
|
assert isinstance(
|
366
356
|
dataset._session, Session
|
367
357
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -446,10 +436,8 @@ class IncrementalPCA(BaseTransformer):
|
|
446
436
|
if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
|
447
437
|
expected_dtype = convert_sp_to_sf_type(output_types[0])
|
448
438
|
|
449
|
-
self.
|
450
|
-
|
451
|
-
inference_method=inference_method,
|
452
|
-
)
|
439
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
440
|
+
self._deps = self._get_dependencies()
|
453
441
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
454
442
|
|
455
443
|
transform_kwargs = dict(
|
@@ -516,16 +504,42 @@ class IncrementalPCA(BaseTransformer):
|
|
516
504
|
self._is_fitted = True
|
517
505
|
return output_result
|
518
506
|
|
507
|
+
|
508
|
+
@available_if(original_estimator_has_callable("fit_transform")) # type: ignore[misc]
|
509
|
+
def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame], output_cols_prefix: str = "fit_transform_",) -> Union[DataFrame, pd.DataFrame]:
|
510
|
+
""" Fit to data, then transform it
|
511
|
+
For more details on this function, see [sklearn.decomposition.IncrementalPCA.fit_transform]
|
512
|
+
(https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.IncrementalPCA.html#sklearn.decomposition.IncrementalPCA.fit_transform)
|
513
|
+
|
519
514
|
|
520
|
-
|
521
|
-
|
522
|
-
|
515
|
+
Raises:
|
516
|
+
TypeError: Supported dataset types: snowpark.DataFrame, pandas.DataFrame.
|
517
|
+
|
518
|
+
Args:
|
519
|
+
dataset: Union[snowflake.snowpark.DataFrame, pandas.DataFrame]
|
520
|
+
Snowpark or Pandas DataFrame.
|
521
|
+
output_cols_prefix: Prefix for the response columns
|
523
522
|
Returns:
|
524
523
|
Transformed dataset.
|
525
524
|
"""
|
526
|
-
self.
|
527
|
-
|
528
|
-
|
525
|
+
self._infer_input_output_cols(dataset)
|
526
|
+
super()._check_dataset_type(dataset)
|
527
|
+
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
528
|
+
estimator=self._sklearn_object,
|
529
|
+
dataset=dataset,
|
530
|
+
input_cols=self.input_cols,
|
531
|
+
label_cols=self.label_cols,
|
532
|
+
sample_weight_col=self.sample_weight_col,
|
533
|
+
autogenerated=self._autogenerated,
|
534
|
+
subproject=_SUBPROJECT,
|
535
|
+
)
|
536
|
+
output_result, fitted_estimator = model_trainer.train_fit_transform(
|
537
|
+
drop_input_cols=self._drop_input_cols,
|
538
|
+
expected_output_cols_list=self.output_cols,
|
539
|
+
)
|
540
|
+
self._sklearn_object = fitted_estimator
|
541
|
+
self._is_fitted = True
|
542
|
+
return output_result
|
529
543
|
|
530
544
|
|
531
545
|
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
@@ -616,10 +630,8 @@ class IncrementalPCA(BaseTransformer):
|
|
616
630
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
617
631
|
|
618
632
|
if isinstance(dataset, DataFrame):
|
619
|
-
self.
|
620
|
-
|
621
|
-
inference_method=inference_method,
|
622
|
-
)
|
633
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
634
|
+
self._deps = self._get_dependencies()
|
623
635
|
assert isinstance(
|
624
636
|
dataset._session, Session
|
625
637
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -684,10 +696,8 @@ class IncrementalPCA(BaseTransformer):
|
|
684
696
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
685
697
|
|
686
698
|
if isinstance(dataset, DataFrame):
|
687
|
-
self.
|
688
|
-
|
689
|
-
inference_method=inference_method,
|
690
|
-
)
|
699
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
700
|
+
self._deps = self._get_dependencies()
|
691
701
|
assert isinstance(
|
692
702
|
dataset._session, Session
|
693
703
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -749,10 +759,8 @@ class IncrementalPCA(BaseTransformer):
|
|
749
759
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
750
760
|
|
751
761
|
if isinstance(dataset, DataFrame):
|
752
|
-
self.
|
753
|
-
|
754
|
-
inference_method=inference_method,
|
755
|
-
)
|
762
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
763
|
+
self._deps = self._get_dependencies()
|
756
764
|
assert isinstance(
|
757
765
|
dataset._session, Session
|
758
766
|
) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
@@ -818,10 +826,8 @@ class IncrementalPCA(BaseTransformer):
|
|
818
826
|
expected_output_cols = self._get_output_column_names(output_cols_prefix)
|
819
827
|
|
820
828
|
if isinstance(dataset, DataFrame):
|
821
|
-
self.
|
822
|
-
|
823
|
-
inference_method=inference_method,
|
824
|
-
)
|
829
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
830
|
+
self._deps = self._get_dependencies()
|
825
831
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
826
832
|
transform_kwargs = dict(
|
827
833
|
session=dataset._session,
|
@@ -883,17 +889,15 @@ class IncrementalPCA(BaseTransformer):
|
|
883
889
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
884
890
|
|
885
891
|
if isinstance(dataset, DataFrame):
|
886
|
-
self.
|
887
|
-
|
888
|
-
inference_method="score",
|
889
|
-
)
|
892
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method="score")
|
893
|
+
self._deps = self._get_dependencies()
|
890
894
|
selected_cols = self._get_active_columns()
|
891
895
|
if len(selected_cols) > 0:
|
892
896
|
dataset = dataset.select(selected_cols)
|
893
897
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
894
898
|
transform_kwargs = dict(
|
895
899
|
session=dataset._session,
|
896
|
-
dependencies=
|
900
|
+
dependencies=self._deps,
|
897
901
|
score_sproc_imports=['sklearn'],
|
898
902
|
)
|
899
903
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -958,11 +962,8 @@ class IncrementalPCA(BaseTransformer):
|
|
958
962
|
|
959
963
|
if isinstance(dataset, DataFrame):
|
960
964
|
|
961
|
-
self.
|
962
|
-
|
963
|
-
inference_method=inference_method,
|
964
|
-
|
965
|
-
)
|
965
|
+
self._batch_inference_validate_snowpark(dataset=dataset, inference_method=inference_method)
|
966
|
+
self._deps = self._get_dependencies()
|
966
967
|
assert isinstance(dataset._session, Session) # mypy does not recognize the check in _batch_inference_validate_snowpark()
|
967
968
|
transform_kwargs = dict(
|
968
969
|
session = dataset._session,
|