snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/file_utils.py +18 -4
  4. snowflake/ml/_internal/platform_capabilities.py +3 -0
  5. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  6. snowflake/ml/_internal/telemetry.py +25 -0
  7. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +0 -1
  11. snowflake/ml/jobs/_utils/constants.py +31 -1
  12. snowflake/ml/jobs/_utils/payload_utils.py +232 -72
  13. snowflake/ml/jobs/_utils/spec_utils.py +78 -38
  14. snowflake/ml/jobs/decorators.py +8 -25
  15. snowflake/ml/jobs/job.py +4 -4
  16. snowflake/ml/jobs/manager.py +5 -0
  17. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  18. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  19. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  21. snowflake/ml/model/_client/sql/model_version.py +58 -0
  22. snowflake/ml/model/_client/sql/service.py +8 -2
  23. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  26. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  27. snowflake/ml/model/_packager/model_env/model_env.py +49 -29
  28. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  29. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
  30. snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
  31. snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
  32. snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
  33. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
  34. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
  35. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  36. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  37. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  38. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  39. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  40. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  41. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  42. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
  43. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
  44. snowflake/ml/model/_packager/model_packager.py +3 -5
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  48. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  49. snowflake/ml/model/_signatures/core.py +54 -33
  50. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  51. snowflake/ml/model/_signatures/numpy_handler.py +12 -20
  52. snowflake/ml/model/_signatures/pandas_handler.py +28 -37
  53. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  54. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  55. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  56. snowflake/ml/model/_signatures/utils.py +120 -8
  57. snowflake/ml/model/custom_model.py +13 -4
  58. snowflake/ml/model/model_signature.py +39 -13
  59. snowflake/ml/model/type_hints.py +28 -2
  60. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  61. snowflake/ml/modeling/metrics/ranking.py +3 -0
  62. snowflake/ml/modeling/metrics/regression.py +3 -0
  63. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  64. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  65. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  66. snowflake/ml/registry/_manager/model_manager.py +55 -7
  67. snowflake/ml/registry/registry.py +52 -4
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
  70. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
  71. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  72. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._packager.model_env import model_env
12
12
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
+ from snowflake.ml.model._packager.model_handlers_migrator import (
14
+ base_migrator,
15
+ pytorch_migrator_2023_12_01,
16
+ )
14
17
  from snowflake.ml.model._packager.model_meta import (
15
18
  model_blob_meta,
16
19
  model_meta as model_meta_api,
20
+ model_meta_schema,
17
21
  )
