snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +24 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +73 -31
  9. snowflake/ml/jobs/decorators.py +3 -0
  10. snowflake/ml/jobs/manager.py +5 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  13. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +8 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/metrics/ranking.py +3 -0
  52. snowflake/ml/modeling/metrics/regression.py +3 -0
  53. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  54. snowflake/ml/registry/_manager/model_manager.py +55 -7
  55. snowflake/ml/registry/registry.py +18 -0
  56. snowflake/ml/version.py +1 -1
  57. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +287 -11
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +61 -57
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from snowflake.ml._internal import type_utils
10
10
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._packager.model_env import model_env
12
12
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
13
+ from snowflake.ml.model._packager.model_handlers_migrator import (
14
+ base_migrator,
15
+ pytorch_migrator_2023_12_01,
16
+ )
14
17
  from snowflake.ml.model._packager.model_meta import (
15
18
  model_blob_meta,
16
19
  model_meta as model_meta_api,
20
+ model_meta_schema,
17
21
  )
18
22
  from snowflake.ml.model._signatures import (
19
23
  pytorch_handler,
@@ -21,7 +25,6 @@ from snowflake.ml.model._signatures import (
21
25
  )
22
26
 
23
27
  if TYPE_CHECKING:
24
- import sentence_transformers # noqa: F401
25
28
  import torch
26
29
 
27
30
 
@@ -33,9 +36,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
33
36
  """
34
37
 
35
38
  HANDLER_TYPE = "pytorch"
36
- HANDLER_VERSION = "2023-12-01"
37
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
38
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
+ HANDLER_VERSION = "2025-03-01"
40
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
41
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
42
+ "2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
43
+ }
39
44
 
40
45
  MODEL_BLOB_FILE_OR_DIR = "model.pt"
41
46
  DEFAULT_TARGET_METHODS = ["forward"]
@@ -89,22 +94,33 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
89
94
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
90
95
  )
91
96
 
97
+ multiple_inputs = kwargs.get("multiple_inputs", False)
98
+
92
99
  def get_prediction(
93
100
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
94
101
  ) -> model_types.SupportedLocalDataType:
95
- if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
96
- sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
97
- model_signature._convert_local_data_to_df(sample_input_data)
98
- )
102
+ if multiple_inputs:
103
+ if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
104
+ sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
105
+ model_signature._convert_local_data_to_df(sample_input_data)
106
+ )
107
+ else:
108
+ if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
109
+ sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
110
+ model_signature._convert_local_data_to_df(sample_input_data)
111
+ )
99
112
 
100
113
  model.eval()
101
114
  target_method = getattr(model, target_method_name, None)
102
115
  assert callable(target_method)
103
116
  with torch.no_grad():
104
- predictions_df = target_method(*sample_input_data)
117
+ if multiple_inputs:
118
+ predictions_df = target_method(*sample_input_data)
119
+ if not isinstance(predictions_df, tuple):
120
+ predictions_df = [predictions_df]
121
+ else:
122
+ predictions_df = target_method(sample_input_data)
105
123
 
106
- if isinstance(predictions_df, torch.Tensor):
107
- predictions_df = [predictions_df]
108
124
  return predictions_df
109
125
 
110
126
  model_meta = handlers_utils.validate_signature(
@@ -127,6 +143,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
127
143
  model_type=cls.HANDLER_TYPE,
128
144
  handler_version=cls.HANDLER_VERSION,
129
145
  path=cls.MODEL_BLOB_FILE_OR_DIR,
146
+ options=model_meta_schema.PyTorchModelBlobOptions(multiple_inputs=multiple_inputs),
130
147
  )
131
148
  model_meta.models[name] = base_meta
132
149
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -172,6 +189,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
172
189
  raw_model: "torch.nn.Module",
173
190
  model_meta: model_meta_api.ModelMetadata,
174
191
  ) -> Type[custom_model.CustomModel]:
192
+ multiple_inputs = cast(
193
+ model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
194
+ )["multiple_inputs"]
195
+
175
196
  def fn_factory(
176
197
  raw_model: "torch.nn.Module",
177
198
  signature: model_signature.ModelSignature,
@@ -183,19 +204,28 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
183
204
  raise ValueError("Tensor cannot handle null values.")
184
205
 
185
206
  raw_model.eval()
186
- t = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
207
+ if multiple_inputs:
208
+ st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
209
+
210
+ if kwargs.get("use_gpu", False):
211
+ st = [element.cuda() for element in st]
187
212
 
188
- if kwargs.get("use_gpu", False):
189
- t = [element.cuda() for element in t]
213
+ with torch.no_grad():
214
+ res = getattr(raw_model, target_method)(*st)
190
215
 
191
- with torch.no_grad():
192
- res = getattr(raw_model, target_method)(*t)
216
+ if not isinstance(res, tuple):
217
+ res = [res]
218
+ else:
219
+ t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
220
+ if kwargs.get("use_gpu", False):
221
+ t = t.cuda()
193
222
 
194
- if isinstance(res, torch.Tensor):
195
- res = [res]
223
+ with torch.no_grad():
224
+ res = getattr(raw_model, target_method)(t)
196
225
 
197
226
  return model_signature_utils.rename_pandas_df(
198
- data=pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
227
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
228
+ features=signature.outputs,
199
229
  )
200
230
 
201
231
  return fn
@@ -297,10 +297,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
297
297
  df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
298
298
  except TypeError:
299
299
  try:
300
- dtype_map = {
301
- spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
302
- for spec in signature.inputs
303
- }
300
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
304
301
 
305
302
  if isinstance(X, pd.DataFrame):
306
303
  X = X.astype(dtype_map, copy=False)
@@ -307,8 +307,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
307
307
  except TypeError:
308
308
  try:
309
309
  dtype_map = {
310
- spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
311
- for spec in signature.inputs
310
+ spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
312
311
  }
313
312
 
314
313
  if isinstance(X, pd.DataFrame):
@@ -1,7 +1,6 @@
1
1
  import os
2
2
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
3
 
4
- import numpy as np
5
4
  import pandas as pd
6
5
  from packaging import version
7
6
  from typing_extensions import TypeGuard, Unpack
@@ -13,6 +12,7 @@ from snowflake.ml.model._packager.model_handlers import _base, _utils as handler
13
12
  from snowflake.ml.model._packager.model_handlers_migrator import (
14
13
  base_migrator,
15
14
  tensorflow_migrator_2023_12_01,
15
+ tensorflow_migrator_2025_01_01,
16
16
  )
17
17
  from snowflake.ml.model._packager.model_meta import (
18
18
  model_blob_meta,
@@ -20,7 +20,6 @@ from snowflake.ml.model._packager.model_meta import (
20
20
  model_meta_schema,
21
21
  )
22
22
  from snowflake.ml.model._signatures import (
23
- numpy_handler,
24
23
  tensorflow_handler,
25
24
  utils as model_signature_utils,
26
25
  )
@@ -37,10 +36,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
37
36
  """
38
37
 
39
38
  HANDLER_TYPE = "tensorflow"
40
- HANDLER_VERSION = "2025-01-01"
41
- _MIN_SNOWPARK_ML_VERSION = "1.7.5"
39
+ HANDLER_VERSION = "2025-03-01"
40
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
42
41
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
43
- "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201
42
+ "2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
43
+ "2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
44
44
  }
