snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/telemetry.py +4 -0
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +25 -1
- snowflake/ml/jobs/_utils/payload_utils.py +94 -20
- snowflake/ml/jobs/_utils/spec_utils.py +95 -31
- snowflake/ml/jobs/decorators.py +7 -0
- snowflake/ml/jobs/manager.py +20 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +113 -17
- snowflake/ml/model/_client/ops/service_ops.py +16 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +10 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +4 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
- snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +52 -31
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +9 -17
- snowflake/ml/model/_signatures/pandas_handler.py +19 -30
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +31 -13
- snowflake/ml/model/type_hints.py +13 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +59 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
@@ -30,10 +30,7 @@ from snowflake.ml.model._packager.model_meta import (
|
|
30
30
|
model_meta as model_meta_api,
|
31
31
|
model_meta_schema,
|
32
32
|
)
|
33
|
-
from snowflake.ml.model._signatures import
|
34
|
-
builtins_handler,
|
35
|
-
utils as model_signature_utils,
|
36
|
-
)
|
33
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
34
|
from snowflake.ml.model.models import huggingface_pipeline
|
38
35
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
39
36
|
|
@@ -66,16 +63,16 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model
|
|
66
63
|
return []
|
67
64
|
|
68
65
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
66
|
+
def sanitize_output(data: Any) -> Any:
|
67
|
+
if isinstance(data, np.number):
|
68
|
+
return data.item()
|
69
|
+
if isinstance(data, np.ndarray):
|
70
|
+
return sanitize_output(data.tolist())
|
71
|
+
if isinstance(data, list):
|
72
|
+
return [sanitize_output(x) for x in data]
|
73
|
+
if isinstance(data, dict):
|
74
|
+
return {k: sanitize_output(v) for k, v in data.items()}
|
75
|
+
return data
|
79
76
|
|
80
77
|
|
81
78
|
@final
|
@@ -410,13 +407,17 @@ class HuggingFacePipelineHandler(
|
|
410
407
|
)
|
411
408
|
for conv_data in X.to_dict("records")
|
412
409
|
]
|
413
|
-
elif len(signature.inputs) == 1:
|
414
|
-
input_data = X.to_dict("list")[signature.inputs[0].name]
|
415
410
|
else:
|
416
411
|
if isinstance(raw_model, transformers.TableQuestionAnsweringPipeline):
|
417
412
|
X["table"] = X["table"].apply(json.loads)
|
418
413
|
|
419
|
-
|
414
|
+
# Most pipelines if it is expecting more than one arguments,
|
415
|
+
# it is expecting a list of dict, where each dict has keys corresponding to the argument.
|
416
|
+
if len(signature.inputs) > 1:
|
417
|
+
input_data = X.to_dict("records")
|
418
|
+
# If it is only expecting one argument, Then it is expecting a list of something.
|
419
|
+
else:
|
420
|
+
input_data = X[signature.inputs[0].name].to_list()
|
420
421
|
temp_res = getattr(raw_model, target_method)(input_data)
|
421
422
|
|
422
423
|
# Some huggingface pipeline will omit the outer list when there is only 1 input.
|
@@ -439,7 +440,6 @@ class HuggingFacePipelineHandler(
|
|
439
440
|
),
|
440
441
|
)
|
441
442
|
and X.shape[0] == 1
|
442
|
-
and isinstance(temp_res[0], dict)
|
443
443
|
)
|
444
444
|
):
|
445
445
|
temp_res = [temp_res]
|
@@ -453,14 +453,18 @@ class HuggingFacePipelineHandler(
|
|
453
453
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
454
454
|
|
455
455
|
# To concat those who outputs a list with one input.
|
456
|
-
if
|
457
|
-
|
458
|
-
|
456
|
+
if isinstance(temp_res[0], list):
|
457
|
+
if isinstance(temp_res[0][0], dict):
|
458
|
+
res = pd.DataFrame({0: temp_res})
|
459
|
+
else:
|
460
|
+
res = pd.DataFrame(temp_res)
|
461
|
+
else:
|
459
462
|
res = pd.DataFrame(temp_res)
|
460
|
-
|
461
|
-
|
463
|
+
|
464
|
+
if hasattr(res, "map"):
|
465
|
+
res = res.map(sanitize_output)
|
462
466
|
else:
|
463
|
-
|
467
|
+
res = res.applymap(sanitize_output)
|
464
468
|
|
465
469
|
return model_signature_utils.rename_pandas_df(data=res, features=signature.outputs)
|
466
470
|
|
@@ -191,11 +191,7 @@ class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
|
191
191
|
signature: model_signature.ModelSignature,
|
192
192
|
target_method: str,
|
193
193
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
194
|
-
dtype_map = {
|
195
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True)
|
196
|
-
for spec in signature.inputs
|
197
|
-
if isinstance(spec, model_signature.FeatureSpec)
|
198
|
-
}
|
194
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
199
195
|
|
200
196
|
@custom_model.inference_api
|
201
197
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
@@ -10,10 +10,14 @@ from snowflake.ml._internal import type_utils
|
|
10
10
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._packager.model_env import model_env
|
12
12
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
13
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
13
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
14
|
+
base_migrator,
|
15
|
+
pytorch_migrator_2023_12_01,
|
16
|
+
)
|
14
17
|
from snowflake.ml.model._packager.model_meta import (
|
15
18
|
model_blob_meta,
|
16
19
|
model_meta as model_meta_api,
|
20
|
+
model_meta_schema,
|
17
21
|
)
|
18
22
|
from snowflake.ml.model._signatures import (
|
19
23
|
pytorch_handler,
|
@@ -21,7 +25,6 @@ from snowflake.ml.model._signatures import (
|
|
21
25
|
)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
|
-
import sentence_transformers # noqa: F401
|
25
28
|
import torch
|
26
29
|
|
27
30
|
|
@@ -33,9 +36,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
33
36
|
"""
|
34
37
|
|
35
38
|
HANDLER_TYPE = "pytorch"
|
36
|
-
HANDLER_VERSION = "
|
37
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0
|
38
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
39
|
+
HANDLER_VERSION = "2025-03-01"
|
40
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
41
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
42
|
+
"2023-12-01": pytorch_migrator_2023_12_01.PyTorchHandlerMigrator20231201
|
43
|
+
}
|
39
44
|
|
40
45
|
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
41
46
|
DEFAULT_TARGET_METHODS = ["forward"]
|
@@ -89,22 +94,33 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
89
94
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
90
95
|
)
|
91
96
|
|
97
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
98
|
+
|
92
99
|
def get_prediction(
|
93
100
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
94
101
|
) -> model_types.SupportedLocalDataType:
|
95
|
-
if
|
96
|
-
|
97
|
-
|
98
|
-
|
102
|
+
if multiple_inputs:
|
103
|
+
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
|
104
|
+
sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
|
105
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
|
109
|
+
sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
|
110
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
111
|
+
)
|
99
112
|
|
100
113
|
model.eval()
|
101
114
|
target_method = getattr(model, target_method_name, None)
|
102
115
|
assert callable(target_method)
|
103
116
|
with torch.no_grad():
|
104
|
-
|
117
|
+
if multiple_inputs:
|
118
|
+
predictions_df = target_method(*sample_input_data)
|
119
|
+
if not isinstance(predictions_df, tuple):
|
120
|
+
predictions_df = [predictions_df]
|
121
|
+
else:
|
122
|
+
predictions_df = target_method(sample_input_data)
|
105
123
|
|
106
|
-
if isinstance(predictions_df, torch.Tensor):
|
107
|
-
predictions_df = [predictions_df]
|
108
124
|
return predictions_df
|
109
125
|
|
110
126
|
model_meta = handlers_utils.validate_signature(
|
@@ -127,6 +143,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
127
143
|
model_type=cls.HANDLER_TYPE,
|
128
144
|
handler_version=cls.HANDLER_VERSION,
|
129
145
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
146
|
+
options=model_meta_schema.PyTorchModelBlobOptions(multiple_inputs=multiple_inputs),
|
130
147
|
)
|
131
148
|
model_meta.models[name] = base_meta
|
132
149
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -172,6 +189,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
172
189
|
raw_model: "torch.nn.Module",
|
173
190
|
model_meta: model_meta_api.ModelMetadata,
|
174
191
|
) -> Type[custom_model.CustomModel]:
|
192
|
+
multiple_inputs = cast(
|
193
|
+
model_meta_schema.PyTorchModelBlobOptions, model_meta.models[model_meta.name].options
|
194
|
+
)["multiple_inputs"]
|
195
|
+
|
175
196
|
def fn_factory(
|
176
197
|
raw_model: "torch.nn.Module",
|
177
198
|
signature: model_signature.ModelSignature,
|
@@ -183,19 +204,28 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
183
204
|
raise ValueError("Tensor cannot handle null values.")
|
184
205
|
|
185
206
|
raw_model.eval()
|
186
|
-
|
207
|
+
if multiple_inputs:
|
208
|
+
st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
209
|
+
|
210
|
+
if kwargs.get("use_gpu", False):
|
211
|
+
st = [element.cuda() for element in st]
|
187
212
|
|
188
|
-
|
189
|
-
|
213
|
+
with torch.no_grad():
|
214
|
+
res = getattr(raw_model, target_method)(*st)
|
190
215
|
|
191
|
-
|
192
|
-
|
216
|
+
if not isinstance(res, tuple):
|
217
|
+
res = [res]
|
218
|
+
else:
|
219
|
+
t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
220
|
+
if kwargs.get("use_gpu", False):
|
221
|
+
t = t.cuda()
|
193
222
|
|
194
|
-
|
195
|
-
|
223
|
+
with torch.no_grad():
|
224
|
+
res = getattr(raw_model, target_method)(t)
|
196
225
|
|
197
226
|
return model_signature_utils.rename_pandas_df(
|
198
|
-
|
227
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
228
|
+
features=signature.outputs,
|
199
229
|
)
|
200
230
|
|
201
231
|
return fn
|
@@ -57,6 +57,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
57
57
|
"predict_proba",
|
58
58
|
"predict_log_proba",
|
59
59
|
"decision_function",
|
60
|
+
"score_samples",
|
60
61
|
]
|
61
62
|
EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
|
62
63
|
|
@@ -74,10 +75,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
74
75
|
and (
|
75
76
|
not type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
|
76
77
|
) # LGBMModel is actually a BaseEstimator
|
77
|
-
and any(
|
78
|
-
(hasattr(model, method) and callable(getattr(model, method, None)))
|
79
|
-
for method in cls.DEFAULT_TARGET_METHODS
|
80
|
-
)
|
81
78
|
)
|
82
79
|
|
83
80
|
@classmethod
|
@@ -297,10 +294,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
297
294
|
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
298
295
|
except TypeError:
|
299
296
|
try:
|
300
|
-
dtype_map = {
|
301
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True) # type: ignore[attr-defined]
|
302
|
-
for spec in signature.inputs
|
303
|
-
}
|
297
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
304
298
|
|
305
299
|
if isinstance(X, pd.DataFrame):
|
306
300
|
X = X.astype(dtype_map, copy=False)
|
@@ -307,8 +307,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
307
307
|
except TypeError:
|
308
308
|
try:
|
309
309
|
dtype_map = {
|
310
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True)
|
311
|
-
for spec in signature.inputs
|
310
|
+
spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
|
312
311
|
}
|
313
312
|
|
314
313
|
if isinstance(X, pd.DataFrame):
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import os
|
2
2
|
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
3
|
|
4
|
-
import numpy as np
|
5
4
|
import pandas as pd
|
6
5
|
from packaging import version
|
7
6
|
from typing_extensions import TypeGuard, Unpack
|
@@ -13,6 +12,7 @@ from snowflake.ml.model._packager.model_handlers import _base, _utils as handler
|
|
13
12
|
from snowflake.ml.model._packager.model_handlers_migrator import (
|
14
13
|
base_migrator,
|
15
14
|
tensorflow_migrator_2023_12_01,
|
15
|
+
tensorflow_migrator_2025_01_01,
|
16
16
|
)
|
17
17
|
from snowflake.ml.model._packager.model_meta import (
|
18
18
|
model_blob_meta,
|
@@ -20,7 +20,6 @@ from snowflake.ml.model._packager.model_meta import (
|
|
20
20
|
model_meta_schema,
|
21
21
|
)
|
22
22
|
from snowflake.ml.model._signatures import (
|
23
|
-
numpy_handler,
|
24
23
|
tensorflow_handler,
|
25
24
|
utils as model_signature_utils,
|
26
25
|
)
|
@@ -37,10 +36,11 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
37
36
|
"""
|
38
37
|
|
39
38
|
HANDLER_TYPE = "tensorflow"
|
40
|
-
HANDLER_VERSION = "2025-
|
41
|
-
_MIN_SNOWPARK_ML_VERSION = "1.
|
39
|
+
HANDLER_VERSION = "2025-03-01"
|
40
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
42
41
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
43
|
-
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201
|
42
|
+
"2023-12-01": tensorflow_migrator_2023_12_01.TensorflowHandlerMigrator20231201,
|
43
|
+
"2025-01-01": tensorflow_migrator_2025_01_01.TensorflowHandlerMigrator20250101,
|
44
44
|
}
|
45
45
|
|
46
46
|
MODEL_BLOB_FILE_OR_DIR = "model"
|
@@ -112,25 +112,35 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
112
112
|
default_target_methods=default_target_methods,
|
113
113
|
)
|
114
114
|
|
115
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
116
|
+
|
115
117
|
if is_keras_model and len(target_methods) > 1:
|
116
118
|
raise ValueError("Keras model can only have one target method.")
|
117
119
|
|
118
120
|
def get_prediction(
|
119
121
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
120
122
|
) -> model_types.SupportedLocalDataType:
|
121
|
-
if
|
122
|
-
|
123
|
-
|
124
|
-
|
123
|
+
if multiple_inputs:
|
124
|
+
if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
|
125
|
+
sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
|
126
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
if not tensorflow_handler.TensorflowTensorHandler.can_handle(sample_input_data):
|
130
|
+
sample_input_data = tensorflow_handler.TensorflowTensorHandler.convert_from_df(
|
131
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
132
|
+
)
|
125
133
|
|
126
134
|
target_method = getattr(model, target_method_name, None)
|
127
135
|
assert callable(target_method)
|
128
136
|
for tensor in sample_input_data:
|
129
137
|
tensorflow.stop_gradient(tensor)
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
138
|
+
if multiple_inputs:
|
139
|
+
predictions_df = target_method(*sample_input_data)
|
140
|
+
if not isinstance(predictions_df, tuple):
|
141
|
+
predictions_df = [predictions_df]
|
142
|
+
else:
|
143
|
+
predictions_df = target_method(sample_input_data)
|
134
144
|
|
135
145
|
return predictions_df
|
136
146
|
|
@@ -159,7 +169,9 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
159
169
|
model_type=cls.HANDLER_TYPE,
|
160
170
|
handler_version=cls.HANDLER_VERSION,
|
161
171
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
162
|
-
options=model_meta_schema.TensorflowModelBlobOptions(
|
172
|
+
options=model_meta_schema.TensorflowModelBlobOptions(
|
173
|
+
save_format=save_format, multiple_inputs=multiple_inputs
|
174
|
+
),
|
163
175
|
)
|
164
176
|
model_meta.models[name] = base_meta
|
165
177
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -219,6 +231,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
219
231
|
raw_model: "tensorflow.Module",
|
220
232
|
model_meta: model_meta_api.ModelMetadata,
|
221
233
|
) -> Type[custom_model.CustomModel]:
|
234
|
+
multiple_inputs = cast(
|
235
|
+
model_meta_schema.TensorflowModelBlobOptions, model_meta.models[model_meta.name].options
|
236
|
+
)["multiple_inputs"]
|
237
|
+
|
222
238
|
def fn_factory(
|
223
239
|
raw_model: "tensorflow.Module",
|
224
240
|
signature: model_signature.ModelSignature,
|
@@ -229,21 +245,25 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
229
245
|
if X.isnull().any(axis=None):
|
230
246
|
raise ValueError("Tensor cannot handle null values.")
|
231
247
|
|
232
|
-
|
248
|
+
if multiple_inputs:
|
249
|
+
t = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(X, signature.inputs)
|
233
250
|
|
234
|
-
|
235
|
-
|
236
|
-
|
251
|
+
for tensor in t:
|
252
|
+
tensorflow.stop_gradient(tensor)
|
253
|
+
res = getattr(raw_model, target_method)(*t)
|
237
254
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
242
|
-
# In case of running on CPU, it will return numpy array
|
243
|
-
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
255
|
+
if not isinstance(res, tuple):
|
256
|
+
res = [res]
|
244
257
|
else:
|
245
|
-
|
246
|
-
|
258
|
+
t = tensorflow_handler.TensorflowTensorHandler.convert_from_df(X, signature.inputs)
|
259
|
+
|
260
|
+
tensorflow.stop_gradient(t)
|
261
|
+
res = getattr(raw_model, target_method)(t)
|
262
|
+
|
263
|
+
return model_signature_utils.rename_pandas_df(
|
264
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
265
|
+
features=signature.outputs,
|
266
|
+
)
|
247
267
|
|
248
268
|
return fn
|
249
269
|
|
@@ -8,10 +8,14 @@ from snowflake.ml._internal import type_utils
|
|
8
8
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
9
9
|
from snowflake.ml.model._packager.model_env import model_env
|
10
10
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
11
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
11
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
12
|
+
base_migrator,
|
13
|
+
torchscript_migrator_2023_12_01,
|
14
|
+
)
|
12
15
|
from snowflake.ml.model._packager.model_meta import (
|
13
16
|
model_blob_meta,
|
14
17
|
model_meta as model_meta_api,
|
18
|
+
model_meta_schema,
|
15
19
|
)
|
16
20
|
from snowflake.ml.model._signatures import (
|
17
21
|
pytorch_handler,
|
@@ -30,9 +34,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
30
34
|
"""
|
31
35
|
|
32
36
|
HANDLER_TYPE = "torchscript"
|
33
|
-
HANDLER_VERSION = "
|
34
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0
|
35
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
37
|
+
HANDLER_VERSION = "2025-03-01"
|
38
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
39
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
40
|
+
"2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
|
41
|
+
}
|
36
42
|
|
37
43
|
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
38
44
|
DEFAULT_TARGET_METHODS = ["forward"]
|
@@ -81,22 +87,32 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
81
87
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
82
88
|
)
|
83
89
|
|
90
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
91
|
+
|
84
92
|
def get_prediction(
|
85
93
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
86
94
|
) -> model_types.SupportedLocalDataType:
|
87
|
-
if
|
88
|
-
|
89
|
-
|
90
|
-
|
95
|
+
if multiple_inputs:
|
96
|
+
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
|
97
|
+
sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
|
98
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
|
102
|
+
sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
|
103
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
104
|
+
)
|
91
105
|
|
92
106
|
model.eval()
|
93
107
|
target_method = getattr(model, target_method_name, None)
|
94
108
|
assert callable(target_method)
|
95
109
|
with torch.no_grad():
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
110
|
+
if multiple_inputs:
|
111
|
+
predictions_df = target_method(*sample_input_data)
|
112
|
+
if not isinstance(predictions_df, tuple):
|
113
|
+
predictions_df = [predictions_df]
|
114
|
+
else:
|
115
|
+
predictions_df = target_method(sample_input_data)
|
100
116
|
|
101
117
|
return predictions_df
|
102
118
|
|
@@ -117,6 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
117
133
|
model_type=cls.HANDLER_TYPE,
|
118
134
|
handler_version=cls.HANDLER_VERSION,
|
119
135
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
136
|
+
options=model_meta_schema.TorchScriptModelBlobOptions(multiple_inputs=multiple_inputs),
|
120
137
|
)
|
121
138
|
model_meta.models[name] = base_meta
|
122
139
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -170,6 +187,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
170
187
|
signature: model_signature.ModelSignature,
|
171
188
|
target_method: str,
|
172
189
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
190
|
+
multiple_inputs = cast(
|
191
|
+
model_meta_schema.TorchScriptModelBlobOptions, model_meta.models[model_meta.name].options
|
192
|
+
)["multiple_inputs"]
|
193
|
+
|
173
194
|
@custom_model.inference_api
|
174
195
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
175
196
|
if X.isnull().any(axis=None):
|
@@ -179,19 +200,27 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
179
200
|
|
180
201
|
raw_model.eval()
|
181
202
|
|
182
|
-
|
203
|
+
if multiple_inputs:
|
204
|
+
st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
183
205
|
|
184
|
-
|
185
|
-
|
206
|
+
if kwargs.get("use_gpu", False):
|
207
|
+
st = [element.cuda() for element in st]
|
186
208
|
|
187
|
-
|
188
|
-
|
209
|
+
with torch.no_grad():
|
210
|
+
res = getattr(raw_model, target_method)(*st)
|
189
211
|
|
190
|
-
|
191
|
-
|
212
|
+
if not isinstance(res, tuple):
|
213
|
+
res = [res]
|
214
|
+
else:
|
215
|
+
t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
216
|
+
if kwargs.get("use_gpu", False):
|
217
|
+
t = t.cuda()
|
192
218
|
|
219
|
+
with torch.no_grad():
|
220
|
+
res = getattr(raw_model, target_method)(t)
|
193
221
|
return model_signature_utils.rename_pandas_df(
|
194
|
-
|
222
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
223
|
+
features=signature.outputs,
|
195
224
|
)
|
196
225
|
|
197
226
|
return fn
|
@@ -99,10 +99,10 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
99
99
|
def get_prediction(
|
100
100
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
101
101
|
) -> model_types.SupportedLocalDataType:
|
102
|
-
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
102
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray, xgboost.DMatrix)):
|
103
103
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
104
104
|
|
105
|
-
if isinstance(model, xgboost.Booster):
|
105
|
+
if isinstance(model, xgboost.Booster) and not isinstance(sample_input_data, xgboost.DMatrix):
|
106
106
|
sample_input_data = xgboost.DMatrix(sample_input_data)
|
107
107
|
|
108
108
|
target_method = getattr(model, target_method_name, None)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class PyTorchHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
|
19
|
+
model_blob_options["multiple_inputs"] = True
|
20
|
+
model_meta.models[name].options = model_blob_options
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TensorflowHandlerMigrator20250101(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2025-01-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
model_blob_metadata = model_meta.models[name]
|
17
|
+
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
18
|
+
model_blob_options["multiple_inputs"] = True
|
19
|
+
model_meta.models[name].options = model_blob_options
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TorchScriptHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
|
19
|
+
model_blob_options["multiple_inputs"] = True
|
20
|
+
model_meta.models[name].options = model_blob_options
|
@@ -1,2 +1 @@
|
|
1
|
-
REQUIREMENTS = ['cloudpickle>=2.0.0']
|
2
|
-
ALL_REQUIREMENTS=['cloudpickle>=2.0.0']
|
1
|
+
REQUIREMENTS = ['cloudpickle>=2.0.0,<3']
|