snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/file_utils.py +18 -4
  4. snowflake/ml/_internal/platform_capabilities.py +3 -0
  5. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  6. snowflake/ml/_internal/telemetry.py +25 -0
  7. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +0 -1
  11. snowflake/ml/jobs/_utils/constants.py +31 -1
  12. snowflake/ml/jobs/_utils/payload_utils.py +232 -72
  13. snowflake/ml/jobs/_utils/spec_utils.py +78 -38
  14. snowflake/ml/jobs/decorators.py +8 -25
  15. snowflake/ml/jobs/job.py +4 -4
  16. snowflake/ml/jobs/manager.py +5 -0
  17. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  18. snowflake/ml/model/_client/ops/model_ops.py +107 -14
  19. snowflake/ml/model/_client/ops/service_ops.py +1 -1
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  21. snowflake/ml/model/_client/sql/model_version.py +58 -0
  22. snowflake/ml/model/_client/sql/service.py +8 -2
  23. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  26. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  27. snowflake/ml/model/_packager/model_env/model_env.py +49 -29
  28. snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
  29. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
  30. snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
  31. snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
  32. snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
  33. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
  34. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
  35. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  36. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  37. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  38. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  39. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  40. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  41. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  42. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
  43. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
  44. snowflake/ml/model/_packager/model_packager.py +3 -5
  45. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  46. snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  48. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  49. snowflake/ml/model/_signatures/core.py +54 -33
  50. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  51. snowflake/ml/model/_signatures/numpy_handler.py +12 -20
  52. snowflake/ml/model/_signatures/pandas_handler.py +28 -37
  53. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  54. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  55. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  56. snowflake/ml/model/_signatures/utils.py +120 -8
  57. snowflake/ml/model/custom_model.py +13 -4
  58. snowflake/ml/model/model_signature.py +39 -13
  59. snowflake/ml/model/type_hints.py +28 -2
  60. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  61. snowflake/ml/modeling/metrics/ranking.py +3 -0
  62. snowflake/ml/modeling/metrics/regression.py +3 -0
  63. snowflake/ml/modeling/pipeline/pipeline.py +18 -1
  64. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  65. snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
  66. snowflake/ml/registry/_manager/model_manager.py +55 -7
  67. snowflake/ml/registry/registry.py +52 -4
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
  70. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
  71. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
  72. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,14 @@ from snowflake.ml._internal import type_utils
8
8
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
9
9
  from snowflake.ml.model._packager.model_env import model_env
10
10
  from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
11
- from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
11
+ from snowflake.ml.model._packager.model_handlers_migrator import (
12
+ base_migrator,
13
+ torchscript_migrator_2023_12_01,
14
+ )
12
15
  from snowflake.ml.model._packager.model_meta import (
13
16
  model_blob_meta,
14
17
  model_meta as model_meta_api,
18
+ model_meta_schema,
15
19
  )
