snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +25 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +31 -1
- snowflake/ml/jobs/_utils/payload_utils.py +232 -72
- snowflake/ml/jobs/_utils/spec_utils.py +78 -38
- snowflake/ml/jobs/decorators.py +8 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +49 -29
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +54 -33
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +12 -20
- snowflake/ml/model/_signatures/pandas_handler.py +28 -37
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +39 -13
- snowflake/ml/model/type_hints.py +28 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +52 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import itertools
|
|
3
3
|
import os
|
4
4
|
import pathlib
|
5
5
|
import warnings
|
6
|
-
from typing import DefaultDict, List, Optional
|
6
|
+
from typing import DefaultDict, Dict, List, Optional
|
7
7
|
|
8
8
|
from packaging import requirements, version
|
9
9
|
|
@@ -36,6 +36,7 @@ class ModelEnv:
|
|
36
36
|
pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
|
37
37
|
self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
|
38
38
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
|
39
|
+
self.artifact_repository_map: Optional[Dict[str, str]] = None
|
39
40
|
self._conda_dependencies: DefaultDict[str, List[requirements.Requirement]] = collections.defaultdict(list)
|
40
41
|
self._pip_requirements: List[requirements.Requirement] = []
|
41
42
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
@@ -113,7 +114,33 @@ class ModelEnv:
|
|
113
114
|
self._snowpark_ml_version = version.parse(snowpark_ml_version)
|
114
115
|
|
115
116
|
def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
|
116
|
-
"""Append requirements into model env if absent.
|
117
|
+
"""Append requirements into model env if absent. Depending on the environment, requirements may be added
|
118
|
+
to either the pip requirements or conda dependencies.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
pkgs: A list of ModelDependency namedtuple to be appended.
|
122
|
+
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
123
|
+
"""
|
124
|
+
if self.pip_requirements and not self.conda_dependencies and pkgs:
|
125
|
+
pip_pkg_reqs: List[str] = []
|
126
|
+
warnings.warn(
|
127
|
+
(
|
128
|
+
"Dependencies specified from pip requirements."
|
129
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
130
|
+
),
|
131
|
+
category=UserWarning,
|
132
|
+
stacklevel=2,
|
133
|
+
)
|
134
|
+
for conda_req_str, pip_name in pkgs:
|
135
|
+
_, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
|
136
|
+
pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
|
137
|
+
pip_pkg_reqs.append(str(pip_req))
|
138
|
+
self._include_if_absent_pip(pip_pkg_reqs, check_local_version)
|
139
|
+
else:
|
140
|
+
self._include_if_absent_conda(pkgs, check_local_version)
|
141
|
+
|
142
|
+
def _include_if_absent_conda(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
|
143
|
+
"""Append requirements into model env conda dependencies if absent.
|
117
144
|
|
118
145
|
Args:
|
119
146
|
pkgs: A list of ModelDependency namedtuple to be appended.
|
@@ -134,8 +161,8 @@ class ModelEnv:
|
|
134
161
|
if show_warning_message:
|
135
162
|
warnings.warn(
|
136
163
|
(
|
137
|
-
f"Basic dependency {req_to_add.name} specified from
|
138
|
-
|
164
|
+
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
165
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
139
166
|
),
|
140
167
|
category=UserWarning,
|
141
168
|
stacklevel=2,
|
@@ -157,11 +184,11 @@ class ModelEnv:
|
|
157
184
|
stacklevel=2,
|
158
185
|
)
|
159
186
|
|
160
|
-
def
|
161
|
-
"""Append pip requirements into model env if absent.
|
187
|
+
def _include_if_absent_pip(self, pkgs: List[str], check_local_version: bool = False) -> None:
|
188
|
+
"""Append pip requirements into model env pip requirements if absent.
|
162
189
|
|
163
190
|
Args:
|
164
|
-
pkgs: A list of
|
191
|
+
pkgs: A list of strings to be appended to pip environment.
|
165
192
|
check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
|
166
193
|
"""
|
167
194
|
|
@@ -187,25 +214,6 @@ class ModelEnv:
|
|
187
214
|
self._conda_dependencies[channel].remove(spec)
|
188
215
|
|
189
216
|
def generate_env_for_cuda(self) -> None:
|
190
|
-
if self.cuda_version is None:
|
191
|
-
return
|
192
|
-
|
193
|
-
cuda_spec = env_utils.find_dep_spec(
|
194
|
-
self._conda_dependencies, self._pip_requirements, conda_pkg_name="cuda", remove_spec=False
|
195
|
-
)
|
196
|
-
if cuda_spec and not cuda_spec.specifier.contains(self.cuda_version):
|
197
|
-
raise ValueError(
|
198
|
-
"The CUDA requirement you specified in your conda dependencies or pip requirements is"
|
199
|
-
" conflicting with CUDA version required. Please do not specify CUDA dependency using conda"
|
200
|
-
" dependencies or pip requirements."
|
201
|
-
)
|
202
|
-
|
203
|
-
if not cuda_spec:
|
204
|
-
self.include_if_absent(
|
205
|
-
[ModelDependency(requirement=f"nvidia::cuda=={self.cuda_version}.*", pip_name="cuda")],
|
206
|
-
check_local_version=False,
|
207
|
-
)
|
208
|
-
|
209
217
|
xgboost_spec = env_utils.find_dep_spec(
|
210
218
|
self._conda_dependencies, self._pip_requirements, conda_pkg_name="xgboost", remove_spec=True
|
211
219
|
)
|
@@ -236,7 +244,7 @@ class ModelEnv:
|
|
236
244
|
check_local_version=False,
|
237
245
|
)
|
238
246
|
|
239
|
-
self.
|
247
|
+
self._include_if_absent_pip(["bitsandbytes>=0.41.0"], check_local_version=False)
|
240
248
|
|
241
249
|
def relax_version(self) -> None:
|
242
250
|
"""Relax the version requirements for both conda dependencies and pip requirements.
|
@@ -252,7 +260,9 @@ class ModelEnv:
|
|
252
260
|
self._pip_requirements = list(map(env_utils.relax_requirement_version, self._pip_requirements))
|
253
261
|
|
254
262
|
def load_from_conda_file(self, conda_env_path: pathlib.Path) -> None:
|
255
|
-
conda_dependencies_dict, pip_requirements_list, python_version = env_utils.load_conda_env_file(
|
263
|
+
conda_dependencies_dict, pip_requirements_list, python_version, cuda_version = env_utils.load_conda_env_file(
|
264
|
+
conda_env_path
|
265
|
+
)
|
256
266
|
|
257
267
|
for channel, channel_dependencies in conda_dependencies_dict.items():
|
258
268
|
if channel != env_utils.DEFAULT_CHANNEL_NAME:
|
@@ -310,6 +320,9 @@ class ModelEnv:
|
|
310
320
|
if python_version:
|
311
321
|
self.python_version = python_version
|
312
322
|
|
323
|
+
if cuda_version:
|
324
|
+
self.cuda_version = cuda_version
|
325
|
+
|
313
326
|
def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
|
314
327
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
315
328
|
|
@@ -333,6 +346,7 @@ class ModelEnv:
|
|
333
346
|
def load_from_dict(self, base_dir: pathlib.Path, env_dict: model_meta_schema.ModelEnvDict) -> None:
|
334
347
|
self.conda_env_rel_path = pathlib.PurePosixPath(env_dict["conda"])
|
335
348
|
self.pip_requirements_rel_path = pathlib.PurePosixPath(env_dict["pip"])
|
349
|
+
self.artifact_repository_map = env_dict.get("artifact_repository_map", None)
|
336
350
|
|
337
351
|
self.load_from_conda_file(base_dir / self.conda_env_rel_path)
|
338
352
|
self.load_from_pip_file(base_dir / self.pip_requirements_rel_path)
|
@@ -342,12 +356,17 @@ class ModelEnv:
|
|
342
356
|
self.snowpark_ml_version = env_dict["snowpark_ml_version"]
|
343
357
|
|
344
358
|
def save_as_dict(
|
345
|
-
self,
|
359
|
+
self,
|
360
|
+
base_dir: pathlib.Path,
|
361
|
+
default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL,
|
362
|
+
is_gpu: Optional[bool] = False,
|
346
363
|
) -> model_meta_schema.ModelEnvDict:
|
364
|
+
cuda_version = self.cuda_version if is_gpu else None
|
347
365
|
env_utils.save_conda_env_file(
|
348
366
|
pathlib.Path(base_dir / self.conda_env_rel_path),
|
349
367
|
self._conda_dependencies,
|
350
368
|
self.python_version,
|
369
|
+
cuda_version,
|
351
370
|
default_channel_override=default_channel_override,
|
352
371
|
)
|
353
372
|
env_utils.save_requirements_file(
|
@@ -356,6 +375,7 @@ class ModelEnv:
|
|
356
375
|
return {
|
357
376
|
"conda": self.conda_env_rel_path.as_posix(),
|
358
377
|
"pip": self.pip_requirements_rel_path.as_posix(),
|
378
|
+
"artifact_repository_map": self.artifact_repository_map if self.artifact_repository_map is not None else {},
|
359
379
|
"python_version": self.python_version,
|
360
380
|
"cuda_version": self.cuda_version,
|
361
381
|
"snowpark_ml_version": self.snowpark_ml_version,
|
@@ -39,7 +39,7 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
|
|
39
39
|
|
40
40
|
|
41
41
|
def get_truncated_sample_data(
|
42
|
-
sample_input_data: model_types.SupportedDataType, length: int = 100
|
42
|
+
sample_input_data: model_types.SupportedDataType, length: int = 100, is_for_modeling_model: bool = False
|
43
43
|
) -> model_types.SupportedLocalDataType:
|
44
44
|
trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
|
45
45
|
local_sample_input: model_types.SupportedLocalDataType = None
|
@@ -47,6 +47,8 @@ def get_truncated_sample_data(
|
|
47
47
|
# Added because of Any from missing stubs.
|
48
48
|
trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
|
49
49
|
local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
|
50
|
+
if is_for_modeling_model:
|
51
|
+
local_sample_input.columns = trunc_sample_input.columns
|
50
52
|
else:
|
51
53
|
local_sample_input = trunc_sample_input
|
52
54
|
return local_sample_input
|
@@ -58,13 +60,15 @@ def validate_signature(
|
|
58
60
|
target_methods: Iterable[str],
|
59
61
|
sample_input_data: Optional[model_types.SupportedDataType],
|
60
62
|
get_prediction_fn: Callable[[str, model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
|
63
|
+
is_for_modeling_model: bool = False,
|
61
64
|
) -> model_meta.ModelMetadata:
|
62
65
|
if model_meta.signatures:
|
63
66
|
validate_target_methods(model, list(model_meta.signatures.keys()))
|
64
67
|
if sample_input_data is not None:
|
65
|
-
local_sample_input = get_truncated_sample_data(
|
68
|
+
local_sample_input = get_truncated_sample_data(
|
69
|
+
sample_input_data, is_for_modeling_model=is_for_modeling_model
|
70
|
+
)
|
66
71
|
for target_method in model_meta.signatures.keys():
|
67
|
-
|
68
72
|
model_signature_inst = model_meta.signatures.get(target_method)
|
69
73
|
if model_signature_inst is not None:
|
70
74
|
# strict validation the input signature
|
@@ -77,7 +81,7 @@ def validate_signature(
|
|
77
81
|
assert (
|
78
82
|
sample_input_data is not None
|
79
83
|
), "Model signature and sample input are None at the same time. This should not happen with local model."
|
80
|
-
local_sample_input = get_truncated_sample_data(sample_input_data)
|
84
|
+
local_sample_input = get_truncated_sample_data(sample_input_data, is_for_modeling_model=is_for_modeling_model)
|
81
85
|
for target_method in target_methods:
|
82
86
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
83
87
|
sig = model_signature.infer_signature(
|
@@ -30,10 +30,7 @@ from snowflake.ml.model._packager.model_meta import (
|
|
30
30
|
model_meta as model_meta_api,
|
31
31
|
model_meta_schema,
|
32
32
|
)
|
33
|
-
from snowflake.ml.model._signatures import
|
34
|
-
builtins_handler,
|
35
|
-
utils as model_signature_utils,
|
36
|
-
)
|
33
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
37
34
|
from snowflake.ml.model.models import huggingface_pipeline
|
38
35
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
39
36
|
|
@@ -66,16 +63,16 @@ def get_requirements_from_task(task: str, spcs_only: bool = False) -> List[model
|
|
66
63
|
return []
|
67
64
|
|
68
65
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
66
|
+
def sanitize_output(data: Any) -> Any:
|
67
|
+
if isinstance(data, np.number):
|
68
|
+
return data.item()
|
69
|
+
if isinstance(data, np.ndarray):
|
70
|
+
return sanitize_output(data.tolist())
|
71
|
+
if isinstance(data, list):
|
72
|
+
return [sanitize_output(x) for x in data]
|
73
|
+
if isinstance(data, dict):
|
74
|
+
return {k: sanitize_output(v) for k, v in data.items()}
|
75
|
+
return data
|
79
76
|
|
80
77
|
|
81
78
|
@final
|
@@ -146,6 +143,10 @@ class HuggingFacePipelineHandler(
|
|
146
143
|
framework = getattr(model, "framework", None)
|
147
144
|
batch_size = getattr(model, "batch_size", None)
|
148
145
|
|
146
|
+
has_tokenizer = getattr(model, "tokenizer", None) is not None
|
147
|
+
has_feature_extractor = getattr(model, "feature_extractor", None) is not None
|
148
|
+
has_image_preprocessor = getattr(model, "image_preprocessor", None) is not None
|
149
|
+
|
149
150
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
150
151
|
params = {
|
151
152
|
**model._preprocess_params, # type:ignore[attr-defined]
|
@@ -234,6 +235,9 @@ class HuggingFacePipelineHandler(
|
|
234
235
|
{
|
235
236
|
"task": task,
|
236
237
|
"batch_size": batch_size if batch_size is not None else 1,
|
238
|
+
"has_tokenizer": has_tokenizer,
|
239
|
+
"has_feature_extractor": has_feature_extractor,
|
240
|
+
"has_image_preprocessor": has_image_preprocessor,
|
237
241
|
}
|
238
242
|
),
|
239
243
|
)
|
@@ -308,6 +312,14 @@ class HuggingFacePipelineHandler(
|
|
308
312
|
if os.path.isdir(model_blob_file_or_dir_path):
|
309
313
|
import transformers
|
310
314
|
|
315
|
+
additional_pipeline_params = {}
|
316
|
+
if model_blob_options.get("has_tokenizer", False):
|
317
|
+
additional_pipeline_params["tokenizer"] = model_blob_file_or_dir_path
|
318
|
+
if model_blob_options.get("has_feature_extractor", False):
|
319
|
+
additional_pipeline_params["feature_extractor"] = model_blob_file_or_dir_path
|
320
|
+
if model_blob_options.get("has_image_preprocessor", False):
|
321
|
+
additional_pipeline_params["image_preprocessor"] = model_blob_file_or_dir_path
|
322
|
+
|
311
323
|
with open(
|
312
324
|
os.path.join(
|
313
325
|
model_blob_file_or_dir_path,
|
@@ -324,6 +336,7 @@ class HuggingFacePipelineHandler(
|
|
324
336
|
model=model_blob_file_or_dir_path,
|
325
337
|
trust_remote_code=True,
|
326
338
|
torch_dtype="auto",
|
339
|
+
**additional_pipeline_params,
|
327
340
|
**device_config,
|
328
341
|
)
|
329
342
|
|
@@ -394,13 +407,17 @@ class HuggingFacePipelineHandler(
|
|
394
407
|
)
|
395
408
|
for conv_data in X.to_dict("records")
|
396
409
|
]
|
397
|
-
elif len(signature.inputs) == 1:
|
398
|
-
input_data = X.to_dict("list")[signature.inputs[0].name]
|
399
410
|
else:
|
400
411
|
if isinstance(raw_model, transformers.TableQuestionAnsweringPipeline):
|
401
412
|
X["table"] = X["table"].apply(json.loads)
|
402
413
|
|
403
|
-
|
414
|
+
# Most pipelines if it is expecting more than one arguments,
|
415
|
+
# it is expecting a list of dict, where each dict has keys corresponding to the argument.
|
416
|
+
if len(signature.inputs) > 1:
|
417
|
+
input_data = X.to_dict("records")
|
418
|
+
# If it is only expecting one argument, Then it is expecting a list of something.
|
419
|
+
else:
|
420
|
+
input_data = X[signature.inputs[0].name].to_list()
|
404
421
|
temp_res = getattr(raw_model, target_method)(input_data)
|
405
422
|
|
406
423
|
# Some huggingface pipeline will omit the outer list when there is only 1 input.
|
@@ -423,7 +440,6 @@ class HuggingFacePipelineHandler(
|
|
423
440
|
),
|
424
441
|
)
|
425
442
|
and X.shape[0] == 1
|
426
|
-
and isinstance(temp_res[0], dict)
|
427
443
|
)
|
428
444
|
):
|
429
445
|
temp_res = [temp_res]
|
@@ -437,14 +453,18 @@ class HuggingFacePipelineHandler(
|
|
437
453
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
438
454
|
|
439
455
|
# To concat those who outputs a list with one input.
|
440
|
-
if
|
441
|
-
|
442
|
-
|
456
|
+
if isinstance(temp_res[0], list):
|
457
|
+
if isinstance(temp_res[0][0], dict):
|
458
|
+
res = pd.DataFrame({0: temp_res})
|
459
|
+
else:
|
460
|
+
res = pd.DataFrame(temp_res)
|
461
|
+
else:
|
443
462
|
res = pd.DataFrame(temp_res)
|
444
|
-
|
445
|
-
|
463
|
+
|
464
|
+
if hasattr(res, "map"):
|
465
|
+
res = res.map(sanitize_output)
|
446
466
|
else:
|
447
|
-
|
467
|
+
res = res.applymap(sanitize_output)
|
448
468
|
|
449
469
|
return model_signature_utils.rename_pandas_df(data=res, features=signature.outputs)
|
450
470
|
|
@@ -0,0 +1,226 @@
|
|
1
|
+
import os
|
2
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
3
|
+
|
4
|
+
import cloudpickle
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
from packaging import version
|
8
|
+
from typing_extensions import TypeGuard, Unpack
|
9
|
+
|
10
|
+
from snowflake.ml._internal import type_utils
|
11
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
12
|
+
from snowflake.ml.model._packager.model_env import model_env
|
13
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
14
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
15
|
+
from snowflake.ml.model._packager.model_meta import (
|
16
|
+
model_blob_meta,
|
17
|
+
model_meta as model_meta_api,
|
18
|
+
)
|
19
|
+
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import keras
|
23
|
+
|
24
|
+
|
25
|
+
@final
|
26
|
+
class KerasHandler(_base.BaseModelHandler["keras.Model"]):
|
27
|
+
"""Handler for Keras v3 model.
|
28
|
+
|
29
|
+
Currently keras.Model based classes are supported.
|
30
|
+
"""
|
31
|
+
|
32
|
+
HANDLER_TYPE = "keras"
|
33
|
+
HANDLER_VERSION = "2025-01-01"
|
34
|
+
_MIN_SNOWPARK_ML_VERSION = "1.7.5"
|
35
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
|
+
|
37
|
+
MODEL_BLOB_FILE_OR_DIR = "model.keras"
|
38
|
+
CUSTOM_OBJECT_SAVE_PATH = "custom_objects.pkl"
|
39
|
+
DEFAULT_TARGET_METHODS = ["predict"]
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def can_handle(
|
43
|
+
cls,
|
44
|
+
model: model_types.SupportedModelType,
|
45
|
+
) -> TypeGuard["keras.Model"]:
|
46
|
+
if not type_utils.LazyType("keras.Model").isinstance(model):
|
47
|
+
return False
|
48
|
+
import keras
|
49
|
+
|
50
|
+
return version.parse(keras.__version__) >= version.parse("3.0.0")
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def cast_model(
|
54
|
+
cls,
|
55
|
+
model: model_types.SupportedModelType,
|
56
|
+
) -> "keras.Model":
|
57
|
+
import keras
|
58
|
+
|
59
|
+
assert isinstance(model, keras.Model)
|
60
|
+
|
61
|
+
return cast(keras.Model, model)
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def save_model(
|
65
|
+
cls,
|
66
|
+
name: str,
|
67
|
+
model: "keras.Model",
|
68
|
+
model_meta: model_meta_api.ModelMetadata,
|
69
|
+
model_blobs_dir_path: str,
|
70
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
71
|
+
is_sub_model: Optional[bool] = False,
|
72
|
+
**kwargs: Unpack[model_types.TensorflowSaveOptions],
|
73
|
+
) -> None:
|
74
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
75
|
+
if enable_explainability:
|
76
|
+
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
77
|
+
|
78
|
+
import keras
|
79
|
+
|
80
|
+
assert isinstance(model, keras.Model)
|
81
|
+
|
82
|
+
if not is_sub_model:
|
83
|
+
target_methods = handlers_utils.get_target_methods(
|
84
|
+
model=model,
|
85
|
+
target_methods=kwargs.pop("target_methods", None),
|
86
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
87
|
+
)
|
88
|
+
|
89
|
+
def get_prediction(
|
90
|
+
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
91
|
+
) -> model_types.SupportedLocalDataType:
|
92
|
+
target_method = getattr(model, target_method_name, None)
|
93
|
+
assert callable(target_method)
|
94
|
+
predictions_df = target_method(sample_input_data)
|
95
|
+
|
96
|
+
if (
|
97
|
+
type_utils.LazyType("tensorflow.Tensor").isinstance(predictions_df)
|
98
|
+
or type_utils.LazyType("tensorflow.Variable").isinstance(predictions_df)
|
99
|
+
or type_utils.LazyType("torch.Tensor").isinstance(predictions_df)
|
100
|
+
):
|
101
|
+
predictions_df = [predictions_df]
|
102
|
+
|
103
|
+
return predictions_df
|
104
|
+
|
105
|
+
model_meta = handlers_utils.validate_signature(
|
106
|
+
model=model,
|
107
|
+
model_meta=model_meta,
|
108
|
+
target_methods=target_methods,
|
109
|
+
sample_input_data=sample_input_data,
|
110
|
+
get_prediction_fn=get_prediction,
|
111
|
+
)
|
112
|
+
|
113
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
114
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
115
|
+
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
116
|
+
model.save(save_path)
|
117
|
+
|
118
|
+
custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
|
119
|
+
custom_objects = keras.saving.get_custom_objects()
|
120
|
+
with open(custom_object_save_path, "wb") as f:
|
121
|
+
cloudpickle.dump(custom_objects, f)
|
122
|
+
|
123
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
124
|
+
name=name,
|
125
|
+
model_type=cls.HANDLER_TYPE,
|
126
|
+
handler_version=cls.HANDLER_VERSION,
|
127
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
128
|
+
)
|
129
|
+
model_meta.models[name] = base_meta
|
130
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
131
|
+
|
132
|
+
dependencies = [
|
133
|
+
model_env.ModelDependency(requirement="keras>=3", pip_name="keras"),
|
134
|
+
]
|
135
|
+
keras_backend = keras.backend.backend()
|
136
|
+
if keras_backend == "tensorflow":
|
137
|
+
dependencies.append(model_env.ModelDependency(requirement="tensorflow", pip_name="tensorflow"))
|
138
|
+
elif keras_backend == "torch":
|
139
|
+
dependencies.append(model_env.ModelDependency(requirement="pytorch", pip_name="torch"))
|
140
|
+
elif keras_backend == "jax":
|
141
|
+
dependencies.append(model_env.ModelDependency(requirement="jax", pip_name="jax"))
|
142
|
+
else:
|
143
|
+
raise ValueError(f"Unsupported backend {keras_backend}")
|
144
|
+
|
145
|
+
model_meta.env.include_if_absent(
|
146
|
+
dependencies,
|
147
|
+
check_local_version=True,
|
148
|
+
)
|
149
|
+
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def load_model(
|
153
|
+
cls,
|
154
|
+
name: str,
|
155
|
+
model_meta: model_meta_api.ModelMetadata,
|
156
|
+
model_blobs_dir_path: str,
|
157
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
158
|
+
) -> "keras.Model":
|
159
|
+
import keras
|
160
|
+
|
161
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
162
|
+
model_blobs_metadata = model_meta.models
|
163
|
+
model_blob_metadata = model_blobs_metadata[name]
|
164
|
+
model_blob_filename = model_blob_metadata.path
|
165
|
+
|
166
|
+
custom_object_save_path = os.path.join(model_blob_path, cls.CUSTOM_OBJECT_SAVE_PATH)
|
167
|
+
with open(custom_object_save_path, "rb") as f:
|
168
|
+
custom_objects = cloudpickle.load(f)
|
169
|
+
load_path = os.path.join(model_blob_path, model_blob_filename)
|
170
|
+
m = keras.models.load_model(load_path, custom_objects=custom_objects, safe_mode=False)
|
171
|
+
|
172
|
+
return cast(keras.Model, m)
|
173
|
+
|
174
|
+
@classmethod
|
175
|
+
def convert_as_custom_model(
|
176
|
+
cls,
|
177
|
+
raw_model: "keras.Model",
|
178
|
+
model_meta: model_meta_api.ModelMetadata,
|
179
|
+
background_data: Optional[pd.DataFrame] = None,
|
180
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
181
|
+
) -> custom_model.CustomModel:
|
182
|
+
|
183
|
+
from snowflake.ml.model import custom_model
|
184
|
+
|
185
|
+
def _create_custom_model(
|
186
|
+
raw_model: "keras.Model",
|
187
|
+
model_meta: model_meta_api.ModelMetadata,
|
188
|
+
) -> Type[custom_model.CustomModel]:
|
189
|
+
def fn_factory(
|
190
|
+
raw_model: "keras.Model",
|
191
|
+
signature: model_signature.ModelSignature,
|
192
|
+
target_method: str,
|
193
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
194
|
+
dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
|
195
|
+
|
196
|
+
@custom_model.inference_api
|
197
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
198
|
+
res = getattr(raw_model, target_method)(X.astype(dtype_map), verbose=0)
|
199
|
+
|
200
|
+
if isinstance(res, list) and len(res) > 0 and isinstance(res[0], np.ndarray):
|
201
|
+
# In case of multi-output estimators, predict_proba(), decision_function(), etc., functions
|
202
|
+
# return a list of ndarrays. We need to deal them separately
|
203
|
+
df = numpy_handler.SeqOfNumpyArrayHandler.convert_to_df(res)
|
204
|
+
else:
|
205
|
+
df = pd.DataFrame(res)
|
206
|
+
|
207
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
208
|
+
|
209
|
+
return fn
|
210
|
+
|
211
|
+
type_method_dict = {}
|
212
|
+
for target_method_name, sig in model_meta.signatures.items():
|
213
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
214
|
+
|
215
|
+
_KerasModel = type(
|
216
|
+
"_KerasModel",
|
217
|
+
(custom_model.CustomModel,),
|
218
|
+
type_method_dict,
|
219
|
+
)
|
220
|
+
|
221
|
+
return _KerasModel
|
222
|
+
|
223
|
+
_KerasModel = _create_custom_model(raw_model, model_meta)
|
224
|
+
keras_model = _KerasModel(custom_model.ModelContext())
|
225
|
+
|
226
|
+
return keras_model
|