snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/sql_identifier.py +25 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +19 -2
- snowflake/ml/feature_store/feature_view.py +82 -28
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +78 -9
- snowflake/ml/model/_client/ops/model_ops.py +89 -7
- snowflake/ml/model/_client/ops/service_ops.py +200 -91
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +47 -13
- snowflake/ml/model/_model_composer/model_composer.py +11 -41
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
- snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
- snowflake/ml/model/_packager/model_packager.py +14 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/type_hints.py +11 -152
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
- snowflake/ml/modeling/cluster/birch.py +1 -0
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
- snowflake/ml/modeling/cluster/dbscan.py +1 -0
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
- snowflake/ml/modeling/cluster/k_means.py +1 -0
- snowflake/ml/modeling/cluster/mean_shift.py +1 -0
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
- snowflake/ml/modeling/cluster/optics.py +1 -0
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
- snowflake/ml/modeling/compose/column_transformer.py +1 -0
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
- snowflake/ml/modeling/covariance/oas.py +1 -0
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/pca.py +1 -0
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
- snowflake/ml/modeling/impute/knn_imputer.py +1 -0
- snowflake/ml/modeling/impute/missing_indicator.py +1 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/lars.py +1 -0
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/perceptron.py +1 -0
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ridge.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
- snowflake/ml/modeling/manifold/isomap.py +1 -0
- snowflake/ml/modeling/manifold/mds.py +1 -0
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
- snowflake/ml/modeling/manifold/tsne.py +1 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
- snowflake/ml/modeling/pipeline/pipeline.py +0 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
- snowflake/ml/modeling/svm/linear_svc.py +1 -0
- snowflake/ml/modeling/svm/linear_svr.py +1 -0
- snowflake/ml/modeling/svm/nu_svc.py +1 -0
- snowflake/ml/modeling/svm/nu_svr.py +1 -0
- snowflake/ml/modeling/svm/svc.py +1 -0
- snowflake/ml/modeling/svm/svr.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -4
- snowflake/ml/registry/registry.py +165 -6
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +24 -9
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/RECORD +225 -249
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -106
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -2,23 +2,67 @@ import json
|
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from typing import TYPE_CHECKING, Any, Union
|
4
4
|
|
5
|
+
from snowflake.ml._internal import type_utils
|
5
6
|
from snowflake.ml.model import model_signature, type_hints
|
6
7
|
from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
|
7
8
|
|
8
9
|
if TYPE_CHECKING:
|
10
|
+
import catboost
|
9
11
|
import lightgbm
|
12
|
+
import sklearn
|
13
|
+
import sklearn.pipeline
|
10
14
|
import xgboost
|
11
15
|
|
12
16
|
|
13
17
|
@dataclass
|
14
|
-
class
|
15
|
-
|
18
|
+
class ModelTaskAndOutputType:
|
19
|
+
task: type_hints.Task
|
16
20
|
output_type: model_signature.DataType
|
17
21
|
|
18
22
|
|
19
|
-
def
|
23
|
+
def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]) -> type_hints.Task:
|
24
|
+
from sklearn.base import is_classifier, is_regressor
|
25
|
+
|
26
|
+
if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
|
27
|
+
return type_hints.Task.UNKNOWN
|
28
|
+
if is_regressor(model):
|
29
|
+
return type_hints.Task.TABULAR_REGRESSION
|
30
|
+
if is_classifier(model):
|
31
|
+
classes_list = getattr(model, "classes_", [])
|
32
|
+
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
33
|
+
if isinstance(num_classes, int):
|
34
|
+
if num_classes > 2:
|
35
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
36
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
37
|
+
return type_hints.Task.UNKNOWN
|
38
|
+
return type_hints.Task.UNKNOWN
|
39
|
+
|
40
|
+
|
41
|
+
def get_model_task_catboost(model: "catboost.CatBoost") -> type_hints.Task:
|
42
|
+
loss_function = None
|
43
|
+
if type_utils.LazyType("catboost.CatBoost").isinstance(model):
|
44
|
+
loss_function = model.get_all_params()["loss_function"] # type: ignore[attr-defined]
|
45
|
+
|
46
|
+
if (type_utils.LazyType("catboost.CatBoostClassifier").isinstance(model)) or model._is_classification_objective(
|
47
|
+
loss_function
|
48
|
+
):
|
49
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
50
|
+
if num_classes == 0:
|
51
|
+
return type_hints.Task.UNKNOWN
|
52
|
+
if num_classes <= 2:
|
53
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
54
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
55
|
+
if (type_utils.LazyType("catboost.CatBoostRanker").isinstance(model)) or model._is_ranking_objective(loss_function):
|
56
|
+
return type_hints.Task.TABULAR_RANKING
|
57
|
+
if (type_utils.LazyType("catboost.CatBoostRegressor").isinstance(model)) or model._is_regression_objective(
|
58
|
+
loss_function
|
59
|
+
):
|
60
|
+
return type_hints.Task.TABULAR_REGRESSION
|
20
61
|
|
21
|
-
|
62
|
+
return type_hints.Task.UNKNOWN
|
63
|
+
|
64
|
+
|
65
|
+
def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.Task:
|
22
66
|
|
23
67
|
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
24
68
|
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
@@ -36,81 +80,90 @@ def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBM
|
|
36
80
|
]
|
37
81
|
|
38
82
|
# does not account for cross-entropy and custom
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
if
|
45
|
-
return type_hints.
|
46
|
-
if
|
47
|
-
return type_hints.
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
return type_hints.ModelObjective.REGRESSION
|
57
|
-
return type_hints.ModelObjective.UNKNOWN
|
58
|
-
|
59
|
-
|
60
|
-
def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective:
|
61
|
-
|
62
|
-
import xgboost
|
83
|
+
model_task = ""
|
84
|
+
if type_utils.LazyType("lightgbm.Booster").isinstance(model):
|
85
|
+
model_task = model.params["objective"] # type: ignore[attr-defined]
|
86
|
+
elif hasattr(model, "objective_"):
|
87
|
+
model_task = model.objective_
|
88
|
+
if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
|
89
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
90
|
+
if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
|
91
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
92
|
+
if model_task in _RANKING_OBJECTIVES:
|
93
|
+
return type_hints.Task.TABULAR_RANKING
|
94
|
+
if model_task in _REGRESSION_OBJECTIVES:
|
95
|
+
return type_hints.Task.TABULAR_REGRESSION
|
96
|
+
return type_hints.Task.UNKNOWN
|
97
|
+
|
98
|
+
|
99
|
+
def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.Task:
|
63
100
|
|
64
101
|
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
65
102
|
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
66
103
|
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
67
104
|
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
68
105
|
|
69
|
-
|
70
|
-
if
|
71
|
-
model_params = json.loads(model.save_config())
|
72
|
-
|
106
|
+
model_task = ""
|
107
|
+
if type_utils.LazyType("xgboost.Booster").isinstance(model):
|
108
|
+
model_params = json.loads(model.save_config()) # type: ignore[attr-defined]
|
109
|
+
model_task = model_params.get("learner", {}).get("objective", "")
|
73
110
|
else:
|
74
111
|
if hasattr(model, "get_params"):
|
75
|
-
|
112
|
+
model_task = model.get_params().get("objective", "")
|
76
113
|
|
77
|
-
if isinstance(
|
78
|
-
|
114
|
+
if isinstance(model_task, dict):
|
115
|
+
model_task = model_task.get("name", "")
|
79
116
|
for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
80
|
-
if classification_objective in
|
81
|
-
return type_hints.
|
117
|
+
if classification_objective in model_task:
|
118
|
+
return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
|
82
119
|
for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
83
|
-
if classification_objective in
|
84
|
-
return type_hints.
|
120
|
+
if classification_objective in model_task:
|
121
|
+
return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
|
85
122
|
for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
|
86
|
-
if ranking_objective in
|
87
|
-
return type_hints.
|
123
|
+
if ranking_objective in model_task:
|
124
|
+
return type_hints.Task.TABULAR_RANKING
|
88
125
|
for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
|
89
|
-
if regression_objective in
|
90
|
-
return type_hints.
|
91
|
-
return type_hints.
|
126
|
+
if regression_objective in model_task:
|
127
|
+
return type_hints.Task.TABULAR_REGRESSION
|
128
|
+
return type_hints.Task.UNKNOWN
|
92
129
|
|
93
130
|
|
94
|
-
def
|
95
|
-
|
131
|
+
def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
|
132
|
+
if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
|
133
|
+
model
|
134
|
+
):
|
135
|
+
task = get_model_task_xgb(model)
|
136
|
+
output_type = model_signature.DataType.DOUBLE
|
137
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
138
|
+
output_type = model_signature.DataType.STRING
|
139
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
96
140
|
|
97
|
-
if
|
98
|
-
|
141
|
+
if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
|
142
|
+
"lightgbm.LGBMModel"
|
143
|
+
).isinstance(model):
|
144
|
+
task = get_model_task_lightgbm(model)
|
99
145
|
output_type = model_signature.DataType.DOUBLE
|
100
|
-
if
|
146
|
+
if task in [
|
147
|
+
type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
|
148
|
+
type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
|
149
|
+
]:
|
101
150
|
output_type = model_signature.DataType.STRING
|
102
|
-
return
|
151
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
103
152
|
|
104
|
-
|
153
|
+
if type_utils.LazyType("catboost.CatBoost").isinstance(model):
|
154
|
+
task = get_model_task_catboost(model)
|
155
|
+
output_type = model_signature.DataType.DOUBLE
|
156
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
157
|
+
output_type = model_signature.DataType.STRING
|
158
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
105
159
|
|
106
|
-
if isinstance(model
|
107
|
-
|
160
|
+
if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
|
161
|
+
"sklearn.pipeline.Pipeline"
|
162
|
+
).isinstance(model):
|
163
|
+
task = get_task_skl(model)
|
108
164
|
output_type = model_signature.DataType.DOUBLE
|
109
|
-
if
|
110
|
-
type_hints.ModelObjective.BINARY_CLASSIFICATION,
|
111
|
-
type_hints.ModelObjective.MULTI_CLASSIFICATION,
|
112
|
-
]:
|
165
|
+
if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
|
113
166
|
output_type = model_signature.DataType.STRING
|
114
|
-
return
|
167
|
+
return ModelTaskAndOutputType(task=task, output_type=output_type)
|
115
168
|
|
116
169
|
raise ValueError(f"Model type {type(model)} is not supported")
|
@@ -2,7 +2,6 @@ import logging
|
|
2
2
|
import os
|
3
3
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
4
4
|
|
5
|
-
import cloudpickle
|
6
5
|
import pandas as pd
|
7
6
|
from typing_extensions import TypeGuard, Unpack
|
8
7
|
|
@@ -120,9 +119,21 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
120
119
|
model_meta.env.include_if_absent(
|
121
120
|
[
|
122
121
|
model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"),
|
122
|
+
model_env.ModelDependency(requirement="transformers", pip_name="transformers"),
|
123
|
+
model_env.ModelDependency(requirement="pytorch", pip_name="torch"),
|
123
124
|
],
|
124
125
|
check_local_version=True,
|
125
126
|
)
|
127
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
128
|
+
|
129
|
+
@staticmethod
|
130
|
+
def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
|
131
|
+
if kwargs.get("device", None) is not None:
|
132
|
+
return kwargs["device"]
|
133
|
+
elif kwargs.get("use_gpu", False):
|
134
|
+
return "cuda"
|
135
|
+
|
136
|
+
return None
|
126
137
|
|
127
138
|
@classmethod
|
128
139
|
def load_model(
|
@@ -144,13 +155,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
144
155
|
model_blob_filename = model_blob_metadata.path
|
145
156
|
model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
|
146
157
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
assert os.path.isfile(model_blob_file_or_dir_path) # if the saved model is a file
|
151
|
-
with open(model_blob_file_or_dir_path, "rb") as f:
|
152
|
-
model = cloudpickle.load(f)
|
153
|
-
assert isinstance(model, sentence_transformers.SentenceTransformer)
|
158
|
+
model = sentence_transformers.SentenceTransformer(
|
159
|
+
model_blob_file_or_dir_path, device=cls._get_device_config(**kwargs)
|
160
|
+
)
|
154
161
|
return model
|
155
162
|
|
156
163
|
@classmethod
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import warnings
|
2
3
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
|
3
4
|
|
4
5
|
import cloudpickle
|
@@ -6,22 +7,21 @@ import numpy as np
|
|
6
7
|
import pandas as pd
|
7
8
|
from typing_extensions import TypeGuard, Unpack
|
8
9
|
|
9
|
-
import snowflake.snowpark.dataframe as sp_df
|
10
10
|
from snowflake.ml._internal import type_utils
|
11
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._packager.model_env import model_env
|
13
|
-
from snowflake.ml.model._packager.model_handlers import
|
13
|
+
from snowflake.ml.model._packager.model_handlers import (
|
14
|
+
_base,
|
15
|
+
_utils as handlers_utils,
|
16
|
+
model_objective_utils,
|
17
|
+
)
|
14
18
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
15
19
|
from snowflake.ml.model._packager.model_meta import (
|
16
20
|
model_blob_meta,
|
17
21
|
model_meta as model_meta_api,
|
18
22
|
model_meta_schema,
|
19
23
|
)
|
20
|
-
from snowflake.ml.model._signatures import
|
21
|
-
numpy_handler,
|
22
|
-
snowpark_handler,
|
23
|
-
utils as model_signature_utils,
|
24
|
-
)
|
24
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
25
25
|
|
26
26
|
if TYPE_CHECKING:
|
27
27
|
import sklearn.base
|
@@ -40,28 +40,14 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
40
40
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
41
41
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
42
42
|
|
43
|
-
DEFAULT_TARGET_METHODS = [
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
if isinstance(model, sklearn.pipeline.Pipeline):
|
53
|
-
return model_types.ModelObjective.UNKNOWN
|
54
|
-
if is_regressor(model):
|
55
|
-
return model_types.ModelObjective.REGRESSION
|
56
|
-
if is_classifier(model):
|
57
|
-
classes_list = getattr(model, "classes_", [])
|
58
|
-
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
59
|
-
if isinstance(num_classes, int):
|
60
|
-
if num_classes > 2:
|
61
|
-
return model_types.ModelObjective.MULTI_CLASSIFICATION
|
62
|
-
return model_types.ModelObjective.BINARY_CLASSIFICATION
|
63
|
-
return model_types.ModelObjective.UNKNOWN
|
64
|
-
return model_types.ModelObjective.UNKNOWN
|
43
|
+
DEFAULT_TARGET_METHODS = [
|
44
|
+
"predict",
|
45
|
+
"transform",
|
46
|
+
"predict_proba",
|
47
|
+
"predict_log_proba",
|
48
|
+
"decision_function",
|
49
|
+
]
|
50
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
65
51
|
|
66
52
|
@classmethod
|
67
53
|
def can_handle(
|
@@ -95,18 +81,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
95
81
|
|
96
82
|
return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model)
|
97
83
|
|
98
|
-
@staticmethod
|
99
|
-
def get_explainability_supported_background(
|
100
|
-
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
101
|
-
) -> Optional[pd.DataFrame]:
|
102
|
-
if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame):
|
103
|
-
return (
|
104
|
-
sample_input_data
|
105
|
-
if isinstance(sample_input_data, pd.DataFrame)
|
106
|
-
else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
107
|
-
)
|
108
|
-
return None
|
109
|
-
|
110
84
|
@classmethod
|
111
85
|
def save_model(
|
112
86
|
cls,
|
@@ -125,23 +99,10 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
125
99
|
import sklearn.pipeline
|
126
100
|
|
127
101
|
assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
|
128
|
-
|
129
|
-
background_data = cls.get_explainability_supported_background(sample_input_data)
|
130
|
-
|
131
|
-
# if users did not ask then we enable if we have background data
|
132
|
-
if enable_explainability is None and background_data is not None:
|
133
|
-
enable_explainability = True
|
134
102
|
if enable_explainability:
|
135
|
-
# if users set it explicitly but no
|
136
|
-
if
|
137
|
-
raise ValueError(
|
138
|
-
"Sample input data is required to enable explainability. Currently we only support this for "
|
139
|
-
+ "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
|
140
|
-
)
|
141
|
-
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
|
142
|
-
os.makedirs(data_blob_path, exist_ok=True)
|
143
|
-
with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
|
144
|
-
background_data.to_parquet(f)
|
103
|
+
# if users set it explicitly but no sample_input_data then error out
|
104
|
+
if sample_input_data is None:
|
105
|
+
raise ValueError("Sample input data is required to enable explainability.")
|
145
106
|
|
146
107
|
if not is_sub_model:
|
147
108
|
target_methods = handlers_utils.get_target_methods(
|
@@ -151,7 +112,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
151
112
|
)
|
152
113
|
|
153
114
|
def get_prediction(
|
154
|
-
target_method_name: str,
|
115
|
+
target_method_name: str,
|
116
|
+
sample_input_data: model_types.SupportedLocalDataType,
|
155
117
|
) -> model_types.SupportedLocalDataType:
|
156
118
|
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
157
119
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
@@ -169,19 +131,40 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
169
131
|
get_prediction_fn=get_prediction,
|
170
132
|
)
|
171
133
|
|
172
|
-
|
173
|
-
model_meta.model_objective = model_objective
|
134
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
174
135
|
|
136
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
137
|
+
sample_input_data, model_meta, explain_target_method
|
138
|
+
)
|
139
|
+
|
140
|
+
model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(model)
|
141
|
+
model_meta.task = model_task_and_output_type.task
|
142
|
+
|
143
|
+
# if users did not ask then we enable if we have background data
|
144
|
+
if enable_explainability is None:
|
145
|
+
if background_data is None:
|
146
|
+
warnings.warn(
|
147
|
+
"sample_input_data should be provided to enable explainability by default",
|
148
|
+
category=UserWarning,
|
149
|
+
stacklevel=1,
|
150
|
+
)
|
151
|
+
enable_explainability = False
|
152
|
+
else:
|
153
|
+
enable_explainability = True
|
175
154
|
if enable_explainability:
|
176
|
-
|
155
|
+
handlers_utils.save_background_data(
|
156
|
+
model_blobs_dir_path,
|
157
|
+
cls.EXPLAIN_ARTIFACTS_DIR,
|
158
|
+
cls.BG_DATA_FILE_SUFFIX,
|
159
|
+
name,
|
160
|
+
background_data,
|
161
|
+
)
|
177
162
|
|
178
|
-
if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
|
179
|
-
output_type = model_signature.DataType.STRING
|
180
163
|
model_meta = handlers_utils.add_explain_method_signature(
|
181
164
|
model_meta=model_meta,
|
182
165
|
explain_method="explain",
|
183
|
-
target_method=
|
184
|
-
output_return_type=output_type,
|
166
|
+
target_method=explain_target_method,
|
167
|
+
output_return_type=model_task_and_output_type.output_type,
|
185
168
|
)
|
186
169
|
|
187
170
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -202,7 +185,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
202
185
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
203
186
|
|
204
187
|
model_meta.env.include_if_absent(
|
205
|
-
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
|
188
|
+
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
|
189
|
+
check_local_version=True,
|
206
190
|
)
|
207
191
|
|
208
192
|
@classmethod
|
@@ -43,6 +43,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
43
43
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
44
44
|
|
45
45
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
46
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
47
|
+
|
46
48
|
IS_AUTO_SIGNATURE = True
|
47
49
|
|
48
50
|
@classmethod
|
@@ -71,13 +73,14 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
71
73
|
|
72
74
|
@classmethod
|
73
75
|
def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
|
74
|
-
import importlib_metadata
|
76
|
+
from importlib import metadata as importlib_metadata
|
77
|
+
|
75
78
|
from packaging import version
|
76
79
|
|
77
80
|
local_version = None
|
78
81
|
|
79
82
|
try:
|
80
|
-
local_dist = importlib_metadata.distribution(pkg_name)
|
83
|
+
local_dist = importlib_metadata.distribution(pkg_name)
|
81
84
|
local_version = version.parse(local_dist.version)
|
82
85
|
except importlib_metadata.PackageNotFoundError:
|
83
86
|
pass
|
@@ -104,7 +107,13 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
104
107
|
def _get_supported_object_for_explainability(
|
105
108
|
cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
|
106
109
|
) -> Any:
|
107
|
-
|
110
|
+
from snowflake.ml.modeling import pipeline as snowml_pipeline
|
111
|
+
|
112
|
+
# handle pipeline objects separately
|
113
|
+
if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
|
114
|
+
return None
|
115
|
+
|
116
|
+
methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
|
108
117
|
for method_name in methods:
|
109
118
|
if hasattr(estimator, method_name):
|
110
119
|
try:
|
@@ -136,9 +145,9 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
136
145
|
# Pipeline is inherited from BaseEstimator, so no need to add one more check
|
137
146
|
|
138
147
|
if not is_sub_model:
|
139
|
-
if
|
148
|
+
if model_meta.signatures:
|
140
149
|
warnings.warn(
|
141
|
-
"
|
150
|
+
"Providing model signature for Snowpark ML "
|
142
151
|
+ "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
|
143
152
|
UserWarning,
|
144
153
|
stacklevel=2,
|
@@ -162,22 +171,31 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
162
171
|
python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
|
163
172
|
if python_base_obj is None:
|
164
173
|
if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
|
165
|
-
raise ValueError(
|
174
|
+
raise ValueError(
|
175
|
+
"Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
|
176
|
+
)
|
166
177
|
# set None to False so we don't include shap in the environment
|
167
178
|
enable_explainability = False
|
168
179
|
else:
|
169
|
-
|
170
|
-
|
171
|
-
)
|
172
|
-
model_meta.model_objective = model_objective_and_output_type.objective
|
180
|
+
model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(python_base_obj)
|
181
|
+
model_meta.task = model_task_and_output_type.task
|
182
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
173
183
|
model_meta = handlers_utils.add_explain_method_signature(
|
174
184
|
model_meta=model_meta,
|
175
185
|
explain_method="explain",
|
176
|
-
target_method=
|
177
|
-
output_return_type=
|
186
|
+
target_method=explain_target_method,
|
187
|
+
output_return_type=model_task_and_output_type.output_type,
|
178
188
|
)
|
179
189
|
enable_explainability = True
|
180
190
|
|
191
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
192
|
+
sample_input_data, model_meta, explain_target_method
|
193
|
+
)
|
194
|
+
if background_data is not None:
|
195
|
+
handlers_utils.save_background_data(
|
196
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
197
|
+
)
|
198
|
+
|
181
199
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
182
200
|
os.makedirs(model_blob_path, exist_ok=True)
|
183
201
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
@@ -258,6 +276,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
258
276
|
raw_model: "BaseEstimator",
|
259
277
|
signature: model_signature.ModelSignature,
|
260
278
|
target_method: str,
|
279
|
+
background_data: Optional[pd.DataFrame] = None,
|
261
280
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
262
281
|
@custom_model.inference_api
|
263
282
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
@@ -276,16 +295,16 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
276
295
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
277
296
|
import shap
|
278
297
|
|
279
|
-
methods = ["to_xgboost", "to_lightgbm"]
|
298
|
+
methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
|
280
299
|
for method_name in methods:
|
281
300
|
try:
|
282
301
|
base_model = getattr(raw_model, method_name)()
|
283
|
-
explainer = shap.
|
284
|
-
df =
|
302
|
+
explainer = shap.Explainer(base_model, masker=background_data)
|
303
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
285
304
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
286
305
|
except exceptions.SnowflakeMLException:
|
287
306
|
pass # Do nothing and continue to the next method
|
288
|
-
raise ValueError("The model must be an xgboost or
|
307
|
+
raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
|
289
308
|
|
290
309
|
if target_method == "explain":
|
291
310
|
return explain_fn
|
@@ -294,7 +313,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
294
313
|
|
295
314
|
type_method_dict = {}
|
296
315
|
for target_method_name, sig in model_meta.signatures.items():
|
297
|
-
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
316
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
|
298
317
|
|
299
318
|
_SnowMLModel = type(
|
300
319
|
"_SnowMLModel",
|
@@ -1,6 +1,7 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
2
|
import os
|
3
3
|
import warnings
|
4
|
+
from importlib import metadata as importlib_metadata
|
4
5
|
from typing import (
|
5
6
|
TYPE_CHECKING,
|
6
7
|
Any,
|
@@ -13,7 +14,6 @@ from typing import (
|
|
13
14
|
final,
|
14
15
|
)
|
15
16
|
|
16
|
-
import importlib_metadata
|
17
17
|
import numpy as np
|
18
18
|
import pandas as pd
|
19
19
|
from packaging import version
|
@@ -53,6 +53,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
53
53
|
|
54
54
|
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
55
55
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
56
|
+
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
|
56
57
|
|
57
58
|
@classmethod
|
58
59
|
def can_handle(
|
@@ -96,7 +97,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
96
97
|
local_xgb_version = None
|
97
98
|
|
98
99
|
try:
|
99
|
-
local_dist = importlib_metadata.distribution("xgboost")
|
100
|
+
local_dist = importlib_metadata.distribution("xgboost")
|
100
101
|
local_xgb_version = version.parse(local_dist.version)
|
101
102
|
except importlib_metadata.PackageNotFoundError:
|
102
103
|
pass
|
@@ -138,21 +139,35 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
138
139
|
sample_input_data=sample_input_data,
|
139
140
|
get_prediction_fn=get_prediction,
|
140
141
|
)
|
141
|
-
|
142
|
-
model_meta.
|
143
|
-
model_meta.model_objective, model_objective_and_output.objective
|
144
|
-
)
|
142
|
+
model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
|
143
|
+
model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
|
145
144
|
if enable_explainability:
|
146
145
|
model_meta = handlers_utils.add_explain_method_signature(
|
147
146
|
model_meta=model_meta,
|
148
147
|
explain_method="explain",
|
149
148
|
target_method="predict",
|
150
|
-
output_return_type=
|
149
|
+
output_return_type=model_task_and_output.output_type,
|
151
150
|
)
|
152
151
|
model_meta.function_properties = {
|
153
152
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
154
153
|
}
|
155
154
|
|
155
|
+
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
156
|
+
|
157
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
158
|
+
sample_input_data, model_meta, explain_target_method
|
159
|
+
)
|
160
|
+
if background_data is not None:
|
161
|
+
handlers_utils.save_background_data(
|
162
|
+
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
163
|
+
)
|
164
|
+
else:
|
165
|
+
warnings.warn(
|
166
|
+
"sample_input_data should be provided for better explainability results",
|
167
|
+
category=UserWarning,
|
168
|
+
stacklevel=1,
|
169
|
+
)
|
170
|
+
|
156
171
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
157
172
|
os.makedirs(model_blob_path, exist_ok=True)
|
158
173
|
model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|