snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +67 -10
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +12 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +579 -53
- snowflake/ml/feature_store/feature_view.py +168 -5
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +11 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +24 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +11 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +42 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +68 -0
- snowflake/ml/model/_packager/model_handlers/xgboost.py +59 -0
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +4 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +48 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +64 -42
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
+
from enum import Enum
|
2
3
|
from typing import Dict, Generic, Optional, Protocol, Type, final
|
3
4
|
|
4
5
|
from typing_extensions import TypeGuard, Unpack
|
@@ -8,6 +9,15 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
8
9
|
from snowflake.ml.model._packager.model_meta import model_meta
|
9
10
|
|
10
11
|
|
12
|
+
class ModelObjective(Enum):
|
13
|
+
# This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
|
14
|
+
UNKNOWN = "unknown"
|
15
|
+
BINARY_CLASSIFICATION = "binary_classification"
|
16
|
+
MULTI_CLASSIFICATION = "multi_classification"
|
17
|
+
REGRESSION = "regression"
|
18
|
+
RANKING = "ranking"
|
19
|
+
|
20
|
+
|
11
21
|
class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
12
22
|
HANDLER_TYPE: model_types.SupportedModelHandlerType
|
13
23
|
HANDLER_VERSION: str
|
@@ -16,7 +26,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
16
26
|
|
17
27
|
@classmethod
|
18
28
|
@abstractmethod
|
19
|
-
def can_handle(cls, model: model_types.
|
29
|
+
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard[model_types._ModelType]:
|
20
30
|
"""Whether this handler could support the type of the `model`.
|
21
31
|
|
22
32
|
Args:
|
@@ -1,4 +1,9 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
import pandas as pd
|
2
7
|
|
3
8
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
4
9
|
from snowflake.ml.model._packager.model_meta import model_meta
|
@@ -40,6 +45,24 @@ def validate_signature(
|
|
40
45
|
return model_meta
|
41
46
|
|
42
47
|
|
48
|
+
def add_explain_method_signature(
|
49
|
+
model_meta: model_meta.ModelMetadata,
|
50
|
+
explain_method: str,
|
51
|
+
target_method: str,
|
52
|
+
output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
|
53
|
+
) -> model_meta.ModelMetadata:
|
54
|
+
if target_method not in model_meta.signatures:
|
55
|
+
raise ValueError(f"Signature for target method {target_method} is missing")
|
56
|
+
inputs = model_meta.signatures[target_method].inputs
|
57
|
+
model_meta.signatures[explain_method] = model_signature.ModelSignature(
|
58
|
+
inputs=inputs,
|
59
|
+
outputs=[
|
60
|
+
model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
|
61
|
+
],
|
62
|
+
)
|
63
|
+
return model_meta
|
64
|
+
|
65
|
+
|
43
66
|
def get_target_methods(
|
44
67
|
model: model_types.SupportedModelType,
|
45
68
|
target_methods: Optional[Sequence[str]],
|
@@ -56,3 +79,37 @@ def validate_target_methods(model: model_types.SupportedModelType, target_method
|
|
56
79
|
for method_name in target_methods:
|
57
80
|
if not _is_callable(model, method_name):
|
58
81
|
raise ValueError(f"Target method {method_name} is not callable or does not exist in the model.")
|
82
|
+
|
83
|
+
|
84
|
+
def get_num_classes_if_exists(model: model_types.SupportedModelType) -> int:
|
85
|
+
num_classes = getattr(model, "classes_", [])
|
86
|
+
return len(num_classes)
|
87
|
+
|
88
|
+
|
89
|
+
def convert_explanations_to_2D_df(
|
90
|
+
model: model_types.SupportedModelType, explanations: npt.NDArray[Any]
|
91
|
+
) -> pd.DataFrame:
|
92
|
+
if explanations.ndim != 3:
|
93
|
+
return pd.DataFrame(explanations)
|
94
|
+
|
95
|
+
if hasattr(model, "classes_"):
|
96
|
+
classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
|
97
|
+
len_classes = len(classes_list)
|
98
|
+
if explanations.shape[2] != len_classes:
|
99
|
+
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
100
|
+
else:
|
101
|
+
classes_list = [i for i in range(explanations.shape[2])]
|
102
|
+
exp_2d = []
|
103
|
+
# TODO (SNOW-1549044): Optimize this
|
104
|
+
for row in explanations:
|
105
|
+
col_list = []
|
106
|
+
for column in row:
|
107
|
+
class_explanations = {}
|
108
|
+
for cl, cl_exp in zip(classes_list, column):
|
109
|
+
if isinstance(cl, (int, np.integer)):
|
110
|
+
cl = int(cl)
|
111
|
+
class_explanations[cl] = cl_exp
|
112
|
+
col_list.append(json.dumps(class_explanations))
|
113
|
+
exp_2d.append(col_list)
|
114
|
+
|
115
|
+
return pd.DataFrame(exp_2d)
|
@@ -33,6 +33,22 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
33
33
|
MODELE_BLOB_FILE_OR_DIR = "model.bin"
|
34
34
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
35
|
|
36
|
+
@classmethod
|
37
|
+
def get_model_objective(cls, model: "catboost.CatBoost") -> _base.ModelObjective:
|
38
|
+
import catboost
|
39
|
+
|
40
|
+
if isinstance(model, catboost.CatBoostClassifier):
|
41
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
|
+
if num_classes == 2:
|
43
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
44
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
45
|
+
if isinstance(model, catboost.CatBoostRanker):
|
46
|
+
return _base.ModelObjective.RANKING
|
47
|
+
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
+
return _base.ModelObjective.REGRESSION
|
49
|
+
# TODO: Find out model type from the generic Catboost Model
|
50
|
+
return _base.ModelObjective.UNKNOWN
|
51
|
+
|
36
52
|
@classmethod
|
37
53
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
38
54
|
return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
|
@@ -89,6 +105,16 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
89
105
|
sample_input_data=sample_input_data,
|
90
106
|
get_prediction_fn=get_prediction,
|
91
107
|
)
|
108
|
+
if kwargs.get("enable_explainability", False):
|
109
|
+
output_type = model_signature.DataType.DOUBLE
|
110
|
+
if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
|
111
|
+
output_type = model_signature.DataType.STRING
|
112
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
113
|
+
model_meta=model_meta,
|
114
|
+
explain_method="explain",
|
115
|
+
target_method="predict",
|
116
|
+
output_return_type=output_type,
|
117
|
+
)
|
92
118
|
|
93
119
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
94
120
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -112,6 +138,11 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
112
138
|
],
|
113
139
|
check_local_version=True,
|
114
140
|
)
|
141
|
+
if kwargs.get("enable_explainability", False):
|
142
|
+
model_meta.env.include_if_absent(
|
143
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
144
|
+
check_local_version=True,
|
145
|
+
)
|
115
146
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
116
147
|
|
117
148
|
return None
|
@@ -186,6 +217,17 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
186
217
|
|
187
218
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
188
219
|
|
220
|
+
@custom_model.inference_api
|
221
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
222
|
+
import shap
|
223
|
+
|
224
|
+
explainer = shap.TreeExplainer(raw_model)
|
225
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
226
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
227
|
+
|
228
|
+
if target_method == "explain":
|
229
|
+
return explain_fn
|
230
|
+
|
189
231
|
return fn
|
190
232
|
|
191
233
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -43,6 +43,45 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
43
43
|
|
44
44
|
MODELE_BLOB_FILE_OR_DIR = "model.pkl"
|
45
45
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
+
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
47
|
+
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
48
|
+
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
49
|
+
_REGRESSION_OBJECTIVES = [
|
50
|
+
"regression",
|
51
|
+
"regression_l1",
|
52
|
+
"huber",
|
53
|
+
"fair",
|
54
|
+
"poisson",
|
55
|
+
"quantile",
|
56
|
+
"tweedie",
|
57
|
+
"mape",
|
58
|
+
"gamma",
|
59
|
+
]
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def get_model_objective(cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> _base.ModelObjective:
|
63
|
+
import lightgbm
|
64
|
+
|
65
|
+
# does not account for cross-entropy and custom
|
66
|
+
if isinstance(model, lightgbm.LGBMClassifier):
|
67
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
68
|
+
if num_classes == 2:
|
69
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
70
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
71
|
+
if isinstance(model, lightgbm.LGBMRanker):
|
72
|
+
return _base.ModelObjective.RANKING
|
73
|
+
if isinstance(model, lightgbm.LGBMRegressor):
|
74
|
+
return _base.ModelObjective.REGRESSION
|
75
|
+
model_objective = model.params["objective"]
|
76
|
+
if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
|
77
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
78
|
+
if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
|
79
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
80
|
+
if model_objective in cls._RANKING_OBJECTIVES:
|
81
|
+
return _base.ModelObjective.RANKING
|
82
|
+
if model_objective in cls._REGRESSION_OBJECTIVES:
|
83
|
+
return _base.ModelObjective.REGRESSION
|
84
|
+
return _base.ModelObjective.UNKNOWN
|
46
85
|
|
47
86
|
@classmethod
|
48
87
|
def can_handle(
|
@@ -105,6 +144,19 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
105
144
|
sample_input_data=sample_input_data,
|
106
145
|
get_prediction_fn=get_prediction,
|
107
146
|
)
|
147
|
+
if kwargs.get("enable_explainability", False):
|
148
|
+
output_type = model_signature.DataType.DOUBLE
|
149
|
+
if cls.get_model_objective(model) in [
|
150
|
+
_base.ModelObjective.BINARY_CLASSIFICATION,
|
151
|
+
_base.ModelObjective.MULTI_CLASSIFICATION,
|
152
|
+
]:
|
153
|
+
output_type = model_signature.DataType.STRING
|
154
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
155
|
+
model_meta=model_meta,
|
156
|
+
explain_method="explain",
|
157
|
+
target_method="predict",
|
158
|
+
output_return_type=output_type,
|
159
|
+
)
|
108
160
|
|
109
161
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
110
162
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -130,6 +182,11 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
130
182
|
],
|
131
183
|
check_local_version=True,
|
132
184
|
)
|
185
|
+
if kwargs.get("enable_explainability", False):
|
186
|
+
model_meta.env.include_if_absent(
|
187
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
188
|
+
check_local_version=True,
|
189
|
+
)
|
133
190
|
|
134
191
|
return None
|
135
192
|
|
@@ -198,6 +255,17 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
198
255
|
|
199
256
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
200
257
|
|
258
|
+
@custom_model.inference_api
|
259
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
260
|
+
import shap
|
261
|
+
|
262
|
+
explainer = shap.TreeExplainer(raw_model)
|
263
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
264
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
265
|
+
|
266
|
+
if target_method == "explain":
|
267
|
+
return explain_fn
|
268
|
+
|
201
269
|
return fn
|
202
270
|
|
203
271
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -1,4 +1,5 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
+
import json
|
2
3
|
import os
|
3
4
|
from typing import (
|
4
5
|
TYPE_CHECKING,
|
@@ -46,6 +47,39 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
46
47
|
|
47
48
|
MODELE_BLOB_FILE_OR_DIR = "model.ubj"
|
48
49
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
50
|
+
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
51
|
+
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
52
|
+
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
53
|
+
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def get_model_objective(cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> _base.ModelObjective:
|
57
|
+
import xgboost
|
58
|
+
|
59
|
+
if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
|
60
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
61
|
+
if num_classes == 2:
|
62
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
63
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
64
|
+
if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
|
65
|
+
return _base.ModelObjective.REGRESSION
|
66
|
+
if isinstance(model, xgboost.XGBRanker):
|
67
|
+
return _base.ModelObjective.RANKING
|
68
|
+
model_params = json.loads(model.save_config())
|
69
|
+
model_objective = model_params["learner"]["objective"]
|
70
|
+
for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
71
|
+
if classification_objective in model_objective:
|
72
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
73
|
+
for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
74
|
+
if classification_objective in model_objective:
|
75
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
76
|
+
for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
|
77
|
+
if ranking_objective in model_objective:
|
78
|
+
return _base.ModelObjective.RANKING
|
79
|
+
for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
|
80
|
+
if regression_objective in model_objective:
|
81
|
+
return _base.ModelObjective.REGRESSION
|
82
|
+
return _base.ModelObjective.UNKNOWN
|
49
83
|
|
50
84
|
@classmethod
|
51
85
|
def can_handle(
|
@@ -112,6 +146,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
112
146
|
sample_input_data=sample_input_data,
|
113
147
|
get_prediction_fn=get_prediction,
|
114
148
|
)
|
149
|
+
if kwargs.get("enable_explainability", False):
|
150
|
+
output_type = model_signature.DataType.DOUBLE
|
151
|
+
if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
|
152
|
+
output_type = model_signature.DataType.STRING
|
153
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
154
|
+
model_meta=model_meta,
|
155
|
+
explain_method="explain",
|
156
|
+
target_method="predict",
|
157
|
+
output_return_type=output_type,
|
158
|
+
)
|
115
159
|
|
116
160
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
117
161
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -133,6 +177,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
133
177
|
],
|
134
178
|
check_local_version=True,
|
135
179
|
)
|
180
|
+
if kwargs.get("enable_explainability", False):
|
181
|
+
model_meta.env.include_if_absent(
|
182
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
183
|
+
check_local_version=True,
|
184
|
+
)
|
136
185
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
137
186
|
|
138
187
|
@classmethod
|
@@ -206,6 +255,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
206
255
|
|
207
256
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
208
257
|
|
258
|
+
@custom_model.inference_api
|
259
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
260
|
+
import shap
|
261
|
+
|
262
|
+
explainer = shap.TreeExplainer(raw_model)
|
263
|
+
df = pd.DataFrame(explainer(X).values)
|
264
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
265
|
+
|
266
|
+
if target_method == "explain":
|
267
|
+
return explain_fn
|
209
268
|
return fn
|
210
269
|
|
211
270
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -35,7 +35,7 @@ class ModelRuntime:
|
|
35
35
|
self,
|
36
36
|
name: str,
|
37
37
|
env: model_env.ModelEnv,
|
38
|
-
imports: Optional[List[
|
38
|
+
imports: Optional[List[str]] = None,
|
39
39
|
is_gpu: bool = False,
|
40
40
|
loading_from_file: bool = False,
|
41
41
|
) -> None:
|
@@ -75,7 +75,7 @@ class ModelRuntime:
|
|
75
75
|
snowpark_ml_lib_path = runtime_base_path / "snowflake-ml-python.zip"
|
76
76
|
file_utils.zip_python_package(str(snowpark_ml_lib_path), "snowflake.ml")
|
77
77
|
snowpark_ml_lib_rel_path = pathlib.PurePosixPath(snowpark_ml_lib_path.relative_to(packager_path).as_posix())
|
78
|
-
self.imports.append(snowpark_ml_lib_rel_path)
|
78
|
+
self.imports.append(str(snowpark_ml_lib_rel_path))
|
79
79
|
|
80
80
|
self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
|
81
81
|
self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
|
@@ -108,6 +108,4 @@ class ModelRuntime:
|
|
108
108
|
warnings.simplefilter("ignore")
|
109
109
|
env.load_from_conda_file(packager_path / conda_env_rel_path)
|
110
110
|
env.load_from_pip_file(packager_path / pip_requirements_rel_path)
|
111
|
-
return ModelRuntime(
|
112
|
-
name=name, env=env, imports=list(map(pathlib.PurePosixPath, loaded_dict["imports"])), loading_from_file=True
|
113
|
-
)
|
111
|
+
return ModelRuntime(name=name, env=env, imports=loaded_dict["imports"], loading_from_file=True)
|
@@ -232,7 +232,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
232
232
|
),
|
233
233
|
)
|
234
234
|
else:
|
235
|
-
if isinstance(data_col[0], list):
|
235
|
+
if isinstance(data_col.iloc[0], list):
|
236
236
|
if not ft_shape:
|
237
237
|
raise snowml_exceptions.SnowflakeMLException(
|
238
238
|
error_code=error_codes.INVALID_DATA,
|
@@ -266,7 +266,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
266
266
|
),
|
267
267
|
)
|
268
268
|
|
269
|
-
elif isinstance(data_col[0], np.ndarray):
|
269
|
+
elif isinstance(data_col.iloc[0], np.ndarray):
|
270
270
|
if not ft_shape:
|
271
271
|
raise snowml_exceptions.SnowflakeMLException(
|
272
272
|
error_code=error_codes.INVALID_DATA,
|
@@ -297,7 +297,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
297
297
|
),
|
298
298
|
)
|
299
299
|
|
300
|
-
elif isinstance(data_col[0], str):
|
300
|
+
elif isinstance(data_col.iloc[0], str):
|
301
301
|
if ft_shape is not None:
|
302
302
|
raise snowml_exceptions.SnowflakeMLException(
|
303
303
|
error_code=error_codes.INVALID_DATA,
|
@@ -316,7 +316,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
316
316
|
),
|
317
317
|
)
|
318
318
|
|
319
|
-
elif isinstance(data_col[0], bytes):
|
319
|
+
elif isinstance(data_col.iloc[0], bytes):
|
320
320
|
if ft_shape is not None:
|
321
321
|
raise snowml_exceptions.SnowflakeMLException(
|
322
322
|
error_code=error_codes.INVALID_DATA,
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -232,11 +232,13 @@ class BaseModelSaveOption(TypedDict):
|
|
232
232
|
_legacy_save: NotRequired[bool]
|
233
233
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
234
234
|
method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
|
235
|
+
include_pip_dependencies: NotRequired[bool]
|
235
236
|
|
236
237
|
|
237
238
|
class CatBoostModelSaveOptions(BaseModelSaveOption):
|
238
239
|
target_methods: NotRequired[Sequence[str]]
|
239
240
|
cuda_version: NotRequired[str]
|
241
|
+
enable_explainability: NotRequired[bool]
|
240
242
|
|
241
243
|
|
242
244
|
class CustomModelSaveOption(BaseModelSaveOption):
|
@@ -250,10 +252,12 @@ class SKLModelSaveOptions(BaseModelSaveOption):
|
|
250
252
|
class XGBModelSaveOptions(BaseModelSaveOption):
|
251
253
|
target_methods: NotRequired[Sequence[str]]
|
252
254
|
cuda_version: NotRequired[str]
|
255
|
+
enable_explainability: NotRequired[bool]
|
253
256
|
|
254
257
|
|
255
258
|
class LGBMModelSaveOptions(BaseModelSaveOption):
|
256
259
|
target_methods: NotRequired[Sequence[str]]
|
260
|
+
enable_explainability: NotRequired[bool]
|
257
261
|
|
258
262
|
|
259
263
|
class SNOWModelSaveOptions(BaseModelSaveOption):
|
@@ -41,7 +41,7 @@ cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snow
|
|
41
41
|
|
42
42
|
_PROJECT = "ModelDevelopment"
|
43
43
|
DEFAULT_UDTF_NJOBS = 3
|
44
|
-
ENABLE_EFFICIENT_MEMORY_USAGE =
|
44
|
+
ENABLE_EFFICIENT_MEMORY_USAGE = True
|
45
45
|
_UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
|
46
46
|
|
47
47
|
|
@@ -83,7 +83,19 @@ def _load_data_into_udf() -> Tuple[
|
|
83
83
|
with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
|
84
84
|
fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
|
85
85
|
|
86
|
-
#
|
86
|
+
# Convert dataframe to numpy would save memory consumption
|
87
|
+
# Except for Pipeline, we need to keep the dataframe for the column names
|
88
|
+
from sklearn.pipeline import Pipeline
|
89
|
+
if isinstance(base_estimator, Pipeline):
|
90
|
+
return (
|
91
|
+
df[CONSTANTS['input_cols']],
|
92
|
+
df[CONSTANTS['label_cols']].squeeze(),
|
93
|
+
indices,
|
94
|
+
params_to_evaluate,
|
95
|
+
base_estimator,
|
96
|
+
fit_and_score_kwargs,
|
97
|
+
CONSTANTS
|
98
|
+
)
|
87
99
|
return (
|
88
100
|
df[CONSTANTS['input_cols']].to_numpy(),
|
89
101
|
df[CONSTANTS['label_cols']].squeeze().to_numpy(),
|
@@ -1,5 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import copy
|
3
|
+
import warnings
|
3
4
|
from typing import Any, Dict, Iterable, Optional, Type, Union
|
4
5
|
|
5
6
|
import numpy as np
|
@@ -10,6 +11,7 @@ from sklearn import impute
|
|
10
11
|
from snowflake import snowpark
|
11
12
|
from snowflake.ml._internal import telemetry
|
12
13
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
14
|
+
from snowflake.ml._internal.utils import formatting
|
13
15
|
from snowflake.ml.modeling.framework import _utils, base
|
14
16
|
from snowflake.snowpark import functions as F, types as T
|
15
17
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
@@ -171,6 +173,14 @@ class SimpleImputer(base.BaseTransformer):
|
|
171
173
|
self.set_output_cols(output_cols)
|
172
174
|
self.set_passthrough_cols(passthrough_cols)
|
173
175
|
|
176
|
+
def _is_integer_type(self, column_type: T.DataType) -> bool:
|
177
|
+
return (
|
178
|
+
isinstance(column_type, T.ByteType)
|
179
|
+
or isinstance(column_type, T.ShortType)
|
180
|
+
or isinstance(column_type, T.IntegerType)
|
181
|
+
or isinstance(column_type, T.LongType)
|
182
|
+
)
|
183
|
+
|
174
184
|
def _reset(self) -> None:
|
175
185
|
"""
|
176
186
|
Reset internal data-dependent state of the imputer, if necessary.
|
@@ -389,6 +399,22 @@ class SimpleImputer(base.BaseTransformer):
|
|
389
399
|
# Use `fillna` for replacing nans. Check if the column has a string data type, or coerce a float.
|
390
400
|
if not isinstance(input_col_datatypes[input_col], T.StringType):
|
391
401
|
statistic = float(statistic)
|
402
|
+
|
403
|
+
if self._is_integer_type(input_col_datatypes[input_col]):
|
404
|
+
if statistic.is_integer():
|
405
|
+
statistic = int(statistic)
|
406
|
+
else:
|
407
|
+
warnings.warn(
|
408
|
+
formatting.unwrap(
|
409
|
+
f"""
|
410
|
+
Integer column may not be imputed with a non-integer value {statistic}.
|
411
|
+
In order to impute a non-integer value, convert the column to FloatType before imputing.
|
412
|
+
"""
|
413
|
+
),
|
414
|
+
category=UserWarning,
|
415
|
+
stacklevel=1,
|
416
|
+
)
|
417
|
+
|
392
418
|
transformed_dataset = transformed_dataset.na.fill({output_col: statistic})
|
393
419
|
else:
|
394
420
|
transformed_dataset = transformed_dataset.na.replace(
|
@@ -99,10 +99,6 @@ class Pipeline(base.BaseTransformer):
|
|
99
99
|
must implement `fit` and `transform` methods.
|
100
100
|
The final step can be a transform or estimator, that is, it must implement
|
101
101
|
`fit` and `transform`/`predict` methods.
|
102
|
-
TODO: SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
|
103
|
-
estimator(like None or passthrough). Currently this Pipeline class works with a list of all
|
104
|
-
transforms or a list of transforms ending with an estimator. Should we change this implementation
|
105
|
-
to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
|
106
102
|
|
107
103
|
Args:
|
108
104
|
steps: List of (name, transform) tuples (implementing `fit`/`transform`) that
|
@@ -111,6 +107,10 @@ class Pipeline(base.BaseTransformer):
|
|
111
107
|
"""
|
112
108
|
super().__init__()
|
113
109
|
self.steps = steps
|
110
|
+
# TODO(snandamuri): SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
|
111
|
+
# estimator(like None or passthrough). Currently this Pipeline class works with a list of all
|
112
|
+
# transforms or a list of transforms ending with an estimator. Should we change this implementation
|
113
|
+
# to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
|
114
114
|
self._is_final_step_estimator = Pipeline._is_estimator(steps[-1][1])
|
115
115
|
self._is_fitted = False
|
116
116
|
self._feature_names_in: List[np.ndarray[Any, np.dtype[Any]]] = []
|