snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -365,18 +365,24 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
365
365
|
self._get_model_signatures(dataset)
|
366
366
|
return self
|
367
367
|
|
368
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
369
|
-
if self._drop_input_cols:
|
370
|
-
return []
|
371
|
-
else:
|
372
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
373
|
-
|
374
368
|
def _batch_inference_validate_snowpark(
|
375
369
|
self,
|
376
370
|
dataset: DataFrame,
|
377
371
|
inference_method: str,
|
378
372
|
) -> List[str]:
|
379
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
373
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
374
|
+
return the available package that exists in the snowflake anaconda channel
|
375
|
+
|
376
|
+
Args:
|
377
|
+
dataset: snowpark dataframe
|
378
|
+
inference_method: the inference method such as predict, score...
|
379
|
+
|
380
|
+
Raises:
|
381
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
382
|
+
SnowflakeMLException: If the session is None, raise error
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
A list of available package that exists in the snowflake anaconda channel
|
380
386
|
"""
|
381
387
|
if not self._is_fitted:
|
382
388
|
raise exceptions.SnowflakeMLException(
|
@@ -450,7 +456,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
450
456
|
transform_kwargs = dict(
|
451
457
|
session = dataset._session,
|
452
458
|
dependencies = self._deps,
|
453
|
-
|
459
|
+
drop_input_cols = self._drop_input_cols,
|
454
460
|
expected_output_cols_type = expected_type_inferred,
|
455
461
|
)
|
456
462
|
|
@@ -510,16 +516,16 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
510
516
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
511
517
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
512
518
|
# each row containing a list of values.
|
513
|
-
expected_dtype = "
|
519
|
+
expected_dtype = "array"
|
514
520
|
|
515
521
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
516
522
|
if expected_dtype == "":
|
517
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
523
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
518
524
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
519
|
-
expected_dtype = "
|
520
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
525
|
+
expected_dtype = "array"
|
526
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
521
527
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
522
|
-
expected_dtype = "
|
528
|
+
expected_dtype = "array"
|
523
529
|
else:
|
524
530
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
525
531
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -537,7 +543,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
537
543
|
transform_kwargs = dict(
|
538
544
|
session = dataset._session,
|
539
545
|
dependencies = self._deps,
|
540
|
-
|
546
|
+
drop_input_cols = self._drop_input_cols,
|
541
547
|
expected_output_cols_type = expected_dtype,
|
542
548
|
)
|
543
549
|
|
@@ -588,7 +594,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
588
594
|
subproject=_SUBPROJECT,
|
589
595
|
)
|
590
596
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
591
|
-
|
597
|
+
drop_input_cols=self._drop_input_cols,
|
592
598
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
593
599
|
)
|
594
600
|
self._sklearn_object = fitted_estimator
|
@@ -606,44 +612,6 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
606
612
|
assert self._sklearn_object is not None
|
607
613
|
return self._sklearn_object.embedding_
|
608
614
|
|
609
|
-
|
610
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
611
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
612
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
613
|
-
"""
|
614
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
615
|
-
if output_cols:
|
616
|
-
output_cols = [
|
617
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
618
|
-
for c in output_cols
|
619
|
-
]
|
620
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
621
|
-
output_cols = [output_cols_prefix]
|
622
|
-
elif self._sklearn_object is not None:
|
623
|
-
classes = self._sklearn_object.classes_
|
624
|
-
if isinstance(classes, numpy.ndarray):
|
625
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
626
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
627
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
628
|
-
output_cols = []
|
629
|
-
for i, cl in enumerate(classes):
|
630
|
-
# For binary classification, there is only one output column for each class
|
631
|
-
# ndarray as the two classes are complementary.
|
632
|
-
if len(cl) == 2:
|
633
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
634
|
-
else:
|
635
|
-
output_cols.extend([
|
636
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
637
|
-
])
|
638
|
-
else:
|
639
|
-
output_cols = []
|
640
|
-
|
641
|
-
# Make sure column names are valid snowflake identifiers.
|
642
|
-
assert output_cols is not None # Make MyPy happy
|
643
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
644
|
-
|
645
|
-
return rv
|
646
|
-
|
647
615
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
648
616
|
@telemetry.send_api_usage_telemetry(
|
649
617
|
project=_PROJECT,
|
@@ -683,7 +651,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
683
651
|
transform_kwargs = dict(
|
684
652
|
session=dataset._session,
|
685
653
|
dependencies=self._deps,
|
686
|
-
|
654
|
+
drop_input_cols = self._drop_input_cols,
|
687
655
|
expected_output_cols_type="float",
|
688
656
|
)
|
689
657
|
|
@@ -748,7 +716,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
748
716
|
transform_kwargs = dict(
|
749
717
|
session=dataset._session,
|
750
718
|
dependencies=self._deps,
|
751
|
-
|
719
|
+
drop_input_cols = self._drop_input_cols,
|
752
720
|
expected_output_cols_type="float",
|
753
721
|
)
|
754
722
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -809,7 +777,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
809
777
|
transform_kwargs = dict(
|
810
778
|
session=dataset._session,
|
811
779
|
dependencies=self._deps,
|
812
|
-
|
780
|
+
drop_input_cols = self._drop_input_cols,
|
813
781
|
expected_output_cols_type="float",
|
814
782
|
)
|
815
783
|
|
@@ -874,7 +842,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
874
842
|
transform_kwargs = dict(
|
875
843
|
session=dataset._session,
|
876
844
|
dependencies=self._deps,
|
877
|
-
|
845
|
+
drop_input_cols = self._drop_input_cols,
|
878
846
|
expected_output_cols_type="float",
|
879
847
|
)
|
880
848
|
|
@@ -930,13 +898,17 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
930
898
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
931
899
|
|
932
900
|
if isinstance(dataset, DataFrame):
|
901
|
+
self._deps = self._batch_inference_validate_snowpark(
|
902
|
+
dataset=dataset,
|
903
|
+
inference_method="score",
|
904
|
+
)
|
933
905
|
selected_cols = self._get_active_columns()
|
934
906
|
if len(selected_cols) > 0:
|
935
907
|
dataset = dataset.select(selected_cols)
|
936
908
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
937
909
|
transform_kwargs = dict(
|
938
910
|
session=dataset._session,
|
939
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
911
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
940
912
|
score_sproc_imports=['sklearn'],
|
941
913
|
)
|
942
914
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1010,9 +982,9 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
1010
982
|
transform_kwargs = dict(
|
1011
983
|
session = dataset._session,
|
1012
984
|
dependencies = self._deps,
|
1013
|
-
|
1014
|
-
expected_output_cols_type
|
1015
|
-
n_neighbors =
|
985
|
+
drop_input_cols = self._drop_input_cols,
|
986
|
+
expected_output_cols_type="array",
|
987
|
+
n_neighbors = n_neighbors,
|
1016
988
|
return_distance = return_distance
|
1017
989
|
)
|
1018
990
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -483,18 +483,24 @@ class XGBClassifier(BaseTransformer):
|
|
483
483
|
self._get_model_signatures(dataset)
|
484
484
|
return self
|
485
485
|
|
486
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
487
|
-
if self._drop_input_cols:
|
488
|
-
return []
|
489
|
-
else:
|
490
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
491
|
-
|
492
486
|
def _batch_inference_validate_snowpark(
|
493
487
|
self,
|
494
488
|
dataset: DataFrame,
|
495
489
|
inference_method: str,
|
496
490
|
) -> List[str]:
|
497
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
491
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
492
|
+
return the available package that exists in the snowflake anaconda channel
|
493
|
+
|
494
|
+
Args:
|
495
|
+
dataset: snowpark dataframe
|
496
|
+
inference_method: the inference method such as predict, score...
|
497
|
+
|
498
|
+
Raises:
|
499
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
500
|
+
SnowflakeMLException: If the session is None, raise error
|
501
|
+
|
502
|
+
Returns:
|
503
|
+
A list of available package that exists in the snowflake anaconda channel
|
498
504
|
"""
|
499
505
|
if not self._is_fitted:
|
500
506
|
raise exceptions.SnowflakeMLException(
|
@@ -568,7 +574,7 @@ class XGBClassifier(BaseTransformer):
|
|
568
574
|
transform_kwargs = dict(
|
569
575
|
session = dataset._session,
|
570
576
|
dependencies = self._deps,
|
571
|
-
|
577
|
+
drop_input_cols = self._drop_input_cols,
|
572
578
|
expected_output_cols_type = expected_type_inferred,
|
573
579
|
)
|
574
580
|
|
@@ -628,16 +634,16 @@ class XGBClassifier(BaseTransformer):
|
|
628
634
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
629
635
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
630
636
|
# each row containing a list of values.
|
631
|
-
expected_dtype = "
|
637
|
+
expected_dtype = "array"
|
632
638
|
|
633
639
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
634
640
|
if expected_dtype == "":
|
635
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
641
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
636
642
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
637
|
-
expected_dtype = "
|
638
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
643
|
+
expected_dtype = "array"
|
644
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
639
645
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
640
|
-
expected_dtype = "
|
646
|
+
expected_dtype = "array"
|
641
647
|
else:
|
642
648
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
643
649
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -655,7 +661,7 @@ class XGBClassifier(BaseTransformer):
|
|
655
661
|
transform_kwargs = dict(
|
656
662
|
session = dataset._session,
|
657
663
|
dependencies = self._deps,
|
658
|
-
|
664
|
+
drop_input_cols = self._drop_input_cols,
|
659
665
|
expected_output_cols_type = expected_dtype,
|
660
666
|
)
|
661
667
|
|
@@ -706,7 +712,7 @@ class XGBClassifier(BaseTransformer):
|
|
706
712
|
subproject=_SUBPROJECT,
|
707
713
|
)
|
708
714
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
709
|
-
|
715
|
+
drop_input_cols=self._drop_input_cols,
|
710
716
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
711
717
|
)
|
712
718
|
self._sklearn_object = fitted_estimator
|
@@ -724,44 +730,6 @@ class XGBClassifier(BaseTransformer):
|
|
724
730
|
assert self._sklearn_object is not None
|
725
731
|
return self._sklearn_object.embedding_
|
726
732
|
|
727
|
-
|
728
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
729
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
730
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
731
|
-
"""
|
732
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
733
|
-
if output_cols:
|
734
|
-
output_cols = [
|
735
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
736
|
-
for c in output_cols
|
737
|
-
]
|
738
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
739
|
-
output_cols = [output_cols_prefix]
|
740
|
-
elif self._sklearn_object is not None:
|
741
|
-
classes = self._sklearn_object.classes_
|
742
|
-
if isinstance(classes, numpy.ndarray):
|
743
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
744
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
745
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
746
|
-
output_cols = []
|
747
|
-
for i, cl in enumerate(classes):
|
748
|
-
# For binary classification, there is only one output column for each class
|
749
|
-
# ndarray as the two classes are complementary.
|
750
|
-
if len(cl) == 2:
|
751
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
752
|
-
else:
|
753
|
-
output_cols.extend([
|
754
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
755
|
-
])
|
756
|
-
else:
|
757
|
-
output_cols = []
|
758
|
-
|
759
|
-
# Make sure column names are valid snowflake identifiers.
|
760
|
-
assert output_cols is not None # Make MyPy happy
|
761
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
762
|
-
|
763
|
-
return rv
|
764
|
-
|
765
733
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
766
734
|
@telemetry.send_api_usage_telemetry(
|
767
735
|
project=_PROJECT,
|
@@ -803,7 +771,7 @@ class XGBClassifier(BaseTransformer):
|
|
803
771
|
transform_kwargs = dict(
|
804
772
|
session=dataset._session,
|
805
773
|
dependencies=self._deps,
|
806
|
-
|
774
|
+
drop_input_cols = self._drop_input_cols,
|
807
775
|
expected_output_cols_type="float",
|
808
776
|
)
|
809
777
|
|
@@ -870,7 +838,7 @@ class XGBClassifier(BaseTransformer):
|
|
870
838
|
transform_kwargs = dict(
|
871
839
|
session=dataset._session,
|
872
840
|
dependencies=self._deps,
|
873
|
-
|
841
|
+
drop_input_cols = self._drop_input_cols,
|
874
842
|
expected_output_cols_type="float",
|
875
843
|
)
|
876
844
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -931,7 +899,7 @@ class XGBClassifier(BaseTransformer):
|
|
931
899
|
transform_kwargs = dict(
|
932
900
|
session=dataset._session,
|
933
901
|
dependencies=self._deps,
|
934
|
-
|
902
|
+
drop_input_cols = self._drop_input_cols,
|
935
903
|
expected_output_cols_type="float",
|
936
904
|
)
|
937
905
|
|
@@ -996,7 +964,7 @@ class XGBClassifier(BaseTransformer):
|
|
996
964
|
transform_kwargs = dict(
|
997
965
|
session=dataset._session,
|
998
966
|
dependencies=self._deps,
|
999
|
-
|
967
|
+
drop_input_cols = self._drop_input_cols,
|
1000
968
|
expected_output_cols_type="float",
|
1001
969
|
)
|
1002
970
|
|
@@ -1052,13 +1020,17 @@ class XGBClassifier(BaseTransformer):
|
|
1052
1020
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1053
1021
|
|
1054
1022
|
if isinstance(dataset, DataFrame):
|
1023
|
+
self._deps = self._batch_inference_validate_snowpark(
|
1024
|
+
dataset=dataset,
|
1025
|
+
inference_method="score",
|
1026
|
+
)
|
1055
1027
|
selected_cols = self._get_active_columns()
|
1056
1028
|
if len(selected_cols) > 0:
|
1057
1029
|
dataset = dataset.select(selected_cols)
|
1058
1030
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1059
1031
|
transform_kwargs = dict(
|
1060
1032
|
session=dataset._session,
|
1061
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
1033
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
1062
1034
|
score_sproc_imports=['xgboost'],
|
1063
1035
|
)
|
1064
1036
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1132,9 +1104,9 @@ class XGBClassifier(BaseTransformer):
|
|
1132
1104
|
transform_kwargs = dict(
|
1133
1105
|
session = dataset._session,
|
1134
1106
|
dependencies = self._deps,
|
1135
|
-
|
1136
|
-
expected_output_cols_type
|
1137
|
-
n_neighbors =
|
1107
|
+
drop_input_cols = self._drop_input_cols,
|
1108
|
+
expected_output_cols_type="array",
|
1109
|
+
n_neighbors = n_neighbors,
|
1138
1110
|
return_distance = return_distance
|
1139
1111
|
)
|
1140
1112
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -482,18 +482,24 @@ class XGBRegressor(BaseTransformer):
|
|
482
482
|
self._get_model_signatures(dataset)
|
483
483
|
return self
|
484
484
|
|
485
|
-
def _get_pass_through_columns(self, dataset: DataFrame) -> List[str]:
|
486
|
-
if self._drop_input_cols:
|
487
|
-
return []
|
488
|
-
else:
|
489
|
-
return list(set(dataset.columns) - set(self.output_cols))
|
490
|
-
|
491
485
|
def _batch_inference_validate_snowpark(
|
492
486
|
self,
|
493
487
|
dataset: DataFrame,
|
494
488
|
inference_method: str,
|
495
489
|
) -> List[str]:
|
496
|
-
"""Util method to run validate that batch inference can be run on a snowpark dataframe
|
490
|
+
"""Util method to run validate that batch inference can be run on a snowpark dataframe and
|
491
|
+
return the available package that exists in the snowflake anaconda channel
|
492
|
+
|
493
|
+
Args:
|
494
|
+
dataset: snowpark dataframe
|
495
|
+
inference_method: the inference method such as predict, score...
|
496
|
+
|
497
|
+
Raises:
|
498
|
+
SnowflakeMLException: If the estimator is not fitted, raise error
|
499
|
+
SnowflakeMLException: If the session is None, raise error
|
500
|
+
|
501
|
+
Returns:
|
502
|
+
A list of available package that exists in the snowflake anaconda channel
|
497
503
|
"""
|
498
504
|
if not self._is_fitted:
|
499
505
|
raise exceptions.SnowflakeMLException(
|
@@ -567,7 +573,7 @@ class XGBRegressor(BaseTransformer):
|
|
567
573
|
transform_kwargs = dict(
|
568
574
|
session = dataset._session,
|
569
575
|
dependencies = self._deps,
|
570
|
-
|
576
|
+
drop_input_cols = self._drop_input_cols,
|
571
577
|
expected_output_cols_type = expected_type_inferred,
|
572
578
|
)
|
573
579
|
|
@@ -627,16 +633,16 @@ class XGBRegressor(BaseTransformer):
|
|
627
633
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
628
634
|
# based on init param values. We will convert that to pandas dataframe of shape (n_samples, 1) with
|
629
635
|
# each row containing a list of values.
|
630
|
-
expected_dtype = "
|
636
|
+
expected_dtype = "array"
|
631
637
|
|
632
638
|
# If we were unable to assign a type to this transform in the factory, infer the type here.
|
633
639
|
if expected_dtype == "":
|
634
|
-
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "
|
640
|
+
# If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "array"
|
635
641
|
if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
|
636
|
-
expected_dtype = "
|
637
|
-
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "
|
642
|
+
expected_dtype = "array"
|
643
|
+
# If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "array"
|
638
644
|
elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
|
639
|
-
expected_dtype = "
|
645
|
+
expected_dtype = "array"
|
640
646
|
else:
|
641
647
|
output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
|
642
648
|
# We can only infer the output types from the input types if the following two statemetns are true:
|
@@ -654,7 +660,7 @@ class XGBRegressor(BaseTransformer):
|
|
654
660
|
transform_kwargs = dict(
|
655
661
|
session = dataset._session,
|
656
662
|
dependencies = self._deps,
|
657
|
-
|
663
|
+
drop_input_cols = self._drop_input_cols,
|
658
664
|
expected_output_cols_type = expected_dtype,
|
659
665
|
)
|
660
666
|
|
@@ -705,7 +711,7 @@ class XGBRegressor(BaseTransformer):
|
|
705
711
|
subproject=_SUBPROJECT,
|
706
712
|
)
|
707
713
|
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
708
|
-
|
714
|
+
drop_input_cols=self._drop_input_cols,
|
709
715
|
expected_output_cols_list=self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix),
|
710
716
|
)
|
711
717
|
self._sklearn_object = fitted_estimator
|
@@ -723,44 +729,6 @@ class XGBRegressor(BaseTransformer):
|
|
723
729
|
assert self._sklearn_object is not None
|
724
730
|
return self._sklearn_object.embedding_
|
725
731
|
|
726
|
-
|
727
|
-
def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
|
728
|
-
""" Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
|
729
|
-
Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
|
730
|
-
"""
|
731
|
-
output_cols_prefix = identifier.resolve_identifier(output_cols_prefix)
|
732
|
-
if output_cols:
|
733
|
-
output_cols = [
|
734
|
-
identifier.concat_names([output_cols_prefix, identifier.resolve_identifier(c)])
|
735
|
-
for c in output_cols
|
736
|
-
]
|
737
|
-
elif getattr(self._sklearn_object, "classes_", None) is None:
|
738
|
-
output_cols = [output_cols_prefix]
|
739
|
-
elif self._sklearn_object is not None:
|
740
|
-
classes = self._sklearn_object.classes_
|
741
|
-
if isinstance(classes, numpy.ndarray):
|
742
|
-
output_cols = [f'{output_cols_prefix}{str(c)}' for c in classes.tolist()]
|
743
|
-
elif isinstance(classes, list) and len(classes) > 0 and isinstance(classes[0], numpy.ndarray):
|
744
|
-
# If the estimator is a multioutput estimator, classes_ will be a list of ndarrays.
|
745
|
-
output_cols = []
|
746
|
-
for i, cl in enumerate(classes):
|
747
|
-
# For binary classification, there is only one output column for each class
|
748
|
-
# ndarray as the two classes are complementary.
|
749
|
-
if len(cl) == 2:
|
750
|
-
output_cols.append(f'{output_cols_prefix}{i}_{cl[0]}')
|
751
|
-
else:
|
752
|
-
output_cols.extend([
|
753
|
-
f'{output_cols_prefix}{i}_{c}' for c in cl.tolist()
|
754
|
-
])
|
755
|
-
else:
|
756
|
-
output_cols = []
|
757
|
-
|
758
|
-
# Make sure column names are valid snowflake identifiers.
|
759
|
-
assert output_cols is not None # Make MyPy happy
|
760
|
-
rv = [identifier.rename_to_valid_snowflake_identifier(c) for c in output_cols]
|
761
|
-
|
762
|
-
return rv
|
763
|
-
|
764
732
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
765
733
|
@telemetry.send_api_usage_telemetry(
|
766
734
|
project=_PROJECT,
|
@@ -800,7 +768,7 @@ class XGBRegressor(BaseTransformer):
|
|
800
768
|
transform_kwargs = dict(
|
801
769
|
session=dataset._session,
|
802
770
|
dependencies=self._deps,
|
803
|
-
|
771
|
+
drop_input_cols = self._drop_input_cols,
|
804
772
|
expected_output_cols_type="float",
|
805
773
|
)
|
806
774
|
|
@@ -865,7 +833,7 @@ class XGBRegressor(BaseTransformer):
|
|
865
833
|
transform_kwargs = dict(
|
866
834
|
session=dataset._session,
|
867
835
|
dependencies=self._deps,
|
868
|
-
|
836
|
+
drop_input_cols = self._drop_input_cols,
|
869
837
|
expected_output_cols_type="float",
|
870
838
|
)
|
871
839
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -926,7 +894,7 @@ class XGBRegressor(BaseTransformer):
|
|
926
894
|
transform_kwargs = dict(
|
927
895
|
session=dataset._session,
|
928
896
|
dependencies=self._deps,
|
929
|
-
|
897
|
+
drop_input_cols = self._drop_input_cols,
|
930
898
|
expected_output_cols_type="float",
|
931
899
|
)
|
932
900
|
|
@@ -991,7 +959,7 @@ class XGBRegressor(BaseTransformer):
|
|
991
959
|
transform_kwargs = dict(
|
992
960
|
session=dataset._session,
|
993
961
|
dependencies=self._deps,
|
994
|
-
|
962
|
+
drop_input_cols = self._drop_input_cols,
|
995
963
|
expected_output_cols_type="float",
|
996
964
|
)
|
997
965
|
|
@@ -1047,13 +1015,17 @@ class XGBRegressor(BaseTransformer):
|
|
1047
1015
|
transform_kwargs: ScoreKwargsTypedDict = dict()
|
1048
1016
|
|
1049
1017
|
if isinstance(dataset, DataFrame):
|
1018
|
+
self._deps = self._batch_inference_validate_snowpark(
|
1019
|
+
dataset=dataset,
|
1020
|
+
inference_method="score",
|
1021
|
+
)
|
1050
1022
|
selected_cols = self._get_active_columns()
|
1051
1023
|
if len(selected_cols) > 0:
|
1052
1024
|
dataset = dataset.select(selected_cols)
|
1053
1025
|
assert isinstance(dataset._session, Session) # keep mypy happy
|
1054
1026
|
transform_kwargs = dict(
|
1055
1027
|
session=dataset._session,
|
1056
|
-
dependencies=["snowflake-snowpark-python"] + self.
|
1028
|
+
dependencies=["snowflake-snowpark-python"] + self._deps,
|
1057
1029
|
score_sproc_imports=['xgboost'],
|
1058
1030
|
)
|
1059
1031
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -1127,9 +1099,9 @@ class XGBRegressor(BaseTransformer):
|
|
1127
1099
|
transform_kwargs = dict(
|
1128
1100
|
session = dataset._session,
|
1129
1101
|
dependencies = self._deps,
|
1130
|
-
|
1131
|
-
expected_output_cols_type
|
1132
|
-
n_neighbors =
|
1102
|
+
drop_input_cols = self._drop_input_cols,
|
1103
|
+
expected_output_cols_type="array",
|
1104
|
+
n_neighbors = n_neighbors,
|
1133
1105
|
return_distance = return_distance
|
1134
1106
|
)
|
1135
1107
|
elif isinstance(dataset, pd.DataFrame):
|