snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -319,18 +319,24 @@ class ARDRegression(BaseTransformer):
|
|
319
319
|
self._get_model_signatures(dataset)
|
320
320
|
return self
|
321
321
|
|
322
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
323
|
-
if self._drop_input_cols:
|
324
|
-
return []
|
325
|
-
else:
|
326
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
327
|
-
|
328
322
|
def _batch_inference_validate_snowpark(
|
329
323
|
self,
|
330
324
|
dataset: DataFrame,
|
331
325
|
inference_method: str,
|
332
326
|
) -> List[str]:
|
333
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
327
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
328
|
+
return the available package that exists in the snowflake anaconda channel
|
329
|
+
|
330
|
+
Args:
|
331
|
+
dataset: snowpark dataframe
|
332
|
+
inference_method: the inference method such as predict, score...
|
333
|
+
|
334
|
+
Raises:
|
335
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
336
|
+
SnowflakeMLException: If the session is None, raise error
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
A list of available package that exists in the snowflake anaconda channel
|
334
340
|
"""
|
335
341
|
if not self._is_fitted:
|
336
342
|
raise exceptions.SnowflakeMLException(
|
@@ -404,7 +410,7 @@ class ARDRegression(BaseTransformer):
|
|
404
410
|
transform_kwargs = dict(
|
405
411
|
session = dataset._session,
|
406
412
|
dependencies = self._deps,
|
407
|
-
|
413
|
+
drop_input_cols = self._drop_input_cols,
|
408
414
|
expected_output_cols_type = expected_type_inferred,
|
409
415
|
)
|
410
416
|
|
@@ -464,16 +470,16 @@ class ARDRegression(BaseTransformer):
|
|
464
470
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
465
471
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
466
472
|
# each row containing a list of values.
|
467
|
-
expected_dtype = "
|
473
|
+
expected_dtype = "array"
|
468
474
|
|
469
475
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
470
476
|
if expected_dtype == "":
|
471
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
477
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
472
478
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
473
|
-
expected_dtype = "
|
474
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
479
|
+
expected_dtype = "array"
|
480
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
475
481
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
476
|
-
expected_dtype = "
|
482
|
+
expected_dtype = "array"
|
477
483
|
else:
|
478
484
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
479
485
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -491,7 +497,7 @@ class ARDRegression(BaseTransformer):
|
|
491
497
|
transform_kwargs = dict(
|
492
498
|
session = dataset._session,
|
493
499
|
dependencies = self._deps,
|
494
|
-
|
500
|
+
drop_input_cols = self._drop_input_cols,
|
495
501
|
expected_output_cols_type = expected_dtype,
|
496
502
|
)
|
497
503
|
|
@@ -542,7 +548,7 @@ class ARDRegression(BaseTransformer):
|
|
542
548
|
subproject=_SUBPROJECT,
|
543
549
|
)
|
544
550
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
545
|
-
|
551
|
+
drop_input_cols=self._drop_input_cols,
|
546
552
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
547
553
|
)
|
548
554
|
self._sklearn_object = fitted_estimator
|
@@ -560,44 +566,6 @@ class ARDRegression(BaseTransformer):
|
|
560
566
|
assert self._sklearn_object is not None
|
561
567
|
return self._sklearn_object.embedding_
|
562
568
|
|
563
|
-
|
564
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
565
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
566
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
567
|
-
"""
|
568
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
569
|
-
if output_cols:
|
570
|
-
output_cols = [
|
571
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
572
|
-
for c in output_cols
|
573
|
-
]
|
574
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
575
|
-
output_cols = [output_cols_prefix]
|
576
|
-
elif self._sklearn_object is not None:
|
577
|
-
classes = self._sklearn_object.classes_
|
578
|
-
if isinstance(classes, numpy.ndarray):
|
579
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
580
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
581
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
582
|
-
output_cols = []
|
583
|
-
for i, cl in enumerate(classes):
|
584
|
-
# For binary classification, there is only one output column for each class
|
585
|
-
# ndarray as the two classes are complementary.
|
586
|
-
if len(cl) == 2:
|
587
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
588
|
-
else:
|
589
|
-
output_cols.extend([
|
590
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
591
|
-
])
|
592
|
-
else:
|
593
|
-
output_cols = []
|
594
|
-
|
595
|
-
# Make sure column names are valid snowflake identifiers.
|
596
|
-
assert output_cols is not None # Make MyPy happy
|
597
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
598
|
-
|
599
|
-
return rv
|
600
|
-
|
601
569
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
602
570
|
@telemetry.send_api_usage_telemetry(
|
603
571
|
project=_PROJECT,
|
@@ -637,7 +605,7 @@ class ARDRegression(BaseTransformer):
|
|
637
605
|
transform_kwargs = dict(
|
638
606
|
session=dataset._session,
|
639
607
|
dependencies=self._deps,
|
640
|
-
|
608
|
+
drop_input_cols = self._drop_input_cols,
|
641
609
|
expected_output_cols_type="float",
|
642
610
|
)
|
643
611
|
|
@@ -702,7 +670,7 @@ class ARDRegression(BaseTransformer):
|
|
702
670
|
transform_kwargs = dict(
|
703
671
|
session=dataset._session,
|
704
672
|
dependencies=self._deps,
|
705
|
-
|
673
|
+
drop_input_cols = self._drop_input_cols,
|
706
674
|
expected_output_cols_type="float",
|
707
675
|
)
|
708
676
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -763,7 +731,7 @@ class ARDRegression(BaseTransformer):
|
|
763
731
|
transform_kwargs = dict(
|
764
732
|
session=dataset._session,
|
765
733
|
dependencies=self._deps,
|
766
|
-
|
734
|
+
drop_input_cols = self._drop_input_cols,
|
767
735
|
expected_output_cols_type="float",
|
768
736
|
)
|
769
737
|
|
@@ -828,7 +796,7 @@ class ARDRegression(BaseTransformer):
|
|
828
796
|
transform_kwargs = dict(
|
829
797
|
session=dataset._session,
|
830
798
|
dependencies=self._deps,
|
831
|
-
|
799
|
+
drop_input_cols = self._drop_input_cols,
|
832
800
|
expected_output_cols_type="float",
|
833
801
|
)
|
834
802
|
|
@@ -884,13 +852,17 @@ class ARDRegression(BaseTransformer):
|
|
884
852
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
885
853
|
|
886
854
|
if isinstance(dataset, DataFrame):
|
855
|
+
self._deps = self._batch_inference_validate_snowpark(
|
856
|
+
dataset=dataset,
|
857
|
+
inference_method="score",
|
858
|
+
)
|
887
859
|
selected_cols = self._get_active_columns()
|
888
860
|
if len(selected_cols) > 0:
|
889
861
|
dataset = dataset.select(selected_cols)
|
890
862
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
891
863
|
transform_kwargs = dict(
|
892
864
|
session=dataset._session,
|
893
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
865
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
894
866
|
score_sproc_imports=['sklearn'],
|
895
867
|
)
|
896
868
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -964,9 +936,9 @@ class ARDRegression(BaseTransformer):
|
|
964
936
|
transform_kwargs = dict(
|
965
937
|
session = dataset._session,
|
966
938
|
dependencies = self._deps,
|
967
|
-
|
968
|
-
expected_output_cols_type
|
969
|
-
n_neighbors =
|
939
|
+
drop_input_cols = self._drop_input_cols,
|
940
|
+
expected_output_cols_type="array",
|
941
|
+
n_neighbors = n_neighbors,
|
970
942
|
return_distance = return_distance
|
971
943
|
)
|
972
944
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -330,18 +330,24 @@ class BayesianRidge(BaseTransformer):
|
|
330
330
|
self._get_model_signatures(dataset)
|
331
331
|
return self
|
332
332
|
|
333
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
334
|
-
if self._drop_input_cols:
|
335
|
-
return []
|
336
|
-
else:
|
337
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
338
|
-
|
339
333
|
def _batch_inference_validate_snowpark(
|
340
334
|
self,
|
341
335
|
dataset: DataFrame,
|
342
336
|
inference_method: str,
|
343
337
|
) -> List[str]:
|
344
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
338
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
339
|
+
return the available package that exists in the snowflake anaconda channel
|
340
|
+
|
341
|
+
Args:
|
342
|
+
dataset: snowpark dataframe
|
343
|
+
inference_method: the inference method such as predict, score...
|
344
|
+
|
345
|
+
Raises:
|
346
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
347
|
+
SnowflakeMLException: If the session is None, raise error
|
348
|
+
|
349
|
+
Returns:
|
350
|
+
A list of available package that exists in the snowflake anaconda channel
|
345
351
|
"""
|
346
352
|
if not self._is_fitted:
|
347
353
|
raise exceptions.SnowflakeMLException(
|
@@ -415,7 +421,7 @@ class BayesianRidge(BaseTransformer):
|
|
415
421
|
transform_kwargs = dict(
|
416
422
|
session = dataset._session,
|
417
423
|
dependencies = self._deps,
|
418
|
-
|
424
|
+
drop_input_cols = self._drop_input_cols,
|
419
425
|
expected_output_cols_type = expected_type_inferred,
|
420
426
|
)
|
421
427
|
|
@@ -475,16 +481,16 @@ class BayesianRidge(BaseTransformer):
|
|
475
481
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
476
482
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
477
483
|
# each row containing a list of values.
|
478
|
-
expected_dtype = "
|
484
|
+
expected_dtype = "array"
|
479
485
|
|
480
486
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
481
487
|
if expected_dtype == "":
|
482
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
488
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
483
489
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
484
|
-
expected_dtype = "
|
485
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
490
|
+
expected_dtype = "array"
|
491
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
486
492
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
487
|
-
expected_dtype = "
|
493
|
+
expected_dtype = "array"
|
488
494
|
else:
|
489
495
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
490
496
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -502,7 +508,7 @@ class BayesianRidge(BaseTransformer):
|
|
502
508
|
transform_kwargs = dict(
|
503
509
|
session = dataset._session,
|
504
510
|
dependencies = self._deps,
|
505
|
-
|
511
|
+
drop_input_cols = self._drop_input_cols,
|
506
512
|
expected_output_cols_type = expected_dtype,
|
507
513
|
)
|
508
514
|
|
@@ -553,7 +559,7 @@ class BayesianRidge(BaseTransformer):
|
|
553
559
|
subproject=_SUBPROJECT,
|
554
560
|
)
|
555
561
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
556
|
-
|
562
|
+
drop_input_cols=self._drop_input_cols,
|
557
563
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
558
564
|
)
|
559
565
|
self._sklearn_object = fitted_estimator
|
@@ -571,44 +577,6 @@ class BayesianRidge(BaseTransformer):
|
|
571
577
|
assert self._sklearn_object is not None
|
572
578
|
return self._sklearn_object.embedding_
|
573
579
|
|
574
|
-
|
575
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
576
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
577
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
578
|
-
"""
|
579
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
580
|
-
if output_cols:
|
581
|
-
output_cols = [
|
582
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
583
|
-
for c in output_cols
|
584
|
-
]
|
585
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
586
|
-
output_cols = [output_cols_prefix]
|
587
|
-
elif self._sklearn_object is not None:
|
588
|
-
classes = self._sklearn_object.classes_
|
589
|
-
if isinstance(classes, numpy.ndarray):
|
590
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
591
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
592
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
593
|
-
output_cols = []
|
594
|
-
for i, cl in enumerate(classes):
|
595
|
-
# For binary classification, there is only one output column for each class
|
596
|
-
# ndarray as the two classes are complementary.
|
597
|
-
if len(cl) == 2:
|
598
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
599
|
-
else:
|
600
|
-
output_cols.extend([
|
601
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
602
|
-
])
|
603
|
-
else:
|
604
|
-
output_cols = []
|
605
|
-
|
606
|
-
# Make sure column names are valid snowflake identifiers.
|
607
|
-
assert output_cols is not None # Make MyPy happy
|
608
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
609
|
-
|
610
|
-
return rv
|
611
|
-
|
612
580
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
613
581
|
@telemetry.send_api_usage_telemetry(
|
614
582
|
project=_PROJECT,
|
@@ -648,7 +616,7 @@ class BayesianRidge(BaseTransformer):
|
|
648
616
|
transform_kwargs = dict(
|
649
617
|
session=dataset._session,
|
650
618
|
dependencies=self._deps,
|
651
|
-
|
619
|
+
drop_input_cols = self._drop_input_cols,
|
652
620
|
expected_output_cols_type="float",
|
653
621
|
)
|
654
622
|
|
@@ -713,7 +681,7 @@ class BayesianRidge(BaseTransformer):
|
|
713
681
|
transform_kwargs = dict(
|
714
682
|
session=dataset._session,
|
715
683
|
dependencies=self._deps,
|
716
|
-
|
684
|
+
drop_input_cols = self._drop_input_cols,
|
717
685
|
expected_output_cols_type="float",
|
718
686
|
)
|
719
687
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -774,7 +742,7 @@ class BayesianRidge(BaseTransformer):
|
|
774
742
|
transform_kwargs = dict(
|
775
743
|
session=dataset._session,
|
776
744
|
dependencies=self._deps,
|
777
|
-
|
745
|
+
drop_input_cols = self._drop_input_cols,
|
778
746
|
expected_output_cols_type="float",
|
779
747
|
)
|
780
748
|
|
@@ -839,7 +807,7 @@ class BayesianRidge(BaseTransformer):
|
|
839
807
|
transform_kwargs = dict(
|
840
808
|
session=dataset._session,
|
841
809
|
dependencies=self._deps,
|
842
|
-
|
810
|
+
drop_input_cols = self._drop_input_cols,
|
843
811
|
expected_output_cols_type="float",
|
844
812
|
)
|
845
813
|
|
@@ -895,13 +863,17 @@ class BayesianRidge(BaseTransformer):
|
|
895
863
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
896
864
|
|
897
865
|
if isinstance(dataset, DataFrame):
|
866
|
+
self._deps = self._batch_inference_validate_snowpark(
|
867
|
+
dataset=dataset,
|
868
|
+
inference_method="score",
|
869
|
+
)
|
898
870
|
selected_cols = self._get_active_columns()
|
899
871
|
if len(selected_cols) > 0:
|
900
872
|
dataset = dataset.select(selected_cols)
|
901
873
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
902
874
|
transform_kwargs = dict(
|
903
875
|
session=dataset._session,
|
904
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
876
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
905
877
|
score_sproc_imports=['sklearn'],
|
906
878
|
)
|
907
879
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -975,9 +947,9 @@ class BayesianRidge(BaseTransformer):
|
|
975
947
|
transform_kwargs = dict(
|
976
948
|
session = dataset._session,
|
977
949
|
dependencies = self._deps,
|
978
|
-
|
979
|
-
expected_output_cols_type
|
980
|
-
n_neighbors =
|
950
|
+
drop_input_cols = self._drop_input_cols,
|
951
|
+
expected_output_cols_type="array",
|
952
|
+
n_neighbors = n_neighbors,
|
981
953
|
return_distance = return_distance
|
982
954
|
)
|
983
955
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -329,18 +329,24 @@ class ElasticNet(BaseTransformer):
|
|
329
329
|
self._get_model_signatures(dataset)
|
330
330
|
return self
|
331
331
|
|
332
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
333
|
-
if self._drop_input_cols:
|
334
|
-
return []
|
335
|
-
else:
|
336
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
337
|
-
|
338
332
|
def _batch_inference_validate_snowpark(
|
339
333
|
self,
|
340
334
|
dataset: DataFrame,
|
341
335
|
inference_method: str,
|
342
336
|
) -> List[str]:
|
343
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
337
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
338
|
+
return the available package that exists in the snowflake anaconda channel
|
339
|
+
|
340
|
+
Args:
|
341
|
+
dataset: snowpark dataframe
|
342
|
+
inference_method: the inference method such as predict, score...
|
343
|
+
|
344
|
+
Raises:
|
345
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
346
|
+
SnowflakeMLException: If the session is None, raise error
|
347
|
+
|
348
|
+
Returns:
|
349
|
+
A list of available package that exists in the snowflake anaconda channel
|
344
350
|
"""
|
345
351
|
if not self._is_fitted:
|
346
352
|
raise exceptions.SnowflakeMLException(
|
@@ -414,7 +420,7 @@ class ElasticNet(BaseTransformer):
|
|
414
420
|
transform_kwargs = dict(
|
415
421
|
session = dataset._session,
|
416
422
|
dependencies = self._deps,
|
417
|
-
|
423
|
+
drop_input_cols = self._drop_input_cols,
|
418
424
|
expected_output_cols_type = expected_type_inferred,
|
419
425
|
)
|
420
426
|
|
@@ -474,16 +480,16 @@ class ElasticNet(BaseTransformer):
|
|
474
480
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
475
481
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
476
482
|
# each row containing a list of values.
|
477
|
-
expected_dtype = "
|
483
|
+
expected_dtype = "array"
|
478
484
|
|
479
485
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
480
486
|
if expected_dtype == "":
|
481
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
487
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
482
488
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
483
|
-
expected_dtype = "
|
484
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
489
|
+
expected_dtype = "array"
|
490
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
485
491
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
486
|
-
expected_dtype = "
|
492
|
+
expected_dtype = "array"
|
487
493
|
else:
|
488
494
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
489
495
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -501,7 +507,7 @@ class ElasticNet(BaseTransformer):
|
|
501
507
|
transform_kwargs = dict(
|
502
508
|
session = dataset._session,
|
503
509
|
dependencies = self._deps,
|
504
|
-
|
510
|
+
drop_input_cols = self._drop_input_cols,
|
505
511
|
expected_output_cols_type = expected_dtype,
|
506
512
|
)
|
507
513
|
|
@@ -552,7 +558,7 @@ class ElasticNet(BaseTransformer):
|
|
552
558
|
subproject=_SUBPROJECT,
|
553
559
|
)
|
554
560
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
555
|
-
|
561
|
+
drop_input_cols=self._drop_input_cols,
|
556
562
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
557
563
|
)
|
558
564
|
self._sklearn_object = fitted_estimator
|
@@ -570,44 +576,6 @@ class ElasticNet(BaseTransformer):
|
|
570
576
|
assert self._sklearn_object is not None
|
571
577
|
return self._sklearn_object.embedding_
|
572
578
|
|
573
|
-
|
574
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
575
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
576
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
577
|
-
"""
|
578
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
579
|
-
if output_cols:
|
580
|
-
output_cols = [
|
581
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
582
|
-
for c in output_cols
|
583
|
-
]
|
584
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
585
|
-
output_cols = [output_cols_prefix]
|
586
|
-
elif self._sklearn_object is not None:
|
587
|
-
classes = self._sklearn_object.classes_
|
588
|
-
if isinstance(classes, numpy.ndarray):
|
589
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
590
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
591
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
592
|
-
output_cols = []
|
593
|
-
for i, cl in enumerate(classes):
|
594
|
-
# For binary classification, there is only one output column for each class
|
595
|
-
# ndarray as the two classes are complementary.
|
596
|
-
if len(cl) == 2:
|
597
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
598
|
-
else:
|
599
|
-
output_cols.extend([
|
600
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
601
|
-
])
|
602
|
-
else:
|
603
|
-
output_cols = []
|
604
|
-
|
605
|
-
# Make sure column names are valid snowflake identifiers.
|
606
|
-
assert output_cols is not None # Make MyPy happy
|
607
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
608
|
-
|
609
|
-
return rv
|
610
|
-
|
611
579
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
612
580
|
@telemetry.send_api_usage_telemetry(
|
613
581
|
project=_PROJECT,
|
@@ -647,7 +615,7 @@ class ElasticNet(BaseTransformer):
|
|
647
615
|
transform_kwargs = dict(
|
648
616
|
session=dataset._session,
|
649
617
|
dependencies=self._deps,
|
650
|
-
|
618
|
+
drop_input_cols = self._drop_input_cols,
|
651
619
|
expected_output_cols_type="float",
|
652
620
|
)
|
653
621
|
|
@@ -712,7 +680,7 @@ class ElasticNet(BaseTransformer):
|
|
712
680
|
transform_kwargs = dict(
|
713
681
|
session=dataset._session,
|
714
682
|
dependencies=self._deps,
|
715
|
-
|
683
|
+
drop_input_cols = self._drop_input_cols,
|
716
684
|
expected_output_cols_type="float",
|
717
685
|
)
|
718
686
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -773,7 +741,7 @@ class ElasticNet(BaseTransformer):
|
|
773
741
|
transform_kwargs = dict(
|
774
742
|
session=dataset._session,
|
775
743
|
dependencies=self._deps,
|
776
|
-
|
744
|
+
drop_input_cols = self._drop_input_cols,
|
777
745
|
expected_output_cols_type="float",
|
778
746
|
)
|
779
747
|
|
@@ -838,7 +806,7 @@ class ElasticNet(BaseTransformer):
|
|
838
806
|
transform_kwargs = dict(
|
839
807
|
session=dataset._session,
|
840
808
|
dependencies=self._deps,
|
841
|
-
|
809
|
+
drop_input_cols = self._drop_input_cols,
|
842
810
|
expected_output_cols_type="float",
|
843
811
|
)
|
844
812
|
|
@@ -894,13 +862,17 @@ class ElasticNet(BaseTransformer):
|
|
894
862
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
895
863
|
|
896
864
|
if isinstance(dataset, DataFrame):
|
865
|
+
self._deps = self._batch_inference_validate_snowpark(
|
866
|
+
dataset=dataset,
|
867
|
+
inference_method="score",
|
868
|
+
)
|
897
869
|
selected_cols = self._get_active_columns()
|
898
870
|
if len(selected_cols) > 0:
|
899
871
|
dataset = dataset.select(selected_cols)
|
900
872
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
901
873
|
transform_kwargs = dict(
|
902
874
|
session=dataset._session,
|
903
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
875
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
904
876
|
score_sproc_imports=['sklearn'],
|
905
877
|
)
|
906
878
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -974,9 +946,9 @@ class ElasticNet(BaseTransformer):
|
|
974
946
|
transform_kwargs = dict(
|
975
947
|
session = dataset._session,
|
976
948
|
dependencies = self._deps,
|
977
|
-
|
978
|
-
expected_output_cols_type
|
979
|
-
n_neighbors =
|
949
|
+
drop_input_cols = self._drop_input_cols,
|
950
|
+
expected_output_cols_type="array",
|
951
|
+
n_neighbors = n_neighbors,
|
980
952
|
return_distance = return_distance
|
981
953
|
)
|
982
954
|
elif isinstance(dataset, pd.DataFrame):
|