snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/file_utils.py +18 -4
  4. snowflake/ml/_internal/platform_capabilities.py +3 -0
  5. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  6. snowflake/ml/_internal/telemetry.py +25 -0
  7. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +0 -1
  11. snowflake/ml/jobs/_utils/constants.py +31 -1
  12. snowflake/ml/jobs/_utils/payload_utils.py +232 -72
  13. snowflake/ml/jobs/_utils/spec_utils.py +78 -38
  14. snowflake/ml/jobs/decorators.py +8 -25
  15. snowflake/ml/jobs/job.py +4 -4
  16. snowflake/ml/jobs/manager.py +5 -0
  17. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  18. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  19. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  21. snowflake/ml/model/_client/sql/model_version.py +58 -0
  22. snowflake/ml/model/_client/sql/service.py +8 -2
  23. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  26. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  27. snowflake/ml/model/_packager/model_env/model_env.py +49 -29
  28. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  29. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
  30. snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
  31. snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
  32. snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
  33. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
  34. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
  35. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  36. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  37. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  38. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  39. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  40. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  41. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  42. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
  43. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
  44. snowflake/ml/model/_packager/model_packager.py +3 -5
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  48. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  49. snowflake/ml/model/_signatures/core.py +54 -33
  50. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  51. snowflake/ml/model/_signatures/numpy_handler.py +12 -20
  52. snowflake/ml/model/_signatures/pandas_handler.py +28 -37
  53. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  54. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  55. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  56. snowflake/ml/model/_signatures/utils.py +120 -8
  57. snowflake/ml/model/custom_model.py +13 -4
  58. snowflake/ml/model/model_signature.py +39 -13
  59. snowflake/ml/model/type_hints.py +28 -2
  60. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  61. snowflake/ml/modeling/metrics/ranking.py +3 -0
  62. snowflake/ml/modeling/metrics/regression.py +3 -0
  63. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  64. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  65. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  66. snowflake/ml/registry/_manager/model_manager.py +55 -7
  67. snowflake/ml/registry/registry.py +52 -4
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
  70. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
  71. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  72. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import itertools
3
3
  import os
4
4
  import pathlib
5
5
  import warnings
6
- from typing import DefaultDict, List, Optional
6
+ from typing import DefaultDict, Dict, List, Optional
7
7
 
8
8
  from packaging import requirements, version
9
9
 
@@ -36,6 +36,7 @@ class ModelEnv:
36
36
  pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
37
37
  self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
38
38
  self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
39
+ self.artifact_repository_map: Optional[Dict[str, str]] = None
39
40
  self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
40
41
  self._pip_requirements: List[requirements.Requirement] = []
41
42
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
@@ -113,7 +114,33 @@ class ModelEnv:
113
114
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
114
115
 
115
116
  def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
116
- """Append requirements into model env if absent.
117
+ """Append requirements into model env if absent. Depending on the environment, requirements may be added
118
+ to either the pip requirements or conda dependencies.
119
+
120
+ Args:
121
+ pkgs: A list of ModelDependency namedtuple to be appended.
122
+ check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
123
+ """
124
+ if self.pip_requirements and not self.conda_dependencies and pkgs:
125
+ pip_pkg_reqs: List[str] = []
126
+ warnings.warn(
127
+ (
128
+ "Dependencies specified from pip requirements."
129
+ " This may prevent model deploying to Snowflake Warehouse."
130
+ ),
131
+ category=UserWarning,
132
+ stacklevel=2,
133
+ )
134
+ for conda_req_str, pip_name in pkgs:
135
+ _, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
136
+ pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
137
+ pip_pkg_reqs.append(str(pip_req))
138
+ self._include_if_absent_pip(pip_pkg_reqs, check_local_version)
139
+ else:
140
+ self._include_if_absent_conda(pkgs, check_local_version)
141
+
142
+ def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
143
+ """Append requirements into model env conda dependencies if absent.
117
144
 
118
145
  Args:
119
146
  pkgs: A list of ModelDependency namedtuple to be appended.
@@ -134,8 +161,8 @@ class ModelEnv:
134
161
  if show_warning_message:
