snowflake-ml-python 1.7.4__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/ml/_internal/env_utils.py +64 -21
  2. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  3. snowflake/ml/_internal/telemetry.py +21 -0
  4. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  5. snowflake/ml/feature_store/feature_store.py +18 -0
  6. snowflake/ml/feature_store/feature_view.py +46 -1
  7. snowflake/ml/jobs/_utils/constants.py +7 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +139 -53
  9. snowflake/ml/jobs/_utils/spec_utils.py +5 -7
  10. snowflake/ml/jobs/decorators.py +5 -25
  11. snowflake/ml/jobs/job.py +4 -4
  12. snowflake/ml/model/_packager/model_env/model_env.py +45 -28
  13. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  14. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +16 -0
  15. snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
  16. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
  18. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
  19. snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
  20. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  21. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
  22. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
  23. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  24. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
  25. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  26. snowflake/ml/model/_signatures/core.py +2 -2
  27. snowflake/ml/model/_signatures/numpy_handler.py +5 -5
  28. snowflake/ml/model/_signatures/pandas_handler.py +9 -7
  29. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  30. snowflake/ml/model/model_signature.py +8 -0
  31. snowflake/ml/model/type_hints.py +15 -0
  32. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  33. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  34. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  35. snowflake/ml/registry/registry.py +34 -4
  36. snowflake/ml/version.py +1 -1
  37. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +58 -25
  38. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +41 -38
  39. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
  40. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
  41. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,230 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+