45
45
 
46
46
  MODEL_BLOB_FILE_OR_DIR = "model"
@@ -112,25 +112,35 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
112
112
  default_target_methods=default_target_methods,
113
113
  )
114
114
 
115
+ multiple_inputs = kwargs.get("multiple_inputs", False)
116
+
115
117
  if is_keras_model and len(target_methods) > 1:
116
118
  raise ValueError("Keras model can only have one target method.")
117
119
 
118
120
  def get_prediction(
119
121
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
120
122
  ) -> model_types.SupportedLocalDataType:
121
- if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
122
- sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
123
- model_signature._convert_local_data_to_df(sample_input_data)
124
- )
123
+ if multiple_inputs:
124
+ if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
125
+ sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
126
+ model_signature._convert_local_data_to_df(sample_input_data)
127
+ )
128
+ else:
129
+ if not tensorflow_handler.TensorflowTensorHandler.can_handle(sample_input_data):
130
+ sample_input_data = tensorflow_handler.TensorflowTensorHandler.convert_from_df(
131
+ model_signature._convert_local_data_to_df(sample_input_data)
132
+ )
125
133
 
126
134
  target_method = getattr(model, target_method_name, None)
127
135
  assert callable(target_method)
128
136
  for tensor in sample_input_data:
129
137
  tensorflow.stop_gradient(tensor)
