snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +25 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +31 -1
- snowflake/ml/jobs/_utils/payload_utils.py +232 -72
- snowflake/ml/jobs/_utils/spec_utils.py +78 -38
- snowflake/ml/jobs/decorators.py +8 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +49 -29
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +54 -33
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +12 -20
- snowflake/ml/model/_signatures/pandas_handler.py +28 -37
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +39 -13
- snowflake/ml/model/type_hints.py +28 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +52 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._packager.model_env import model_env
|
12
12
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
13
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
13
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
14
|
+
base_migrator,
|
15
|
+
pytorch_migrator_2023_12_01,
|
16
|
+
)
|
14
17
|
from snowflake.ml.model._packager.model_meta import (
|
15
18
|
model_blob_meta,
|
16
19
|
model_meta as model_meta_api,
|
20
|
+
model_meta_schema,
|
17
21
|
)
|
18
22
|
from snowflake.ml.model._signatures import (
|
19
23
|
pytorch_handler,
|
@@ -21,7 +25,6 @@ from snowflake.ml.model._signatures import (
|
|
21
25
|
)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
|
-
import sentence_transformers # noqa: F401
|
25
28
|
import torch
|
26
29
|
|
27
30
|
|
@@ -33,9 +36,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
33
36
|
"""
|
34
37
|
|
35
38
|
HANDLER_TYPE = "pytorch"
|
36
|
-
HANDLER_VERSION = "
|
37
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0
|
38
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
39
|
+
HANDLER_VERSION = "2025-03-01"
|
40
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
|
+
"2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
|
43
|
+
}
|
39
44
|
|
40
45
|
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
41
46
|
DEFAULT_TARGET_METHODS = ["forward"]
|
@@ -49,6 +54,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
49
54
|
type_utils.LazyType("torch.nn.Module").isinstance(model)
|
50
55
|
and not type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
|
51
56
|
and not type_utils.LazyType("sentence_transformers.SentenceTransformer").isinstance(model)
|
57
|
+
and not type_utils.LazyType("keras.Model").isinstance(model)
|
52
58
|
)
|
53
59
|
|
54
60
|
@classmethod
|
@@ -88,22 +94,33 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
88
94
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
89
95
|
)
|
90
96
|
|
97
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
98
|
+
|
91
99
|
def get_prediction(
|
92
100
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
93
101
|
) -> model_types.SupportedLocalDataType:
|
94
|
-
if
|
95
|
-
|
96
|
-
|
97
|
-
|
102
|
+
if multiple_inputs:
|
103
|
+
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
|
104
|
+
sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
|
105
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
|
109
|
+
sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
|
110
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
111
|
+
)
|
98
112
|
|
99
113
|
model.eval()
|
100
114
|
target_method = getattr(model, target_method_name, None)
|
101
115
|
assert callable(target_method)
|
102
116
|
with torch.no_grad():
|
103
|
-
|
117
|
+
if multiple_inputs:
|
118
|
+
predictions_df = target_method(*sample_input_data)
|
119
|
+
if not isinstance(predictions_df, tuple):
|
120
|
+
predictions_df = [predictions_df]
|
121
|
+
else:
|
122
|
+
predictions_df = target_method(sample_input_data)
|
104
123
|
|
105
|
-
if isinstance(predictions_df, torch.Tensor):
|
106
|
-
predictions_df = [predictions_df]
|
107
124
|
return predictions_df
|
108
125
|
|
109
126
|
model_meta = handlers_utils.validate_signature(
|
@@ -126,6 +143,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
126
143
|
model_type=cls.HANDLER_TYPE,
|
127
144
|
handler_version=cls.HANDLER_VERSION,
|
128
145
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
146
|
+
options=model_meta_schema.PyTorchModelBlobOptions(multiple_inputs=multiple_inputs),
|
129
147
|
)
|
130
148
|
model_meta.models[name] = base_meta
|
131
149
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -171,6 +189,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
171
189
|
raw_model: "torch.nn.Module",
|
172
190
|
model_meta: model_meta_api.ModelMetadata,
|
173
191
|
) -> Type[custom_model.CustomModel]:
|
192
|
+
multiple_inputs = cast(
|
193
|
+
model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
|
194
|
+
)["multiple_inputs"]
|
195
|
+
|
174
196
|
def fn_factory(
|
175
197
|
raw_model: "torch.nn.Module",
|
176
198
|
signature: model_signature.ModelSignature,
|
@@ -182,19 +204,28 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
182
204
|
raise ValueError("Tensor cannot handle null values.")
|
183
205
|
|
184
206
|
raw_model.eval()
|
185
|
-
|
207
|
+
if multiple_inputs:
|
208
|
+
st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
209
|
+
|
210
|
+
if kwargs.get("use_gpu", False):
|
211
|
+
st = [element.cuda() for element in st]
|
186
212
|
|
187
|
-
|
188
|
-
|
213
|
+
with torch.no_grad():
|
214
|
+
res = getattr(raw_model, target_method)(*st)
|
189
215
|
|
190
|
-
|
191
|
-
|
216
|
+
if not isinstance(res, tuple):
|
217
|
+
res = [res]
|
218
|
+
else:
|
219
|
+
t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
220
|
+
if kwargs.get("use_gpu", False):
|
221
|
+
t = t.cuda()
|
192
222
|
|
193
|
-
|
194
|
-
|
223
|
+
with torch.no_grad():
|
224
|
+
res = getattr(raw_model, target_method)(t)
|
195
225
|
|
196
226
|
return model_signature_utils.rename_pandas_df(
|
197
|
-
|
227
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
228
|
+
features=signature.outputs,
|
198
229
|
)
|
199
230
|
|
200
231
|
return fn
|
@@ -292,12 +292,34 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
292
292
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
293
293
|
import shap
|
294
294
|
|
295
|
-
# TODO: if not resolved by explainer, we need to pass the callable function
|
296
295
|
try:
|
297
296
|
explainer = shap.Explainer(raw_model, background_data)
|
298
297
|
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
299
|
-
except TypeError
|
300
|
-
|
298
|
+
except TypeError:
|
299
|
+
try:
|
300
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
301
|
+
|
302
|
+
if isinstance(X, pd.DataFrame):
|
303
|
+
X = X.astype(dtype_map, copy=False)
|
304
|
+
if hasattr(raw_model, "predict_proba"):
|
305
|
+
if isinstance(X, np.ndarray):
|
306
|
+
explanations = shap.Explainer(
|
307
|
+
raw_model.predict_proba, background_data.values # type: ignore[union-attr]
|
308
|
+
)(X).values
|
309
|
+
else:
|
310
|
+
explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
|
311
|
+
elif hasattr(raw_model, "predict"):
|
312
|
+
if isinstance(X, np.ndarray):
|
313
|
+
explanations = shap.Explainer(
|
314
|
+
raw_model.predict, background_data.values # type: ignore[union-attr]
|
315
|
+
)(X).values
|
316
|
+
else:
|
317
|
+
explanations = shap.Explainer(raw_model.predict, background_data)(X).values
|
318
|
+
else:
|
319
|
+
raise ValueError("Missing any supported target method to explain.")
|
320
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
|
321
|
+
except TypeError as e:
|
322
|
+
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
301
323
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
302
324
|
|
303
325
|
if target_method == "explain":
|
@@ -74,11 +74,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
74
74
|
background_data: Optional[model_types.SupportedDataType],
|
75
75
|
enable_explainability: Optional[bool],
|
76
76
|
) -> Any:
|
77
|
-
from snowflake.ml.modeling import pipeline as snowml_pipeline
|
78
|
-
|
79
|
-
# handle pipeline objects separately
|
80
|
-
if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
|
81
|
-
return None
|
82
77
|
|
83
78
|
tree_methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
|
84
79
|
non_tree_methods = ["to_sklearn"]
|
@@ -129,27 +124,54 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
129
124
|
# Pipeline is inherited from BaseEstimator, so no need to add one more check
|
130
125
|
|
131
126
|
if not is_sub_model:
|
132
|
-
if model_meta.signatures:
|
127
|
+
if model_meta.signatures or sample_input_data is not None:
|
133
128
|
warnings.warn(
|
134
129
|
"Providing model signature for Snowpark ML "
|
135
130
|
+ "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
|
136
131
|
UserWarning,
|
137
132
|
stacklevel=2,
|
138
133
|
)
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
134
|
+
target_methods = handlers_utils.get_target_methods(
|
135
|
+
model=model,
|
136
|
+
target_methods=kwargs.pop("target_methods", None),
|
137
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
138
|
+
)
|
139
|
+
|
140
|
+
def get_prediction(
|
141
|
+
target_method_name: str,
|
142
|
+
sample_input_data: model_types.SupportedLocalDataType,
|
143
|
+
) -> model_types.SupportedLocalDataType:
|
144
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
145
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
146
|
+
|
147
|
+
target_method = getattr(model, target_method_name, None)
|
148
|
+
assert callable(target_method)
|
149
|
+
predictions_df = target_method(sample_input_data)
|
150
|
+
return predictions_df
|
151
|
+
|
152
|
+
model_meta = handlers_utils.validate_signature(
|
153
|
+
model=model,
|
154
|
+
model_meta=model_meta,
|
155
|
+
target_methods=target_methods,
|
156
|
+
sample_input_data=sample_input_data,
|
157
|
+
get_prediction_fn=get_prediction,
|
158
|
+
is_for_modeling_model=True,
|
159
|
+
)
|
144
160
|
else:
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
161
|
+
assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
|
162
|
+
model_signature_dict = getattr(model, "model_signatures", {})
|
163
|
+
optional_target_methods = kwargs.pop("target_methods", None)
|
164
|
+
if not optional_target_methods:
|
165
|
+
model_meta.signatures = model_signature_dict
|
166
|
+
else:
|
167
|
+
temp_model_signature_dict = {}
|
168
|
+
for method_name in optional_target_methods:
|
169
|
+
method_model_signature = model_signature_dict.get(method_name, None)
|
170
|
+
if method_model_signature is not None:
|
171
|
+
temp_model_signature_dict[method_name] = method_model_signature
|
172
|
+
else:
|
173
|
+
raise ValueError(f"Target method {method_name} does not exist in the model.")
|
174
|
+
model_meta.signatures = temp_model_signature_dict
|
153
175
|
|
154
176
|
python_base_obj = cls._get_supported_object_for_explainability(model, sample_input_data, enable_explainability)
|
155
177
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
@@ -279,9 +301,39 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
279
301
|
for method_name in non_tree_methods:
|
280
302
|
try:
|
281
303
|
base_model = getattr(raw_model, method_name)()
|
282
|
-
|
283
|
-
|
304
|
+
try:
|
305
|
+
explainer = shap.Explainer(base_model, masker=background_data)
|
306
|
+
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
307
|
+
except TypeError:
|
308
|
+
try:
|
309
|
+
dtype_map = {
|
310
|
+
spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
|
311
|
+
}
|
312
|
+
|
313
|
+
if isinstance(X, pd.DataFrame):
|
314
|
+
X = X.astype(dtype_map, copy=False)
|
315
|
+
if hasattr(base_model, "predict_proba"):
|
316
|
+
if isinstance(X, np.ndarray):
|
317
|
+
explainer = shap.Explainer(
|
318
|
+
base_model.predict_proba,
|
319
|
+
background_data.values, # type: ignore[union-attr]
|
320
|
+
)
|
321
|
+
else:
|
322
|
+
explainer = shap.Explainer(base_model.predict_proba, background_data)
|
323
|
+
elif hasattr(base_model, "predict"):
|
324
|
+
if isinstance(X, np.ndarray):
|
325
|
+
explainer = shap.Explainer(
|
326
|
+
base_model.predict, background_data.values # type: ignore[union-attr]
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
explainer = shap.Explainer(base_model.predict, background_data)
|
330
|
+
else:
|
331
|
+
raise ValueError("Missing any supported target method to explain.")
|
332
|
+
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
333
|
+
except TypeError as e:
|
334
|
+
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
284
335
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
336
|
+
|
285
337
|
except exceptions.SnowflakeMLException:
|
286
338
|
pass # Do nothing and continue to the next method
|
287
339
|
raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
3
|
|
4
|
-
import numpy as np
|
5
4
|
import pandas as pd
|
6
5
|
from packaging import version
|
7
6
|
from typing_extensions import TypeGuard, Unpack
|
@@ -10,14 +9,17 @@ from snowflake.ml._internal import type_utils
|
|
10
9
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
10
|
from snowflake.ml.model._packager.model_env import model_env
|
12
11
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
13
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
12
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
13
|
+
base_migrator,
|
14
|
+
tensorflow_migrator_2023_12_01,
|
15
|
+
tensorflow_migrator_2025_01_01,
|
16
|
+
)
|
14
17
|
from snowflake.ml.model._packager.model_meta import (
|
15
18
|
model_blob_meta,
|
16
19
|
model_meta as model_meta_api,
|
17
20
|
model_meta_schema,
|
18
21
|
)
|
19
22
|
from snowflake.ml.model._signatures import (
|
20
|
-
numpy_handler,
|
21
23
|
tensorflow_handler,
|
22
24
|
utils as model_signature_utils,
|
23
25
|
)
|
@@ -28,15 +30,18 @@ if TYPE_CHECKING:
|
|
28
30
|
|
29
31
|
@final
|
30
32
|
class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
31
|
-
"""Handler for TensorFlow based model.
|
33
|
+
"""Handler for TensorFlow based model or keras v2 model.
|
32
34
|
|
33
35
|
Currently tensorflow.Module based classes are supported.
|
34
36
|
"""
|
35
37
|
|
36
38
|
HANDLER_TYPE = "tensorflow"
|
37
|
-
HANDLER_VERSION = "
|
38
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0
|
39
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
39
|
+
HANDLER_VERSION = "2025-03-01"
|
40
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
|
+
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
|
43
|
+
"2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
|
44
|
+
}
|
40
45
|
|
41
46
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
42
47
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
@@ -46,7 +51,13 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
46
51
|
cls,
|
47
52
|
model: model_types.SupportedModelType,
|
48
53
|
) -> TypeGuard["tensorflow.nn.Module"]:
|
49
|
-
|
54
|
+
if not type_utils.LazyType("tensorflow.Module").isinstance(model):
|
55
|
+
return False
|
56
|
+
if type_utils.LazyType("keras.Model").isinstance(model):
|
57
|
+
import keras
|
58
|
+
|
59
|
+
return version.parse(keras.__version__) < version.parse("3.0.0")
|
60
|
+
return True
|
50
61
|
|
51
62
|
@classmethod
|
52
63
|
def cast_model(
|
@@ -74,44 +85,22 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
74
85
|
if enable_explainability:
|
75
86
|
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
76
87
|
|
77
|
-
# When tensorflow is installed, keras is also installed.
|
78
|
-
import keras
|
79
88
|
import tensorflow
|
80
89
|
|
81
90
|
assert isinstance(model, tensorflow.Module)
|
82
91
|
|
83
|
-
is_keras_model = type_utils.LazyType("
|
84
|
-
"keras.Model"
|
85
|
-
).isinstance(model)
|
92
|
+
is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
|
86
93
|
is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
|
87
|
-
is_keras_functional_or_sequential_model = (
|
88
|
-
getattr(model, "_is_graph_network", False)
|
89
|
-
or type_utils.LazyType("tensorflow.keras.engine.sequential.Sequential").isinstance(model)
|
90
|
-
or type_utils.LazyType("keras.engine.sequential.Sequential").isinstance(model)
|
91
|
-
or type_utils.LazyType("tf_keras.engine.sequential.Sequential").isinstance(model)
|
92
|
-
)
|
93
|
-
|
94
|
-
assert isinstance(model, tensorflow.Module)
|
95
|
-
|
96
|
-
keras_version = version.parse(keras.__version__)
|
97
|
-
|
98
94
|
# Tensorflow and keras model save format is different.
|
99
|
-
# Keras
|
100
|
-
#
|
101
|
-
# Keras v2 other models are saved using tensorflow saved model format
|
102
|
-
# Tensorflow models are saved using tensorflow saved model format
|
95
|
+
# Keras v2 models are saved using keras api
|
96
|
+
# Tensorflow models are saved using tensorflow api
|
103
97
|
|
104
98
|
if is_keras_model or is_tf_keras_model:
|
105
|
-
|
106
|
-
save_format = "keras"
|
107
|
-
elif keras_version.major == 2 or is_tf_keras_model:
|
108
|
-
save_format = "keras_tf"
|
109
|
-
else:
|
110
|
-
save_format = "cloudpickle"
|
99
|
+
save_format = "keras_tf"
|
111
100
|
else:
|
112
101
|
save_format = "tf"
|
113
102
|
|
114
|
-
if is_keras_model:
|
103
|
+
if is_keras_model or is_tf_keras_model:
|
115
104
|
default_target_methods = ["predict"]
|
116
105
|
else:
|
117
106
|
default_target_methods = cls.DEFAULT_TARGET_METHODS
|
@@ -123,25 +112,35 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
123
112
|
default_target_methods=default_target_methods,
|
124
113
|
)
|
125
114
|
|
115
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
116
|
+
|
126
117
|
if is_keras_model and len(target_methods) > 1:
|
127
118
|
raise ValueError("Keras model can only have one target method.")
|
128
119
|
|
129
120
|
def get_prediction(
|
130
121
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
131
122
|
) -> model_types.SupportedLocalDataType:
|
132
|
-
if
|
133
|
-
|
134
|
-
|
135
|
-
|
123
|
+
if multiple_inputs:
|
124
|
+
if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
|
125
|
+
sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
|
126
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
if not tensorflow_handler.TensorflowTensorHandler.can_handle(sample_input_data):
|
130
|
+
sample_input_data = tensorflow_handler.TensorflowTensorHandler.convert_from_df(
|
131
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
132
|
+
)
|
136
133
|
|
137
134
|
target_method = getattr(model, target_method_name, None)
|
138
135
|
assert callable(target_method)
|
139
136
|
for tensor in sample_input_data:
|
140
137
|
tensorflow.stop_gradient(tensor)
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
138
|
+
if multiple_inputs:
|
139
|
+
predictions_df = target_method(*sample_input_data)
|
140
|
+
if not isinstance(predictions_df, tuple):
|
141
|
+
predictions_df = [predictions_df]
|
142
|
+
else:
|
143
|
+
predictions_df = target_method(sample_input_data)
|
145
144
|
|
146
145
|
return predictions_df
|
147
146
|
|
@@ -156,15 +155,8 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
156
155
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
157
156
|
os.makedirs(model_blob_path, exist_ok=True)
|
158
157
|
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
159
|
-
if save_format == "
|
160
|
-
model.save(save_path, save_format="keras")
|
161
|
-
elif save_format == "keras_tf":
|
158
|
+
if save_format == "keras_tf":
|
162
159
|
model.save(save_path, save_format="tf")
|
163
|
-
elif save_format == "cloudpickle":
|
164
|
-
import cloudpickle
|
165
|
-
|
166
|
-
with open(save_path, "wb") as f:
|
167
|
-
cloudpickle.dump(model, f)
|
168
160
|
else:
|
169
161
|
tensorflow.saved_model.save(
|
170
162
|
model,
|
@@ -177,7 +169,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
177
169
|
model_type=cls.HANDLER_TYPE,
|
178
170
|
handler_version=cls.HANDLER_VERSION,
|
179
171
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
180
|
-
options=model_meta_schema.TensorflowModelBlobOptions(
|
172
|
+
options=model_meta_schema.TensorflowModelBlobOptions(
|
173
|
+
save_format=save_format, multiple_inputs=multiple_inputs
|
174
|
+
),
|
181
175
|
)
|
182
176
|
model_meta.models[name] = base_meta
|
183
177
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -186,7 +180,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
186
180
|
model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
|
187
181
|
]
|
188
182
|
if is_keras_model:
|
189
|
-
dependencies.append(model_env.ModelDependency(requirement="keras", pip_name="keras"))
|
183
|
+
dependencies.append(model_env.ModelDependency(requirement="keras<=3", pip_name="keras"))
|
190
184
|
elif is_tf_keras_model:
|
191
185
|
dependencies.append(model_env.ModelDependency(requirement="tf-keras", pip_name="tf-keras"))
|
192
186
|
|
@@ -204,6 +198,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
204
198
|
model_blobs_dir_path: str,
|
205
199
|
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
206
200
|
) -> "tensorflow.Module":
|
201
|
+
os.environ["TF_USE_LEGACY_KERAS"] = "1"
|
207
202
|
import tensorflow
|
208
203
|
|
209
204
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
@@ -212,14 +207,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
212
207
|
model_blob_filename = model_blob_metadata.path
|
213
208
|
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
214
209
|
load_path = os.path.join(model_blob_path, model_blob_filename)
|
215
|
-
save_format = model_blob_options.get("save_format", "
|
216
|
-
if save_format == "
|
210
|
+
save_format = model_blob_options.get("save_format", "keras_tf")
|
211
|
+
if save_format == "keras_tf":
|
217
212
|
m = tensorflow.keras.models.load_model(load_path)
|
218
|
-
elif save_format == "cloudpickle":
|
219
|
-
import cloudpickle
|
220
|
-
|
221
|
-
with open(load_path, "rb") as f:
|
222
|
-
m = cloudpickle.load(f)
|
223
213
|
else:
|
224
214
|
m = tensorflow.saved_model.load(load_path)
|
225
215
|
|
@@ -241,6 +231,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
241
231
|
raw_model: "tensorflow.Module",
|
242
232
|
model_meta: model_meta_api.ModelMetadata,
|
243
233
|
) -> Type[custom_model.CustomModel]:
|
234
|
+
multiple_inputs = cast(
|
235
|
+
model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
|
236
|
+
)["multiple_inputs"]
|
237
|
+
|
244
238
|
def fn_factory(
|
245
239
|
raw_model: "tensorflow.Module",
|
246
240
|
signature: model_signature.ModelSignature,
|
@@ -251,21 +245,25 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
251
245
|
if X.isnull().any(axis=None):
|
252
246
|
raise ValueError("Tensor cannot handle null values.")
|
253
247
|
|
254
|
-
|
248
|
+
if multiple_inputs:
|
249
|
+
t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
|
255
250
|
|
256
|
-
|
257
|
-
|
258
|
-
|
251
|
+
for tensor in t:
|
252
|
+
tensorflow.stop_gradient(tensor)
|
253
|
+
res = getattr(raw_model, target_method)(*t)
|
259
254
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
264
|
-
# In case of running on CPU, it will return numpy array
|
265
|
-
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
255
|
+
if not isinstance(res, tuple):
|
256
|
+
res = [res]
|
266
257
|
else:
|
267
|
-
|
268
|
-
|
258
|
+
t = tensorflow_handler.TensorflowTensorHandler.convert_from_df(X, signature.inputs)
|
259
|
+
|
260
|
+
tensorflow.stop_gradient(t)
|
261
|
+
res = getattr(raw_model, target_method)(t)
|
262
|
+
|
263
|
+
return model_signature_utils.rename_pandas_df(
|
264
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
265
|
+
features=signature.outputs,
|
266
|
+
)
|
269
267
|
|
270
268
|
return fn
|
271
269
|
|