4
+ import cloudpickle
5
+ import numpy as np
6
+ import pandas as pd
7
+ from packaging import version
8
+ from typing_extensions import TypeGuard, Unpack
9
+
10
+ from snowflake.ml._internal import type_utils
11
+ from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
+ from snowflake.ml.model._packager.model_env import model_env
13
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
14
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
15
+ from snowflake.ml.model._packager.model_meta import (
16
+ model_blob_meta,
17
+ model_meta as model_meta_api,
18
+ )
19
+ from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
20
+
21
+ if TYPE_CHECKING:
22
+ import keras
23
+
24
+
25
+ @final
26
+ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
27
+ """Handler for Keras v3 model.
28
+
29
+ Currently keras.Model based classes are supported.
30
+ """
31
+
32
+ HANDLER_TYPE = "keras"
33
+ HANDLER_VERSION = "2025-01-01"
34
+ _MIN_SNOWPARK_ML_VERSION = "1.7.5"
35
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
36
+
37
+ MODEL_BLOB_FILE_OR_DIR = "model.keras"
38
+ CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
39
+ DEFAULT_TARGET_METHODS = ["predict"]
40
+
41
+ @classmethod
42
+ def can_handle(
43
+ cls,
44
+ model: model_types.SupportedModelType,
45
+ ) -> TypeGuard["keras.Model"]:
46
+ if not type_utils.LazyType("keras.Model").isinstance(model):
47
+ return False
48
+ import keras
49
+
50
+ return version.parse(keras.__version__) >= version.parse("3.0.0")
51
+
52
+ @classmethod
53
+ def cast_model(
54
+ cls,
55
+ model: model_types.SupportedModelType,
56
+ ) -> "keras.Model":
57
+ import keras
58
+
59
+ assert isinstance(model, keras.Model)
60
+
61
+ return cast(keras.Model, model)
62
+
63
+ @classmethod
64
+ def save_model(
65
+ cls,
66
+ name: str,
67
+ model: "keras.Model",
68
+ model_meta: model_meta_api.ModelMetadata,
69
+ model_blobs_dir_path: str,
70
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
71
+ is_sub_model: Optional[bool] = False,
72
+ **kwargs: Unpack[model_types.TensorflowSaveOptions],
73
+ ) -> None:
74
+ enable_explainability = kwargs.get("enable_explainability", False)
75
+ if enable_explainability:
76
+ raise NotImplementedError("Explainability is not supported for Tensorflow model.")
77
+
78
+ import keras
79
+
80
+ assert isinstance(model, keras.Model)
81
+
82
+ if not is_sub_model:
83
+ target_methods = handlers_utils.get_target_methods(
84
+ model=model,
85
+ target_methods=kwargs.pop("target_methods", None),
86
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
87
+ )
88
+
89
+ def get_prediction(
90
+ target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
91
+ ) -> model_types.SupportedLocalDataType:
92
+ target_method = getattr(model, target_method_name, None)
93
+ assert callable(target_method)
94
+ predictions_df = target_method(sample_input_data)
95
+
96
+ if (
97
+ type_utils.LazyType("tensorflow.Tensor").isinstance(predictions_df)
98
+ or type_utils.LazyType("tensorflow.Variable").isinstance(predictions_df)
99
+ or type_utils.LazyType("torch.Tensor").isinstance(predictions_df)
100
+ ):
101
+ predictions_df = [predictions_df]
102
+
103
+ return predictions_df
104
+
105
+ model_meta = handlers_utils.validate_signature(
106
+ model=model,
107
+ model_meta=model_meta,
108
+ target_methods=target_methods,
109
+ sample_input_data=sample_input_data,
110
+ get_prediction_fn=get_prediction,
111
+ )
112
+
113
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
114
+ os.makedirs(model_blob_path, exist_ok=True)
115
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
116
+ model.save(save_path)
117
+
118
+ custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
119
+ custom_objects = keras.saving.get_custom_objects()
120
+ with open(custom_object_save_path, "wb") as f:
121
+ cloudpickle.dump(custom_objects, f)
122
+
123
+ base_meta = model_blob_meta.ModelBlobMeta(
124
+ name=name,
125
+ model_type=cls.HANDLER_TYPE,
126
+ handler_version=cls.HANDLER_VERSION,
127
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
128
+ )
129
+ model_meta.models[name] = base_meta
130
+ model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
131
+
132
+ dependencies = [
133
+ model_env.ModelDependency(requirement="keras>=3", pip_name="keras"),
134
+ ]
135
+ keras_backend = keras.backend.backend()
136
+ if keras_backend == "tensorflow":
137
+ dependencies.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
138
+ elif keras_backend == "torch":
139
+ dependencies.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
140
+ elif keras_backend == "jax":
141
+ dependencies.append(model_env.ModelDependency(requirement="jax", pip_name="jax"))
142
+ else:
143
+ raise ValueError(f"Unsupported backend {keras_backend}")
144
+
145
+ model_meta.env.include_if_absent(
146
+ dependencies,
147
+ check_local_version=True,
148
+ )
149
+ model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
150
+
151
+ @classmethod
152
+ def load_model(
153
+ cls,
154
+ name: str,
155
+ model_meta: model_meta_api.ModelMetadata,
156
+ model_blobs_dir_path: str,
157
+ **kwargs: Unpack[model_types.TensorflowLoadOptions],
158
+ ) -> "keras.Model":
159
+ import keras
160
+
161
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
162
+ model_blobs_metadata = model_meta.models
163
+ model_blob_metadata = model_blobs_metadata[name]
164
+ model_blob_filename = model_blob_metadata.path
165
+
166
+ custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
167
+ with open(custom_object_save_path, "rb") as f:
168
+ custom_objects = cloudpickle.load(f)
169
+ load_path = os.path.join(model_blob_path, model_blob_filename)
170
+ m = keras.models.load_model(load_path, custom_objects=custom_objects, safe_mode=False)
171
+
172
+ return cast(keras.Model, m)
173
+
174
+ @classmethod
175
+ def convert_as_custom_model(
176
+ cls,
177
+ raw_model: "keras.Model",
178
+ model_meta: model_meta_api.ModelMetadata,
179
+ background_data: Optional[pd.DataFrame] = None,
180
+ **kwargs: Unpack[model_types.TensorflowLoadOptions],
181
+ ) -> custom_model.CustomModel:
182
+
183
+ from snowflake.ml.model import custom_model
184
+
185
+ def _create_custom_model(
186
+ raw_model: "keras.Model",
187
+ model_meta: model_meta_api.ModelMetadata,
188
+ ) -> Type[custom_model.CustomModel]:
189
+ def fn_factory(
190
+ raw_model: "keras.Model",
191
+ signature: model_signature.ModelSignature,
192
+ target_method: str,
193
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
194
+ dtype_map = {
195
+ spec.name: spec.as_dtype(force_numpy_dtype=True)
196
+ for spec in signature.inputs
197
+ if isinstance(spec, model_signature.FeatureSpec)
198
+ }
199
+
200
+ @custom_model.inference_api
201
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
202
+ res = getattr(raw_model, target_method)(X.astype(dtype_map), verbose=0)
203
+
204
+ if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
205
+ # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
206
+ # return a list of ndarrays. We need to deal them separately
207
+ df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
208
+ else:
209
+ df = pd.DataFrame(res)
210
+
211
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
212
+
213
+ return fn
214
+
215
+ type_method_dict = {}
216
+ for target_method_name, sig in model_meta.signatures.items():
217
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
218
+
219
+ _KerasModel = type(
220
+ "_KerasModel",
221
+ (custom_model.CustomModel,),
222
+ type_method_dict,
223
+ )
224
+
225
+ return _KerasModel
226
+
227
+ _KerasModel = _create_custom_model(raw_model, model_meta)
228
+ keras_model = _KerasModel(custom_model.ModelContext())
229
+
230
+ return keras_model
@@ -49,6 +49,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
49
49
  type_utils.LazyType("torch.nn.Module").isinstance(model)
50
50
  and not type_utils.LazyType("torch.jit.ScriptModule").isinstance(model)
51
51
  and not type_utils.LazyType("sentence_transformers.SentenceTransformer").isinstance(model)
52
+ and not type_utils.LazyType("keras.Model").isinstance(model)
52
53
  )
