snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -358,18 +358,24 @@ class RidgeClassifier(BaseTransformer):
|
|
358
358
|
self._get_model_signatures(dataset)
|
359
359
|
return self
|
360
360
|
|
361
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
362
|
-
if self._drop_input_cols:
|
363
|
-
return []
|
364
|
-
else:
|
365
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
366
|
-
|
367
361
|
def _batch_inference_validate_snowpark(
|
368
362
|
self,
|
369
363
|
dataset: DataFrame,
|
370
364
|
inference_method: str,
|
371
365
|
) -> List[str]:
|
372
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
366
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
367
|
+
return the available package that exists in the snowflake anaconda channel
|
368
|
+
|
369
|
+
Args:
|
370
|
+
dataset: snowpark dataframe
|
371
|
+
inference_method: the inference method such as predict, score...
|
372
|
+
|
373
|
+
Raises:
|
374
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
375
|
+
SnowflakeMLException: If the session is None, raise error
|
376
|
+
|
377
|
+
Returns:
|
378
|
+
A list of available package that exists in the snowflake anaconda channel
|
373
379
|
"""
|
374
380
|
if not self._is_fitted:
|
375
381
|
raise exceptions.SnowflakeMLException(
|
@@ -443,7 +449,7 @@ class RidgeClassifier(BaseTransformer):
|
|
443
449
|
transform_kwargs = dict(
|
444
450
|
session = dataset._session,
|
445
451
|
dependencies = self._deps,
|
446
|
-
|
452
|
+
drop_input_cols = self._drop_input_cols,
|
447
453
|
expected_output_cols_type = expected_type_inferred,
|
448
454
|
)
|
449
455
|
|
@@ -503,16 +509,16 @@ class RidgeClassifier(BaseTransformer):
|
|
503
509
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
504
510
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
505
511
|
# each row containing a list of values.
|
506
|
-
expected_dtype = "
|
512
|
+
expected_dtype = "array"
|
507
513
|
|
508
514
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
509
515
|
if expected_dtype == "":
|
510
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
516
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
511
517
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
512
|
-
expected_dtype = "
|
513
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
518
|
+
expected_dtype = "array"
|
519
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
514
520
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
515
|
-
expected_dtype = "
|
521
|
+
expected_dtype = "array"
|
516
522
|
else:
|
517
523
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
518
524
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -530,7 +536,7 @@ class RidgeClassifier(BaseTransformer):
|
|
530
536
|
transform_kwargs = dict(
|
531
537
|
session = dataset._session,
|
532
538
|
dependencies = self._deps,
|
533
|
-
|
539
|
+
drop_input_cols = self._drop_input_cols,
|
534
540
|
expected_output_cols_type = expected_dtype,
|
535
541
|
)
|
536
542
|
|
@@ -581,7 +587,7 @@ class RidgeClassifier(BaseTransformer):
|
|
581
587
|
subproject=_SUBPROJECT,
|
582
588
|
)
|
583
589
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
584
|
-
|
590
|
+
drop_input_cols=self._drop_input_cols,
|
585
591
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
586
592
|
)
|
587
593
|
self._sklearn_object = fitted_estimator
|
@@ -599,44 +605,6 @@ class RidgeClassifier(BaseTransformer):
|
|
599
605
|
assert self._sklearn_object is not None
|
600
606
|
return self._sklearn_object.embedding_
|
601
607
|
|
602
|
-
|
603
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
604
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
605
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
606
|
-
"""
|
607
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
608
|
-
if output_cols:
|
609
|
-
output_cols = [
|
610
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
611
|
-
for c in output_cols
|
612
|
-
]
|
613
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
614
|
-
output_cols = [output_cols_prefix]
|
615
|
-
elif self._sklearn_object is not None:
|
616
|
-
classes = self._sklearn_object.classes_
|
617
|
-
if isinstance(classes, numpy.ndarray):
|
618
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
619
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
620
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
621
|
-
output_cols = []
|
622
|
-
for i, cl in enumerate(classes):
|
623
|
-
# For binary classification, there is only one output column for each class
|
624
|
-
# ndarray as the two classes are complementary.
|
625
|
-
if len(cl) == 2:
|
626
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
627
|
-
else:
|
628
|
-
output_cols.extend([
|
629
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
630
|
-
])
|
631
|
-
else:
|
632
|
-
output_cols = []
|
633
|
-
|
634
|
-
# Make sure column names are valid snowflake identifiers.
|
635
|
-
assert output_cols is not None # Make MyPy happy
|
636
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
637
|
-
|
638
|
-
return rv
|
639
|
-
|
640
608
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
641
609
|
@telemetry.send_api_usage_telemetry(
|
642
610
|
project=_PROJECT,
|
@@ -676,7 +644,7 @@ class RidgeClassifier(BaseTransformer):
|
|
676
644
|
transform_kwargs = dict(
|
677
645
|
session=dataset._session,
|
678
646
|
dependencies=self._deps,
|
679
|
-
|
647
|
+
drop_input_cols = self._drop_input_cols,
|
680
648
|
expected_output_cols_type="float",
|
681
649
|
)
|
682
650
|
|
@@ -741,7 +709,7 @@ class RidgeClassifier(BaseTransformer):
|
|
741
709
|
transform_kwargs = dict(
|
742
710
|
session=dataset._session,
|
743
711
|
dependencies=self._deps,
|
744
|
-
|
712
|
+
drop_input_cols = self._drop_input_cols,
|
745
713
|
expected_output_cols_type="float",
|
746
714
|
)
|
747
715
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -804,7 +772,7 @@ class RidgeClassifier(BaseTransformer):
|
|
804
772
|
transform_kwargs = dict(
|
805
773
|
session=dataset._session,
|
806
774
|
dependencies=self._deps,
|
807
|
-
|
775
|
+
drop_input_cols = self._drop_input_cols,
|
808
776
|
expected_output_cols_type="float",
|
809
777
|
)
|
810
778
|
|
@@ -869,7 +837,7 @@ class RidgeClassifier(BaseTransformer):
|
|
869
837
|
transform_kwargs = dict(
|
870
838
|
session=dataset._session,
|
871
839
|
dependencies=self._deps,
|
872
|
-
|
840
|
+
drop_input_cols = self._drop_input_cols,
|
873
841
|
expected_output_cols_type="float",
|
874
842
|
)
|
875
843
|
|
@@ -925,13 +893,17 @@ class RidgeClassifier(BaseTransformer):
|
|
925
893
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
926
894
|
|
927
895
|
if isinstance(dataset, DataFrame):
|
896
|
+
self._deps = self._batch_inference_validate_snowpark(
|
897
|
+
dataset=dataset,
|
898
|
+
inference_method="score",
|
899
|
+
)
|
928
900
|
selected_cols = self._get_active_columns()
|
929
901
|
if len(selected_cols) > 0:
|
930
902
|
dataset = dataset.select(selected_cols)
|
931
903
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
932
904
|
transform_kwargs = dict(
|
933
905
|
session=dataset._session,
|
934
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
906
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
935
907
|
score_sproc_imports=['sklearn'],
|
936
908
|
)
|
937
909
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1005,9 +977,9 @@ class RidgeClassifier(BaseTransformer):
|
|
1005
977
|
transform_kwargs = dict(
|
1006
978
|
session = dataset._session,
|
1007
979
|
dependencies = self._deps,
|
1008
|
-
|
1009
|
-
expected_output_cols_type
|
1010
|
-
n_neighbors =
|
980
|
+
drop_input_cols = self._drop_input_cols,
|
981
|
+
expected_output_cols_type="array",
|
982
|
+
n_neighbors = n_neighbors,
|
1011
983
|
return_distance = return_distance
|
1012
984
|
)
|
1013
985
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -309,18 +309,24 @@ class RidgeClassifierCV(BaseTransformer):
|
|
309
309
|
self._get_model_signatures(dataset)
|
310
310
|
return self
|
311
311
|
|
312
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
313
|
-
if self._drop_input_cols:
|
314
|
-
return []
|
315
|
-
else:
|
316
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
317
|
-
|
318
312
|
def _batch_inference_validate_snowpark(
|
319
313
|
self,
|
320
314
|
dataset: DataFrame,
|
321
315
|
inference_method: str,
|
322
316
|
) -> List[str]:
|
323
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
317
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
318
|
+
return the available package that exists in the snowflake anaconda channel
|
319
|
+
|
320
|
+
Args:
|
321
|
+
dataset: snowpark dataframe
|
322
|
+
inference_method: the inference method such as predict, score...
|
323
|
+
|
324
|
+
Raises:
|
325
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
326
|
+
SnowflakeMLException: If the session is None, raise error
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
A list of available package that exists in the snowflake anaconda channel
|
324
330
|
"""
|
325
331
|
if not self._is_fitted:
|
326
332
|
raise exceptions.SnowflakeMLException(
|
@@ -394,7 +400,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
394
400
|
transform_kwargs = dict(
|
395
401
|
session = dataset._session,
|
396
402
|
dependencies = self._deps,
|
397
|
-
|
403
|
+
drop_input_cols = self._drop_input_cols,
|
398
404
|
expected_output_cols_type = expected_type_inferred,
|
399
405
|
)
|
400
406
|
|
@@ -454,16 +460,16 @@ class RidgeClassifierCV(BaseTransformer):
|
|
454
460
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
455
461
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
456
462
|
# each row containing a list of values.
|
457
|
-
expected_dtype = "
|
463
|
+
expected_dtype = "array"
|
458
464
|
|
459
465
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
460
466
|
if expected_dtype == "":
|
461
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
467
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
462
468
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
463
|
-
expected_dtype = "
|
464
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
469
|
+
expected_dtype = "array"
|
470
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
465
471
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
466
|
-
expected_dtype = "
|
472
|
+
expected_dtype = "array"
|
467
473
|
else:
|
468
474
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
469
475
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -481,7 +487,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
481
487
|
transform_kwargs = dict(
|
482
488
|
session = dataset._session,
|
483
489
|
dependencies = self._deps,
|
484
|
-
|
490
|
+
drop_input_cols = self._drop_input_cols,
|
485
491
|
expected_output_cols_type = expected_dtype,
|
486
492
|
)
|
487
493
|
|
@@ -532,7 +538,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
532
538
|
subproject=_SUBPROJECT,
|
533
539
|
)
|
534
540
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
535
|
-
|
541
|
+
drop_input_cols=self._drop_input_cols,
|
536
542
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
537
543
|
)
|
538
544
|
self._sklearn_object = fitted_estimator
|
@@ -550,44 +556,6 @@ class RidgeClassifierCV(BaseTransformer):
|
|
550
556
|
assert self._sklearn_object is not None
|
551
557
|
return self._sklearn_object.embedding_
|
552
558
|
|
553
|
-
|
554
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
555
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
556
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
557
|
-
"""
|
558
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
559
|
-
if output_cols:
|
560
|
-
output_cols = [
|
561
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
562
|
-
for c in output_cols
|
563
|
-
]
|
564
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
565
|
-
output_cols = [output_cols_prefix]
|
566
|
-
elif self._sklearn_object is not None:
|
567
|
-
classes = self._sklearn_object.classes_
|
568
|
-
if isinstance(classes, numpy.ndarray):
|
569
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
570
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
571
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
572
|
-
output_cols = []
|
573
|
-
for i, cl in enumerate(classes):
|
574
|
-
# For binary classification, there is only one output column for each class
|
575
|
-
# ndarray as the two classes are complementary.
|
576
|
-
if len(cl) == 2:
|
577
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
578
|
-
else:
|
579
|
-
output_cols.extend([
|
580
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
581
|
-
])
|
582
|
-
else:
|
583
|
-
output_cols = []
|
584
|
-
|
585
|
-
# Make sure column names are valid snowflake identifiers.
|
586
|
-
assert output_cols is not None # Make MyPy happy
|
587
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
588
|
-
|
589
|
-
return rv
|
590
|
-
|
591
559
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
592
560
|
@telemetry.send_api_usage_telemetry(
|
593
561
|
project=_PROJECT,
|
@@ -627,7 +595,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
627
595
|
transform_kwargs = dict(
|
628
596
|
session=dataset._session,
|
629
597
|
dependencies=self._deps,
|
630
|
-
|
598
|
+
drop_input_cols = self._drop_input_cols,
|
631
599
|
expected_output_cols_type="float",
|
632
600
|
)
|
633
601
|
|
@@ -692,7 +660,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
692
660
|
transform_kwargs = dict(
|
693
661
|
session=dataset._session,
|
694
662
|
dependencies=self._deps,
|
695
|
-
|
663
|
+
drop_input_cols = self._drop_input_cols,
|
696
664
|
expected_output_cols_type="float",
|
697
665
|
)
|
698
666
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -755,7 +723,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
755
723
|
transform_kwargs = dict(
|
756
724
|
session=dataset._session,
|
757
725
|
dependencies=self._deps,
|
758
|
-
|
726
|
+
drop_input_cols = self._drop_input_cols,
|
759
727
|
expected_output_cols_type="float",
|
760
728
|
)
|
761
729
|
|
@@ -820,7 +788,7 @@ class RidgeClassifierCV(BaseTransformer):
|
|
820
788
|
transform_kwargs = dict(
|
821
789
|
session=dataset._session,
|
822
790
|
dependencies=self._deps,
|
823
|
-
|
791
|
+
drop_input_cols = self._drop_input_cols,
|
824
792
|
expected_output_cols_type="float",
|
825
793
|
)
|
826
794
|
|
@@ -876,13 +844,17 @@ class RidgeClassifierCV(BaseTransformer):
|
|
876
844
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
877
845
|
|
878
846
|
if isinstance(dataset, DataFrame):
|
847
|
+
self._deps = self._batch_inference_validate_snowpark(
|
848
|
+
dataset=dataset,
|
849
|
+
inference_method="score",
|
850
|
+
)
|
879
851
|
selected_cols = self._get_active_columns()
|
880
852
|
if len(selected_cols) > 0:
|
881
853
|
dataset = dataset.select(selected_cols)
|
882
854
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
883
855
|
transform_kwargs = dict(
|
884
856
|
session=dataset._session,
|
885
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
857
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
886
858
|
score_sproc_imports=['sklearn'],
|
887
859
|
)
|
888
860
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -956,9 +928,9 @@ class RidgeClassifierCV(BaseTransformer):
|
|
956
928
|
transform_kwargs = dict(
|
957
929
|
session = dataset._session,
|
958
930
|
dependencies = self._deps,
|
959
|
-
|
960
|
-
expected_output_cols_type
|
961
|
-
n_neighbors =
|
931
|
+
drop_input_cols = self._drop_input_cols,
|
932
|
+
expected_output_cols_type="array",
|
933
|
+
n_neighbors = n_neighbors,
|
962
934
|
return_distance = return_distance
|
963
935
|
)
|
964
936
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -330,18 +330,24 @@ class RidgeCV(BaseTransformer):
|
|
330
330
|
self._get_model_signatures(dataset)
|
331
331
|
return self
|
332
332
|
|
333
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
334
|
-
if self._drop_input_cols:
|
335
|
-
return []
|
336
|
-
else:
|
337
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
338
|
-
|
339
333
|
def _batch_inference_validate_snowpark(
|
340
334
|
self,
|
341
335
|
dataset: DataFrame,
|
342
336
|
inference_method: str,
|
343
337
|
) -> List[str]:
|
344
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
338
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
339
|
+
return the available package that exists in the snowflake anaconda channel
|
340
|
+
|
341
|
+
Args:
|
342
|
+
dataset: snowpark dataframe
|
343
|
+
inference_method: the inference method such as predict, score...
|
344
|
+
|
345
|
+
Raises:
|
346
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
347
|
+
SnowflakeMLException: If the session is None, raise error
|
348
|
+
|
349
|
+
Returns:
|
350
|
+
A list of available package that exists in the snowflake anaconda channel
|
345
351
|
"""
|
346
352
|
if not self._is_fitted:
|
347
353
|
raise exceptions.SnowflakeMLException(
|
@@ -415,7 +421,7 @@ class RidgeCV(BaseTransformer):
|
|
415
421
|
transform_kwargs = dict(
|
416
422
|
session = dataset._session,
|
417
423
|
dependencies = self._deps,
|
418
|
-
|
424
|
+
drop_input_cols = self._drop_input_cols,
|
419
425
|
expected_output_cols_type = expected_type_inferred,
|
420
426
|
)
|
421
427
|
|
@@ -475,16 +481,16 @@ class RidgeCV(BaseTransformer):
|
|
475
481
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
476
482
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
477
483
|
# each row containing a list of values.
|
478
|
-
expected_dtype = "
|
484
|
+
expected_dtype = "array"
|
479
485
|
|
480
486
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
481
487
|
if expected_dtype == "":
|
482
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
488
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
483
489
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
484
|
-
expected_dtype = "
|
485
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
490
|
+
expected_dtype = "array"
|
491
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
486
492
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
487
|
-
expected_dtype = "
|
493
|
+
expected_dtype = "array"
|
488
494
|
else:
|
489
495
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
490
496
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -502,7 +508,7 @@ class RidgeCV(BaseTransformer):
|
|
502
508
|
transform_kwargs = dict(
|
503
509
|
session = dataset._session,
|
504
510
|
dependencies = self._deps,
|
505
|
-
|
511
|
+
drop_input_cols = self._drop_input_cols,
|
506
512
|
expected_output_cols_type = expected_dtype,
|
507
513
|
)
|
508
514
|
|
@@ -553,7 +559,7 @@ class RidgeCV(BaseTransformer):
|
|
553
559
|
subproject=_SUBPROJECT,
|
554
560
|
)
|
555
561
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
556
|
-
|
562
|
+
drop_input_cols=self._drop_input_cols,
|
557
563
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
558
564
|
)
|
559
565
|
self._sklearn_object = fitted_estimator
|
@@ -571,44 +577,6 @@ class RidgeCV(BaseTransformer):
|
|
571
577
|
assert self._sklearn_object is not None
|
572
578
|
return self._sklearn_object.embedding_
|
573
579
|
|
574
|
-
|
575
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
576
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
577
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
578
|
-
"""
|
579
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
580
|
-
if output_cols:
|
581
|
-
output_cols = [
|
582
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
583
|
-
for c in output_cols
|
584
|
-
]
|
585
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
586
|
-
output_cols = [output_cols_prefix]
|
587
|
-
elif self._sklearn_object is not None:
|
588
|
-
classes = self._sklearn_object.classes_
|
589
|
-
if isinstance(classes, numpy.ndarray):
|
590
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
591
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
592
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
593
|
-
output_cols = []
|
594
|
-
for i, cl in enumerate(classes):
|
595
|
-
# For binary classification, there is only one output column for each class
|
596
|
-
# ndarray as the two classes are complementary.
|
597
|
-
if len(cl) == 2:
|
598
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
599
|
-
else:
|
600
|
-
output_cols.extend([
|
601
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
602
|
-
])
|
603
|
-
else:
|
604
|
-
output_cols = []
|
605
|
-
|
606
|
-
# Make sure column names are valid snowflake identifiers.
|
607
|
-
assert output_cols is not None # Make MyPy happy
|
608
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
609
|
-
|
610
|
-
return rv
|
611
|
-
|
612
580
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
613
581
|
@telemetry.send_api_usage_telemetry(
|
614
582
|
project=_PROJECT,
|
@@ -648,7 +616,7 @@ class RidgeCV(BaseTransformer):
|
|
648
616
|
transform_kwargs = dict(
|
649
617
|
session=dataset._session,
|
650
618
|
dependencies=self._deps,
|
651
|
-
|
619
|
+
drop_input_cols = self._drop_input_cols,
|
652
620
|
expected_output_cols_type="float",
|
653
621
|
)
|
654
622
|
|
@@ -713,7 +681,7 @@ class RidgeCV(BaseTransformer):
|
|
713
681
|
transform_kwargs = dict(
|
714
682
|
session=dataset._session,
|
715
683
|
dependencies=self._deps,
|
716
|
-
|
684
|
+
drop_input_cols = self._drop_input_cols,
|
717
685
|
expected_output_cols_type="float",
|
718
686
|
)
|
719
687
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -774,7 +742,7 @@ class RidgeCV(BaseTransformer):
|
|
774
742
|
transform_kwargs = dict(
|
775
743
|
session=dataset._session,
|
776
744
|
dependencies=self._deps,
|
777
|
-
|
745
|
+
drop_input_cols = self._drop_input_cols,
|
778
746
|
expected_output_cols_type="float",
|
779
747
|
)
|
780
748
|
|
@@ -839,7 +807,7 @@ class RidgeCV(BaseTransformer):
|
|
839
807
|
transform_kwargs = dict(
|
840
808
|
session=dataset._session,
|
841
809
|
dependencies=self._deps,
|
842
|
-
|
810
|
+
drop_input_cols = self._drop_input_cols,
|
843
811
|
expected_output_cols_type="float",
|
844
812
|
)
|
845
813
|
|
@@ -895,13 +863,17 @@ class RidgeCV(BaseTransformer):
|
|
895
863
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
896
864
|
|
897
865
|
if isinstance(dataset, DataFrame):
|
866
|
+
self._deps = self._batch_inference_validate_snowpark(
|
867
|
+
dataset=dataset,
|
868
|
+
inference_method="score",
|
869
|
+
)
|
898
870
|
selected_cols = self._get_active_columns()
|
899
871
|
if len(selected_cols) > 0:
|
900
872
|
dataset = dataset.select(selected_cols)
|
901
873
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
902
874
|
transform_kwargs = dict(
|
903
875
|
session=dataset._session,
|
904
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
876
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
905
877
|
score_sproc_imports=['sklearn'],
|
906
878
|
)
|
907
879
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -975,9 +947,9 @@ class RidgeCV(BaseTransformer):
|
|
975
947
|
transform_kwargs = dict(
|
976
948
|
session = dataset._session,
|
977
949
|
dependencies = self._deps,
|
978
|
-
|
979
|
-
expected_output_cols_type
|
980
|
-
n_neighbors =
|
950
|
+
drop_input_cols = self._drop_input_cols,
|
951
|
+
expected_output_cols_type="array",
|
952
|
+
n_neighbors = n_neighbors,
|
981
953
|
return_distance = return_distance
|
982
954
|
)
|
983
955
|
elif isinstance(dataset, pd.DataFrame):
|