snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -391,18 +391,24 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
391
391
|
self._get_model_signatures(dataset)
|
392
392
|
return self
|
393
393
|
|
394
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
395
|
-
if self._drop_input_cols:
|
396
|
-
return []
|
397
|
-
else:
|
398
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
399
|
-
|
400
394
|
def _batch_inference_validate_snowpark(
|
401
395
|
self,
|
402
396
|
dataset: DataFrame,
|
403
397
|
inference_method: str,
|
404
398
|
) -> List[str]:
|
405
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
399
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
400
|
+
return the available package that exists in the snowflake anaconda channel
|
401
|
+
|
402
|
+
Args:
|
403
|
+
dataset: snowpark dataframe
|
404
|
+
inference_method: the inference method such as predict, score...
|
405
|
+
|
406
|
+
Raises:
|
407
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
408
|
+
SnowflakeMLException: If the session is None, raise error
|
409
|
+
|
410
|
+
Returns:
|
411
|
+
A list of available package that exists in the snowflake anaconda channel
|
406
412
|
"""
|
407
413
|
if not self._is_fitted:
|
408
414
|
raise exceptions.SnowflakeMLException(
|
@@ -476,7 +482,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
476
482
|
transform_kwargs = dict(
|
477
483
|
session = dataset._session,
|
478
484
|
dependencies = self._deps,
|
479
|
-
|
485
|
+
drop_input_cols = self._drop_input_cols,
|
480
486
|
expected_output_cols_type = expected_type_inferred,
|
481
487
|
)
|
482
488
|
|
@@ -536,16 +542,16 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
536
542
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
537
543
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
538
544
|
# each row containing a list of values.
|
539
|
-
expected_dtype = "
|
545
|
+
expected_dtype = "array"
|
540
546
|
|
541
547
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
542
548
|
if expected_dtype == "":
|
543
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
549
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
544
550
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
545
|
-
expected_dtype = "
|
546
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
551
|
+
expected_dtype = "array"
|
552
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
547
553
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
548
|
-
expected_dtype = "
|
554
|
+
expected_dtype = "array"
|
549
555
|
else:
|
550
556
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
551
557
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -563,7 +569,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
563
569
|
transform_kwargs = dict(
|
564
570
|
session = dataset._session,
|
565
571
|
dependencies = self._deps,
|
566
|
-
|
572
|
+
drop_input_cols = self._drop_input_cols,
|
567
573
|
expected_output_cols_type = expected_dtype,
|
568
574
|
)
|
569
575
|
|
@@ -614,7 +620,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
614
620
|
subproject=_SUBPROJECT,
|
615
621
|
)
|
616
622
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
617
|
-
|
623
|
+
drop_input_cols=self._drop_input_cols,
|
618
624
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
619
625
|
)
|
620
626
|
self._sklearn_object = fitted_estimator
|
@@ -632,44 +638,6 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
632
638
|
assert self._sklearn_object is not None
|
633
639
|
return self._sklearn_object.embedding_
|
634
640
|
|
635
|
-
|
636
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
637
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
638
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
639
|
-
"""
|
640
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
641
|
-
if output_cols:
|
642
|
-
output_cols = [
|
643
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
644
|
-
for c in output_cols
|
645
|
-
]
|
646
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
647
|
-
output_cols = [output_cols_prefix]
|
648
|
-
elif self._sklearn_object is not None:
|
649
|
-
classes = self._sklearn_object.classes_
|
650
|
-
if isinstance(classes, numpy.ndarray):
|
651
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
652
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
653
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
654
|
-
output_cols = []
|
655
|
-
for i, cl in enumerate(classes):
|
656
|
-
# For binary classification, there is only one output column for each class
|
657
|
-
# ndarray as the two classes are complementary.
|
658
|
-
if len(cl) == 2:
|
659
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
660
|
-
else:
|
661
|
-
output_cols.extend([
|
662
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
663
|
-
])
|
664
|
-
else:
|
665
|
-
output_cols = []
|
666
|
-
|
667
|
-
# Make sure column names are valid snowflake identifiers.
|
668
|
-
assert output_cols is not None # Make MyPy happy
|
669
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
670
|
-
|
671
|
-
return rv
|
672
|
-
|
673
641
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
674
642
|
@telemetry.send_api_usage_telemetry(
|
675
643
|
project=_PROJECT,
|
@@ -711,7 +679,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
711
679
|
transform_kwargs = dict(
|
712
680
|
session=dataset._session,
|
713
681
|
dependencies=self._deps,
|
714
|
-
|
682
|
+
drop_input_cols = self._drop_input_cols,
|
715
683
|
expected_output_cols_type="float",
|
716
684
|
)
|
717
685
|
|
@@ -778,7 +746,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
778
746
|
transform_kwargs = dict(
|
779
747
|
session=dataset._session,
|
780
748
|
dependencies=self._deps,
|
781
|
-
|
749
|
+
drop_input_cols = self._drop_input_cols,
|
782
750
|
expected_output_cols_type="float",
|
783
751
|
)
|
784
752
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -839,7 +807,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
839
807
|
transform_kwargs = dict(
|
840
808
|
session=dataset._session,
|
841
809
|
dependencies=self._deps,
|
842
|
-
|
810
|
+
drop_input_cols = self._drop_input_cols,
|
843
811
|
expected_output_cols_type="float",
|
844
812
|
)
|
845
813
|
|
@@ -904,7 +872,7 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
904
872
|
transform_kwargs = dict(
|
905
873
|
session=dataset._session,
|
906
874
|
dependencies=self._deps,
|
907
|
-
|
875
|
+
drop_input_cols = self._drop_input_cols,
|
908
876
|
expected_output_cols_type="float",
|
909
877
|
)
|
910
878
|
|
@@ -960,13 +928,17 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
960
928
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
961
929
|
|
962
930
|
if isinstance(dataset, DataFrame):
|
931
|
+
self._deps = self._batch_inference_validate_snowpark(
|
932
|
+
dataset=dataset,
|
933
|
+
inference_method="score",
|
934
|
+
)
|
963
935
|
selected_cols = self._get_active_columns()
|
964
936
|
if len(selected_cols) > 0:
|
965
937
|
dataset = dataset.select(selected_cols)
|
966
938
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
967
939
|
transform_kwargs = dict(
|
968
940
|
session=dataset._session,
|
969
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
941
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
970
942
|
score_sproc_imports=['sklearn'],
|
971
943
|
)
|
972
944
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1040,9 +1012,9 @@ class DecisionTreeClassifier(BaseTransformer):
|
|
1040
1012
|
transform_kwargs = dict(
|
1041
1013
|
session = dataset._session,
|
1042
1014
|
dependencies = self._deps,
|
1043
|
-
|
1044
|
-
expected_output_cols_type
|
1045
|
-
n_neighbors =
|
1015
|
+
drop_input_cols = self._drop_input_cols,
|
1016
|
+
expected_output_cols_type="array",
|
1017
|
+
n_neighbors = n_neighbors,
|
1046
1018
|
return_distance = return_distance
|
1047
1019
|
)
|
1048
1020
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -373,18 +373,24 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
373
373
|
self._get_model_signatures(dataset)
|
374
374
|
return self
|
375
375
|
|
376
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
377
|
-
if self._drop_input_cols:
|
378
|
-
return []
|
379
|
-
else:
|
380
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
381
|
-
|
382
376
|
def _batch_inference_validate_snowpark(
|
383
377
|
self,
|
384
378
|
dataset: DataFrame,
|
385
379
|
inference_method: str,
|
386
380
|
) -> List[str]:
|
387
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
381
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
382
|
+
return the available package that exists in the snowflake anaconda channel
|
383
|
+
|
384
|
+
Args:
|
385
|
+
dataset: snowpark dataframe
|
386
|
+
inference_method: the inference method such as predict, score...
|
387
|
+
|
388
|
+
Raises:
|
389
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
390
|
+
SnowflakeMLException: If the session is None, raise error
|
391
|
+
|
392
|
+
Returns:
|
393
|
+
A list of available package that exists in the snowflake anaconda channel
|
388
394
|
"""
|
389
395
|
if not self._is_fitted:
|
390
396
|
raise exceptions.SnowflakeMLException(
|
@@ -458,7 +464,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
458
464
|
transform_kwargs = dict(
|
459
465
|
session = dataset._session,
|
460
466
|
dependencies = self._deps,
|
461
|
-
|
467
|
+
drop_input_cols = self._drop_input_cols,
|
462
468
|
expected_output_cols_type = expected_type_inferred,
|
463
469
|
)
|
464
470
|
|
@@ -518,16 +524,16 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
518
524
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
519
525
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
520
526
|
# each row containing a list of values.
|
521
|
-
expected_dtype = "
|
527
|
+
expected_dtype = "array"
|
522
528
|
|
523
529
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
524
530
|
if expected_dtype == "":
|
525
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
531
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
526
532
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
527
|
-
expected_dtype = "
|
528
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
533
|
+
expected_dtype = "array"
|
534
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
529
535
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
530
|
-
expected_dtype = "
|
536
|
+
expected_dtype = "array"
|
531
537
|
else:
|
532
538
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
533
539
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -545,7 +551,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
545
551
|
transform_kwargs = dict(
|
546
552
|
session = dataset._session,
|
547
553
|
dependencies = self._deps,
|
548
|
-
|
554
|
+
drop_input_cols = self._drop_input_cols,
|
549
555
|
expected_output_cols_type = expected_dtype,
|
550
556
|
)
|
551
557
|
|
@@ -596,7 +602,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
596
602
|
subproject=_SUBPROJECT,
|
597
603
|
)
|
598
604
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
599
|
-
|
605
|
+
drop_input_cols=self._drop_input_cols,
|
600
606
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
601
607
|
)
|
602
608
|
self._sklearn_object = fitted_estimator
|
@@ -614,44 +620,6 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
614
620
|
assert self._sklearn_object is not None
|
615
621
|
return self._sklearn_object.embedding_
|
616
622
|
|
617
|
-
|
618
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
619
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
620
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
621
|
-
"""
|
622
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
623
|
-
if output_cols:
|
624
|
-
output_cols = [
|
625
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
626
|
-
for c in output_cols
|
627
|
-
]
|
628
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
629
|
-
output_cols = [output_cols_prefix]
|
630
|
-
elif self._sklearn_object is not None:
|
631
|
-
classes = self._sklearn_object.classes_
|
632
|
-
if isinstance(classes, numpy.ndarray):
|
633
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
634
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
635
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
636
|
-
output_cols = []
|
637
|
-
for i, cl in enumerate(classes):
|
638
|
-
# For binary classification, there is only one output column for each class
|
639
|
-
# ndarray as the two classes are complementary.
|
640
|
-
if len(cl) == 2:
|
641
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
642
|
-
else:
|
643
|
-
output_cols.extend([
|
644
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
645
|
-
])
|
646
|
-
else:
|
647
|
-
output_cols = []
|
648
|
-
|
649
|
-
# Make sure column names are valid snowflake identifiers.
|
650
|
-
assert output_cols is not None # Make MyPy happy
|
651
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
652
|
-
|
653
|
-
return rv
|
654
|
-
|
655
623
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
656
624
|
@telemetry.send_api_usage_telemetry(
|
657
625
|
project=_PROJECT,
|
@@ -691,7 +659,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
691
659
|
transform_kwargs = dict(
|
692
660
|
session=dataset._session,
|
693
661
|
dependencies=self._deps,
|
694
|
-
|
662
|
+
drop_input_cols = self._drop_input_cols,
|
695
663
|
expected_output_cols_type="float",
|
696
664
|
)
|
697
665
|
|
@@ -756,7 +724,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
756
724
|
transform_kwargs = dict(
|
757
725
|
session=dataset._session,
|
758
726
|
dependencies=self._deps,
|
759
|
-
|
727
|
+
drop_input_cols = self._drop_input_cols,
|
760
728
|
expected_output_cols_type="float",
|
761
729
|
)
|
762
730
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -817,7 +785,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
817
785
|
transform_kwargs = dict(
|
818
786
|
session=dataset._session,
|
819
787
|
dependencies=self._deps,
|
820
|
-
|
788
|
+
drop_input_cols = self._drop_input_cols,
|
821
789
|
expected_output_cols_type="float",
|
822
790
|
)
|
823
791
|
|
@@ -882,7 +850,7 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
882
850
|
transform_kwargs = dict(
|
883
851
|
session=dataset._session,
|
884
852
|
dependencies=self._deps,
|
885
|
-
|
853
|
+
drop_input_cols = self._drop_input_cols,
|
886
854
|
expected_output_cols_type="float",
|
887
855
|
)
|
888
856
|
|
@@ -938,13 +906,17 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
938
906
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
939
907
|
|
940
908
|
if isinstance(dataset, DataFrame):
|
909
|
+
self._deps = self._batch_inference_validate_snowpark(
|
910
|
+
dataset=dataset,
|
911
|
+
inference_method="score",
|
912
|
+
)
|
941
913
|
selected_cols = self._get_active_columns()
|
942
914
|
if len(selected_cols) > 0:
|
943
915
|
dataset = dataset.select(selected_cols)
|
944
916
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
945
917
|
transform_kwargs = dict(
|
946
918
|
session=dataset._session,
|
947
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
919
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
948
920
|
score_sproc_imports=['sklearn'],
|
949
921
|
)
|
950
922
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1018,9 +990,9 @@ class DecisionTreeRegressor(BaseTransformer):
|
|
1018
990
|
transform_kwargs = dict(
|
1019
991
|
session = dataset._session,
|
1020
992
|
dependencies = self._deps,
|
1021
|
-
|
1022
|
-
expected_output_cols_type
|
1023
|
-
n_neighbors =
|
993
|
+
drop_input_cols = self._drop_input_cols,
|
994
|
+
expected_output_cols_type="array",
|
995
|
+
n_neighbors = n_neighbors,
|
1024
996
|
return_distance = return_distance
|
1025
997
|
)
|
1026
998
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -383,18 +383,24 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
383
383
|
self._get_model_signatures(dataset)
|
384
384
|
return self
|
385
385
|
|
386
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
387
|
-
if self._drop_input_cols:
|
388
|
-
return []
|
389
|
-
else:
|
390
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
391
|
-
|
392
386
|
def _batch_inference_validate_snowpark(
|
393
387
|
self,
|
394
388
|
dataset: DataFrame,
|
395
389
|
inference_method: str,
|
396
390
|
) -> List[str]:
|
397
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
391
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
392
|
+
return the available package that exists in the snowflake anaconda channel
|
393
|
+
|
394
|
+
Args:
|
395
|
+
dataset: snowpark dataframe
|
396
|
+
inference_method: the inference method such as predict, score...
|
397
|
+
|
398
|
+
Raises:
|
399
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
400
|
+
SnowflakeMLException: If the session is None, raise error
|
401
|
+
|
402
|
+
Returns:
|
403
|
+
A list of available package that exists in the snowflake anaconda channel
|
398
404
|
"""
|
399
405
|
if not self._is_fitted:
|
400
406
|
raise exceptions.SnowflakeMLException(
|
@@ -468,7 +474,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
468
474
|
transform_kwargs = dict(
|
469
475
|
session = dataset._session,
|
470
476
|
dependencies = self._deps,
|
471
|
-
|
477
|
+
drop_input_cols = self._drop_input_cols,
|
472
478
|
expected_output_cols_type = expected_type_inferred,
|
473
479
|
)
|
474
480
|
|
@@ -528,16 +534,16 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
528
534
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
529
535
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
530
536
|
# each row containing a list of values.
|
531
|
-
expected_dtype = "
|
537
|
+
expected_dtype = "array"
|
532
538
|
|
533
539
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
534
540
|
if expected_dtype == "":
|
535
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
541
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
536
542
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
537
|
-
expected_dtype = "
|
538
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
543
|
+
expected_dtype = "array"
|
544
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
539
545
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
540
|
-
expected_dtype = "
|
546
|
+
expected_dtype = "array"
|
541
547
|
else:
|
542
548
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
543
549
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -555,7 +561,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
555
561
|
transform_kwargs = dict(
|
556
562
|
session = dataset._session,
|
557
563
|
dependencies = self._deps,
|
558
|
-
|
564
|
+
drop_input_cols = self._drop_input_cols,
|
559
565
|
expected_output_cols_type = expected_dtype,
|
560
566
|
)
|
561
567
|
|
@@ -606,7 +612,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
606
612
|
subproject=_SUBPROJECT,
|
607
613
|
)
|
608
614
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
609
|
-
|
615
|
+
drop_input_cols=self._drop_input_cols,
|
610
616
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
611
617
|
)
|
612
618
|
self._sklearn_object = fitted_estimator
|
@@ -624,44 +630,6 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
624
630
|
assert self._sklearn_object is not None
|
625
631
|
return self._sklearn_object.embedding_
|
626
632
|
|
627
|
-
|
628
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
629
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
630
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
631
|
-
"""
|
632
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
633
|
-
if output_cols:
|
634
|
-
output_cols = [
|
635
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
636
|
-
for c in output_cols
|
637
|
-
]
|
638
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
639
|
-
output_cols = [output_cols_prefix]
|
640
|
-
elif self._sklearn_object is not None:
|
641
|
-
classes = self._sklearn_object.classes_
|
642
|
-
if isinstance(classes, numpy.ndarray):
|
643
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
644
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
645
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
646
|
-
output_cols = []
|
647
|
-
for i, cl in enumerate(classes):
|
648
|
-
# For binary classification, there is only one output column for each class
|
649
|
-
# ndarray as the two classes are complementary.
|
650
|
-
if len(cl) == 2:
|
651
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
652
|
-
else:
|
653
|
-
output_cols.extend([
|
654
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
655
|
-
])
|
656
|
-
else:
|
657
|
-
output_cols = []
|
658
|
-
|
659
|
-
# Make sure column names are valid snowflake identifiers.
|
660
|
-
assert output_cols is not None # Make MyPy happy
|
661
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
662
|
-
|
663
|
-
return rv
|
664
|
-
|
665
633
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
666
634
|
@telemetry.send_api_usage_telemetry(
|
667
635
|
project=_PROJECT,
|
@@ -703,7 +671,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
703
671
|
transform_kwargs = dict(
|
704
672
|
session=dataset._session,
|
705
673
|
dependencies=self._deps,
|
706
|
-
|
674
|
+
drop_input_cols = self._drop_input_cols,
|
707
675
|
expected_output_cols_type="float",
|
708
676
|
)
|
709
677
|
|
@@ -770,7 +738,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
770
738
|
transform_kwargs = dict(
|
771
739
|
session=dataset._session,
|
772
740
|
dependencies=self._deps,
|
773
|
-
|
741
|
+
drop_input_cols = self._drop_input_cols,
|
774
742
|
expected_output_cols_type="float",
|
775
743
|
)
|
776
744
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -831,7 +799,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
831
799
|
transform_kwargs = dict(
|
832
800
|
session=dataset._session,
|
833
801
|
dependencies=self._deps,
|
834
|
-
|
802
|
+
drop_input_cols = self._drop_input_cols,
|
835
803
|
expected_output_cols_type="float",
|
836
804
|
)
|
837
805
|
|
@@ -896,7 +864,7 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
896
864
|
transform_kwargs = dict(
|
897
865
|
session=dataset._session,
|
898
866
|
dependencies=self._deps,
|
899
|
-
|
867
|
+
drop_input_cols = self._drop_input_cols,
|
900
868
|
expected_output_cols_type="float",
|
901
869
|
)
|
902
870
|
|
@@ -952,13 +920,17 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
952
920
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
953
921
|
|
954
922
|
if isinstance(dataset, DataFrame):
|
923
|
+
self._deps = self._batch_inference_validate_snowpark(
|
924
|
+
dataset=dataset,
|
925
|
+
inference_method="score",
|
926
|
+
)
|
955
927
|
selected_cols = self._get_active_columns()
|
956
928
|
if len(selected_cols) > 0:
|
957
929
|
dataset = dataset.select(selected_cols)
|
958
930
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
959
931
|
transform_kwargs = dict(
|
960
932
|
session=dataset._session,
|
961
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
933
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
962
934
|
score_sproc_imports=['sklearn'],
|
963
935
|
)
|
964
936
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1032,9 +1004,9 @@ class ExtraTreeClassifier(BaseTransformer):
|
|
1032
1004
|
transform_kwargs = dict(
|
1033
1005
|
session = dataset._session,
|
1034
1006
|
dependencies = self._deps,
|
1035
|
-
|
1036
|
-
expected_output_cols_type
|
1037
|
-
n_neighbors =
|
1007
|
+
drop_input_cols = self._drop_input_cols,
|
1008
|
+
expected_output_cols_type="array",
|
1009
|
+
n_neighbors = n_neighbors,
|
1038
1010
|
return_distance = return_distance
|
1039
1011
|
)
|
1040
1012
|
elif isinstance(dataset, pd.DataFrame):
|