snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -577,12 +574,23 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
577
574
|
autogenerated=self._autogenerated,
|
578
575
|
subproject=_SUBPROJECT,
|
579
576
|
)
|
580
|
-
|
581
|
-
|
582
|
-
expected_output_cols_list=(
|
583
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
584
|
-
),
|
577
|
+
expected_output_cols = (
|
578
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
585
579
|
)
|
580
|
+
if isinstance(dataset, DataFrame):
|
581
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
582
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
583
|
+
)
|
584
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
585
|
+
drop_input_cols=self._drop_input_cols,
|
586
|
+
expected_output_cols_list=expected_output_cols,
|
587
|
+
example_output_pd_df=example_output_pd_df,
|
588
|
+
)
|
589
|
+
else:
|
590
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
591
|
+
drop_input_cols=self._drop_input_cols,
|
592
|
+
expected_output_cols_list=expected_output_cols,
|
593
|
+
)
|
586
594
|
self._sklearn_object = fitted_estimator
|
587
595
|
self._is_fitted = True
|
588
596
|
return output_result
|
@@ -661,12 +669,41 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
661
669
|
|
662
670
|
return rv
|
663
671
|
|
664
|
-
def
|
665
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
666
|
-
) -> List[str]:
|
672
|
+
def _align_expected_output(
|
673
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
674
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
675
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
676
|
+
and output dataframe with 1 line.
|
677
|
+
If the method is fit_predict, run 2 lines of data.
|
678
|
+
"""
|
667
679
|
# in case the inferred output column names dimension is different
|
668
680
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
669
|
-
|
681
|
+
|
682
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
683
|
+
# so change the minimum of number of rows to 2
|
684
|
+
num_examples = 2
|
685
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
686
|
+
project=_PROJECT,
|
687
|
+
subproject=_SUBPROJECT,
|
688
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
689
|
+
inspect.currentframe(), ExtraTreeRegressor.__class__.__name__
|
690
|
+
),
|
691
|
+
api_calls=[Session.call],
|
692
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
693
|
+
)
|
694
|
+
if output_cols_prefix == "fit_predict_":
|
695
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
696
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
697
|
+
num_examples = self._sklearn_object.n_clusters
|
698
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
699
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
700
|
+
num_examples = self._sklearn_object.min_samples
|
701
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
702
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
703
|
+
num_examples = self._sklearn_object.n_neighbors
|
704
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
705
|
+
else:
|
706
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
670
707
|
|
671
708
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
672
709
|
# seen during the fit.
|
@@ -678,12 +715,14 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
678
715
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
679
716
|
if self.sample_weight_col:
|
680
717
|
output_df_columns_set -= set(self.sample_weight_col)
|
718
|
+
|
681
719
|
# if the dimension of inferred output column names is correct; use it
|
682
720
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
683
|
-
return expected_output_cols_list
|
721
|
+
return expected_output_cols_list, output_df_pd
|
684
722
|
# otherwise, use the sklearn estimator's output
|
685
723
|
else:
|
686
|
-
|
724
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
725
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
687
726
|
|
688
727
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
689
728
|
@telemetry.send_api_usage_telemetry(
|
@@ -729,7 +768,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
729
768
|
drop_input_cols=self._drop_input_cols,
|
730
769
|
expected_output_cols_type="float",
|
731
770
|
)
|
732
|
-
expected_output_cols = self.
|
771
|
+
expected_output_cols, _ = self._align_expected_output(
|
733
772
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
734
773
|
)
|
735
774
|
|
@@ -795,7 +834,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
795
834
|
drop_input_cols=self._drop_input_cols,
|
796
835
|
expected_output_cols_type="float",
|
797
836
|
)
|
798
|
-
expected_output_cols = self.
|
837
|
+
expected_output_cols, _ = self._align_expected_output(
|
799
838
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
800
839
|
)
|
801
840
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -858,7 +897,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
858
897
|
drop_input_cols=self._drop_input_cols,
|
859
898
|
expected_output_cols_type="float",
|
860
899
|
)
|
861
|
-
expected_output_cols = self.
|
900
|
+
expected_output_cols, _ = self._align_expected_output(
|
862
901
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
863
902
|
)
|
864
903
|
|
@@ -923,7 +962,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
923
962
|
drop_input_cols = self._drop_input_cols,
|
924
963
|
expected_output_cols_type="float",
|
925
964
|
)
|
926
|
-
expected_output_cols = self.
|
965
|
+
expected_output_cols, _ = self._align_expected_output(
|
927
966
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
928
967
|
)
|
929
968
|
|
@@ -4,18 +4,17 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
18
16
|
import numpy
|
17
|
+
import sklearn
|
19
18
|
import xgboost
|
20
19
|
from sklearn.utils.metaestimators import available_if
|
21
20
|
|
@@ -23,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
23
22
|
from snowflake.ml._internal import telemetry
|
24
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
25
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
26
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
27
26
|
from snowflake.snowpark import DataFrame, Session
|
28
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
29
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
30
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
31
|
-
ModelTransformHandlers,
|
32
30
|
BatchInferenceKwargsTypedDict,
|
33
31
|
ScoreKwargsTypedDict
|
34
32
|
)
|
@@ -361,7 +359,7 @@ class XGBClassifier(BaseTransformer):
|
|
361
359
|
self.set_sample_weight_col(sample_weight_col)
|
362
360
|
self._use_external_memory_version = use_external_memory_version
|
363
361
|
self._batch_size = batch_size
|
364
|
-
deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
362
|
+
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
365
363
|
|
366
364
|
self._deps = list(deps)
|
367
365
|
|
@@ -695,12 +693,23 @@ class XGBClassifier(BaseTransformer):
|
|
695
693
|
autogenerated=self._autogenerated,
|
696
694
|
subproject=_SUBPROJECT,
|
697
695
|
)
|
698
|
-
|
699
|
-
|
700
|
-
expected_output_cols_list=(
|
701
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
702
|
-
),
|
696
|
+
expected_output_cols = (
|
697
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
703
698
|
)
|
699
|
+
if isinstance(dataset, DataFrame):
|
700
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
701
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
702
|
+
)
|
703
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
704
|
+
drop_input_cols=self._drop_input_cols,
|
705
|
+
expected_output_cols_list=expected_output_cols,
|
706
|
+
example_output_pd_df=example_output_pd_df,
|
707
|
+
)
|
708
|
+
else:
|
709
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
710
|
+
drop_input_cols=self._drop_input_cols,
|
711
|
+
expected_output_cols_list=expected_output_cols,
|
712
|
+
)
|
704
713
|
self._sklearn_object = fitted_estimator
|
705
714
|
self._is_fitted = True
|
706
715
|
return output_result
|
@@ -779,12 +788,41 @@ class XGBClassifier(BaseTransformer):
|
|
779
788
|
|
780
789
|
return rv
|
781
790
|
|
782
|
-
def
|
783
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
784
|
-
) -> List[str]:
|
791
|
+
def _align_expected_output(
|
792
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
793
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
794
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
795
|
+
and output dataframe with 1 line.
|
796
|
+
If the method is fit_predict, run 2 lines of data.
|
797
|
+
"""
|
785
798
|
# in case the inferred output column names dimension is different
|
786
799
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
787
|
-
|
800
|
+
|
801
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
802
|
+
# so change the minimum of number of rows to 2
|
803
|
+
num_examples = 2
|
804
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
805
|
+
project=_PROJECT,
|
806
|
+
subproject=_SUBPROJECT,
|
807
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
808
|
+
inspect.currentframe(), XGBClassifier.__class__.__name__
|
809
|
+
),
|
810
|
+
api_calls=[Session.call],
|
811
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
812
|
+
)
|
813
|
+
if output_cols_prefix == "fit_predict_":
|
814
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
815
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
816
|
+
num_examples = self._sklearn_object.n_clusters
|
817
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
818
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
819
|
+
num_examples = self._sklearn_object.min_samples
|
820
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
821
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
822
|
+
num_examples = self._sklearn_object.n_neighbors
|
823
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
824
|
+
else:
|
825
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
788
826
|
|
789
827
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
790
828
|
# seen during the fit.
|
@@ -796,12 +834,14 @@ class XGBClassifier(BaseTransformer):
|
|
796
834
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
797
835
|
if self.sample_weight_col:
|
798
836
|
output_df_columns_set -= set(self.sample_weight_col)
|
837
|
+
|
799
838
|
# if the dimension of inferred output column names is correct; use it
|
800
839
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
801
|
-
return expected_output_cols_list
|
840
|
+
return expected_output_cols_list, output_df_pd
|
802
841
|
# otherwise, use the sklearn estimator's output
|
803
842
|
else:
|
804
|
-
|
843
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
844
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
805
845
|
|
806
846
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
807
847
|
@telemetry.send_api_usage_telemetry(
|
@@ -849,7 +889,7 @@ class XGBClassifier(BaseTransformer):
|
|
849
889
|
drop_input_cols=self._drop_input_cols,
|
850
890
|
expected_output_cols_type="float",
|
851
891
|
)
|
852
|
-
expected_output_cols = self.
|
892
|
+
expected_output_cols, _ = self._align_expected_output(
|
853
893
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
854
894
|
)
|
855
895
|
|
@@ -917,7 +957,7 @@ class XGBClassifier(BaseTransformer):
|
|
917
957
|
drop_input_cols=self._drop_input_cols,
|
918
958
|
expected_output_cols_type="float",
|
919
959
|
)
|
920
|
-
expected_output_cols = self.
|
960
|
+
expected_output_cols, _ = self._align_expected_output(
|
921
961
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
922
962
|
)
|
923
963
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -980,7 +1020,7 @@ class XGBClassifier(BaseTransformer):
|
|
980
1020
|
drop_input_cols=self._drop_input_cols,
|
981
1021
|
expected_output_cols_type="float",
|
982
1022
|
)
|
983
|
-
expected_output_cols = self.
|
1023
|
+
expected_output_cols, _ = self._align_expected_output(
|
984
1024
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
985
1025
|
)
|
986
1026
|
|
@@ -1045,7 +1085,7 @@ class XGBClassifier(BaseTransformer):
|
|
1045
1085
|
drop_input_cols = self._drop_input_cols,
|
1046
1086
|
expected_output_cols_type="float",
|
1047
1087
|
)
|
1048
|
-
expected_output_cols = self.
|
1088
|
+
expected_output_cols, _ = self._align_expected_output(
|
1049
1089
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1050
1090
|
)
|
1051
1091
|
|
@@ -1110,7 +1150,7 @@ class XGBClassifier(BaseTransformer):
|
|
1110
1150
|
transform_kwargs = dict(
|
1111
1151
|
session=dataset._session,
|
1112
1152
|
dependencies=self._deps,
|
1113
|
-
score_sproc_imports=['xgboost'],
|
1153
|
+
score_sproc_imports=['xgboost', 'sklearn'],
|
1114
1154
|
)
|
1115
1155
|
elif isinstance(dataset, pd.DataFrame):
|
1116
1156
|
# pandas_handler.score() does not require any extra kwargs.
|
@@ -4,18 +4,17 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
18
16
|
import numpy
|
17
|
+
import sklearn
|
19
18
|
import xgboost
|
20
19
|
from sklearn.utils.metaestimators import available_if
|
21
20
|
|
@@ -23,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
23
22
|
from snowflake.ml._internal import telemetry
|
24
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
25
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
26
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
27
26
|
from snowflake.snowpark import DataFrame, Session
|
28
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
29
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
30
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
31
|
-
ModelTransformHandlers,
|
32
30
|
BatchInferenceKwargsTypedDict,
|
33
31
|
ScoreKwargsTypedDict
|
34
32
|
)
|
@@ -361,7 +359,7 @@ class XGBRegressor(BaseTransformer):
|
|
361
359
|
self.set_sample_weight_col(sample_weight_col)
|
362
360
|
self._use_external_memory_version = use_external_memory_version
|
363
361
|
self._batch_size = batch_size
|
364
|
-
deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
362
|
+
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
365
363
|
|
366
364
|
self._deps = list(deps)
|
367
365
|
|
@@ -694,12 +692,23 @@ class XGBRegressor(BaseTransformer):
|
|
694
692
|
autogenerated=self._autogenerated,
|
695
693
|
subproject=_SUBPROJECT,
|
696
694
|
)
|
697
|
-
|
698
|
-
|
699
|
-
expected_output_cols_list=(
|
700
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
701
|
-
),
|
695
|
+
expected_output_cols = (
|
696
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
702
697
|
)
|
698
|
+
if isinstance(dataset, DataFrame):
|
699
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
700
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
701
|
+
)
|
702
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
703
|
+
drop_input_cols=self._drop_input_cols,
|
704
|
+
expected_output_cols_list=expected_output_cols,
|
705
|
+
example_output_pd_df=example_output_pd_df,
|
706
|
+
)
|
707
|
+
else:
|
708
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
709
|
+
drop_input_cols=self._drop_input_cols,
|
710
|
+
expected_output_cols_list=expected_output_cols,
|
711
|
+
)
|
703
712
|
self._sklearn_object = fitted_estimator
|
704
713
|
self._is_fitted = True
|
705
714
|
return output_result
|
@@ -778,12 +787,41 @@ class XGBRegressor(BaseTransformer):
|
|
778
787
|
|
779
788
|
return rv
|
780
789
|
|
781
|
-
def
|
782
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
783
|
-
) -> List[str]:
|
790
|
+
def _align_expected_output(
|
791
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
792
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
793
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
794
|
+
and output dataframe with 1 line.
|
795
|
+
If the method is fit_predict, run 2 lines of data.
|
796
|
+
"""
|
784
797
|
# in case the inferred output column names dimension is different
|
785
798
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
786
|
-
|
799
|
+
|
800
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
801
|
+
# so change the minimum of number of rows to 2
|
802
|
+
num_examples = 2
|
803
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
804
|
+
project=_PROJECT,
|
805
|
+
subproject=_SUBPROJECT,
|
806
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
807
|
+
inspect.currentframe(), XGBRegressor.__class__.__name__
|
808
|
+
),
|
809
|
+
api_calls=[Session.call],
|
810
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
811
|
+
)
|
812
|
+
if output_cols_prefix == "fit_predict_":
|
813
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
814
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
815
|
+
num_examples = self._sklearn_object.n_clusters
|
816
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
817
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
818
|
+
num_examples = self._sklearn_object.min_samples
|
819
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
820
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
821
|
+
num_examples = self._sklearn_object.n_neighbors
|
822
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
823
|
+
else:
|
824
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
787
825
|
|
788
826
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
789
827
|
# seen during the fit.
|
@@ -795,12 +833,14 @@ class XGBRegressor(BaseTransformer):
|
|
795
833
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
796
834
|
if self.sample_weight_col:
|
797
835
|
output_df_columns_set -= set(self.sample_weight_col)
|
836
|
+
|
798
837
|
# if the dimension of inferred output column names is correct; use it
|
799
838
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
800
|
-
return expected_output_cols_list
|
839
|
+
return expected_output_cols_list, output_df_pd
|
801
840
|
# otherwise, use the sklearn estimator's output
|
802
841
|
else:
|
803
|
-
|
842
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
843
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
804
844
|
|
805
845
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
806
846
|
@telemetry.send_api_usage_telemetry(
|
@@ -846,7 +886,7 @@ class XGBRegressor(BaseTransformer):
|
|
846
886
|
drop_input_cols=self._drop_input_cols,
|
847
887
|
expected_output_cols_type="float",
|
848
888
|
)
|
849
|
-
expected_output_cols = self.
|
889
|
+
expected_output_cols, _ = self._align_expected_output(
|
850
890
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
851
891
|
)
|
852
892
|
|
@@ -912,7 +952,7 @@ class XGBRegressor(BaseTransformer):
|
|
912
952
|
drop_input_cols=self._drop_input_cols,
|
913
953
|
expected_output_cols_type="float",
|
914
954
|
)
|
915
|
-
expected_output_cols = self.
|
955
|
+
expected_output_cols, _ = self._align_expected_output(
|
916
956
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
917
957
|
)
|
918
958
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -975,7 +1015,7 @@ class XGBRegressor(BaseTransformer):
|
|
975
1015
|
drop_input_cols=self._drop_input_cols,
|
976
1016
|
expected_output_cols_type="float",
|
977
1017
|
)
|
978
|
-
expected_output_cols = self.
|
1018
|
+
expected_output_cols, _ = self._align_expected_output(
|
979
1019
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
980
1020
|
)
|
981
1021
|
|
@@ -1040,7 +1080,7 @@ class XGBRegressor(BaseTransformer):
|
|
1040
1080
|
drop_input_cols = self._drop_input_cols,
|
1041
1081
|
expected_output_cols_type="float",
|
1042
1082
|
)
|
1043
|
-
expected_output_cols = self.
|
1083
|
+
expected_output_cols, _ = self._align_expected_output(
|
1044
1084
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1045
1085
|
)
|
1046
1086
|
|
@@ -1105,7 +1145,7 @@ class XGBRegressor(BaseTransformer):
|
|
1105
1145
|
transform_kwargs = dict(
|
1106
1146
|
session=dataset._session,
|
1107
1147
|
dependencies=self._deps,
|
1108
|
-
score_sproc_imports=['xgboost'],
|
1148
|
+
score_sproc_imports=['xgboost', 'sklearn'],
|
1109
1149
|
)
|
1110
1150
|
elif isinstance(dataset, pd.DataFrame):
|
1111
1151
|
# pandas_handler.score() does not require any extra kwargs.
|