snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -351,18 +351,24 @@ class LassoCV(BaseTransformer):
|
|
351
351
|
self._get_model_signatures(dataset)
|
352
352
|
return self
|
353
353
|
|
354
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
355
|
-
if self._drop_input_cols:
|
356
|
-
return []
|
357
|
-
else:
|
358
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
359
|
-
|
360
354
|
def _batch_inference_validate_snowpark(
|
361
355
|
self,
|
362
356
|
dataset: DataFrame,
|
363
357
|
inference_method: str,
|
364
358
|
) -> List[str]:
|
365
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
359
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
360
|
+
return the available package that exists in the snowflake anaconda channel
|
361
|
+
|
362
|
+
Args:
|
363
|
+
dataset: snowpark dataframe
|
364
|
+
inference_method: the inference method such as predict, score...
|
365
|
+
|
366
|
+
Raises:
|
367
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
368
|
+
SnowflakeMLException: If the session is None, raise error
|
369
|
+
|
370
|
+
Returns:
|
371
|
+
A list of available package that exists in the snowflake anaconda channel
|
366
372
|
"""
|
367
373
|
if not self._is_fitted:
|
368
374
|
raise exceptions.SnowflakeMLException(
|
@@ -436,7 +442,7 @@ class LassoCV(BaseTransformer):
|
|
436
442
|
transform_kwargs = dict(
|
437
443
|
session = dataset._session,
|
438
444
|
dependencies = self._deps,
|
439
|
-
|
445
|
+
drop_input_cols = self._drop_input_cols,
|
440
446
|
expected_output_cols_type = expected_type_inferred,
|
441
447
|
)
|
442
448
|
|
@@ -496,16 +502,16 @@ class LassoCV(BaseTransformer):
|
|
496
502
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
497
503
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
498
504
|
# each row containing a list of values.
|
499
|
-
expected_dtype = "
|
505
|
+
expected_dtype = "array"
|
500
506
|
|
501
507
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
502
508
|
if expected_dtype == "":
|
503
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
509
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
504
510
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
505
|
-
expected_dtype = "
|
506
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
511
|
+
expected_dtype = "array"
|
512
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
507
513
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
508
|
-
expected_dtype = "
|
514
|
+
expected_dtype = "array"
|
509
515
|
else:
|
510
516
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
511
517
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -523,7 +529,7 @@ class LassoCV(BaseTransformer):
|
|
523
529
|
transform_kwargs = dict(
|
524
530
|
session = dataset._session,
|
525
531
|
dependencies = self._deps,
|
526
|
-
|
532
|
+
drop_input_cols = self._drop_input_cols,
|
527
533
|
expected_output_cols_type = expected_dtype,
|
528
534
|
)
|
529
535
|
|
@@ -574,7 +580,7 @@ class LassoCV(BaseTransformer):
|
|
574
580
|
subproject=_SUBPROJECT,
|
575
581
|
)
|
576
582
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
577
|
-
|
583
|
+
drop_input_cols=self._drop_input_cols,
|
578
584
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
579
585
|
)
|
580
586
|
self._sklearn_object = fitted_estimator
|
@@ -592,44 +598,6 @@ class LassoCV(BaseTransformer):
|
|
592
598
|
assert self._sklearn_object is not None
|
593
599
|
return self._sklearn_object.embedding_
|
594
600
|
|
595
|
-
|
596
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
597
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
598
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
599
|
-
"""
|
600
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
601
|
-
if output_cols:
|
602
|
-
output_cols = [
|
603
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
604
|
-
for c in output_cols
|
605
|
-
]
|
606
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
607
|
-
output_cols = [output_cols_prefix]
|
608
|
-
elif self._sklearn_object is not None:
|
609
|
-
classes = self._sklearn_object.classes_
|
610
|
-
if isinstance(classes, numpy.ndarray):
|
611
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
612
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
613
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
614
|
-
output_cols = []
|
615
|
-
for i, cl in enumerate(classes):
|
616
|
-
# For binary classification, there is only one output column for each class
|
617
|
-
# ndarray as the two classes are complementary.
|
618
|
-
if len(cl) == 2:
|
619
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
620
|
-
else:
|
621
|
-
output_cols.extend([
|
622
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
623
|
-
])
|
624
|
-
else:
|
625
|
-
output_cols = []
|
626
|
-
|
627
|
-
# Make sure column names are valid snowflake identifiers.
|
628
|
-
assert output_cols is not None # Make MyPy happy
|
629
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
630
|
-
|
631
|
-
return rv
|
632
|
-
|
633
601
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
634
602
|
@telemetry.send_api_usage_telemetry(
|
635
603
|
project=_PROJECT,
|
@@ -669,7 +637,7 @@ class LassoCV(BaseTransformer):
|
|
669
637
|
transform_kwargs = dict(
|
670
638
|
session=dataset._session,
|
671
639
|
dependencies=self._deps,
|
672
|
-
|
640
|
+
drop_input_cols = self._drop_input_cols,
|
673
641
|
expected_output_cols_type="float",
|
674
642
|
)
|
675
643
|
|
@@ -734,7 +702,7 @@ class LassoCV(BaseTransformer):
|
|
734
702
|
transform_kwargs = dict(
|
735
703
|
session=dataset._session,
|
736
704
|
dependencies=self._deps,
|
737
|
-
|
705
|
+
drop_input_cols = self._drop_input_cols,
|
738
706
|
expected_output_cols_type="float",
|
739
707
|
)
|
740
708
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -795,7 +763,7 @@ class LassoCV(BaseTransformer):
|
|
795
763
|
transform_kwargs = dict(
|
796
764
|
session=dataset._session,
|
797
765
|
dependencies=self._deps,
|
798
|
-
|
766
|
+
drop_input_cols = self._drop_input_cols,
|
799
767
|
expected_output_cols_type="float",
|
800
768
|
)
|
801
769
|
|
@@ -860,7 +828,7 @@ class LassoCV(BaseTransformer):
|
|
860
828
|
transform_kwargs = dict(
|
861
829
|
session=dataset._session,
|
862
830
|
dependencies=self._deps,
|
863
|
-
|
831
|
+
drop_input_cols = self._drop_input_cols,
|
864
832
|
expected_output_cols_type="float",
|
865
833
|
)
|
866
834
|
|
@@ -916,13 +884,17 @@ class LassoCV(BaseTransformer):
|
|
916
884
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
917
885
|
|
918
886
|
if isinstance(dataset, DataFrame):
|
887
|
+
self._deps = self._batch_inference_validate_snowpark(
|
888
|
+
dataset=dataset,
|
889
|
+
inference_method="score",
|
890
|
+
)
|
919
891
|
selected_cols = self._get_active_columns()
|
920
892
|
if len(selected_cols) > 0:
|
921
893
|
dataset = dataset.select(selected_cols)
|
922
894
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
923
895
|
transform_kwargs = dict(
|
924
896
|
session=dataset._session,
|
925
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
897
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
926
898
|
score_sproc_imports=['sklearn'],
|
927
899
|
)
|
928
900
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -996,9 +968,9 @@ class LassoCV(BaseTransformer):
|
|
996
968
|
transform_kwargs = dict(
|
997
969
|
session = dataset._session,
|
998
970
|
dependencies = self._deps,
|
999
|
-
|
1000
|
-
expected_output_cols_type
|
1001
|
-
n_neighbors =
|
971
|
+
drop_input_cols = self._drop_input_cols,
|
972
|
+
expected_output_cols_type="array",
|
973
|
+
n_neighbors = n_neighbors,
|
1002
974
|
return_distance = return_distance
|
1003
975
|
)
|
1004
976
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -343,18 +343,24 @@ class LassoLars(BaseTransformer):
|
|
343
343
|
self._get_model_signatures(dataset)
|
344
344
|
return self
|
345
345
|
|
346
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
347
|
-
if self._drop_input_cols:
|
348
|
-
return []
|
349
|
-
else:
|
350
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
351
|
-
|
352
346
|
def _batch_inference_validate_snowpark(
|
353
347
|
self,
|
354
348
|
dataset: DataFrame,
|
355
349
|
inference_method: str,
|
356
350
|
) -> List[str]:
|
357
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
351
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
352
|
+
return the available package that exists in the snowflake anaconda channel
|
353
|
+
|
354
|
+
Args:
|
355
|
+
dataset: snowpark dataframe
|
356
|
+
inference_method: the inference method such as predict, score...
|
357
|
+
|
358
|
+
Raises:
|
359
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
360
|
+
SnowflakeMLException: If the session is None, raise error
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
A list of available package that exists in the snowflake anaconda channel
|
358
364
|
"""
|
359
365
|
if not self._is_fitted:
|
360
366
|
raise exceptions.SnowflakeMLException(
|
@@ -428,7 +434,7 @@ class LassoLars(BaseTransformer):
|
|
428
434
|
transform_kwargs = dict(
|
429
435
|
session = dataset._session,
|
430
436
|
dependencies = self._deps,
|
431
|
-
|
437
|
+
drop_input_cols = self._drop_input_cols,
|
432
438
|
expected_output_cols_type = expected_type_inferred,
|
433
439
|
)
|
434
440
|
|
@@ -488,16 +494,16 @@ class LassoLars(BaseTransformer):
|
|
488
494
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
489
495
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
490
496
|
# each row containing a list of values.
|
491
|
-
expected_dtype = "
|
497
|
+
expected_dtype = "array"
|
492
498
|
|
493
499
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
494
500
|
if expected_dtype == "":
|
495
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
501
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
496
502
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
497
|
-
expected_dtype = "
|
498
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
503
|
+
expected_dtype = "array"
|
504
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
499
505
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
500
|
-
expected_dtype = "
|
506
|
+
expected_dtype = "array"
|
501
507
|
else:
|
502
508
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
503
509
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -515,7 +521,7 @@ class LassoLars(BaseTransformer):
|
|
515
521
|
transform_kwargs = dict(
|
516
522
|
session = dataset._session,
|
517
523
|
dependencies = self._deps,
|
518
|
-
|
524
|
+
drop_input_cols = self._drop_input_cols,
|
519
525
|
expected_output_cols_type = expected_dtype,
|
520
526
|
)
|
521
527
|
|
@@ -566,7 +572,7 @@ class LassoLars(BaseTransformer):
|
|
566
572
|
subproject=_SUBPROJECT,
|
567
573
|
)
|
568
574
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
569
|
-
|
575
|
+
drop_input_cols=self._drop_input_cols,
|
570
576
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
571
577
|
)
|
572
578
|
self._sklearn_object = fitted_estimator
|
@@ -584,44 +590,6 @@ class LassoLars(BaseTransformer):
|
|
584
590
|
assert self._sklearn_object is not None
|
585
591
|
return self._sklearn_object.embedding_
|
586
592
|
|
587
|
-
|
588
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
589
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
590
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
591
|
-
"""
|
592
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
593
|
-
if output_cols:
|
594
|
-
output_cols = [
|
595
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
596
|
-
for c in output_cols
|
597
|
-
]
|
598
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
599
|
-
output_cols = [output_cols_prefix]
|
600
|
-
elif self._sklearn_object is not None:
|
601
|
-
classes = self._sklearn_object.classes_
|
602
|
-
if isinstance(classes, numpy.ndarray):
|
603
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
604
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
605
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
606
|
-
output_cols = []
|
607
|
-
for i, cl in enumerate(classes):
|
608
|
-
# For binary classification, there is only one output column for each class
|
609
|
-
# ndarray as the two classes are complementary.
|
610
|
-
if len(cl) == 2:
|
611
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
612
|
-
else:
|
613
|
-
output_cols.extend([
|
614
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
615
|
-
])
|
616
|
-
else:
|
617
|
-
output_cols = []
|
618
|
-
|
619
|
-
# Make sure column names are valid snowflake identifiers.
|
620
|
-
assert output_cols is not None # Make MyPy happy
|
621
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
622
|
-
|
623
|
-
return rv
|
624
|
-
|
625
593
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
626
594
|
@telemetry.send_api_usage_telemetry(
|
627
595
|
project=_PROJECT,
|
@@ -661,7 +629,7 @@ class LassoLars(BaseTransformer):
|
|
661
629
|
transform_kwargs = dict(
|
662
630
|
session=dataset._session,
|
663
631
|
dependencies=self._deps,
|
664
|
-
|
632
|
+
drop_input_cols = self._drop_input_cols,
|
665
633
|
expected_output_cols_type="float",
|
666
634
|
)
|
667
635
|
|
@@ -726,7 +694,7 @@ class LassoLars(BaseTransformer):
|
|
726
694
|
transform_kwargs = dict(
|
727
695
|
session=dataset._session,
|
728
696
|
dependencies=self._deps,
|
729
|
-
|
697
|
+
drop_input_cols = self._drop_input_cols,
|
730
698
|
expected_output_cols_type="float",
|
731
699
|
)
|
732
700
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -787,7 +755,7 @@ class LassoLars(BaseTransformer):
|
|
787
755
|
transform_kwargs = dict(
|
788
756
|
session=dataset._session,
|
789
757
|
dependencies=self._deps,
|
790
|
-
|
758
|
+
drop_input_cols = self._drop_input_cols,
|
791
759
|
expected_output_cols_type="float",
|
792
760
|
)
|
793
761
|
|
@@ -852,7 +820,7 @@ class LassoLars(BaseTransformer):
|
|
852
820
|
transform_kwargs = dict(
|
853
821
|
session=dataset._session,
|
854
822
|
dependencies=self._deps,
|
855
|
-
|
823
|
+
drop_input_cols = self._drop_input_cols,
|
856
824
|
expected_output_cols_type="float",
|
857
825
|
)
|
858
826
|
|
@@ -908,13 +876,17 @@ class LassoLars(BaseTransformer):
|
|
908
876
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
909
877
|
|
910
878
|
if isinstance(dataset, DataFrame):
|
879
|
+
self._deps = self._batch_inference_validate_snowpark(
|
880
|
+
dataset=dataset,
|
881
|
+
inference_method="score",
|
882
|
+
)
|
911
883
|
selected_cols = self._get_active_columns()
|
912
884
|
if len(selected_cols) > 0:
|
913
885
|
dataset = dataset.select(selected_cols)
|
914
886
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
915
887
|
transform_kwargs = dict(
|
916
888
|
session=dataset._session,
|
917
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
889
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
918
890
|
score_sproc_imports=['sklearn'],
|
919
891
|
)
|
920
892
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -988,9 +960,9 @@ class LassoLars(BaseTransformer):
|
|
988
960
|
transform_kwargs = dict(
|
989
961
|
session = dataset._session,
|
990
962
|
dependencies = self._deps,
|
991
|
-
|
992
|
-
expected_output_cols_type
|
993
|
-
n_neighbors =
|
963
|
+
drop_input_cols = self._drop_input_cols,
|
964
|
+
expected_output_cols_type="array",
|
965
|
+
n_neighbors = n_neighbors,
|
994
966
|
return_distance = return_distance
|
995
967
|
)
|
996
968
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -344,18 +344,24 @@ class LassoLarsCV(BaseTransformer):
|
|
344
344
|
self._get_model_signatures(dataset)
|
345
345
|
return self
|
346
346
|
|
347
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
348
|
-
if self._drop_input_cols:
|
349
|
-
return []
|
350
|
-
else:
|
351
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
352
|
-
|
353
347
|
def _batch_inference_validate_snowpark(
|
354
348
|
self,
|
355
349
|
dataset: DataFrame,
|
356
350
|
inference_method: str,
|
357
351
|
) -> List[str]:
|
358
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
352
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
353
|
+
return the available package that exists in the snowflake anaconda channel
|
354
|
+
|
355
|
+
Args:
|
356
|
+
dataset: snowpark dataframe
|
357
|
+
inference_method: the inference method such as predict, score...
|
358
|
+
|
359
|
+
Raises:
|
360
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
361
|
+
SnowflakeMLException: If the session is None, raise error
|
362
|
+
|
363
|
+
Returns:
|
364
|
+
A list of available package that exists in the snowflake anaconda channel
|
359
365
|
"""
|
360
366
|
if not self._is_fitted:
|
361
367
|
raise exceptions.SnowflakeMLException(
|
@@ -429,7 +435,7 @@ class LassoLarsCV(BaseTransformer):
|
|
429
435
|
transform_kwargs = dict(
|
430
436
|
session = dataset._session,
|
431
437
|
dependencies = self._deps,
|
432
|
-
|
438
|
+
drop_input_cols = self._drop_input_cols,
|
433
439
|
expected_output_cols_type = expected_type_inferred,
|
434
440
|
)
|
435
441
|
|
@@ -489,16 +495,16 @@ class LassoLarsCV(BaseTransformer):
|
|
489
495
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
490
496
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
491
497
|
# each row containing a list of values.
|
492
|
-
expected_dtype = "
|
498
|
+
expected_dtype = "array"
|
493
499
|
|
494
500
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
495
501
|
if expected_dtype == "":
|
496
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
502
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
497
503
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
498
|
-
expected_dtype = "
|
499
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
504
|
+
expected_dtype = "array"
|
505
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
500
506
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
501
|
-
expected_dtype = "
|
507
|
+
expected_dtype = "array"
|
502
508
|
else:
|
503
509
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
504
510
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -516,7 +522,7 @@ class LassoLarsCV(BaseTransformer):
|
|
516
522
|
transform_kwargs = dict(
|
517
523
|
session = dataset._session,
|
518
524
|
dependencies = self._deps,
|
519
|
-
|
525
|
+
drop_input_cols = self._drop_input_cols,
|
520
526
|
expected_output_cols_type = expected_dtype,
|
521
527
|
)
|
522
528
|
|
@@ -567,7 +573,7 @@ class LassoLarsCV(BaseTransformer):
|
|
567
573
|
subproject=_SUBPROJECT,
|
568
574
|
)
|
569
575
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
570
|
-
|
576
|
+
drop_input_cols=self._drop_input_cols,
|
571
577
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
572
578
|
)
|
573
579
|
self._sklearn_object = fitted_estimator
|
@@ -585,44 +591,6 @@ class LassoLarsCV(BaseTransformer):
|
|
585
591
|
assert self._sklearn_object is not None
|
586
592
|
return self._sklearn_object.embedding_
|
587
593
|
|
588
|
-
|
589
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
590
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
591
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
592
|
-
"""
|
593
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
594
|
-
if output_cols:
|
595
|
-
output_cols = [
|
596
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
597
|
-
for c in output_cols
|
598
|
-
]
|
599
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
600
|
-
output_cols = [output_cols_prefix]
|
601
|
-
elif self._sklearn_object is not None:
|
602
|
-
classes = self._sklearn_object.classes_
|
603
|
-
if isinstance(classes, numpy.ndarray):
|
604
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
605
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
606
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
607
|
-
output_cols = []
|
608
|
-
for i, cl in enumerate(classes):
|
609
|
-
# For binary classification, there is only one output column for each class
|
610
|
-
# ndarray as the two classes are complementary.
|
611
|
-
if len(cl) == 2:
|
612
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
613
|
-
else:
|
614
|
-
output_cols.extend([
|
615
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
616
|
-
])
|
617
|
-
else:
|
618
|
-
output_cols = []
|
619
|
-
|
620
|
-
# Make sure column names are valid snowflake identifiers.
|
621
|
-
assert output_cols is not None # Make MyPy happy
|
622
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
623
|
-
|
624
|
-
return rv
|
625
|
-
|
626
594
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
627
595
|
@telemetry.send_api_usage_telemetry(
|
628
596
|
project=_PROJECT,
|
@@ -662,7 +630,7 @@ class LassoLarsCV(BaseTransformer):
|
|
662
630
|
transform_kwargs = dict(
|
663
631
|
session=dataset._session,
|
664
632
|
dependencies=self._deps,
|
665
|
-
|
633
|
+
drop_input_cols = self._drop_input_cols,
|
666
634
|
expected_output_cols_type="float",
|
667
635
|
)
|
668
636
|
|
@@ -727,7 +695,7 @@ class LassoLarsCV(BaseTransformer):
|
|
727
695
|
transform_kwargs = dict(
|
728
696
|
session=dataset._session,
|
729
697
|
dependencies=self._deps,
|
730
|
-
|
698
|
+
drop_input_cols = self._drop_input_cols,
|
731
699
|
expected_output_cols_type="float",
|
732
700
|
)
|
733
701
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -788,7 +756,7 @@ class LassoLarsCV(BaseTransformer):
|
|
788
756
|
transform_kwargs = dict(
|
789
757
|
session=dataset._session,
|
790
758
|
dependencies=self._deps,
|
791
|
-
|
759
|
+
drop_input_cols = self._drop_input_cols,
|
792
760
|
expected_output_cols_type="float",
|
793
761
|
)
|
794
762
|
|
@@ -853,7 +821,7 @@ class LassoLarsCV(BaseTransformer):
|
|
853
821
|
transform_kwargs = dict(
|
854
822
|
session=dataset._session,
|
855
823
|
dependencies=self._deps,
|
856
|
-
|
824
|
+
drop_input_cols = self._drop_input_cols,
|
857
825
|
expected_output_cols_type="float",
|
858
826
|
)
|
859
827
|
|
@@ -909,13 +877,17 @@ class LassoLarsCV(BaseTransformer):
|
|
909
877
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
910
878
|
|
911
879
|
if isinstance(dataset, DataFrame):
|
880
|
+
self._deps = self._batch_inference_validate_snowpark(
|
881
|
+
dataset=dataset,
|
882
|
+
inference_method="score",
|
883
|
+
)
|
912
884
|
selected_cols = self._get_active_columns()
|
913
885
|
if len(selected_cols) > 0:
|
914
886
|
dataset = dataset.select(selected_cols)
|
915
887
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
916
888
|
transform_kwargs = dict(
|
917
889
|
session=dataset._session,
|
918
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
890
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
919
891
|
score_sproc_imports=['sklearn'],
|
920
892
|
)
|
921
893
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -989,9 +961,9 @@ class LassoLarsCV(BaseTransformer):
|
|
989
961
|
transform_kwargs = dict(
|
990
962
|
session = dataset._session,
|
991
963
|
dependencies = self._deps,
|
992
|
-
|
993
|
-
expected_output_cols_type
|
994
|
-
n_neighbors =
|
964
|
+
drop_input_cols = self._drop_input_cols,
|
965
|
+
expected_output_cols_type="array",
|
966
|
+
n_neighbors = n_neighbors,
|
995
967
|
return_distance = return_distance
|
996
968
|
)
|
997
969
|
elif isinstance(dataset, pd.DataFrame):
|