snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +25 -1
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +95 -31
  9. snowflake/ml/jobs/decorators.py +7 -0
  10. snowflake/ml/jobs/manager.py +20 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +113 -17
  13. snowflake/ml/model/_client/ops/service_ops.py +16 -5
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +10 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
  52. snowflake/ml/modeling/metrics/ranking.py +3 -0
  53. snowflake/ml/modeling/metrics/regression.py +3 -0
  54. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  55. snowflake/ml/registry/_manager/model_manager.py +55 -7
  56. snowflake/ml/registry/registry.py +59 -1
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
@@ -30,10 +30,7 @@ from snowflake.ml.model._packager.model_meta import (
30
30
  model_meta as model_meta_api,
31
31
  model_meta_schema,
32
32
  )
33
- from snowflake.ml.model._signatures import (
34
- builtins_handler,
35
- utils as model_signature_utils,
36
- )
33
+ from snowflake.ml.model._signatures import utils as model_signature_utils
37
34
  from snowflake.ml.model.models import huggingface_pipeline
38
35
  from snowflake.snowpark._internal import utils as snowpark_utils
39
36
 
@@ -66,16 +63,16 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model
66
63
  return []
67
64
 
68
65
 
69
- class NumpyEncoder(json.JSONEncoder):
70
- # This is a JSON encoder class to ensure the output from Huggingface pipeline is JSON serializable.
71
- # What it covers is numpy object.
72
- def default(self, z: object) -> object:
73
- if isinstance(z, np.number):
74
- if np.can_cast(z, np.int64, casting="safe"):
75
- return int(z)
76
- elif np.can_cast(z, np.float64, casting="safe"):
77
- return z.astype(np.float64)
78
- return super().default(z)
66
+ def sanitize_output(data: Any) -> Any:
67
+ if isinstance(data, np.number):
68
+ return data.item()
69
+ if isinstance(data, np.ndarray):
70
+ return sanitize_output(data.tolist())
71
+ if isinstance(data, list):
72
+ return [sanitize_output(x) for x in data]
73
+ if isinstance(data, dict):
74
+ return {k: sanitize_output(v) for k, v in data.items()}
75
+ return data
79
76
 
80
77
 
81
78
  @final
@@ -410,13 +407,17 @@ class HuggingFacePipelineHandler(
410
407
  )
411
408
  for conv_data in X.to_dict("records")
412
409
  ]
413
- elif len(signature.inputs) == 1:
414
- input_data = X.to_dict("list")[signature.inputs[0].name]
415
410
  else:
416
411
  if isinstance(raw_model, transformers.TableQuestionAnsweringPipeline):
417
412
  X["table"] = X["table"].apply(json.loads)
418
413
 
419
- input_data = X.to_dict("records")
414
+ # Most pipelines if it is expecting more than one arguments,
415
+ # it is expecting a list of dict, where each dict has keys corresponding to the argument.
416
+ if len(signature.inputs) > 1:
417
+ input_data = X.to_dict("records")
418
+ # If it is only expecting one argument, Then it is expecting a list of something.
419
+ else:
420
+ input_data = X[signature.inputs[0].name].to_list()
420
421
  temp_res = getattr(raw_model, target_method)(input_data)
421
422
 
422
423
  # Some huggingface pipeline will omit the outer list when there is only 1 input.
@@ -439,7 +440,6 @@ class HuggingFacePipelineHandler(
439
440
  ),
440
441
  )
441
442
  and X.shape[0] == 1
442
- and isinstance(temp_res[0], dict)
443
443
  )
444
444
  ):
445
445
  temp_res = [temp_res]
