snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -378,18 +378,24 @@ class KernelPCA(BaseTransformer):
|
|
378
378
|
self._get_model_signatures(dataset)
|
379
379
|
return self
|
380
380
|
|
381
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
382
|
-
if self._drop_input_cols:
|
383
|
-
return []
|
384
|
-
else:
|
385
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
386
|
-
|
387
381
|
def _batch_inference_validate_snowpark(
|
388
382
|
self,
|
389
383
|
dataset: DataFrame,
|
390
384
|
inference_method: str,
|
391
385
|
) -> List[str]:
|
392
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
386
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
387
|
+
return the available package that exists in the snowflake anaconda channel
|
388
|
+
|
389
|
+
Args:
|
390
|
+
dataset: snowpark dataframe
|
391
|
+
inference_method: the inference method such as predict, score...
|
392
|
+
|
393
|
+
Raises:
|
394
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
395
|
+
SnowflakeMLException: If the session is None, raise error
|
396
|
+
|
397
|
+
Returns:
|
398
|
+
A list of available package that exists in the snowflake anaconda channel
|
393
399
|
"""
|
394
400
|
if not self._is_fitted:
|
395
401
|
raise exceptions.SnowflakeMLException(
|
@@ -461,7 +467,7 @@ class KernelPCA(BaseTransformer):
|
|
461
467
|
transform_kwargs = dict(
|
462
468
|
session = dataset._session,
|
463
469
|
dependencies = self._deps,
|
464
|
-
|
470
|
+
drop_input_cols = self._drop_input_cols,
|
465
471
|
expected_output_cols_type = expected_type_inferred,
|
466
472
|
)
|
467
473
|
|
@@ -523,16 +529,16 @@ class KernelPCA(BaseTransformer):
|
|
523
529
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
524
530
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
525
531
|
# each row containing a list of values.
|
526
|
-
expected_dtype = "
|
532
|
+
expected_dtype = "array"
|
527
533
|
|
528
534
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
529
535
|
if expected_dtype == "":
|
530
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
536
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
531
537
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
532
|
-
expected_dtype = "
|
533
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
538
|
+
expected_dtype = "array"
|
539
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
534
540
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
535
|
-
expected_dtype = "
|
541
|
+
expected_dtype = "array"
|
536
542
|
else:
|
537
543
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
538
544
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -550,7 +556,7 @@ class KernelPCA(BaseTransformer):
|
|
550
556
|
transform_kwargs = dict(
|
551
557
|
session = dataset._session,
|
552
558
|
dependencies = self._deps,
|
553
|
-
|
559
|
+
drop_input_cols = self._drop_input_cols,
|
554
560
|
expected_output_cols_type = expected_dtype,
|
555
561
|
)
|
556
562
|
|
@@ -601,7 +607,7 @@ class KernelPCA(BaseTransformer):
|
|
601
607
|
subproject=_SUBPROJECT,
|
602
608
|
)
|
603
609
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
604
|
-
|
610
|
+
drop_input_cols=self._drop_input_cols,
|
605
611
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
606
612
|
)
|
607
613
|
self._sklearn_object = fitted_estimator
|
@@ -619,44 +625,6 @@ class KernelPCA(BaseTransformer):
|
|
619
625
|
assert self._sklearn_object is not None
|
620
626
|
return self._sklearn_object.embedding_
|
621
627
|
|
622
|
-
|
623
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
624
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
625
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
626
|
-
"""
|
627
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
628
|
-
if output_cols:
|
629
|
-
output_cols = [
|
630
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
631
|
-
for c in output_cols
|
632
|
-
]
|
633
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
634
|
-
output_cols = [output_cols_prefix]
|
635
|
-
elif self._sklearn_object is not None:
|
636
|
-
classes = self._sklearn_object.classes_
|
637
|
-
if isinstance(classes, numpy.ndarray):
|
638
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
639
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
640
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
641
|
-
output_cols = []
|
642
|
-
for i, cl in enumerate(classes):
|
643
|
-
# For binary classification, there is only one output column for each class
|
644
|
-
# ndarray as the two classes are complementary.
|
645
|
-
if len(cl) == 2:
|
646
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
647
|
-
else:
|
648
|
-
output_cols.extend([
|
649
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
650
|
-
])
|
651
|
-
else:
|
652
|
-
output_cols = []
|
653
|
-
|
654
|
-
# Make sure column names are valid snowflake identifiers.
|
655
|
-
assert output_cols is not None # Make MyPy happy
|
656
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
657
|
-
|
658
|
-
return rv
|
659
|
-
|
660
628
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
661
629
|
@telemetry.send_api_usage_telemetry(
|
662
630
|
project=_PROJECT,
|
@@ -696,7 +664,7 @@ class KernelPCA(BaseTransformer):
|
|
696
664
|
transform_kwargs = dict(
|
697
665
|
session=dataset._session,
|
698
666
|
dependencies=self._deps,
|
699
|
-
|
667
|
+
drop_input_cols = self._drop_input_cols,
|
700
668
|
expected_output_cols_type="float",
|
701
669
|
)
|
702
670
|
|
@@ -761,7 +729,7 @@ class KernelPCA(BaseTransformer):
|
|
761
729
|
transform_kwargs = dict(
|
762
730
|
session=dataset._session,
|
763
731
|
dependencies=self._deps,
|
764
|
-
|
732
|
+
drop_input_cols = self._drop_input_cols,
|
765
733
|
expected_output_cols_type="float",
|
766
734
|
)
|
767
735
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -822,7 +790,7 @@ class KernelPCA(BaseTransformer):
|
|
822
790
|
transform_kwargs = dict(
|
823
791
|
session=dataset._session,
|
824
792
|
dependencies=self._deps,
|
825
|
-
|
793
|
+
drop_input_cols = self._drop_input_cols,
|
826
794
|
expected_output_cols_type="float",
|
827
795
|
)
|
828
796
|
|
@@ -887,7 +855,7 @@ class KernelPCA(BaseTransformer):
|
|
887
855
|
transform_kwargs = dict(
|
888
856
|
session=dataset._session,
|
889
857
|
dependencies=self._deps,
|
890
|
-
|
858
|
+
drop_input_cols = self._drop_input_cols,
|
891
859
|
expected_output_cols_type="float",
|
892
860
|
)
|
893
861
|
|
@@ -941,13 +909,17 @@ class KernelPCA(BaseTransformer):
|
|
941
909
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
942
910
|
|
943
911
|
if isinstance(dataset, DataFrame):
|
912
|
+
self._deps = self._batch_inference_validate_snowpark(
|
913
|
+
dataset=dataset,
|
914
|
+
inference_method="score",
|
915
|
+
)
|
944
916
|
selected_cols = self._get_active_columns()
|
945
917
|
if len(selected_cols) > 0:
|
946
918
|
dataset = dataset.select(selected_cols)
|
947
919
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
948
920
|
transform_kwargs = dict(
|
949
921
|
session=dataset._session,
|
950
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
922
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
951
923
|
score_sproc_imports=['sklearn'],
|
952
924
|
)
|
953
925
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1021,9 +993,9 @@ class KernelPCA(BaseTransformer):
|
|
1021
993
|
transform_kwargs = dict(
|
1022
994
|
session = dataset._session,
|
1023
995
|
dependencies = self._deps,
|
1024
|
-
|
1025
|
-
expected_output_cols_type
|
1026
|
-
n_neighbors =
|
996
|
+
drop_input_cols = self._drop_input_cols,
|
997
|
+
expected_output_cols_type="array",
|
998
|
+
n_neighbors = n_neighbors,
|
1027
999
|
return_distance = return_distance
|
1028
1000
|
)
|
1029
1001
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -400,18 +400,24 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
400
400
|
self._get_model_signatures(dataset)
|
401
401
|
return self
|
402
402
|
|
403
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
404
|
-
if self._drop_input_cols:
|
405
|
-
return []
|
406
|
-
else:
|
407
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
408
|
-
|
409
403
|
def _batch_inference_validate_snowpark(
|
410
404
|
self,
|
411
405
|
dataset: DataFrame,
|
412
406
|
inference_method: str,
|
413
407
|
) -> List[str]:
|
414
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
408
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
409
|
+
return the available package that exists in the snowflake anaconda channel
|
410
|
+
|
411
|
+
Args:
|
412
|
+
dataset: snowpark dataframe
|
413
|
+
inference_method: the inference method such as predict, score...
|
414
|
+
|
415
|
+
Raises:
|
416
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
417
|
+
SnowflakeMLException: If the session is None, raise error
|
418
|
+
|
419
|
+
Returns:
|
420
|
+
A list of available package that exists in the snowflake anaconda channel
|
415
421
|
"""
|
416
422
|
if not self._is_fitted:
|
417
423
|
raise exceptions.SnowflakeMLException(
|
@@ -483,7 +489,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
483
489
|
transform_kwargs = dict(
|
484
490
|
session = dataset._session,
|
485
491
|
dependencies = self._deps,
|
486
|
-
|
492
|
+
drop_input_cols = self._drop_input_cols,
|
487
493
|
expected_output_cols_type = expected_type_inferred,
|
488
494
|
)
|
489
495
|
|
@@ -545,16 +551,16 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
545
551
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
546
552
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
547
553
|
# each row containing a list of values.
|
548
|
-
expected_dtype = "
|
554
|
+
expected_dtype = "array"
|
549
555
|
|
550
556
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
551
557
|
if expected_dtype == "":
|
552
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
558
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
553
559
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
554
|
-
expected_dtype = "
|
555
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
560
|
+
expected_dtype = "array"
|
561
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
556
562
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
557
|
-
expected_dtype = "
|
563
|
+
expected_dtype = "array"
|
558
564
|
else:
|
559
565
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
560
566
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -572,7 +578,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
572
578
|
transform_kwargs = dict(
|
573
579
|
session = dataset._session,
|
574
580
|
dependencies = self._deps,
|
575
|
-
|
581
|
+
drop_input_cols = self._drop_input_cols,
|
576
582
|
expected_output_cols_type = expected_dtype,
|
577
583
|
)
|
578
584
|
|
@@ -623,7 +629,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
623
629
|
subproject=_SUBPROJECT,
|
624
630
|
)
|
625
631
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
626
|
-
|
632
|
+
drop_input_cols=self._drop_input_cols,
|
627
633
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
628
634
|
)
|
629
635
|
self._sklearn_object = fitted_estimator
|
@@ -641,44 +647,6 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
641
647
|
assert self._sklearn_object is not None
|
642
648
|
return self._sklearn_object.embedding_
|
643
649
|
|
644
|
-
|
645
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
646
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
647
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
648
|
-
"""
|
649
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
650
|
-
if output_cols:
|
651
|
-
output_cols = [
|
652
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
653
|
-
for c in output_cols
|
654
|
-
]
|
655
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
656
|
-
output_cols = [output_cols_prefix]
|
657
|
-
elif self._sklearn_object is not None:
|
658
|
-
classes = self._sklearn_object.classes_
|
659
|
-
if isinstance(classes, numpy.ndarray):
|
660
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
661
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
662
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
663
|
-
output_cols = []
|
664
|
-
for i, cl in enumerate(classes):
|
665
|
-
# For binary classification, there is only one output column for each class
|
666
|
-
# ndarray as the two classes are complementary.
|
667
|
-
if len(cl) == 2:
|
668
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
669
|
-
else:
|
670
|
-
output_cols.extend([
|
671
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
672
|
-
])
|
673
|
-
else:
|
674
|
-
output_cols = []
|
675
|
-
|
676
|
-
# Make sure column names are valid snowflake identifiers.
|
677
|
-
assert output_cols is not None # Make MyPy happy
|
678
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
679
|
-
|
680
|
-
return rv
|
681
|
-
|
682
650
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
683
651
|
@telemetry.send_api_usage_telemetry(
|
684
652
|
project=_PROJECT,
|
@@ -718,7 +686,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
718
686
|
transform_kwargs = dict(
|
719
687
|
session=dataset._session,
|
720
688
|
dependencies=self._deps,
|
721
|
-
|
689
|
+
drop_input_cols = self._drop_input_cols,
|
722
690
|
expected_output_cols_type="float",
|
723
691
|
)
|
724
692
|
|
@@ -783,7 +751,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
783
751
|
transform_kwargs = dict(
|
784
752
|
session=dataset._session,
|
785
753
|
dependencies=self._deps,
|
786
|
-
|
754
|
+
drop_input_cols = self._drop_input_cols,
|
787
755
|
expected_output_cols_type="float",
|
788
756
|
)
|
789
757
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -844,7 +812,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
844
812
|
transform_kwargs = dict(
|
845
813
|
session=dataset._session,
|
846
814
|
dependencies=self._deps,
|
847
|
-
|
815
|
+
drop_input_cols = self._drop_input_cols,
|
848
816
|
expected_output_cols_type="float",
|
849
817
|
)
|
850
818
|
|
@@ -909,7 +877,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
909
877
|
transform_kwargs = dict(
|
910
878
|
session=dataset._session,
|
911
879
|
dependencies=self._deps,
|
912
|
-
|
880
|
+
drop_input_cols = self._drop_input_cols,
|
913
881
|
expected_output_cols_type="float",
|
914
882
|
)
|
915
883
|
|
@@ -963,13 +931,17 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
963
931
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
964
932
|
|
965
933
|
if isinstance(dataset, DataFrame):
|
934
|
+
self._deps = self._batch_inference_validate_snowpark(
|
935
|
+
dataset=dataset,
|
936
|
+
inference_method="score",
|
937
|
+
)
|
966
938
|
selected_cols = self._get_active_columns()
|
967
939
|
if len(selected_cols) > 0:
|
968
940
|
dataset = dataset.select(selected_cols)
|
969
941
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
970
942
|
transform_kwargs = dict(
|
971
943
|
session=dataset._session,
|
972
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
944
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
973
945
|
score_sproc_imports=['sklearn'],
|
974
946
|
)
|
975
947
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1043,9 +1015,9 @@ class MiniBatchDictionaryLearning(BaseTransformer):
|
|
1043
1015
|
transform_kwargs = dict(
|
1044
1016
|
session = dataset._session,
|
1045
1017
|
dependencies = self._deps,
|
1046
|
-
|
1047
|
-
expected_output_cols_type
|
1048
|
-
n_neighbors =
|
1018
|
+
drop_input_cols = self._drop_input_cols,
|
1019
|
+
expected_output_cols_type="array",
|
1020
|
+
n_neighbors = n_neighbors,
|
1049
1021
|
return_distance = return_distance
|
1050
1022
|
)
|
1051
1023
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -345,18 +345,24 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
345
345
|
self._get_model_signatures(dataset)
|
346
346
|
return self
|
347
347
|
|
348
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
349
|
-
if self._drop_input_cols:
|
350
|
-
return []
|
351
|
-
else:
|
352
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
353
|
-
|
354
348
|
def _batch_inference_validate_snowpark(
|
355
349
|
self,
|
356
350
|
dataset: DataFrame,
|
357
351
|
inference_method: str,
|
358
352
|
) -> List[str]:
|
359
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
353
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
354
|
+
return the available package that exists in the snowflake anaconda channel
|
355
|
+
|
356
|
+
Args:
|
357
|
+
dataset: snowpark dataframe
|
358
|
+
inference_method: the inference method such as predict, score...
|
359
|
+
|
360
|
+
Raises:
|
361
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
362
|
+
SnowflakeMLException: If the session is None, raise error
|
363
|
+
|
364
|
+
Returns:
|
365
|
+
A list of available package that exists in the snowflake anaconda channel
|
360
366
|
"""
|
361
367
|
if not self._is_fitted:
|
362
368
|
raise exceptions.SnowflakeMLException(
|
@@ -428,7 +434,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
428
434
|
transform_kwargs = dict(
|
429
435
|
session = dataset._session,
|
430
436
|
dependencies = self._deps,
|
431
|
-
|
437
|
+
drop_input_cols = self._drop_input_cols,
|
432
438
|
expected_output_cols_type = expected_type_inferred,
|
433
439
|
)
|
434
440
|
|
@@ -490,16 +496,16 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
490
496
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
491
497
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
492
498
|
# each row containing a list of values.
|
493
|
-
expected_dtype = "
|
499
|
+
expected_dtype = "array"
|
494
500
|
|
495
501
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
496
502
|
if expected_dtype == "":
|
497
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
503
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
498
504
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
499
|
-
expected_dtype = "
|
500
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
505
|
+
expected_dtype = "array"
|
506
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
501
507
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
502
|
-
expected_dtype = "
|
508
|
+
expected_dtype = "array"
|
503
509
|
else:
|
504
510
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
505
511
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -517,7 +523,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
517
523
|
transform_kwargs = dict(
|
518
524
|
session = dataset._session,
|
519
525
|
dependencies = self._deps,
|
520
|
-
|
526
|
+
drop_input_cols = self._drop_input_cols,
|
521
527
|
expected_output_cols_type = expected_dtype,
|
522
528
|
)
|
523
529
|
|
@@ -568,7 +574,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
568
574
|
subproject=_SUBPROJECT,
|
569
575
|
)
|
570
576
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
571
|
-
|
577
|
+
drop_input_cols=self._drop_input_cols,
|
572
578
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
573
579
|
)
|
574
580
|
self._sklearn_object = fitted_estimator
|
@@ -586,44 +592,6 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
586
592
|
assert self._sklearn_object is not None
|
587
593
|
return self._sklearn_object.embedding_
|
588
594
|
|
589
|
-
|
590
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
591
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
592
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
593
|
-
"""
|
594
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
595
|
-
if output_cols:
|
596
|
-
output_cols = [
|
597
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
598
|
-
for c in output_cols
|
599
|
-
]
|
600
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
601
|
-
output_cols = [output_cols_prefix]
|
602
|
-
elif self._sklearn_object is not None:
|
603
|
-
classes = self._sklearn_object.classes_
|
604
|
-
if isinstance(classes, numpy.ndarray):
|
605
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
606
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
607
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
608
|
-
output_cols = []
|
609
|
-
for i, cl in enumerate(classes):
|
610
|
-
# For binary classification, there is only one output column for each class
|
611
|
-
# ndarray as the two classes are complementary.
|
612
|
-
if len(cl) == 2:
|
613
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
614
|
-
else:
|
615
|
-
output_cols.extend([
|
616
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
617
|
-
])
|
618
|
-
else:
|
619
|
-
output_cols = []
|
620
|
-
|
621
|
-
# Make sure column names are valid snowflake identifiers.
|
622
|
-
assert output_cols is not None # Make MyPy happy
|
623
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
624
|
-
|
625
|
-
return rv
|
626
|
-
|
627
595
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
628
596
|
@telemetry.send_api_usage_telemetry(
|
629
597
|
project=_PROJECT,
|
@@ -663,7 +631,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
663
631
|
transform_kwargs = dict(
|
664
632
|
session=dataset._session,
|
665
633
|
dependencies=self._deps,
|
666
|
-
|
634
|
+
drop_input_cols = self._drop_input_cols,
|
667
635
|
expected_output_cols_type="float",
|
668
636
|
)
|
669
637
|
|
@@ -728,7 +696,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
728
696
|
transform_kwargs = dict(
|
729
697
|
session=dataset._session,
|
730
698
|
dependencies=self._deps,
|
731
|
-
|
699
|
+
drop_input_cols = self._drop_input_cols,
|
732
700
|
expected_output_cols_type="float",
|
733
701
|
)
|
734
702
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -789,7 +757,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
789
757
|
transform_kwargs = dict(
|
790
758
|
session=dataset._session,
|
791
759
|
dependencies=self._deps,
|
792
|
-
|
760
|
+
drop_input_cols = self._drop_input_cols,
|
793
761
|
expected_output_cols_type="float",
|
794
762
|
)
|
795
763
|
|
@@ -854,7 +822,7 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
854
822
|
transform_kwargs = dict(
|
855
823
|
session=dataset._session,
|
856
824
|
dependencies=self._deps,
|
857
|
-
|
825
|
+
drop_input_cols = self._drop_input_cols,
|
858
826
|
expected_output_cols_type="float",
|
859
827
|
)
|
860
828
|
|
@@ -908,13 +876,17 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
908
876
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
909
877
|
|
910
878
|
if isinstance(dataset, DataFrame):
|
879
|
+
self._deps = self._batch_inference_validate_snowpark(
|
880
|
+
dataset=dataset,
|
881
|
+
inference_method="score",
|
882
|
+
)
|
911
883
|
selected_cols = self._get_active_columns()
|
912
884
|
if len(selected_cols) > 0:
|
913
885
|
dataset = dataset.select(selected_cols)
|
914
886
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
915
887
|
transform_kwargs = dict(
|
916
888
|
session=dataset._session,
|
917
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
889
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
918
890
|
score_sproc_imports=['sklearn'],
|
919
891
|
)
|
920
892
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -988,9 +960,9 @@ class MiniBatchSparsePCA(BaseTransformer):
|
|
988
960
|
transform_kwargs = dict(
|
989
961
|
session = dataset._session,
|
990
962
|
dependencies = self._deps,
|
991
|
-
|
992
|
-
expected_output_cols_type
|
993
|
-
n_neighbors =
|
963
|
+
drop_input_cols = self._drop_input_cols,
|
964
|
+
expected_output_cols_type="array",
|
965
|
+
n_neighbors = n_neighbors,
|
994
966
|
return_distance = return_distance
|
995
967
|
)
|
996
968
|
elif isinstance(dataset, pd.DataFrame):
|