snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -283,18 +283,24 @@ class PolynomialFeatures(BaseTransformer):
|
|
283
283
|
self._get_model_signatures(dataset)
|
284
284
|
return self
|
285
285
|
|
286
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
287
|
-
if self._drop_input_cols:
|
288
|
-
return []
|
289
|
-
else:
|
290
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
291
|
-
|
292
286
|
def _batch_inference_validate_snowpark(
|
293
287
|
self,
|
294
288
|
dataset: DataFrame,
|
295
289
|
inference_method: str,
|
296
290
|
) -> List[str]:
|
297
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
291
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
292
|
+
return the available package that exists in the snowflake anaconda channel
|
293
|
+
|
294
|
+
Args:
|
295
|
+
dataset: snowpark dataframe
|
296
|
+
inference_method: the inference method such as predict, score...
|
297
|
+
|
298
|
+
Raises:
|
299
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
300
|
+
SnowflakeMLException: If the session is None, raise error
|
301
|
+
|
302
|
+
Returns:
|
303
|
+
A list of available package that exists in the snowflake anaconda channel
|
298
304
|
"""
|
299
305
|
if not self._is_fitted:
|
300
306
|
raise exceptions.SnowflakeMLException(
|
@@ -366,7 +372,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
366
372
|
transform_kwargs = dict(
|
367
373
|
session = dataset._session,
|
368
374
|
dependencies = self._deps,
|
369
|
-
|
375
|
+
drop_input_cols = self._drop_input_cols,
|
370
376
|
expected_output_cols_type = expected_type_inferred,
|
371
377
|
)
|
372
378
|
|
@@ -428,16 +434,16 @@ class PolynomialFeatures(BaseTransformer):
|
|
428
434
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
429
435
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
430
436
|
# each row containing a list of values.
|
431
|
-
expected_dtype = "
|
437
|
+
expected_dtype = "array"
|
432
438
|
|
433
439
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
434
440
|
if expected_dtype == "":
|
435
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
441
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
436
442
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
437
|
-
expected_dtype = "
|
438
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
443
|
+
expected_dtype = "array"
|
444
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
439
445
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
440
|
-
expected_dtype = "
|
446
|
+
expected_dtype = "array"
|
441
447
|
else:
|
442
448
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
443
449
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -455,7 +461,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
455
461
|
transform_kwargs = dict(
|
456
462
|
session = dataset._session,
|
457
463
|
dependencies = self._deps,
|
458
|
-
|
464
|
+
drop_input_cols = self._drop_input_cols,
|
459
465
|
expected_output_cols_type = expected_dtype,
|
460
466
|
)
|
461
467
|
|
@@ -506,7 +512,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
506
512
|
subproject=_SUBPROJECT,
|
507
513
|
)
|
508
514
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
509
|
-
|
515
|
+
drop_input_cols=self._drop_input_cols,
|
510
516
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
511
517
|
)
|
512
518
|
self._sklearn_object = fitted_estimator
|
@@ -524,44 +530,6 @@ class PolynomialFeatures(BaseTransformer):
|
|
524
530
|
assert self._sklearn_object is not None
|
525
531
|
return self._sklearn_object.embedding_
|
526
532
|
|
527
|
-
|
528
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
529
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
530
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
531
|
-
"""
|
532
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
533
|
-
if output_cols:
|
534
|
-
output_cols = [
|
535
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
536
|
-
for c in output_cols
|
537
|
-
]
|
538
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
539
|
-
output_cols = [output_cols_prefix]
|
540
|
-
elif self._sklearn_object is not None:
|
541
|
-
classes = self._sklearn_object.classes_
|
542
|
-
if isinstance(classes, numpy.ndarray):
|
543
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
544
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
545
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
546
|
-
output_cols = []
|
547
|
-
for i, cl in enumerate(classes):
|
548
|
-
# For binary classification, there is only one output column for each class
|
549
|
-
# ndarray as the two classes are complementary.
|
550
|
-
if len(cl) == 2:
|
551
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
552
|
-
else:
|
553
|
-
output_cols.extend([
|
554
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
555
|
-
])
|
556
|
-
else:
|
557
|
-
output_cols = []
|
558
|
-
|
559
|
-
# Make sure column names are valid snowflake identifiers.
|
560
|
-
assert output_cols is not None # Make MyPy happy
|
561
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
562
|
-
|
563
|
-
return rv
|
564
|
-
|
565
533
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
566
534
|
@telemetry.send_api_usage_telemetry(
|
567
535
|
project=_PROJECT,
|
@@ -601,7 +569,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
601
569
|
transform_kwargs = dict(
|
602
570
|
session=dataset._session,
|
603
571
|
dependencies=self._deps,
|
604
|
-
|
572
|
+
drop_input_cols = self._drop_input_cols,
|
605
573
|
expected_output_cols_type="float",
|
606
574
|
)
|
607
575
|
|
@@ -666,7 +634,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
666
634
|
transform_kwargs = dict(
|
667
635
|
session=dataset._session,
|
668
636
|
dependencies=self._deps,
|
669
|
-
|
637
|
+
drop_input_cols = self._drop_input_cols,
|
670
638
|
expected_output_cols_type="float",
|
671
639
|
)
|
672
640
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -727,7 +695,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
727
695
|
transform_kwargs = dict(
|
728
696
|
session=dataset._session,
|
729
697
|
dependencies=self._deps,
|
730
|
-
|
698
|
+
drop_input_cols = self._drop_input_cols,
|
731
699
|
expected_output_cols_type="float",
|
732
700
|
)
|
733
701
|
|
@@ -792,7 +760,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
792
760
|
transform_kwargs = dict(
|
793
761
|
session=dataset._session,
|
794
762
|
dependencies=self._deps,
|
795
|
-
|
763
|
+
drop_input_cols = self._drop_input_cols,
|
796
764
|
expected_output_cols_type="float",
|
797
765
|
)
|
798
766
|
|
@@ -846,13 +814,17 @@ class PolynomialFeatures(BaseTransformer):
|
|
846
814
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
847
815
|
|
848
816
|
if isinstance(dataset, DataFrame):
|
817
|
+
self._deps = self._batch_inference_validate_snowpark(
|
818
|
+
dataset=dataset,
|
819
|
+
inference_method="score",
|
820
|
+
)
|
849
821
|
selected_cols = self._get_active_columns()
|
850
822
|
if len(selected_cols) > 0:
|
851
823
|
dataset = dataset.select(selected_cols)
|
852
824
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
853
825
|
transform_kwargs = dict(
|
854
826
|
session=dataset._session,
|
855
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
827
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
856
828
|
score_sproc_imports=['sklearn'],
|
857
829
|
)
|
858
830
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -926,9 +898,9 @@ class PolynomialFeatures(BaseTransformer):
|
|
926
898
|
transform_kwargs = dict(
|
927
899
|
session = dataset._session,
|
928
900
|
dependencies = self._deps,
|
929
|
-
|
930
|
-
expected_output_cols_type
|
931
|
-
n_neighbors =
|
901
|
+
drop_input_cols = self._drop_input_cols,
|
902
|
+
expected_output_cols_type="array",
|
903
|
+
n_neighbors = n_neighbors,
|
932
904
|
return_distance = return_distance
|
933
905
|
)
|
934
906
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -289,18 +289,24 @@ class LabelPropagation(BaseTransformer):
|
|
289
289
|
self._get_model_signatures(dataset)
|
290
290
|
return self
|
291
291
|
|
292
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
293
|
-
if self._drop_input_cols:
|
294
|
-
return []
|
295
|
-
else:
|
296
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
297
|
-
|
298
292
|
def _batch_inference_validate_snowpark(
|
299
293
|
self,
|
300
294
|
dataset: DataFrame,
|
301
295
|
inference_method: str,
|
302
296
|
) -> List[str]:
|
303
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
297
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
298
|
+
return the available package that exists in the snowflake anaconda channel
|
299
|
+
|
300
|
+
Args:
|
301
|
+
dataset: snowpark dataframe
|
302
|
+
inference_method: the inference method such as predict, score...
|
303
|
+
|
304
|
+
Raises:
|
305
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
306
|
+
SnowflakeMLException: If the session is None, raise error
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
A list of available package that exists in the snowflake anaconda channel
|
304
310
|
"""
|
305
311
|
if not self._is_fitted:
|
306
312
|
raise exceptions.SnowflakeMLException(
|
@@ -374,7 +380,7 @@ class LabelPropagation(BaseTransformer):
|
|
374
380
|
transform_kwargs = dict(
|
375
381
|
session = dataset._session,
|
376
382
|
dependencies = self._deps,
|
377
|
-
|
383
|
+
drop_input_cols = self._drop_input_cols,
|
378
384
|
expected_output_cols_type = expected_type_inferred,
|
379
385
|
)
|
380
386
|
|
@@ -434,16 +440,16 @@ class LabelPropagation(BaseTransformer):
|
|
434
440
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
435
441
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
436
442
|
# each row containing a list of values.
|
437
|
-
expected_dtype = "
|
443
|
+
expected_dtype = "array"
|
438
444
|
|
439
445
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
440
446
|
if expected_dtype == "":
|
441
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
447
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
442
448
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
443
|
-
expected_dtype = "
|
444
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
449
|
+
expected_dtype = "array"
|
450
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
445
451
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
446
|
-
expected_dtype = "
|
452
|
+
expected_dtype = "array"
|
447
453
|
else:
|
448
454
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
449
455
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -461,7 +467,7 @@ class LabelPropagation(BaseTransformer):
|
|
461
467
|
transform_kwargs = dict(
|
462
468
|
session = dataset._session,
|
463
469
|
dependencies = self._deps,
|
464
|
-
|
470
|
+
drop_input_cols = self._drop_input_cols,
|
465
471
|
expected_output_cols_type = expected_dtype,
|
466
472
|
)
|
467
473
|
|
@@ -512,7 +518,7 @@ class LabelPropagation(BaseTransformer):
|
|
512
518
|
subproject=_SUBPROJECT,
|
513
519
|
)
|
514
520
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
515
|
-
|
521
|
+
drop_input_cols=self._drop_input_cols,
|
516
522
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
517
523
|
)
|
518
524
|
self._sklearn_object = fitted_estimator
|
@@ -530,44 +536,6 @@ class LabelPropagation(BaseTransformer):
|
|
530
536
|
assert self._sklearn_object is not None
|
531
537
|
return self._sklearn_object.embedding_
|
532
538
|
|
533
|
-
|
534
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
535
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
536
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
537
|
-
"""
|
538
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
539
|
-
if output_cols:
|
540
|
-
output_cols = [
|
541
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
542
|
-
for c in output_cols
|
543
|
-
]
|
544
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
545
|
-
output_cols = [output_cols_prefix]
|
546
|
-
elif self._sklearn_object is not None:
|
547
|
-
classes = self._sklearn_object.classes_
|
548
|
-
if isinstance(classes, numpy.ndarray):
|
549
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
550
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
551
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
552
|
-
output_cols = []
|
553
|
-
for i, cl in enumerate(classes):
|
554
|
-
# For binary classification, there is only one output column for each class
|
555
|
-
# ndarray as the two classes are complementary.
|
556
|
-
if len(cl) == 2:
|
557
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
558
|
-
else:
|
559
|
-
output_cols.extend([
|
560
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
561
|
-
])
|
562
|
-
else:
|
563
|
-
output_cols = []
|
564
|
-
|
565
|
-
# Make sure column names are valid snowflake identifiers.
|
566
|
-
assert output_cols is not None # Make MyPy happy
|
567
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
568
|
-
|
569
|
-
return rv
|
570
|
-
|
571
539
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
572
540
|
@telemetry.send_api_usage_telemetry(
|
573
541
|
project=_PROJECT,
|
@@ -609,7 +577,7 @@ class LabelPropagation(BaseTransformer):
|
|
609
577
|
transform_kwargs = dict(
|
610
578
|
session=dataset._session,
|
611
579
|
dependencies=self._deps,
|
612
|
-
|
580
|
+
drop_input_cols = self._drop_input_cols,
|
613
581
|
expected_output_cols_type="float",
|
614
582
|
)
|
615
583
|
|
@@ -676,7 +644,7 @@ class LabelPropagation(BaseTransformer):
|
|
676
644
|
transform_kwargs = dict(
|
677
645
|
session=dataset._session,
|
678
646
|
dependencies=self._deps,
|
679
|
-
|
647
|
+
drop_input_cols = self._drop_input_cols,
|
680
648
|
expected_output_cols_type="float",
|
681
649
|
)
|
682
650
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -737,7 +705,7 @@ class LabelPropagation(BaseTransformer):
|
|
737
705
|
transform_kwargs = dict(
|
738
706
|
session=dataset._session,
|
739
707
|
dependencies=self._deps,
|
740
|
-
|
708
|
+
drop_input_cols = self._drop_input_cols,
|
741
709
|
expected_output_cols_type="float",
|
742
710
|
)
|
743
711
|
|
@@ -802,7 +770,7 @@ class LabelPropagation(BaseTransformer):
|
|
802
770
|
transform_kwargs = dict(
|
803
771
|
session=dataset._session,
|
804
772
|
dependencies=self._deps,
|
805
|
-
|
773
|
+
drop_input_cols = self._drop_input_cols,
|
806
774
|
expected_output_cols_type="float",
|
807
775
|
)
|
808
776
|
|
@@ -858,13 +826,17 @@ class LabelPropagation(BaseTransformer):
|
|
858
826
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
859
827
|
|
860
828
|
if isinstance(dataset, DataFrame):
|
829
|
+
self._deps = self._batch_inference_validate_snowpark(
|
830
|
+
dataset=dataset,
|
831
|
+
inference_method="score",
|
832
|
+
)
|
861
833
|
selected_cols = self._get_active_columns()
|
862
834
|
if len(selected_cols) > 0:
|
863
835
|
dataset = dataset.select(selected_cols)
|
864
836
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
865
837
|
transform_kwargs = dict(
|
866
838
|
session=dataset._session,
|
867
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
839
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
868
840
|
score_sproc_imports=['sklearn'],
|
869
841
|
)
|
870
842
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -938,9 +910,9 @@ class LabelPropagation(BaseTransformer):
|
|
938
910
|
transform_kwargs = dict(
|
939
911
|
session = dataset._session,
|
940
912
|
dependencies = self._deps,
|
941
|
-
|
942
|
-
expected_output_cols_type
|
943
|
-
n_neighbors =
|
913
|
+
drop_input_cols = self._drop_input_cols,
|
914
|
+
expected_output_cols_type="array",
|
915
|
+
n_neighbors = n_neighbors,
|
944
916
|
return_distance = return_distance
|
945
917
|
)
|
946
918
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -298,18 +298,24 @@ class LabelSpreading(BaseTransformer):
|
|
298
298
|
self._get_model_signatures(dataset)
|
299
299
|
return self
|
300
300
|
|
301
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
302
|
-
if self._drop_input_cols:
|
303
|
-
return []
|
304
|
-
else:
|
305
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
306
|
-
|
307
301
|
def _batch_inference_validate_snowpark(
|
308
302
|
self,
|
309
303
|
dataset: DataFrame,
|
310
304
|
inference_method: str,
|
311
305
|
) -> List[str]:
|
312
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
306
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
307
|
+
return the available package that exists in the snowflake anaconda channel
|
308
|
+
|
309
|
+
Args:
|
310
|
+
dataset: snowpark dataframe
|
311
|
+
inference_method: the inference method such as predict, score...
|
312
|
+
|
313
|
+
Raises:
|
314
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
315
|
+
SnowflakeMLException: If the session is None, raise error
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
A list of available package that exists in the snowflake anaconda channel
|
313
319
|
"""
|
314
320
|
if not self._is_fitted:
|
315
321
|
raise exceptions.SnowflakeMLException(
|
@@ -383,7 +389,7 @@ class LabelSpreading(BaseTransformer):
|
|
383
389
|
transform_kwargs = dict(
|
384
390
|
session = dataset._session,
|
385
391
|
dependencies = self._deps,
|
386
|
-
|
392
|
+
drop_input_cols = self._drop_input_cols,
|
387
393
|
expected_output_cols_type = expected_type_inferred,
|
388
394
|
)
|
389
395
|
|
@@ -443,16 +449,16 @@ class LabelSpreading(BaseTransformer):
|
|
443
449
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
444
450
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
445
451
|
# each row containing a list of values.
|
446
|
-
expected_dtype = "
|
452
|
+
expected_dtype = "array"
|
447
453
|
|
448
454
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
449
455
|
if expected_dtype == "":
|
450
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
456
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
451
457
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
452
|
-
expected_dtype = "
|
453
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
458
|
+
expected_dtype = "array"
|
459
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
454
460
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
455
|
-
expected_dtype = "
|
461
|
+
expected_dtype = "array"
|
456
462
|
else:
|
457
463
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
458
464
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -470,7 +476,7 @@ class LabelSpreading(BaseTransformer):
|
|
470
476
|
transform_kwargs = dict(
|
471
477
|
session = dataset._session,
|
472
478
|
dependencies = self._deps,
|
473
|
-
|
479
|
+
drop_input_cols = self._drop_input_cols,
|
474
480
|
expected_output_cols_type = expected_dtype,
|
475
481
|
)
|
476
482
|
|
@@ -521,7 +527,7 @@ class LabelSpreading(BaseTransformer):
|
|
521
527
|
subproject=_SUBPROJECT,
|
522
528
|
)
|
523
529
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
524
|
-
|
530
|
+
drop_input_cols=self._drop_input_cols,
|
525
531
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
526
532
|
)
|
527
533
|
self._sklearn_object = fitted_estimator
|
@@ -539,44 +545,6 @@ class LabelSpreading(BaseTransformer):
|
|
539
545
|
assert self._sklearn_object is not None
|
540
546
|
return self._sklearn_object.embedding_
|
541
547
|
|
542
|
-
|
543
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
544
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
545
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
546
|
-
"""
|
547
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
548
|
-
if output_cols:
|
549
|
-
output_cols = [
|
550
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
551
|
-
for c in output_cols
|
552
|
-
]
|
553
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
554
|
-
output_cols = [output_cols_prefix]
|
555
|
-
elif self._sklearn_object is not None:
|
556
|
-
classes = self._sklearn_object.classes_
|
557
|
-
if isinstance(classes, numpy.ndarray):
|
558
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
559
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
560
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
561
|
-
output_cols = []
|
562
|
-
for i, cl in enumerate(classes):
|
563
|
-
# For binary classification, there is only one output column for each class
|
564
|
-
# ndarray as the two classes are complementary.
|
565
|
-
if len(cl) == 2:
|
566
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
567
|
-
else:
|
568
|
-
output_cols.extend([
|
569
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
570
|
-
])
|
571
|
-
else:
|
572
|
-
output_cols = []
|
573
|
-
|
574
|
-
# Make sure column names are valid snowflake identifiers.
|
575
|
-
assert output_cols is not None # Make MyPy happy
|
576
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
577
|
-
|
578
|
-
return rv
|
579
|
-
|
580
548
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
581
549
|
@telemetry.send_api_usage_telemetry(
|
582
550
|
project=_PROJECT,
|
@@ -618,7 +586,7 @@ class LabelSpreading(BaseTransformer):
|
|
618
586
|
transform_kwargs = dict(
|
619
587
|
session=dataset._session,
|
620
588
|
dependencies=self._deps,
|
621
|
-
|
589
|
+
drop_input_cols = self._drop_input_cols,
|
622
590
|
expected_output_cols_type="float",
|
623
591
|
)
|
624
592
|
|
@@ -685,7 +653,7 @@ class LabelSpreading(BaseTransformer):
|
|
685
653
|
transform_kwargs = dict(
|
686
654
|
session=dataset._session,
|
687
655
|
dependencies=self._deps,
|
688
|
-
|
656
|
+
drop_input_cols = self._drop_input_cols,
|
689
657
|
expected_output_cols_type="float",
|
690
658
|
)
|
691
659
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -746,7 +714,7 @@ class LabelSpreading(BaseTransformer):
|
|
746
714
|
transform_kwargs = dict(
|
747
715
|
session=dataset._session,
|
748
716
|
dependencies=self._deps,
|
749
|
-
|
717
|
+
drop_input_cols = self._drop_input_cols,
|
750
718
|
expected_output_cols_type="float",
|
751
719
|
)
|
752
720
|
|
@@ -811,7 +779,7 @@ class LabelSpreading(BaseTransformer):
|
|
811
779
|
transform_kwargs = dict(
|
812
780
|
session=dataset._session,
|
813
781
|
dependencies=self._deps,
|
814
|
-
|
782
|
+
drop_input_cols = self._drop_input_cols,
|
815
783
|
expected_output_cols_type="float",
|
816
784
|
)
|
817
785
|
|
@@ -867,13 +835,17 @@ class LabelSpreading(BaseTransformer):
|
|
867
835
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
868
836
|
|
869
837
|
if isinstance(dataset, DataFrame):
|
838
|
+
self._deps = self._batch_inference_validate_snowpark(
|
839
|
+
dataset=dataset,
|
840
|
+
inference_method="score",
|
841
|
+
)
|
870
842
|
selected_cols = self._get_active_columns()
|
871
843
|
if len(selected_cols) > 0:
|
872
844
|
dataset = dataset.select(selected_cols)
|
873
845
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
874
846
|
transform_kwargs = dict(
|
875
847
|
session=dataset._session,
|
876
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
848
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
877
849
|
score_sproc_imports=['sklearn'],
|
878
850
|
)
|
879
851
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -947,9 +919,9 @@ class LabelSpreading(BaseTransformer):
|
|
947
919
|
transform_kwargs = dict(
|
948
920
|
session = dataset._session,
|
949
921
|
dependencies = self._deps,
|
950
|
-
|
951
|
-
expected_output_cols_type
|
952
|
-
n_neighbors =
|
922
|
+
drop_input_cols = self._drop_input_cols,
|
923
|
+
expected_output_cols_type="array",
|
924
|
+
n_neighbors = n_neighbors,
|
953
925
|
return_distance = return_distance
|
954
926
|
)
|
955
927
|
elif isinstance(dataset, pd.DataFrame):
|