snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +14 -0
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +59 -6
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +59 -24
- snowflake/ml/feature_store/feature_view.py +148 -4
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_model_composer/model_composer.py +3 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/type_hints.py +1 -3
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/pipeline/pipeline.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -41,7 +41,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
41
41
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
42
42
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
43
43
|
|
44
|
-
|
44
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
45
45
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
46
|
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
47
47
|
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
@@ -59,29 +59,31 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
59
59
|
]
|
60
60
|
|
61
61
|
@classmethod
|
62
|
-
def get_model_objective(
|
62
|
+
def get_model_objective(
|
63
|
+
cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]
|
64
|
+
) -> model_meta_schema.ModelObjective:
|
63
65
|
import lightgbm
|
64
66
|
|
65
67
|
# does not account for cross-entropy and custom
|
66
68
|
if isinstance(model, lightgbm.LGBMClassifier):
|
67
69
|
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
68
70
|
if num_classes == 2:
|
69
|
-
return
|
70
|
-
return
|
71
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
72
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
71
73
|
if isinstance(model, lightgbm.LGBMRanker):
|
72
|
-
return
|
74
|
+
return model_meta_schema.ModelObjective.RANKING
|
73
75
|
if isinstance(model, lightgbm.LGBMRegressor):
|
74
|
-
return
|
76
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
75
77
|
model_objective = model.params["objective"]
|
76
78
|
if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
|
77
|
-
return
|
79
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
78
80
|
if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
|
79
|
-
return
|
81
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
80
82
|
if model_objective in cls._RANKING_OBJECTIVES:
|
81
|
-
return
|
83
|
+
return model_meta_schema.ModelObjective.RANKING
|
82
84
|
if model_objective in cls._REGRESSION_OBJECTIVES:
|
83
|
-
return
|
84
|
-
return
|
85
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
86
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
85
87
|
|
86
88
|
@classmethod
|
87
89
|
def can_handle(
|
@@ -144,11 +146,13 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
144
146
|
sample_input_data=sample_input_data,
|
145
147
|
get_prediction_fn=get_prediction,
|
146
148
|
)
|
147
|
-
|
149
|
+
model_objective = cls.get_model_objective(model)
|
150
|
+
model_meta.model_objective = model_objective
|
151
|
+
if kwargs.get("enable_explainability", True):
|
148
152
|
output_type = model_signature.DataType.DOUBLE
|
149
|
-
if
|
150
|
-
|
151
|
-
|
153
|
+
if model_objective in [
|
154
|
+
model_meta_schema.ModelObjective.BINARY_CLASSIFICATION,
|
155
|
+
model_meta_schema.ModelObjective.MULTI_CLASSIFICATION,
|
152
156
|
]:
|
153
157
|
output_type = model_signature.DataType.STRING
|
154
158
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -157,11 +161,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
157
161
|
target_method="predict",
|
158
162
|
output_return_type=output_type,
|
159
163
|
)
|
164
|
+
model_meta.function_properties = {
|
165
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
166
|
+
}
|
160
167
|
|
161
168
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
162
169
|
os.makedirs(model_blob_path, exist_ok=True)
|
163
170
|
|
164
|
-
model_save_path = os.path.join(model_blob_path, cls.
|
171
|
+
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
165
172
|
with open(model_save_path, "wb") as f:
|
166
173
|
cloudpickle.dump(model, f)
|
167
174
|
|
@@ -169,7 +176,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
169
176
|
name=name,
|
170
177
|
model_type=cls.HANDLER_TYPE,
|
171
178
|
handler_version=cls.HANDLER_VERSION,
|
172
|
-
path=cls.
|
179
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
173
180
|
options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
|
174
181
|
)
|
175
182
|
model_meta.models[name] = base_meta
|
@@ -182,11 +189,12 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
182
189
|
],
|
183
190
|
check_local_version=True,
|
184
191
|
)
|
185
|
-
if kwargs.get("enable_explainability",
|
192
|
+
if kwargs.get("enable_explainability", True):
|
186
193
|
model_meta.env.include_if_absent(
|
187
194
|
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
188
195
|
check_local_version=True,
|
189
196
|
)
|
197
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
190
198
|
|
191
199
|
return None
|
192
200
|
|
@@ -226,6 +234,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
226
234
|
cls,
|
227
235
|
raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
|
228
236
|
model_meta: model_meta_api.ModelMetadata,
|
237
|
+
background_data: Optional[pd.DataFrame] = None,
|
229
238
|
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
230
239
|
) -> custom_model.CustomModel:
|
231
240
|
import lightgbm
|
@@ -28,7 +28,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
28
28
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
29
29
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
30
30
|
|
31
|
-
|
31
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
32
32
|
LLM_META = "llm_meta"
|
33
33
|
IS_AUTO_SIGNATURE = True
|
34
34
|
|
@@ -59,9 +59,12 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
59
59
|
**kwargs: Unpack[model_types.LLMSaveOptions],
|
60
60
|
) -> None:
|
61
61
|
assert not is_sub_model, "LLM can not be sub-model."
|
62
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
63
|
+
if enable_explainability:
|
64
|
+
raise NotImplementedError("Explainability is not supported for llm model.")
|
62
65
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
63
66
|
os.makedirs(model_blob_path, exist_ok=True)
|
64
|
-
model_blob_dir_path = os.path.join(model_blob_path, cls.
|
67
|
+
model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
65
68
|
|
66
69
|
sig = model_signature.ModelSignature(
|
67
70
|
inputs=[
|
@@ -86,7 +89,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
86
89
|
name=name,
|
87
90
|
model_type=cls.HANDLER_TYPE,
|
88
91
|
handler_version=cls.HANDLER_VERSION,
|
89
|
-
path=cls.
|
92
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
90
93
|
options=model_meta_schema.LLMModelBlobOptions(
|
91
94
|
{
|
92
95
|
"batch_size": model.max_batch_size,
|
@@ -143,6 +146,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
143
146
|
cls,
|
144
147
|
raw_model: llm.LLM,
|
145
148
|
model_meta: model_meta_api.ModelMetadata,
|
149
|
+
background_data: Optional[pd.DataFrame] = None,
|
146
150
|
**kwargs: Unpack[model_types.LLMLoadOptions],
|
147
151
|
) -> custom_model.CustomModel:
|
148
152
|
import gc
|
@@ -63,7 +63,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
63
63
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
64
64
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
65
65
|
|
66
|
-
|
66
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
67
67
|
_DEFAULT_TARGET_METHOD = "predict"
|
68
68
|
DEFAULT_TARGET_METHODS = [_DEFAULT_TARGET_METHOD]
|
69
69
|
IS_AUTO_SIGNATURE = True
|
@@ -97,6 +97,10 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
97
97
|
is_sub_model: Optional[bool] = False,
|
98
98
|
**kwargs: Unpack[model_types.MLFlowSaveOptions],
|
99
99
|
) -> None:
|
100
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
101
|
+
if enable_explainability:
|
102
|
+
raise NotImplementedError("Explainability is not supported for MLFlow model.")
|
103
|
+
|
100
104
|
import mlflow
|
101
105
|
|
102
106
|
assert isinstance(model, mlflow.pyfunc.PyFuncModel)
|
@@ -142,13 +146,13 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
142
146
|
except (mlflow.MlflowException, OSError):
|
143
147
|
raise ValueError("Cannot load MLFlow model artifacts.")
|
144
148
|
|
145
|
-
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.
|
149
|
+
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
146
150
|
|
147
151
|
base_meta = model_blob_meta.ModelBlobMeta(
|
148
152
|
name=name,
|
149
153
|
model_type=cls.HANDLER_TYPE,
|
150
154
|
handler_version=cls.HANDLER_VERSION,
|
151
|
-
path=cls.
|
155
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
152
156
|
options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": model_info.artifact_path}),
|
153
157
|
)
|
154
158
|
model_meta.models[name] = base_meta
|
@@ -194,6 +198,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
194
198
|
cls,
|
195
199
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
196
200
|
model_meta: model_meta_api.ModelMetadata,
|
201
|
+
background_data: Optional[pd.DataFrame] = None,
|
197
202
|
**kwargs: Unpack[model_types.MLFlowLoadOptions],
|
198
203
|
) -> custom_model.CustomModel:
|
199
204
|
from snowflake.ml.model import custom_model
|
@@ -37,7 +37,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
37
37
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
38
38
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
39
39
|
|
40
|
-
|
40
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
41
41
|
DEFAULT_TARGET_METHODS = ["forward"]
|
42
42
|
|
43
43
|
@classmethod
|
@@ -73,6 +73,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
73
73
|
is_sub_model: Optional[bool] = False,
|
74
74
|
**kwargs: Unpack[model_types.PyTorchSaveOptions],
|
75
75
|
) -> None:
|
76
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
77
|
+
if enable_explainability:
|
78
|
+
raise NotImplementedError("Explainability is not supported for PyTorch model.")
|
79
|
+
|
76
80
|
import torch
|
77
81
|
|
78
82
|
assert isinstance(model, torch.nn.Module)
|
@@ -115,13 +119,13 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
115
119
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
116
120
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
117
121
|
os.makedirs(model_blob_path, exist_ok=True)
|
118
|
-
with open(os.path.join(model_blob_path, cls.
|
122
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
119
123
|
torch.save(model, f, pickle_module=cloudpickle)
|
120
124
|
base_meta = model_blob_meta.ModelBlobMeta(
|
121
125
|
name=name,
|
122
126
|
model_type=cls.HANDLER_TYPE,
|
123
127
|
handler_version=cls.HANDLER_VERSION,
|
124
|
-
path=cls.
|
128
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
125
129
|
)
|
126
130
|
model_meta.models[name] = base_meta
|
127
131
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -156,6 +160,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
156
160
|
cls,
|
157
161
|
raw_model: "torch.nn.Module",
|
158
162
|
model_meta: model_meta_api.ModelMetadata,
|
163
|
+
background_data: Optional[pd.DataFrame] = None,
|
159
164
|
**kwargs: Unpack[model_types.PyTorchLoadOptions],
|
160
165
|
) -> custom_model.CustomModel:
|
161
166
|
import torch
|
@@ -31,7 +31,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
31
31
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
32
32
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
33
33
|
|
34
|
-
|
34
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
35
35
|
DEFAULT_TARGET_METHODS = ["encode"]
|
36
36
|
|
37
37
|
@classmethod
|
@@ -64,6 +64,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
64
64
|
is_sub_model: Optional[bool] = False,
|
65
65
|
**kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
|
66
66
|
) -> None:
|
67
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
68
|
+
if enable_explainability:
|
69
|
+
raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
|
70
|
+
|
67
71
|
# Validate target methods and signature (if possible)
|
68
72
|
if not is_sub_model:
|
69
73
|
target_methods = handlers_utils.get_target_methods(
|
@@ -101,14 +105,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
101
105
|
# save model
|
102
106
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
103
107
|
os.makedirs(model_blob_path, exist_ok=True)
|
104
|
-
model.save(os.path.join(model_blob_path, cls.
|
108
|
+
model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
105
109
|
|
106
110
|
# save model metadata
|
107
111
|
base_meta = model_blob_meta.ModelBlobMeta(
|
108
112
|
name=name,
|
109
113
|
model_type=cls.HANDLER_TYPE,
|
110
114
|
handler_version=cls.HANDLER_VERSION,
|
111
|
-
path=cls.
|
115
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
112
116
|
)
|
113
117
|
model_meta.models[name] = base_meta
|
114
118
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -154,6 +158,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
154
158
|
cls,
|
155
159
|
raw_model: "sentence_transformers.SentenceTransformer",
|
156
160
|
model_meta: model_meta_api.ModelMetadata,
|
161
|
+
background_data: Optional[pd.DataFrame] = None,
|
157
162
|
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
|
158
163
|
) -> custom_model.CustomModel:
|
159
164
|
import sentence_transformers
|
@@ -6,6 +6,7 @@ import numpy as np
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
8
8
|
|
9
|
+
import snowflake.snowpark.dataframe as sp_df
|
9
10
|
from snowflake.ml._internal import type_utils
|
10
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
12
|
from snowflake.ml.model._packager.model_env import model_env
|
@@ -14,8 +15,13 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
14
15
|
from snowflake.ml.model._packager.model_meta import (
|
15
16
|
model_blob_meta,
|
16
17
|
model_meta as model_meta_api,
|
18
|
+
model_meta_schema,
|
19
|
+
)
|
20
|
+
from snowflake.ml.model._signatures import (
|
21
|
+
numpy_handler,
|
22
|
+
snowpark_handler,
|
23
|
+
utils as model_signature_utils,
|
17
24
|
)
|
18
|
-
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
19
25
|
|
20
26
|
if TYPE_CHECKING:
|
21
27
|
import sklearn.base
|
@@ -36,6 +42,27 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
36
42
|
|
37
43
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
38
44
|
|
45
|
+
@classmethod
|
46
|
+
def get_model_objective(
|
47
|
+
cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
|
48
|
+
) -> model_meta_schema.ModelObjective:
|
49
|
+
import sklearn.pipeline
|
50
|
+
from sklearn.base import is_classifier, is_regressor
|
51
|
+
|
52
|
+
if isinstance(model, sklearn.pipeline.Pipeline):
|
53
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
54
|
+
if is_regressor(model):
|
55
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
56
|
+
if is_classifier(model):
|
57
|
+
classes_list = getattr(model, "classes_", [])
|
58
|
+
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
59
|
+
if isinstance(num_classes, int):
|
60
|
+
if num_classes > 2:
|
61
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
62
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
63
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
64
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
65
|
+
|
39
66
|
@classmethod
|
40
67
|
def can_handle(
|
41
68
|
cls,
|
@@ -79,11 +106,33 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
79
106
|
is_sub_model: Optional[bool] = False,
|
80
107
|
**kwargs: Unpack[model_types.SKLModelSaveOptions],
|
81
108
|
) -> None:
|
109
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
110
|
+
|
82
111
|
import sklearn.base
|
83
112
|
import sklearn.pipeline
|
84
113
|
|
85
114
|
assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
|
86
115
|
|
116
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
117
|
+
if enable_explainability:
|
118
|
+
# TODO: Currently limited to pandas df, need to extend to other types.
|
119
|
+
if sample_input_data is None or not (
|
120
|
+
isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame)
|
121
|
+
):
|
122
|
+
raise ValueError(
|
123
|
+
"Sample input data is required to enable explainability. Currently we only support this for "
|
124
|
+
+ "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
|
125
|
+
)
|
126
|
+
sample_input_data_pandas = (
|
127
|
+
sample_input_data
|
128
|
+
if isinstance(sample_input_data, pd.DataFrame)
|
129
|
+
else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
130
|
+
)
|
131
|
+
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
|
132
|
+
os.makedirs(data_blob_path, exist_ok=True)
|
133
|
+
with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
|
134
|
+
sample_input_data_pandas.to_parquet(f)
|
135
|
+
|
87
136
|
if not is_sub_model:
|
88
137
|
target_methods = handlers_utils.get_target_methods(
|
89
138
|
model=model,
|
@@ -110,19 +159,36 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
110
159
|
get_prediction_fn=get_prediction,
|
111
160
|
)
|
112
161
|
|
162
|
+
if enable_explainability:
|
163
|
+
output_type = model_signature.DataType.DOUBLE
|
164
|
+
if cls.get_model_objective(model) == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
165
|
+
output_type = model_signature.DataType.STRING
|
166
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
167
|
+
model_meta=model_meta,
|
168
|
+
explain_method="explain",
|
169
|
+
target_method="predict",
|
170
|
+
output_return_type=output_type,
|
171
|
+
)
|
172
|
+
|
113
173
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
114
174
|
os.makedirs(model_blob_path, exist_ok=True)
|
115
|
-
with open(os.path.join(model_blob_path, cls.
|
175
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
116
176
|
cloudpickle.dump(model, f)
|
117
177
|
base_meta = model_blob_meta.ModelBlobMeta(
|
118
178
|
name=name,
|
119
179
|
model_type=cls.HANDLER_TYPE,
|
120
180
|
handler_version=cls.HANDLER_VERSION,
|
121
|
-
path=cls.
|
181
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
122
182
|
)
|
123
183
|
model_meta.models[name] = base_meta
|
124
184
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
125
185
|
|
186
|
+
if enable_explainability:
|
187
|
+
model_meta.env.include_if_absent(
|
188
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
189
|
+
check_local_version=True,
|
190
|
+
)
|
191
|
+
|
126
192
|
model_meta.env.include_if_absent(
|
127
193
|
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
|
128
194
|
)
|
@@ -153,6 +219,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
153
219
|
cls,
|
154
220
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
155
221
|
model_meta: model_meta_api.ModelMetadata,
|
222
|
+
background_data: Optional[pd.DataFrame] = None,
|
156
223
|
**kwargs: Unpack[model_types.SKLModelLoadOptions],
|
157
224
|
) -> custom_model.CustomModel:
|
158
225
|
from snowflake.ml.model import custom_model
|
@@ -165,6 +232,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
165
232
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
166
233
|
signature: model_signature.ModelSignature,
|
167
234
|
target_method: str,
|
235
|
+
background_data: Optional[pd.DataFrame],
|
168
236
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
169
237
|
@custom_model.inference_api
|
170
238
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
@@ -179,11 +247,26 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
179
247
|
|
180
248
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
181
249
|
|
250
|
+
@custom_model.inference_api
|
251
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
252
|
+
import shap
|
253
|
+
|
254
|
+
# TODO: if not resolved by explainer, we need to pass the callable function
|
255
|
+
try:
|
256
|
+
explainer = shap.Explainer(raw_model, background_data)
|
257
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
258
|
+
except TypeError as e:
|
259
|
+
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
260
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
261
|
+
|
262
|
+
if target_method == "explain":
|
263
|
+
return explain_fn
|
264
|
+
|
182
265
|
return fn
|
183
266
|
|
184
267
|
type_method_dict = {}
|
185
268
|
for target_method_name, sig in model_meta.signatures.items():
|
186
|
-
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
269
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
|
187
270
|
|
188
271
|
_SKLModel = type(
|
189
272
|
"_SKLModel",
|
@@ -73,6 +73,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
73
73
|
is_sub_model: Optional[bool] = False,
|
74
74
|
**kwargs: Unpack[model_types.SNOWModelSaveOptions],
|
75
75
|
) -> None:
|
76
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
77
|
+
if enable_explainability:
|
78
|
+
raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
|
79
|
+
|
76
80
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
77
81
|
|
78
82
|
assert isinstance(model, BaseEstimator)
|
@@ -103,13 +107,13 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
103
107
|
|
104
108
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
105
109
|
os.makedirs(model_blob_path, exist_ok=True)
|
106
|
-
with open(os.path.join(model_blob_path, cls.
|
110
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
107
111
|
cloudpickle.dump(model, f)
|
108
112
|
base_meta = model_blob_meta.ModelBlobMeta(
|
109
113
|
name=name,
|
110
114
|
model_type=cls.HANDLER_TYPE,
|
111
115
|
handler_version=cls.HANDLER_VERSION,
|
112
|
-
path=cls.
|
116
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
113
117
|
)
|
114
118
|
model_meta.models[name] = base_meta
|
115
119
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -146,6 +150,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
146
150
|
cls,
|
147
151
|
raw_model: "BaseEstimator",
|
148
152
|
model_meta: model_meta_api.ModelMetadata,
|
153
|
+
background_data: Optional[pd.DataFrame] = None,
|
149
154
|
**kwargs: Unpack[model_types.SNOWModelLoadOptions],
|
150
155
|
) -> custom_model.CustomModel:
|
151
156
|
from snowflake.ml.model import custom_model
|
@@ -36,7 +36,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
36
36
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
37
37
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
38
38
|
|
39
|
-
|
39
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
40
40
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
41
41
|
|
42
42
|
@classmethod
|
@@ -68,6 +68,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
68
68
|
is_sub_model: Optional[bool] = False,
|
69
69
|
**kwargs: Unpack[model_types.TensorflowSaveOptions],
|
70
70
|
) -> None:
|
71
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
72
|
+
if enable_explainability:
|
73
|
+
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
74
|
+
|
71
75
|
import tensorflow
|
72
76
|
|
73
77
|
assert isinstance(model, tensorflow.Module)
|
@@ -114,15 +118,15 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
114
118
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
115
119
|
os.makedirs(model_blob_path, exist_ok=True)
|
116
120
|
if isinstance(model, tensorflow.keras.Model):
|
117
|
-
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.
|
121
|
+
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
118
122
|
else:
|
119
|
-
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.
|
123
|
+
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
120
124
|
|
121
125
|
base_meta = model_blob_meta.ModelBlobMeta(
|
122
126
|
name=name,
|
123
127
|
model_type=cls.HANDLER_TYPE,
|
124
128
|
handler_version=cls.HANDLER_VERSION,
|
125
|
-
path=cls.
|
129
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
126
130
|
)
|
127
131
|
model_meta.models[name] = base_meta
|
128
132
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -156,6 +160,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
156
160
|
cls,
|
157
161
|
raw_model: "tensorflow.Module",
|
158
162
|
model_meta: model_meta_api.ModelMetadata,
|
163
|
+
background_data: Optional[pd.DataFrame] = None,
|
159
164
|
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
160
165
|
) -> custom_model.CustomModel:
|
161
166
|
import tensorflow
|
@@ -34,7 +34,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
34
34
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
35
35
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
36
|
|
37
|
-
|
37
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
38
38
|
DEFAULT_TARGET_METHODS = ["forward"]
|
39
39
|
|
40
40
|
@classmethod
|
@@ -66,6 +66,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
66
66
|
is_sub_model: Optional[bool] = False,
|
67
67
|
**kwargs: Unpack[model_types.TorchScriptSaveOptions],
|
68
68
|
) -> None:
|
69
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
70
|
+
if enable_explainability:
|
71
|
+
raise NotImplementedError("Explainability is not supported for Torch Script model.")
|
72
|
+
|
69
73
|
import torch
|
70
74
|
|
71
75
|
assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
@@ -106,13 +110,13 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
106
110
|
|
107
111
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
108
112
|
os.makedirs(model_blob_path, exist_ok=True)
|
109
|
-
with open(os.path.join(model_blob_path, cls.
|
113
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
110
114
|
torch.jit.save(model, f) # type:ignore[attr-defined]
|
111
115
|
base_meta = model_blob_meta.ModelBlobMeta(
|
112
116
|
name=name,
|
113
117
|
model_type=cls.HANDLER_TYPE,
|
114
118
|
handler_version=cls.HANDLER_VERSION,
|
115
|
-
path=cls.
|
119
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
116
120
|
)
|
117
121
|
model_meta.models[name] = base_meta
|
118
122
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -152,6 +156,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
152
156
|
cls,
|
153
157
|
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
154
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
|
+
background_data: Optional[pd.DataFrame] = None,
|
155
160
|
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
156
161
|
) -> custom_model.CustomModel:
|
157
162
|
from snowflake.ml.model import custom_model
|