snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
24
22
|
from snowflake.ml._internal import telemetry
|
25
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
26
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
27
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
28
26
|
from snowflake.snowpark import DataFrame, Session
|
29
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
30
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
31
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
32
|
-
ModelTransformHandlers,
|
33
30
|
BatchInferenceKwargsTypedDict,
|
34
31
|
ScoreKwargsTypedDict
|
35
32
|
)
|
@@ -577,12 +574,23 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
577
574
|
autogenerated=self._autogenerated,
|
578
575
|
subproject=_SUBPROJECT,
|
579
576
|
)
|
580
|
-
|
581
|
-
|
582
|
-
expected_output_cols_list=(
|
583
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
584
|
-
),
|
577
|
+
expected_output_cols = (
|
578
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
585
579
|
)
|
580
|
+
if isinstance(dataset, DataFrame):
|
581
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
582
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
583
|
+
)
|
584
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
585
|
+
drop_input_cols=self._drop_input_cols,
|
586
|
+
expected_output_cols_list=expected_output_cols,
|
587
|
+
example_output_pd_df=example_output_pd_df,
|
588
|
+
)
|
589
|
+
else:
|
590
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
591
|
+
drop_input_cols=self._drop_input_cols,
|
592
|
+
expected_output_cols_list=expected_output_cols,
|
593
|
+
)
|
586
594
|
self._sklearn_object = fitted_estimator
|
587
595
|
self._is_fitted = True
|
588
596
|
return output_result
|
@@ -605,6 +613,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
605
613
|
"""
|
606
614
|
self._infer_input_output_cols(dataset)
|
607
615
|
super()._check_dataset_type(dataset)
|
616
|
+
|
608
617
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
609
618
|
estimator=self._sklearn_object,
|
610
619
|
dataset=dataset,
|
@@ -661,12 +670,41 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
661
670
|
|
662
671
|
return rv
|
663
672
|
|
664
|
-
def
|
665
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
666
|
-
) -> List[str]:
|
673
|
+
def _align_expected_output(
|
674
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
675
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
676
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
677
|
+
and output dataframe with 1 line.
|
678
|
+
If the method is fit_predict, run 2 lines of data.
|
679
|
+
"""
|
667
680
|
# in case the inferred output column names dimension is different
|
668
681
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
669
|
-
|
682
|
+
|
683
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
684
|
+
# so change the minimum of number of rows to 2
|
685
|
+
num_examples = 2
|
686
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
687
|
+
project=_PROJECT,
|
688
|
+
subproject=_SUBPROJECT,
|
689
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
690
|
+
inspect.currentframe(), ExtraTreeRegressor.__class__.__name__
|
691
|
+
),
|
692
|
+
api_calls=[Session.call],
|
693
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
694
|
+
)
|
695
|
+
if output_cols_prefix == "fit_predict_":
|
696
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
697
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
698
|
+
num_examples = self._sklearn_object.n_clusters
|
699
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
700
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
701
|
+
num_examples = self._sklearn_object.min_samples
|
702
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
703
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
704
|
+
num_examples = self._sklearn_object.n_neighbors
|
705
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
706
|
+
else:
|
707
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
670
708
|
|
671
709
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
672
710
|
# seen during the fit.
|
@@ -678,12 +716,14 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
678
716
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
679
717
|
if self.sample_weight_col:
|
680
718
|
output_df_columns_set -= set(self.sample_weight_col)
|
719
|
+
|
681
720
|
# if the dimension of inferred output column names is correct; use it
|
682
721
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
683
|
-
return expected_output_cols_list
|
722
|
+
return expected_output_cols_list, output_df_pd
|
684
723
|
# otherwise, use the sklearn estimator's output
|
685
724
|
else:
|
686
|
-
|
725
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
726
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
687
727
|
|
688
728
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
689
729
|
@telemetry.send_api_usage_telemetry(
|
@@ -729,7 +769,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
729
769
|
drop_input_cols=self._drop_input_cols,
|
730
770
|
expected_output_cols_type="float",
|
731
771
|
)
|
732
|
-
expected_output_cols = self.
|
772
|
+
expected_output_cols, _ = self._align_expected_output(
|
733
773
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
734
774
|
)
|
735
775
|
|
@@ -795,7 +835,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
795
835
|
drop_input_cols=self._drop_input_cols,
|
796
836
|
expected_output_cols_type="float",
|
797
837
|
)
|
798
|
-
expected_output_cols = self.
|
838
|
+
expected_output_cols, _ = self._align_expected_output(
|
799
839
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
800
840
|
)
|
801
841
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -858,7 +898,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
858
898
|
drop_input_cols=self._drop_input_cols,
|
859
899
|
expected_output_cols_type="float",
|
860
900
|
)
|
861
|
-
expected_output_cols = self.
|
901
|
+
expected_output_cols, _ = self._align_expected_output(
|
862
902
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
863
903
|
)
|
864
904
|
|
@@ -923,7 +963,7 @@ class ExtraTreeRegressor(BaseTransformer):
|
|
923
963
|
drop_input_cols = self._drop_input_cols,
|
924
964
|
expected_output_cols_type="float",
|
925
965
|
)
|
926
|
-
expected_output_cols = self.
|
966
|
+
expected_output_cols, _ = self._align_expected_output(
|
927
967
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
928
968
|
)
|
929
969
|
|
@@ -4,18 +4,17 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
18
16
|
import numpy
|
17
|
+
import sklearn
|
19
18
|
import xgboost
|
20
19
|
from sklearn.utils.metaestimators import available_if
|
21
20
|
|
@@ -23,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
23
22
|
from snowflake.ml._internal import telemetry
|
24
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
25
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
26
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
27
26
|
from snowflake.snowpark import DataFrame, Session
|
28
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
29
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
30
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
31
|
-
ModelTransformHandlers,
|
32
30
|
BatchInferenceKwargsTypedDict,
|
33
31
|
ScoreKwargsTypedDict
|
34
32
|
)
|
@@ -361,7 +359,7 @@ class XGBClassifier(BaseTransformer):
|
|
361
359
|
self.set_sample_weight_col(sample_weight_col)
|
362
360
|
self._use_external_memory_version = use_external_memory_version
|
363
361
|
self._batch_size = batch_size
|
364
|
-
deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
362
|
+
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
365
363
|
|
366
364
|
self._deps = list(deps)
|
367
365
|
|
@@ -695,12 +693,23 @@ class XGBClassifier(BaseTransformer):
|
|
695
693
|
autogenerated=self._autogenerated,
|
696
694
|
subproject=_SUBPROJECT,
|
697
695
|
)
|
698
|
-
|
699
|
-
|
700
|
-
expected_output_cols_list=(
|
701
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
702
|
-
),
|
696
|
+
expected_output_cols = (
|
697
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
703
698
|
)
|
699
|
+
if isinstance(dataset, DataFrame):
|
700
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
701
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
702
|
+
)
|
703
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
704
|
+
drop_input_cols=self._drop_input_cols,
|
705
|
+
expected_output_cols_list=expected_output_cols,
|
706
|
+
example_output_pd_df=example_output_pd_df,
|
707
|
+
)
|
708
|
+
else:
|
709
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
710
|
+
drop_input_cols=self._drop_input_cols,
|
711
|
+
expected_output_cols_list=expected_output_cols,
|
712
|
+
)
|
704
713
|
self._sklearn_object = fitted_estimator
|
705
714
|
self._is_fitted = True
|
706
715
|
return output_result
|
@@ -723,6 +732,7 @@ class XGBClassifier(BaseTransformer):
|
|
723
732
|
"""
|
724
733
|
self._infer_input_output_cols(dataset)
|
725
734
|
super()._check_dataset_type(dataset)
|
735
|
+
|
726
736
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
727
737
|
estimator=self._sklearn_object,
|
728
738
|
dataset=dataset,
|
@@ -779,12 +789,41 @@ class XGBClassifier(BaseTransformer):
|
|
779
789
|
|
780
790
|
return rv
|
781
791
|
|
782
|
-
def
|
783
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
784
|
-
) -> List[str]:
|
792
|
+
def _align_expected_output(
|
793
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
794
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
795
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
796
|
+
and output dataframe with 1 line.
|
797
|
+
If the method is fit_predict, run 2 lines of data.
|
798
|
+
"""
|
785
799
|
# in case the inferred output column names dimension is different
|
786
800
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
787
|
-
|
801
|
+
|
802
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
803
|
+
# so change the minimum of number of rows to 2
|
804
|
+
num_examples = 2
|
805
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
806
|
+
project=_PROJECT,
|
807
|
+
subproject=_SUBPROJECT,
|
808
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
809
|
+
inspect.currentframe(), XGBClassifier.__class__.__name__
|
810
|
+
),
|
811
|
+
api_calls=[Session.call],
|
812
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
813
|
+
)
|
814
|
+
if output_cols_prefix == "fit_predict_":
|
815
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
816
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
817
|
+
num_examples = self._sklearn_object.n_clusters
|
818
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
819
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
820
|
+
num_examples = self._sklearn_object.min_samples
|
821
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
822
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
823
|
+
num_examples = self._sklearn_object.n_neighbors
|
824
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
825
|
+
else:
|
826
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
788
827
|
|
789
828
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
790
829
|
# seen during the fit.
|
@@ -796,12 +835,14 @@ class XGBClassifier(BaseTransformer):
|
|
796
835
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
797
836
|
if self.sample_weight_col:
|
798
837
|
output_df_columns_set -= set(self.sample_weight_col)
|
838
|
+
|
799
839
|
# if the dimension of inferred output column names is correct; use it
|
800
840
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
801
|
-
return expected_output_cols_list
|
841
|
+
return expected_output_cols_list, output_df_pd
|
802
842
|
# otherwise, use the sklearn estimator's output
|
803
843
|
else:
|
804
|
-
|
844
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
845
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
805
846
|
|
806
847
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
807
848
|
@telemetry.send_api_usage_telemetry(
|
@@ -849,7 +890,7 @@ class XGBClassifier(BaseTransformer):
|
|
849
890
|
drop_input_cols=self._drop_input_cols,
|
850
891
|
expected_output_cols_type="float",
|
851
892
|
)
|
852
|
-
expected_output_cols = self.
|
893
|
+
expected_output_cols, _ = self._align_expected_output(
|
853
894
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
854
895
|
)
|
855
896
|
|
@@ -917,7 +958,7 @@ class XGBClassifier(BaseTransformer):
|
|
917
958
|
drop_input_cols=self._drop_input_cols,
|
918
959
|
expected_output_cols_type="float",
|
919
960
|
)
|
920
|
-
expected_output_cols = self.
|
961
|
+
expected_output_cols, _ = self._align_expected_output(
|
921
962
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
922
963
|
)
|
923
964
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -980,7 +1021,7 @@ class XGBClassifier(BaseTransformer):
|
|
980
1021
|
drop_input_cols=self._drop_input_cols,
|
981
1022
|
expected_output_cols_type="float",
|
982
1023
|
)
|
983
|
-
expected_output_cols = self.
|
1024
|
+
expected_output_cols, _ = self._align_expected_output(
|
984
1025
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
985
1026
|
)
|
986
1027
|
|
@@ -1045,7 +1086,7 @@ class XGBClassifier(BaseTransformer):
|
|
1045
1086
|
drop_input_cols = self._drop_input_cols,
|
1046
1087
|
expected_output_cols_type="float",
|
1047
1088
|
)
|
1048
|
-
expected_output_cols = self.
|
1089
|
+
expected_output_cols, _ = self._align_expected_output(
|
1049
1090
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1050
1091
|
)
|
1051
1092
|
|
@@ -1110,7 +1151,7 @@ class XGBClassifier(BaseTransformer):
|
|
1110
1151
|
transform_kwargs = dict(
|
1111
1152
|
session=dataset._session,
|
1112
1153
|
dependencies=self._deps,
|
1113
|
-
score_sproc_imports=['xgboost'],
|
1154
|
+
score_sproc_imports=['xgboost', 'sklearn'],
|
1114
1155
|
)
|
1115
1156
|
elif isinstance(dataset, pd.DataFrame):
|
1116
1157
|
# pandas_handler.score() does not require any extra kwargs.
|
@@ -4,18 +4,17 @@
|
|
4
4
|
#
|
5
5
|
import inspect
|
6
6
|
import os
|
7
|
-
import
|
8
|
-
from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
|
9
|
-
from typing_extensions import TypeGuard
|
7
|
+
from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
|
10
8
|
from uuid import uuid4
|
11
9
|
|
12
10
|
import cloudpickle as cp
|
13
|
-
import pandas as pd
|
14
11
|
import numpy as np
|
12
|
+
import pandas as pd
|
15
13
|
from numpy import typing as npt
|
16
14
|
|
17
15
|
|
18
16
|
import numpy
|
17
|
+
import sklearn
|
19
18
|
import xgboost
|
20
19
|
from sklearn.utils.metaestimators import available_if
|
21
20
|
|
@@ -23,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
|
|
23
22
|
from snowflake.ml._internal import telemetry
|
24
23
|
from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
|
25
24
|
from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
|
26
|
-
from snowflake.ml._internal.utils import
|
25
|
+
from snowflake.ml._internal.utils import identifier
|
27
26
|
from snowflake.snowpark import DataFrame, Session
|
28
27
|
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
|
29
28
|
from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
|
30
29
|
from snowflake.ml.modeling._internal.transformer_protocols import (
|
31
|
-
ModelTransformHandlers,
|
32
30
|
BatchInferenceKwargsTypedDict,
|
33
31
|
ScoreKwargsTypedDict
|
34
32
|
)
|
@@ -361,7 +359,7 @@ class XGBRegressor(BaseTransformer):
|
|
361
359
|
self.set_sample_weight_col(sample_weight_col)
|
362
360
|
self._use_external_memory_version = use_external_memory_version
|
363
361
|
self._batch_size = batch_size
|
364
|
-
deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
362
|
+
deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
|
365
363
|
|
366
364
|
self._deps = list(deps)
|
367
365
|
|
@@ -694,12 +692,23 @@ class XGBRegressor(BaseTransformer):
|
|
694
692
|
autogenerated=self._autogenerated,
|
695
693
|
subproject=_SUBPROJECT,
|
696
694
|
)
|
697
|
-
|
698
|
-
|
699
|
-
expected_output_cols_list=(
|
700
|
-
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
701
|
-
),
|
695
|
+
expected_output_cols = (
|
696
|
+
self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
|
702
697
|
)
|
698
|
+
if isinstance(dataset, DataFrame):
|
699
|
+
expected_output_cols, example_output_pd_df = self._align_expected_output(
|
700
|
+
"fit_predict", dataset, expected_output_cols, output_cols_prefix
|
701
|
+
)
|
702
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
703
|
+
drop_input_cols=self._drop_input_cols,
|
704
|
+
expected_output_cols_list=expected_output_cols,
|
705
|
+
example_output_pd_df=example_output_pd_df,
|
706
|
+
)
|
707
|
+
else:
|
708
|
+
output_result, fitted_estimator = model_trainer.train_fit_predict(
|
709
|
+
drop_input_cols=self._drop_input_cols,
|
710
|
+
expected_output_cols_list=expected_output_cols,
|
711
|
+
)
|
703
712
|
self._sklearn_object = fitted_estimator
|
704
713
|
self._is_fitted = True
|
705
714
|
return output_result
|
@@ -722,6 +731,7 @@ class XGBRegressor(BaseTransformer):
|
|
722
731
|
"""
|
723
732
|
self._infer_input_output_cols(dataset)
|
724
733
|
super()._check_dataset_type(dataset)
|
734
|
+
|
725
735
|
model_trainer = ModelTrainerBuilder.build_fit_transform(
|
726
736
|
estimator=self._sklearn_object,
|
727
737
|
dataset=dataset,
|
@@ -778,12 +788,41 @@ class XGBRegressor(BaseTransformer):
|
|
778
788
|
|
779
789
|
return rv
|
780
790
|
|
781
|
-
def
|
782
|
-
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
|
783
|
-
) -> List[str]:
|
791
|
+
def _align_expected_output(
|
792
|
+
self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
|
793
|
+
) -> Tuple[List[str], pd.DataFrame]:
|
794
|
+
""" Run 1 line of data with the desired method, and return one tuple that consists of the output column names
|
795
|
+
and output dataframe with 1 line.
|
796
|
+
If the method is fit_predict, run 2 lines of data.
|
797
|
+
"""
|
784
798
|
# in case the inferred output column names dimension is different
|
785
799
|
# we use one line of snowpark dataframe and put it into sklearn estimator using pandas
|
786
|
-
|
800
|
+
|
801
|
+
# For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
|
802
|
+
# so change the minimum of number of rows to 2
|
803
|
+
num_examples = 2
|
804
|
+
statement_params = telemetry.get_function_usage_statement_params(
|
805
|
+
project=_PROJECT,
|
806
|
+
subproject=_SUBPROJECT,
|
807
|
+
function_name=telemetry.get_statement_params_full_func_name(
|
808
|
+
inspect.currentframe(), XGBRegressor.__class__.__name__
|
809
|
+
),
|
810
|
+
api_calls=[Session.call],
|
811
|
+
custom_tags={"autogen": True} if self._autogenerated else None,
|
812
|
+
)
|
813
|
+
if output_cols_prefix == "fit_predict_":
|
814
|
+
if hasattr(self._sklearn_object, "n_clusters"):
|
815
|
+
# cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
|
816
|
+
num_examples = self._sklearn_object.n_clusters
|
817
|
+
elif hasattr(self._sklearn_object, "min_samples"):
|
818
|
+
# OPTICS default min_samples 5, which requires at least 5 lines of data
|
819
|
+
num_examples = self._sklearn_object.min_samples
|
820
|
+
elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
|
821
|
+
# LocalOutlierFactor expects n_neighbors <= n_samples
|
822
|
+
num_examples = self._sklearn_object.n_neighbors
|
823
|
+
sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
|
824
|
+
else:
|
825
|
+
sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
|
787
826
|
|
788
827
|
# Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
|
789
828
|
# seen during the fit.
|
@@ -795,12 +834,14 @@ class XGBRegressor(BaseTransformer):
|
|
795
834
|
output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
|
796
835
|
if self.sample_weight_col:
|
797
836
|
output_df_columns_set -= set(self.sample_weight_col)
|
837
|
+
|
798
838
|
# if the dimension of inferred output column names is correct; use it
|
799
839
|
if len(expected_output_cols_list) == len(output_df_columns_set):
|
800
|
-
return expected_output_cols_list
|
840
|
+
return expected_output_cols_list, output_df_pd
|
801
841
|
# otherwise, use the sklearn estimator's output
|
802
842
|
else:
|
803
|
-
|
843
|
+
expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
|
844
|
+
return expected_output_cols_list, output_df_pd[expected_output_cols_list]
|
804
845
|
|
805
846
|
@available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
|
806
847
|
@telemetry.send_api_usage_telemetry(
|
@@ -846,7 +887,7 @@ class XGBRegressor(BaseTransformer):
|
|
846
887
|
drop_input_cols=self._drop_input_cols,
|
847
888
|
expected_output_cols_type="float",
|
848
889
|
)
|
849
|
-
expected_output_cols = self.
|
890
|
+
expected_output_cols, _ = self._align_expected_output(
|
850
891
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
851
892
|
)
|
852
893
|
|
@@ -912,7 +953,7 @@ class XGBRegressor(BaseTransformer):
|
|
912
953
|
drop_input_cols=self._drop_input_cols,
|
913
954
|
expected_output_cols_type="float",
|
914
955
|
)
|
915
|
-
expected_output_cols = self.
|
956
|
+
expected_output_cols, _ = self._align_expected_output(
|
916
957
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
917
958
|
)
|
918
959
|
elif isinstance(dataset, pd.DataFrame):
|
@@ -975,7 +1016,7 @@ class XGBRegressor(BaseTransformer):
|
|
975
1016
|
drop_input_cols=self._drop_input_cols,
|
976
1017
|
expected_output_cols_type="float",
|
977
1018
|
)
|
978
|
-
expected_output_cols = self.
|
1019
|
+
expected_output_cols, _ = self._align_expected_output(
|
979
1020
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
980
1021
|
)
|
981
1022
|
|
@@ -1040,7 +1081,7 @@ class XGBRegressor(BaseTransformer):
|
|
1040
1081
|
drop_input_cols = self._drop_input_cols,
|
1041
1082
|
expected_output_cols_type="float",
|
1042
1083
|
)
|
1043
|
-
expected_output_cols = self.
|
1084
|
+
expected_output_cols, _ = self._align_expected_output(
|
1044
1085
|
inference_method, dataset, expected_output_cols, output_cols_prefix
|
1045
1086
|
)
|
1046
1087
|
|
@@ -1105,7 +1146,7 @@ class XGBRegressor(BaseTransformer):
|
|
1105
1146
|
transform_kwargs = dict(
|
1106
1147
|
session=dataset._session,
|
1107
1148
|
dependencies=self._deps,
|
1108
|
-
score_sproc_imports=['xgboost'],
|
1149
|
+
score_sproc_imports=['xgboost', 'sklearn'],
|
1109
1150
|
)
|
1110
1151
|
elif isinstance(dataset, pd.DataFrame):
|
1111
1152
|
# pandas_handler.score() does not require any extra kwargs.
|