135
162
  warnings.warn(
136
163
  (
137
- f"Basic dependency {req_to_add.name} specified from PIP requirements."
138
- + " This may prevent model deploying to Snowflake Warehouse."
164
+ f"Basic dependency {req_to_add.name} specified from pip requirements."
165
+ " This may prevent model deploying to Snowflake Warehouse."
139
166
  ),
140
167
  category=UserWarning,
141
168
  stacklevel=2,
@@ -157,11 +184,11 @@ class ModelEnv:
157
184
  stacklevel=2,
158
185
  )
159
186
 
160
- def include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
161
- """Append pip requirements into model env if absent.
187
+ def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
188
+ """Append pip requirements into model env pip requirements if absent.
162
189
 
163
190
  Args:
164
- pkgs: A list of string to be appended in pip requirement.
191
+ pkgs: A list of strings to be appended to pip environment.
165
192
  check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
166
193
  """
167
194
 
@@ -187,25 +214,6 @@ class ModelEnv:
187
214
  self._conda_dependencies[channel].remove(spec)
188
215
 
189
216
  def generate_env_for_cuda(self) -> None:
190
- if self.cuda_version is None:
191
- return
192
-
193
- cuda_spec = env_utils.find_dep_spec(
194
- self._conda_dependencies, self._pip_requirements, conda_pkg_name="cuda", remove_spec=False
195
- )
196
- if cuda_spec and not cuda_spec.specifier.contains(self.cuda_version):
197
- raise ValueError(
198
- "The CUDA requirement you specified in your conda dependencies or pip requirements is"
199
- " conflicting with CUDA version required. Please do not specify CUDA dependency using conda"
200
- " dependencies or pip requirements."
201
- )
202
-
203
- if not cuda_spec:
204
- self.include_if_absent(
205
- [ModelDependency(requirement=f"nvidia::cuda=={self.cuda_version}.*", pip_name="cuda")],
206
- check_local_version=False,
207
- )
208
-
209
217
  xgboost_spec = env_utils.find_dep_spec(
210
218
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
211
219
  )
@@ -236,7 +244,7 @@ class ModelEnv:
236
244
  check_local_version=False,
237
245
  )
238
246
 
239
- self.include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
247
+ self._include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
240
248
 
241
249
  def relax_version(self) -> None:
242
250
  """Relax the version requirements for both conda dependencies and pip requirements.
@@ -252,7 +260,9 @@ class ModelEnv:
252
260
  self._pip_requirements = list(map(env_utils.relax_requirement_version, self._pip_requirements))
253
261
 
254
262
  def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None:
255
- conda_dependencies_dict, pip_requirements_list, python_version = env_utils.load_conda_env_file(conda_env_path)
263
+ conda_dependencies_dict, pip_requirements_list, python_version, cuda_version = env_utils.load_conda_env_file(
264
+ conda_env_path
265
+ )
256
266
 
257
267
  for channel, channel_dependencies in conda_dependencies_dict.items():
258
268
  if channel != env_utils.DEFAULT_CHANNEL_NAME:
@@ -310,6 +320,9 @@ class ModelEnv:
310
320
  if python_version:
311
321
  self.python_version = python_version
312
322
 
323
+ if cuda_version:
324
+ self.cuda_version = cuda_version
325
+
313
326
  def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
314
327
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
315
328
 
@@ -333,6 +346,7 @@ class ModelEnv:
333
346
  def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
334
347
  self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
335
348
  self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
349
+ self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
336
350
 
337
351
  self.load_from_conda_file(base_dir / self.conda_env_rel_path)
338
352
  self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
@@ -342,12 +356,17 @@ class ModelEnv:
342
356
  self.snowpark_ml_version = env_dict["snowpark_ml_version"]
343
357
 
344
358
  def save_as_dict(
345
- self, base_dir: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
359
+ self,
360
+ base_dir: pathlib.Path,
361
+ default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL,
362
+ is_gpu: Optional[bool] = False,
346
363
  ) -> model_meta_schema.ModelEnvDict:
364
+ cuda_version = self.cuda_version if is_gpu else None
347
365
  env_utils.save_conda_env_file(
348
366
  pathlib.Path(base_dir / self.conda_env_rel_path),
349
367
  self._conda_dependencies,
350
368
  self.python_version,
369
+ cuda_version,
351
370
  default_channel_override=default_channel_override,
352
371
  )