18
22
  from snowflake.ml.model._signatures import (
19
23
  pytorch_handler,
@@ -21,7 +25,6 @@ from snowflake.ml.model._signatures import (
21
25
  )
22
26
 
23
27
  if TYPE_CHECKING:
24
- import sentence_transformers # noqa: F401
25
28
  import torch
26
29
 
27
30
 
@@ -33,9 +36,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
33
36
  """
34
37
 
35
38
  HANDLER_TYPE = "pytorch"
36
- HANDLER_VERSION = "2023-12-01"
37
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
38
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
+ HANDLER_VERSION = "2025-03-01"
40
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
42
+ "2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
43
+ }
39
44
 
40
45
  MODEL_BLOB_FILE_OR_DIR = "model.pt"
41
46
  DEFAULT_TARGET_METHODS = ["forward"]
@@ -49,6 +54,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
49
54
  type_utils.LazyType("torch.nn.Module").isinstance(model)
50
55
  and not type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
51
56
  and not type_utils.LazyType("sentence_transformers.SentenceTransformer").isinstance(model)
57
+ and not type_utils.LazyType("keras.Model").isinstance(model)
52
58
  )
53
59
 
54
60
  @classmethod
@@ -88,22 +94,33 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
88
94
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
89
95
  )
90
96
 
97
+ multiple_inputs = kwargs.get("multiple_inputs", False)
98
+
91
99
  def get_prediction(
92
100
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
93
101
  ) -> model_types.SupportedLocalDataType:
94
- if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
95
- sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
96
- model_signature._convert_local_data_to_df(sample_input_data)
97
- )
102
+ if multiple_inputs:
103
+ if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
104
+ sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
105
+ model_signature._convert_local_data_to_df(sample_input_data)
106
+ )
107
+ else:
108
+ if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
109
+ sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
110
+ model_signature._convert_local_data_to_df(sample_input_data)
111
+ )
98
112
 
99
113
  model.eval()
100
114
  target_method = getattr(model, target_method_name, None)
101
115
  assert callable(target_method)
102
116
  with torch.no_grad():
103
- predictions_df = target_method(*sample_input_data)
117
+ if multiple_inputs:
118
+ predictions_df = target_method(*sample_input_data)
119
+ if not isinstance(predictions_df, tuple):
120
+ predictions_df = [predictions_df]
121
+ else:
122
+ predictions_df = target_method(sample_input_data)
104
123
 
105
- if isinstance(predictions_df, torch.Tensor):
106
- predictions_df = [predictions_df]
107
124
  return predictions_df
108
125
 
109
126
  model_meta = handlers_utils.validate_signature(
@@ -126,6 +143,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
126
143
  model_type=cls.HANDLER_TYPE,
127
144
  handler_version=cls.HANDLER_VERSION,
128
145
  path=cls.MODEL_BLOB_FILE_OR_DIR,
146
+ options=model_meta_schema.PyTorchModelBlobOptions(multiple_inputs=multiple_inputs),
129
147
  )
130
148
  model_meta.models[name] = base_meta
131
149
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -171,6 +189,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
171
189
  raw_model: "torch.nn.Module",
172
190
  model_meta: model_meta_api.ModelMetadata,
173
191
  ) -> Type[custom_model.CustomModel]:
192
+ multiple_inputs = cast(
193
+ model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
194
+ )["multiple_inputs"]
195
+
174
196
  def fn_factory(
175
197
  raw_model: "torch.nn.Module",
176
198
  signature: model_signature.ModelSignature,
@@ -182,19 +204,28 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
182
204
  raise ValueError("Tensor cannot handle null values.")
183
205
 
184
206
  raw_model.eval()
185
- t = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
207
+ if multiple_inputs:
208
+ st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
209
+
210
+ if kwargs.get("use_gpu", False):
211
+ st = [element.cuda() for element in st]
186
212
 
187
- if kwargs.get("use_gpu", False):
188
- t = [element.cuda() for element in t]
213
+ with torch.no_grad():
214
+ res = getattr(raw_model, target_method)(*st)
189
215
 
190
- with torch.no_grad():
191
- res = getattr(raw_model, target_method)(*t)
216
+ if not isinstance(res, tuple):
217
+ res = [res]
218
+ else:
219
+ t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
220
+ if kwargs.get("use_gpu", False):
221
+ t = t.cuda()
192
222
 
193
- if isinstance(res, torch.Tensor):
194
- res = [res]
223
+ with torch.no_grad():
224
+ res = getattr(raw_model, target_method)(t)
195
225
 
196
226
  return model_signature_utils.rename_pandas_df(
197
- data=pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
227
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
228
+ features=signature.outputs,
198
229
  )
199
230
 
200
231
  return fn
@@ -292,12 +292,34 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
292
292
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
293
293
  import shap
294
294
 
295
- # TODO: if not resolved by explainer, we need to pass the callable function
296
295
  try:
297
296
  explainer = shap.Explainer(raw_model, background_data)
298
297
  df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
299
- except TypeError as e:
300
- raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
298
+ except TypeError:
299
+ try:
300
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
301
+
302
+ if isinstance(X, pd.DataFrame):
303
+ X = X.astype(dtype_map, copy=False)
304
+ if hasattr(raw_model, "predict_proba"):
305
+ if isinstance(X, np.ndarray):
306
+ explanations = shap.Explainer(
307
+ raw_model.predict_proba, background_data.values # type: ignore[union-attr]
308
+ )(X).values
309
+ else:
310
+ explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
311
+ elif hasattr(raw_model, "predict"):
312
+ if isinstance(X, np.ndarray):
313
+ explanations = shap.Explainer(
314
+ raw_model.predict, background_data.values # type: ignore[union-attr]
315
+ )(X).values
316
+ else:
317
+ explanations = shap.Explainer(raw_model.predict, background_data)(X).values
318
+ else:
319
+ raise ValueError("Missing any supported target method to explain.")
320
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
321
+ except TypeError as e:
322
+ raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
301
323
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
302
324
 
303
325
  if target_method == "explain":
@@ -74,11 +74,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
74
74
  background_data: Optional[model_types.SupportedDataType],
75
75
  enable_explainability: Optional[bool],
76
76
  ) -> Any:
77
- from snowflake.ml.modeling import pipeline as snowml_pipeline
78
-
79
- # handle pipeline objects separately
80
- if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
81
- return None
82
77
 
83
78
  tree_methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
84
79
  non_tree_methods = ["to_sklearn"]
@@ -129,27 +124,54 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
129
124
  # Pipeline is inherited from BaseEstimator, so no need to add one more check
130
125
 
131
126
  if not is_sub_model:
132
- if model_meta.signatures:
127
+ if model_meta.signatures or sample_input_data is not None:
133
128
  warnings.warn(
134
129
  "Providing model signature for Snowpark ML "
135
130
  + "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
136
131
  UserWarning,
137
132
  stacklevel=2,
138
133
  )
139
- assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
140
- model_signature_dict = getattr(model, "model_signatures", {})
141
- target_methods = kwargs.pop("target_methods", None)
142
- if not target_methods:
143
- model_meta.signatures = model_signature_dict
134
+ target_methods = handlers_utils.get_target_methods(
135
+ model=model,
136
+ target_methods=kwargs.pop("target_methods", None),
137
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
138
+ )
139
+
140
+ def get_prediction(
141
+ target_method_name: str,
142
+ sample_input_data: model_types.SupportedLocalDataType,
143
+ ) -> model_types.SupportedLocalDataType:
144
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
145
+ sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
146
+
147
+ target_method = getattr(model, target_method_name, None)
148
+ assert callable(target_method)
149
+ predictions_df = target_method(sample_input_data)
150
+ return predictions_df
151
+
152
+ model_meta = handlers_utils.validate_signature(
153
+ model=model,
154
+ model_meta=model_meta,
155
+ target_methods=target_methods,
156
+ sample_input_data=sample_input_data,
157
+ get_prediction_fn=get_prediction,
158
+ is_for_modeling_model=True,
159
+ )
144
160
  else:
145
- temp_model_signature_dict = {}
146
- for method_name in target_methods:
147
- method_model_signature = model_signature_dict.get(method_name, None)
148
- if method_model_signature is not None:
149
- temp_model_signature_dict[method_name] = method_model_signature
150
- else:
151
- raise ValueError(f"Target method {method_name} does not exist in the model.")
152
- model_meta.signatures = temp_model_signature_dict
161
+ assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
162
+ model_signature_dict = getattr(model, "model_signatures", {})
163
+ optional_target_methods = kwargs.pop("target_methods", None)
164
+ if not optional_target_methods:
165
+ model_meta.signatures = model_signature_dict
166
+ else:
167
+ temp_model_signature_dict = {}
168
+ for method_name in optional_target_methods:
169
+ method_model_signature = model_signature_dict.get(method_name, None)
170
+ if method_model_signature is not None:
171
+ temp_model_signature_dict[method_name] = method_model_signature
172
+ else:
173
+ raise ValueError(f"Target method {method_name} does not exist in the model.")
174
+ model_meta.signatures = temp_model_signature_dict
153
175
 
154
176
  python_base_obj = cls._get_supported_object_for_explainability(model, sample_input_data, enable_explainability)
155
177
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
@@ -279,9 +301,39 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
279
301
  for method_name in non_tree_methods:
280
302
  try:
281
303
  base_model = getattr(raw_model, method_name)()
282
- explainer = shap.Explainer(base_model, masker=background_data)
283
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
304
+ try:
305
+ explainer = shap.Explainer(base_model, masker=background_data)
306
+ df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
307
+ except TypeError:
308
+ try:
309
+ dtype_map = {
310
+ spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
311
+ }
312
+
313
+ if isinstance(X, pd.DataFrame):
314
+ X = X.astype(dtype_map, copy=False)
315
+ if hasattr(base_model, "predict_proba"):
316
+ if isinstance(X, np.ndarray):
317
+ explainer = shap.Explainer(
318
+ base_model.predict_proba,
319
+ background_data.values, # type: ignore[union-attr]
320
+ )
321
+ else:
322
+ explainer = shap.Explainer(base_model.predict_proba, background_data)
323
+ elif hasattr(base_model, "predict"):
324
+ if isinstance(X, np.ndarray):
325
+ explainer = shap.Explainer(
326
+ base_model.predict, background_data.values # type: ignore[union-attr]
327
+ )
328
+ else:
329
+ explainer = shap.Explainer(base_model.predict, background_data)
330
+ else:
331
+ raise ValueError("Missing any supported target method to explain.")
332
+ df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
333
+ except TypeError as e:
334
+ raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
284
335
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
336
+
285
337
  except exceptions.SnowflakeMLException:
286
338
  pass # Do nothing and continue to the next method
287
339
  raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
3
 
4
- import numpy as np
5
4
  import pandas as pd
6
5
  from packaging import version
7
6
  from typing_extensions import TypeGuard, Unpack
@@ -10,14 +9,17 @@ from snowflake.ml._internal import type_utils
10
9
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
10
  from snowflake.ml.model._packager.model_env import model_env
12
11
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
12
+ from snowflake.ml.model._packager.model_handlers_migrator import (
13
+ base_migrator,
14
+ tensorflow_migrator_2023_12_01,
15
+ tensorflow_migrator_2025_01_01,
16
+ )
14
17
  from snowflake.ml.model._packager.model_meta import (
15
18
  model_blob_meta,
16
19
  model_meta as model_meta_api,
17
20
  model_meta_schema,
18
21
  )
19
22
  from snowflake.ml.model._signatures import (
20
- numpy_handler,
21
23
  tensorflow_handler,
22
24
  utils as model_signature_utils,
23
25
  )
@@ -28,15 +30,18 @@ if TYPE_CHECKING:
28
30
 
29
31
  @final
30
32
  class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
31
- """Handler for TensorFlow based model.
33
+ """Handler for TensorFlow based model or keras v2 model.
32
34
 
33
35
  Currently tensorflow.Module based classes are supported.
34
36
  """
35
37
 
36
38
  HANDLER_TYPE = "tensorflow"
37
- HANDLER_VERSION = "2023-12-01"
38
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
+ HANDLER_VERSION = "2025-03-01"
40
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
42
+ "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
43
+ "2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
44
+ }
40
45
 
