snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,11 @@ from typing_extensions import TypeGuard, Unpack
|
|
19
19
|
from snowflake.ml._internal import type_utils
|
20
20
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
21
21
|
from snowflake.ml.model._packager.model_env import model_env
|
22
|
-
from snowflake.ml.model._packager.model_handlers import
|
22
|
+
from snowflake.ml.model._packager.model_handlers import (
|
23
|
+
_base,
|
24
|
+
_utils as handlers_utils,
|
25
|
+
model_objective_utils,
|
26
|
+
)
|
23
27
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
24
28
|
from snowflake.ml.model._packager.model_meta import (
|
25
29
|
model_blob_meta,
|
@@ -41,47 +45,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
41
45
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
42
46
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
43
47
|
|
44
|
-
|
48
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
45
49
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
-
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
47
|
-
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
48
|
-
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
49
|
-
_REGRESSION_OBJECTIVES = [
|
50
|
-
"regression",
|
51
|
-
"regression_l1",
|
52
|
-
"huber",
|
53
|
-
"fair",
|
54
|
-
"poisson",
|
55
|
-
"quantile",
|
56
|
-
"tweedie",
|
57
|
-
"mape",
|
58
|
-
"gamma",
|
59
|
-
]
|
60
|
-
|
61
|
-
@classmethod
|
62
|
-
def get_model_objective(cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> _base.ModelObjective:
|
63
|
-
import lightgbm
|
64
|
-
|
65
|
-
# does not account for cross-entropy and custom
|
66
|
-
if isinstance(model, lightgbm.LGBMClassifier):
|
67
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
68
|
-
if num_classes == 2:
|
69
|
-
return _base.ModelObjective.BINARY_CLASSIFICATION
|
70
|
-
return _base.ModelObjective.MULTI_CLASSIFICATION
|
71
|
-
if isinstance(model, lightgbm.LGBMRanker):
|
72
|
-
return _base.ModelObjective.RANKING
|
73
|
-
if isinstance(model, lightgbm.LGBMRegressor):
|
74
|
-
return _base.ModelObjective.REGRESSION
|
75
|
-
model_objective = model.params["objective"]
|
76
|
-
if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
|
77
|
-
return _base.ModelObjective.BINARY_CLASSIFICATION
|
78
|
-
if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
|
79
|
-
return _base.ModelObjective.MULTI_CLASSIFICATION
|
80
|
-
if model_objective in cls._RANKING_OBJECTIVES:
|
81
|
-
return _base.ModelObjective.RANKING
|
82
|
-
if model_objective in cls._REGRESSION_OBJECTIVES:
|
83
|
-
return _base.ModelObjective.REGRESSION
|
84
|
-
return _base.ModelObjective.UNKNOWN
|
85
50
|
|
86
51
|
@classmethod
|
87
52
|
def can_handle(
|
@@ -116,6 +81,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
116
81
|
is_sub_model: Optional[bool] = False,
|
117
82
|
**kwargs: Unpack[model_types.LGBMModelSaveOptions],
|
118
83
|
) -> None:
|
84
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
85
|
+
|
119
86
|
import lightgbm
|
120
87
|
|
121
88
|
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
@@ -144,24 +111,25 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
144
111
|
sample_input_data=sample_input_data,
|
145
112
|
get_prediction_fn=get_prediction,
|
146
113
|
)
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
]:
|
153
|
-
output_type = model_signature.DataType.STRING
|
114
|
+
model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
|
115
|
+
model_meta.model_objective = handlers_utils.validate_model_objective(
|
116
|
+
model_meta.model_objective, model_objective_and_output.objective
|
117
|
+
)
|
118
|
+
if enable_explainability:
|
154
119
|
model_meta = handlers_utils.add_explain_method_signature(
|
155
120
|
model_meta=model_meta,
|
156
121
|
explain_method="explain",
|
157
122
|
target_method="predict",
|
158
|
-
output_return_type=output_type,
|
123
|
+
output_return_type=model_objective_and_output.output_type,
|
159
124
|
)
|
125
|
+
model_meta.function_properties = {
|
126
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
127
|
+
}
|
160
128
|
|
161
129
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
162
130
|
os.makedirs(model_blob_path, exist_ok=True)
|
163
131
|
|
164
|
-
model_save_path = os.path.join(model_blob_path, cls.
|
132
|
+
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
165
133
|
with open(model_save_path, "wb") as f:
|
166
134
|
cloudpickle.dump(model, f)
|
167
135
|
|
@@ -169,7 +137,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
169
137
|
name=name,
|
170
138
|
model_type=cls.HANDLER_TYPE,
|
171
139
|
handler_version=cls.HANDLER_VERSION,
|
172
|
-
path=cls.
|
140
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
173
141
|
options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
|
174
142
|
)
|
175
143
|
model_meta.models[name] = base_meta
|
@@ -182,11 +150,9 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
182
150
|
],
|
183
151
|
check_local_version=True,
|
184
152
|
)
|
185
|
-
if
|
186
|
-
model_meta.env.include_if_absent(
|
187
|
-
|
188
|
-
check_local_version=True,
|
189
|
-
)
|
153
|
+
if enable_explainability:
|
154
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
155
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
190
156
|
|
191
157
|
return None
|
192
158
|
|
@@ -226,6 +192,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
226
192
|
cls,
|
227
193
|
raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
|
228
194
|
model_meta: model_meta_api.ModelMetadata,
|
195
|
+
background_data: Optional[pd.DataFrame] = None,
|
229
196
|
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
230
197
|
) -> custom_model.CustomModel:
|
231
198
|
import lightgbm
|
@@ -28,7 +28,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
28
28
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
29
29
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
30
30
|
|
31
|
-
|
31
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
32
32
|
LLM_META = "llm_meta"
|
33
33
|
IS_AUTO_SIGNATURE = True
|
34
34
|
|
@@ -59,9 +59,12 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
59
59
|
**kwargs: Unpack[model_types.LLMSaveOptions],
|
60
60
|
) -> None:
|
61
61
|
assert not is_sub_model, "LLM can not be sub-model."
|
62
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
63
|
+
if enable_explainability:
|
64
|
+
raise NotImplementedError("Explainability is not supported for llm model.")
|
62
65
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
63
66
|
os.makedirs(model_blob_path, exist_ok=True)
|
64
|
-
model_blob_dir_path = os.path.join(model_blob_path, cls.
|
67
|
+
model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
65
68
|
|
66
69
|
sig = model_signature.ModelSignature(
|
67
70
|
inputs=[
|
@@ -86,7 +89,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
86
89
|
name=name,
|
87
90
|
model_type=cls.HANDLER_TYPE,
|
88
91
|
handler_version=cls.HANDLER_VERSION,
|
89
|
-
path=cls.
|
92
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
90
93
|
options=model_meta_schema.LLMModelBlobOptions(
|
91
94
|
{
|
92
95
|
"batch_size": model.max_batch_size,
|
@@ -143,6 +146,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
143
146
|
cls,
|
144
147
|
raw_model: llm.LLM,
|
145
148
|
model_meta: model_meta_api.ModelMetadata,
|
149
|
+
background_data: Optional[pd.DataFrame] = None,
|
146
150
|
**kwargs: Unpack[model_types.LLMLoadOptions],
|
147
151
|
) -> custom_model.CustomModel:
|
148
152
|
import gc
|
@@ -201,7 +205,9 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
201
205
|
"token": raw_model.token,
|
202
206
|
}
|
203
207
|
model_dir_path = raw_model.model_id_or_path
|
204
|
-
peft_config = peft.PeftConfig.from_pretrained(
|
208
|
+
peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
|
209
|
+
model_dir_path
|
210
|
+
)
|
205
211
|
base_model_path = peft_config.base_model_name_or_path
|
206
212
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
207
213
|
base_model_path,
|
@@ -217,7 +223,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
217
223
|
model_dir_path,
|
218
224
|
device_map="auto",
|
219
225
|
torch_dtype="auto",
|
220
|
-
**hub_kwargs,
|
226
|
+
**hub_kwargs, # type: ignore[arg-type]
|
221
227
|
)
|
222
228
|
hf_model.eval()
|
223
229
|
hf_model = hf_model.merge_and_unload()
|
@@ -63,7 +63,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
63
63
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
64
64
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
65
65
|
|
66
|
-
|
66
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
67
67
|
_DEFAULT_TARGET_METHOD = "predict"
|
68
68
|
DEFAULT_TARGET_METHODS = [_DEFAULT_TARGET_METHOD]
|
69
69
|
IS_AUTO_SIGNATURE = True
|
@@ -97,6 +97,10 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
97
97
|
is_sub_model: Optional[bool] = False,
|
98
98
|
**kwargs: Unpack[model_types.MLFlowSaveOptions],
|
99
99
|
) -> None:
|
100
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
101
|
+
if enable_explainability:
|
102
|
+
raise NotImplementedError("Explainability is not supported for MLFlow model.")
|
103
|
+
|
100
104
|
import mlflow
|
101
105
|
|
102
106
|
assert isinstance(model, mlflow.pyfunc.PyFuncModel)
|
@@ -142,13 +146,13 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
142
146
|
except (mlflow.MlflowException, OSError):
|
143
147
|
raise ValueError("Cannot load MLFlow model artifacts.")
|
144
148
|
|
145
|
-
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.
|
149
|
+
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
146
150
|
|
147
151
|
base_meta = model_blob_meta.ModelBlobMeta(
|
148
152
|
name=name,
|
149
153
|
model_type=cls.HANDLER_TYPE,
|
150
154
|
handler_version=cls.HANDLER_VERSION,
|
151
|
-
path=cls.
|
155
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
152
156
|
options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": model_info.artifact_path}),
|
153
157
|
)
|
154
158
|
model_meta.models[name] = base_meta
|
@@ -194,6 +198,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
194
198
|
cls,
|
195
199
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
196
200
|
model_meta: model_meta_api.ModelMetadata,
|
201
|
+
background_data: Optional[pd.DataFrame] = None,
|
197
202
|
**kwargs: Unpack[model_types.MLFlowLoadOptions],
|
198
203
|
) -> custom_model.CustomModel:
|
199
204
|
from snowflake.ml.model import custom_model
|
@@ -0,0 +1,116 @@
|
|
1
|
+
import json
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import TYPE_CHECKING, Any, Union
|
4
|
+
|
5
|
+
from snowflake.ml.model import model_signature, type_hints
|
6
|
+
from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
import lightgbm
|
10
|
+
import xgboost
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class ModelObjectiveAndOutputType:
|
15
|
+
objective: type_hints.ModelObjective
|
16
|
+
output_type: model_signature.DataType
|
17
|
+
|
18
|
+
|
19
|
+
def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.ModelObjective:
|
20
|
+
|
21
|
+
import lightgbm
|
22
|
+
|
23
|
+
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
24
|
+
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
25
|
+
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
26
|
+
_REGRESSION_OBJECTIVES = [
|
27
|
+
"regression",
|
28
|
+
"regression_l1",
|
29
|
+
"huber",
|
30
|
+
"fair",
|
31
|
+
"poisson",
|
32
|
+
"quantile",
|
33
|
+
"tweedie",
|
34
|
+
"mape",
|
35
|
+
"gamma",
|
36
|
+
]
|
37
|
+
|
38
|
+
# does not account for cross-entropy and custom
|
39
|
+
if isinstance(model, lightgbm.LGBMClassifier):
|
40
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
41
|
+
if num_classes == 2:
|
42
|
+
return type_hints.ModelObjective.BINARY_CLASSIFICATION
|
43
|
+
return type_hints.ModelObjective.MULTI_CLASSIFICATION
|
44
|
+
if isinstance(model, lightgbm.LGBMRanker):
|
45
|
+
return type_hints.ModelObjective.RANKING
|
46
|
+
if isinstance(model, lightgbm.LGBMRegressor):
|
47
|
+
return type_hints.ModelObjective.REGRESSION
|
48
|
+
model_objective = model.params["objective"]
|
49
|
+
if model_objective in _BINARY_CLASSIFICATION_OBJECTIVES:
|
50
|
+
return type_hints.ModelObjective.BINARY_CLASSIFICATION
|
51
|
+
if model_objective in _MULTI_CLASSIFICATION_OBJECTIVES:
|
52
|
+
return type_hints.ModelObjective.MULTI_CLASSIFICATION
|
53
|
+
if model_objective in _RANKING_OBJECTIVES:
|
54
|
+
return type_hints.ModelObjective.RANKING
|
55
|
+
if model_objective in _REGRESSION_OBJECTIVES:
|
56
|
+
return type_hints.ModelObjective.REGRESSION
|
57
|
+
return type_hints.ModelObjective.UNKNOWN
|
58
|
+
|
59
|
+
|
60
|
+
def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective:
|
61
|
+
|
62
|
+
import xgboost
|
63
|
+
|
64
|
+
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
65
|
+
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
66
|
+
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
67
|
+
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
68
|
+
|
69
|
+
model_objective = ""
|
70
|
+
if isinstance(model, xgboost.Booster):
|
71
|
+
model_params = json.loads(model.save_config())
|
72
|
+
model_objective = model_params.get("learner", {}).get("objective", "")
|
73
|
+
else:
|
74
|
+
if hasattr(model, "get_params"):
|
75
|
+
model_objective = model.get_params().get("objective", "")
|
76
|
+
|
77
|
+
if isinstance(model_objective, dict):
|
78
|
+
model_objective = model_objective.get("name", "")
|
79
|
+
for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
80
|
+
if classification_objective in model_objective:
|
81
|
+
return type_hints.ModelObjective.BINARY_CLASSIFICATION
|
82
|
+
for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
83
|
+
if classification_objective in model_objective:
|
84
|
+
return type_hints.ModelObjective.MULTI_CLASSIFICATION
|
85
|
+
for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
|
86
|
+
if ranking_objective in model_objective:
|
87
|
+
return type_hints.ModelObjective.RANKING
|
88
|
+
for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
|
89
|
+
if regression_objective in model_objective:
|
90
|
+
return type_hints.ModelObjective.REGRESSION
|
91
|
+
return type_hints.ModelObjective.UNKNOWN
|
92
|
+
|
93
|
+
|
94
|
+
def get_model_objective_and_output_type(model: Any) -> ModelObjectiveAndOutputType:
|
95
|
+
import xgboost
|
96
|
+
|
97
|
+
if isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel):
|
98
|
+
model_objective = get_model_objective_xgb(model)
|
99
|
+
output_type = model_signature.DataType.DOUBLE
|
100
|
+
if model_objective == type_hints.ModelObjective.MULTI_CLASSIFICATION:
|
101
|
+
output_type = model_signature.DataType.STRING
|
102
|
+
return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
|
103
|
+
|
104
|
+
import lightgbm
|
105
|
+
|
106
|
+
if isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel):
|
107
|
+
model_objective = get_model_objective_lightgbm(model)
|
108
|
+
output_type = model_signature.DataType.DOUBLE
|
109
|
+
if model_objective in [
|
110
|
+
type_hints.ModelObjective.BINARY_CLASSIFICATION,
|
111
|
+
type_hints.ModelObjective.MULTI_CLASSIFICATION,
|
112
|
+
]:
|
113
|
+
output_type = model_signature.DataType.STRING
|
114
|
+
return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
|
115
|
+
|
116
|
+
raise ValueError(f"Model type {type(model)} is not supported")
|
@@ -37,7 +37,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
37
37
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
38
38
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
39
39
|
|
40
|
-
|
40
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
41
41
|
DEFAULT_TARGET_METHODS = ["forward"]
|
42
42
|
|
43
43
|
@classmethod
|
@@ -73,6 +73,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
73
73
|
is_sub_model: Optional[bool] = False,
|
74
74
|
**kwargs: Unpack[model_types.PyTorchSaveOptions],
|
75
75
|
) -> None:
|
76
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
77
|
+
if enable_explainability:
|
78
|
+
raise NotImplementedError("Explainability is not supported for PyTorch model.")
|
79
|
+
|
76
80
|
import torch
|
77
81
|
|
78
82
|
assert isinstance(model, torch.nn.Module)
|
@@ -115,13 +119,13 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
115
119
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
116
120
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
117
121
|
os.makedirs(model_blob_path, exist_ok=True)
|
118
|
-
with open(os.path.join(model_blob_path, cls.
|
122
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
119
123
|
torch.save(model, f, pickle_module=cloudpickle)
|
120
124
|
base_meta = model_blob_meta.ModelBlobMeta(
|
121
125
|
name=name,
|
122
126
|
model_type=cls.HANDLER_TYPE,
|
123
127
|
handler_version=cls.HANDLER_VERSION,
|
124
|
-
path=cls.
|
128
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
125
129
|
)
|
126
130
|
model_meta.models[name] = base_meta
|
127
131
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -156,6 +160,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
156
160
|
cls,
|
157
161
|
raw_model: "torch.nn.Module",
|
158
162
|
model_meta: model_meta_api.ModelMetadata,
|
163
|
+
background_data: Optional[pd.DataFrame] = None,
|
159
164
|
**kwargs: Unpack[model_types.PyTorchLoadOptions],
|
160
165
|
) -> custom_model.CustomModel:
|
161
166
|
import torch
|
@@ -31,7 +31,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
31
31
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
32
32
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
33
33
|
|
34
|
-
|
34
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
35
35
|
DEFAULT_TARGET_METHODS = ["encode"]
|
36
36
|
|
37
37
|
@classmethod
|
@@ -64,6 +64,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
64
64
|
is_sub_model: Optional[bool] = False,
|
65
65
|
**kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
|
66
66
|
) -> None:
|
67
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
68
|
+
if enable_explainability:
|
69
|
+
raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
|
70
|
+
|
67
71
|
# Validate target methods and signature (if possible)
|
68
72
|
if not is_sub_model:
|
69
73
|
target_methods = handlers_utils.get_target_methods(
|
@@ -101,14 +105,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
101
105
|
# save model
|
102
106
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
103
107
|
os.makedirs(model_blob_path, exist_ok=True)
|
104
|
-
model.save(os.path.join(model_blob_path, cls.
|
108
|
+
model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
105
109
|
|
106
110
|
# save model metadata
|
107
111
|
base_meta = model_blob_meta.ModelBlobMeta(
|
108
112
|
name=name,
|
109
113
|
model_type=cls.HANDLER_TYPE,
|
110
114
|
handler_version=cls.HANDLER_VERSION,
|
111
|
-
path=cls.
|
115
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
112
116
|
)
|
113
117
|
model_meta.models[name] = base_meta
|
114
118
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -154,6 +158,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
154
158
|
cls,
|
155
159
|
raw_model: "sentence_transformers.SentenceTransformer",
|
156
160
|
model_meta: model_meta_api.ModelMetadata,
|
161
|
+
background_data: Optional[pd.DataFrame] = None,
|
157
162
|
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
|
158
163
|
) -> custom_model.CustomModel:
|
159
164
|
import sentence_transformers
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
8
8
|
|
9
|
+
import snowflake.snowpark.dataframe as sp_df
|
9
10
|
from snowflake.ml._internal import type_utils
|
10
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
12
|
from snowflake.ml.model._packager.model_env import model_env
|
@@ -14,8 +15,13 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
14
15
|
from snowflake.ml.model._packager.model_meta import (
|
15
16
|
model_blob_meta,
|
16
17
|
model_meta as model_meta_api,
|
18
|
+
model_meta_schema,
|
19
|
+
)
|
20
|
+
from snowflake.ml.model._signatures import (
|
21
|
+
numpy_handler,
|
22
|
+
snowpark_handler,
|
23
|
+
utils as model_signature_utils,
|
17
24
|
)
|
18
|
-
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
19
25
|
|
20
26
|
if TYPE_CHECKING:
|
21
27
|
import sklearn.base
|
@@ -36,6 +42,27 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
36
42
|
|
37
43
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
38
44
|
|
45
|
+
@classmethod
|
46
|
+
def get_model_objective(
|
47
|
+
cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
|
48
|
+
) -> model_types.ModelObjective:
|
49
|
+
import sklearn.pipeline
|
50
|
+
from sklearn.base import is_classifier, is_regressor
|
51
|
+
|
52
|
+
if isinstance(model, sklearn.pipeline.Pipeline):
|
53
|
+
return model_types.ModelObjective.UNKNOWN
|
54
|
+
if is_regressor(model):
|
55
|
+
return model_types.ModelObjective.REGRESSION
|
56
|
+
if is_classifier(model):
|
57
|
+
classes_list = getattr(model, "classes_", [])
|
58
|
+
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
59
|
+
if isinstance(num_classes, int):
|
60
|
+
if num_classes > 2:
|
61
|
+
return model_types.ModelObjective.MULTI_CLASSIFICATION
|
62
|
+
return model_types.ModelObjective.BINARY_CLASSIFICATION
|
63
|
+
return model_types.ModelObjective.UNKNOWN
|
64
|
+
return model_types.ModelObjective.UNKNOWN
|
65
|
+
|
39
66
|
@classmethod
|
40
67
|
def can_handle(
|
41
68
|
cls,
|
@@ -68,6 +95,18 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
68
95
|
|
69
96
|
return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model)
|
70
97
|
|
98
|
+
@staticmethod
|
99
|
+
def get_explainability_supported_background(
|
100
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
101
|
+
) -> Optional[pd.DataFrame]:
|
102
|
+
if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame):
|
103
|
+
return (
|
104
|
+
sample_input_data
|
105
|
+
if isinstance(sample_input_data, pd.DataFrame)
|
106
|
+
else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
107
|
+
)
|
108
|
+
return None
|
109
|
+
|
71
110
|
@classmethod
|
72
111
|
def save_model(
|
73
112
|
cls,
|
@@ -79,11 +118,31 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
79
118
|
is_sub_model: Optional[bool] = False,
|
80
119
|
**kwargs: Unpack[model_types.SKLModelSaveOptions],
|
81
120
|
) -> None:
|
121
|
+
# setting None by default to distinguish if users did not set it
|
122
|
+
enable_explainability = kwargs.get("enable_explainability", None)
|
123
|
+
|
82
124
|
import sklearn.base
|
83
125
|
import sklearn.pipeline
|
84
126
|
|
85
127
|
assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
|
86
128
|
|
129
|
+
background_data = cls.get_explainability_supported_background(sample_input_data)
|
130
|
+
|
131
|
+
# if users did not ask then we enable if we have background data
|
132
|
+
if enable_explainability is None and background_data is not None:
|
133
|
+
enable_explainability = True
|
134
|
+
if enable_explainability:
|
135
|
+
# if users set it explicitly but no background data then error out
|
136
|
+
if background_data is None:
|
137
|
+
raise ValueError(
|
138
|
+
"Sample input data is required to enable explainability. Currently we only support this for "
|
139
|
+
+ "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
|
140
|
+
)
|
141
|
+
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
|
142
|
+
os.makedirs(data_blob_path, exist_ok=True)
|
143
|
+
with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
|
144
|
+
background_data.to_parquet(f)
|
145
|
+
|
87
146
|
if not is_sub_model:
|
88
147
|
target_methods = handlers_utils.get_target_methods(
|
89
148
|
model=model,
|
@@ -110,19 +169,38 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
110
169
|
get_prediction_fn=get_prediction,
|
111
170
|
)
|
112
171
|
|
172
|
+
model_objective = cls.get_model_objective(model)
|
173
|
+
model_meta.model_objective = model_objective
|
174
|
+
|
175
|
+
if enable_explainability:
|
176
|
+
output_type = model_signature.DataType.DOUBLE
|
177
|
+
|
178
|
+
if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
|
179
|
+
output_type = model_signature.DataType.STRING
|
180
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
181
|
+
model_meta=model_meta,
|
182
|
+
explain_method="explain",
|
183
|
+
target_method="predict",
|
184
|
+
output_return_type=output_type,
|
185
|
+
)
|
186
|
+
|
113
187
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
114
188
|
os.makedirs(model_blob_path, exist_ok=True)
|
115
|
-
with open(os.path.join(model_blob_path, cls.
|
189
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
116
190
|
cloudpickle.dump(model, f)
|
117
191
|
base_meta = model_blob_meta.ModelBlobMeta(
|
118
192
|
name=name,
|
119
193
|
model_type=cls.HANDLER_TYPE,
|
120
194
|
handler_version=cls.HANDLER_VERSION,
|
121
|
-
path=cls.
|
195
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
122
196
|
)
|
123
197
|
model_meta.models[name] = base_meta
|
124
198
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
125
199
|
|
200
|
+
if enable_explainability:
|
201
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
202
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
203
|
+
|
126
204
|
model_meta.env.include_if_absent(
|
127
205
|
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
|
128
206
|
)
|
@@ -153,6 +231,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
153
231
|
cls,
|
154
232
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
155
233
|
model_meta: model_meta_api.ModelMetadata,
|
234
|
+
background_data: Optional[pd.DataFrame] = None,
|
156
235
|
**kwargs: Unpack[model_types.SKLModelLoadOptions],
|
157
236
|
) -> custom_model.CustomModel:
|
158
237
|
from snowflake.ml.model import custom_model
|
@@ -165,6 +244,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
165
244
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
166
245
|
signature: model_signature.ModelSignature,
|
167
246
|
target_method: str,
|
247
|
+
background_data: Optional[pd.DataFrame],
|
168
248
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
169
249
|
@custom_model.inference_api
|
170
250
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
@@ -179,11 +259,26 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
179
259
|
|
180
260
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
181
261
|
|
262
|
+
@custom_model.inference_api
|
263
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
264
|
+
import shap
|
265
|
+
|
266
|
+
# TODO: if not resolved by explainer, we need to pass the callable function
|
267
|
+
try:
|
268
|
+
explainer = shap.Explainer(raw_model, background_data)
|
269
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
270
|
+
except TypeError as e:
|
271
|
+
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
272
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
273
|
+
|
274
|
+
if target_method == "explain":
|
275
|
+
return explain_fn
|
276
|
+
|
182
277
|
return fn
|
183
278
|
|
184
279
|
type_method_dict = {}
|
185
280
|
for target_method_name, sig in model_meta.signatures.items():
|
186
|
-
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
281
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
|
187
282
|
|
188
283
|
_SKLModel = type(
|
189
284
|
"_SKLModel",
|