snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/telemetry.py +4 -0
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +24 -0
- snowflake/ml/jobs/_utils/payload_utils.py +94 -20
- snowflake/ml/jobs/_utils/spec_utils.py +73 -31
- snowflake/ml/jobs/decorators.py +3 -0
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +4 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
- snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +52 -31
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +9 -17
- snowflake/ml/model/_signatures/pandas_handler.py +19 -30
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +31 -13
- snowflake/ml/model/type_hints.py +13 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +18 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +287 -11
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +61 -57
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
from collections import abc
|
2
|
-
from typing import TYPE_CHECKING,
|
2
|
+
from typing import TYPE_CHECKING, Literal, Optional, Sequence, Union
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import pandas as pd
|
@@ -11,12 +11,62 @@ from snowflake.ml._internal.exceptions import (
|
|
11
11
|
exceptions as snowml_exceptions,
|
12
12
|
)
|
13
13
|
from snowflake.ml.model import type_hints as model_types
|
14
|
-
from snowflake.ml.model._signatures import base_handler, core
|
14
|
+
from snowflake.ml.model._signatures import base_handler, core, numpy_handler
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
import tensorflow
|
18
18
|
|
19
19
|
|
20
|
+
class TensorflowTensorHandler(base_handler.BaseDataHandler[Union["tensorflow.Tensor", "tensorflow.Variable"]]):
|
21
|
+
@staticmethod
|
22
|
+
def can_handle(
|
23
|
+
data: model_types.SupportedDataType,
|
24
|
+
) -> TypeGuard[Union["tensorflow.Tensor", "tensorflow.Variable"]]:
|
25
|
+
return type_utils.LazyType("tensorflow.Tensor").isinstance(data) or type_utils.LazyType(
|
26
|
+
"tensorflow.Variable"
|
27
|
+
).isinstance(data)
|
28
|
+
|
29
|
+
@staticmethod
|
30
|
+
def count(data: Union["tensorflow.Tensor", "tensorflow.Variable"]) -> int:
|
31
|
+
return numpy_handler.NumpyArrayHandler.count(data.numpy())
|
32
|
+
|
33
|
+
@staticmethod
|
34
|
+
def truncate(
|
35
|
+
data: Union["tensorflow.Tensor", "tensorflow.Variable"], length: int
|
36
|
+
) -> Union["tensorflow.Tensor", "tensorflow.Variable"]:
|
37
|
+
return data[: min(TensorflowTensorHandler.count(data), length)]
|
38
|
+
|
39
|
+
@staticmethod
|
40
|
+
def validate(data: Union["tensorflow.Tensor", "tensorflow.Variable"]) -> None:
|
41
|
+
numpy_handler.NumpyArrayHandler.validate(data.numpy())
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def infer_signature(
|
45
|
+
data: Union["tensorflow.Tensor", "tensorflow.Variable"], role: Literal["input", "output"]
|
46
|
+
) -> Sequence[core.BaseFeatureSpec]:
|
47
|
+
return numpy_handler.NumpyArrayHandler.infer_signature(data.numpy(), role=role)
|
48
|
+
|
49
|
+
@staticmethod
|
50
|
+
def convert_to_df(
|
51
|
+
data: Union["tensorflow.Tensor", "tensorflow.Variable"], ensure_serializable: bool = True
|
52
|
+
) -> pd.DataFrame:
|
53
|
+
return numpy_handler.NumpyArrayHandler.convert_to_df(data.numpy(), ensure_serializable=ensure_serializable)
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def convert_from_df(
|
57
|
+
df: pd.DataFrame, features: Optional[Sequence[core.BaseFeatureSpec]] = None
|
58
|
+
) -> Union["tensorflow.Tensor", "tensorflow.Variable"]:
|
59
|
+
import tensorflow as tf
|
60
|
+
|
61
|
+
if features is None:
|
62
|
+
if any(dtype == np.dtype("O") for dtype in df.dtypes):
|
63
|
+
return tf.convert_to_tensor(np.array(df.to_numpy().tolist()))
|
64
|
+
return tf.convert_to_tensor(df.to_numpy())
|
65
|
+
|
66
|
+
assert isinstance(features[0], core.FeatureSpec)
|
67
|
+
return tf.convert_to_tensor(np.array(df.to_numpy().tolist()), dtype=features[0]._dtype._numpy_type)
|
68
|
+
|
69
|
+
|
20
70
|
class SeqOfTensorflowTensorHandler(
|
21
71
|
base_handler.BaseDataHandler[Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]]
|
22
72
|
):
|
@@ -28,35 +78,12 @@ class SeqOfTensorflowTensorHandler(
|
|
28
78
|
return False
|
29
79
|
if len(data) == 0:
|
30
80
|
return False
|
31
|
-
|
32
|
-
|
33
|
-
).isinstance(data[0]):
|
34
|
-
return all(
|
35
|
-
type_utils.LazyType("tensorflow.Tensor").isinstance(data_col)
|
36
|
-
or type_utils.LazyType("tensorflow.Variable").isinstance(data_col)
|
37
|
-
for data_col in data
|
38
|
-
)
|
39
|
-
return False
|
81
|
+
|
82
|
+
return all(TensorflowTensorHandler.can_handle(data_col) for data_col in data)
|
40
83
|
|
41
84
|
@staticmethod
|
42
85
|
def count(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> int:
|
43
|
-
|
44
|
-
|
45
|
-
rows = []
|
46
|
-
for data_col in data:
|
47
|
-
shapes = data_col.shape.as_list()
|
48
|
-
if data_col.shape == tf.TensorShape(None) or (not shapes) or (shapes[0] is None):
|
49
|
-
# Unknown shape array
|
50
|
-
raise snowml_exceptions.SnowflakeMLException(
|
51
|
-
error_code=error_codes.INVALID_DATA,
|
52
|
-
original_exception=ValueError("Data Validation Error: Unknown shape data is found."),
|
53
|
-
)
|
54
|
-
# Make mypy happy
|
55
|
-
assert isinstance(shapes[0], int)
|
56
|
-
|
57
|
-
rows.append(shapes[0])
|
58
|
-
|
59
|
-
return min(rows)
|
86
|
+
return min(TensorflowTensorHandler.count(data_col) for data_col in data)
|
60
87
|
|
61
88
|
@staticmethod
|
62
89
|
def truncate(
|
@@ -66,49 +93,14 @@ class SeqOfTensorflowTensorHandler(
|
|
66
93
|
|
67
94
|
@staticmethod
|
68
95
|
def validate(data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]]) -> None:
|
69
|
-
import tensorflow as tf
|
70
|
-
|
71
96
|
for data_col in data:
|
72
|
-
|
73
|
-
# Unknown shape array
|
74
|
-
raise snowml_exceptions.SnowflakeMLException(
|
75
|
-
error_code=error_codes.INVALID_DATA,
|
76
|
-
original_exception=ValueError("Data Validation Error: Unknown shape data is found."),
|
77
|
-
)
|
78
|
-
|
79
|
-
if data_col.shape == tf.TensorShape([0]):
|
80
|
-
# Empty array
|
81
|
-
raise snowml_exceptions.SnowflakeMLException(
|
82
|
-
error_code=error_codes.INVALID_DATA,
|
83
|
-
original_exception=ValueError("Data Validation Error: Empty data is found."),
|
84
|
-
)
|
85
|
-
|
86
|
-
if data_col.shape == tf.TensorShape([1]) or data_col.shape == tf.TensorShape([]):
|
87
|
-
# scalar
|
88
|
-
raise snowml_exceptions.SnowflakeMLException(
|
89
|
-
error_code=error_codes.INVALID_DATA,
|
90
|
-
original_exception=ValueError("Data Validation Error: Scalar data is found."),
|
91
|
-
)
|
97
|
+
TensorflowTensorHandler.validate(data_col)
|
92
98
|
|
93
99
|
@staticmethod
|
94
100
|
def infer_signature(
|
95
101
|
data: Sequence[Union["tensorflow.Tensor", "tensorflow.Variable"]], role: Literal["input", "output"]
|
96
102
|
) -> Sequence[core.BaseFeatureSpec]:
|
97
|
-
|
98
|
-
features: List[core.BaseFeatureSpec] = []
|
99
|
-
role_prefix = (
|
100
|
-
SeqOfTensorflowTensorHandler.INPUT_PREFIX if role == "input" else SeqOfTensorflowTensorHandler.OUTPUT_PREFIX
|
101
|
-
) + "_"
|
102
|
-
|
103
|
-
for i, data_col in enumerate(data):
|
104
|
-
dtype = core.DataType.from_numpy_type(data_col.dtype.as_numpy_dtype)
|
105
|
-
ft_name = f"{role_prefix}{feature_prefix}{i}"
|
106
|
-
if len(data_col.shape) == 1:
|
107
|
-
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, nullable=False))
|
108
|
-
else:
|
109
|
-
ft_shape = tuple(data_col.shape[1:])
|
110
|
-
features.append(core.FeatureSpec(dtype=dtype, name=ft_name, shape=ft_shape, nullable=False))
|
111
|
-
return features
|
103
|
+
return numpy_handler.SeqOfNumpyArrayHandler.infer_signature([data_col.numpy() for data_col in data], role=role)
|
112
104
|
|
113
105
|
@staticmethod
|
114
106
|
def convert_to_df(
|
@@ -129,8 +121,10 @@ class SeqOfTensorflowTensorHandler(
|
|
129
121
|
for feature in features:
|
130
122
|
if isinstance(feature, core.FeatureGroupSpec):
|
131
123
|
raise snowml_exceptions.SnowflakeMLException(
|
132
|
-
error_code=error_codes.
|
133
|
-
original_exception=NotImplementedError(
|
124
|
+
error_code=error_codes.INVALID_DATA_TYPE,
|
125
|
+
original_exception=NotImplementedError(
|
126
|
+
"FeatureGroupSpec is not supported when converting to Tensorflow tensor."
|
127
|
+
),
|
134
128
|
)
|
135
129
|
assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
|
136
130
|
res.append(
|
@@ -135,7 +135,16 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
135
135
|
core.FeatureSpec(name="inputs", dtype=core.DataType.STRING),
|
136
136
|
],
|
137
137
|
outputs=[
|
138
|
-
core.
|
138
|
+
core.FeatureGroupSpec(
|
139
|
+
name="outputs",
|
140
|
+
specs=[
|
141
|
+
core.FeatureSpec(name="sequence", dtype=core.DataType.STRING),
|
142
|
+
core.FeatureSpec(name="score", dtype=core.DataType.DOUBLE),
|
143
|
+
core.FeatureSpec(name="token", dtype=core.DataType.INT64),
|
144
|
+
core.FeatureSpec(name="token_str", dtype=core.DataType.STRING),
|
145
|
+
],
|
146
|
+
shape=(-1,),
|
147
|
+
),
|
139
148
|
],
|
140
149
|
)
|
141
150
|
|
@@ -144,7 +153,18 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
144
153
|
return core.ModelSignature(
|
145
154
|
inputs=[core.FeatureSpec(name="inputs", dtype=core.DataType.STRING)],
|
146
155
|
outputs=[
|
147
|
-
core.
|
156
|
+
core.FeatureGroupSpec(
|
157
|
+
name="outputs",
|
158
|
+
specs=[
|
159
|
+
core.FeatureSpec(name="word", dtype=core.DataType.STRING),
|
160
|
+
core.FeatureSpec(name="score", dtype=core.DataType.DOUBLE),
|
161
|
+
core.FeatureSpec(name="entity", dtype=core.DataType.STRING),
|
162
|
+
core.FeatureSpec(name="index", dtype=core.DataType.INT64),
|
163
|
+
core.FeatureSpec(name="start", dtype=core.DataType.INT64),
|
164
|
+
core.FeatureSpec(name="end", dtype=core.DataType.INT64),
|
165
|
+
],
|
166
|
+
shape=(-1,),
|
167
|
+
),
|
148
168
|
],
|
149
169
|
)
|
150
170
|
|
@@ -171,7 +191,16 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
171
191
|
core.FeatureSpec(name="context", dtype=core.DataType.STRING),
|
172
192
|
],
|
173
193
|
outputs=[
|
174
|
-
core.
|
194
|
+
core.FeatureGroupSpec(
|
195
|
+
name="answers",
|
196
|
+
specs=[
|
197
|
+
core.FeatureSpec(name="score", dtype=core.DataType.DOUBLE),
|
198
|
+
core.FeatureSpec(name="start", dtype=core.DataType.INT64),
|
199
|
+
core.FeatureSpec(name="end", dtype=core.DataType.INT64),
|
200
|
+
core.FeatureSpec(name="answer", dtype=core.DataType.STRING),
|
201
|
+
],
|
202
|
+
shape=(-1,),
|
203
|
+
),
|
175
204
|
],
|
176
205
|
)
|
177
206
|
|
@@ -216,17 +245,22 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
216
245
|
return core.ModelSignature(
|
217
246
|
inputs=[
|
218
247
|
core.FeatureSpec(name="text", dtype=core.DataType.STRING),
|
219
|
-
core.FeatureSpec(name="text_pair", dtype=core.DataType.STRING),
|
220
248
|
],
|
221
249
|
outputs=[
|
222
|
-
core.
|
250
|
+
core.FeatureGroupSpec(
|
251
|
+
name="labels",
|
252
|
+
specs=[
|
253
|
+
core.FeatureSpec(name="label", dtype=core.DataType.STRING),
|
254
|
+
core.FeatureSpec(name="score", dtype=core.DataType.DOUBLE),
|
255
|
+
],
|
256
|
+
shape=(-1,),
|
257
|
+
),
|
223
258
|
],
|
224
259
|
)
|
225
260
|
# Else, return a dict per input
|
226
261
|
return core.ModelSignature(
|
227
262
|
inputs=[
|
228
263
|
core.FeatureSpec(name="text", dtype=core.DataType.STRING),
|
229
|
-
core.FeatureSpec(name="text_pair", dtype=core.DataType.STRING),
|
230
264
|
],
|
231
265
|
outputs=[
|
232
266
|
core.FeatureSpec(name="label", dtype=core.DataType.STRING),
|
@@ -243,9 +277,24 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
243
277
|
)
|
244
278
|
# Always generate a list of dict per input
|
245
279
|
return core.ModelSignature(
|
246
|
-
inputs=[
|
280
|
+
inputs=[
|
281
|
+
core.FeatureGroupSpec(
|
282
|
+
name="inputs",
|
283
|
+
specs=[
|
284
|
+
core.FeatureSpec(name="role", dtype=core.DataType.STRING),
|
285
|
+
core.FeatureSpec(name="content", dtype=core.DataType.STRING),
|
286
|
+
],
|
287
|
+
shape=(-1,),
|
288
|
+
),
|
289
|
+
],
|
247
290
|
outputs=[
|
248
|
-
core.
|
291
|
+
core.FeatureGroupSpec(
|
292
|
+
name="outputs",
|
293
|
+
specs=[
|
294
|
+
core.FeatureSpec(name="generated_text", dtype=core.DataType.STRING),
|
295
|
+
],
|
296
|
+
shape=(-1,),
|
297
|
+
)
|
249
298
|
],
|
250
299
|
)
|
251
300
|
|
@@ -300,3 +349,66 @@ def huggingface_pipeline_signature_auto_infer(task: str, params: Dict[str, Any])
|
|
300
349
|
|
301
350
|
def series_dropna(series: pd.Series) -> pd.Series:
|
302
351
|
return series.dropna(inplace=False).reset_index(drop=True).convert_dtypes()
|
352
|
+
|
353
|
+
|
354
|
+
def infer_list(name: str, data: List[Any]) -> core.BaseFeatureSpec:
|
355
|
+
"""Infer the feature specification from a list.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
name: Feature name.
|
359
|
+
data: A list.
|
360
|
+
|
361
|
+
Raises:
|
362
|
+
SnowflakeMLException: ValueError: Raised when empty list is provided.
|
363
|
+
|
364
|
+
Returns:
|
365
|
+
A feature specification.
|
366
|
+
"""
|
367
|
+
if not data:
|
368
|
+
raise snowml_exceptions.SnowflakeMLException(
|
369
|
+
error_code=error_codes.INVALID_DATA,
|
370
|
+
original_exception=ValueError("Data Validation Error: Empty list is found."),
|
371
|
+
)
|
372
|
+
|
373
|
+
if all(isinstance(value, dict) for value in data):
|
374
|
+
ft = infer_dict(name, data[0])
|
375
|
+
ft._name = name
|
376
|
+
ft._shape = (-1,)
|
377
|
+
return ft
|
378
|
+
|
379
|
+
arr = convert_list_to_ndarray(data)
|
380
|
+
arr_dtype = core.DataType.from_numpy_type(arr.dtype)
|
381
|
+
|
382
|
+
return core.FeatureSpec(name=name, dtype=arr_dtype, shape=arr.shape)
|
383
|
+
|
384
|
+
|
385
|
+
def infer_dict(name: str, data: Dict[str, Any]) -> core.FeatureGroupSpec:
|
386
|
+
"""Infer the feature specification from a dictionary.
|
387
|
+
|
388
|
+
Args:
|
389
|
+
name: Feature name.
|
390
|
+
data: A dictionary.
|
391
|
+
|
392
|
+
Raises:
|
393
|
+
SnowflakeMLException: ValueError: Raised when empty dictionary is provided.
|
394
|
+
SnowflakeMLException: ValueError: Raised when empty list is found in the dictionary.
|
395
|
+
|
396
|
+
Returns:
|
397
|
+
A feature group specification.
|
398
|
+
"""
|
399
|
+
if not data:
|
400
|
+
raise snowml_exceptions.SnowflakeMLException(
|
401
|
+
error_code=error_codes.INVALID_DATA,
|
402
|
+
original_exception=ValueError("Data Validation Error: Empty dictionary is found."),
|
403
|
+
)
|
404
|
+
|
405
|
+
specs = []
|
406
|
+
for key, value in data.items():
|
407
|
+
if isinstance(value, list):
|
408
|
+
specs.append(infer_list(key, value))
|
409
|
+
elif isinstance(value, dict):
|
410
|
+
specs.append(infer_dict(key, value))
|
411
|
+
else:
|
412
|
+
specs.append(core.FeatureSpec(name=key, dtype=core.DataType.from_numpy_type(np.array(value).dtype)))
|
413
|
+
|
414
|
+
return core.FeatureGroupSpec(name=name, specs=specs)
|
@@ -76,7 +76,7 @@ class ModelRef:
|
|
76
76
|
def __getattr__(self, method_name: str) -> Any:
|
77
77
|
if hasattr(self._model, method_name):
|
78
78
|
return MethodRef(self, method_name)
|
79
|
-
raise
|
79
|
+
raise AttributeError(f"Method {method_name} not found in model {self._name}.")
|
80
80
|
|
81
81
|
def __getstate__(self) -> Dict[str, Any]:
|
82
82
|
state = self.__dict__.copy()
|
@@ -94,7 +94,16 @@ class ModelRef:
|
|
94
94
|
|
95
95
|
class ModelContext:
|
96
96
|
"""
|
97
|
-
Context for a custom model
|
97
|
+
Context for a custom model storing paths to file artifacts and model object references.
|
98
|
+
|
99
|
+
Keyword argument values can be string file paths or supported in-memory models. Paths and model references
|
100
|
+
can be accessed with dictionary access methods in the custom model.
|
101
|
+
|
102
|
+
For example, in a custom model with `context=ModelContext(my_file='my_file.pkl', my_model=my_model)`,
|
103
|
+
the filepath and model reference can be accessed with `self.context['my_file']` and `self.context['my_model']`
|
104
|
+
in the inference and init methods.
|
105
|
+
|
106
|
+
The use of `artifacts` and `model_refs` arguments is deprecated. Set keyword arguments directly instead.
|
98
107
|
|
99
108
|
Attributes:
|
100
109
|
artifacts: A dictionary mapping the name of the artifact to its path.
|
@@ -267,14 +276,14 @@ def _validate_predict_function(func: Callable[[model_types.CustomModelType, pd.D
|
|
267
276
|
|
268
277
|
|
269
278
|
def inference_api(
|
270
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]
|
279
|
+
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
271
280
|
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
272
281
|
func.__dict__["_is_inference_api"] = True
|
273
282
|
return func
|
274
283
|
|
275
284
|
|
276
285
|
def partitioned_inference_api(
|
277
|
-
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]
|
286
|
+
func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
|
278
287
|
) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
|
279
288
|
func.__dict__["_is_inference_api"] = True
|
280
289
|
func.__dict__["_is_partitioned_inference_api"] = True
|
@@ -32,6 +32,7 @@ from snowflake.ml.model._signatures import (
|
|
32
32
|
base_handler,
|
33
33
|
builtins_handler as builtins_handler,
|
34
34
|
core,
|
35
|
+
dmatrix_handler,
|
35
36
|
numpy_handler,
|
36
37
|
pandas_handler,
|
37
38
|
pytorch_handler,
|
@@ -52,8 +53,11 @@ _LOCAL_DATA_HANDLERS: List[Type[base_handler.BaseDataHandler[Any]]] = [
|
|
52
53
|
numpy_handler.NumpyArrayHandler,
|
53
54
|
builtins_handler.ListOfBuiltinHandler,
|
54
55
|
numpy_handler.SeqOfNumpyArrayHandler,
|
56
|
+
pytorch_handler.PyTorchTensorHandler,
|
55
57
|
pytorch_handler.SeqOfPyTorchTensorHandler,
|
58
|
+
tensorflow_handler.TensorflowTensorHandler,
|
56
59
|
tensorflow_handler.SeqOfTensorflowTensorHandler,
|
60
|
+
dmatrix_handler.XGBoostDMatrixHandler,
|
57
61
|
]
|
58
62
|
_ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [snowpark_handler.SnowparkDataFrameHandler]
|
59
63
|
|
@@ -218,7 +222,6 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
218
222
|
strict: Enable strict validation, this includes value range based validation
|
219
223
|
|
220
224
|
Raises:
|
221
|
-
SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported.
|
222
225
|
SnowflakeMLException: ValueError: Raised when a feature cannot be found.
|
223
226
|
SnowflakeMLException: ValueError: Raised when feature is scalar but confront list element.
|
224
227
|
SnowflakeMLException: ValueError: Raised when feature type is not aligned in list element.
|
@@ -236,7 +239,10 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
236
239
|
except KeyError:
|
237
240
|
raise snowml_exceptions.SnowflakeMLException(
|
238
241
|
error_code=error_codes.INVALID_DATA,
|
239
|
-
original_exception=ValueError(
|
242
|
+
original_exception=ValueError(
|
243
|
+
f"Data Validation Error: feature {ft_name} does not exist in data. "
|
244
|
+
f"Available columns are {data.columns}."
|
245
|
+
),
|
240
246
|
)
|
241
247
|
|
242
248
|
if data_col.isnull().any():
|
@@ -244,10 +250,15 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
244
250
|
df_col_dtype = data_col.dtype
|
245
251
|
|
246
252
|
if isinstance(feature, core.FeatureGroupSpec):
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
253
|
+
if df_col_dtype != np.dtype("O"):
|
254
|
+
raise snowml_exceptions.SnowflakeMLException(
|
255
|
+
error_code=error_codes.INVALID_DATA,
|
256
|
+
original_exception=ValueError(
|
257
|
+
f"Data Validation Error in feature group {ft_name}: "
|
258
|
+
+ f"It needs to be a dictionary or list of dictionary, but get {df_col_dtype}."
|
259
|
+
),
|
260
|
+
)
|
261
|
+
continue
|
251
262
|
|
252
263
|
assert isinstance(feature, core.FeatureSpec) # assert for mypy.
|
253
264
|
ft_type = feature._dtype
|
@@ -437,7 +448,6 @@ def _validate_snowpark_data(
|
|
437
448
|
strict: Enable strict validation, this includes value range based validation.
|
438
449
|
|
439
450
|
Raises:
|
440
|
-
SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported.
|
441
451
|
SnowflakeMLException: ValueError: Raised when confronting invalid feature.
|
442
452
|
SnowflakeMLException: ValueError: Raised when a feature cannot be found.
|
443
453
|
|
@@ -467,10 +477,15 @@ def _validate_snowpark_data(
|
|
467
477
|
if field.name == ft_name:
|
468
478
|
found = True
|
469
479
|
if isinstance(feature, core.FeatureGroupSpec):
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
480
|
+
if not isinstance(field.datatype, (spt.ArrayType, spt.StructType, spt.VariantType)):
|
481
|
+
errors[identifier_rule].append(
|
482
|
+
ValueError(
|
483
|
+
f"Data Validation Error in feature group {feature.name}: "
|
484
|
+
+ f"Feature expects {feature.as_snowpark_type()},"
|
485
|
+
+ f" while {field.name} has type {field.datatype}."
|
486
|
+
),
|
487
|
+
)
|
488
|
+
continue
|
474
489
|
assert isinstance(feature, core.FeatureSpec) # mypy
|
475
490
|
ft_type = feature._dtype
|
476
491
|
field_data_type = field.datatype
|
@@ -644,11 +659,14 @@ def _validate_snowpark_type_feature(
|
|
644
659
|
)
|
645
660
|
|
646
661
|
|
647
|
-
def _convert_local_data_to_df(
|
662
|
+
def _convert_local_data_to_df(
|
663
|
+
data: model_types.SupportedLocalDataType, ensure_serializable: bool = False
|
664
|
+
) -> pd.DataFrame:
|
648
665
|
"""Convert local data to pandas DataFrame or Snowpark DataFrame
|
649
666
|
|
650
667
|
Args:
|
651
668
|
data: The provided data.
|
669
|
+
ensure_serializable: Ensure the data is serializable. Defaults to False.
|
652
670
|
|
653
671
|
Raises:
|
654
672
|
SnowflakeMLException: NotImplementedError: Raised when data cannot be handled by any data handler.
|
@@ -660,7 +678,7 @@ def _convert_local_data_to_df(data: model_types.SupportedLocalDataType) -> pd.Da
|
|
660
678
|
for handler in _LOCAL_DATA_HANDLERS:
|
661
679
|
if handler.can_handle(data):
|
662
680
|
handler.validate(data)
|
663
|
-
df = handler.convert_to_df(data, ensure_serializable=
|
681
|
+
df = handler.convert_to_df(data, ensure_serializable=ensure_serializable)
|
664
682
|
break
|
665
683
|
if df is None:
|
666
684
|
raise snowml_exceptions.SnowflakeMLException(
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -26,7 +26,15 @@ if TYPE_CHECKING:
|
|
26
26
|
from snowflake.ml.modeling.framework import base # noqa: F401
|
27
27
|
|
28
28
|
|
29
|
-
_SupportedBuiltins = Union[
|
29
|
+
_SupportedBuiltins = Union[
|
30
|
+
int,
|
31
|
+
float,
|
32
|
+
bool,
|
33
|
+
str,
|
34
|
+
bytes,
|
35
|
+
Dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
|
36
|
+
"_SupportedBuiltinsList",
|
37
|
+
]
|
30
38
|
_SupportedNumpyDtype = Union[
|
31
39
|
"np.int8",
|
32
40
|
"np.int16",
|
@@ -48,7 +56,7 @@ _SupportedBuiltinsList = Sequence[_SupportedBuiltins]
|
|
48
56
|
_SupportedArrayLike = Union[_SupportedNumpyArray, "torch.Tensor", "tensorflow.Tensor", "tensorflow.Variable"]
|
49
57
|
|
50
58
|
SupportedLocalDataType = Union[
|
51
|
-
"pd.DataFrame",
|
59
|
+
"pd.DataFrame", _SupportedArrayLike, Sequence[_SupportedArrayLike], _SupportedBuiltinsList
|
52
60
|
]
|
53
61
|
|
54
62
|
SupportedDataType = Union[SupportedLocalDataType, "snowflake.snowpark.DataFrame"]
|
@@ -177,16 +185,19 @@ class SNOWModelSaveOptions(BaseModelSaveOption):
|
|
177
185
|
class PyTorchSaveOptions(BaseModelSaveOption):
|
178
186
|
target_methods: NotRequired[Sequence[str]]
|
179
187
|
cuda_version: NotRequired[str]
|
188
|
+
multiple_inputs: NotRequired[bool]
|
180
189
|
|
181
190
|
|
182
191
|
class TorchScriptSaveOptions(BaseModelSaveOption):
|
183
192
|
target_methods: NotRequired[Sequence[str]]
|
184
193
|
cuda_version: NotRequired[str]
|
194
|
+
multiple_inputs: NotRequired[bool]
|
185
195
|
|
186
196
|
|
187
197
|
class TensorflowSaveOptions(BaseModelSaveOption):
|
188
198
|
target_methods: NotRequired[Sequence[str]]
|
189
199
|
cuda_version: NotRequired[str]
|
200
|
+
multiple_inputs: NotRequired[bool]
|
190
201
|
|
191
202
|
|
192
203
|
class MLFlowSaveOptions(BaseModelSaveOption):
|
@@ -98,6 +98,7 @@ def precision_recall_curve(
|
|
98
98
|
packages=[
|
99
99
|
f"cloudpickle=={cloudpickle.__version__}",
|
100
100
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
101
|
+
f"numpy=={np.__version__}",
|
101
102
|
"snowflake-snowpark-python",
|
102
103
|
],
|
103
104
|
statement_params=statement_params,
|
@@ -245,6 +246,7 @@ def roc_auc_score(
|
|
245
246
|
packages=[
|
246
247
|
f"cloudpickle=={cloudpickle.__version__}",
|
247
248
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
249
|
+
f"numpy=={np.__version__}",
|
248
250
|
"snowflake-snowpark-python",
|
249
251
|
],
|
250
252
|
statement_params=statement_params,
|
@@ -348,6 +350,7 @@ def roc_curve(
|
|
348
350
|
packages=[
|
349
351
|
f"cloudpickle=={cloudpickle.__version__}",
|
350
352
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
353
|
+
f"numpy=={np.__version__}",
|
351
354
|
"snowflake-snowpark-python",
|
352
355
|
],
|
353
356
|
statement_params=statement_params,
|
@@ -83,6 +83,7 @@ def d2_absolute_error_score(
|
|
83
83
|
packages=[
|
84
84
|
f"cloudpickle=={cloudpickle.__version__}",
|
85
85
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
86
|
+
f"numpy=={np.__version__}",
|
86
87
|
"snowflake-snowpark-python",
|
87
88
|
],
|
88
89
|
statement_params=statement_params,
|
@@ -180,6 +181,7 @@ def d2_pinball_score(
|
|
180
181
|
packages=[
|
181
182
|
f"cloudpickle=={cloudpickle.__version__}",
|
182
183
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
184
|
+
f"numpy=={np.__version__}",
|
183
185
|
"snowflake-snowpark-python",
|
184
186
|
],
|
185
187
|
statement_params=statement_params,
|
@@ -295,6 +297,7 @@ def explained_variance_score(
|
|
295
297
|
packages=[
|
296
298
|
f"cloudpickle=={cloudpickle.__version__}",
|
297
299
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
300
|
+
f"numpy=={np.__version__}",
|
298
301
|
"snowflake-snowpark-python",
|
299
302
|
],
|
300
303
|
statement_params=statement_params,
|
@@ -341,7 +341,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
341
341
|
is_permanent=False,
|
342
342
|
name=udf_name,
|
343
343
|
replace=True,
|
344
|
-
packages=["numpy"],
|
344
|
+
packages=[f"numpy=={np.__version__}"],
|
345
345
|
session=dataset._session,
|
346
346
|
statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
|
347
347
|
)
|