41
46
  MODEL_BLOB_FILE_OR_DIR = "model"
42
47
  DEFAULT_TARGET_METHODS = ["__call__"]
@@ -46,7 +51,13 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
46
51
  cls,
47
52
  model: model_types.SupportedModelType,
48
53
  ) -> TypeGuard["tensorflow.nn.Module"]:
49
- return type_utils.LazyType("tensorflow.Module").isinstance(model)
54
+ if not type_utils.LazyType("tensorflow.Module").isinstance(model):
55
+ return False
56
+ if type_utils.LazyType("keras.Model").isinstance(model):
57
+ import keras
58
+
59
+ return version.parse(keras.__version__) < version.parse("3.0.0")
60
+ return True
50
61
 
51
62
  @classmethod
52
63
  def cast_model(
@@ -74,44 +85,22 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
74
85
  if enable_explainability:
75
86
  raise NotImplementedError("Explainability is not supported for Tensorflow model.")
76
87
 
77
- # When tensorflow is installed, keras is also installed.
78
- import keras
79
88
  import tensorflow
80
89
 
81
90
  assert isinstance(model, tensorflow.Module)
82
91
 
83
- is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
84
- "keras.Model"
85
- ).isinstance(model)
92
+ is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
86
93
  is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
87
- is_keras_functional_or_sequential_model = (
88
- getattr(model, "_is_graph_network", False)
89
- or type_utils.LazyType("tensorflow.keras.engine.sequential.Sequential").isinstance(model)
90
- or type_utils.LazyType("keras.engine.sequential.Sequential").isinstance(model)
91
- or type_utils.LazyType("tf_keras.engine.sequential.Sequential").isinstance(model)
92
- )
93
-
94
- assert isinstance(model, tensorflow.Module)
95
-
96
- keras_version = version.parse(keras.__version__)
97
-
98
94
  # Tensorflow and keras model save format is different.
