snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -338,18 +338,24 @@ class KMeans(BaseTransformer):
|
|
338
338
|
self._get_model_signatures(dataset)
|
339
339
|
return self
|
340
340
|
|
341
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
342
|
-
if self._drop_input_cols:
|
343
|
-
return []
|
344
|
-
else:
|
345
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
346
|
-
|
347
341
|
def _batch_inference_validate_snowpark(
|
348
342
|
self,
|
349
343
|
dataset: DataFrame,
|
350
344
|
inference_method: str,
|
351
345
|
) -> List[str]:
|
352
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
346
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
347
|
+
return the available package that exists in the snowflake anaconda channel
|
348
|
+
|
349
|
+
Args:
|
350
|
+
dataset: snowpark dataframe
|
351
|
+
inference_method: the inference method such as predict, score...
|
352
|
+
|
353
|
+
Raises:
|
354
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
355
|
+
SnowflakeMLException: If the session is None, raise error
|
356
|
+
|
357
|
+
Returns:
|
358
|
+
A list of available package that exists in the snowflake anaconda channel
|
353
359
|
"""
|
354
360
|
if not self._is_fitted:
|
355
361
|
raise exceptions.SnowflakeMLException(
|
@@ -423,7 +429,7 @@ class KMeans(BaseTransformer):
|
|
423
429
|
transform_kwargs = dict(
|
424
430
|
session = dataset._session,
|
425
431
|
dependencies = self._deps,
|
426
|
-
|
432
|
+
drop_input_cols = self._drop_input_cols,
|
427
433
|
expected_output_cols_type = expected_type_inferred,
|
428
434
|
)
|
429
435
|
|
@@ -485,16 +491,16 @@ class KMeans(BaseTransformer):
|
|
485
491
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
486
492
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
487
493
|
# each row containing a list of values.
|
488
|
-
expected_dtype = "
|
494
|
+
expected_dtype = "array"
|
489
495
|
|
490
496
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
491
497
|
if expected_dtype == "":
|
492
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
498
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
493
499
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
494
|
-
expected_dtype = "
|
495
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
500
|
+
expected_dtype = "array"
|
501
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
496
502
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
497
|
-
expected_dtype = "
|
503
|
+
expected_dtype = "array"
|
498
504
|
else:
|
499
505
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
500
506
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -512,7 +518,7 @@ class KMeans(BaseTransformer):
|
|
512
518
|
transform_kwargs = dict(
|
513
519
|
session = dataset._session,
|
514
520
|
dependencies = self._deps,
|
515
|
-
|
521
|
+
drop_input_cols = self._drop_input_cols,
|
516
522
|
expected_output_cols_type = expected_dtype,
|
517
523
|
)
|
518
524
|
|
@@ -565,7 +571,7 @@ class KMeans(BaseTransformer):
|
|
565
571
|
subproject=_SUBPROJECT,
|
566
572
|
)
|
567
573
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
568
|
-
|
574
|
+
drop_input_cols=self._drop_input_cols,
|
569
575
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
570
576
|
)
|
571
577
|
self._sklearn_object = fitted_estimator
|
@@ -583,44 +589,6 @@ class KMeans(BaseTransformer):
|
|
583
589
|
assert self._sklearn_object is not None
|
584
590
|
return self._sklearn_object.embedding_
|
585
591
|
|
586
|
-
|
587
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
588
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
589
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
590
|
-
"""
|
591
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
592
|
-
if output_cols:
|
593
|
-
output_cols = [
|
594
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
595
|
-
for c in output_cols
|
596
|
-
]
|
597
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
598
|
-
output_cols = [output_cols_prefix]
|
599
|
-
elif self._sklearn_object is not None:
|
600
|
-
classes = self._sklearn_object.classes_
|
601
|
-
if isinstance(classes, numpy.ndarray):
|
602
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
603
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
604
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
605
|
-
output_cols = []
|
606
|
-
for i, cl in enumerate(classes):
|
607
|
-
# For binary classification, there is only one output column for each class
|
608
|
-
# ndarray as the two classes are complementary.
|
609
|
-
if len(cl) == 2:
|
610
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
611
|
-
else:
|
612
|
-
output_cols.extend([
|
613
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
614
|
-
])
|
615
|
-
else:
|
616
|
-
output_cols = []
|
617
|
-
|
618
|
-
# Make sure column names are valid snowflake identifiers.
|
619
|
-
assert output_cols is not None # Make MyPy happy
|
620
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
621
|
-
|
622
|
-
return rv
|
623
|
-
|
624
592
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
625
593
|
@telemetry.send_api_usage_telemetry(
|
626
594
|
project=_PROJECT,
|
@@ -660,7 +628,7 @@ class KMeans(BaseTransformer):
|
|
660
628
|
transform_kwargs = dict(
|
661
629
|
session=dataset._session,
|
662
630
|
dependencies=self._deps,
|
663
|
-
|
631
|
+
drop_input_cols = self._drop_input_cols,
|
664
632
|
expected_output_cols_type="float",
|
665
633
|
)
|
666
634
|
|
@@ -725,7 +693,7 @@ class KMeans(BaseTransformer):
|
|
725
693
|
transform_kwargs = dict(
|
726
694
|
session=dataset._session,
|
727
695
|
dependencies=self._deps,
|
728
|
-
|
696
|
+
drop_input_cols = self._drop_input_cols,
|
729
697
|
expected_output_cols_type="float",
|
730
698
|
)
|
731
699
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -786,7 +754,7 @@ class KMeans(BaseTransformer):
|
|
786
754
|
transform_kwargs = dict(
|
787
755
|
session=dataset._session,
|
788
756
|
dependencies=self._deps,
|
789
|
-
|
757
|
+
drop_input_cols = self._drop_input_cols,
|
790
758
|
expected_output_cols_type="float",
|
791
759
|
)
|
792
760
|
|
@@ -851,7 +819,7 @@ class KMeans(BaseTransformer):
|
|
851
819
|
transform_kwargs = dict(
|
852
820
|
session=dataset._session,
|
853
821
|
dependencies=self._deps,
|
854
|
-
|
822
|
+
drop_input_cols = self._drop_input_cols,
|
855
823
|
expected_output_cols_type="float",
|
856
824
|
)
|
857
825
|
|
@@ -907,13 +875,17 @@ class KMeans(BaseTransformer):
|
|
907
875
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
908
876
|
|
909
877
|
if isinstance(dataset, DataFrame):
|
878
|
+
self._deps = self._batch_inference_validate_snowpark(
|
879
|
+
dataset=dataset,
|
880
|
+
inference_method="score",
|
881
|
+
)
|
910
882
|
selected_cols = self._get_active_columns()
|
911
883
|
if len(selected_cols) > 0:
|
912
884
|
dataset = dataset.select(selected_cols)
|
913
885
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
914
886
|
transform_kwargs = dict(
|
915
887
|
session=dataset._session,
|
916
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
888
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
917
889
|
score_sproc_imports=['sklearn'],
|
918
890
|
)
|
919
891
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -987,9 +959,9 @@ class KMeans(BaseTransformer):
|
|
987
959
|
transform_kwargs = dict(
|
988
960
|
session = dataset._session,
|
989
961
|
dependencies = self._deps,
|
990
|
-
|
991
|
-
expected_output_cols_type
|
992
|
-
n_neighbors =
|
962
|
+
drop_input_cols = self._drop_input_cols,
|
963
|
+
expected_output_cols_type="array",
|
964
|
+
n_neighbors = n_neighbors,
|
993
965
|
return_distance = return_distance
|
994
966
|
)
|
995
967
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -314,18 +314,24 @@ class MeanShift(BaseTransformer):
|
|
314
314
|
self._get_model_signatures(dataset)
|
315
315
|
return self
|
316
316
|
|
317
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
318
|
-
if self._drop_input_cols:
|
319
|
-
return []
|
320
|
-
else:
|
321
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
322
|
-
|
323
317
|
def _batch_inference_validate_snowpark(
|
324
318
|
self,
|
325
319
|
dataset: DataFrame,
|
326
320
|
inference_method: str,
|
327
321
|
) -> List[str]:
|
328
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
322
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
323
|
+
return the available package that exists in the snowflake anaconda channel
|
324
|
+
|
325
|
+
Args:
|
326
|
+
dataset: snowpark dataframe
|
327
|
+
inference_method: the inference method such as predict, score...
|
328
|
+
|
329
|
+
Raises:
|
330
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
331
|
+
SnowflakeMLException: If the session is None, raise error
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
A list of available package that exists in the snowflake anaconda channel
|
329
335
|
"""
|
330
336
|
if not self._is_fitted:
|
331
337
|
raise exceptions.SnowflakeMLException(
|
@@ -399,7 +405,7 @@ class MeanShift(BaseTransformer):
|
|
399
405
|
transform_kwargs = dict(
|
400
406
|
session = dataset._session,
|
401
407
|
dependencies = self._deps,
|
402
|
-
|
408
|
+
drop_input_cols = self._drop_input_cols,
|
403
409
|
expected_output_cols_type = expected_type_inferred,
|
404
410
|
)
|
405
411
|
|
@@ -459,16 +465,16 @@ class MeanShift(BaseTransformer):
|
|
459
465
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
460
466
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
461
467
|
# each row containing a list of values.
|
462
|
-
expected_dtype = "
|
468
|
+
expected_dtype = "array"
|
463
469
|
|
464
470
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
465
471
|
if expected_dtype == "":
|
466
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
472
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
467
473
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
468
|
-
expected_dtype = "
|
469
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
474
|
+
expected_dtype = "array"
|
475
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
470
476
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
471
|
-
expected_dtype = "
|
477
|
+
expected_dtype = "array"
|
472
478
|
else:
|
473
479
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
474
480
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -486,7 +492,7 @@ class MeanShift(BaseTransformer):
|
|
486
492
|
transform_kwargs = dict(
|
487
493
|
session = dataset._session,
|
488
494
|
dependencies = self._deps,
|
489
|
-
|
495
|
+
drop_input_cols = self._drop_input_cols,
|
490
496
|
expected_output_cols_type = expected_dtype,
|
491
497
|
)
|
492
498
|
|
@@ -539,7 +545,7 @@ class MeanShift(BaseTransformer):
|
|
539
545
|
subproject=_SUBPROJECT,
|
540
546
|
)
|
541
547
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
542
|
-
|
548
|
+
drop_input_cols=self._drop_input_cols,
|
543
549
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
544
550
|
)
|
545
551
|
self._sklearn_object = fitted_estimator
|
@@ -557,44 +563,6 @@ class MeanShift(BaseTransformer):
|
|
557
563
|
assert self._sklearn_object is not None
|
558
564
|
return self._sklearn_object.embedding_
|
559
565
|
|
560
|
-
|
561
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
562
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
563
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
564
|
-
"""
|
565
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
566
|
-
if output_cols:
|
567
|
-
output_cols = [
|
568
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
569
|
-
for c in output_cols
|
570
|
-
]
|
571
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
572
|
-
output_cols = [output_cols_prefix]
|
573
|
-
elif self._sklearn_object is not None:
|
574
|
-
classes = self._sklearn_object.classes_
|
575
|
-
if isinstance(classes, numpy.ndarray):
|
576
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
577
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
578
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
579
|
-
output_cols = []
|
580
|
-
for i, cl in enumerate(classes):
|
581
|
-
# For binary classification, there is only one output column for each class
|
582
|
-
# ndarray as the two classes are complementary.
|
583
|
-
if len(cl) == 2:
|
584
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
585
|
-
else:
|
586
|
-
output_cols.extend([
|
587
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
588
|
-
])
|
589
|
-
else:
|
590
|
-
output_cols = []
|
591
|
-
|
592
|
-
# Make sure column names are valid snowflake identifiers.
|
593
|
-
assert output_cols is not None # Make MyPy happy
|
594
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
595
|
-
|
596
|
-
return rv
|
597
|
-
|
598
566
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
599
567
|
@telemetry.send_api_usage_telemetry(
|
600
568
|
project=_PROJECT,
|
@@ -634,7 +602,7 @@ class MeanShift(BaseTransformer):
|
|
634
602
|
transform_kwargs = dict(
|
635
603
|
session=dataset._session,
|
636
604
|
dependencies=self._deps,
|
637
|
-
|
605
|
+
drop_input_cols = self._drop_input_cols,
|
638
606
|
expected_output_cols_type="float",
|
639
607
|
)
|
640
608
|
|
@@ -699,7 +667,7 @@ class MeanShift(BaseTransformer):
|
|
699
667
|
transform_kwargs = dict(
|
700
668
|
session=dataset._session,
|
701
669
|
dependencies=self._deps,
|
702
|
-
|
670
|
+
drop_input_cols = self._drop_input_cols,
|
703
671
|
expected_output_cols_type="float",
|
704
672
|
)
|
705
673
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -760,7 +728,7 @@ class MeanShift(BaseTransformer):
|
|
760
728
|
transform_kwargs = dict(
|
761
729
|
session=dataset._session,
|
762
730
|
dependencies=self._deps,
|
763
|
-
|
731
|
+
drop_input_cols = self._drop_input_cols,
|
764
732
|
expected_output_cols_type="float",
|
765
733
|
)
|
766
734
|
|
@@ -825,7 +793,7 @@ class MeanShift(BaseTransformer):
|
|
825
793
|
transform_kwargs = dict(
|
826
794
|
session=dataset._session,
|
827
795
|
dependencies=self._deps,
|
828
|
-
|
796
|
+
drop_input_cols = self._drop_input_cols,
|
829
797
|
expected_output_cols_type="float",
|
830
798
|
)
|
831
799
|
|
@@ -879,13 +847,17 @@ class MeanShift(BaseTransformer):
|
|
879
847
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
880
848
|
|
881
849
|
if isinstance(dataset, DataFrame):
|
850
|
+
self._deps = self._batch_inference_validate_snowpark(
|
851
|
+
dataset=dataset,
|
852
|
+
inference_method="score",
|
853
|
+
)
|
882
854
|
selected_cols = self._get_active_columns()
|
883
855
|
if len(selected_cols) > 0:
|
884
856
|
dataset = dataset.select(selected_cols)
|
885
857
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
886
858
|
transform_kwargs = dict(
|
887
859
|
session=dataset._session,
|
888
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
860
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
889
861
|
score_sproc_imports=['sklearn'],
|
890
862
|
)
|
891
863
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -959,9 +931,9 @@ class MeanShift(BaseTransformer):
|
|
959
931
|
transform_kwargs = dict(
|
960
932
|
session = dataset._session,
|
961
933
|
dependencies = self._deps,
|
962
|
-
|
963
|
-
expected_output_cols_type
|
964
|
-
n_neighbors =
|
934
|
+
drop_input_cols = self._drop_input_cols,
|
935
|
+
expected_output_cols_type="array",
|
936
|
+
n_neighbors = n_neighbors,
|
965
937
|
return_distance = return_distance
|
966
938
|
)
|
967
939
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -364,18 +364,24 @@ class MiniBatchKMeans(BaseTransformer):
|
|
364
364
|
self._get_model_signatures(dataset)
|
365
365
|
return self
|
366
366
|
|
367
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
368
|
-
if self._drop_input_cols:
|
369
|
-
return []
|
370
|
-
else:
|
371
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
372
|
-
|
373
367
|
def _batch_inference_validate_snowpark(
|
374
368
|
self,
|
375
369
|
dataset: DataFrame,
|
376
370
|
inference_method: str,
|
377
371
|
) -> List[str]:
|
378
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
372
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
373
|
+
return the available package that exists in the snowflake anaconda channel
|
374
|
+
|
375
|
+
Args:
|
376
|
+
dataset: snowpark dataframe
|
377
|
+
inference_method: the inference method such as predict, score...
|
378
|
+
|
379
|
+
Raises:
|
380
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
381
|
+
SnowflakeMLException: If the session is None, raise error
|
382
|
+
|
383
|
+
Returns:
|
384
|
+
A list of available package that exists in the snowflake anaconda channel
|
379
385
|
"""
|
380
386
|
if not self._is_fitted:
|
381
387
|
raise exceptions.SnowflakeMLException(
|
@@ -449,7 +455,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
449
455
|
transform_kwargs = dict(
|
450
456
|
session = dataset._session,
|
451
457
|
dependencies = self._deps,
|
452
|
-
|
458
|
+
drop_input_cols = self._drop_input_cols,
|
453
459
|
expected_output_cols_type = expected_type_inferred,
|
454
460
|
)
|
455
461
|
|
@@ -511,16 +517,16 @@ class MiniBatchKMeans(BaseTransformer):
|
|
511
517
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
512
518
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
513
519
|
# each row containing a list of values.
|
514
|
-
expected_dtype = "
|
520
|
+
expected_dtype = "array"
|
515
521
|
|
516
522
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
517
523
|
if expected_dtype == "":
|
518
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
524
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
519
525
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
520
|
-
expected_dtype = "
|
521
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
526
|
+
expected_dtype = "array"
|
527
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
522
528
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
523
|
-
expected_dtype = "
|
529
|
+
expected_dtype = "array"
|
524
530
|
else:
|
525
531
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
526
532
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -538,7 +544,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
538
544
|
transform_kwargs = dict(
|
539
545
|
session = dataset._session,
|
540
546
|
dependencies = self._deps,
|
541
|
-
|
547
|
+
drop_input_cols = self._drop_input_cols,
|
542
548
|
expected_output_cols_type = expected_dtype,
|
543
549
|
)
|
544
550
|
|
@@ -591,7 +597,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
591
597
|
subproject=_SUBPROJECT,
|
592
598
|
)
|
593
599
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
594
|
-
|
600
|
+
drop_input_cols=self._drop_input_cols,
|
595
601
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
596
602
|
)
|
597
603
|
self._sklearn_object = fitted_estimator
|
@@ -609,44 +615,6 @@ class MiniBatchKMeans(BaseTransformer):
|
|
609
615
|
assert self._sklearn_object is not None
|
610
616
|
return self._sklearn_object.embedding_
|
611
617
|
|
612
|
-
|
613
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
614
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
615
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
616
|
-
"""
|
617
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
618
|
-
if output_cols:
|
619
|
-
output_cols = [
|
620
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
621
|
-
for c in output_cols
|
622
|
-
]
|
623
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
624
|
-
output_cols = [output_cols_prefix]
|
625
|
-
elif self._sklearn_object is not None:
|
626
|
-
classes = self._sklearn_object.classes_
|
627
|
-
if isinstance(classes, numpy.ndarray):
|
628
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
629
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
630
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
631
|
-
output_cols = []
|
632
|
-
for i, cl in enumerate(classes):
|
633
|
-
# For binary classification, there is only one output column for each class
|
634
|
-
# ndarray as the two classes are complementary.
|
635
|
-
if len(cl) == 2:
|
636
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
637
|
-
else:
|
638
|
-
output_cols.extend([
|
639
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
640
|
-
])
|
641
|
-
else:
|
642
|
-
output_cols = []
|
643
|
-
|
644
|
-
# Make sure column names are valid snowflake identifiers.
|
645
|
-
assert output_cols is not None # Make MyPy happy
|
646
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
647
|
-
|
648
|
-
return rv
|
649
|
-
|
650
618
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
651
619
|
@telemetry.send_api_usage_telemetry(
|
652
620
|
project=_PROJECT,
|
@@ -686,7 +654,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
686
654
|
transform_kwargs = dict(
|
687
655
|
session=dataset._session,
|
688
656
|
dependencies=self._deps,
|
689
|
-
|
657
|
+
drop_input_cols = self._drop_input_cols,
|
690
658
|
expected_output_cols_type="float",
|
691
659
|
)
|
692
660
|
|
@@ -751,7 +719,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
751
719
|
transform_kwargs = dict(
|
752
720
|
session=dataset._session,
|
753
721
|
dependencies=self._deps,
|
754
|
-
|
722
|
+
drop_input_cols = self._drop_input_cols,
|
755
723
|
expected_output_cols_type="float",
|
756
724
|
)
|
757
725
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -812,7 +780,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
812
780
|
transform_kwargs = dict(
|
813
781
|
session=dataset._session,
|
814
782
|
dependencies=self._deps,
|
815
|
-
|
783
|
+
drop_input_cols = self._drop_input_cols,
|
816
784
|
expected_output_cols_type="float",
|
817
785
|
)
|
818
786
|
|
@@ -877,7 +845,7 @@ class MiniBatchKMeans(BaseTransformer):
|
|
877
845
|
transform_kwargs = dict(
|
878
846
|
session=dataset._session,
|
879
847
|
dependencies=self._deps,
|
880
|
-
|
848
|
+
drop_input_cols = self._drop_input_cols,
|
881
849
|
expected_output_cols_type="float",
|
882
850
|
)
|
883
851
|
|
@@ -933,13 +901,17 @@ class MiniBatchKMeans(BaseTransformer):
|
|
933
901
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
934
902
|
|
935
903
|
if isinstance(dataset, DataFrame):
|
904
|
+
self._deps = self._batch_inference_validate_snowpark(
|
905
|
+
dataset=dataset,
|
906
|
+
inference_method="score",
|
907
|
+
)
|
936
908
|
selected_cols = self._get_active_columns()
|
937
909
|
if len(selected_cols) > 0:
|
938
910
|
dataset = dataset.select(selected_cols)
|
939
911
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
940
912
|
transform_kwargs = dict(
|
941
913
|
session=dataset._session,
|
942
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
914
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
943
915
|
score_sproc_imports=['sklearn'],
|
944
916
|
)
|
945
917
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1013,9 +985,9 @@ class MiniBatchKMeans(BaseTransformer):
|
|
1013
985
|
transform_kwargs = dict(
|
1014
986
|
session = dataset._session,
|
1015
987
|
dependencies = self._deps,
|
1016
|
-
|
1017
|
-
expected_output_cols_type
|
1018
|
-
n_neighbors =
|
988
|
+
drop_input_cols = self._drop_input_cols,
|
989
|
+
expected_output_cols_type="array",
|
990
|
+
n_neighbors = n_neighbors,
|
1019
991
|
return_distance = return_distance
|
1020
992
|
)
|
1021
993
|
elif isinstance(dataset, pd.DataFrame):
|