130
- predictions_df = target_method(*sample_input_data)
131
-
132
- if isinstance(predictions_df, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
133
- predictions_df = [predictions_df]
138
+ if multiple_inputs:
139
+ predictions_df = target_method(*sample_input_data)
140
+ if not isinstance(predictions_df, tuple):
141
+ predictions_df = [predictions_df]
142
+ else:
143
+ predictions_df = target_method(sample_input_data)
134
144
 
135
145
  return predictions_df
136
146
 
@@ -159,7 +169,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
159
169
  model_type=cls.HANDLER_TYPE,
160
170
  handler_version=cls.HANDLER_VERSION,
161
171
  path=cls.MODEL_BLOB_FILE_OR_DIR,
162
- options=model_meta_schema.TensorflowModelBlobOptions(save_format=save_format),
172
+ options=model_meta_schema.TensorflowModelBlobOptions(
173
+ save_format=save_format, multiple_inputs=multiple_inputs
174
+ ),
163
175
  )
164
176
  model_meta.models[name] = base_meta
165
177
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -219,6 +231,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
219
231
  raw_model: "tensorflow.Module",
220
232
  model_meta: model_meta_api.ModelMetadata,
221
233
  ) -> Type[custom_model.CustomModel]:
234
+ multiple_inputs = cast(
235
+ model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
236
+ )["multiple_inputs"]
237
+
222
238
  def fn_factory(
223
239
  raw_model: "tensorflow.Module",
224
240
  signature: model_signature.ModelSignature,
@@ -229,21 +245,25 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
229
245
  if X.isnull().any(axis=None):
230
246
  raise ValueError("Tensor cannot handle null values.")
231
247
 
232
- t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
248
+ if multiple_inputs:
249
+ t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
233
250
 
234
- for tensor in t:
235
- tensorflow.stop_gradient(tensor)
236
- res = getattr(raw_model, target_method)(*t)
251
+ for tensor in t:
252
+ tensorflow.stop_gradient(tensor)
253
+ res = getattr(raw_model, target_method)(*t)
237
254
 
238
- if isinstance(res, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
239
- res = [res]
240
-
241
- if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
242
- # In case of running on CPU, it will return numpy array
243
- df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
255
+ if not isinstance(res, tuple):
256
+ res = [res]
244
257
  else:
245
- df = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_to_df(res)
246
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
258
+ t = tensorflow_handler.TensorflowTensorHandler.convert_from_df(X, signature.inputs)
259
+
260
+ tensorflow.stop_gradient(t)
261
+ res = getattr(raw_model, target_method)(t)
262
+
263
+ return model_signature_utils.rename_pandas_df(
264
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
265
+ features=signature.outputs,
266
+ )
247
267
 
248
268
  return fn
249
269
 
@@ -8,10 +8,14 @@ from snowflake.ml._internal import type_utils
8
8
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
9
9
  from snowflake.ml.model._packager.model_env import model_env
10
10
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
11
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
11
+ from snowflake.ml.model._packager.model_handlers_migrator import (
12
+ base_migrator,
13
+ torchscript_migrator_2023_12_01,
14
+ )
12
15
  from snowflake.ml.model._packager.model_meta import (
13
16
  model_blob_meta,
14
17
  model_meta as model_meta_api,
18
+ model_meta_schema,
15
19
  )
16
20
  from snowflake.ml.model._signatures import (
17
21
  pytorch_handler,
@@ -30,9 +34,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
30
34
  """
31
35
 
32
36
  HANDLER_TYPE = "torchscript"
33
- HANDLER_VERSION = "2023-12-01"
34
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
35
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
37
+ HANDLER_VERSION = "2025-03-01"
38
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
39
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
40
+ "2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
41
+ }
36
42
 
37
43
  MODEL_BLOB_FILE_OR_DIR = "model.pt"
38
44
  DEFAULT_TARGET_METHODS = ["forward"]
@@ -81,22 +87,32 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
81
87
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
82
88
  )
83
89
 
90
+ multiple_inputs = kwargs.get("multiple_inputs", False)
91
+
84
92
  def get_prediction(
85
93
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
86
94
  ) -> model_types.SupportedLocalDataType:
87
- if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
88
- sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
89
- model_signature._convert_local_data_to_df(sample_input_data)
90
- )
95
+ if multiple_inputs:
96
+ if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
97
+ sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
98
+ model_signature._convert_local_data_to_df(sample_input_data)
99
+ )
100
+ else:
101
+ if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
102
+ sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
103
+ model_signature._convert_local_data_to_df(sample_input_data)
104
+ )
91
105
 
92
106
  model.eval()
93
107
  target_method = getattr(model, target_method_name, None)
94
108
  assert callable(target_method)
95
109
  with torch.no_grad():
96
- predictions_df = target_method(*sample_input_data)
97
-
98
- if isinstance(predictions_df, torch.Tensor):
99
- predictions_df = [predictions_df]
110
+ if multiple_inputs:
111
+ predictions_df = target_method(*sample_input_data)
112
+ if not isinstance(predictions_df, tuple):
113
+ predictions_df = [predictions_df]
114
+ else:
115
+ predictions_df = target_method(sample_input_data)
100
116
 
101
117
  return predictions_df
102
118
 
@@ -117,6 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
117
133
  model_type=cls.HANDLER_TYPE,
118
134
  handler_version=cls.HANDLER_VERSION,
119
135
  path=cls.MODEL_BLOB_FILE_OR_DIR,
136
+ options=model_meta_schema.TorchScriptModelBlobOptions(multiple_inputs=multiple_inputs),
120
137
  )
121
138
  model_meta.models[name] = base_meta
122
139
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -170,6 +187,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
170
187
  signature: model_signature.ModelSignature,
171
188
  target_method: str,
172
189
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
190
+ multiple_inputs = cast(
191
+ model_meta_schema.TorchScriptModelBlobOptions, model_meta.models[model_meta.name].options
192
+ )["multiple_inputs"]
193
+
173
194
  @custom_model.inference_api
174
195
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
175
196
  if X.isnull().any(axis=None):
@@ -179,19 +200,27 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
179
200
 
180
201
  raw_model.eval()
181
202
 
182
- t = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
203
+ if multiple_inputs:
204
+ st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
183
205
 
184
- if kwargs.get("use_gpu", False):
185
- t = [element.cuda() for element in t]
206
+ if kwargs.get("use_gpu", False):
207
+ st = [element.cuda() for element in st]
186
208
 
187
- with torch.no_grad():
188
- res = getattr(raw_model, target_method)(*t)
209
+ with torch.no_grad():
210
+ res = getattr(raw_model, target_method)(*st)
189
211
 
190
- if isinstance(res, torch.Tensor):
191
- res = [res]
212
+ if not isinstance(res, tuple):
213
+ res = [res]
214
+ else:
215
+ t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
216
+ if kwargs.get("use_gpu", False):
217
+ t = t.cuda()
192
218
 
219
+ with torch.no_grad():
220
+ res = getattr(raw_model, target_method)(t)
193
221
  return model_signature_utils.rename_pandas_df(
194
- data=pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
222
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
223
+ features=signature.outputs,
195
224
  )
196
225
 
197
226
  return fn
@@ -99,10 +99,10 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
99
99
  def get_prediction(
100
100
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
101
101
  ) -> model_types.SupportedLocalDataType:
102
- if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
102
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray, xgboost.DMatrix)):
103
103
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
104
104
 
105
- if isinstance(model, xgboost.Booster):
105
+ if isinstance(model, xgboost.Booster) and not isinstance(sample_input_data, xgboost.DMatrix):
106
106
  sample_input_data = xgboost.DMatrix(sample_input_data)
107
107
 
108
108
  target_method = getattr(model, target_method_name, None)
@@ -0,0 +1,20 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class PyTorchHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
19
+ model_blob_options["multiple_inputs"] = True
20
+ model_meta.models[name].options = model_blob_options
@@ -0,0 +1,19 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TensorflowHandlerMigrator20250101(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2025-01-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+ model_blob_metadata = model_meta.models[name]
17
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
18
+ model_blob_options["multiple_inputs"] = True
19
+ model_meta.models[name].options = model_blob_options
@@ -0,0 +1,20 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TorchScriptHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
19
+ model_blob_options["multiple_inputs"] = True
20
+ model_meta.models[name].options = model_blob_options
@@ -1,2 +1 @@
1
1
  REQUIREMENTS = ['cloudpickle>=2.0.0']
2
- ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
@@ -48,6 +48,7 @@ def create_model_metadata(
48
48
  ext_modules: Optional[List[ModuleType]] = None,
49
49
  conda_dependencies: Optional[List[str]] = None,
50
50
  pip_requirements: Optional[List[str]] = None,
51
+ artifact_repository_map: Optional[Dict[str, str]] = None,
51
52
  python_version: Optional[str] = None,
52
53
  task: model_types.Task = model_types.Task.UNKNOWN,
53
54
  **kwargs: Any,
@@ -67,6 +68,7 @@ def create_model_metadata(
67
68
  ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
68
69
  conda_dependencies: List of conda requirements for running the model. Defaults to None.
69
70
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
71
+ artifact_repository_map: A dict mapping from package channel to artifact repository name.
70
72
  python_version: A string of python version where model is run. Used for user override. If specified as None,
71
73
  current version would be captured. Defaults to None.
72
74
  task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
@@ -102,6 +104,7 @@ def create_model_metadata(
102
104
  env = _create_env_for_model_metadata(
103
105
  conda_dependencies=conda_dependencies,
104
106
  pip_requirements=pip_requirements,
107
+ artifact_repository_map=artifact_repository_map,
105
108
  python_version=python_version,
106
109
  embed_local_ml_library=embed_local_ml_library,
107
110
  )
@@ -151,6 +154,7 @@ def _create_env_for_model_metadata(
151
154
  *,
152
155
  conda_dependencies: Optional[List[str]] = None,
153
156
  pip_requirements: Optional[List[str]] = None,
157
+ artifact_repository_map: Optional[Dict[str, str]] = None,
154
158
  python_version: Optional[str] = None,
155
159
  embed_local_ml_library: bool = False,
156
160
  ) -> model_env.ModelEnv:
@@ -159,6 +163,7 @@ def _create_env_for_model_metadata(
159
163
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
160
164
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
161
165
  env.pip_requirements = pip_requirements # type: ignore[assignment]
166
+ env.artifact_repository_map = artifact_repository_map
162
167
  env.python_version = python_version # type: ignore[assignment]
163
168
  env.snowpark_ml_version = snowml_env.VERSION
164
169
 
@@ -331,7 +336,6 @@ class ModelMetadata:
331
336
  "function_properties": self.function_properties,
332
337
  }
333
338
  )
334
-
335
339
  with open(model_yaml_path, "w", encoding="utf-8") as out:
336
340
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
337
341
  yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
@@ -18,6 +18,7 @@ class FunctionProperties(Enum):
18
18
  class ModelRuntimeDependenciesDict(TypedDict):
19
19
  conda: Required[str]
20
20
  pip: Required[str]
21
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
21
22
 
22
23
 
23
24
  class ModelRuntimeDict(TypedDict):
@@ -28,6 +29,7 @@ class ModelRuntimeDict(TypedDict):
28
29
  class ModelEnvDict(TypedDict):
29
30
  conda: Required[str]
30
31
  pip: Required[str]
32
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
31
33
  python_version: Required[str]
32
34
  cuda_version: NotRequired[Optional[str]]
33
35
  snowpark_ml_version: Required[str]
@@ -61,8 +63,17 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
61
63
  xgb_estimator_type: Required[str]
62
64
 
63
65
 
66
+ class PyTorchModelBlobOptions(BaseModelBlobOptions):
67
+ multiple_inputs: Required[bool]
68
+
69
+
70
+ class TorchScriptModelBlobOptions(BaseModelBlobOptions):
71
+ multiple_inputs: Required[bool]
72
+
73
+
64
74
  class TensorflowModelBlobOptions(BaseModelBlobOptions):
65
75
  save_format: Required[str]
76
+ multiple_inputs: Required[bool]
66
77
 
67
78
 
68
79
  class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
@@ -74,6 +85,8 @@ ModelBlobOptions = Union[
74
85
  HuggingFacePipelineModelBlobOptions,
75
86
  MLFlowModelBlobOptions,
76
87
  XgboostModelBlobOptions,
88
+ PyTorchModelBlobOptions,
89
+ TorchScriptModelBlobOptions,
77
90
  TensorflowModelBlobOptions,
78
91
  SentenceTransformersModelBlobOptions,
79
92
  ]
@@ -43,13 +43,13 @@ class ModelPackager:
43
43
  metadata: Optional[Dict[str, str]] = None,
44
44
  conda_dependencies: Optional[List[str]] = None,
45
45
  pip_requirements: Optional[List[str]] = None,
46
+ artifact_repository_map: Optional[Dict[str, str]] = None,
46
47
  python_version: Optional[str] = None,
47
48
  ext_modules: Optional[List[ModuleType]] = None,
48
49
  code_paths: Optional[List[str]] = None,
49
- options: Optional[model_types.ModelSaveOption] = None,
50
+ options: model_types.ModelSaveOption,
50
51
  task: model_types.Task = model_types.Task.UNKNOWN,
51
52
  ) -> model_meta.ModelMetadata:
52
-
53
53
  if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
54
54
  raise snowml_exceptions.SnowflakeMLException(
55
55
  error_code=error_codes.INVALID_ARGUMENT,
@@ -58,9 +58,6 @@ class ModelPackager:
58
58
  ),
59
59
  )
60
60
 
61
- if not options:
62
- options = model_types.BaseModelSaveOption()
63
-
64
61
  handler = model_handler.find_handler(model)
65
62
  if handler is None:
66
63
  raise snowml_exceptions.SnowflakeMLException(
@@ -77,6 +74,7 @@ class ModelPackager:
77
74
  ext_modules=ext_modules,
78
75
  conda_dependencies=conda_dependencies,
79
76
  pip_requirements=pip_requirements,
77
+ artifact_repository_map=artifact_repository_map,
80
78
  python_version=python_version,
81
79
  task=task,
82
80
  **options,
@@ -1,2 +1 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'keras>=2.0.0,<4', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<3', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.7.0,<3', 'sentencepiece>=0.1.95,<0.2.0', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.17.0,<3', 'tokenizers>=0.15.1,<1', 'torchdata>=0.4,<1', 'transformers>=4.37.2,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.12.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']