snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -263,18 +263,24 @@ class OAS(BaseTransformer):
|
|
263
263
|
self._get_model_signatures(dataset)
|
264
264
|
return self
|
265
265
|
|
266
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
267
|
-
if self._drop_input_cols:
|
268
|
-
return []
|
269
|
-
else:
|
270
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
271
|
-
|
272
266
|
def _batch_inference_validate_snowpark(
|
273
267
|
self,
|
274
268
|
dataset: DataFrame,
|
275
269
|
inference_method: str,
|
276
270
|
) -> List[str]:
|
277
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
271
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
272
|
+
return the available package that exists in the snowflake anaconda channel
|
273
|
+
|
274
|
+
Args:
|
275
|
+
dataset: snowpark dataframe
|
276
|
+
inference_method: the inference method such as predict, score...
|
277
|
+
|
278
|
+
Raises:
|
279
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
280
|
+
SnowflakeMLException: If the session is None, raise error
|
281
|
+
|
282
|
+
Returns:
|
283
|
+
A list of available package that exists in the snowflake anaconda channel
|
278
284
|
"""
|
279
285
|
if not self._is_fitted:
|
280
286
|
raise exceptions.SnowflakeMLException(
|
@@ -346,7 +352,7 @@ class OAS(BaseTransformer):
|
|
346
352
|
transform_kwargs = dict(
|
347
353
|
session = dataset._session,
|
348
354
|
dependencies = self._deps,
|
349
|
-
|
355
|
+
drop_input_cols = self._drop_input_cols,
|
350
356
|
expected_output_cols_type = expected_type_inferred,
|
351
357
|
)
|
352
358
|
|
@@ -406,16 +412,16 @@ class OAS(BaseTransformer):
|
|
406
412
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
407
413
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
408
414
|
# each row containing a list of values.
|
409
|
-
expected_dtype = "
|
415
|
+
expected_dtype = "array"
|
410
416
|
|
411
417
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
412
418
|
if expected_dtype == "":
|
413
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
419
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
414
420
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
415
|
-
expected_dtype = "
|
416
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
421
|
+
expected_dtype = "array"
|
422
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
417
423
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
418
|
-
expected_dtype = "
|
424
|
+
expected_dtype = "array"
|
419
425
|
else:
|
420
426
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
421
427
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -433,7 +439,7 @@ class OAS(BaseTransformer):
|
|
433
439
|
transform_kwargs = dict(
|
434
440
|
session = dataset._session,
|
435
441
|
dependencies = self._deps,
|
436
|
-
|
442
|
+
drop_input_cols = self._drop_input_cols,
|
437
443
|
expected_output_cols_type = expected_dtype,
|
438
444
|
)
|
439
445
|
|
@@ -484,7 +490,7 @@ class OAS(BaseTransformer):
|
|
484
490
|
subproject=_SUBPROJECT,
|
485
491
|
)
|
486
492
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
487
|
-
|
493
|
+
drop_input_cols=self._drop_input_cols,
|
488
494
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
489
495
|
)
|
490
496
|
self._sklearn_object = fitted_estimator
|
@@ -502,44 +508,6 @@ class OAS(BaseTransformer):
|
|
502
508
|
assert self._sklearn_object is not None
|
503
509
|
return self._sklearn_object.embedding_
|
504
510
|
|
505
|
-
|
506
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
507
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
508
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
509
|
-
"""
|
510
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
511
|
-
if output_cols:
|
512
|
-
output_cols = [
|
513
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
514
|
-
for c in output_cols
|
515
|
-
]
|
516
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
517
|
-
output_cols = [output_cols_prefix]
|
518
|
-
elif self._sklearn_object is not None:
|
519
|
-
classes = self._sklearn_object.classes_
|
520
|
-
if isinstance(classes, numpy.ndarray):
|
521
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
522
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
523
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
524
|
-
output_cols = []
|
525
|
-
for i, cl in enumerate(classes):
|
526
|
-
# For binary classification, there is only one output column for each class
|
527
|
-
# ndarray as the two classes are complementary.
|
528
|
-
if len(cl) == 2:
|
529
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
530
|
-
else:
|
531
|
-
output_cols.extend([
|
532
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
533
|
-
])
|
534
|
-
else:
|
535
|
-
output_cols = []
|
536
|
-
|
537
|
-
# Make sure column names are valid snowflake identifiers.
|
538
|
-
assert output_cols is not None # Make MyPy happy
|
539
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
540
|
-
|
541
|
-
return rv
|
542
|
-
|
543
511
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
544
512
|
@telemetry.send_api_usage_telemetry(
|
545
513
|
project=_PROJECT,
|
@@ -579,7 +547,7 @@ class OAS(BaseTransformer):
|
|
579
547
|
transform_kwargs = dict(
|
580
548
|
session=dataset._session,
|
581
549
|
dependencies=self._deps,
|
582
|
-
|
550
|
+
drop_input_cols = self._drop_input_cols,
|
583
551
|
expected_output_cols_type="float",
|
584
552
|
)
|
585
553
|
|
@@ -644,7 +612,7 @@ class OAS(BaseTransformer):
|
|
644
612
|
transform_kwargs = dict(
|
645
613
|
session=dataset._session,
|
646
614
|
dependencies=self._deps,
|
647
|
-
|
615
|
+
drop_input_cols = self._drop_input_cols,
|
648
616
|
expected_output_cols_type="float",
|
649
617
|
)
|
650
618
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -705,7 +673,7 @@ class OAS(BaseTransformer):
|
|
705
673
|
transform_kwargs = dict(
|
706
674
|
session=dataset._session,
|
707
675
|
dependencies=self._deps,
|
708
|
-
|
676
|
+
drop_input_cols = self._drop_input_cols,
|
709
677
|
expected_output_cols_type="float",
|
710
678
|
)
|
711
679
|
|
@@ -770,7 +738,7 @@ class OAS(BaseTransformer):
|
|
770
738
|
transform_kwargs = dict(
|
771
739
|
session=dataset._session,
|
772
740
|
dependencies=self._deps,
|
773
|
-
|
741
|
+
drop_input_cols = self._drop_input_cols,
|
774
742
|
expected_output_cols_type="float",
|
775
743
|
)
|
776
744
|
|
@@ -826,13 +794,17 @@ class OAS(BaseTransformer):
|
|
826
794
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
827
795
|
|
828
796
|
if isinstance(dataset, DataFrame):
|
797
|
+
self._deps = self._batch_inference_validate_snowpark(
|
798
|
+
dataset=dataset,
|
799
|
+
inference_method="score",
|
800
|
+
)
|
829
801
|
selected_cols = self._get_active_columns()
|
830
802
|
if len(selected_cols) > 0:
|
831
803
|
dataset = dataset.select(selected_cols)
|
832
804
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
833
805
|
transform_kwargs = dict(
|
834
806
|
session=dataset._session,
|
835
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
807
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
836
808
|
score_sproc_imports=['sklearn'],
|
837
809
|
)
|
838
810
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -906,9 +878,9 @@ class OAS(BaseTransformer):
|
|
906
878
|
transform_kwargs = dict(
|
907
879
|
session = dataset._session,
|
908
880
|
dependencies = self._deps,
|
909
|
-
|
910
|
-
expected_output_cols_type
|
911
|
-
n_neighbors =
|
881
|
+
drop_input_cols = self._drop_input_cols,
|
882
|
+
expected_output_cols_type="array",
|
883
|
+
n_neighbors = n_neighbors,
|
912
884
|
return_distance = return_distance
|
913
885
|
)
|
914
886
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -269,18 +269,24 @@ class ShrunkCovariance(BaseTransformer):
|
|
269
269
|
self._get_model_signatures(dataset)
|
270
270
|
return self
|
271
271
|
|
272
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
273
|
-
if self._drop_input_cols:
|
274
|
-
return []
|
275
|
-
else:
|
276
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
277
|
-
|
278
272
|
def _batch_inference_validate_snowpark(
|
279
273
|
self,
|
280
274
|
dataset: DataFrame,
|
281
275
|
inference_method: str,
|
282
276
|
) -> List[str]:
|
283
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
277
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
278
|
+
return the available package that exists in the snowflake anaconda channel
|
279
|
+
|
280
|
+
Args:
|
281
|
+
dataset: snowpark dataframe
|
282
|
+
inference_method: the inference method such as predict, score...
|
283
|
+
|
284
|
+
Raises:
|
285
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
286
|
+
SnowflakeMLException: If the session is None, raise error
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
A list of available package that exists in the snowflake anaconda channel
|
284
290
|
"""
|
285
291
|
if not self._is_fitted:
|
286
292
|
raise exceptions.SnowflakeMLException(
|
@@ -352,7 +358,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
352
358
|
transform_kwargs = dict(
|
353
359
|
session = dataset._session,
|
354
360
|
dependencies = self._deps,
|
355
|
-
|
361
|
+
drop_input_cols = self._drop_input_cols,
|
356
362
|
expected_output_cols_type = expected_type_inferred,
|
357
363
|
)
|
358
364
|
|
@@ -412,16 +418,16 @@ class ShrunkCovariance(BaseTransformer):
|
|
412
418
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
413
419
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
414
420
|
# each row containing a list of values.
|
415
|
-
expected_dtype = "
|
421
|
+
expected_dtype = "array"
|
416
422
|
|
417
423
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
418
424
|
if expected_dtype == "":
|
419
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
425
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
420
426
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
421
|
-
expected_dtype = "
|
422
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
427
|
+
expected_dtype = "array"
|
428
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
423
429
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
424
|
-
expected_dtype = "
|
430
|
+
expected_dtype = "array"
|
425
431
|
else:
|
426
432
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
427
433
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -439,7 +445,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
439
445
|
transform_kwargs = dict(
|
440
446
|
session = dataset._session,
|
441
447
|
dependencies = self._deps,
|
442
|
-
|
448
|
+
drop_input_cols = self._drop_input_cols,
|
443
449
|
expected_output_cols_type = expected_dtype,
|
444
450
|
)
|
445
451
|
|
@@ -490,7 +496,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
490
496
|
subproject=_SUBPROJECT,
|
491
497
|
)
|
492
498
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
493
|
-
|
499
|
+
drop_input_cols=self._drop_input_cols,
|
494
500
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
495
501
|
)
|
496
502
|
self._sklearn_object = fitted_estimator
|
@@ -508,44 +514,6 @@ class ShrunkCovariance(BaseTransformer):
|
|
508
514
|
assert self._sklearn_object is not None
|
509
515
|
return self._sklearn_object.embedding_
|
510
516
|
|
511
|
-
|
512
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
513
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
514
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
515
|
-
"""
|
516
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
517
|
-
if output_cols:
|
518
|
-
output_cols = [
|
519
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
520
|
-
for c in output_cols
|
521
|
-
]
|
522
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
523
|
-
output_cols = [output_cols_prefix]
|
524
|
-
elif self._sklearn_object is not None:
|
525
|
-
classes = self._sklearn_object.classes_
|
526
|
-
if isinstance(classes, numpy.ndarray):
|
527
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
528
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
529
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
530
|
-
output_cols = []
|
531
|
-
for i, cl in enumerate(classes):
|
532
|
-
# For binary classification, there is only one output column for each class
|
533
|
-
# ndarray as the two classes are complementary.
|
534
|
-
if len(cl) == 2:
|
535
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
536
|
-
else:
|
537
|
-
output_cols.extend([
|
538
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
539
|
-
])
|
540
|
-
else:
|
541
|
-
output_cols = []
|
542
|
-
|
543
|
-
# Make sure column names are valid snowflake identifiers.
|
544
|
-
assert output_cols is not None # Make MyPy happy
|
545
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
546
|
-
|
547
|
-
return rv
|
548
|
-
|
549
517
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
550
518
|
@telemetry.send_api_usage_telemetry(
|
551
519
|
project=_PROJECT,
|
@@ -585,7 +553,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
585
553
|
transform_kwargs = dict(
|
586
554
|
session=dataset._session,
|
587
555
|
dependencies=self._deps,
|
588
|
-
|
556
|
+
drop_input_cols = self._drop_input_cols,
|
589
557
|
expected_output_cols_type="float",
|
590
558
|
)
|
591
559
|
|
@@ -650,7 +618,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
650
618
|
transform_kwargs = dict(
|
651
619
|
session=dataset._session,
|
652
620
|
dependencies=self._deps,
|
653
|
-
|
621
|
+
drop_input_cols = self._drop_input_cols,
|
654
622
|
expected_output_cols_type="float",
|
655
623
|
)
|
656
624
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -711,7 +679,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
711
679
|
transform_kwargs = dict(
|
712
680
|
session=dataset._session,
|
713
681
|
dependencies=self._deps,
|
714
|
-
|
682
|
+
drop_input_cols = self._drop_input_cols,
|
715
683
|
expected_output_cols_type="float",
|
716
684
|
)
|
717
685
|
|
@@ -776,7 +744,7 @@ class ShrunkCovariance(BaseTransformer):
|
|
776
744
|
transform_kwargs = dict(
|
777
745
|
session=dataset._session,
|
778
746
|
dependencies=self._deps,
|
779
|
-
|
747
|
+
drop_input_cols = self._drop_input_cols,
|
780
748
|
expected_output_cols_type="float",
|
781
749
|
)
|
782
750
|
|
@@ -832,13 +800,17 @@ class ShrunkCovariance(BaseTransformer):
|
|
832
800
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
833
801
|
|
834
802
|
if isinstance(dataset, DataFrame):
|
803
|
+
self._deps = self._batch_inference_validate_snowpark(
|
804
|
+
dataset=dataset,
|
805
|
+
inference_method="score",
|
806
|
+
)
|
835
807
|
selected_cols = self._get_active_columns()
|
836
808
|
if len(selected_cols) > 0:
|
837
809
|
dataset = dataset.select(selected_cols)
|
838
810
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
839
811
|
transform_kwargs = dict(
|
840
812
|
session=dataset._session,
|
841
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
813
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
842
814
|
score_sproc_imports=['sklearn'],
|
843
815
|
)
|
844
816
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -912,9 +884,9 @@ class ShrunkCovariance(BaseTransformer):
|
|
912
884
|
transform_kwargs = dict(
|
913
885
|
session = dataset._session,
|
914
886
|
dependencies = self._deps,
|
915
|
-
|
916
|
-
expected_output_cols_type
|
917
|
-
n_neighbors =
|
887
|
+
drop_input_cols = self._drop_input_cols,
|
888
|
+
expected_output_cols_type="array",
|
889
|
+
n_neighbors = n_neighbors,
|
918
890
|
return_distance = return_distance
|
919
891
|
)
|
920
892
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -375,18 +375,24 @@ class DictionaryLearning(BaseTransformer):
|
|
375
375
|
self._get_model_signatures(dataset)
|
376
376
|
return self
|
377
377
|
|
378
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
379
|
-
if self._drop_input_cols:
|
380
|
-
return []
|
381
|
-
else:
|
382
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
383
|
-
|
384
378
|
def _batch_inference_validate_snowpark(
|
385
379
|
self,
|
386
380
|
dataset: DataFrame,
|
387
381
|
inference_method: str,
|
388
382
|
) -> List[str]:
|
389
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
383
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
384
|
+
return the available package that exists in the snowflake anaconda channel
|
385
|
+
|
386
|
+
Args:
|
387
|
+
dataset: snowpark dataframe
|
388
|
+
inference_method: the inference method such as predict, score...
|
389
|
+
|
390
|
+
Raises:
|
391
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
392
|
+
SnowflakeMLException: If the session is None, raise error
|
393
|
+
|
394
|
+
Returns:
|
395
|
+
A list of available package that exists in the snowflake anaconda channel
|
390
396
|
"""
|
391
397
|
if not self._is_fitted:
|
392
398
|
raise exceptions.SnowflakeMLException(
|
@@ -458,7 +464,7 @@ class DictionaryLearning(BaseTransformer):
|
|
458
464
|
transform_kwargs = dict(
|
459
465
|
session = dataset._session,
|
460
466
|
dependencies = self._deps,
|
461
|
-
|
467
|
+
drop_input_cols = self._drop_input_cols,
|
462
468
|
expected_output_cols_type = expected_type_inferred,
|
463
469
|
)
|
464
470
|
|
@@ -520,16 +526,16 @@ class DictionaryLearning(BaseTransformer):
|
|
520
526
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
521
527
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
522
528
|
# each row containing a list of values.
|
523
|
-
expected_dtype = "
|
529
|
+
expected_dtype = "array"
|
524
530
|
|
525
531
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
526
532
|
if expected_dtype == "":
|
527
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
533
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
528
534
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
529
|
-
expected_dtype = "
|
530
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
535
|
+
expected_dtype = "array"
|
536
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
531
537
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
532
|
-
expected_dtype = "
|
538
|
+
expected_dtype = "array"
|
533
539
|
else:
|
534
540
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
535
541
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -547,7 +553,7 @@ class DictionaryLearning(BaseTransformer):
|
|
547
553
|
transform_kwargs = dict(
|
548
554
|
session = dataset._session,
|
549
555
|
dependencies = self._deps,
|
550
|
-
|
556
|
+
drop_input_cols = self._drop_input_cols,
|
551
557
|
expected_output_cols_type = expected_dtype,
|
552
558
|
)
|
553
559
|
|
@@ -598,7 +604,7 @@ class DictionaryLearning(BaseTransformer):
|
|
598
604
|
subproject=_SUBPROJECT,
|
599
605
|
)
|
600
606
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
601
|
-
|
607
|
+
drop_input_cols=self._drop_input_cols,
|
602
608
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
603
609
|
)
|
604
610
|
self._sklearn_object = fitted_estimator
|
@@ -616,44 +622,6 @@ class DictionaryLearning(BaseTransformer):
|
|
616
622
|
assert self._sklearn_object is not None
|
617
623
|
return self._sklearn_object.embedding_
|
618
624
|
|
619
|
-
|
620
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
621
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
622
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
623
|
-
"""
|
624
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
625
|
-
if output_cols:
|
626
|
-
output_cols = [
|
627
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
628
|
-
for c in output_cols
|
629
|
-
]
|
630
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
631
|
-
output_cols = [output_cols_prefix]
|
632
|
-
elif self._sklearn_object is not None:
|
633
|
-
classes = self._sklearn_object.classes_
|
634
|
-
if isinstance(classes, numpy.ndarray):
|
635
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
636
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
637
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
638
|
-
output_cols = []
|
639
|
-
for i, cl in enumerate(classes):
|
640
|
-
# For binary classification, there is only one output column for each class
|
641
|
-
# ndarray as the two classes are complementary.
|
642
|
-
if len(cl) == 2:
|
643
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
644
|
-
else:
|
645
|
-
output_cols.extend([
|
646
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
647
|
-
])
|
648
|
-
else:
|
649
|
-
output_cols = []
|
650
|
-
|
651
|
-
# Make sure column names are valid snowflake identifiers.
|
652
|
-
assert output_cols is not None # Make MyPy happy
|
653
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
654
|
-
|
655
|
-
return rv
|
656
|
-
|
657
625
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
658
626
|
@telemetry.send_api_usage_telemetry(
|
659
627
|
project=_PROJECT,
|
@@ -693,7 +661,7 @@ class DictionaryLearning(BaseTransformer):
|
|
693
661
|
transform_kwargs = dict(
|
694
662
|
session=dataset._session,
|
695
663
|
dependencies=self._deps,
|
696
|
-
|
664
|
+
drop_input_cols = self._drop_input_cols,
|
697
665
|
expected_output_cols_type="float",
|
698
666
|
)
|
699
667
|
|
@@ -758,7 +726,7 @@ class DictionaryLearning(BaseTransformer):
|
|
758
726
|
transform_kwargs = dict(
|
759
727
|
session=dataset._session,
|
760
728
|
dependencies=self._deps,
|
761
|
-
|
729
|
+
drop_input_cols = self._drop_input_cols,
|
762
730
|
expected_output_cols_type="float",
|
763
731
|
)
|
764
732
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -819,7 +787,7 @@ class DictionaryLearning(BaseTransformer):
|
|
819
787
|
transform_kwargs = dict(
|
820
788
|
session=dataset._session,
|
821
789
|
dependencies=self._deps,
|
822
|
-
|
790
|
+
drop_input_cols = self._drop_input_cols,
|
823
791
|
expected_output_cols_type="float",
|
824
792
|
)
|
825
793
|
|
@@ -884,7 +852,7 @@ class DictionaryLearning(BaseTransformer):
|
|
884
852
|
transform_kwargs = dict(
|
885
853
|
session=dataset._session,
|
886
854
|
dependencies=self._deps,
|
887
|
-
|
855
|
+
drop_input_cols = self._drop_input_cols,
|
888
856
|
expected_output_cols_type="float",
|
889
857
|
)
|
890
858
|
|
@@ -938,13 +906,17 @@ class DictionaryLearning(BaseTransformer):
|
|
938
906
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
939
907
|
|
940
908
|
if isinstance(dataset, DataFrame):
|
909
|
+
self._deps = self._batch_inference_validate_snowpark(
|
910
|
+
dataset=dataset,
|
911
|
+
inference_method="score",
|
912
|
+
)
|
941
913
|
selected_cols = self._get_active_columns()
|
942
914
|
if len(selected_cols) > 0:
|
943
915
|
dataset = dataset.select(selected_cols)
|
944
916
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
945
917
|
transform_kwargs = dict(
|
946
918
|
session=dataset._session,
|
947
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
919
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
948
920
|
score_sproc_imports=['sklearn'],
|
949
921
|
)
|
950
922
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1018,9 +990,9 @@ class DictionaryLearning(BaseTransformer):
|
|
1018
990
|
transform_kwargs = dict(
|
1019
991
|
session = dataset._session,
|
1020
992
|
dependencies = self._deps,
|
1021
|
-
|
1022
|
-
expected_output_cols_type
|
1023
|
-
n_neighbors =
|
993
|
+
drop_input_cols = self._drop_input_cols,
|
994
|
+
expected_output_cols_type="array",
|
995
|
+
n_neighbors = n_neighbors,
|
1024
996
|
return_distance = return_distance
|
1025
997
|
)
|
1026
998
|
elif isinstance(dataset, pd.DataFrame):
|