@@ -453,14 +453,18 @@ class HuggingFacePipelineHandler(
453
453
  temp_res = [[conv.generated_responses] for conv in temp_res]
454
454
 
455
455
  # To concat those who outputs a list with one input.
456
- if builtins_handler.ListOfBuiltinHandler.can_handle(temp_res):
457
- res = builtins_handler.ListOfBuiltinHandler.convert_to_df(temp_res)
458
- elif isinstance(temp_res[0], dict):
456
+ if isinstance(temp_res[0], list):
457
+ if isinstance(temp_res[0][0], dict):
458
+ res = pd.DataFrame({0: temp_res})
459
+ else:
460
+ res = pd.DataFrame(temp_res)
461
+ else:
459
462
  res = pd.DataFrame(temp_res)
460
- elif isinstance(temp_res[0], list):
461
- res = pd.DataFrame([json.dumps(output, cls=NumpyEncoder) for output in temp_res])
463
+
464
+ if hasattr(res, "map"):
465
+ res = res.map(sanitize_output)
462
466
  else:
463
- raise ValueError(f"Cannot parse output {temp_res} from pipeline object")
467
+ res = res.applymap(sanitize_output)
464
468
 
465
469
  return model_signature_utils.rename_pandas_df(data=res, features=signature.outputs)
466
470
 
@@ -191,11 +191,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
191
191
  signature: model_signature.ModelSignature,
192
192
  target_method: str,
193
193
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
194
- dtype_map = {
195
- spec.name: spec.as_dtype(force_numpy_dtype=True)
196
- for spec in signature.inputs
197
- if isinstance(spec, model_signature.FeatureSpec)
198
- }
194
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
199
195
 
200
196
  @custom_model.inference_api
201
197
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
@@ -10,10 +10,14 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._packager.model_env import model_env
12
12
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
+ from snowflake.ml.model._packager.model_handlers_migrator import (
14
+ base_migrator,
15
+ pytorch_migrator_2023_12_01,
16
+ )
14
17
  from snowflake.ml.model._packager.model_meta import (
15
18
  model_blob_meta,
16
19
  model_meta as model_meta_api,
20
+ model_meta_schema,
17
21
  )
18
22
  from snowflake.ml.model._signatures import (
19
23
  pytorch_handler,
@@ -21,7 +25,6 @@ from snowflake.ml.model._signatures import (
21
25
  )
22
26
 
23
27
  if TYPE_CHECKING:
24
- import sentence_transformers # noqa: F401
25
28
  import torch
26
29
 
27
30
 
@@ -33,9 +36,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
33
36
  """
34
37
 
35
38
  HANDLER_TYPE = "pytorch"
36
- HANDLER_VERSION = "2023-12-01"
37
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
38
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
+ HANDLER_VERSION = "2025-03-01"
40
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
42
+ "2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
43
+ }
39
44
 
40
45
  MODEL_BLOB_FILE_OR_DIR = "model.pt"
41
46
  DEFAULT_TARGET_METHODS = ["forward"]
@@ -89,22 +94,33 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
89
94
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
90
95
  )
91
96
 
97
+ multiple_inputs = kwargs.get("multiple_inputs", False)
98
+
92
99
  def get_prediction(
93
100
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
94
101
  ) -> model_types.SupportedLocalDataType:
95
- if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
96
- sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
97
- model_signature._convert_local_data_to_df(sample_input_data)
98
- )
102
+ if multiple_inputs:
103
+ if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
104
+ sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
105
+ model_signature._convert_local_data_to_df(sample_input_data)
106
+ )
107
+ else:
108
+ if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
109
+ sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
110
+ model_signature._convert_local_data_to_df(sample_input_data)
111
+ )
99
112
 
100
113
  model.eval()
101
114
  target_method = getattr(model, target_method_name, None)
102
115
  assert callable(target_method)
103
116
  with torch.no_grad():
104
- predictions_df = target_method(*sample_input_data)
117
+ if multiple_inputs:
118
+ predictions_df = target_method(*sample_input_data)
119
+ if not isinstance(predictions_df, tuple):
120
+ predictions_df = [predictions_df]
121
+ else:
122
+ predictions_df = target_method(sample_input_data)
105
123
 
106
- if isinstance(predictions_df, torch.Tensor):
107
- predictions_df = [predictions_df]
108
124
  return predictions_df
109
125
 
110
126
  model_meta = handlers_utils.validate_signature(
@@ -127,6 +143,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
127
143
  model_type=cls.HANDLER_TYPE,
128
144
  handler_version=cls.HANDLER_VERSION,
129
145
  path=cls.MODEL_BLOB_FILE_OR_DIR,
146
+ options=model_meta_schema.PyTorchModelBlobOptions(multiple_inputs=multiple_inputs),
130
147
  )
131
148
  model_meta.models[name] = base_meta
132
149
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -172,6 +189,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
172
189
  raw_model: "torch.nn.Module",
173
190
  model_meta: model_meta_api.ModelMetadata,
174
191
  ) -> Type[custom_model.CustomModel]:
192
+ multiple_inputs = cast(
193
+ model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
194
+ )["multiple_inputs"]
195
+
175
196
  def fn_factory(
176
197
  raw_model: "torch.nn.Module",
177
198
  signature: model_signature.ModelSignature,
@@ -183,19 +204,28 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
183
204
  raise ValueError("Tensor cannot handle null values.")
184
205
 
185
206
  raw_model.eval()
186
- t = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
207
+ if multiple_inputs:
208
+ st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
209
+
210
+ if kwargs.get("use_gpu", False):
211
+ st = [element.cuda() for element in st]
187
212
 
188
- if kwargs.get("use_gpu", False):
189
- t = [element.cuda() for element in t]
213
+ with torch.no_grad():
214
+ res = getattr(raw_model, target_method)(*st)
190
215
 
191
- with torch.no_grad():
192
- res = getattr(raw_model, target_method)(*t)
216
+ if not isinstance(res, tuple):
217
+ res = [res]
218
+ else:
219
+ t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
220
+ if kwargs.get("use_gpu", False):
221
+ t = t.cuda()
193
222
 
194
- if isinstance(res, torch.Tensor):
195
- res = [res]
223
+ with torch.no_grad():
224
+ res = getattr(raw_model, target_method)(t)
196
225
 
197
226
  return model_signature_utils.rename_pandas_df(
198
- data=pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
227
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
228
+ features=signature.outputs,
199
229
  )
200
230
 
201
231
  return fn
@@ -57,6 +57,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
57
57
  "predict_proba",
58
58
  "predict_log_proba",
59
59
  "decision_function",
60
+ "score_samples",
60
61
  ]
61
62
  EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
62
63
 
@@ -74,10 +75,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
74
75
  and (
75
76
  not type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
76
77
  ) # LGBMModel is actually a BaseEstimator
77
- and any(
78
- (hasattr(model, method) and callable(getattr(model, method, None)))
79
- for method in cls.DEFAULT_TARGET_METHODS
80
- )
81
78
  )
82
79
 
83
80
  @classmethod
@@ -297,10 +294,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
297
294
  df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
298
295
  except TypeError:
299
296
  try:
300
- dtype_map = {
301
- spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
302
- for spec in signature.inputs
303
- }
297
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
304
298
 
305
299
  if isinstance(X, pd.DataFrame):
306
300
  X = X.astype(dtype_map, copy=False)
@@ -307,8 +307,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
307
307
  except TypeError:
308
308
  try:
309
309
  dtype_map = {
310
- spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
311
- for spec in signature.inputs
310
+ spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
312
311
  }
313
312
 
314
313
  if isinstance(X, pd.DataFrame):
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
3
 
4
- import numpy as np
5
4
  import pandas as pd
6
5
  from packaging import version
7
6
  from typing_extensions import TypeGuard, Unpack
@@ -13,6 +12,7 @@ from snowflake.ml.model._packager.model_handlers import _base, _utils as handler
13
12
  from snowflake.ml.model._packager.model_handlers_migrator import (
14
13
  base_migrator,
15
14
  tensorflow_migrator_2023_12_01,
15
+ tensorflow_migrator_2025_01_01,
16
16
  )
17
17
  from snowflake.ml.model._packager.model_meta import (
18
18
  model_blob_meta,
@@ -20,7 +20,6 @@ from snowflake.ml.model._packager.model_meta import (
20
20
  model_meta_schema,
21
21
  )
22
22
  from snowflake.ml.model._signatures import (
23
- numpy_handler,
24
23
  tensorflow_handler,
25
24
  utils as model_signature_utils,
26
25
  )
@@ -37,10 +36,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
37
36
  """
38
37
 
39
38
  HANDLER_TYPE = "tensorflow"
40
- HANDLER_VERSION = "2025-01-01"
41
- _MIN_SNOWPARK_ML_VERSION = "1.7.5"
39
+ HANDLER_VERSION = "2025-03-01"
40
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
42
41
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
43
- "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201
42
+ "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
43
+ "2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
44
44
  }
45
45
 
46
46
  MODEL_BLOB_FILE_OR_DIR = "model"
@@ -112,25 +112,35 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
112
112
  default_target_methods=default_target_methods,
113
113
  )
