snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,27 @@
|
|
1
1
|
import os
|
2
2
|
import warnings
|
3
|
-
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
+
from packaging import version
|
8
9
|
from typing_extensions import TypeGuard, Unpack
|
9
10
|
|
10
11
|
from snowflake.ml._internal import type_utils
|
12
|
+
from snowflake.ml._internal.exceptions import exceptions
|
11
13
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
14
|
from snowflake.ml.model._packager.model_env import model_env
|
13
|
-
from snowflake.ml.model._packager.model_handlers import
|
15
|
+
from snowflake.ml.model._packager.model_handlers import (
|
16
|
+
_base,
|
17
|
+
_utils as handlers_utils,
|
18
|
+
model_objective_utils,
|
19
|
+
)
|
14
20
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
15
21
|
from snowflake.ml.model._packager.model_meta import (
|
16
22
|
model_blob_meta,
|
17
23
|
model_meta as model_meta_api,
|
24
|
+
model_meta_schema,
|
18
25
|
)
|
19
26
|
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
20
27
|
|
@@ -62,6 +69,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
62
69
|
|
63
70
|
return cast("BaseEstimator", model)
|
64
71
|
|
72
|
+
@classmethod
|
73
|
+
def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
|
74
|
+
import importlib_metadata
|
75
|
+
from packaging import version
|
76
|
+
|
77
|
+
local_version = None
|
78
|
+
|
79
|
+
try:
|
80
|
+
local_dist = importlib_metadata.distribution(pkg_name) # type: ignore[no-untyped-call]
|
81
|
+
local_version = version.parse(local_dist.version)
|
82
|
+
except importlib_metadata.PackageNotFoundError:
|
83
|
+
pass
|
84
|
+
|
85
|
+
return local_version
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
|
89
|
+
|
90
|
+
local_xgb_version = cls._get_local_version_package("xgboost")
|
91
|
+
|
92
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
|
93
|
+
if enable_explainability:
|
94
|
+
warnings.warn(
|
95
|
+
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
96
|
+
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
97
|
+
category=UserWarning,
|
98
|
+
stacklevel=1,
|
99
|
+
)
|
100
|
+
return False
|
101
|
+
return True
|
102
|
+
|
103
|
+
@classmethod
|
104
|
+
def _get_supported_object_for_explainability(
|
105
|
+
cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
|
106
|
+
) -> Any:
|
107
|
+
methods = ["to_xgboost", "to_lightgbm"]
|
108
|
+
for method_name in methods:
|
109
|
+
if hasattr(estimator, method_name):
|
110
|
+
try:
|
111
|
+
result = getattr(estimator, method_name)()
|
112
|
+
if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
|
113
|
+
return None
|
114
|
+
return result
|
115
|
+
except exceptions.SnowflakeMLException:
|
116
|
+
pass # Do nothing and continue to the next method
|
117
|
+
return None
|
118
|
+
|
65
119
|
@classmethod
|
66
120
|
def save_model(
|
67
121
|
cls,
|
@@ -73,6 +127,9 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
73
127
|
is_sub_model: Optional[bool] = False,
|
74
128
|
**kwargs: Unpack[model_types.SNOWModelSaveOptions],
|
75
129
|
) -> None:
|
130
|
+
|
131
|
+
enable_explainability = kwargs.get("enable_explainability", None)
|
132
|
+
|
76
133
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
77
134
|
|
78
135
|
assert isinstance(model, BaseEstimator)
|
@@ -101,15 +158,35 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
101
158
|
raise ValueError(f"Target method {method_name} does not exist in the model.")
|
102
159
|
model_meta.signatures = temp_model_signature_dict
|
103
160
|
|
161
|
+
if enable_explainability or enable_explainability is None:
|
162
|
+
python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
|
163
|
+
if python_base_obj is None:
|
164
|
+
if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
|
165
|
+
raise ValueError("Explain only support for xgboost or lightgbm Snowpark ML models.")
|
166
|
+
# set None to False so we don't include shap in the environment
|
167
|
+
enable_explainability = False
|
168
|
+
else:
|
169
|
+
model_objective_and_output_type = model_objective_utils.get_model_objective_and_output_type(
|
170
|
+
python_base_obj
|
171
|
+
)
|
172
|
+
model_meta.model_objective = model_objective_and_output_type.objective
|
173
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
174
|
+
model_meta=model_meta,
|
175
|
+
explain_method="explain",
|
176
|
+
target_method="predict",
|
177
|
+
output_return_type=model_objective_and_output_type.output_type,
|
178
|
+
)
|
179
|
+
enable_explainability = True
|
180
|
+
|
104
181
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
105
182
|
os.makedirs(model_blob_path, exist_ok=True)
|
106
|
-
with open(os.path.join(model_blob_path, cls.
|
183
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
107
184
|
cloudpickle.dump(model, f)
|
108
185
|
base_meta = model_blob_meta.ModelBlobMeta(
|
109
186
|
name=name,
|
110
187
|
model_type=cls.HANDLER_TYPE,
|
111
188
|
handler_version=cls.HANDLER_VERSION,
|
112
|
-
path=cls.
|
189
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
113
190
|
)
|
114
191
|
model_meta.models[name] = base_meta
|
115
192
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -118,7 +195,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
118
195
|
model_dependencies = model._get_dependencies()
|
119
196
|
for dep in model_dependencies:
|
120
197
|
pkg_name = dep.split("==")[0]
|
121
|
-
|
198
|
+
if pkg_name != "xgboost":
|
199
|
+
_include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
|
200
|
+
continue
|
201
|
+
|
202
|
+
local_xgb_version = cls._get_local_version_package("xgboost")
|
203
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
204
|
+
model_meta.env.include_if_absent(
|
205
|
+
[
|
206
|
+
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
207
|
+
],
|
208
|
+
check_local_version=False,
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
model_meta.env.include_if_absent(
|
212
|
+
[
|
213
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
214
|
+
],
|
215
|
+
check_local_version=True,
|
216
|
+
)
|
217
|
+
|
218
|
+
if enable_explainability:
|
219
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
220
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
122
221
|
model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
|
123
222
|
|
124
223
|
@classmethod
|
@@ -146,6 +245,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
146
245
|
cls,
|
147
246
|
raw_model: "BaseEstimator",
|
148
247
|
model_meta: model_meta_api.ModelMetadata,
|
248
|
+
background_data: Optional[pd.DataFrame] = None,
|
149
249
|
**kwargs: Unpack[model_types.SNOWModelLoadOptions],
|
150
250
|
) -> custom_model.CustomModel:
|
151
251
|
from snowflake.ml.model import custom_model
|
@@ -172,6 +272,24 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
172
272
|
|
173
273
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
174
274
|
|
275
|
+
@custom_model.inference_api
|
276
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
277
|
+
import shap
|
278
|
+
|
279
|
+
methods = ["to_xgboost", "to_lightgbm"]
|
280
|
+
for method_name in methods:
|
281
|
+
try:
|
282
|
+
base_model = getattr(raw_model, method_name)()
|
283
|
+
explainer = shap.TreeExplainer(base_model)
|
284
|
+
df = pd.DataFrame(explainer(X).values)
|
285
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
286
|
+
except exceptions.SnowflakeMLException:
|
287
|
+
pass # Do nothing and continue to the next method
|
288
|
+
raise ValueError("The model must be an xgboost or lightgbm estimator.")
|
289
|
+
|
290
|
+
if target_method == "explain":
|
291
|
+
return explain_fn
|
292
|
+
|
175
293
|
return fn
|
176
294
|
|
177
295
|
type_method_dict = {}
|
@@ -36,7 +36,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
36
36
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
37
37
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
38
38
|
|
39
|
-
|
39
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
40
40
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
41
41
|
|
42
42
|
@classmethod
|
@@ -68,6 +68,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
68
68
|
is_sub_model: Optional[bool] = False,
|
69
69
|
**kwargs: Unpack[model_types.TensorflowSaveOptions],
|
70
70
|
) -> None:
|
71
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
72
|
+
if enable_explainability:
|
73
|
+
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
74
|
+
|
71
75
|
import tensorflow
|
72
76
|
|
73
77
|
assert isinstance(model, tensorflow.Module)
|
@@ -114,15 +118,15 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
114
118
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
115
119
|
os.makedirs(model_blob_path, exist_ok=True)
|
116
120
|
if isinstance(model, tensorflow.keras.Model):
|
117
|
-
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.
|
121
|
+
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
118
122
|
else:
|
119
|
-
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.
|
123
|
+
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
120
124
|
|
121
125
|
base_meta = model_blob_meta.ModelBlobMeta(
|
122
126
|
name=name,
|
123
127
|
model_type=cls.HANDLER_TYPE,
|
124
128
|
handler_version=cls.HANDLER_VERSION,
|
125
|
-
path=cls.
|
129
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
126
130
|
)
|
127
131
|
model_meta.models[name] = base_meta
|
128
132
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -156,6 +160,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
156
160
|
cls,
|
157
161
|
raw_model: "tensorflow.Module",
|
158
162
|
model_meta: model_meta_api.ModelMetadata,
|
163
|
+
background_data: Optional[pd.DataFrame] = None,
|
159
164
|
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
160
165
|
) -> custom_model.CustomModel:
|
161
166
|
import tensorflow
|
@@ -34,7 +34,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
34
34
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
35
35
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
36
|
|
37
|
-
|
37
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
38
38
|
DEFAULT_TARGET_METHODS = ["forward"]
|
39
39
|
|
40
40
|
@classmethod
|
@@ -66,6 +66,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
66
66
|
is_sub_model: Optional[bool] = False,
|
67
67
|
**kwargs: Unpack[model_types.TorchScriptSaveOptions],
|
68
68
|
) -> None:
|
69
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
70
|
+
if enable_explainability:
|
71
|
+
raise NotImplementedError("Explainability is not supported for Torch Script model.")
|
72
|
+
|
69
73
|
import torch
|
70
74
|
|
71
75
|
assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
@@ -106,13 +110,13 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
106
110
|
|
107
111
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
108
112
|
os.makedirs(model_blob_path, exist_ok=True)
|
109
|
-
with open(os.path.join(model_blob_path, cls.
|
110
|
-
torch.jit.save(model, f) # type:ignore[attr-defined]
|
113
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
114
|
+
torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
|
111
115
|
base_meta = model_blob_meta.ModelBlobMeta(
|
112
116
|
name=name,
|
113
117
|
model_type=cls.HANDLER_TYPE,
|
114
118
|
handler_version=cls.HANDLER_VERSION,
|
115
|
-
path=cls.
|
119
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
116
120
|
)
|
117
121
|
model_meta.models[name] = base_meta
|
118
122
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -137,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
137
141
|
model_blob_metadata = model_blobs_metadata[name]
|
138
142
|
model_blob_filename = model_blob_metadata.path
|
139
143
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
140
|
-
m = torch.jit.load( # type:ignore[attr-defined]
|
144
|
+
m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
|
141
145
|
f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
|
142
146
|
)
|
143
147
|
assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
@@ -152,6 +156,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
152
156
|
cls,
|
153
157
|
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
154
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
|
+
background_data: Optional[pd.DataFrame] = None,
|
155
160
|
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
156
161
|
) -> custom_model.CustomModel:
|
157
162
|
from snowflake.ml.model import custom_model
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
-
import json
|
3
2
|
import os
|
3
|
+
import warnings
|
4
4
|
from typing import (
|
5
5
|
TYPE_CHECKING,
|
6
6
|
Any,
|
@@ -13,14 +13,20 @@ from typing import (
|
|
13
13
|
final,
|
14
14
|
)
|
15
15
|
|
16
|
+
import importlib_metadata
|
16
17
|
import numpy as np
|
17
18
|
import pandas as pd
|
19
|
+
from packaging import version
|
18
20
|
from typing_extensions import TypeGuard, Unpack
|
19
21
|
|
20
22
|
from snowflake.ml._internal import type_utils
|
21
23
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
22
24
|
from snowflake.ml.model._packager.model_env import model_env
|
23
|
-
from snowflake.ml.model._packager.model_handlers import
|
25
|
+
from snowflake.ml.model._packager.model_handlers import (
|
26
|
+
_base,
|
27
|
+
_utils as handlers_utils,
|
28
|
+
model_objective_utils,
|
29
|
+
)
|
24
30
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
25
31
|
from snowflake.ml.model._packager.model_meta import (
|
26
32
|
model_blob_meta,
|
@@ -45,41 +51,8 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
45
51
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
46
52
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
47
53
|
|
48
|
-
|
54
|
+
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
49
55
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
50
|
-
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
51
|
-
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
52
|
-
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
53
|
-
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
54
|
-
|
55
|
-
@classmethod
|
56
|
-
def get_model_objective(cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> _base.ModelObjective:
|
57
|
-
import xgboost
|
58
|
-
|
59
|
-
if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
|
60
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
61
|
-
if num_classes == 2:
|
62
|
-
return _base.ModelObjective.BINARY_CLASSIFICATION
|
63
|
-
return _base.ModelObjective.MULTI_CLASSIFICATION
|
64
|
-
if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
|
65
|
-
return _base.ModelObjective.REGRESSION
|
66
|
-
if isinstance(model, xgboost.XGBRanker):
|
67
|
-
return _base.ModelObjective.RANKING
|
68
|
-
model_params = json.loads(model.save_config())
|
69
|
-
model_objective = model_params["learner"]["objective"]
|
70
|
-
for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
71
|
-
if classification_objective in model_objective:
|
72
|
-
return _base.ModelObjective.BINARY_CLASSIFICATION
|
73
|
-
for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
74
|
-
if classification_objective in model_objective:
|
75
|
-
return _base.ModelObjective.MULTI_CLASSIFICATION
|
76
|
-
for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
|
77
|
-
if ranking_objective in model_objective:
|
78
|
-
return _base.ModelObjective.RANKING
|
79
|
-
for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
|
80
|
-
if regression_objective in model_objective:
|
81
|
-
return _base.ModelObjective.REGRESSION
|
82
|
-
return _base.ModelObjective.UNKNOWN
|
83
56
|
|
84
57
|
@classmethod
|
85
58
|
def can_handle(
|
@@ -114,10 +87,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
114
87
|
is_sub_model: Optional[bool] = False,
|
115
88
|
**kwargs: Unpack[model_types.XGBModelSaveOptions],
|
116
89
|
) -> None:
|
90
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
91
|
+
|
117
92
|
import xgboost
|
118
93
|
|
119
94
|
assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
|
120
95
|
|
96
|
+
local_xgb_version = None
|
97
|
+
|
98
|
+
try:
|
99
|
+
local_dist = importlib_metadata.distribution("xgboost") # type: ignore[no-untyped-call]
|
100
|
+
local_xgb_version = version.parse(local_dist.version)
|
101
|
+
except importlib_metadata.PackageNotFoundError:
|
102
|
+
pass
|
103
|
+
|
104
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
|
105
|
+
warnings.warn(
|
106
|
+
f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
|
107
|
+
+ "If you want model explanations, lower the xgboost version to <2.1.0.",
|
108
|
+
category=UserWarning,
|
109
|
+
stacklevel=1,
|
110
|
+
)
|
111
|
+
enable_explainability = False
|
112
|
+
|
121
113
|
if not is_sub_model:
|
122
114
|
target_methods = handlers_utils.get_target_methods(
|
123
115
|
model=model,
|
@@ -146,25 +138,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
146
138
|
sample_input_data=sample_input_data,
|
147
139
|
get_prediction_fn=get_prediction,
|
148
140
|
)
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
141
|
+
model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
|
142
|
+
model_meta.model_objective = handlers_utils.validate_model_objective(
|
143
|
+
model_meta.model_objective, model_objective_and_output.objective
|
144
|
+
)
|
145
|
+
if enable_explainability:
|
153
146
|
model_meta = handlers_utils.add_explain_method_signature(
|
154
147
|
model_meta=model_meta,
|
155
148
|
explain_method="explain",
|
156
149
|
target_method="predict",
|
157
|
-
output_return_type=output_type,
|
150
|
+
output_return_type=model_objective_and_output.output_type,
|
158
151
|
)
|
152
|
+
model_meta.function_properties = {
|
153
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
154
|
+
}
|
159
155
|
|
160
156
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
161
157
|
os.makedirs(model_blob_path, exist_ok=True)
|
162
|
-
model.save_model(os.path.join(model_blob_path, cls.
|
158
|
+
model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
163
159
|
base_meta = model_blob_meta.ModelBlobMeta(
|
164
160
|
name=name,
|
165
161
|
model_type=cls.HANDLER_TYPE,
|
166
162
|
handler_version=cls.HANDLER_VERSION,
|
167
|
-
path=cls.
|
163
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
168
164
|
options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
|
169
165
|
)
|
170
166
|
model_meta.models[name] = base_meta
|
@@ -173,15 +169,27 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
173
169
|
model_meta.env.include_if_absent(
|
174
170
|
[
|
175
171
|
model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
|
176
|
-
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
177
172
|
],
|
178
173
|
check_local_version=True,
|
179
174
|
)
|
180
|
-
if
|
175
|
+
if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
|
181
176
|
model_meta.env.include_if_absent(
|
182
|
-
[
|
177
|
+
[
|
178
|
+
model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
|
179
|
+
],
|
180
|
+
check_local_version=False,
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
model_meta.env.include_if_absent(
|
184
|
+
[
|
185
|
+
model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
|
186
|
+
],
|
183
187
|
check_local_version=True,
|
184
188
|
)
|
189
|
+
|
190
|
+
if enable_explainability:
|
191
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
192
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
185
193
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
186
194
|
|
187
195
|
@classmethod
|
@@ -224,6 +232,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
224
232
|
cls,
|
225
233
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
226
234
|
model_meta: model_meta_api.ModelMetadata,
|
235
|
+
background_data: Optional[pd.DataFrame] = None,
|
227
236
|
**kwargs: Unpack[model_types.XGBModelLoadOptions],
|
228
237
|
) -> custom_model.CustomModel:
|
229
238
|
import xgboost
|
@@ -260,7 +269,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
260
269
|
import shap
|
261
270
|
|
262
271
|
explainer = shap.TreeExplainer(raw_model)
|
263
|
-
df =
|
272
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
264
273
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
265
274
|
|
266
275
|
if target_method == "explain":
|
@@ -55,6 +55,7 @@ def create_model_metadata(
|
|
55
55
|
conda_dependencies: Optional[List[str]] = None,
|
56
56
|
pip_requirements: Optional[List[str]] = None,
|
57
57
|
python_version: Optional[str] = None,
|
58
|
+
model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
|
58
59
|
**kwargs: Any,
|
59
60
|
) -> Generator["ModelMetadata", None, None]:
|
60
61
|
"""Create a generator for model metadata object. Use generator to ensure correct register and unregister for
|
@@ -74,6 +75,9 @@ def create_model_metadata(
|
|
74
75
|
pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
|
75
76
|
python_version: A string of python version where model is run. Used for user override. If specified as None,
|
76
77
|
current version would be captured. Defaults to None.
|
78
|
+
model_objective: The objective of the Model Version. It is an enum class ModelObjective with values REGRESSION,
|
79
|
+
BINARY_CLASSIFICATION, MULTI_CLASSIFICATION, RANKING, or UNKNOWN. By default it is set to
|
80
|
+
ModelObjective.UNKNOWN and may be overridden by inferring from the Model Object.
|
77
81
|
**kwargs: Dict of attributes and values of the metadata. Used when loading from file.
|
78
82
|
|
79
83
|
Raises:
|
@@ -131,6 +135,7 @@ def create_model_metadata(
|
|
131
135
|
model_type=model_type,
|
132
136
|
signatures=signatures,
|
133
137
|
function_properties=function_properties,
|
138
|
+
model_objective=model_objective,
|
134
139
|
)
|
135
140
|
|
136
141
|
code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
|
@@ -237,6 +242,7 @@ class ModelMetadata:
|
|
237
242
|
function_properties: A dict mapping function names to dict mapping function property key to value.
|
238
243
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
239
244
|
creation_timestamp: Unix timestamp when the model metadata is created.
|
245
|
+
model_objective: Model objective like regression, classification etc.
|
240
246
|
"""
|
241
247
|
|
242
248
|
def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
|
@@ -260,6 +266,8 @@ class ModelMetadata:
|
|
260
266
|
min_snowpark_ml_version: Optional[str] = None,
|
261
267
|
models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
|
262
268
|
original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
|
269
|
+
model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
|
270
|
+
explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
|
263
271
|
) -> None:
|
264
272
|
self.name = name
|
265
273
|
self.signatures: Dict[str, model_signature.ModelSignature] = dict()
|
@@ -284,6 +292,9 @@ class ModelMetadata:
|
|
284
292
|
|
285
293
|
self.original_metadata_version = original_metadata_version
|
286
294
|
|
295
|
+
self.model_objective: model_types.ModelObjective = model_objective
|
296
|
+
self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
|
297
|
+
|
287
298
|
@property
|
288
299
|
def min_snowpark_ml_version(self) -> str:
|
289
300
|
return self._min_snowpark_ml_version.base_version
|
@@ -321,9 +332,11 @@ class ModelMetadata:
|
|
321
332
|
model_dict = model_meta_schema.ModelMetadataDict(
|
322
333
|
{
|
323
334
|
"creation_timestamp": self.creation_timestamp,
|
324
|
-
"env": self.env.save_as_dict(
|
335
|
+
"env": self.env.save_as_dict(
|
336
|
+
pathlib.Path(model_dir_path), default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
337
|
+
),
|
325
338
|
"runtimes": {
|
326
|
-
runtime_name: runtime.save(pathlib.Path(model_dir_path))
|
339
|
+
runtime_name: runtime.save(pathlib.Path(model_dir_path), default_channel_override="conda-forge")
|
327
340
|
for runtime_name, runtime in self.runtimes.items()
|
328
341
|
},
|
329
342
|
"metadata": self.metadata,
|
@@ -333,6 +346,13 @@ class ModelMetadata:
|
|
333
346
|
"signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()},
|
334
347
|
"version": model_meta_schema.MODEL_METADATA_VERSION,
|
335
348
|
"min_snowpark_ml_version": self.min_snowpark_ml_version,
|
349
|
+
"model_objective": self.model_objective.value,
|
350
|
+
"explainability": (
|
351
|
+
model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value)
|
352
|
+
if self.explain_algorithm
|
353
|
+
else None
|
354
|
+
),
|
355
|
+
"function_properties": self.function_properties,
|
336
356
|
}
|
337
357
|
)
|
338
358
|
|
@@ -370,6 +390,9 @@ class ModelMetadata:
|
|
370
390
|
signatures=loaded_meta["signatures"],
|
371
391
|
version=original_loaded_meta_version,
|
372
392
|
min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
|
393
|
+
model_objective=loaded_meta.get("model_objective", model_types.ModelObjective.UNKNOWN.value),
|
394
|
+
explainability=loaded_meta.get("explainability", None),
|
395
|
+
function_properties=loaded_meta.get("function_properties", {}),
|
373
396
|
)
|
374
397
|
|
375
398
|
@classmethod
|
@@ -406,6 +429,11 @@ class ModelMetadata:
|
|
406
429
|
else:
|
407
430
|
runtimes = None
|
408
431
|
|
432
|
+
explanation_algorithm_dict = model_dict.get("explainability", None)
|
433
|
+
explanation_algorithm = None
|
434
|
+
if explanation_algorithm_dict:
|
435
|
+
explanation_algorithm = model_meta_schema.ModelExplainAlgorithm(explanation_algorithm_dict["algorithm"])
|
436
|
+
|
409
437
|
return cls(
|
410
438
|
name=model_dict["name"],
|
411
439
|
model_type=model_dict["model_type"],
|
@@ -417,4 +445,9 @@ class ModelMetadata:
|
|
417
445
|
min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
|
418
446
|
models=models,
|
419
447
|
original_metadata_version=model_dict["version"],
|
448
|
+
model_objective=model_types.ModelObjective(
|
449
|
+
model_dict.get("model_objective", model_types.ModelObjective.UNKNOWN.value)
|
450
|
+
),
|
451
|
+
explain_algorithm=explanation_algorithm,
|
452
|
+
function_properties=model_dict.get("function_properties", {}),
|
420
453
|
)
|
@@ -71,6 +71,10 @@ ModelBlobOptions = Union[
|
|
71
71
|
]
|
72
72
|
|
73
73
|
|
74
|
+
class ExplainabilityMetadataDict(TypedDict):
|
75
|
+
algorithm: Required[str]
|
76
|
+
|
77
|
+
|
74
78
|
class ModelBlobMetadataDict(TypedDict):
|
75
79
|
name: Required[str]
|
76
80
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
@@ -92,3 +96,10 @@ class ModelMetadataDict(TypedDict):
|
|
92
96
|
signatures: Required[Dict[str, Dict[str, Any]]]
|
93
97
|
version: Required[str]
|
94
98
|
min_snowpark_ml_version: Required[str]
|
99
|
+
model_objective: Required[str]
|
100
|
+
explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
|
101
|
+
function_properties: NotRequired[Dict[str, Dict[str, Any]]]
|
102
|
+
|
103
|
+
|
104
|
+
class ModelExplainAlgorithm(Enum):
|
105
|
+
SHAP = "shap"
|
@@ -47,6 +47,7 @@ class ModelPackager:
|
|
47
47
|
ext_modules: Optional[List[ModuleType]] = None,
|
48
48
|
code_paths: Optional[List[str]] = None,
|
49
49
|
options: Optional[model_types.ModelSaveOption] = None,
|
50
|
+
model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
|
50
51
|
) -> model_meta.ModelMetadata:
|
51
52
|
if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
|
52
53
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -84,6 +85,7 @@ class ModelPackager:
|
|
84
85
|
conda_dependencies=conda_dependencies,
|
85
86
|
pip_requirements=pip_requirements,
|
86
87
|
python_version=python_version,
|
88
|
+
model_objective=model_objective,
|
87
89
|
**options,
|
88
90
|
) as meta:
|
89
91
|
model_blobs_path = os.path.join(self.local_dir_path, ModelPackager.MODEL_BLOBS_DIR)
|
@@ -146,7 +148,8 @@ class ModelPackager:
|
|
146
148
|
m = handler.load_model(self.meta.name, self.meta, model_blobs_path, **options)
|
147
149
|
|
148
150
|
if as_custom_model:
|
149
|
-
|
151
|
+
background_data = handler.load_background_data(self.meta.name, model_blobs_path)
|
152
|
+
m = handler.convert_as_custom_model(m, self.meta, background_data, **options)
|
150
153
|
assert isinstance(m, custom_model.CustomModel)
|
151
154
|
|
152
155
|
self.model = m
|
@@ -67,7 +67,9 @@ class ModelRuntime:
|
|
67
67
|
def runtime_rel_path(self) -> pathlib.PurePosixPath:
|
68
68
|
return pathlib.PurePosixPath(ModelRuntime.RUNTIME_DIR_REL_PATH) / self.name
|
69
69
|
|
70
|
-
def save(
|
70
|
+
def save(
|
71
|
+
self, packager_path: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
72
|
+
) -> model_meta_schema.ModelRuntimeDict:
|
71
73
|
runtime_base_path = packager_path / self.runtime_rel_path
|
72
74
|
runtime_base_path.mkdir(parents=True, exist_ok=True)
|
73
75
|
|
@@ -80,7 +82,7 @@ class ModelRuntime:
|
|
80
82
|
self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
|
81
83
|
self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
|
82
84
|
|
83
|
-
env_dict = self.runtime_env.save_as_dict(packager_path)
|
85
|
+
env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
|
84
86
|
|
85
87
|
return model_meta_schema.ModelRuntimeDict(
|
86
88
|
imports=list(map(str, self.imports)),
|
@@ -30,7 +30,7 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
|
|
30
30
|
|
31
31
|
@staticmethod
|
32
32
|
def count(data: Sequence["torch.Tensor"]) -> int:
|
33
|
-
return min(data_col.shape[0] for data_col in data)
|
33
|
+
return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return]
|
34
34
|
|
35
35
|
@staticmethod
|
36
36
|
def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]:
|