snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/telemetry.py +4 -0
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +24 -0
- snowflake/ml/jobs/_utils/payload_utils.py +94 -20
- snowflake/ml/jobs/_utils/spec_utils.py +73 -31
- snowflake/ml/jobs/decorators.py +3 -0
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +4 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
- snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +52 -31
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +9 -17
- snowflake/ml/model/_signatures/pandas_handler.py +19 -30
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +31 -13
- snowflake/ml/model/type_hints.py +13 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +18 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +287 -11
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +61 -57
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,14 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._packager.model_env import model_env
|
12
12
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
13
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
13
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
14
|
+
base_migrator,
|
15
|
+
pytorch_migrator_2023_12_01,
|
16
|
+
)
|
14
17
|
from snowflake.ml.model._packager.model_meta import (
|
15
18
|
model_blob_meta,
|
16
19
|
model_meta as model_meta_api,
|
20
|
+
model_meta_schema,
|
17
21
|
)
|
18
22
|
from snowflake.ml.model._signatures import (
|
19
23
|
pytorch_handler,
|
@@ -21,7 +25,6 @@ from snowflake.ml.model._signatures import (
|
|
21
25
|
)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
|
-
import sentence_transformers # noqa: F401
|
25
28
|
import torch
|
26
29
|
|
27
30
|
|
@@ -33,9 +36,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
33
36
|
"""
|
34
37
|
|
35
38
|
HANDLER_TYPE = "pytorch"
|
36
|
-
HANDLER_VERSION = "
|
37
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0
|
38
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
39
|
+
HANDLER_VERSION = "2025-03-01"
|
40
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
|
+
"2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
|
43
|
+
}
|
39
44
|
|
40
45
|
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
41
46
|
DEFAULT_TARGET_METHODS = ["forward"]
|
@@ -89,22 +94,33 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
89
94
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
90
95
|
)
|
91
96
|
|
97
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
98
|
+
|
92
99
|
def get_prediction(
|
93
100
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
94
101
|
) -> model_types.SupportedLocalDataType:
|
95
|
-
if
|
96
|
-
|
97
|
-
|
98
|
-
|
102
|
+
if multiple_inputs:
|
103
|
+
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
|
104
|
+
sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
|
105
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
|
109
|
+
sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
|
110
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
111
|
+
)
|
99
112
|
|
100
113
|
model.eval()
|
101
114
|
target_method = getattr(model, target_method_name, None)
|
102
115
|
assert callable(target_method)
|
103
116
|
with torch.no_grad():
|
104
|
-
|
117
|
+
if multiple_inputs:
|
118
|
+
predictions_df = target_method(*sample_input_data)
|
119
|
+
if not isinstance(predictions_df, tuple):
|
120
|
+
predictions_df = [predictions_df]
|
121
|
+
else:
|
122
|
+
predictions_df = target_method(sample_input_data)
|
105
123
|
|
106
|
-
if isinstance(predictions_df, torch.Tensor):
|
107
|
-
predictions_df = [predictions_df]
|
108
124
|
return predictions_df
|
109
125
|
|
110
126
|
model_meta = handlers_utils.validate_signature(
|
@@ -127,6 +143,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
127
143
|
model_type=cls.HANDLER_TYPE,
|
128
144
|
handler_version=cls.HANDLER_VERSION,
|
129
145
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
146
|
+
options=model_meta_schema.PyTorchModelBlobOptions(multiple_inputs=multiple_inputs),
|
130
147
|
)
|
131
148
|
model_meta.models[name] = base_meta
|
132
149
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -172,6 +189,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
172
189
|
raw_model: "torch.nn.Module",
|
173
190
|
model_meta: model_meta_api.ModelMetadata,
|
174
191
|
) -> Type[custom_model.CustomModel]:
|
192
|
+
multiple_inputs = cast(
|
193
|
+
model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
|
194
|
+
)["multiple_inputs"]
|
195
|
+
|
175
196
|
def fn_factory(
|
176
197
|
raw_model: "torch.nn.Module",
|
177
198
|
signature: model_signature.ModelSignature,
|
@@ -183,19 +204,28 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
183
204
|
raise ValueError("Tensor cannot handle null values.")
|
184
205
|
|
185
206
|
raw_model.eval()
|
186
|
-
|
207
|
+
if multiple_inputs:
|
208
|
+
st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
209
|
+
|
210
|
+
if kwargs.get("use_gpu", False):
|
211
|
+
st = [element.cuda() for element in st]
|
187
212
|
|
188
|
-
|
189
|
-
|
213
|
+
with torch.no_grad():
|
214
|
+
res = getattr(raw_model, target_method)(*st)
|
190
215
|
|
191
|
-
|
192
|
-
|
216
|
+
if not isinstance(res, tuple):
|
217
|
+
res = [res]
|
218
|
+
else:
|
219
|
+
t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
220
|
+
if kwargs.get("use_gpu", False):
|
221
|
+
t = t.cuda()
|
193
222
|
|
194
|
-
|
195
|
-
|
223
|
+
with torch.no_grad():
|
224
|
+
res = getattr(raw_model, target_method)(t)
|
196
225
|
|
197
226
|
return model_signature_utils.rename_pandas_df(
|
198
|
-
|
227
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
228
|
+
features=signature.outputs,
|
199
229
|
)
|
200
230
|
|
201
231
|
return fn
|
@@ -297,10 +297,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
297
297
|
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
298
298
|
except TypeError:
|
299
299
|
try:
|
300
|
-
dtype_map = {
|
301
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
|
302
|
-
for spec in signature.inputs
|
303
|
-
}
|
300
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
304
301
|
|
305
302
|
if isinstance(X, pd.DataFrame):
|
306
303
|
X = X.astype(dtype_map, copy=False)
|
@@ -307,8 +307,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
307
307
|
except TypeError:
|
308
308
|
try:
|
309
309
|
dtype_map = {
|
310
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True)
|
311
|
-
for spec in signature.inputs
|
310
|
+
spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
|
312
311
|
}
|
313
312
|
|
314
313
|
if isinstance(X, pd.DataFrame):
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
3
|
|
4
|
-
import numpy as np
|
5
4
|
import pandas as pd
|
6
5
|
from packaging import version
|
7
6
|
from typing_extensions import TypeGuard, Unpack
|
@@ -13,6 +12,7 @@ from snowflake.ml.model._packager.model_handlers import _base, _utils as handler
|
|
13
12
|
from snowflake.ml.model._packager.model_handlers_migrator import (
|
14
13
|
base_migrator,
|
15
14
|
tensorflow_migrator_2023_12_01,
|
15
|
+
tensorflow_migrator_2025_01_01,
|
16
16
|
)
|
17
17
|
from snowflake.ml.model._packager.model_meta import (
|
18
18
|
model_blob_meta,
|
@@ -20,7 +20,6 @@ from snowflake.ml.model._packager.model_meta import (
|
|
20
20
|
model_meta_schema,
|
21
21
|
)
|
22
22
|
from snowflake.ml.model._signatures import (
|
23
|
-
numpy_handler,
|
24
23
|
tensorflow_handler,
|
25
24
|
utils as model_signature_utils,
|
26
25
|
)
|
@@ -37,10 +36,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
37
36
|
"""
|
38
37
|
|
39
38
|
HANDLER_TYPE = "tensorflow"
|
40
|
-
HANDLER_VERSION = "2025-
|
41
|
-
_MIN_SNOWPARK_ML_VERSION = "1.
|
39
|
+
HANDLER_VERSION = "2025-03-01"
|
40
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
42
41
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
43
|
-
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201
|
42
|
+
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
|
43
|
+
"2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
|
44
44
|
}
|
45
45
|
|
46
46
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
@@ -112,25 +112,35 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
112
112
|
default_target_methods=default_target_methods,
|
113
113
|
)
|
114
114
|
|
115
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
116
|
+
|
115
117
|
if is_keras_model and len(target_methods) > 1:
|
116
118
|
raise ValueError("Keras model can only have one target method.")
|
117
119
|
|
118
120
|
def get_prediction(
|
119
121
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
120
122
|
) -> model_types.SupportedLocalDataType:
|
121
|
-
if
|
122
|
-
|
123
|
-
|
124
|
-
|
123
|
+
if multiple_inputs:
|
124
|
+
if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
|
125
|
+
sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
|
126
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
if not tensorflow_handler.TensorflowTensorHandler.can_handle(sample_input_data):
|
130
|
+
sample_input_data = tensorflow_handler.TensorflowTensorHandler.convert_from_df(
|
131
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
132
|
+
)
|
125
133
|
|
126
134
|
target_method = getattr(model, target_method_name, None)
|
127
135
|
assert callable(target_method)
|
128
136
|
for tensor in sample_input_data:
|
129
137
|
tensorflow.stop_gradient(tensor)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
138
|
+
if multiple_inputs:
|
139
|
+
predictions_df = target_method(*sample_input_data)
|
140
|
+
if not isinstance(predictions_df, tuple):
|
141
|
+
predictions_df = [predictions_df]
|
142
|
+
else:
|
143
|
+
predictions_df = target_method(sample_input_data)
|
134
144
|
|
135
145
|
return predictions_df
|
136
146
|
|
@@ -159,7 +169,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
159
169
|
model_type=cls.HANDLER_TYPE,
|
160
170
|
handler_version=cls.HANDLER_VERSION,
|
161
171
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
162
|
-
options=model_meta_schema.TensorflowModelBlobOptions(
|
172
|
+
options=model_meta_schema.TensorflowModelBlobOptions(
|
173
|
+
save_format=save_format, multiple_inputs=multiple_inputs
|
174
|
+
),
|
163
175
|
)
|
164
176
|
model_meta.models[name] = base_meta
|
165
177
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -219,6 +231,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
219
231
|
raw_model: "tensorflow.Module",
|
220
232
|
model_meta: model_meta_api.ModelMetadata,
|
221
233
|
) -> Type[custom_model.CustomModel]:
|
234
|
+
multiple_inputs = cast(
|
235
|
+
model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
|
236
|
+
)["multiple_inputs"]
|
237
|
+
|
222
238
|
def fn_factory(
|
223
239
|
raw_model: "tensorflow.Module",
|
224
240
|
signature: model_signature.ModelSignature,
|
@@ -229,21 +245,25 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
229
245
|
if X.isnull().any(axis=None):
|
230
246
|
raise ValueError("Tensor cannot handle null values.")
|
231
247
|
|
232
|
-
|
248
|
+
if multiple_inputs:
|
249
|
+
t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
|
233
250
|
|
234
|
-
|
235
|
-
|
236
|
-
|
251
|
+
for tensor in t:
|
252
|
+
tensorflow.stop_gradient(tensor)
|
253
|
+
res = getattr(raw_model, target_method)(*t)
|
237
254
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
242
|
-
# In case of running on CPU, it will return numpy array
|
243
|
-
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
255
|
+
if not isinstance(res, tuple):
|
256
|
+
res = [res]
|
244
257
|
else:
|
245
|
-
|
246
|
-
|
258
|
+
t = tensorflow_handler.TensorflowTensorHandler.convert_from_df(X, signature.inputs)
|
259
|
+
|
260
|
+
tensorflow.stop_gradient(t)
|
261
|
+
res = getattr(raw_model, target_method)(t)
|
262
|
+
|
263
|
+
return model_signature_utils.rename_pandas_df(
|
264
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
265
|
+
features=signature.outputs,
|
266
|
+
)
|
247
267
|
|
248
268
|
return fn
|
249
269
|
|
@@ -8,10 +8,14 @@ from snowflake.ml._internal import type_utils
|
|
8
8
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
9
9
|
from snowflake.ml.model._packager.model_env import model_env
|
10
10
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
11
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
11
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
12
|
+
base_migrator,
|
13
|
+
torchscript_migrator_2023_12_01,
|
14
|
+
)
|
12
15
|
from snowflake.ml.model._packager.model_meta import (
|
13
16
|
model_blob_meta,
|
14
17
|
model_meta as model_meta_api,
|
18
|
+
model_meta_schema,
|
15
19
|
)
|
16
20
|
from snowflake.ml.model._signatures import (
|
17
21
|
pytorch_handler,
|
@@ -30,9 +34,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
30
34
|
"""
|
31
35
|
|
32
36
|
HANDLER_TYPE = "torchscript"
|
33
|
-
HANDLER_VERSION = "
|
34
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0
|
35
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
37
|
+
HANDLER_VERSION = "2025-03-01"
|
38
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
39
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
40
|
+
"2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
|
41
|
+
}
|
36
42
|
|
37
43
|
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
38
44
|
DEFAULT_TARGET_METHODS = ["forward"]
|
@@ -81,22 +87,32 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
81
87
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
82
88
|
)
|
83
89
|
|
90
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
91
|
+
|
84
92
|
def get_prediction(
|
85
93
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
86
94
|
) -> model_types.SupportedLocalDataType:
|
87
|
-
if
|
88
|
-
|
89
|
-
|
90
|
-
|
95
|
+
if multiple_inputs:
|
96
|
+
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
|
97
|
+
sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
|
98
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
|
102
|
+
sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
|
103
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
104
|
+
)
|
91
105
|
|
92
106
|
model.eval()
|
93
107
|
target_method = getattr(model, target_method_name, None)
|
94
108
|
assert callable(target_method)
|
95
109
|
with torch.no_grad():
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
110
|
+
if multiple_inputs:
|
111
|
+
predictions_df = target_method(*sample_input_data)
|
112
|
+
if not isinstance(predictions_df, tuple):
|
113
|
+
predictions_df = [predictions_df]
|
114
|
+
else:
|
115
|
+
predictions_df = target_method(sample_input_data)
|
100
116
|
|
101
117
|
return predictions_df
|
102
118
|
|
@@ -117,6 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
117
133
|
model_type=cls.HANDLER_TYPE,
|
118
134
|
handler_version=cls.HANDLER_VERSION,
|
119
135
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
136
|
+
options=model_meta_schema.TorchScriptModelBlobOptions(multiple_inputs=multiple_inputs),
|
120
137
|
)
|
121
138
|
model_meta.models[name] = base_meta
|
122
139
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -170,6 +187,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
170
187
|
signature: model_signature.ModelSignature,
|
171
188
|
target_method: str,
|
172
189
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
190
|
+
multiple_inputs = cast(
|
191
|
+
model_meta_schema.TorchScriptModelBlobOptions, model_meta.models[model_meta.name].options
|
192
|
+
)["multiple_inputs"]
|
193
|
+
|
173
194
|
@custom_model.inference_api
|
174
195
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
175
196
|
if X.isnull().any(axis=None):
|
@@ -179,19 +200,27 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
179
200
|
|
180
201
|
raw_model.eval()
|
181
202
|
|
182
|
-
|
203
|
+
if multiple_inputs:
|
204
|
+
st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
183
205
|
|
184
|
-
|
185
|
-
|
206
|
+
if kwargs.get("use_gpu", False):
|
207
|
+
st = [element.cuda() for element in st]
|
186
208
|
|
187
|
-
|
188
|
-
|
209
|
+
with torch.no_grad():
|
210
|
+
res = getattr(raw_model, target_method)(*st)
|
189
211
|
|
190
|
-
|
191
|
-
|
212
|
+
if not isinstance(res, tuple):
|
213
|
+
res = [res]
|
214
|
+
else:
|
215
|
+
t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
216
|
+
if kwargs.get("use_gpu", False):
|
217
|
+
t = t.cuda()
|
192
218
|
|
219
|
+
with torch.no_grad():
|
220
|
+
res = getattr(raw_model, target_method)(t)
|
193
221
|
return model_signature_utils.rename_pandas_df(
|
194
|
-
|
222
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
223
|
+
features=signature.outputs,
|
195
224
|
)
|
196
225
|
|
197
226
|
return fn
|
@@ -99,10 +99,10 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
99
99
|
def get_prediction(
|
100
100
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
101
101
|
) -> model_types.SupportedLocalDataType:
|
102
|
-
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
102
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray, xgboost.DMatrix)):
|
103
103
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
104
104
|
|
105
|
-
if isinstance(model, xgboost.Booster):
|
105
|
+
if isinstance(model, xgboost.Booster) and not isinstance(sample_input_data, xgboost.DMatrix):
|
106
106
|
sample_input_data = xgboost.DMatrix(sample_input_data)
|
107
107
|
|
108
108
|
target_method = getattr(model, target_method_name, None)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class PyTorchHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
|
19
|
+
model_blob_options["multiple_inputs"] = True
|
20
|
+
model_meta.models[name].options = model_blob_options
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TensorflowHandlerMigrator20250101(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2025-01-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
model_blob_metadata = model_meta.models[name]
|
17
|
+
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
18
|
+
model_blob_options["multiple_inputs"] = True
|
19
|
+
model_meta.models[name].options = model_blob_options
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TorchScriptHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
|
19
|
+
model_blob_options["multiple_inputs"] = True
|
20
|
+
model_meta.models[name].options = model_blob_options
|
@@ -48,6 +48,7 @@ def create_model_metadata(
|
|
48
48
|
ext_modules: Optional[List[ModuleType]] = None,
|
49
49
|
conda_dependencies: Optional[List[str]] = None,
|
50
50
|
pip_requirements: Optional[List[str]] = None,
|
51
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
51
52
|
python_version: Optional[str] = None,
|
52
53
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
53
54
|
**kwargs: Any,
|
@@ -67,6 +68,7 @@ def create_model_metadata(
|
|
67
68
|
ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
|
68
69
|
conda_dependencies: List of conda requirements for running the model. Defaults to None.
|
69
70
|
pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
|
71
|
+
artifact_repository_map: A dict mapping from package channel to artifact repository name.
|
70
72
|
python_version: A string of python version where model is run. Used for user override. If specified as None,
|
71
73
|
current version would be captured. Defaults to None.
|
72
74
|
task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
|
@@ -102,6 +104,7 @@ def create_model_metadata(
|
|
102
104
|
env = _create_env_for_model_metadata(
|
103
105
|
conda_dependencies=conda_dependencies,
|
104
106
|
pip_requirements=pip_requirements,
|
107
|
+
artifact_repository_map=artifact_repository_map,
|
105
108
|
python_version=python_version,
|
106
109
|
embed_local_ml_library=embed_local_ml_library,
|
107
110
|
)
|
@@ -151,6 +154,7 @@ def _create_env_for_model_metadata(
|
|
151
154
|
*,
|
152
155
|
conda_dependencies: Optional[List[str]] = None,
|
153
156
|
pip_requirements: Optional[List[str]] = None,
|
157
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
154
158
|
python_version: Optional[str] = None,
|
155
159
|
embed_local_ml_library: bool = False,
|
156
160
|
) -> model_env.ModelEnv:
|
@@ -159,6 +163,7 @@ def _create_env_for_model_metadata(
|
|
159
163
|
# Mypy doesn't like getter and setter have different types. See python/mypy #3004
|
160
164
|
env.conda_dependencies = conda_dependencies # type: ignore[assignment]
|
161
165
|
env.pip_requirements = pip_requirements # type: ignore[assignment]
|
166
|
+
env.artifact_repository_map = artifact_repository_map
|
162
167
|
env.python_version = python_version # type: ignore[assignment]
|
163
168
|
env.snowpark_ml_version = snowml_env.VERSION
|
164
169
|
|
@@ -331,7 +336,6 @@ class ModelMetadata:
|
|
331
336
|
"function_properties": self.function_properties,
|
332
337
|
}
|
333
338
|
)
|
334
|
-
|
335
339
|
with open(model_yaml_path, "w", encoding="utf-8") as out:
|
336
340
|
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
337
341
|
yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
|
@@ -18,6 +18,7 @@ class FunctionProperties(Enum):
|
|
18
18
|
class ModelRuntimeDependenciesDict(TypedDict):
|
19
19
|
conda: Required[str]
|
20
20
|
pip: Required[str]
|
21
|
+
artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
|
21
22
|
|
22
23
|
|
23
24
|
class ModelRuntimeDict(TypedDict):
|
@@ -28,6 +29,7 @@ class ModelRuntimeDict(TypedDict):
|
|
28
29
|
class ModelEnvDict(TypedDict):
|
29
30
|
conda: Required[str]
|
30
31
|
pip: Required[str]
|
32
|
+
artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
|
31
33
|
python_version: Required[str]
|
32
34
|
cuda_version: NotRequired[Optional[str]]
|
33
35
|
snowpark_ml_version: Required[str]
|
@@ -61,8 +63,17 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
|
|
61
63
|
xgb_estimator_type: Required[str]
|
62
64
|
|
63
65
|
|
66
|
+
class PyTorchModelBlobOptions(BaseModelBlobOptions):
|
67
|
+
multiple_inputs: Required[bool]
|
68
|
+
|
69
|
+
|
70
|
+
class TorchScriptModelBlobOptions(BaseModelBlobOptions):
|
71
|
+
multiple_inputs: Required[bool]
|
72
|
+
|
73
|
+
|
64
74
|
class TensorflowModelBlobOptions(BaseModelBlobOptions):
|
65
75
|
save_format: Required[str]
|
76
|
+
multiple_inputs: Required[bool]
|
66
77
|
|
67
78
|
|
68
79
|
class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
|
@@ -74,6 +85,8 @@ ModelBlobOptions = Union[
|
|
74
85
|
HuggingFacePipelineModelBlobOptions,
|
75
86
|
MLFlowModelBlobOptions,
|
76
87
|
XgboostModelBlobOptions,
|
88
|
+
PyTorchModelBlobOptions,
|
89
|
+
TorchScriptModelBlobOptions,
|
77
90
|
TensorflowModelBlobOptions,
|
78
91
|
SentenceTransformersModelBlobOptions,
|
79
92
|
]
|
@@ -43,13 +43,13 @@ class ModelPackager:
|
|
43
43
|
metadata: Optional[Dict[str, str]] = None,
|
44
44
|
conda_dependencies: Optional[List[str]] = None,
|
45
45
|
pip_requirements: Optional[List[str]] = None,
|
46
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
46
47
|
python_version: Optional[str] = None,
|
47
48
|
ext_modules: Optional[List[ModuleType]] = None,
|
48
49
|
code_paths: Optional[List[str]] = None,
|
49
|
-
options:
|
50
|
+
options: model_types.ModelSaveOption,
|
50
51
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
51
52
|
) -> model_meta.ModelMetadata:
|
52
|
-
|
53
53
|
if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
|
54
54
|
raise snowml_exceptions.SnowflakeMLException(
|
55
55
|
error_code=error_codes.INVALID_ARGUMENT,
|
@@ -58,9 +58,6 @@ class ModelPackager:
|
|
58
58
|
),
|
59
59
|
)
|
60
60
|
|
61
|
-
if not options:
|
62
|
-
options = model_types.BaseModelSaveOption()
|
63
|
-
|
64
61
|
handler = model_handler.find_handler(model)
|
65
62
|
if handler is None:
|
66
63
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -77,6 +74,7 @@ class ModelPackager:
|
|
77
74
|
ext_modules=ext_modules,
|
78
75
|
conda_dependencies=conda_dependencies,
|
79
76
|
pip_requirements=pip_requirements,
|
77
|
+
artifact_repository_map=artifact_repository_map,
|
80
78
|
python_version=python_version,
|
81
79
|
task=task,
|
82
80
|
**options,
|
@@ -1,2 +1 @@
|
|
1
|
-
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.
|
2
|
-
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'keras>=2.0.0,<4', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<3', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.7.0,<3', 'sentencepiece>=0.1.95,<0.2.0', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.17.0,<3', 'tokenizers>=0.15.1,<1', 'torchdata>=0.4,<1', 'transformers>=4.37.2,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.12.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|