114
114
 
115
+ multiple_inputs = kwargs.get("multiple_inputs", False)
116
+
115
117
  if is_keras_model and len(target_methods) > 1:
116
118
  raise ValueError("Keras model can only have one target method.")
117
119
 
118
120
  def get_prediction(
119
121
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
120
122
  ) -> model_types.SupportedLocalDataType:
121
- if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
122
- sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
123
- model_signature._convert_local_data_to_df(sample_input_data)
124
- )
123
+ if multiple_inputs:
124
+ if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
125
+ sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
126
+ model_signature._convert_local_data_to_df(sample_input_data)
127
+ )
128
+ else:
129
+ if not tensorflow_handler.TensorflowTensorHandler.can_handle(sample_input_data):
130
+ sample_input_data = tensorflow_handler.TensorflowTensorHandler.convert_from_df(
131
+ model_signature._convert_local_data_to_df(sample_input_data)
132
+ )
125
133
 
126
134
  target_method = getattr(model, target_method_name, None)
127
135
  assert callable(target_method)
128
136
  for tensor in sample_input_data:
129
137
  tensorflow.stop_gradient(tensor)
130
- predictions_df = target_method(*sample_input_data)
131
-
132
- if isinstance(predictions_df, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
133
- predictions_df = [predictions_df]
138
+ if multiple_inputs:
139
+ predictions_df = target_method(*sample_input_data)
140
+ if not isinstance(predictions_df, tuple):
141
+ predictions_df = [predictions_df]
142
+ else:
143
+ predictions_df = target_method(sample_input_data)
134
144
 
135
145
  return predictions_df
136
146
 
@@ -159,7 +169,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
159
169
  model_type=cls.HANDLER_TYPE,
160
170
  handler_version=cls.HANDLER_VERSION,
161
171
  path=cls.MODEL_BLOB_FILE_OR_DIR,
162
- options=model_meta_schema.TensorflowModelBlobOptions(save_format=save_format),
172
+ options=model_meta_schema.TensorflowModelBlobOptions(
173
+ save_format=save_format, multiple_inputs=multiple_inputs
174
+ ),
163
175
  )
