snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -303,18 +303,24 @@ class AffinityPropagation(BaseTransformer):
|
|
303
303
|
self._get_model_signatures(dataset)
|
304
304
|
return self
|
305
305
|
|
306
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
307
|
-
if self._drop_input_cols:
|
308
|
-
return []
|
309
|
-
else:
|
310
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
311
|
-
|
312
306
|
def _batch_inference_validate_snowpark(
|
313
307
|
self,
|
314
308
|
dataset: DataFrame,
|
315
309
|
inference_method: str,
|
316
310
|
) -> List[str]:
|
317
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
311
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
312
|
+
return the available package that exists in the snowflake anaconda channel
|
313
|
+
|
314
|
+
Args:
|
315
|
+
dataset: snowpark dataframe
|
316
|
+
inference_method: the inference method such as predict, score...
|
317
|
+
|
318
|
+
Raises:
|
319
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
320
|
+
SnowflakeMLException: If the session is None, raise error
|
321
|
+
|
322
|
+
Returns:
|
323
|
+
A list of available package that exists in the snowflake anaconda channel
|
318
324
|
"""
|
319
325
|
if not self._is_fitted:
|
320
326
|
raise exceptions.SnowflakeMLException(
|
@@ -388,7 +394,7 @@ class AffinityPropagation(BaseTransformer):
|
|
388
394
|
transform_kwargs = dict(
|
389
395
|
session = dataset._session,
|
390
396
|
dependencies = self._deps,
|
391
|
-
|
397
|
+
drop_input_cols = self._drop_input_cols,
|
392
398
|
expected_output_cols_type = expected_type_inferred,
|
393
399
|
)
|
394
400
|
|
@@ -448,16 +454,16 @@ class AffinityPropagation(BaseTransformer):
|
|
448
454
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
449
455
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
450
456
|
# each row containing a list of values.
|
451
|
-
expected_dtype = "
|
457
|
+
expected_dtype = "array"
|
452
458
|
|
453
459
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
454
460
|
if expected_dtype == "":
|
455
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
461
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
456
462
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
457
|
-
expected_dtype = "
|
458
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
463
|
+
expected_dtype = "array"
|
464
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
459
465
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
460
|
-
expected_dtype = "
|
466
|
+
expected_dtype = "array"
|
461
467
|
else:
|
462
468
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
463
469
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -475,7 +481,7 @@ class AffinityPropagation(BaseTransformer):
|
|
475
481
|
transform_kwargs = dict(
|
476
482
|
session = dataset._session,
|
477
483
|
dependencies = self._deps,
|
478
|
-
|
484
|
+
drop_input_cols = self._drop_input_cols,
|
479
485
|
expected_output_cols_type = expected_dtype,
|
480
486
|
)
|
481
487
|
|
@@ -528,7 +534,7 @@ class AffinityPropagation(BaseTransformer):
|
|
528
534
|
subproject=_SUBPROJECT,
|
529
535
|
)
|
530
536
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
531
|
-
|
537
|
+
drop_input_cols=self._drop_input_cols,
|
532
538
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
533
539
|
)
|
534
540
|
self._sklearn_object = fitted_estimator
|
@@ -546,44 +552,6 @@ class AffinityPropagation(BaseTransformer):
|
|
546
552
|
assert self._sklearn_object is not None
|
547
553
|
return self._sklearn_object.embedding_
|
548
554
|
|
549
|
-
|
550
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
551
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
552
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
553
|
-
"""
|
554
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
555
|
-
if output_cols:
|
556
|
-
output_cols = [
|
557
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
558
|
-
for c in output_cols
|
559
|
-
]
|
560
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
561
|
-
output_cols = [output_cols_prefix]
|
562
|
-
elif self._sklearn_object is not None:
|
563
|
-
classes = self._sklearn_object.classes_
|
564
|
-
if isinstance(classes, numpy.ndarray):
|
565
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
566
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
567
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
568
|
-
output_cols = []
|
569
|
-
for i, cl in enumerate(classes):
|
570
|
-
# For binary classification, there is only one output column for each class
|
571
|
-
# ndarray as the two classes are complementary.
|
572
|
-
if len(cl) == 2:
|
573
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
574
|
-
else:
|
575
|
-
output_cols.extend([
|
576
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
577
|
-
])
|
578
|
-
else:
|
579
|
-
output_cols = []
|
580
|
-
|
581
|
-
# Make sure column names are valid snowflake identifiers.
|
582
|
-
assert output_cols is not None # Make MyPy happy
|
583
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
584
|
-
|
585
|
-
return rv
|
586
|
-
|
587
555
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
588
556
|
@telemetry.send_api_usage_telemetry(
|
589
557
|
project=_PROJECT,
|
@@ -623,7 +591,7 @@ class AffinityPropagation(BaseTransformer):
|
|
623
591
|
transform_kwargs = dict(
|
624
592
|
session=dataset._session,
|
625
593
|
dependencies=self._deps,
|
626
|
-
|
594
|
+
drop_input_cols = self._drop_input_cols,
|
627
595
|
expected_output_cols_type="float",
|
628
596
|
)
|
629
597
|
|
@@ -688,7 +656,7 @@ class AffinityPropagation(BaseTransformer):
|
|
688
656
|
transform_kwargs = dict(
|
689
657
|
session=dataset._session,
|
690
658
|
dependencies=self._deps,
|
691
|
-
|
659
|
+
drop_input_cols = self._drop_input_cols,
|
692
660
|
expected_output_cols_type="float",
|
693
661
|
)
|
694
662
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -749,7 +717,7 @@ class AffinityPropagation(BaseTransformer):
|
|
749
717
|
transform_kwargs = dict(
|
750
718
|
session=dataset._session,
|
751
719
|
dependencies=self._deps,
|
752
|
-
|
720
|
+
drop_input_cols = self._drop_input_cols,
|
753
721
|
expected_output_cols_type="float",
|
754
722
|
)
|
755
723
|
|
@@ -814,7 +782,7 @@ class AffinityPropagation(BaseTransformer):
|
|
814
782
|
transform_kwargs = dict(
|
815
783
|
session=dataset._session,
|
816
784
|
dependencies=self._deps,
|
817
|
-
|
785
|
+
drop_input_cols = self._drop_input_cols,
|
818
786
|
expected_output_cols_type="float",
|
819
787
|
)
|
820
788
|
|
@@ -868,13 +836,17 @@ class AffinityPropagation(BaseTransformer):
|
|
868
836
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
869
837
|
|
870
838
|
if isinstance(dataset, DataFrame):
|
839
|
+
self._deps = self._batch_inference_validate_snowpark(
|
840
|
+
dataset=dataset,
|
841
|
+
inference_method="score",
|
842
|
+
)
|
871
843
|
selected_cols = self._get_active_columns()
|
872
844
|
if len(selected_cols) > 0:
|
873
845
|
dataset = dataset.select(selected_cols)
|
874
846
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
875
847
|
transform_kwargs = dict(
|
876
848
|
session=dataset._session,
|
877
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
849
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
878
850
|
score_sproc_imports=['sklearn'],
|
879
851
|
)
|
880
852
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -948,9 +920,9 @@ class AffinityPropagation(BaseTransformer):
|
|
948
920
|
transform_kwargs = dict(
|
949
921
|
session = dataset._session,
|
950
922
|
dependencies = self._deps,
|
951
|
-
|
952
|
-
expected_output_cols_type
|
953
|
-
n_neighbors =
|
923
|
+
drop_input_cols = self._drop_input_cols,
|
924
|
+
expected_output_cols_type="array",
|
925
|
+
n_neighbors = n_neighbors,
|
954
926
|
return_distance = return_distance
|
955
927
|
)
|
956
928
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -336,18 +336,24 @@ class AgglomerativeClustering(BaseTransformer):
|
|
336
336
|
self._get_model_signatures(dataset)
|
337
337
|
return self
|
338
338
|
|
339
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
340
|
-
if self._drop_input_cols:
|
341
|
-
return []
|
342
|
-
else:
|
343
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
344
|
-
|
345
339
|
def _batch_inference_validate_snowpark(
|
346
340
|
self,
|
347
341
|
dataset: DataFrame,
|
348
342
|
inference_method: str,
|
349
343
|
) -> List[str]:
|
350
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
344
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
345
|
+
return the available package that exists in the snowflake anaconda channel
|
346
|
+
|
347
|
+
Args:
|
348
|
+
dataset: snowpark dataframe
|
349
|
+
inference_method: the inference method such as predict, score...
|
350
|
+
|
351
|
+
Raises:
|
352
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
353
|
+
SnowflakeMLException: If the session is None, raise error
|
354
|
+
|
355
|
+
Returns:
|
356
|
+
A list of available package that exists in the snowflake anaconda channel
|
351
357
|
"""
|
352
358
|
if not self._is_fitted:
|
353
359
|
raise exceptions.SnowflakeMLException(
|
@@ -419,7 +425,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
419
425
|
transform_kwargs = dict(
|
420
426
|
session = dataset._session,
|
421
427
|
dependencies = self._deps,
|
422
|
-
|
428
|
+
drop_input_cols = self._drop_input_cols,
|
423
429
|
expected_output_cols_type = expected_type_inferred,
|
424
430
|
)
|
425
431
|
|
@@ -479,16 +485,16 @@ class AgglomerativeClustering(BaseTransformer):
|
|
479
485
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
480
486
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
481
487
|
# each row containing a list of values.
|
482
|
-
expected_dtype = "
|
488
|
+
expected_dtype = "array"
|
483
489
|
|
484
490
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
485
491
|
if expected_dtype == "":
|
486
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
492
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
487
493
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
488
|
-
expected_dtype = "
|
489
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
494
|
+
expected_dtype = "array"
|
495
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
490
496
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
491
|
-
expected_dtype = "
|
497
|
+
expected_dtype = "array"
|
492
498
|
else:
|
493
499
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
494
500
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -506,7 +512,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
506
512
|
transform_kwargs = dict(
|
507
513
|
session = dataset._session,
|
508
514
|
dependencies = self._deps,
|
509
|
-
|
515
|
+
drop_input_cols = self._drop_input_cols,
|
510
516
|
expected_output_cols_type = expected_dtype,
|
511
517
|
)
|
512
518
|
|
@@ -559,7 +565,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
559
565
|
subproject=_SUBPROJECT,
|
560
566
|
)
|
561
567
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
562
|
-
|
568
|
+
drop_input_cols=self._drop_input_cols,
|
563
569
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
564
570
|
)
|
565
571
|
self._sklearn_object = fitted_estimator
|
@@ -577,44 +583,6 @@ class AgglomerativeClustering(BaseTransformer):
|
|
577
583
|
assert self._sklearn_object is not None
|
578
584
|
return self._sklearn_object.embedding_
|
579
585
|
|
580
|
-
|
581
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
582
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
583
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
584
|
-
"""
|
585
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
586
|
-
if output_cols:
|
587
|
-
output_cols = [
|
588
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
589
|
-
for c in output_cols
|
590
|
-
]
|
591
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
592
|
-
output_cols = [output_cols_prefix]
|
593
|
-
elif self._sklearn_object is not None:
|
594
|
-
classes = self._sklearn_object.classes_
|
595
|
-
if isinstance(classes, numpy.ndarray):
|
596
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
597
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
598
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
599
|
-
output_cols = []
|
600
|
-
for i, cl in enumerate(classes):
|
601
|
-
# For binary classification, there is only one output column for each class
|
602
|
-
# ndarray as the two classes are complementary.
|
603
|
-
if len(cl) == 2:
|
604
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
605
|
-
else:
|
606
|
-
output_cols.extend([
|
607
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
608
|
-
])
|
609
|
-
else:
|
610
|
-
output_cols = []
|
611
|
-
|
612
|
-
# Make sure column names are valid snowflake identifiers.
|
613
|
-
assert output_cols is not None # Make MyPy happy
|
614
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
615
|
-
|
616
|
-
return rv
|
617
|
-
|
618
586
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
619
587
|
@telemetry.send_api_usage_telemetry(
|
620
588
|
project=_PROJECT,
|
@@ -654,7 +622,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
654
622
|
transform_kwargs = dict(
|
655
623
|
session=dataset._session,
|
656
624
|
dependencies=self._deps,
|
657
|
-
|
625
|
+
drop_input_cols = self._drop_input_cols,
|
658
626
|
expected_output_cols_type="float",
|
659
627
|
)
|
660
628
|
|
@@ -719,7 +687,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
719
687
|
transform_kwargs = dict(
|
720
688
|
session=dataset._session,
|
721
689
|
dependencies=self._deps,
|
722
|
-
|
690
|
+
drop_input_cols = self._drop_input_cols,
|
723
691
|
expected_output_cols_type="float",
|
724
692
|
)
|
725
693
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -780,7 +748,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
780
748
|
transform_kwargs = dict(
|
781
749
|
session=dataset._session,
|
782
750
|
dependencies=self._deps,
|
783
|
-
|
751
|
+
drop_input_cols = self._drop_input_cols,
|
784
752
|
expected_output_cols_type="float",
|
785
753
|
)
|
786
754
|
|
@@ -845,7 +813,7 @@ class AgglomerativeClustering(BaseTransformer):
|
|
845
813
|
transform_kwargs = dict(
|
846
814
|
session=dataset._session,
|
847
815
|
dependencies=self._deps,
|
848
|
-
|
816
|
+
drop_input_cols = self._drop_input_cols,
|
849
817
|
expected_output_cols_type="float",
|
850
818
|
)
|
851
819
|
|
@@ -899,13 +867,17 @@ class AgglomerativeClustering(BaseTransformer):
|
|
899
867
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
900
868
|
|
901
869
|
if isinstance(dataset, DataFrame):
|
870
|
+
self._deps = self._batch_inference_validate_snowpark(
|
871
|
+
dataset=dataset,
|
872
|
+
inference_method="score",
|
873
|
+
)
|
902
874
|
selected_cols = self._get_active_columns()
|
903
875
|
if len(selected_cols) > 0:
|
904
876
|
dataset = dataset.select(selected_cols)
|
905
877
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
906
878
|
transform_kwargs = dict(
|
907
879
|
session=dataset._session,
|
908
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
880
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
909
881
|
score_sproc_imports=['sklearn'],
|
910
882
|
)
|
911
883
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -979,9 +951,9 @@ class AgglomerativeClustering(BaseTransformer):
|
|
979
951
|
transform_kwargs = dict(
|
980
952
|
session = dataset._session,
|
981
953
|
dependencies = self._deps,
|
982
|
-
|
983
|
-
expected_output_cols_type
|
984
|
-
n_neighbors =
|
954
|
+
drop_input_cols = self._drop_input_cols,
|
955
|
+
expected_output_cols_type="array",
|
956
|
+
n_neighbors = n_neighbors,
|
985
957
|
return_distance = return_distance
|
986
958
|
)
|
987
959
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -294,18 +294,24 @@ class Birch(BaseTransformer):
|
|
294
294
|
self._get_model_signatures(dataset)
|
295
295
|
return self
|
296
296
|
|
297
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
298
|
-
if self._drop_input_cols:
|
299
|
-
return []
|
300
|
-
else:
|
301
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
302
|
-
|
303
297
|
def _batch_inference_validate_snowpark(
|
304
298
|
self,
|
305
299
|
dataset: DataFrame,
|
306
300
|
inference_method: str,
|
307
301
|
) -> List[str]:
|
308
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
302
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
303
|
+
return the available package that exists in the snowflake anaconda channel
|
304
|
+
|
305
|
+
Args:
|
306
|
+
dataset: snowpark dataframe
|
307
|
+
inference_method: the inference method such as predict, score...
|
308
|
+
|
309
|
+
Raises:
|
310
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
311
|
+
SnowflakeMLException: If the session is None, raise error
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
A list of available package that exists in the snowflake anaconda channel
|
309
315
|
"""
|
310
316
|
if not self._is_fitted:
|
311
317
|
raise exceptions.SnowflakeMLException(
|
@@ -379,7 +385,7 @@ class Birch(BaseTransformer):
|
|
379
385
|
transform_kwargs = dict(
|
380
386
|
session = dataset._session,
|
381
387
|
dependencies = self._deps,
|
382
|
-
|
388
|
+
drop_input_cols = self._drop_input_cols,
|
383
389
|
expected_output_cols_type = expected_type_inferred,
|
384
390
|
)
|
385
391
|
|
@@ -441,16 +447,16 @@ class Birch(BaseTransformer):
|
|
441
447
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
442
448
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
443
449
|
# each row containing a list of values.
|
444
|
-
expected_dtype = "
|
450
|
+
expected_dtype = "array"
|
445
451
|
|
446
452
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
447
453
|
if expected_dtype == "":
|
448
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
454
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
449
455
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
450
|
-
expected_dtype = "
|
451
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
456
|
+
expected_dtype = "array"
|
457
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
452
458
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
453
|
-
expected_dtype = "
|
459
|
+
expected_dtype = "array"
|
454
460
|
else:
|
455
461
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
456
462
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -468,7 +474,7 @@ class Birch(BaseTransformer):
|
|
468
474
|
transform_kwargs = dict(
|
469
475
|
session = dataset._session,
|
470
476
|
dependencies = self._deps,
|
471
|
-
|
477
|
+
drop_input_cols = self._drop_input_cols,
|
472
478
|
expected_output_cols_type = expected_dtype,
|
473
479
|
)
|
474
480
|
|
@@ -521,7 +527,7 @@ class Birch(BaseTransformer):
|
|
521
527
|
subproject=_SUBPROJECT,
|
522
528
|
)
|
523
529
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
524
|
-
|
530
|
+
drop_input_cols=self._drop_input_cols,
|
525
531
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
526
532
|
)
|
527
533
|
self._sklearn_object = fitted_estimator
|
@@ -539,44 +545,6 @@ class Birch(BaseTransformer):
|
|
539
545
|
assert self._sklearn_object is not None
|
540
546
|
return self._sklearn_object.embedding_
|
541
547
|
|
542
|
-
|
543
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
544
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
545
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
546
|
-
"""
|
547
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
548
|
-
if output_cols:
|
549
|
-
output_cols = [
|
550
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
551
|
-
for c in output_cols
|
552
|
-
]
|
553
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
554
|
-
output_cols = [output_cols_prefix]
|
555
|
-
elif self._sklearn_object is not None:
|
556
|
-
classes = self._sklearn_object.classes_
|
557
|
-
if isinstance(classes, numpy.ndarray):
|
558
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
559
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
560
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
561
|
-
output_cols = []
|
562
|
-
for i, cl in enumerate(classes):
|
563
|
-
# For binary classification, there is only one output column for each class
|
564
|
-
# ndarray as the two classes are complementary.
|
565
|
-
if len(cl) == 2:
|
566
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
567
|
-
else:
|
568
|
-
output_cols.extend([
|
569
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
570
|
-
])
|
571
|
-
else:
|
572
|
-
output_cols = []
|
573
|
-
|
574
|
-
# Make sure column names are valid snowflake identifiers.
|
575
|
-
assert output_cols is not None # Make MyPy happy
|
576
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
577
|
-
|
578
|
-
return rv
|
579
|
-
|
580
548
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
581
549
|
@telemetry.send_api_usage_telemetry(
|
582
550
|
project=_PROJECT,
|
@@ -616,7 +584,7 @@ class Birch(BaseTransformer):
|
|
616
584
|
transform_kwargs = dict(
|
617
585
|
session=dataset._session,
|
618
586
|
dependencies=self._deps,
|
619
|
-
|
587
|
+
drop_input_cols = self._drop_input_cols,
|
620
588
|
expected_output_cols_type="float",
|
621
589
|
)
|
622
590
|
|
@@ -681,7 +649,7 @@ class Birch(BaseTransformer):
|
|
681
649
|
transform_kwargs = dict(
|
682
650
|
session=dataset._session,
|
683
651
|
dependencies=self._deps,
|
684
|
-
|
652
|
+
drop_input_cols = self._drop_input_cols,
|
685
653
|
expected_output_cols_type="float",
|
686
654
|
)
|
687
655
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -742,7 +710,7 @@ class Birch(BaseTransformer):
|
|
742
710
|
transform_kwargs = dict(
|
743
711
|
session=dataset._session,
|
744
712
|
dependencies=self._deps,
|
745
|
-
|
713
|
+
drop_input_cols = self._drop_input_cols,
|
746
714
|
expected_output_cols_type="float",
|
747
715
|
)
|
748
716
|
|
@@ -807,7 +775,7 @@ class Birch(BaseTransformer):
|
|
807
775
|
transform_kwargs = dict(
|
808
776
|
session=dataset._session,
|
809
777
|
dependencies=self._deps,
|
810
|
-
|
778
|
+
drop_input_cols = self._drop_input_cols,
|
811
779
|
expected_output_cols_type="float",
|
812
780
|
)
|
813
781
|
|
@@ -861,13 +829,17 @@ class Birch(BaseTransformer):
|
|
861
829
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
862
830
|
|
863
831
|
if isinstance(dataset, DataFrame):
|
832
|
+
self._deps = self._batch_inference_validate_snowpark(
|
833
|
+
dataset=dataset,
|
834
|
+
inference_method="score",
|
835
|
+
)
|
864
836
|
selected_cols = self._get_active_columns()
|
865
837
|
if len(selected_cols) > 0:
|
866
838
|
dataset = dataset.select(selected_cols)
|
867
839
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
868
840
|
transform_kwargs = dict(
|
869
841
|
session=dataset._session,
|
870
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
842
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
871
843
|
score_sproc_imports=['sklearn'],
|
872
844
|
)
|
873
845
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -941,9 +913,9 @@ class Birch(BaseTransformer):
|
|
941
913
|
transform_kwargs = dict(
|
942
914
|
session = dataset._session,
|
943
915
|
dependencies = self._deps,
|
944
|
-
|
945
|
-
expected_output_cols_type
|
946
|
-
n_neighbors =
|
916
|
+
drop_input_cols = self._drop_input_cols,
|
917
|
+
expected_output_cols_type="array",
|
918
|
+
n_neighbors = n_neighbors,
|
947
919
|
return_distance = return_distance
|
948
920
|
)
|
949
921
|
elif isinstance(dataset, pd.DataFrame):
|