99
- # Keras functional or sequential models are saved as keras format
100
- # Keras v3 other models are saved using cloudpickle
101
- # Keras v2 other models are saved using tensorflow saved model format
102
- # Tensorflow models are saved using tensorflow saved model format
95
+ # Keras v2 models are saved using keras api
96
+ # Tensorflow models are saved using tensorflow api
103
97
 
104
98
  if is_keras_model or is_tf_keras_model:
105
- if is_keras_functional_or_sequential_model:
106
- save_format = "keras"
107
- elif keras_version.major == 2 or is_tf_keras_model:
108
- save_format = "keras_tf"
109
- else:
110
- save_format = "cloudpickle"
99
+ save_format = "keras_tf"
111
100
  else:
112
101
  save_format = "tf"
113
102
 
114
- if is_keras_model:
103
+ if is_keras_model or is_tf_keras_model:
115
104
  default_target_methods = ["predict"]
116
105
  else:
117
106
  default_target_methods = cls.DEFAULT_TARGET_METHODS
@@ -123,25 +112,35 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
123
112
  default_target_methods=default_target_methods,
124
113
  )
125
114
 
115
+ multiple_inputs = kwargs.get("multiple_inputs", False)
116
+
126
117
  if is_keras_model and len(target_methods) > 1:
127
118
  raise ValueError("Keras model can only have one target method.")
