snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -347,18 +347,24 @@ class PCA(BaseTransformer):
|
|
347
347
|
self._get_model_signatures(dataset)
|
348
348
|
return self
|
349
349
|
|
350
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
351
|
-
if self._drop_input_cols:
|
352
|
-
return []
|
353
|
-
else:
|
354
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
355
|
-
|
356
350
|
def _batch_inference_validate_snowpark(
|
357
351
|
self,
|
358
352
|
dataset: DataFrame,
|
359
353
|
inference_method: str,
|
360
354
|
) -> List[str]:
|
361
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
355
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
356
|
+
return the available package that exists in the snowflake anaconda channel
|
357
|
+
|
358
|
+
Args:
|
359
|
+
dataset: snowpark dataframe
|
360
|
+
inference_method: the inference method such as predict, score...
|
361
|
+
|
362
|
+
Raises:
|
363
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
364
|
+
SnowflakeMLException: If the session is None, raise error
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
A list of available package that exists in the snowflake anaconda channel
|
362
368
|
"""
|
363
369
|
if not self._is_fitted:
|
364
370
|
raise exceptions.SnowflakeMLException(
|
@@ -430,7 +436,7 @@ class PCA(BaseTransformer):
|
|
430
436
|
transform_kwargs = dict(
|
431
437
|
session = dataset._session,
|
432
438
|
dependencies = self._deps,
|
433
|
-
|
439
|
+
drop_input_cols = self._drop_input_cols,
|
434
440
|
expected_output_cols_type = expected_type_inferred,
|
435
441
|
)
|
436
442
|
|
@@ -492,16 +498,16 @@ class PCA(BaseTransformer):
|
|
492
498
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
493
499
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
494
500
|
# each row containing a list of values.
|
495
|
-
expected_dtype = "
|
501
|
+
expected_dtype = "array"
|
496
502
|
|
497
503
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
498
504
|
if expected_dtype == "":
|
499
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
505
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
500
506
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
501
|
-
expected_dtype = "
|
502
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
507
|
+
expected_dtype = "array"
|
508
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
503
509
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
504
|
-
expected_dtype = "
|
510
|
+
expected_dtype = "array"
|
505
511
|
else:
|
506
512
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
507
513
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -519,7 +525,7 @@ class PCA(BaseTransformer):
|
|
519
525
|
transform_kwargs = dict(
|
520
526
|
session = dataset._session,
|
521
527
|
dependencies = self._deps,
|
522
|
-
|
528
|
+
drop_input_cols = self._drop_input_cols,
|
523
529
|
expected_output_cols_type = expected_dtype,
|
524
530
|
)
|
525
531
|
|
@@ -570,7 +576,7 @@ class PCA(BaseTransformer):
|
|
570
576
|
subproject=_SUBPROJECT,
|
571
577
|
)
|
572
578
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
573
|
-
|
579
|
+
drop_input_cols=self._drop_input_cols,
|
574
580
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
575
581
|
)
|
576
582
|
self._sklearn_object = fitted_estimator
|
@@ -588,44 +594,6 @@ class PCA(BaseTransformer):
|
|
588
594
|
assert self._sklearn_object is not None
|
589
595
|
return self._sklearn_object.embedding_
|
590
596
|
|
591
|
-
|
592
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
593
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
594
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
595
|
-
"""
|
596
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
597
|
-
if output_cols:
|
598
|
-
output_cols = [
|
599
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
600
|
-
for c in output_cols
|
601
|
-
]
|
602
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
603
|
-
output_cols = [output_cols_prefix]
|
604
|
-
elif self._sklearn_object is not None:
|
605
|
-
classes = self._sklearn_object.classes_
|
606
|
-
if isinstance(classes, numpy.ndarray):
|
607
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
608
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
609
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
610
|
-
output_cols = []
|
611
|
-
for i, cl in enumerate(classes):
|
612
|
-
# For binary classification, there is only one output column for each class
|
613
|
-
# ndarray as the two classes are complementary.
|
614
|
-
if len(cl) == 2:
|
615
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
616
|
-
else:
|
617
|
-
output_cols.extend([
|
618
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
619
|
-
])
|
620
|
-
else:
|
621
|
-
output_cols = []
|
622
|
-
|
623
|
-
# Make sure column names are valid snowflake identifiers.
|
624
|
-
assert output_cols is not None # Make MyPy happy
|
625
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
626
|
-
|
627
|
-
return rv
|
628
|
-
|
629
597
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
630
598
|
@telemetry.send_api_usage_telemetry(
|
631
599
|
project=_PROJECT,
|
@@ -665,7 +633,7 @@ class PCA(BaseTransformer):
|
|
665
633
|
transform_kwargs = dict(
|
666
634
|
session=dataset._session,
|
667
635
|
dependencies=self._deps,
|
668
|
-
|
636
|
+
drop_input_cols = self._drop_input_cols,
|
669
637
|
expected_output_cols_type="float",
|
670
638
|
)
|
671
639
|
|
@@ -730,7 +698,7 @@ class PCA(BaseTransformer):
|
|
730
698
|
transform_kwargs = dict(
|
731
699
|
session=dataset._session,
|
732
700
|
dependencies=self._deps,
|
733
|
-
|
701
|
+
drop_input_cols = self._drop_input_cols,
|
734
702
|
expected_output_cols_type="float",
|
735
703
|
)
|
736
704
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -791,7 +759,7 @@ class PCA(BaseTransformer):
|
|
791
759
|
transform_kwargs = dict(
|
792
760
|
session=dataset._session,
|
793
761
|
dependencies=self._deps,
|
794
|
-
|
762
|
+
drop_input_cols = self._drop_input_cols,
|
795
763
|
expected_output_cols_type="float",
|
796
764
|
)
|
797
765
|
|
@@ -858,7 +826,7 @@ class PCA(BaseTransformer):
|
|
858
826
|
transform_kwargs = dict(
|
859
827
|
session=dataset._session,
|
860
828
|
dependencies=self._deps,
|
861
|
-
|
829
|
+
drop_input_cols = self._drop_input_cols,
|
862
830
|
expected_output_cols_type="float",
|
863
831
|
)
|
864
832
|
|
@@ -914,13 +882,17 @@ class PCA(BaseTransformer):
|
|
914
882
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
915
883
|
|
916
884
|
if isinstance(dataset, DataFrame):
|
885
|
+
self._deps = self._batch_inference_validate_snowpark(
|
886
|
+
dataset=dataset,
|
887
|
+
inference_method="score",
|
888
|
+
)
|
917
889
|
selected_cols = self._get_active_columns()
|
918
890
|
if len(selected_cols) > 0:
|
919
891
|
dataset = dataset.select(selected_cols)
|
920
892
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
921
893
|
transform_kwargs = dict(
|
922
894
|
session=dataset._session,
|
923
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
895
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
924
896
|
score_sproc_imports=['sklearn'],
|
925
897
|
)
|
926
898
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -994,9 +966,9 @@ class PCA(BaseTransformer):
|
|
994
966
|
transform_kwargs = dict(
|
995
967
|
session = dataset._session,
|
996
968
|
dependencies = self._deps,
|
997
|
-
|
998
|
-
expected_output_cols_type
|
999
|
-
n_neighbors =
|
969
|
+
drop_input_cols = self._drop_input_cols,
|
970
|
+
expected_output_cols_type="array",
|
971
|
+
n_neighbors = n_neighbors,
|
1000
972
|
return_distance = return_distance
|
1001
973
|
)
|
1002
974
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -320,18 +320,24 @@ class SparsePCA(BaseTransformer):
|
|
320
320
|
self._get_model_signatures(dataset)
|
321
321
|
return self
|
322
322
|
|
323
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
324
|
-
if self._drop_input_cols:
|
325
|
-
return []
|
326
|
-
else:
|
327
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
328
|
-
|
329
323
|
def _batch_inference_validate_snowpark(
|
330
324
|
self,
|
331
325
|
dataset: DataFrame,
|
332
326
|
inference_method: str,
|
333
327
|
) -> List[str]:
|
334
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
328
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
329
|
+
return the available package that exists in the snowflake anaconda channel
|
330
|
+
|
331
|
+
Args:
|
332
|
+
dataset: snowpark dataframe
|
333
|
+
inference_method: the inference method such as predict, score...
|
334
|
+
|
335
|
+
Raises:
|
336
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
337
|
+
SnowflakeMLException: If the session is None, raise error
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
A list of available package that exists in the snowflake anaconda channel
|
335
341
|
"""
|
336
342
|
if not self._is_fitted:
|
337
343
|
raise exceptions.SnowflakeMLException(
|
@@ -403,7 +409,7 @@ class SparsePCA(BaseTransformer):
|
|
403
409
|
transform_kwargs = dict(
|
404
410
|
session = dataset._session,
|
405
411
|
dependencies = self._deps,
|
406
|
-
|
412
|
+
drop_input_cols = self._drop_input_cols,
|
407
413
|
expected_output_cols_type = expected_type_inferred,
|
408
414
|
)
|
409
415
|
|
@@ -465,16 +471,16 @@ class SparsePCA(BaseTransformer):
|
|
465
471
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
466
472
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
467
473
|
# each row containing a list of values.
|
468
|
-
expected_dtype = "
|
474
|
+
expected_dtype = "array"
|
469
475
|
|
470
476
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
471
477
|
if expected_dtype == "":
|
472
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
478
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
473
479
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
474
|
-
expected_dtype = "
|
475
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
480
|
+
expected_dtype = "array"
|
481
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
476
482
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
477
|
-
expected_dtype = "
|
483
|
+
expected_dtype = "array"
|
478
484
|
else:
|
479
485
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
480
486
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -492,7 +498,7 @@ class SparsePCA(BaseTransformer):
|
|
492
498
|
transform_kwargs = dict(
|
493
499
|
session = dataset._session,
|
494
500
|
dependencies = self._deps,
|
495
|
-
|
501
|
+
drop_input_cols = self._drop_input_cols,
|
496
502
|
expected_output_cols_type = expected_dtype,
|
497
503
|
)
|
498
504
|
|
@@ -543,7 +549,7 @@ class SparsePCA(BaseTransformer):
|
|
543
549
|
subproject=_SUBPROJECT,
|
544
550
|
)
|
545
551
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
546
|
-
|
552
|
+
drop_input_cols=self._drop_input_cols,
|
547
553
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
548
554
|
)
|
549
555
|
self._sklearn_object = fitted_estimator
|
@@ -561,44 +567,6 @@ class SparsePCA(BaseTransformer):
|
|
561
567
|
assert self._sklearn_object is not None
|
562
568
|
return self._sklearn_object.embedding_
|
563
569
|
|
564
|
-
|
565
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
566
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
567
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
568
|
-
"""
|
569
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
570
|
-
if output_cols:
|
571
|
-
output_cols = [
|
572
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
573
|
-
for c in output_cols
|
574
|
-
]
|
575
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
576
|
-
output_cols = [output_cols_prefix]
|
577
|
-
elif self._sklearn_object is not None:
|
578
|
-
classes = self._sklearn_object.classes_
|
579
|
-
if isinstance(classes, numpy.ndarray):
|
580
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
581
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
582
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
583
|
-
output_cols = []
|
584
|
-
for i, cl in enumerate(classes):
|
585
|
-
# For binary classification, there is only one output column for each class
|
586
|
-
# ndarray as the two classes are complementary.
|
587
|
-
if len(cl) == 2:
|
588
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
589
|
-
else:
|
590
|
-
output_cols.extend([
|
591
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
592
|
-
])
|
593
|
-
else:
|
594
|
-
output_cols = []
|
595
|
-
|
596
|
-
# Make sure column names are valid snowflake identifiers.
|
597
|
-
assert output_cols is not None # Make MyPy happy
|
598
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
599
|
-
|
600
|
-
return rv
|
601
|
-
|
602
570
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
603
571
|
@telemetry.send_api_usage_telemetry(
|
604
572
|
project=_PROJECT,
|
@@ -638,7 +606,7 @@ class SparsePCA(BaseTransformer):
|
|
638
606
|
transform_kwargs = dict(
|
639
607
|
session=dataset._session,
|
640
608
|
dependencies=self._deps,
|
641
|
-
|
609
|
+
drop_input_cols = self._drop_input_cols,
|
642
610
|
expected_output_cols_type="float",
|
643
611
|
)
|
644
612
|
|
@@ -703,7 +671,7 @@ class SparsePCA(BaseTransformer):
|
|
703
671
|
transform_kwargs = dict(
|
704
672
|
session=dataset._session,
|
705
673
|
dependencies=self._deps,
|
706
|
-
|
674
|
+
drop_input_cols = self._drop_input_cols,
|
707
675
|
expected_output_cols_type="float",
|
708
676
|
)
|
709
677
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -764,7 +732,7 @@ class SparsePCA(BaseTransformer):
|
|
764
732
|
transform_kwargs = dict(
|
765
733
|
session=dataset._session,
|
766
734
|
dependencies=self._deps,
|
767
|
-
|
735
|
+
drop_input_cols = self._drop_input_cols,
|
768
736
|
expected_output_cols_type="float",
|
769
737
|
)
|
770
738
|
|
@@ -829,7 +797,7 @@ class SparsePCA(BaseTransformer):
|
|
829
797
|
transform_kwargs = dict(
|
830
798
|
session=dataset._session,
|
831
799
|
dependencies=self._deps,
|
832
|
-
|
800
|
+
drop_input_cols = self._drop_input_cols,
|
833
801
|
expected_output_cols_type="float",
|
834
802
|
)
|
835
803
|
|
@@ -883,13 +851,17 @@ class SparsePCA(BaseTransformer):
|
|
883
851
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
884
852
|
|
885
853
|
if isinstance(dataset, DataFrame):
|
854
|
+
self._deps = self._batch_inference_validate_snowpark(
|
855
|
+
dataset=dataset,
|
856
|
+
inference_method="score",
|
857
|
+
)
|
886
858
|
selected_cols = self._get_active_columns()
|
887
859
|
if len(selected_cols) > 0:
|
888
860
|
dataset = dataset.select(selected_cols)
|
889
861
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
890
862
|
transform_kwargs = dict(
|
891
863
|
session=dataset._session,
|
892
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
864
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
893
865
|
score_sproc_imports=['sklearn'],
|
894
866
|
)
|
895
867
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -963,9 +935,9 @@ class SparsePCA(BaseTransformer):
|
|
963
935
|
transform_kwargs = dict(
|
964
936
|
session = dataset._session,
|
965
937
|
dependencies = self._deps,
|
966
|
-
|
967
|
-
expected_output_cols_type
|
968
|
-
n_neighbors =
|
938
|
+
drop_input_cols = self._drop_input_cols,
|
939
|
+
expected_output_cols_type="array",
|
940
|
+
n_neighbors = n_neighbors,
|
969
941
|
return_distance = return_distance
|
970
942
|
)
|
971
943
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -301,18 +301,24 @@ class TruncatedSVD(BaseTransformer):
|
|
301
301
|
self._get_model_signatures(dataset)
|
302
302
|
return self
|
303
303
|
|
304
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
305
|
-
if self._drop_input_cols:
|
306
|
-
return []
|
307
|
-
else:
|
308
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
309
|
-
|
310
304
|
def _batch_inference_validate_snowpark(
|
311
305
|
self,
|
312
306
|
dataset: DataFrame,
|
313
307
|
inference_method: str,
|
314
308
|
) -> List[str]:
|
315
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
309
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
310
|
+
return the available package that exists in the snowflake anaconda channel
|
311
|
+
|
312
|
+
Args:
|
313
|
+
dataset: snowpark dataframe
|
314
|
+
inference_method: the inference method such as predict, score...
|
315
|
+
|
316
|
+
Raises:
|
317
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
318
|
+
SnowflakeMLException: If the session is None, raise error
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
A list of available package that exists in the snowflake anaconda channel
|
316
322
|
"""
|
317
323
|
if not self._is_fitted:
|
318
324
|
raise exceptions.SnowflakeMLException(
|
@@ -384,7 +390,7 @@ class TruncatedSVD(BaseTransformer):
|
|
384
390
|
transform_kwargs = dict(
|
385
391
|
session = dataset._session,
|
386
392
|
dependencies = self._deps,
|
387
|
-
|
393
|
+
drop_input_cols = self._drop_input_cols,
|
388
394
|
expected_output_cols_type = expected_type_inferred,
|
389
395
|
)
|
390
396
|
|
@@ -446,16 +452,16 @@ class TruncatedSVD(BaseTransformer):
|
|
446
452
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
447
453
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
448
454
|
# each row containing a list of values.
|
449
|
-
expected_dtype = "
|
455
|
+
expected_dtype = "array"
|
450
456
|
|
451
457
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
452
458
|
if expected_dtype == "":
|
453
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
459
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
454
460
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
455
|
-
expected_dtype = "
|
456
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
461
|
+
expected_dtype = "array"
|
462
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
457
463
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
458
|
-
expected_dtype = "
|
464
|
+
expected_dtype = "array"
|
459
465
|
else:
|
460
466
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
461
467
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -473,7 +479,7 @@ class TruncatedSVD(BaseTransformer):
|
|
473
479
|
transform_kwargs = dict(
|
474
480
|
session = dataset._session,
|
475
481
|
dependencies = self._deps,
|
476
|
-
|
482
|
+
drop_input_cols = self._drop_input_cols,
|
477
483
|
expected_output_cols_type = expected_dtype,
|
478
484
|
)
|
479
485
|
|
@@ -524,7 +530,7 @@ class TruncatedSVD(BaseTransformer):
|
|
524
530
|
subproject=_SUBPROJECT,
|
525
531
|
)
|
526
532
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
527
|
-
|
533
|
+
drop_input_cols=self._drop_input_cols,
|
528
534
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
529
535
|
)
|
530
536
|
self._sklearn_object = fitted_estimator
|
@@ -542,44 +548,6 @@ class TruncatedSVD(BaseTransformer):
|
|
542
548
|
assert self._sklearn_object is not None
|
543
549
|
return self._sklearn_object.embedding_
|
544
550
|
|
545
|
-
|
546
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
547
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
548
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
549
|
-
"""
|
550
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
551
|
-
if output_cols:
|
552
|
-
output_cols = [
|
553
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
554
|
-
for c in output_cols
|
555
|
-
]
|
556
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
557
|
-
output_cols = [output_cols_prefix]
|
558
|
-
elif self._sklearn_object is not None:
|
559
|
-
classes = self._sklearn_object.classes_
|
560
|
-
if isinstance(classes, numpy.ndarray):
|
561
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
562
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
563
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
564
|
-
output_cols = []
|
565
|
-
for i, cl in enumerate(classes):
|
566
|
-
# For binary classification, there is only one output column for each class
|
567
|
-
# ndarray as the two classes are complementary.
|
568
|
-
if len(cl) == 2:
|
569
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
570
|
-
else:
|
571
|
-
output_cols.extend([
|
572
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
573
|
-
])
|
574
|
-
else:
|
575
|
-
output_cols = []
|
576
|
-
|
577
|
-
# Make sure column names are valid snowflake identifiers.
|
578
|
-
assert output_cols is not None # Make MyPy happy
|
579
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
580
|
-
|
581
|
-
return rv
|
582
|
-
|
583
551
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
584
552
|
@telemetry.send_api_usage_telemetry(
|
585
553
|
project=_PROJECT,
|
@@ -619,7 +587,7 @@ class TruncatedSVD(BaseTransformer):
|
|
619
587
|
transform_kwargs = dict(
|
620
588
|
session=dataset._session,
|
621
589
|
dependencies=self._deps,
|
622
|
-
|
590
|
+
drop_input_cols = self._drop_input_cols,
|
623
591
|
expected_output_cols_type="float",
|
624
592
|
)
|
625
593
|
|
@@ -684,7 +652,7 @@ class TruncatedSVD(BaseTransformer):
|
|
684
652
|
transform_kwargs = dict(
|
685
653
|
session=dataset._session,
|
686
654
|
dependencies=self._deps,
|
687
|
-
|
655
|
+
drop_input_cols = self._drop_input_cols,
|
688
656
|
expected_output_cols_type="float",
|
689
657
|
)
|
690
658
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -745,7 +713,7 @@ class TruncatedSVD(BaseTransformer):
|
|
745
713
|
transform_kwargs = dict(
|
746
714
|
session=dataset._session,
|
747
715
|
dependencies=self._deps,
|
748
|
-
|
716
|
+
drop_input_cols = self._drop_input_cols,
|
749
717
|
expected_output_cols_type="float",
|
750
718
|
)
|
751
719
|
|
@@ -810,7 +778,7 @@ class TruncatedSVD(BaseTransformer):
|
|
810
778
|
transform_kwargs = dict(
|
811
779
|
session=dataset._session,
|
812
780
|
dependencies=self._deps,
|
813
|
-
|
781
|
+
drop_input_cols = self._drop_input_cols,
|
814
782
|
expected_output_cols_type="float",
|
815
783
|
)
|
816
784
|
|
@@ -864,13 +832,17 @@ class TruncatedSVD(BaseTransformer):
|
|
864
832
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
865
833
|
|
866
834
|
if isinstance(dataset, DataFrame):
|
835
|
+
self._deps = self._batch_inference_validate_snowpark(
|
836
|
+
dataset=dataset,
|
837
|
+
inference_method="score",
|
838
|
+
)
|
867
839
|
selected_cols = self._get_active_columns()
|
868
840
|
if len(selected_cols) > 0:
|
869
841
|
dataset = dataset.select(selected_cols)
|
870
842
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
871
843
|
transform_kwargs = dict(
|
872
844
|
session=dataset._session,
|
873
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
845
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
874
846
|
score_sproc_imports=['sklearn'],
|
875
847
|
)
|
876
848
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -944,9 +916,9 @@ class TruncatedSVD(BaseTransformer):
|
|
944
916
|
transform_kwargs = dict(
|
945
917
|
session = dataset._session,
|
946
918
|
dependencies = self._deps,
|
947
|
-
|
948
|
-
expected_output_cols_type
|
949
|
-
n_neighbors =
|
919
|
+
drop_input_cols = self._drop_input_cols,
|
920
|
+
expected_output_cols_type="array",
|
921
|
+
n_neighbors = n_neighbors,
|
950
922
|
return_distance = return_distance
|
951
923
|
)
|
952
924
|
elif isinstance(dataset, pd.DataFrame):
|