snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import warnings
|
2
3
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
|
3
4
|
|
4
5
|
import cloudpickle
|
@@ -6,22 +7,21 @@ import numpy as np
|
|
6
7
|
import pandas as pd
|
7
8
|
from typing_extensions import TypeGuard, Unpack
|
8
9
|
|
9
|
-
import snowflake.snowpark.dataframe as sp_df
|
10
10
|
from snowflake.ml._internal import type_utils
|
11
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._packager.model_env import model_env
|
13
|
-
from snowflake.ml.model._packager.model_handlers import
|
13
|
+
from snowflake.ml.model._packager.model_handlers import (
|
14
|
+
_base,
|
15
|
+
_utils as handlers_utils,
|
16
|
+
model_objective_utils,
|
17
|
+
)
|
14
18
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
15
19
|
from snowflake.ml.model._packager.model_meta import (
|
16
20
|
model_blob_meta,
|
17
21
|
model_meta as model_meta_api,
|
18
22
|
model_meta_schema,
|
19
23
|
)
|
20
|
-
from snowflake.ml.model._signatures import
|
21
|
-
numpy_handler,
|
22
|
-
snowpark_handler,
|
23
|
-
utils as model_signature_utils,
|
24
|
-
)
|
24
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
25
25
|
|
26
26
|
if TYPE_CHECKING:
|
27
27
|
import sklearn.base
|
@@ -40,28 +40,14 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
41
41
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
42
42
|
|
43
|
-
DEFAULT_TARGET_METHODS = [
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
if isinstance(model, sklearn.pipeline.Pipeline):
|
53
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
54
|
-
if is_regressor(model):
|
55
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
56
|
-
if is_classifier(model):
|
57
|
-
classes_list = getattr(model, "classes_", [])
|
58
|
-
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
59
|
-
if isinstance(num_classes, int):
|
60
|
-
if num_classes > 2:
|
61
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
62
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
63
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
64
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
43
|
+
DEFAULT_TARGET_METHODS = [
|
44
|
+
"predict",
|
45
|
+
"transform",
|
46
|
+
"predict_proba",
|
47
|
+
"predict_log_proba",
|
48
|
+
"decision_function",
|
49
|
+
]
|
50
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
65
51
|
|
66
52
|
@classmethod
|
67
53
|
def can_handle(
|
@@ -106,32 +92,17 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
106
92
|
is_sub_model: Optional[bool] = False,
|
107
93
|
**kwargs: Unpack[model_types.SKLModelSaveOptions],
|
108
94
|
) -> None:
|
109
|
-
|
95
|
+
# setting None by default to distinguish if users did not set it
|
96
|
+
enable_explainability = kwargs.get("enable_explainability", None)
|
110
97
|
|
111
98
|
import sklearn.base
|
112
99
|
import sklearn.pipeline
|
113
100
|
|
114
101
|
assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
|
115
|
-
|
116
|
-
enable_explainability = kwargs.get("enable_explainability", False)
|
117
102
|
if enable_explainability:
|
118
|
-
#
|
119
|
-
if sample_input_data is None
|
120
|
-
|
121
|
-
):
|
122
|
-
raise ValueError(
|
123
|
-
"Sample input data is required to enable explainability. Currently we only support this for "
|
124
|
-
+ "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
|
125
|
-
)
|
126
|
-
sample_input_data_pandas = (
|
127
|
-
sample_input_data
|
128
|
-
if isinstance(sample_input_data, pd.DataFrame)
|
129
|
-
else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
130
|
-
)
|
131
|
-
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
|
132
|
-
os.makedirs(data_blob_path, exist_ok=True)
|
133
|
-
with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
|
134
|
-
sample_input_data_pandas.to_parquet(f)
|
103
|
+
# if users set it explicitly but no sample_input_data then error out
|
104
|
+
if sample_input_data is None:
|
105
|
+
raise ValueError("Sample input data is required to enable explainability.")
|
135
106
|
|
136
107
|
if not is_sub_model:
|
137
108
|
target_methods = handlers_utils.get_target_methods(
|
@@ -141,7 +112,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
141
112
|
)
|
142
113
|
|
143
114
|
def get_prediction(
|
144
|
-
target_method_name: str,
|
115
|
+
target_method_name: str,
|
116
|
+
sample_input_data: model_types.SupportedLocalDataType,
|
145
117
|
) -> model_types.SupportedLocalDataType:
|
146
118
|
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
147
119
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
@@ -159,15 +131,40 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
159
131
|
get_prediction_fn=get_prediction,
|
160
132
|
)
|
161
133
|
|
134
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
135
|
+
|
136
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
137
|
+
sample_input_data, model_meta, explain_target_method
|
138
|
+
)
|
139
|
+
|
140
|
+
model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(model)
|
141
|
+
model_meta.task = model_task_and_output_type.task
|
142
|
+
|
143
|
+
# if users did not ask then we enable if we have background data
|
144
|
+
if enable_explainability is None:
|
145
|
+
if background_data is None:
|
146
|
+
warnings.warn(
|
147
|
+
"sample_input_data should be provided to enable explainability by default",
|
148
|
+
category=UserWarning,
|
149
|
+
stacklevel=1,
|
150
|
+
)
|
151
|
+
enable_explainability = False
|
152
|
+
else:
|
153
|
+
enable_explainability = True
|
162
154
|
if enable_explainability:
|
163
|
-
|
164
|
-
|
165
|
-
|
155
|
+
handlers_utils.save_background_data(
|
156
|
+
model_blobs_dir_path,
|
157
|
+
cls.EXPLAIN_ARTIFACTS_DIR,
|
158
|
+
cls.BG_DATA_FILE_SUFFIX,
|
159
|
+
name,
|
160
|
+
background_data,
|
161
|
+
)
|
162
|
+
|
166
163
|
model_meta = handlers_utils.add_explain_method_signature(
|
167
164
|
model_meta=model_meta,
|
168
165
|
explain_method="explain",
|
169
|
-
target_method=
|
170
|
-
output_return_type=output_type,
|
166
|
+
target_method=explain_target_method,
|
167
|
+
output_return_type=model_task_and_output_type.output_type,
|
171
168
|
)
|
172
169
|
|
173
170
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -184,13 +181,12 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
184
181
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
185
182
|
|
186
183
|
if enable_explainability:
|
187
|
-
model_meta.env.include_if_absent(
|
188
|
-
|
189
|
-
check_local_version=True,
|
190
|
-
)
|
184
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
185
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
191
186
|
|
192
187
|
model_meta.env.include_if_absent(
|
193
|
-
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
|
188
|
+
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
|
189
|
+
check_local_version=True,
|
194
190
|
)
|
195
191
|
|
196
192
|
@classmethod
|
@@ -1,20 +1,27 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
+
from packaging import version
|
8
9
|
from typing_extensions import TypeGuard, Unpack
|
9
10
|
|
10
11
|
from snowflake.ml._internal import type_utils
|
12
|
+
from snowflake.ml._internal.exceptions import exceptions
|
11
13
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
14
|
from snowflake.ml.model._packager.model_env import model_env
|
13
|
-
from snowflake.ml.model._packager.model_handlers import
|
15
|
+
from snowflake.ml.model._packager.model_handlers import (
|
16
|
+
_base,
|
17
|
+
_utils as handlers_utils,
|
18
|
+
model_objective_utils,
|
19
|
+
)
|
14
20
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
15
21
|
from snowflake.ml.model._packager.model_meta import (
|
16
22
|
model_blob_meta,
|
17
23
|
model_meta as model_meta_api,
|
24
|
+
model_meta_schema,
|
18
25
|
)
|
19
26
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
20
27
|
|
@@ -36,6 +43,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
36
43
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
37
44
|
|
38
45
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
46
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
47
|
+
|
39
48
|
IS_AUTO_SIGNATURE = True
|
40
49
|
|
41
50
|
@classmethod
|
@@ -62,6 +71,60 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
62
71
|
|
63
72
|
return cast("BaseEstimator", model)
|
64
73
|
|
74
|
+
@classmethod
|
75
|
+
def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
|
76
|
+
from importlib import metadata as importlib_metadata
|
77
|
+
|
78
|
+
from packaging import version
|
79
|
+
|
80
|
+
local_version = None
|
81
|
+
|
82
|
+
try:
|
83
|
+
local_dist = importlib_metadata.distribution(pkg_name)
|
84
|
+
local_version = version.parse(local_dist.version)
|
85
|
+
except importlib_metadata.PackageNotFoundError:
|
86
|
+
pass
|
87
|
+
|
88
|
+
return local_version
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
|
92
|
+
|
93
|
+
local_xgb_version = cls._get_local_version_package("xgboost")
|
94
|
+
|
95
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
|
96
|
+
if enable_explainability:
|
97
|
+
warnings.warn(
|
98
|
+
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
99
|
+
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
100
|
+
category=UserWarning,
|
101
|
+
stacklevel=1,
|
102
|
+
)
|
103
|
+
return False
|
104
|
+
return True
|
105
|
+
|
106
|
+
@classmethod
|
107
|
+
def _get_supported_object_for_explainability(
|
108
|
+
cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
|
109
|
+
) -> Any:
|
110
|
+
from snowflake.ml.modeling import pipeline as snowml_pipeline
|
111
|
+
|
112
|
+
# handle pipeline objects separately
|
113
|
+
if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
|
114
|
+
return None
|
115
|
+
|
116
|
+
methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
|
117
|
+
for method_name in methods:
|
118
|
+
if hasattr(estimator, method_name):
|
119
|
+
try:
|
120
|
+
result = getattr(estimator, method_name)()
|
121
|
+
if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
|
122
|
+
return None
|
123
|
+
return result
|
124
|
+
except exceptions.SnowflakeMLException:
|
125
|
+
pass # Do nothing and continue to the next method
|
126
|
+
return None
|
127
|
+
|
65
128
|
@classmethod
|
66
129
|
def save_model(
|
67
130
|
cls,
|
@@ -73,9 +136,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
73
136
|
is_sub_model: Optional[bool] = False,
|
74
137
|
**kwargs: Unpack[model_types.SNOWModelSaveOptions],
|
75
138
|
) -> None:
|
76
|
-
|
77
|
-
|
78
|
-
raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
|
139
|
+
|
140
|
+
enable_explainability = kwargs.get("enable_explainability", None)
|
79
141
|
|
80
142
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
81
143
|
|
@@ -83,9 +145,9 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
83
145
|
# Pipeline is inherited from BaseEstimator, so no need to add one more check
|
84
146
|
|
85
147
|
if not is_sub_model:
|
86
|
-
if
|
148
|
+
if model_meta.signatures:
|
87
149
|
warnings.warn(
|
88
|
-
"
|
150
|
+
"Providing model signature for Snowpark ML "
|
89
151
|
+ "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
|
90
152
|
UserWarning,
|
91
153
|
stacklevel=2,
|
@@ -105,6 +167,35 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
105
167
|
raise ValueError(f"Target method {method_name} does not exist in the model.")
|
106
168
|
model_meta.signatures = temp_model_signature_dict
|
107
169
|
|
170
|
+
if enable_explainability or enable_explainability is None:
|
171
|
+
python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
|
172
|
+
if python_base_obj is None:
|
173
|
+
if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
|
174
|
+
raise ValueError(
|
175
|
+
"Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
|
176
|
+
)
|
177
|
+
# set None to False so we don't include shap in the environment
|
178
|
+
enable_explainability = False
|
179
|
+
else:
|
180
|
+
model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(python_base_obj)
|
181
|
+
model_meta.task = model_task_and_output_type.task
|
182
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
183
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
184
|
+
model_meta=model_meta,
|
185
|
+
explain_method="explain",
|
186
|
+
target_method=explain_target_method,
|
187
|
+
output_return_type=model_task_and_output_type.output_type,
|
188
|
+
)
|
189
|
+
enable_explainability = True
|
190
|
+
|
191
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
192
|
+
sample_input_data, model_meta, explain_target_method
|
193
|
+
)
|
194
|
+
if background_data is not None:
|
195
|
+
handlers_utils.save_background_data(
|
196
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
197
|
+
)
|
198
|
+
|
108
199
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
109
200
|
os.makedirs(model_blob_path, exist_ok=True)
|
110
201
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
@@ -122,7 +213,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
122
213
|
model_dependencies = model._get_dependencies()
|
123
214
|
for dep in model_dependencies:
|
124
215
|
pkg_name = dep.split("==")[0]
|
125
|
-
|
216
|
+
if pkg_name != "xgboost":
|
217
|
+
_include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
|
218
|
+
continue
|
219
|
+
|
220
|
+
local_xgb_version = cls._get_local_version_package("xgboost")
|
221
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
222
|
+
model_meta.env.include_if_absent(
|
223
|
+
[
|
224
|
+
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
225
|
+
],
|
226
|
+
check_local_version=False,
|
227
|
+
)
|
228
|
+
else:
|
229
|
+
model_meta.env.include_if_absent(
|
230
|
+
[
|
231
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
232
|
+
],
|
233
|
+
check_local_version=True,
|
234
|
+
)
|
235
|
+
|
236
|
+
if enable_explainability:
|
237
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
238
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
126
239
|
model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
|
127
240
|
|
128
241
|
@classmethod
|
@@ -163,6 +276,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
163
276
|
raw_model: "BaseEstimator",
|
164
277
|
signature: model_signature.ModelSignature,
|
165
278
|
target_method: str,
|
279
|
+
background_data: Optional[pd.DataFrame] = None,
|
166
280
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
167
281
|
@custom_model.inference_api
|
168
282
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
@@ -177,11 +291,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
177
291
|
|
178
292
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
179
293
|
|
294
|
+
@custom_model.inference_api
|
295
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
296
|
+
import shap
|
297
|
+
|
298
|
+
methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
|
299
|
+
for method_name in methods:
|
300
|
+
try:
|
301
|
+
base_model = getattr(raw_model, method_name)()
|
302
|
+
explainer = shap.Explainer(base_model, masker=background_data)
|
303
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
304
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
305
|
+
except exceptions.SnowflakeMLException:
|
306
|
+
pass # Do nothing and continue to the next method
|
307
|
+
raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
|
308
|
+
|
309
|
+
if target_method == "explain":
|
310
|
+
return explain_fn
|
311
|
+
|
180
312
|
return fn
|
181
313
|
|
182
314
|
type_method_dict = {}
|
183
315
|
for target_method_name, sig in model_meta.signatures.items():
|
184
|
-
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
316
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
|
185
317
|
|
186
318
|
_SnowMLModel = type(
|
187
319
|
"_SnowMLModel",
|
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
111
111
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
112
112
|
os.makedirs(model_blob_path, exist_ok=True)
|
113
113
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
114
|
-
torch.jit.save(model, f) # type:ignore[attr-defined]
|
114
|
+
torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
|
115
115
|
base_meta = model_blob_meta.ModelBlobMeta(
|
116
116
|
name=name,
|
117
117
|
model_type=cls.HANDLER_TYPE,
|
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
141
141
|
model_blob_metadata = model_blobs_metadata[name]
|
142
142
|
model_blob_filename = model_blob_metadata.path
|
143
143
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
144
|
-
m = torch.jit.load( # type:ignore[attr-defined]
|
144
|
+
m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
|
145
145
|
f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
|
146
146
|
)
|
147
147
|
assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
-
import json
|
3
2
|
import os
|
3
|
+
import warnings
|
4
|
+
from importlib import metadata as importlib_metadata
|
4
5
|
from typing import (
|
5
6
|
TYPE_CHECKING,
|
6
7
|
Any,
|
@@ -15,12 +16,17 @@ from typing import (
|
|
15
16
|
|
16
17
|
import numpy as np
|
17
18
|
import pandas as pd
|
19
|
+
from packaging import version
|
18
20
|
from typing_extensions import TypeGuard, Unpack
|
19
21
|
|
20
22
|
from snowflake.ml._internal import type_utils
|
21
23
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
22
24
|
from snowflake.ml.model._packager.model_env import model_env
|
23
|
-
from snowflake.ml.model._packager.model_handlers import
|
25
|
+
from snowflake.ml.model._packager.model_handlers import (
|
26
|
+
_base,
|
27
|
+
_utils as handlers_utils,
|
28
|
+
model_objective_utils,
|
29
|
+
)
|
24
30
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
25
31
|
from snowflake.ml.model._packager.model_meta import (
|
26
32
|
model_blob_meta,
|
@@ -47,41 +53,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
47
53
|
|
48
54
|
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
49
55
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
50
|
-
|
51
|
-
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
52
|
-
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
53
|
-
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
54
|
-
|
55
|
-
@classmethod
|
56
|
-
def get_model_objective(
|
57
|
-
cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
|
58
|
-
) -> model_meta_schema.ModelObjective:
|
59
|
-
import xgboost
|
60
|
-
|
61
|
-
if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
|
62
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
63
|
-
if num_classes == 2:
|
64
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
65
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
66
|
-
if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
|
67
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
68
|
-
if isinstance(model, xgboost.XGBRanker):
|
69
|
-
return model_meta_schema.ModelObjective.RANKING
|
70
|
-
model_params = json.loads(model.save_config())
|
71
|
-
model_objective = model_params["learner"]["objective"]
|
72
|
-
for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
73
|
-
if classification_objective in model_objective:
|
74
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
75
|
-
for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
76
|
-
if classification_objective in model_objective:
|
77
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
78
|
-
for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
|
79
|
-
if ranking_objective in model_objective:
|
80
|
-
return model_meta_schema.ModelObjective.RANKING
|
81
|
-
for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
|
82
|
-
if regression_objective in model_objective:
|
83
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
84
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
56
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
|
85
57
|
|
86
58
|
@classmethod
|
87
59
|
def can_handle(
|
@@ -116,10 +88,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
116
88
|
is_sub_model: Optional[bool] = False,
|
117
89
|
**kwargs: Unpack[model_types.XGBModelSaveOptions],
|
118
90
|
) -> None:
|
91
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
92
|
+
|
119
93
|
import xgboost
|
120
94
|
|
121
95
|
assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
|
122
96
|
|
97
|
+
local_xgb_version = None
|
98
|
+
|
99
|
+
try:
|
100
|
+
local_dist = importlib_metadata.distribution("xgboost")
|
101
|
+
local_xgb_version = version.parse(local_dist.version)
|
102
|
+
except importlib_metadata.PackageNotFoundError:
|
103
|
+
pass
|
104
|
+
|
105
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
|
106
|
+
warnings.warn(
|
107
|
+
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
108
|
+
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
109
|
+
category=UserWarning,
|
110
|
+
stacklevel=1,
|
111
|
+
)
|
112
|
+
enable_explainability = False
|
113
|
+
|
123
114
|
if not is_sub_model:
|
124
115
|
target_methods = handlers_utils.get_target_methods(
|
125
116
|
model=model,
|
@@ -148,22 +139,35 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
148
139
|
sample_input_data=sample_input_data,
|
149
140
|
get_prediction_fn=get_prediction,
|
150
141
|
)
|
151
|
-
|
152
|
-
model_meta.
|
153
|
-
if
|
154
|
-
output_type = model_signature.DataType.DOUBLE
|
155
|
-
if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
156
|
-
output_type = model_signature.DataType.STRING
|
142
|
+
model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
|
143
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
144
|
+
if enable_explainability:
|
157
145
|
model_meta = handlers_utils.add_explain_method_signature(
|
158
146
|
model_meta=model_meta,
|
159
147
|
explain_method="explain",
|
160
148
|
target_method="predict",
|
161
|
-
output_return_type=output_type,
|
149
|
+
output_return_type=model_task_and_output.output_type,
|
162
150
|
)
|
163
151
|
model_meta.function_properties = {
|
164
152
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
165
153
|
}
|
166
154
|
|
155
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
156
|
+
|
157
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
158
|
+
sample_input_data, model_meta, explain_target_method
|
159
|
+
)
|
160
|
+
if background_data is not None:
|
161
|
+
handlers_utils.save_background_data(
|
162
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
163
|
+
)
|
164
|
+
else:
|
165
|
+
warnings.warn(
|
166
|
+
"sample_input_data should be provided for better explainability results",
|
167
|
+
category=UserWarning,
|
168
|
+
stacklevel=1,
|
169
|
+
)
|
170
|
+
|
167
171
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
168
172
|
os.makedirs(model_blob_path, exist_ok=True)
|
169
173
|
model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
@@ -180,15 +184,26 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
180
184
|
model_meta.env.include_if_absent(
|
181
185
|
[
|
182
186
|
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
183
|
-
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
184
187
|
],
|
185
188
|
check_local_version=True,
|
186
189
|
)
|
187
|
-
if
|
190
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
191
|
+
model_meta.env.include_if_absent(
|
192
|
+
[
|
193
|
+
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
194
|
+
],
|
195
|
+
check_local_version=False,
|
196
|
+
)
|
197
|
+
else:
|
188
198
|
model_meta.env.include_if_absent(
|
189
|
-
[
|
199
|
+
[
|
200
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
201
|
+
],
|
190
202
|
check_local_version=True,
|
191
203
|
)
|
204
|
+
|
205
|
+
if enable_explainability:
|
206
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
192
207
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
193
208
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
194
209
|
|
@@ -269,7 +284,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
269
284
|
import shap
|
270
285
|
|
271
286
|
explainer = shap.TreeExplainer(raw_model)
|
272
|
-
df =
|
287
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
273
288
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
274
289
|
|
275
290
|
if target_method == "explain":
|