128
119
 
129
120
  def get_prediction(
130
121
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
131
122
  ) -> model_types.SupportedLocalDataType:
132
- if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
133
- sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
134
- model_signature._convert_local_data_to_df(sample_input_data)
135
- )
123
+ if multiple_inputs:
124
+ if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
125
+ sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
126
+ model_signature._convert_local_data_to_df(sample_input_data)
127
+ )
128
+ else:
129
+ if not tensorflow_handler.TensorflowTensorHandler.can_handle(sample_input_data):
130
+ sample_input_data = tensorflow_handler.TensorflowTensorHandler.convert_from_df(
131
+ model_signature._convert_local_data_to_df(sample_input_data)
132
+ )
136
133
 
137
134
  target_method = getattr(model, target_method_name, None)
138
135
  assert callable(target_method)
139
136
  for tensor in sample_input_data:
140
137
  tensorflow.stop_gradient(tensor)
141
- predictions_df = target_method(*sample_input_data)
142
-
143
- if isinstance(predictions_df, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
144
- predictions_df = [predictions_df]
138
+ if multiple_inputs:
139
+ predictions_df = target_method(*sample_input_data)
140
+ if not isinstance(predictions_df, tuple):
141
+ predictions_df = [predictions_df]
142
+ else:
143
+ predictions_df = target_method(sample_input_data)
145
144
 
146
145
  return predictions_df
147
146
 
@@ -156,15 +155,8 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
156
155
  model_blob_path = os.path.join(model_blobs_dir_path, name)
157
156
  os.makedirs(model_blob_path, exist_ok=True)
158
157
  save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
159
- if save_format == "keras":
160
- model.save(save_path, save_format="keras")
161
- elif save_format == "keras_tf":
158
+ if save_format == "keras_tf":
162
159
  model.save(save_path, save_format="tf")
163
- elif save_format == "cloudpickle":
164
- import cloudpickle
165
-
166
- with open(save_path, "wb") as f:
167
- cloudpickle.dump(model, f)
168
160
  else:
169
161
  tensorflow.saved_model.save(
170
162
  model,
@@ -177,7 +169,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
177
169
  model_type=cls.HANDLER_TYPE,
178
170
  handler_version=cls.HANDLER_VERSION,
179
171
  path=cls.MODEL_BLOB_FILE_OR_DIR,
180
- options=model_meta_schema.TensorflowModelBlobOptions(save_format=save_format),
172
+ options=model_meta_schema.TensorflowModelBlobOptions(
173
+ save_format=save_format, multiple_inputs=multiple_inputs
174
+ ),
181
175
  )
182
176
  model_meta.models[name] = base_meta
183
177
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -186,7 +180,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
186
180
  model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
187
181
  ]
