snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -284,18 +284,24 @@ class PolynomialCountSketch(BaseTransformer):
|
|
284
284
|
self._get_model_signatures(dataset)
|
285
285
|
return self
|
286
286
|
|
287
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
288
|
-
if self._drop_input_cols:
|
289
|
-
return []
|
290
|
-
else:
|
291
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
292
|
-
|
293
287
|
def _batch_inference_validate_snowpark(
|
294
288
|
self,
|
295
289
|
dataset: DataFrame,
|
296
290
|
inference_method: str,
|
297
291
|
) -> List[str]:
|
298
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
292
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
293
|
+
return the available package that exists in the snowflake anaconda channel
|
294
|
+
|
295
|
+
Args:
|
296
|
+
dataset: snowpark dataframe
|
297
|
+
inference_method: the inference method such as predict, score...
|
298
|
+
|
299
|
+
Raises:
|
300
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
301
|
+
SnowflakeMLException: If the session is None, raise error
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
A list of available package that exists in the snowflake anaconda channel
|
299
305
|
"""
|
300
306
|
if not self._is_fitted:
|
301
307
|
raise exceptions.SnowflakeMLException(
|
@@ -367,7 +373,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
367
373
|
transform_kwargs = dict(
|
368
374
|
session = dataset._session,
|
369
375
|
dependencies = self._deps,
|
370
|
-
|
376
|
+
drop_input_cols = self._drop_input_cols,
|
371
377
|
expected_output_cols_type = expected_type_inferred,
|
372
378
|
)
|
373
379
|
|
@@ -429,16 +435,16 @@ class PolynomialCountSketch(BaseTransformer):
|
|
429
435
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
430
436
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
431
437
|
# each row containing a list of values.
|
432
|
-
expected_dtype = "
|
438
|
+
expected_dtype = "array"
|
433
439
|
|
434
440
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
435
441
|
if expected_dtype == "":
|
436
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
442
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
437
443
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
438
|
-
expected_dtype = "
|
439
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
444
|
+
expected_dtype = "array"
|
445
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
440
446
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
441
|
-
expected_dtype = "
|
447
|
+
expected_dtype = "array"
|
442
448
|
else:
|
443
449
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
444
450
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -456,7 +462,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
456
462
|
transform_kwargs = dict(
|
457
463
|
session = dataset._session,
|
458
464
|
dependencies = self._deps,
|
459
|
-
|
465
|
+
drop_input_cols = self._drop_input_cols,
|
460
466
|
expected_output_cols_type = expected_dtype,
|
461
467
|
)
|
462
468
|
|
@@ -507,7 +513,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
507
513
|
subproject=_SUBPROJECT,
|
508
514
|
)
|
509
515
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
510
|
-
|
516
|
+
drop_input_cols=self._drop_input_cols,
|
511
517
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
512
518
|
)
|
513
519
|
self._sklearn_object = fitted_estimator
|
@@ -525,44 +531,6 @@ class PolynomialCountSketch(BaseTransformer):
|
|
525
531
|
assert self._sklearn_object is not None
|
526
532
|
return self._sklearn_object.embedding_
|
527
533
|
|
528
|
-
|
529
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
530
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
531
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
532
|
-
"""
|
533
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
534
|
-
if output_cols:
|
535
|
-
output_cols = [
|
536
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
537
|
-
for c in output_cols
|
538
|
-
]
|
539
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
540
|
-
output_cols = [output_cols_prefix]
|
541
|
-
elif self._sklearn_object is not None:
|
542
|
-
classes = self._sklearn_object.classes_
|
543
|
-
if isinstance(classes, numpy.ndarray):
|
544
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
545
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
546
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
547
|
-
output_cols = []
|
548
|
-
for i, cl in enumerate(classes):
|
549
|
-
# For binary classification, there is only one output column for each class
|
550
|
-
# ndarray as the two classes are complementary.
|
551
|
-
if len(cl) == 2:
|
552
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
553
|
-
else:
|
554
|
-
output_cols.extend([
|
555
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
556
|
-
])
|
557
|
-
else:
|
558
|
-
output_cols = []
|
559
|
-
|
560
|
-
# Make sure column names are valid snowflake identifiers.
|
561
|
-
assert output_cols is not None # Make MyPy happy
|
562
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
563
|
-
|
564
|
-
return rv
|
565
|
-
|
566
534
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
567
535
|
@telemetry.send_api_usage_telemetry(
|
568
536
|
project=_PROJECT,
|
@@ -602,7 +570,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
602
570
|
transform_kwargs = dict(
|
603
571
|
session=dataset._session,
|
604
572
|
dependencies=self._deps,
|
605
|
-
|
573
|
+
drop_input_cols = self._drop_input_cols,
|
606
574
|
expected_output_cols_type="float",
|
607
575
|
)
|
608
576
|
|
@@ -667,7 +635,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
667
635
|
transform_kwargs = dict(
|
668
636
|
session=dataset._session,
|
669
637
|
dependencies=self._deps,
|
670
|
-
|
638
|
+
drop_input_cols = self._drop_input_cols,
|
671
639
|
expected_output_cols_type="float",
|
672
640
|
)
|
673
641
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -728,7 +696,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
728
696
|
transform_kwargs = dict(
|
729
697
|
session=dataset._session,
|
730
698
|
dependencies=self._deps,
|
731
|
-
|
699
|
+
drop_input_cols = self._drop_input_cols,
|
732
700
|
expected_output_cols_type="float",
|
733
701
|
)
|
734
702
|
|
@@ -793,7 +761,7 @@ class PolynomialCountSketch(BaseTransformer):
|
|
793
761
|
transform_kwargs = dict(
|
794
762
|
session=dataset._session,
|
795
763
|
dependencies=self._deps,
|
796
|
-
|
764
|
+
drop_input_cols = self._drop_input_cols,
|
797
765
|
expected_output_cols_type="float",
|
798
766
|
)
|
799
767
|
|
@@ -847,13 +815,17 @@ class PolynomialCountSketch(BaseTransformer):
|
|
847
815
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
848
816
|
|
849
817
|
if isinstance(dataset, DataFrame):
|
818
|
+
self._deps = self._batch_inference_validate_snowpark(
|
819
|
+
dataset=dataset,
|
820
|
+
inference_method="score",
|
821
|
+
)
|
850
822
|
selected_cols = self._get_active_columns()
|
851
823
|
if len(selected_cols) > 0:
|
852
824
|
dataset = dataset.select(selected_cols)
|
853
825
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
854
826
|
transform_kwargs = dict(
|
855
827
|
session=dataset._session,
|
856
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
828
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
857
829
|
score_sproc_imports=['sklearn'],
|
858
830
|
)
|
859
831
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -927,9 +899,9 @@ class PolynomialCountSketch(BaseTransformer):
|
|
927
899
|
transform_kwargs = dict(
|
928
900
|
session = dataset._session,
|
929
901
|
dependencies = self._deps,
|
930
|
-
|
931
|
-
expected_output_cols_type
|
932
|
-
n_neighbors =
|
902
|
+
drop_input_cols = self._drop_input_cols,
|
903
|
+
expected_output_cols_type="array",
|
904
|
+
n_neighbors = n_neighbors,
|
933
905
|
return_distance = return_distance
|
934
906
|
)
|
935
907
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -271,18 +271,24 @@ class RBFSampler(BaseTransformer):
|
|
271
271
|
self._get_model_signatures(dataset)
|
272
272
|
return self
|
273
273
|
|
274
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
275
|
-
if self._drop_input_cols:
|
276
|
-
return []
|
277
|
-
else:
|
278
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
279
|
-
|
280
274
|
def _batch_inference_validate_snowpark(
|
281
275
|
self,
|
282
276
|
dataset: DataFrame,
|
283
277
|
inference_method: str,
|
284
278
|
) -> List[str]:
|
285
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
279
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
280
|
+
return the available package that exists in the snowflake anaconda channel
|
281
|
+
|
282
|
+
Args:
|
283
|
+
dataset: snowpark dataframe
|
284
|
+
inference_method: the inference method such as predict, score...
|
285
|
+
|
286
|
+
Raises:
|
287
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
288
|
+
SnowflakeMLException: If the session is None, raise error
|
289
|
+
|
290
|
+
Returns:
|
291
|
+
A list of available package that exists in the snowflake anaconda channel
|
286
292
|
"""
|
287
293
|
if not self._is_fitted:
|
288
294
|
raise exceptions.SnowflakeMLException(
|
@@ -354,7 +360,7 @@ class RBFSampler(BaseTransformer):
|
|
354
360
|
transform_kwargs = dict(
|
355
361
|
session = dataset._session,
|
356
362
|
dependencies = self._deps,
|
357
|
-
|
363
|
+
drop_input_cols = self._drop_input_cols,
|
358
364
|
expected_output_cols_type = expected_type_inferred,
|
359
365
|
)
|
360
366
|
|
@@ -416,16 +422,16 @@ class RBFSampler(BaseTransformer):
|
|
416
422
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
417
423
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
418
424
|
# each row containing a list of values.
|
419
|
-
expected_dtype = "
|
425
|
+
expected_dtype = "array"
|
420
426
|
|
421
427
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
422
428
|
if expected_dtype == "":
|
423
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
429
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
424
430
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
425
|
-
expected_dtype = "
|
426
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
431
|
+
expected_dtype = "array"
|
432
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
427
433
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
428
|
-
expected_dtype = "
|
434
|
+
expected_dtype = "array"
|
429
435
|
else:
|
430
436
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
431
437
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -443,7 +449,7 @@ class RBFSampler(BaseTransformer):
|
|
443
449
|
transform_kwargs = dict(
|
444
450
|
session = dataset._session,
|
445
451
|
dependencies = self._deps,
|
446
|
-
|
452
|
+
drop_input_cols = self._drop_input_cols,
|
447
453
|
expected_output_cols_type = expected_dtype,
|
448
454
|
)
|
449
455
|
|
@@ -494,7 +500,7 @@ class RBFSampler(BaseTransformer):
|
|
494
500
|
subproject=_SUBPROJECT,
|
495
501
|
)
|
496
502
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
497
|
-
|
503
|
+
drop_input_cols=self._drop_input_cols,
|
498
504
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
499
505
|
)
|
500
506
|
self._sklearn_object = fitted_estimator
|
@@ -512,44 +518,6 @@ class RBFSampler(BaseTransformer):
|
|
512
518
|
assert self._sklearn_object is not None
|
513
519
|
return self._sklearn_object.embedding_
|
514
520
|
|
515
|
-
|
516
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
517
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
518
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
519
|
-
"""
|
520
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
521
|
-
if output_cols:
|
522
|
-
output_cols = [
|
523
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
524
|
-
for c in output_cols
|
525
|
-
]
|
526
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
527
|
-
output_cols = [output_cols_prefix]
|
528
|
-
elif self._sklearn_object is not None:
|
529
|
-
classes = self._sklearn_object.classes_
|
530
|
-
if isinstance(classes, numpy.ndarray):
|
531
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
532
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
533
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
534
|
-
output_cols = []
|
535
|
-
for i, cl in enumerate(classes):
|
536
|
-
# For binary classification, there is only one output column for each class
|
537
|
-
# ndarray as the two classes are complementary.
|
538
|
-
if len(cl) == 2:
|
539
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
540
|
-
else:
|
541
|
-
output_cols.extend([
|
542
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
543
|
-
])
|
544
|
-
else:
|
545
|
-
output_cols = []
|
546
|
-
|
547
|
-
# Make sure column names are valid snowflake identifiers.
|
548
|
-
assert output_cols is not None # Make MyPy happy
|
549
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
550
|
-
|
551
|
-
return rv
|
552
|
-
|
553
521
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
554
522
|
@telemetry.send_api_usage_telemetry(
|
555
523
|
project=_PROJECT,
|
@@ -589,7 +557,7 @@ class RBFSampler(BaseTransformer):
|
|
589
557
|
transform_kwargs = dict(
|
590
558
|
session=dataset._session,
|
591
559
|
dependencies=self._deps,
|
592
|
-
|
560
|
+
drop_input_cols = self._drop_input_cols,
|
593
561
|
expected_output_cols_type="float",
|
594
562
|
)
|
595
563
|
|
@@ -654,7 +622,7 @@ class RBFSampler(BaseTransformer):
|
|
654
622
|
transform_kwargs = dict(
|
655
623
|
session=dataset._session,
|
656
624
|
dependencies=self._deps,
|
657
|
-
|
625
|
+
drop_input_cols = self._drop_input_cols,
|
658
626
|
expected_output_cols_type="float",
|
659
627
|
)
|
660
628
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -715,7 +683,7 @@ class RBFSampler(BaseTransformer):
|
|
715
683
|
transform_kwargs = dict(
|
716
684
|
session=dataset._session,
|
717
685
|
dependencies=self._deps,
|
718
|
-
|
686
|
+
drop_input_cols = self._drop_input_cols,
|
719
687
|
expected_output_cols_type="float",
|
720
688
|
)
|
721
689
|
|
@@ -780,7 +748,7 @@ class RBFSampler(BaseTransformer):
|
|
780
748
|
transform_kwargs = dict(
|
781
749
|
session=dataset._session,
|
782
750
|
dependencies=self._deps,
|
783
|
-
|
751
|
+
drop_input_cols = self._drop_input_cols,
|
784
752
|
expected_output_cols_type="float",
|
785
753
|
)
|
786
754
|
|
@@ -834,13 +802,17 @@ class RBFSampler(BaseTransformer):
|
|
834
802
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
835
803
|
|
836
804
|
if isinstance(dataset, DataFrame):
|
805
|
+
self._deps = self._batch_inference_validate_snowpark(
|
806
|
+
dataset=dataset,
|
807
|
+
inference_method="score",
|
808
|
+
)
|
837
809
|
selected_cols = self._get_active_columns()
|
838
810
|
if len(selected_cols) > 0:
|
839
811
|
dataset = dataset.select(selected_cols)
|
840
812
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
841
813
|
transform_kwargs = dict(
|
842
814
|
session=dataset._session,
|
843
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
815
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
844
816
|
score_sproc_imports=['sklearn'],
|
845
817
|
)
|
846
818
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -914,9 +886,9 @@ class RBFSampler(BaseTransformer):
|
|
914
886
|
transform_kwargs = dict(
|
915
887
|
session = dataset._session,
|
916
888
|
dependencies = self._deps,
|
917
|
-
|
918
|
-
expected_output_cols_type
|
919
|
-
n_neighbors =
|
889
|
+
drop_input_cols = self._drop_input_cols,
|
890
|
+
expected_output_cols_type="array",
|
891
|
+
n_neighbors = n_neighbors,
|
920
892
|
return_distance = return_distance
|
921
893
|
)
|
922
894
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -269,18 +269,24 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
269
269
|
self._get_model_signatures(dataset)
|
270
270
|
return self
|
271
271
|
|
272
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
273
|
-
if self._drop_input_cols:
|
274
|
-
return []
|
275
|
-
else:
|
276
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
277
|
-
|
278
272
|
def _batch_inference_validate_snowpark(
|
279
273
|
self,
|
280
274
|
dataset: DataFrame,
|
281
275
|
inference_method: str,
|
282
276
|
) -> List[str]:
|
283
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
277
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
278
|
+
return the available package that exists in the snowflake anaconda channel
|
279
|
+
|
280
|
+
Args:
|
281
|
+
dataset: snowpark dataframe
|
282
|
+
inference_method: the inference method such as predict, score...
|
283
|
+
|
284
|
+
Raises:
|
285
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
286
|
+
SnowflakeMLException: If the session is None, raise error
|
287
|
+
|
288
|
+
Returns:
|
289
|
+
A list of available package that exists in the snowflake anaconda channel
|
284
290
|
"""
|
285
291
|
if not self._is_fitted:
|
286
292
|
raise exceptions.SnowflakeMLException(
|
@@ -352,7 +358,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
352
358
|
transform_kwargs = dict(
|
353
359
|
session = dataset._session,
|
354
360
|
dependencies = self._deps,
|
355
|
-
|
361
|
+
drop_input_cols = self._drop_input_cols,
|
356
362
|
expected_output_cols_type = expected_type_inferred,
|
357
363
|
)
|
358
364
|
|
@@ -414,16 +420,16 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
414
420
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
415
421
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
416
422
|
# each row containing a list of values.
|
417
|
-
expected_dtype = "
|
423
|
+
expected_dtype = "array"
|
418
424
|
|
419
425
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
420
426
|
if expected_dtype == "":
|
421
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
427
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
422
428
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
423
|
-
expected_dtype = "
|
424
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
429
|
+
expected_dtype = "array"
|
430
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
425
431
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
426
|
-
expected_dtype = "
|
432
|
+
expected_dtype = "array"
|
427
433
|
else:
|
428
434
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
429
435
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -441,7 +447,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
441
447
|
transform_kwargs = dict(
|
442
448
|
session = dataset._session,
|
443
449
|
dependencies = self._deps,
|
444
|
-
|
450
|
+
drop_input_cols = self._drop_input_cols,
|
445
451
|
expected_output_cols_type = expected_dtype,
|
446
452
|
)
|
447
453
|
|
@@ -492,7 +498,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
492
498
|
subproject=_SUBPROJECT,
|
493
499
|
)
|
494
500
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
495
|
-
|
501
|
+
drop_input_cols=self._drop_input_cols,
|
496
502
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
497
503
|
)
|
498
504
|
self._sklearn_object = fitted_estimator
|
@@ -510,44 +516,6 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
510
516
|
assert self._sklearn_object is not None
|
511
517
|
return self._sklearn_object.embedding_
|
512
518
|
|
513
|
-
|
514
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
515
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
516
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
517
|
-
"""
|
518
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
519
|
-
if output_cols:
|
520
|
-
output_cols = [
|
521
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
522
|
-
for c in output_cols
|
523
|
-
]
|
524
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
525
|
-
output_cols = [output_cols_prefix]
|
526
|
-
elif self._sklearn_object is not None:
|
527
|
-
classes = self._sklearn_object.classes_
|
528
|
-
if isinstance(classes, numpy.ndarray):
|
529
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
530
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
531
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
532
|
-
output_cols = []
|
533
|
-
for i, cl in enumerate(classes):
|
534
|
-
# For binary classification, there is only one output column for each class
|
535
|
-
# ndarray as the two classes are complementary.
|
536
|
-
if len(cl) == 2:
|
537
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
538
|
-
else:
|
539
|
-
output_cols.extend([
|
540
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
541
|
-
])
|
542
|
-
else:
|
543
|
-
output_cols = []
|
544
|
-
|
545
|
-
# Make sure column names are valid snowflake identifiers.
|
546
|
-
assert output_cols is not None # Make MyPy happy
|
547
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
548
|
-
|
549
|
-
return rv
|
550
|
-
|
551
519
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
552
520
|
@telemetry.send_api_usage_telemetry(
|
553
521
|
project=_PROJECT,
|
@@ -587,7 +555,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
587
555
|
transform_kwargs = dict(
|
588
556
|
session=dataset._session,
|
589
557
|
dependencies=self._deps,
|
590
|
-
|
558
|
+
drop_input_cols = self._drop_input_cols,
|
591
559
|
expected_output_cols_type="float",
|
592
560
|
)
|
593
561
|
|
@@ -652,7 +620,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
652
620
|
transform_kwargs = dict(
|
653
621
|
session=dataset._session,
|
654
622
|
dependencies=self._deps,
|
655
|
-
|
623
|
+
drop_input_cols = self._drop_input_cols,
|
656
624
|
expected_output_cols_type="float",
|
657
625
|
)
|
658
626
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -713,7 +681,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
713
681
|
transform_kwargs = dict(
|
714
682
|
session=dataset._session,
|
715
683
|
dependencies=self._deps,
|
716
|
-
|
684
|
+
drop_input_cols = self._drop_input_cols,
|
717
685
|
expected_output_cols_type="float",
|
718
686
|
)
|
719
687
|
|
@@ -778,7 +746,7 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
778
746
|
transform_kwargs = dict(
|
779
747
|
session=dataset._session,
|
780
748
|
dependencies=self._deps,
|
781
|
-
|
749
|
+
drop_input_cols = self._drop_input_cols,
|
782
750
|
expected_output_cols_type="float",
|
783
751
|
)
|
784
752
|
|
@@ -832,13 +800,17 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
832
800
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
833
801
|
|
834
802
|
if isinstance(dataset, DataFrame):
|
803
|
+
self._deps = self._batch_inference_validate_snowpark(
|
804
|
+
dataset=dataset,
|
805
|
+
inference_method="score",
|
806
|
+
)
|
835
807
|
selected_cols = self._get_active_columns()
|
836
808
|
if len(selected_cols) > 0:
|
837
809
|
dataset = dataset.select(selected_cols)
|
838
810
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
839
811
|
transform_kwargs = dict(
|
840
812
|
session=dataset._session,
|
841
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
813
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
842
814
|
score_sproc_imports=['sklearn'],
|
843
815
|
)
|
844
816
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -912,9 +884,9 @@ class SkewedChi2Sampler(BaseTransformer):
|
|
912
884
|
transform_kwargs = dict(
|
913
885
|
session = dataset._session,
|
914
886
|
dependencies = self._deps,
|
915
|
-
|
916
|
-
expected_output_cols_type
|
917
|
-
n_neighbors =
|
887
|
+
drop_input_cols = self._drop_input_cols,
|
888
|
+
expected_output_cols_type="array",
|
889
|
+
n_neighbors = n_neighbors,
|
918
890
|
return_distance = return_distance
|
919
891
|
)
|
920
892
|
elif isinstance(dataset, pd.DataFrame):
|