164
176
  model_meta.models[name] = base_meta
165
177
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -219,6 +231,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
219
231
  raw_model: "tensorflow.Module",
220
232
  model_meta: model_meta_api.ModelMetadata,
221
233
  ) -> Type[custom_model.CustomModel]:
234
+ multiple_inputs = cast(
235
+ model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
236
+ )["multiple_inputs"]
237
+
222
238
  def fn_factory(
223
239
  raw_model: "tensorflow.Module",
224
240
  signature: model_signature.ModelSignature,
@@ -229,21 +245,25 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
229
245
  if X.isnull().any(axis=None):
230
246
  raise ValueError("Tensor cannot handle null values.")
231
247
 
232
- t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
248
+ if multiple_inputs:
249
+ t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
233
250
 
234
- for tensor in t:
235
- tensorflow.stop_gradient(tensor)
236
- res = getattr(raw_model, target_method)(*t)
251
+ for tensor in t:
252
+ tensorflow.stop_gradient(tensor)
253
+ res = getattr(raw_model, target_method)(*t)
237
254
 
238
- if isinstance(res, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
239
- res = [res]
240
-
241
- if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
242
- # In case of running on CPU, it will return numpy array
243
- df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
255
+ if not isinstance(res, tuple):
256
+ res = [res]
244
257
  else:
245
- df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df(res)
246
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
258
+ t = tensorflow_handler.TensorflowTensorHandler.convert_from_df(X, signature.inputs)
259
+
260
+ tensorflow.stop_gradient(t)
261
+ res = getattr(raw_model, target_method)(t)
262
+
263
+ return model_signature_utils.rename_pandas_df(
264
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
265
+ features=signature.outputs,
266
+ )
247
267
 
248
268
  return fn
249
269
 
@@ -8,10 +8,14 @@ from snowflake.ml._internal import type_utils
8
8
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
9
9
  from snowflake.ml.model._packager.model_env import model_env
10
10
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
11
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
11
+ from snowflake.ml.model._packager.model_handlers_migrator import (
12
+ base_migrator,
13
+ torchscript_migrator_2023_12_01,
14
+ )
12
15
  from snowflake.ml.model._packager.model_meta import (
13
16
  model_blob_meta,
14
17
  model_meta as model_meta_api,
18
+ model_meta_schema,
15
19
  )
16
20
  from snowflake.ml.model._signatures import (
17
21
  pytorch_handler,
@@ -30,9 +34,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
30
34
  """
31
35
 
32
36
  HANDLER_TYPE = "torchscript"
33
- HANDLER_VERSION = "2023-12-01"
34
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
35
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
37
+ HANDLER_VERSION = "2025-03-01"
38
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
39
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
40
+ "2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
41
+ }
36
42
 
37
43
  MODEL_BLOB_FILE_OR_DIR = "model.pt"
38
44
  DEFAULT_TARGET_METHODS = ["forward"]
@@ -81,22 +87,32 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
81
87
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
82
88
  )
83
89
 
90
+ multiple_inputs = kwargs.get("multiple_inputs", False)
91
+
84
92
  def get_prediction(
85
93
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
86
94
  ) -> model_types.SupportedLocalDataType:
87
- if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
88
- sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
89
- model_signature._convert_local_data_to_df(sample_input_data)
90
- )
95
+ if multiple_inputs:
96
+ if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
97
+ sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
98
+ model_signature._convert_local_data_to_df(sample_input_data)
99
+ )
100
+ else:
101
+ if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
102
+ sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
103
+ model_signature._convert_local_data_to_df(sample_input_data)
104
+ )
91
105
 
92
106
  model.eval()
93
107
  target_method = getattr(model, target_method_name, None)
94
108
  assert callable(target_method)
95
109
  with torch.no_grad():
96
- predictions_df = target_method(*sample_input_data)
97
-
98
- if isinstance(predictions_df, torch.Tensor):
99
- predictions_df = [predictions_df]
110
+ if multiple_inputs:
111
+ predictions_df = target_method(*sample_input_data)
112
+ if not isinstance(predictions_df, tuple):
113
+ predictions_df = [predictions_df]
114
+ else:
115
+ predictions_df = target_method(sample_input_data)
100
116
 
101
117
  return predictions_df
102
118
 
@@ -117,6 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
117
133
  model_type=cls.HANDLER_TYPE,
118
134
  handler_version=cls.HANDLER_VERSION,
119
135
  path=cls.MODEL_BLOB_FILE_OR_DIR,
136
+ options=model_meta_schema.TorchScriptModelBlobOptions(multiple_inputs=multiple_inputs),
120
137
  )
121
138
  model_meta.models[name] = base_meta
122
139
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -170,6 +187,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
170
187
  signature: model_signature.ModelSignature,
171
188
  target_method: str,
172
189
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
190
+ multiple_inputs = cast(
191
+ model_meta_schema.TorchScriptModelBlobOptions, model_meta.models[model_meta.name].options
192
+ )["multiple_inputs"]
193
+
173
194
  @custom_model.inference_api
174
195
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
175
196
  if X.isnull().any(axis=None):
@@ -179,19 +200,27 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
179
200
 
180
201
  raw_model.eval()
181
202
 
182
- t = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
203
+ if multiple_inputs:
204
+ st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
183
205
 
184
- if kwargs.get("use_gpu", False):
185
- t = [element.cuda() for element in t]
206
+ if kwargs.get("use_gpu", False):
207
+ st = [element.cuda() for element in st]
186
208
 
187
- with torch.no_grad():
188
- res = getattr(raw_model, target_method)(*t)
209
+ with torch.no_grad():
210
+ res = getattr(raw_model, target_method)(*st)
189
211
 
190
- if isinstance(res, torch.Tensor):
191
- res = [res]
212
+ if not isinstance(res, tuple):
213
+ res = [res]
214
+ else:
215
+ t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
216
+ if kwargs.get("use_gpu", False):
217
+ t = t.cuda()
192
218
 
219
+ with torch.no_grad():
220
+ res = getattr(raw_model, target_method)(t)
193
221
  return model_signature_utils.rename_pandas_df(
194
- data=pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
222
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
223
+ features=signature.outputs,
195
224
  )
196
225
 
197
226
  return fn
@@ -99,10 +99,10 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
99
99
  def get_prediction(
100
100
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
101
101
  ) -> model_types.SupportedLocalDataType:
102
- if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
102
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray, xgboost.DMatrix)):
103
103
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
104
104
 
105
- if isinstance(model, xgboost.Booster):
105
+ if isinstance(model, xgboost.Booster) and not isinstance(sample_input_data, xgboost.DMatrix):
106
106
  sample_input_data = xgboost.DMatrix(sample_input_data)
107
107
 
108
108
  target_method = getattr(model, target_method_name, None)
@@ -0,0 +1,20 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class PyTorchHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
19
+ model_blob_options["multiple_inputs"] = True
20
+ model_meta.models[name].options = model_blob_options
@@ -0,0 +1,19 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TensorflowHandlerMigrator20250101(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2025-01-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+ model_blob_metadata = model_meta.models[name]
17
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
18
+ model_blob_options["multiple_inputs"] = True
19
+ model_meta.models[name].options = model_blob_options
@@ -0,0 +1,20 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TorchScriptHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
19
+ model_blob_options["multiple_inputs"] = True
20
+ model_meta.models[name].options = model_blob_options
@@ -1,2 +1 @@
1
- REQUIREMENTS = ['cloudpickle>=2.0.0']
2
- ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
1
+ REQUIREMENTS = ['cloudpickle>=2.0.0,<3']