53
54
 
54
55
  @classmethod
@@ -292,12 +292,37 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
292
292
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
293
293
  import shap
294
294
 
295
- # TODO: if not resolved by explainer, we need to pass the callable function
296
295
  try:
297
296
  explainer = shap.Explainer(raw_model, background_data)
298
297
  df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
299
- except TypeError as e:
300
- raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
298
+ except TypeError:
299
+ try:
300
+ dtype_map = {
301
+ spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
302
+ for spec in signature.inputs
303
+ }
304
+
305
+ if isinstance(X, pd.DataFrame):
306
+ X = X.astype(dtype_map, copy=False)
307
+ if hasattr(raw_model, "predict_proba"):
308
+ if isinstance(X, np.ndarray):
309
+ explanations = shap.Explainer(
310
+ raw_model.predict_proba, background_data.values # type: ignore[union-attr]
311
+ )(X).values
312
+ else:
313
+ explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
314
+ elif hasattr(raw_model, "predict"):
315
+ if isinstance(X, np.ndarray):
316
+ explanations = shap.Explainer(
317
+ raw_model.predict, background_data.values # type: ignore[union-attr]
318
+ )(X).values
319
+ else:
320
+ explanations = shap.Explainer(raw_model.predict, background_data)(X).values
321
+ else:
322
+ raise ValueError("Missing any supported target method to explain.")
323
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
324
+ except TypeError as e:
325
+ raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
301
326
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
302
327
 
303
328
  if target_method == "explain":
@@ -74,11 +74,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
74
74
  background_data: Optional[model_types.SupportedDataType],
75
75
  enable_explainability: Optional[bool],
76
76
  ) -> Any:
77
- from snowflake.ml.modeling import pipeline as snowml_pipeline
78
-
79
- # handle pipeline objects separately
80
- if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
81
- return None
82
77
 
83
78
  tree_methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
84
79
  non_tree_methods = ["to_sklearn"]
@@ -129,27 +124,54 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
129
124
  # Pipeline is inherited from BaseEstimator, so no need to add one more check
130
125
 
131
126
  if not is_sub_model:
132
- if model_meta.signatures:
127
+ if model_meta.signatures or sample_input_data is not None:
133
128
  warnings.warn(
134
129
  "Providing model signature for Snowpark ML "
135
130
  + "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
136
131
  UserWarning,
137
132
  stacklevel=2,
138
133
  )
139
- assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
140
- model_signature_dict = getattr(model, "model_signatures", {})
141
- target_methods = kwargs.pop("target_methods", None)
142
- if not target_methods:
143
- model_meta.signatures = model_signature_dict
134
+ target_methods = handlers_utils.get_target_methods(
135
+ model=model,
136
+ target_methods=kwargs.pop("target_methods", None),
137
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
138
+ )
139
+
140
+ def get_prediction(
141
+ target_method_name: str,
142
+ sample_input_data: model_types.SupportedLocalDataType,
143
+ ) -> model_types.SupportedLocalDataType:
144
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
145
+ sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
146
+
147
+ target_method = getattr(model, target_method_name, None)
148
+ assert callable(target_method)
149
+ predictions_df = target_method(sample_input_data)
150
+ return predictions_df
151
+
152
+ model_meta = handlers_utils.validate_signature(
153
+ model=model,
154
+ model_meta=model_meta,
155
+ target_methods=target_methods,
156
+ sample_input_data=sample_input_data,
157
+ get_prediction_fn=get_prediction,
158
+ is_for_modeling_model=True,
159
+ )
144
160
  else:
145
- temp_model_signature_dict = {}
146
- for method_name in target_methods:
147
- method_model_signature = model_signature_dict.get(method_name, None)
148
- if method_model_signature is not None:
149
- temp_model_signature_dict[method_name] = method_model_signature
150
- else:
151
- raise ValueError(f"Target method {method_name} does not exist in the model.")
152
- model_meta.signatures = temp_model_signature_dict
161
+ assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
162
+ model_signature_dict = getattr(model, "model_signatures", {})
163
+ optional_target_methods = kwargs.pop("target_methods", None)
164
+ if not optional_target_methods:
165
+ model_meta.signatures = model_signature_dict
166
+ else:
167
+ temp_model_signature_dict = {}
168
+ for method_name in optional_target_methods:
169
+ method_model_signature = model_signature_dict.get(method_name, None)
170
+ if method_model_signature is not None:
171
+ temp_model_signature_dict[method_name] = method_model_signature
172
+ else:
173
+ raise ValueError(f"Target method {method_name} does not exist in the model.")
174
+ model_meta.signatures = temp_model_signature_dict
153
175
 
154
176
  python_base_obj = cls._get_supported_object_for_explainability(model, sample_input_data, enable_explainability)
155
177
  explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
@@ -279,9 +301,40 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
279
301
  for method_name in non_tree_methods:
280
302
  try:
281
303
  base_model = getattr(raw_model, method_name)()
282
- explainer = shap.Explainer(base_model, masker=background_data)
283
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
304
+ try:
305
+ explainer = shap.Explainer(base_model, masker=background_data)
306
+ df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
307
+ except TypeError:
308
+ try:
309
+ dtype_map = {
310
+ spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
311
+ for spec in signature.inputs
312
+ }
313
+
314
+ if isinstance(X, pd.DataFrame):
315
+ X = X.astype(dtype_map, copy=False)
316
+ if hasattr(base_model, "predict_proba"):
317
+ if isinstance(X, np.ndarray):
318
+ explainer = shap.Explainer(
319
+ base_model.predict_proba,
320
+ background_data.values, # type: ignore[union-attr]
321
+ )
322
+ else:
323
+ explainer = shap.Explainer(base_model.predict_proba, background_data)
324
+ elif hasattr(base_model, "predict"):
325
+ if isinstance(X, np.ndarray):
326
+ explainer = shap.Explainer(
327
+ base_model.predict, background_data.values # type: ignore[union-attr]
328
+ )
329
+ else:
330
+ explainer = shap.Explainer(base_model.predict, background_data)
331
+ else:
332
+ raise ValueError("Missing any supported target method to explain.")
333
+ df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
334
+ except TypeError as e:
335
+ raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
284
336
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
337
+
285
338
  except exceptions.SnowflakeMLException:
286
339
  pass # Do nothing and continue to the next method
287
340
  raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