16
20
  from snowflake.ml.model._signatures import (
17
21
  pytorch_handler,
@@ -30,9 +34,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
30
34
  """
31
35
 
32
36
  HANDLER_TYPE = "torchscript"
33
- HANDLER_VERSION = "2023-12-01"
34
- _MIN_SNOWPARK_ML_VERSION = "1.0.12"
35
- _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
37
+ HANDLER_VERSION = "2025-03-01"
38
+ _MIN_SNOWPARK_ML_VERSION = "1.8.0"
39
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
40
+ "2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
41
+ }
36
42
 
37
43
  MODEL_BLOB_FILE_OR_DIR = "model.pt"
38
44
  DEFAULT_TARGET_METHODS = ["forward"]
@@ -81,22 +87,32 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
81
87
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
82
88
  )
83
89
 
90
+ multiple_inputs = kwargs.get("multiple_inputs", False)
91
+
84
92
  def get_prediction(
85
93
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
86
94
  ) -> model_types.SupportedLocalDataType:
87
- if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
88
- sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
89
- model_signature._convert_local_data_to_df(sample_input_data)
90
- )
95
+ if multiple_inputs:
96
+ if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
97
+ sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
98
+ model_signature._convert_local_data_to_df(sample_input_data)
99
+ )
100
+ else:
101
+ if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
102
+ sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
103
+ model_signature._convert_local_data_to_df(sample_input_data)
104
+ )
91
105
 
92
106
  model.eval()
93
107
  target_method = getattr(model, target_method_name, None)
94
108
  assert callable(target_method)
95
109
  with torch.no_grad():
96
- predictions_df = target_method(*sample_input_data)
97
-
98
- if isinstance(predictions_df, torch.Tensor):
99
- predictions_df = [predictions_df]
110
+ if multiple_inputs:
111
+ predictions_df = target_method(*sample_input_data)
112
+ if not isinstance(predictions_df, tuple):
113
+ predictions_df = [predictions_df]
114
+ else:
115
+ predictions_df = target_method(sample_input_data)
100
116
 
101
117
  return predictions_df
102
118
 
@@ -117,6 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
117
133
  model_type=cls.HANDLER_TYPE,
118
134
  handler_version=cls.HANDLER_VERSION,
119
135
  path=cls.MODEL_BLOB_FILE_OR_DIR,
136
+ options=model_meta_schema.TorchScriptModelBlobOptions(multiple_inputs=multiple_inputs),
120
137
  )
121
138
  model_meta.models[name] = base_meta
122
139
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -170,6 +187,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
170
187
  signature: model_signature.ModelSignature,
171
188
  target_method: str,
172
189
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
190
+ multiple_inputs = cast(
191
+ model_meta_schema.TorchScriptModelBlobOptions, model_meta.models[model_meta.name].options
192
+ )["multiple_inputs"]
193
+
173
194
  @custom_model.inference_api
174
195
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
175
196
  if X.isnull().any(axis=None):
@@ -179,19 +200,27 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
179
200
 
180
201
  raw_model.eval()
181
202
 
182
- t = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
203
+ if multiple_inputs:
204
+ st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
183
205
 
184
- if kwargs.get("use_gpu", False):
185
- t = [element.cuda() for element in t]
206
+ if kwargs.get("use_gpu", False):
207
+ st = [element.cuda() for element in st]
186
208
 
187
- with torch.no_grad():
188
- res = getattr(raw_model, target_method)(*t)
209
+ with torch.no_grad():
210
+ res = getattr(raw_model, target_method)(*st)
189
211
 
190
- if isinstance(res, torch.Tensor):
191
- res = [res]
212
+ if not isinstance(res, tuple):
213
+ res = [res]
214
+ else:
215
+ t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
216
+ if kwargs.get("use_gpu", False):
217
+ t = t.cuda()
192
218
 
219
+ with torch.no_grad():
220
+ res = getattr(raw_model, target_method)(t)
193
221
  return model_signature_utils.rename_pandas_df(
194
- data=pytorch_handler.SeqOfPyTorchTensorHandler.convert_to_df(res), features=signature.outputs
222
+ model_signature._convert_local_data_to_df(res, ensure_serializable=True),
223
+ features=signature.outputs,
195
224
  )
196
225
 
197
226
  return fn
@@ -99,10 +99,10 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
99
99
  def get_prediction(
100
100
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
101
101
  ) -> model_types.SupportedLocalDataType:
102
- if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
102
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray, xgboost.DMatrix)):
103
103
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
104
104
 
105
- if isinstance(model, xgboost.Booster):
105
+ if isinstance(model, xgboost.Booster) and not isinstance(sample_input_data, xgboost.DMatrix):
106
106
  sample_input_data = xgboost.DMatrix(sample_input_data)
107
107
 
108
108
  target_method = getattr(model, target_method_name, None)
@@ -0,0 +1,20 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class PyTorchHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
19
+ model_blob_options["multiple_inputs"] = True
20
+ model_meta.models[name].options = model_blob_options
@@ -0,0 +1,48 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TensorflowHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-01-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
19
+ # To migrate code <= 1.7.0, default to keras model
20
+ is_old_model = "save_format" not in model_blob_options and "is_keras_model" not in model_blob_options
21
+ # To migrate code form 1.7.1, default to False.
22
+ is_keras_model = model_blob_options.get("is_keras_model", False)
23
+ # To migrate code from 1.7.2, default to tf, has options keras, keras_tf, cloudpickle, tf
24
+ #
25
+ # if is_keras_model or is_tf_keras_model:
26
+ # if is_keras_functional_or_sequential_model:
27
+ # save_format = "keras"
28
+ # elif keras_version.major == 2 or is_tf_keras_model:
29
+ # save_format = "keras_tf"
30
+ # else:
31
+ # save_format = "cloudpickle"
32
+ # else:
33
+ # save_format = "tf"
34
+ #
35
+ save_format = model_blob_options.get("save_format", "tf")
36
+
37
+ if save_format == "keras" or is_keras_model or is_old_model:
38
+ save_format = "keras_tf"
39
+ elif save_format == "cloudpickle":
40
+ # Given the old logic, this could only happen if the original model is a keras model, and keras is 3.x
41
+ # However, in this case, keras.Model does not extends from tensorflow.Module
42
+ # So actually TensorflowHandler will not be triggered, we could safely error this out.
43
+ raise NotImplementedError(
44
+ "Unable to upgrade keras 3.x model saved by old handler. This is not supposed to happen"
45
+ )
46
+
47
+ model_blob_options["save_format"] = save_format
48
+ model_meta.models[name].options = model_blob_options
@@ -0,0 +1,19 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TensorflowHandlerMigrator20250101(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2025-01-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+ model_blob_metadata = model_meta.models[name]
17
+ model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
18
+ model_blob_options["multiple_inputs"] = True
19
+ model_meta.models[name].options = model_blob_options
@@ -0,0 +1,20 @@
1
+ from typing import cast
2
+
3
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
4
+ from snowflake.ml.model._packager.model_meta import (
5
+ model_meta as model_meta_api,
6
+ model_meta_schema,
7
+ )
8
+
9
+
10
+ class TorchScriptHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
11
+ source_version = "2023-12-01"
12
+ target_version = "2025-03-01"
13
+
14
+ @staticmethod
15
+ def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
16
+
17
+ model_blob_metadata = model_meta.models[name]
18
+ model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
19
+ model_blob_options["multiple_inputs"] = True
20
+ model_meta.models[name].options = model_blob_options
@@ -1,2 +1 @@
1
1
  REQUIREMENTS = ['cloudpickle>=2.0.0']
2
- ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
@@ -48,6 +48,7 @@ def create_model_metadata(
48
48
  ext_modules: Optional[List[ModuleType]] = None,
49
49
  conda_dependencies: Optional[List[str]] = None,
50
50
  pip_requirements: Optional[List[str]] = None,
51
+ artifact_repository_map: Optional[Dict[str, str]] = None,
51
52
  python_version: Optional[str] = None,
52
53
  task: model_types.Task = model_types.Task.UNKNOWN,
53
54
  **kwargs: Any,
@@ -67,6 +68,7 @@ def create_model_metadata(
67
68
  ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
68
69
  conda_dependencies: List of conda requirements for running the model. Defaults to None.
69
70
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
71
+ artifact_repository_map: A dict mapping from package channel to artifact repository name.
70
72
  python_version: A string of python version where model is run. Used for user override. If specified as None,
71
73
  current version would be captured. Defaults to None.
72
74
  task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
@@ -102,6 +104,7 @@ def create_model_metadata(
102
104
  env = _create_env_for_model_metadata(
103
105
  conda_dependencies=conda_dependencies,
104
106
  pip_requirements=pip_requirements,
107
+ artifact_repository_map=artifact_repository_map,
105
108
  python_version=python_version,
106
109
  embed_local_ml_library=embed_local_ml_library,
107
110
  )
@@ -151,6 +154,7 @@ def _create_env_for_model_metadata(
151
154
  *,
152
155
  conda_dependencies: Optional[List[str]] = None,
153
156
  pip_requirements: Optional[List[str]] = None,
157
+ artifact_repository_map: Optional[Dict[str, str]] = None,
154
158
  python_version: Optional[str] = None,
155
159
  embed_local_ml_library: bool = False,
156
160
  ) -> model_env.ModelEnv:
@@ -159,6 +163,7 @@ def _create_env_for_model_metadata(
159
163
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
160
164
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
161
165
  env.pip_requirements = pip_requirements # type: ignore[assignment]
166
+ env.artifact_repository_map = artifact_repository_map
162
167
  env.python_version = python_version # type: ignore[assignment]
163
168
  env.snowpark_ml_version = snowml_env.VERSION
164
169
 
@@ -331,7 +336,6 @@ class ModelMetadata:
331
336
  "function_properties": self.function_properties,
332
337
  }
333
338
  )
334
-
335
339
  with open(model_yaml_path, "w", encoding="utf-8") as out:
336
340
  yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
337
341
  yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
@@ -352,7 +356,7 @@ class ModelMetadata:
352
356
  version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(snowml_env.VERSION)
353
357
  ):
354
358
  raise RuntimeError(
355
- f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version},"
359
+ f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version}, "
356
360
  f"while current version of Snowpark ML library is {snowml_env.VERSION}."
357
361
  )
358
362
  return model_meta_schema.ModelMetadataDict(
@@ -18,6 +18,7 @@ class FunctionProperties(Enum):
18
18
  class ModelRuntimeDependenciesDict(TypedDict):
19
19
  conda: Required[str]
20
20
  pip: Required[str]
21
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
21
22
 
22
23
 
23
24
  class ModelRuntimeDict(TypedDict):
@@ -28,6 +29,7 @@ class ModelRuntimeDict(TypedDict):
28
29
  class ModelEnvDict(TypedDict):
29
30
  conda: Required[str]
30
31
  pip: Required[str]
32
+ artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
31
33
  python_version: Required[str]
32
34
  cuda_version: NotRequired[Optional[str]]
33
35
  snowpark_ml_version: Required[str]
@@ -44,6 +46,9 @@ class CatBoostModelBlobOptions(BaseModelBlobOptions):
44
46
  class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
45
47
  task: Required[str]
46
48
  batch_size: Required[int]
49
+ has_tokenizer: NotRequired[bool]
50
+ has_feature_extractor: NotRequired[bool]
51
+ has_image_preprocessor: NotRequired[bool]
47
52
 
48
53
 
49
54
  class LightGBMModelBlobOptions(BaseModelBlobOptions):
@@ -58,8 +63,17 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
58
63
  xgb_estimator_type: Required[str]
59
64
 
60
65
 
66
+ class PyTorchModelBlobOptions(BaseModelBlobOptions):
67
+ multiple_inputs: Required[bool]
68
+
69
+
70
+ class TorchScriptModelBlobOptions(BaseModelBlobOptions):
71
+ multiple_inputs: Required[bool]
72
+
73
+
61
74
  class TensorflowModelBlobOptions(BaseModelBlobOptions):
62
75
  save_format: Required[str]
76
+ multiple_inputs: Required[bool]
63
77
 
64
78
 
65
79
  class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
@@ -71,6 +85,8 @@ ModelBlobOptions = Union[
71
85
  HuggingFacePipelineModelBlobOptions,
72
86
  MLFlowModelBlobOptions,
73
87
  XgboostModelBlobOptions,
88
+ PyTorchModelBlobOptions,
89
+ TorchScriptModelBlobOptions,
74
90
  TensorflowModelBlobOptions,
75
91
  SentenceTransformersModelBlobOptions,
76
92
  ]
@@ -43,13 +43,13 @@ class ModelPackager:
43
43
  metadata: Optional[Dict[str, str]] = None,
44
44
  conda_dependencies: Optional[List[str]] = None,
45
45
  pip_requirements: Optional[List[str]] = None,
46
+ artifact_repository_map: Optional[Dict[str, str]] = None,
46
47
  python_version: Optional[str] = None,
47
48
  ext_modules: Optional[List[ModuleType]] = None,
48
49
  code_paths: Optional[List[str]] = None,
49
- options: Optional[model_types.ModelSaveOption] = None,
50
+ options: model_types.ModelSaveOption,
50
51
  task: model_types.Task = model_types.Task.UNKNOWN,
51
52
  ) -> model_meta.ModelMetadata:
52
-
53
53
  if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
54
54
  raise snowml_exceptions.SnowflakeMLException(
55
55
  error_code=error_codes.INVALID_ARGUMENT,
@@ -58,9 +58,6 @@ class ModelPackager:
58
58
  ),
59
59
  )
60
60
 
61
- if not options:
62
- options = model_types.BaseModelSaveOption()
63
-
64
61
  handler = model_handler.find_handler(model)
65
62
  if handler is None:
66
63
  raise snowml_exceptions.SnowflakeMLException(
@@ -77,6 +74,7 @@ class ModelPackager:
77
74
  ext_modules=ext_modules,
78
75
  conda_dependencies=conda_dependencies,
79
76
  pip_requirements=pip_requirements,
77
+ artifact_repository_map=artifact_repository_map,
80
78
  python_version=python_version,
81
79
  task=task,
82
80
  **options,
@@ -1,2 +1 @@
1
- REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
2
- ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
1
+ REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.12.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
@@ -45,6 +45,7 @@ class ModelRuntime:
45
45
  self.name = name
46
46
  self.runtime_env = copy.deepcopy(env)
47
47
  self.imports = imports or []
48
+ self.is_gpu = is_gpu
48
49
 
49
50
  if loading_from_file:
50
51
  return
@@ -88,13 +89,18 @@ class ModelRuntime:
88
89
  self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
89
90
  self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
90
91
 
91
- env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
92
+ env_dict = self.runtime_env.save_as_dict(
93
+ packager_path, default_channel_override=default_channel_override, is_gpu=self.is_gpu
94
+ )
92
95
 
93
96
  return model_meta_schema.ModelRuntimeDict(
94
97
  imports=list(map(str, self.imports)),
95
98
  dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
96
99
  conda=env_dict["conda"],
97
100
  pip=env_dict["pip"],
101
+ artifact_repository_map=env_dict["artifact_repository_map"]
102
+ if env_dict.get("artifact_repository_map") is not None
103
+ else {},
98
104
  ),
99
105
  )
100
106
 
@@ -109,6 +115,7 @@ class ModelRuntime:
109
115
  env.python_version = meta_env.python_version
110
116
  env.cuda_version = meta_env.cuda_version
111
117
  env.snowpark_ml_version = meta_env.snowpark_ml_version
118
+ env.artifact_repository_map = meta_env.artifact_repository_map
112
119
 
113
120
  conda_env_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["conda"])
114
121
  pip_requirements_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["pip"])
@@ -24,7 +24,11 @@ def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pi
24
24
  from sklearn.base import is_classifier, is_regressor
25
25
 
26
26
  if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
27
- return type_hints.Task.UNKNOWN
27
+ if hasattr(model, "predict_proba") or hasattr(model, "predict"):
28
+ model = model.steps[-1][1] # type: ignore[attr-defined]
29
+ return _get_model_task(model)
30
+ else:
31
+ return type_hints.Task.UNKNOWN
28
32
  if is_regressor(model):
29
33
  return type_hints.Task.TABULAR_REGRESSION
30
34
  if is_classifier(model):
@@ -14,21 +14,32 @@ from snowflake.ml.model._signatures import base_handler, core, pandas_handler
14
14
 
15
15
 
16
16
  class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBuiltinsList]):
17
+ @staticmethod
18
+ def _can_handle_element(
19
+ element: model_types._SupportedBuiltins,
20
+ ) -> TypeGuard[model_types._SupportedBuiltins]:
21
+ if isinstance(element, abc.Sequence) and not isinstance(element, str):
22
+ for sub_element in element:
23
+ if not ListOfBuiltinHandler._can_handle_element(sub_element):
24
+ return False
25
+ return True
26
+ elif isinstance(element, abc.Mapping):
27
+ for key, value in element.items():
28
+ if not isinstance(key, str):
29
+ return False
30
+ if not ListOfBuiltinHandler._can_handle_element(value):
31
+ return False
32
+ return True
33
+ else:
34
+ return isinstance(element, (int, float, bool, str, datetime.datetime))
35
+
17
36
  @staticmethod
18
37
  def can_handle(data: model_types.SupportedDataType) -> TypeGuard[model_types._SupportedBuiltinsList]:
19
38
  if not isinstance(data, abc.Sequence) or isinstance(data, str):
20
39
  return False
21
40
  if len(data) == 0:
22
41
  return False
23
- can_handle = True
24
- for element in data:
25
- # String is a Sequence but we take them as an whole
26
- if isinstance(element, abc.Sequence) and not isinstance(element, str):
27
- can_handle = ListOfBuiltinHandler.can_handle(element)
28
- elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
29
- can_handle = False
30
- break
31
- return can_handle
42
+ return ListOfBuiltinHandler._can_handle_element(data)
32
43
 
33
44
  @staticmethod
34
45
  def count(data: model_types._SupportedBuiltinsList) -> int: