snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/feature_store.py +41 -17
- snowflake/ml/feature_store/feature_view.py +2 -2
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_version_impl.py +22 -7
- snowflake/ml/model/_client/ops/model_ops.py +39 -3
- snowflake/ml/model/_client/ops/service_ops.py +198 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
- snowflake/ml/model/_client/sql/service.py +85 -18
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
- snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/data/torch_dataset.py +0 -33
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,116 @@
|
|
1
|
+
import json
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import TYPE_CHECKING, Any, Union
|
4
|
+
|
5
|
+
from snowflake.ml.model import model_signature, type_hints
|
6
|
+
from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
import lightgbm
|
10
|
+
import xgboost
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class ModelObjectiveAndOutputType:
|
15
|
+
objective: type_hints.ModelObjective
|
16
|
+
output_type: model_signature.DataType
|
17
|
+
|
18
|
+
|
19
|
+
def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.ModelObjective:
|
20
|
+
|
21
|
+
import lightgbm
|
22
|
+
|
23
|
+
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
24
|
+
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
25
|
+
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
26
|
+
_REGRESSION_OBJECTIVES = [
|
27
|
+
"regression",
|
28
|
+
"regression_l1",
|
29
|
+
"huber",
|
30
|
+
"fair",
|
31
|
+
"poisson",
|
32
|
+
"quantile",
|
33
|
+
"tweedie",
|
34
|
+
"mape",
|
35
|
+
"gamma",
|
36
|
+
]
|
37
|
+
|
38
|
+
# does not account for cross-entropy and custom
|
39
|
+
if isinstance(model, lightgbm.LGBMClassifier):
|
40
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
41
|
+
if num_classes == 2:
|
42
|
+
return type_hints.ModelObjective.BINARY_CLASSIFICATION
|
43
|
+
return type_hints.ModelObjective.MULTI_CLASSIFICATION
|
44
|
+
if isinstance(model, lightgbm.LGBMRanker):
|
45
|
+
return type_hints.ModelObjective.RANKING
|
46
|
+
if isinstance(model, lightgbm.LGBMRegressor):
|
47
|
+
return type_hints.ModelObjective.REGRESSION
|
48
|
+
model_objective = model.params["objective"]
|
49
|
+
if model_objective in _BINARY_CLASSIFICATION_OBJECTIVES:
|
50
|
+
return type_hints.ModelObjective.BINARY_CLASSIFICATION
|
51
|
+
if model_objective in _MULTI_CLASSIFICATION_OBJECTIVES:
|
52
|
+
return type_hints.ModelObjective.MULTI_CLASSIFICATION
|
53
|
+
if model_objective in _RANKING_OBJECTIVES:
|
54
|
+
return type_hints.ModelObjective.RANKING
|
55
|
+
if model_objective in _REGRESSION_OBJECTIVES:
|
56
|
+
return type_hints.ModelObjective.REGRESSION
|
57
|
+
return type_hints.ModelObjective.UNKNOWN
|
58
|
+
|
59
|
+
|
60
|
+
def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective:
|
61
|
+
|
62
|
+
import xgboost
|
63
|
+
|
64
|
+
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
65
|
+
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
66
|
+
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
67
|
+
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
68
|
+
|
69
|
+
model_objective = ""
|
70
|
+
if isinstance(model, xgboost.Booster):
|
71
|
+
model_params = json.loads(model.save_config())
|
72
|
+
model_objective = model_params.get("learner", {}).get("objective", "")
|
73
|
+
else:
|
74
|
+
if hasattr(model, "get_params"):
|
75
|
+
model_objective = model.get_params().get("objective", "")
|
76
|
+
|
77
|
+
if isinstance(model_objective, dict):
|
78
|
+
model_objective = model_objective.get("name", "")
|
79
|
+
for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
80
|
+
if classification_objective in model_objective:
|
81
|
+
return type_hints.ModelObjective.BINARY_CLASSIFICATION
|
82
|
+
for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
83
|
+
if classification_objective in model_objective:
|
84
|
+
return type_hints.ModelObjective.MULTI_CLASSIFICATION
|
85
|
+
for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
|
86
|
+
if ranking_objective in model_objective:
|
87
|
+
return type_hints.ModelObjective.RANKING
|
88
|
+
for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
|
89
|
+
if regression_objective in model_objective:
|
90
|
+
return type_hints.ModelObjective.REGRESSION
|
91
|
+
return type_hints.ModelObjective.UNKNOWN
|
92
|
+
|
93
|
+
|
94
|
+
def get_model_objective_and_output_type(model: Any) -> ModelObjectiveAndOutputType:
|
95
|
+
import xgboost
|
96
|
+
|
97
|
+
if isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel):
|
98
|
+
model_objective = get_model_objective_xgb(model)
|
99
|
+
output_type = model_signature.DataType.DOUBLE
|
100
|
+
if model_objective == type_hints.ModelObjective.MULTI_CLASSIFICATION:
|
101
|
+
output_type = model_signature.DataType.STRING
|
102
|
+
return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
|
103
|
+
|
104
|
+
import lightgbm
|
105
|
+
|
106
|
+
if isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel):
|
107
|
+
model_objective = get_model_objective_lightgbm(model)
|
108
|
+
output_type = model_signature.DataType.DOUBLE
|
109
|
+
if model_objective in [
|
110
|
+
type_hints.ModelObjective.BINARY_CLASSIFICATION,
|
111
|
+
type_hints.ModelObjective.MULTI_CLASSIFICATION,
|
112
|
+
]:
|
113
|
+
output_type = model_signature.DataType.STRING
|
114
|
+
return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
|
115
|
+
|
116
|
+
raise ValueError(f"Model type {type(model)} is not supported")
|
@@ -45,23 +45,23 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
45
45
|
@classmethod
|
46
46
|
def get_model_objective(
|
47
47
|
cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
|
48
|
-
) ->
|
48
|
+
) -> model_types.ModelObjective:
|
49
49
|
import sklearn.pipeline
|
50
50
|
from sklearn.base import is_classifier, is_regressor
|
51
51
|
|
52
52
|
if isinstance(model, sklearn.pipeline.Pipeline):
|
53
|
-
return
|
53
|
+
return model_types.ModelObjective.UNKNOWN
|
54
54
|
if is_regressor(model):
|
55
|
-
return
|
55
|
+
return model_types.ModelObjective.REGRESSION
|
56
56
|
if is_classifier(model):
|
57
57
|
classes_list = getattr(model, "classes_", [])
|
58
58
|
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
59
59
|
if isinstance(num_classes, int):
|
60
60
|
if num_classes > 2:
|
61
|
-
return
|
62
|
-
return
|
63
|
-
return
|
64
|
-
return
|
61
|
+
return model_types.ModelObjective.MULTI_CLASSIFICATION
|
62
|
+
return model_types.ModelObjective.BINARY_CLASSIFICATION
|
63
|
+
return model_types.ModelObjective.UNKNOWN
|
64
|
+
return model_types.ModelObjective.UNKNOWN
|
65
65
|
|
66
66
|
@classmethod
|
67
67
|
def can_handle(
|
@@ -95,6 +95,18 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
95
95
|
|
96
96
|
return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model)
|
97
97
|
|
98
|
+
@staticmethod
|
99
|
+
def get_explainability_supported_background(
|
100
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
101
|
+
) -> Optional[pd.DataFrame]:
|
102
|
+
if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame):
|
103
|
+
return (
|
104
|
+
sample_input_data
|
105
|
+
if isinstance(sample_input_data, pd.DataFrame)
|
106
|
+
else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
107
|
+
)
|
108
|
+
return None
|
109
|
+
|
98
110
|
@classmethod
|
99
111
|
def save_model(
|
100
112
|
cls,
|
@@ -106,32 +118,30 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
106
118
|
is_sub_model: Optional[bool] = False,
|
107
119
|
**kwargs: Unpack[model_types.SKLModelSaveOptions],
|
108
120
|
) -> None:
|
109
|
-
|
121
|
+
# setting None by default to distinguish if users did not set it
|
122
|
+
enable_explainability = kwargs.get("enable_explainability", None)
|
110
123
|
|
111
124
|
import sklearn.base
|
112
125
|
import sklearn.pipeline
|
113
126
|
|
114
127
|
assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
|
115
128
|
|
116
|
-
|
129
|
+
background_data = cls.get_explainability_supported_background(sample_input_data)
|
130
|
+
|
131
|
+
# if users did not ask then we enable if we have background data
|
132
|
+
if enable_explainability is None and background_data is not None:
|
133
|
+
enable_explainability = True
|
117
134
|
if enable_explainability:
|
118
|
-
#
|
119
|
-
if
|
120
|
-
isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame)
|
121
|
-
):
|
135
|
+
# if users set it explicitly but no background data then error out
|
136
|
+
if background_data is None:
|
122
137
|
raise ValueError(
|
123
138
|
"Sample input data is required to enable explainability. Currently we only support this for "
|
124
139
|
+ "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
|
125
140
|
)
|
126
|
-
sample_input_data_pandas = (
|
127
|
-
sample_input_data
|
128
|
-
if isinstance(sample_input_data, pd.DataFrame)
|
129
|
-
else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
130
|
-
)
|
131
141
|
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
|
132
142
|
os.makedirs(data_blob_path, exist_ok=True)
|
133
143
|
with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
|
134
|
-
|
144
|
+
background_data.to_parquet(f)
|
135
145
|
|
136
146
|
if not is_sub_model:
|
137
147
|
target_methods = handlers_utils.get_target_methods(
|
@@ -159,9 +169,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
159
169
|
get_prediction_fn=get_prediction,
|
160
170
|
)
|
161
171
|
|
172
|
+
model_objective = cls.get_model_objective(model)
|
173
|
+
model_meta.model_objective = model_objective
|
174
|
+
|
162
175
|
if enable_explainability:
|
163
176
|
output_type = model_signature.DataType.DOUBLE
|
164
|
-
|
177
|
+
|
178
|
+
if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
|
165
179
|
output_type = model_signature.DataType.STRING
|
166
180
|
model_meta = handlers_utils.add_explain_method_signature(
|
167
181
|
model_meta=model_meta,
|
@@ -184,10 +198,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
184
198
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
185
199
|
|
186
200
|
if enable_explainability:
|
187
|
-
model_meta.env.include_if_absent(
|
188
|
-
|
189
|
-
check_local_version=True,
|
190
|
-
)
|
201
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
202
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
191
203
|
|
192
204
|
model_meta.env.include_if_absent(
|
193
205
|
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
|
@@ -1,20 +1,27 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
+
from packaging import version
|
8
9
|
from typing_extensions import TypeGuard, Unpack
|
9
10
|
|
10
11
|
from snowflake.ml._internal import type_utils
|
12
|
+
from snowflake.ml._internal.exceptions import exceptions
|
11
13
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
14
|
from snowflake.ml.model._packager.model_env import model_env
|
13
|
-
from snowflake.ml.model._packager.model_handlers import
|
15
|
+
from snowflake.ml.model._packager.model_handlers import (
|
16
|
+
_base,
|
17
|
+
_utils as handlers_utils,
|
18
|
+
model_objective_utils,
|
19
|
+
)
|
14
20
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
15
21
|
from snowflake.ml.model._packager.model_meta import (
|
16
22
|
model_blob_meta,
|
17
23
|
model_meta as model_meta_api,
|
24
|
+
model_meta_schema,
|
18
25
|
)
|
19
26
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
20
27
|
|
@@ -62,6 +69,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
62
69
|
|
63
70
|
return cast("BaseEstimator", model)
|
64
71
|
|
72
|
+
@classmethod
|
73
|
+
def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
|
74
|
+
import importlib_metadata
|
75
|
+
from packaging import version
|
76
|
+
|
77
|
+
local_version = None
|
78
|
+
|
79
|
+
try:
|
80
|
+
local_dist = importlib_metadata.distribution(pkg_name) # type: ignore[no-untyped-call]
|
81
|
+
local_version = version.parse(local_dist.version)
|
82
|
+
except importlib_metadata.PackageNotFoundError:
|
83
|
+
pass
|
84
|
+
|
85
|
+
return local_version
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
|
89
|
+
|
90
|
+
local_xgb_version = cls._get_local_version_package("xgboost")
|
91
|
+
|
92
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
|
93
|
+
if enable_explainability:
|
94
|
+
warnings.warn(
|
95
|
+
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
96
|
+
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
97
|
+
category=UserWarning,
|
98
|
+
stacklevel=1,
|
99
|
+
)
|
100
|
+
return False
|
101
|
+
return True
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
def _get_supported_object_for_explainability(
|
105
|
+
cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
|
106
|
+
) -> Any:
|
107
|
+
methods = ["to_xgboost", "to_lightgbm"]
|
108
|
+
for method_name in methods:
|
109
|
+
if hasattr(estimator, method_name):
|
110
|
+
try:
|
111
|
+
result = getattr(estimator, method_name)()
|
112
|
+
if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
|
113
|
+
return None
|
114
|
+
return result
|
115
|
+
except exceptions.SnowflakeMLException:
|
116
|
+
pass # Do nothing and continue to the next method
|
117
|
+
return None
|
118
|
+
|
65
119
|
@classmethod
|
66
120
|
def save_model(
|
67
121
|
cls,
|
@@ -73,9 +127,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
73
127
|
is_sub_model: Optional[bool] = False,
|
74
128
|
**kwargs: Unpack[model_types.SNOWModelSaveOptions],
|
75
129
|
) -> None:
|
76
|
-
|
77
|
-
|
78
|
-
raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
|
130
|
+
|
131
|
+
enable_explainability = kwargs.get("enable_explainability", None)
|
79
132
|
|
80
133
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
81
134
|
|
@@ -105,6 +158,26 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
105
158
|
raise ValueError(f"Target method {method_name} does not exist in the model.")
|
106
159
|
model_meta.signatures = temp_model_signature_dict
|
107
160
|
|
161
|
+
if enable_explainability or enable_explainability is None:
|
162
|
+
python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
|
163
|
+
if python_base_obj is None:
|
164
|
+
if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
|
165
|
+
raise ValueError("Explain only support for xgboost or lightgbm Snowpark ML models.")
|
166
|
+
# set None to False so we don't include shap in the environment
|
167
|
+
enable_explainability = False
|
168
|
+
else:
|
169
|
+
model_objective_and_output_type = model_objective_utils.get_model_objective_and_output_type(
|
170
|
+
python_base_obj
|
171
|
+
)
|
172
|
+
model_meta.model_objective = model_objective_and_output_type.objective
|
173
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
174
|
+
model_meta=model_meta,
|
175
|
+
explain_method="explain",
|
176
|
+
target_method="predict",
|
177
|
+
output_return_type=model_objective_and_output_type.output_type,
|
178
|
+
)
|
179
|
+
enable_explainability = True
|
180
|
+
|
108
181
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
109
182
|
os.makedirs(model_blob_path, exist_ok=True)
|
110
183
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
@@ -122,7 +195,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
122
195
|
model_dependencies = model._get_dependencies()
|
123
196
|
for dep in model_dependencies:
|
124
197
|
pkg_name = dep.split("==")[0]
|
125
|
-
|
198
|
+
if pkg_name != "xgboost":
|
199
|
+
_include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
|
200
|
+
continue
|
201
|
+
|
202
|
+
local_xgb_version = cls._get_local_version_package("xgboost")
|
203
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
204
|
+
model_meta.env.include_if_absent(
|
205
|
+
[
|
206
|
+
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
207
|
+
],
|
208
|
+
check_local_version=False,
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
model_meta.env.include_if_absent(
|
212
|
+
[
|
213
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
214
|
+
],
|
215
|
+
check_local_version=True,
|
216
|
+
)
|
217
|
+
|
218
|
+
if enable_explainability:
|
219
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
220
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
126
221
|
model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
|
127
222
|
|
128
223
|
@classmethod
|
@@ -177,6 +272,24 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
177
272
|
|
178
273
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
179
274
|
|
275
|
+
@custom_model.inference_api
|
276
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
277
|
+
import shap
|
278
|
+
|
279
|
+
methods = ["to_xgboost", "to_lightgbm"]
|
280
|
+
for method_name in methods:
|
281
|
+
try:
|
282
|
+
base_model = getattr(raw_model, method_name)()
|
283
|
+
explainer = shap.TreeExplainer(base_model)
|
284
|
+
df = pd.DataFrame(explainer(X).values)
|
285
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
286
|
+
except exceptions.SnowflakeMLException:
|
287
|
+
pass # Do nothing and continue to the next method
|
288
|
+
raise ValueError("The model must be an xgboost or lightgbm estimator.")
|
289
|
+
|
290
|
+
if target_method == "explain":
|
291
|
+
return explain_fn
|
292
|
+
|
180
293
|
return fn
|
181
294
|
|
182
295
|
type_method_dict = {}
|
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
111
111
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
112
112
|
os.makedirs(model_blob_path, exist_ok=True)
|
113
113
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
114
|
-
torch.jit.save(model, f) # type:ignore[attr-defined]
|
114
|
+
torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
|
115
115
|
base_meta = model_blob_meta.ModelBlobMeta(
|
116
116
|
name=name,
|
117
117
|
model_type=cls.HANDLER_TYPE,
|
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
141
141
|
model_blob_metadata = model_blobs_metadata[name]
|
142
142
|
model_blob_filename = model_blob_metadata.path
|
143
143
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
144
|
-
m = torch.jit.load( # type:ignore[attr-defined]
|
144
|
+
m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
|
145
145
|
f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
|
146
146
|
)
|
147
147
|
assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
-
import json
|
3
2
|
import os
|
3
|
+
import warnings
|
4
4
|
from typing import (
|
5
5
|
TYPE_CHECKING,
|
6
6
|
Any,
|
@@ -13,14 +13,20 @@ from typing import (
|
|
13
13
|
final,
|
14
14
|
)
|
15
15
|
|
16
|
+
import importlib_metadata
|
16
17
|
import numpy as np
|
17
18
|
import pandas as pd
|
19
|
+
from packaging import version
|
18
20
|
from typing_extensions import TypeGuard, Unpack
|
19
21
|
|
20
22
|
from snowflake.ml._internal import type_utils
|
21
23
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
22
24
|
from snowflake.ml.model._packager.model_env import model_env
|
23
|
-
from snowflake.ml.model._packager.model_handlers import
|
25
|
+
from snowflake.ml.model._packager.model_handlers import (
|
26
|
+
_base,
|
27
|
+
_utils as handlers_utils,
|
28
|
+
model_objective_utils,
|
29
|
+
)
|
24
30
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
25
31
|
from snowflake.ml.model._packager.model_meta import (
|
26
32
|
model_blob_meta,
|
@@ -47,41 +53,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
47
53
|
|
48
54
|
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
49
55
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
50
|
-
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
51
|
-
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
52
|
-
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
53
|
-
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
54
|
-
|
55
|
-
@classmethod
|
56
|
-
def get_model_objective(
|
57
|
-
cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
|
58
|
-
) -> model_meta_schema.ModelObjective:
|
59
|
-
import xgboost
|
60
|
-
|
61
|
-
if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
|
62
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
63
|
-
if num_classes == 2:
|
64
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
65
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
66
|
-
if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
|
67
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
68
|
-
if isinstance(model, xgboost.XGBRanker):
|
69
|
-
return model_meta_schema.ModelObjective.RANKING
|
70
|
-
model_params = json.loads(model.save_config())
|
71
|
-
model_objective = model_params["learner"]["objective"]
|
72
|
-
for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
73
|
-
if classification_objective in model_objective:
|
74
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
75
|
-
for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
76
|
-
if classification_objective in model_objective:
|
77
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
78
|
-
for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
|
79
|
-
if ranking_objective in model_objective:
|
80
|
-
return model_meta_schema.ModelObjective.RANKING
|
81
|
-
for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
|
82
|
-
if regression_objective in model_objective:
|
83
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
84
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
85
56
|
|
86
57
|
@classmethod
|
87
58
|
def can_handle(
|
@@ -116,10 +87,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
116
87
|
is_sub_model: Optional[bool] = False,
|
117
88
|
**kwargs: Unpack[model_types.XGBModelSaveOptions],
|
118
89
|
) -> None:
|
90
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
91
|
+
|
119
92
|
import xgboost
|
120
93
|
|
121
94
|
assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
|
122
95
|
|
96
|
+
local_xgb_version = None
|
97
|
+
|
98
|
+
try:
|
99
|
+
local_dist = importlib_metadata.distribution("xgboost") # type: ignore[no-untyped-call]
|
100
|
+
local_xgb_version = version.parse(local_dist.version)
|
101
|
+
except importlib_metadata.PackageNotFoundError:
|
102
|
+
pass
|
103
|
+
|
104
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
|
105
|
+
warnings.warn(
|
106
|
+
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
107
|
+
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
108
|
+
category=UserWarning,
|
109
|
+
stacklevel=1,
|
110
|
+
)
|
111
|
+
enable_explainability = False
|
112
|
+
|
123
113
|
if not is_sub_model:
|
124
114
|
target_methods = handlers_utils.get_target_methods(
|
125
115
|
model=model,
|
@@ -148,17 +138,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
148
138
|
sample_input_data=sample_input_data,
|
149
139
|
get_prediction_fn=get_prediction,
|
150
140
|
)
|
151
|
-
|
152
|
-
model_meta.model_objective =
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
output_type = model_signature.DataType.STRING
|
141
|
+
model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
|
142
|
+
model_meta.model_objective = handlers_utils.validate_model_objective(
|
143
|
+
model_meta.model_objective, model_objective_and_output.objective
|
144
|
+
)
|
145
|
+
if enable_explainability:
|
157
146
|
model_meta = handlers_utils.add_explain_method_signature(
|
158
147
|
model_meta=model_meta,
|
159
148
|
explain_method="explain",
|
160
149
|
target_method="predict",
|
161
|
-
output_return_type=output_type,
|
150
|
+
output_return_type=model_objective_and_output.output_type,
|
162
151
|
)
|
163
152
|
model_meta.function_properties = {
|
164
153
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
@@ -180,15 +169,26 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
180
169
|
model_meta.env.include_if_absent(
|
181
170
|
[
|
182
171
|
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
183
|
-
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
184
172
|
],
|
185
173
|
check_local_version=True,
|
186
174
|
)
|
187
|
-
if
|
175
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
188
176
|
model_meta.env.include_if_absent(
|
189
|
-
[
|
177
|
+
[
|
178
|
+
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
179
|
+
],
|
180
|
+
check_local_version=False,
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
model_meta.env.include_if_absent(
|
184
|
+
[
|
185
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
186
|
+
],
|
190
187
|
check_local_version=True,
|
191
188
|
)
|
189
|
+
|
190
|
+
if enable_explainability:
|
191
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
192
192
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
193
193
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
194
194
|
|
@@ -269,7 +269,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
269
269
|
import shap
|
270
270
|
|
271
271
|
explainer = shap.TreeExplainer(raw_model)
|
272
|
-
df =
|
272
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
273
273
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
274
274
|
|
275
275
|
if target_method == "explain":
|
@@ -55,6 +55,7 @@ def create_model_metadata(
|
|
55
55
|
conda_dependencies: Optional[List[str]] = None,
|
56
56
|
pip_requirements: Optional[List[str]] = None,
|
57
57
|
python_version: Optional[str] = None,
|
58
|
+
model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
|
58
59
|
**kwargs: Any,
|
59
60
|
) -> Generator["ModelMetadata", None, None]:
|
60
61
|
"""Create a generator for model metadata object. Use generator to ensure correct register and unregister for
|
@@ -74,6 +75,9 @@ def create_model_metadata(
|
|
74
75
|
pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
|
75
76
|
python_version: A string of python version where model is run. Used for user override. If specified as None,
|
76
77
|
current version would be captured. Defaults to None.
|
78
|
+
model_objective: The objective of the Model Version. It is an enum class ModelObjective with values REGRESSION,
|
79
|
+
BINARY_CLASSIFICATION, MULTI_CLASSIFICATION, RANKING, or UNKNOWN. By default it is set to
|
80
|
+
ModelObjective.UNKNOWN and may be overridden by inferring from the Model Object.
|
77
81
|
**kwargs: Dict of attributes and values of the metadata. Used when loading from file.
|
78
82
|
|
79
83
|
Raises:
|
@@ -131,6 +135,7 @@ def create_model_metadata(
|
|
131
135
|
model_type=model_type,
|
132
136
|
signatures=signatures,
|
133
137
|
function_properties=function_properties,
|
138
|
+
model_objective=model_objective,
|
134
139
|
)
|
135
140
|
|
136
141
|
code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
|
@@ -261,7 +266,7 @@ class ModelMetadata:
|
|
261
266
|
min_snowpark_ml_version: Optional[str] = None,
|
262
267
|
models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
|
263
268
|
original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
|
264
|
-
model_objective:
|
269
|
+
model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
|
265
270
|
explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
|
266
271
|
) -> None:
|
267
272
|
self.name = name
|
@@ -287,9 +292,7 @@ class ModelMetadata:
|
|
287
292
|
|
288
293
|
self.original_metadata_version = original_metadata_version
|
289
294
|
|
290
|
-
self.model_objective:
|
291
|
-
model_objective or model_meta_schema.ModelObjective.UNKNOWN
|
292
|
-
)
|
295
|
+
self.model_objective: model_types.ModelObjective = model_objective
|
293
296
|
self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
|
294
297
|
|
295
298
|
@property
|
@@ -387,7 +390,7 @@ class ModelMetadata:
|
|
387
390
|
signatures=loaded_meta["signatures"],
|
388
391
|
version=original_loaded_meta_version,
|
389
392
|
min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
|
390
|
-
model_objective=loaded_meta.get("model_objective",
|
393
|
+
model_objective=loaded_meta.get("model_objective", model_types.ModelObjective.UNKNOWN.value),
|
391
394
|
explainability=loaded_meta.get("explainability", None),
|
392
395
|
function_properties=loaded_meta.get("function_properties", {}),
|
393
396
|
)
|
@@ -442,8 +445,8 @@ class ModelMetadata:
|
|
442
445
|
min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
|
443
446
|
models=models,
|
444
447
|
original_metadata_version=model_dict["version"],
|
445
|
-
model_objective=
|
446
|
-
model_dict.get("model_objective",
|
448
|
+
model_objective=model_types.ModelObjective(
|
449
|
+
model_dict.get("model_objective", model_types.ModelObjective.UNKNOWN.value)
|
447
450
|
),
|
448
451
|
explain_algorithm=explanation_algorithm,
|
449
452
|
function_properties=model_dict.get("function_properties", {}),
|