353
372
  env_utils.save_requirements_file(
@@ -356,6 +375,7 @@ class ModelEnv:
356
375
  return {
357
376
  "conda": self.conda_env_rel_path.as_posix(),
358
377
  "pip": self.pip_requirements_rel_path.as_posix(),
378
+ "artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
359
379
  "python_version": self.python_version,
360
380
  "cuda_version": self.cuda_version,
361
381
  "snowpark_ml_version": self.snowpark_ml_version,
@@ -39,7 +39,7 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
39
39
 
40
40
 
41
41
  def get_truncated_sample_data(
42
- sample_input_data: model_types.SupportedDataType, length: int = 100
42
+ sample_input_data: model_types.SupportedDataType, length: int = 100, is_for_modeling_model: bool = False
43
43
  ) -> model_types.SupportedLocalDataType:
44
44
  trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
45
45
  local_sample_input: model_types.SupportedLocalDataType = None
@@ -47,6 +47,8 @@ def get_truncated_sample_data(
47
47
  # Added because of Any from missing stubs.
48
48
  trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
49
49
  local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
50
+ if is_for_modeling_model:
51
+ local_sample_input.columns = trunc_sample_input.columns
50
52
  else:
51
53
  local_sample_input = trunc_sample_input
52
54
  return local_sample_input
@@ -58,13 +60,15 @@ def validate_signature(
58
60
  target_methods: Iterable[str],
59
61
  sample_input_data: Optional[model_types.SupportedDataType],
60
62
  get_prediction_fn: Callable[[str, model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
63
+ is_for_modeling_model: bool = False,
61
64
  ) -> model_meta.ModelMetadata:
62
65
  if model_meta.signatures:
63
66
  validate_target_methods(model, list(model_meta.signatures.keys()))
64
67
  if sample_input_data is not None:
65
- local_sample_input = get_truncated_sample_data(sample_input_data)
68
+ local_sample_input = get_truncated_sample_data(
69
+ sample_input_data, is_for_modeling_model=is_for_modeling_model
70
+ )
66
71
  for target_method in model_meta.signatures.keys():
67
-
68
72
  model_signature_inst = model_meta.signatures.get(target_method)
69
73
  if model_signature_inst is not None:
70
74
  # strict validation the input signature
@@ -77,7 +81,7 @@ def validate_signature(
77
81
  assert (
78
82
  sample_input_data is not None
79
83
  ), "Model signature and sample input are None at the same time. This should not happen with local model."
80
- local_sample_input = get_truncated_sample_data(sample_input_data)
84
+ local_sample_input = get_truncated_sample_data(sample_input_data, is_for_modeling_model=is_for_modeling_model)
81
85
  for target_method in target_methods:
82
86
  predictions_df = get_prediction_fn(target_method, local_sample_input)
83
87
  sig = model_signature.infer_signature(
@@ -30,10 +30,7 @@ from snowflake.ml.model._packager.model_meta import (
30
30
  model_meta as model_meta_api,
31
31
  model_meta_schema,
32
32
  )
33
- from snowflake.ml.model._signatures import (
34
- builtins_handler,
35
- utils as model_signature_utils,
36
- )
33
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
34
  from snowflake.ml.model.models import huggingface_pipeline
38
35
  from snowflake.snowpark._internal import utils as snowpark_utils
39
36
 
@@ -66,16 +63,16 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model
66
63
  return []
67
64
 
68
65
 
69
- class NumpyEncoder(json.JSONEncoder):
70
- # This is a JSON encoder class to ensure the output from Huggingface pipeline is JSON serializable.
71
- # What it covers is numpy object.
72
- def default(self, z: object) -> object:
73
- if isinstance(z, np.number):
74
- if np.can_cast(z, np.int64, casting="safe"):
75
- return int(z)
76
- elif np.can_cast(z, np.float64, casting="safe"):
77
- return z.astype(np.float64)
78
- return super().default(z)
66
+ def sanitize_output(data: Any) -> Any:
67
+ if isinstance(data, np.number):
68
+ return data.item()
69
+ if isinstance(data, np.ndarray):
70
+ return sanitize_output(data.tolist())
71
+ if isinstance(data, list):
72
+ return [sanitize_output(x) for x in data]
73
+ if isinstance(data, dict):
74
+ return {k: sanitize_output(v) for k, v in data.items()}
75
+ return data
79
76
 
80
77
 
81
78
  @final
@@ -146,6 +143,10 @@ class HuggingFacePipelineHandler(
146
143
  framework = getattr(model, "framework", None)
147
144
  batch_size = getattr(model, "batch_size", None)
148
145
 
146
+ has_tokenizer = getattr(model, "tokenizer", None) is not None
147
+ has_feature_extractor = getattr(model, "feature_extractor", None) is not None
148
+ has_image_preprocessor = getattr(model, "image_preprocessor", None) is not None
149
+
149
150
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
150
151
  params = {
151
152
  **model._preprocess_params, # type:ignore[attr-defined]
@@ -234,6 +235,9 @@ class HuggingFacePipelineHandler(
234
235
  {
235
236
  "task": task,
236
237
  "batch_size": batch_size if batch_size is not None else 1,
238
+ "has_tokenizer": has_tokenizer,
239
+ "has_feature_extractor": has_feature_extractor,
240
+ "has_image_preprocessor": has_image_preprocessor,
237
241
  }
238
242
  ),
239
243
  )
@@ -308,6 +312,14 @@ class HuggingFacePipelineHandler(
308
312
  if os.path.isdir(model_blob_file_or_dir_path):
309
313
  import transformers
310
314
 
315
+ additional_pipeline_params = {}
316
+ if model_blob_options.get("has_tokenizer", False):
317
+ additional_pipeline_params["tokenizer"] = model_blob_file_or_dir_path
318
+ if model_blob_options.get("has_feature_extractor", False):
319
+ additional_pipeline_params["feature_extractor"] = model_blob_file_or_dir_path
320
+ if model_blob_options.get("has_image_preprocessor", False):
321
+ additional_pipeline_params["image_preprocessor"] = model_blob_file_or_dir_path
322
+
311
323
  with open(
312
324
  os.path.join(
313
325
  model_blob_file_or_dir_path,
@@ -324,6 +336,7 @@ class HuggingFacePipelineHandler(
324
336
  model=model_blob_file_or_dir_path,
325
337
  trust_remote_code=True,
326
338
  torch_dtype="auto",
339
+ **additional_pipeline_params,
327
340
  **device_config,
328
341
  )
329
342
 
@@ -394,13 +407,17 @@ class HuggingFacePipelineHandler(
394
407
  )
395
408
  for conv_data in X.to_dict("records")
396
409
  ]
397
- elif len(signature.inputs) == 1:
398
- input_data = X.to_dict("list")[signature.inputs[0].name]
399
410
  else:
400
411
  if isinstance(raw_model, transformers.TableQuestionAnsweringPipeline):
401
412
  X["table"] = X["table"].apply(json.loads)
402
413
 
403
- input_data = X.to_dict("records")
414
+ # Most pipelines if it is expecting more than one arguments,
415
+ # it is expecting a list of dict, where each dict has keys corresponding to the argument.
416
+ if len(signature.inputs) > 1:
417
+ input_data = X.to_dict("records")
418
+ # If it is only expecting one argument, Then it is expecting a list of something.
419
+ else:
420
+ input_data = X[signature.inputs[0].name].to_list()
404
421
  temp_res = getattr(raw_model, target_method)(input_data)
405
422
 
406
423
  # Some huggingface pipeline will omit the outer list when there is only 1 input.
@@ -423,7 +440,6 @@ class HuggingFacePipelineHandler(
423
440
  ),
424
441
  )
425
442
  and X.shape[0] == 1
426
- and isinstance(temp_res[0], dict)
427
443
  )
428
444
  ):
429
445
  temp_res = [temp_res]
@@ -437,14 +453,18 @@ class HuggingFacePipelineHandler(
437
453
  temp_res = [[conv.generated_responses] for conv in temp_res]
438
454
 
439
455
  # To concat those who outputs a list with one input.
440
- if builtins_handler.ListOfBuiltinHandler.can_handle(temp_res):
441
- res = builtins_handler.ListOfBuiltinHandler.convert_to_df(temp_res)
442
- elif isinstance(temp_res[0], dict):
456
+ if isinstance(temp_res[0], list):
457
+ if isinstance(temp_res[0][0], dict):
458
+ res = pd.DataFrame({0: temp_res})
459
+ else:
460
+ res = pd.DataFrame(temp_res)
461
+ else:
443
462
  res = pd.DataFrame(temp_res)
444
- elif isinstance(temp_res[0], list):
445
- res = pd.DataFrame([json.dumps(output, cls=NumpyEncoder) for output in temp_res])
463
+
464
+ if hasattr(res, "map"):
465
+ res = res.map(sanitize_output)
446
466
  else:
447
- raise ValueError(f"Cannot parse output {temp_res} from pipeline object")
467
+ res = res.applymap(sanitize_output)
448
468
 
449
469
  return model_signature_utils.rename_pandas_df(data=res, features=signature.outputs)
450
470
 
@@ -0,0 +1,226 @@
1
+ import os
2
+ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+
4
+ import cloudpickle
5
+ import numpy as np
6
+ import pandas as pd
7
+ from packaging import version
8
+ from typing_extensions import TypeGuard, Unpack
9
+
10
+ from snowflake.ml._internal import type_utils
11
+ from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
+ from snowflake.ml.model._packager.model_env import model_env
13
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
14
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
15
+ from snowflake.ml.model._packager.model_meta import (
16
+ model_blob_meta,
17
+ model_meta as model_meta_api,
18
+ )
19
+ from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
20
+
21
+ if TYPE_CHECKING:
22
+ import keras
23
+
24
+
25
+ @final
26
+ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
27
+ """Handler for Keras v3 model.
28
+
29
+ Currently keras.Model based classes are supported.
30
+ """
31
+
32
+ HANDLER_TYPE = "keras"
33
+ HANDLER_VERSION = "2025-01-01"
34
+ _MIN_SNOWPARK_ML_VERSION = "1.7.5"
35
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
36
+
37
+ MODEL_BLOB_FILE_OR_DIR = "model.keras"
38
+ CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
39
+ DEFAULT_TARGET_METHODS = ["predict"]
40
+
41
+ @classmethod
42
+ def can_handle(
43
+ cls,
44
+ model: model_types.SupportedModelType,
45
+ ) -> TypeGuard["keras.Model"]:
46
+ if not type_utils.LazyType("keras.Model").isinstance(model):
47
+ return False
48
+ import keras
49
+
50
+ return version.parse(keras.__version__) >= version.parse("3.0.0")
51
+
52
+ @classmethod
53
+ def cast_model(
54
+ cls,
55
+ model: model_types.SupportedModelType,
56
+ ) -> "keras.Model":
57
+ import keras
58
+
59
+ assert isinstance(model, keras.Model)
60
+
61
+ return cast(keras.Model, model)
62
+
63
+ @classmethod
64
+ def save_model(
65
+ cls,
66
+ name: str,
67
+ model: "keras.Model",
68
+ model_meta: model_meta_api.ModelMetadata,
69
+ model_blobs_dir_path: str,
70
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
71
+ is_sub_model: Optional[bool] = False,
72
+ **kwargs: Unpack[model_types.TensorflowSaveOptions],
73
+ ) -> None:
74
+ enable_explainability = kwargs.get("enable_explainability", False)
75
+ if enable_explainability:
76
+ raise NotImplementedError("Explainability is not supported for Tensorflow model.")
77
+
78
+ import keras
79
+
80
+ assert isinstance(model, keras.Model)
81
+
82
+ if not is_sub_model:
83
+ target_methods = handlers_utils.get_target_methods(
84
+ model=model,
85
+ target_methods=kwargs.pop("target_methods", None),
86
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
87
+ )
88
+
89
+ def get_prediction(
90
+ target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
91
+ ) -> model_types.SupportedLocalDataType:
92
+ target_method = getattr(model, target_method_name, None)
93
+ assert callable(target_method)
94
+ predictions_df = target_method(sample_input_data)
95
+
96
+ if (
97
+ type_utils.LazyType("tensorflow.Tensor").isinstance(predictions_df)
98
+ or type_utils.LazyType("tensorflow.Variable").isinstance(predictions_df)
99
+ or type_utils.LazyType("torch.Tensor").isinstance(predictions_df)
100
+ ):
101
+ predictions_df = [predictions_df]
102
+
103
+ return predictions_df
104
+
105
+ model_meta = handlers_utils.validate_signature(
106
+ model=model,
107
+ model_meta=model_meta,
108
+ target_methods=target_methods,
109
+ sample_input_data=sample_input_data,
110
+ get_prediction_fn=get_prediction,
111
+ )
112
+
113
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
114
+ os.makedirs(model_blob_path, exist_ok=True)
115
+ save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
116
+ model.save(save_path)
117
+
118
+ custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
119
+ custom_objects = keras.saving.get_custom_objects()
120
+ with open(custom_object_save_path, "wb") as f:
121
+ cloudpickle.dump(custom_objects, f)
122
+
123
+ base_meta = model_blob_meta.ModelBlobMeta(
124
+ name=name,
125
+ model_type=cls.HANDLER_TYPE,
126
+ handler_version=cls.HANDLER_VERSION,
127
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
128
+ )
129
+ model_meta.models[name] = base_meta
130
+ model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
131
+
132
+ dependencies = [
133
+ model_env.ModelDependency(requirement="keras>=3", pip_name="keras"),
134
+ ]
135
+ keras_backend = keras.backend.backend()
136
+ if keras_backend == "tensorflow":
137
+ dependencies.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
138
+ elif keras_backend == "torch":
139
+ dependencies.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
140
+ elif keras_backend == "jax":
141
+ dependencies.append(model_env.ModelDependency(requirement="jax", pip_name="jax"))
142
+ else:
143
+ raise ValueError(f"Unsupported backend {keras_backend}")
144
+
145
+ model_meta.env.include_if_absent(
146
+ dependencies,
147
+ check_local_version=True,
148
+ )
149
+ model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
150
+
151
+ @classmethod
152
+ def load_model(
153
+ cls,
154
+ name: str,
155
+ model_meta: model_meta_api.ModelMetadata,
156
+ model_blobs_dir_path: str,
157
+ **kwargs: Unpack[model_types.TensorflowLoadOptions],
158
+ ) -> "keras.Model":
159
+ import keras
160
+
161
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
162
+ model_blobs_metadata = model_meta.models
163
+ model_blob_metadata = model_blobs_metadata[name]
164
+ model_blob_filename = model_blob_metadata.path
165
+
166
+ custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
167
+ with open(custom_object_save_path, "rb") as f:
168
+ custom_objects = cloudpickle.load(f)
169
+ load_path = os.path.join(model_blob_path, model_blob_filename)
170
+ m = keras.models.load_model(load_path, custom_objects=custom_objects, safe_mode=False)
171
+
172
+ return cast(keras.Model, m)
173
+
174
+ @classmethod
175
+ def convert_as_custom_model(
176
+ cls,
177
+ raw_model: "keras.Model",
178
+ model_meta: model_meta_api.ModelMetadata,
179
+ background_data: Optional[pd.DataFrame] = None,
180
+ **kwargs: Unpack[model_types.TensorflowLoadOptions],
181
+ ) -> custom_model.CustomModel:
182
+
183
+ from snowflake.ml.model import custom_model
184
+
185
+ def _create_custom_model(
186
+ raw_model: "keras.Model",
187
+ model_meta: model_meta_api.ModelMetadata,
188
+ ) -> Type[custom_model.CustomModel]:
189
+ def fn_factory(
190
+ raw_model: "keras.Model",
191
+ signature: model_signature.ModelSignature,
192
+ target_method: str,
193
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
194
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
195
+
196
+ @custom_model.inference_api
197
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
198
+ res = getattr(raw_model, target_method)(X.astype(dtype_map), verbose=0)
199
+
200
+ if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
201
+ # In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
202
+ # return a list of ndarrays. We need to deal them separately
203
+ df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
204
+ else:
205
+ df = pd.DataFrame(res)
206
+
207
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
208
+
209
+ return fn
210
+
211
+ type_method_dict = {}
212
+ for target_method_name, sig in model_meta.signatures.items():
213
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
214
+
215
+ _KerasModel = type(
216
+ "_KerasModel",
217
+ (custom_model.CustomModel,),
218
+ type_method_dict,
219
+ )
220
+
221
+ return _KerasModel
222
+
223
+ _KerasModel = _create_custom_model(raw_model, model_meta)
224
+ keras_model = _KerasModel(custom_model.ModelContext())
225
+
226
+ return keras_model