@@ -10,7 +10,10 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._packager.model_env import model_env
12
12
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
+ from snowflake.ml.model._packager.model_handlers_migrator import (
14
+ base_migrator,
15
+ tensorflow_migrator_2023_12_01,
16
+ )
14
17
  from snowflake.ml.model._packager.model_meta import (
15
18
  model_blob_meta,
16
19
  model_meta as model_meta_api,
@@ -28,15 +31,17 @@ if TYPE_CHECKING:
28
31
 
29
32
  @final
30
33
  class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
31
- """Handler for TensorFlow based model.
34
+ """Handler for TensorFlow based model or keras v2 model.
32
35
 
33
36
  Currently tensorflow.Module based classes are supported.
34
37
  """
35
38
 
36
39
  HANDLER_TYPE = "tensorflow"
37
- HANDLER_VERSION = "2023-12-01"
38
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
39
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
40
+ HANDLER_VERSION = "2025-01-01"
41
+ _MIN_SNOWPARK_ML_VERSION = "1.7.5"
42
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
43
+ "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201
44
+ }
40
45
 
41
46
  MODEL_BLOB_FILE_OR_DIR = "model"
42
47
  DEFAULT_TARGET_METHODS = ["__call__"]
@@ -46,7 +51,13 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
46
51
  cls,
47
52
  model: model_types.SupportedModelType,
48
53
  ) -> TypeGuard["tensorflow.nn.Module"]:
49
- return type_utils.LazyType("tensorflow.Module").isinstance(model)
54
+ if not type_utils.LazyType("tensorflow.Module").isinstance(model):
55
+ return False
56
+ if type_utils.LazyType("keras.Model").isinstance(model):
57
+ import keras
58
+
59
+ return version.parse(keras.__version__) < version.parse("3.0.0")
60
+ return True
50
61
 
51
62
  @classmethod
52
63
  def cast_model(
@@ -74,44 +85,22 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
74
85
  if enable_explainability:
75
86
  raise NotImplementedError("Explainability is not supported for Tensorflow model.")
76
87
 
77
- # When tensorflow is installed, keras is also installed.
78
- import keras
79
88
  import tensorflow
80
89
 
81
90
  assert isinstance(model, tensorflow.Module)
82
91
 
83
- is_keras_model = type_utils.LazyType("tensorflow.keras.Model").isinstance(model) or type_utils.LazyType(
84
- "keras.Model"
85
- ).isinstance(model)
92
+ is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
86
93
  is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
87
- is_keras_functional_or_sequential_model = (
88
- getattr(model, "_is_graph_network", False)
89
- or type_utils.LazyType("tensorflow.keras.engine.sequential.Sequential").isinstance(model)
90
- or type_utils.LazyType("keras.engine.sequential.Sequential").isinstance(model)
91
- or type_utils.LazyType("tf_keras.engine.sequential.Sequential").isinstance(model)
92
- )
93
-
94
- assert isinstance(model, tensorflow.Module)
95
-
96
- keras_version = version.parse(keras.__version__)
97
-
98
94
  # Tensorflow and keras model save format is different.
99
- # Keras functional or sequential models are saved as keras format
100
- # Keras v3 other models are saved using cloudpickle
101
- # Keras v2 other models are saved using tensorflow saved model format
102
- # Tensorflow models are saved using tensorflow saved model format
95
+ # Keras v2 models are saved using keras api
96
+ # Tensorflow models are saved using tensorflow api
103
97
 
104
98
  if is_keras_model or is_tf_keras_model:
105
- if is_keras_functional_or_sequential_model:
106
- save_format = "keras"
107
- elif keras_version.major == 2 or is_tf_keras_model:
108
- save_format = "keras_tf"
109
- else:
110
- save_format = "cloudpickle"
99
+ save_format = "keras_tf"
111
100
  else:
112
101
  save_format = "tf"
113
102
 
114
- if is_keras_model:
103
+ if is_keras_model or is_tf_keras_model:
115
104
  default_target_methods = ["predict"]
116
105
  else:
117
106
  default_target_methods = cls.DEFAULT_TARGET_METHODS
@@ -156,15 +145,8 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
156
145
  model_blob_path = os.path.join(model_blobs_dir_path, name)
157
146
  os.makedirs(model_blob_path, exist_ok=True)
158
147
  save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
159
- if save_format == "keras":
160
- model.save(save_path, save_format="keras")
161
- elif save_format == "keras_tf":
148
+ if save_format == "keras_tf":
162
149
  model.save(save_path, save_format="tf")
163
- elif save_format == "cloudpickle":
164
- import cloudpickle
165
-
166
- with open(save_path, "wb") as f:
167
- cloudpickle.dump(model, f)
168
150
  else:
169
151
  tensorflow.saved_model.save(
170
152
  model,
@@ -186,7 +168,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
186
168
  model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"),
187
169
  ]
188
170
  if is_keras_model:
189
- dependencies.append(model_env.ModelDependency(requirement="keras", pip_name="keras"))
171
+ dependencies.append(model_env.ModelDependency(requirement="keras<=3", pip_name="keras"))
190
172
  elif is_tf_keras_model:
191
173
  dependencies.append(model_env.ModelDependency(requirement="tf-keras", pip_name="tf-keras"))
192
174
 
@@ -204,6 +186,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
204
186
  model_blobs_dir_path: str,
205
187
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
206
188
  ) -> "tensorflow.Module":
189
+ os.environ["TF_USE_LEGACY_KERAS"] = "1"
207
190
  import tensorflow
208
191
 
209
192
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -212,14 +195,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
212
195
  model_blob_filename = model_blob_metadata.path
213
196
  model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
214
197
  load_path = os.path.join(model_blob_path, model_blob_filename)
215
- save_format = model_blob_options.get("save_format", "tf")
216
- if save_format == "keras" or save_format == "keras_tf":
198
+ save_format = model_blob_options.get("save_format", "keras_tf")
199
+ if save_format == "keras_tf":
217
200
  m = tensorflow.keras.models.load_model(load_path)
218
- elif save_format == "cloudpickle":
219
- import cloudpickle
220
-
221
- with open(load_path, "rb") as f:
222
- m = cloudpickle.load(f)
223
201
  else:
224
202
  m = tensorflow.saved_model.load(load_path)
225
203
 
@@ -0,0 +1,48 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TensorflowHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-01-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
19
+ # To migrate code <= 1.7.0, default to keras model
20
+ is_old_model = "save_format" not in model_blob_options and "is_keras_model" not in model_blob_options
21
+ # To migrate code form 1.7.1, default to False.
22
+ is_keras_model = model_blob_options.get("is_keras_model", False)
23
+ # To migrate code from 1.7.2, default to tf, has options keras, keras_tf, cloudpickle, tf
24
+ #
25
+ # if is_keras_model or is_tf_keras_model:
26
+ # if is_keras_functional_or_sequential_model:
27
+ # save_format = "keras"
28
+ # elif keras_version.major == 2 or is_tf_keras_model:
29
+ # save_format = "keras_tf"
30
+ # else:
31
+ # save_format = "cloudpickle"
32
+ # else:
33
+ # save_format = "tf"
34
+ #
35
+ save_format = model_blob_options.get("save_format", "tf")
36
+
37
+ if save_format == "keras" or is_keras_model or is_old_model:
38
+ save_format = "keras_tf"
39
+ elif save_format == "cloudpickle":
40
+ # Given the old logic, this could only happen if the original model is a keras model, and keras is 3.x
41
+ # However, in this case, keras.Model does not extends from tensorflow.Module
42
+ # So actually TensorflowHandler will not be triggered, we could safely error this out.
43
+ raise NotImplementedError(
44
+ "Unable to upgrade keras 3.x model saved by old handler. This is not supposed to happen"
45
+ )
46
+
47
+ model_blob_options["save_format"] = save_format
48
+ model_meta.models[name].options = model_blob_options
@@ -352,7 +352,7 @@ class ModelMetadata:
352
352
  version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(snowml_env.VERSION)
353
353
  ):
354
354
  raise RuntimeError(
355
- f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version},"
355
+ f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version}, "
356
356
  f"while current version of Snowpark ML library is {snowml_env.VERSION}."
357
357
  )
358
358
  return model_meta_schema.ModelMetadataDict(
@@ -44,6 +44,9 @@ class CatBoostModelBlobOptions(BaseModelBlobOptions):
44
44
  class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
45
45
  task: Required[str]
46
46
  batch_size: Required[int]
47
+ has_tokenizer: NotRequired[bool]
48
+ has_feature_extractor: NotRequired[bool]
49
+ has_image_preprocessor: NotRequired[bool]
47
50
 
48
51
 
49
52
  class LightGBMModelBlobOptions(BaseModelBlobOptions):
@@ -1,2 +1,2 @@
1
1
  REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
+ ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'keras>=2.0.0,<4', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<3', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.7.0,<3', 'sentencepiece>=0.1.95,<0.2.0', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.17.0,<3', 'tokenizers>=0.15.1,<1', 'torchdata>=0.4,<1', 'transformers>=4.37.2,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -45,6 +45,7 @@ class ModelRuntime:
45
45
  self.name = name
46
46
  self.runtime_env = copy.deepcopy(env)
47
47
  self.imports = imports or []
48
+ self.is_gpu = is_gpu
48
49
 
49
50
  if loading_from_file:
50
51
  return
@@ -88,7 +89,9 @@ class ModelRuntime:
88
89
  self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
89
90
  self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
90
91
 
91
- env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
92
+ env_dict = self.runtime_env.save_as_dict(
93
+ packager_path, default_channel_override=default_channel_override, is_gpu=self.is_gpu
94
+ )
92
95
 
93
96
  return model_meta_schema.ModelRuntimeDict(
94
97
  imports=list(map(str, self.imports)),