188
182
  if is_keras_model:
189
- dependencies.append(model_env.ModelDependency(requirement="keras", pip_name="keras"))
183
+ dependencies.append(model_env.ModelDependency(requirement="keras<=3", pip_name="keras"))
190
184
  elif is_tf_keras_model:
191
185
  dependencies.append(model_env.ModelDependency(requirement="tf-keras", pip_name="tf-keras"))
192
186
 
@@ -204,6 +198,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
204
198
  model_blobs_dir_path: str,
205
199
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
206
200
  ) -> "tensorflow.Module":
201
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
207
202
  import tensorflow
208
203
 
209
204
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -212,14 +207,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
212
207
  model_blob_filename = model_blob_metadata.path
213
208
  model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
214
209
  load_path = os.path.join(model_blob_path, model_blob_filename)
215
- save_format = model_blob_options.get("save_format", "tf")
216
- if save_format == "keras" or save_format == "keras_tf":
210
+ save_format = model_blob_options.get("save_format", "keras_tf")
211
+ if save_format == "keras_tf":
217
212
  m = tensorflow.keras.models.load_model(load_path)
218
- elif save_format == "cloudpickle":
219
- import cloudpickle
220
-
221
- with open(load_path, "rb") as f:
222
- m = cloudpickle.load(f)
223
213
  else:
224
214
  m = tensorflow.saved_model.load(load_path)
225
215
 
@@ -241,6 +231,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
241
231
  raw_model: "tensorflow.Module",
242
232
  model_meta: model_meta_api.ModelMetadata,
243
233
  ) -> Type[custom_model.CustomModel]:
234
+ multiple_inputs = cast(
235
+ model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
236
+ )["multiple_inputs"]
237
+
244
238
  def fn_factory(
245
239
  raw_model: "tensorflow.Module",
246
240
  signature: model_signature.ModelSignature,
@@ -251,21 +245,25 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
251
245
  if X.isnull().any(axis=None):
252
246
  raise ValueError("Tensor cannot handle null values.")
253
247
 
254
- t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
248
+ if multiple_inputs:
249
+ t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
255
250
 
256
- for tensor in t:
257
- tensorflow.stop_gradient(tensor)
258
- res = getattr(raw_model, target_method)(*t)
251
+ for tensor in t:
252
+ tensorflow.stop_gradient(tensor)
253
+ res = getattr(raw_model, target_method)(*t)
259
254
 
260
- if isinstance(res, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
261
- res = [res]
262
-
263
- if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
264
- # In case of running on CPU, it will return numpy array
265
- df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
255
+ if not isinstance(res, tuple):
256
+ res = [res]
266
257
  else:
267
- df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df(res)
268
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
258
+ t = tensorflow_handler.TensorflowTensorHandler.convert_from_df(X, signature.inputs)
259
+
260
+ tensorflow.stop_gradient(t)
261
+ res = getattr(raw_model, target_method)(t)
262
+
263
+ return model_signature_utils.rename_pandas_df(
264
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
265
+ features=signature.outputs,
266
+ )
269
267
 
270
268
  return fn
271
269