snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +25 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +31 -1
- snowflake/ml/jobs/_utils/payload_utils.py +232 -72
- snowflake/ml/jobs/_utils/spec_utils.py +78 -38
- snowflake/ml/jobs/decorators.py +8 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +49 -29
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +54 -33
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +12 -20
- snowflake/ml/model/_signatures/pandas_handler.py +28 -37
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +39 -13
- snowflake/ml/model/type_hints.py +28 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +52 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -8,10 +8,14 @@ from snowflake.ml._internal import type_utils
|
|
8
8
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
9
9
|
from snowflake.ml.model._packager.model_env import model_env
|
10
10
|
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
11
|
-
from snowflake.ml.model._packager.model_handlers_migrator import
|
11
|
+
from snowflake.ml.model._packager.model_handlers_migrator import (
|
12
|
+
base_migrator,
|
13
|
+
torchscript_migrator_2023_12_01,
|
14
|
+
)
|
12
15
|
from snowflake.ml.model._packager.model_meta import (
|
13
16
|
model_blob_meta,
|
14
17
|
model_meta as model_meta_api,
|
18
|
+
model_meta_schema,
|
15
19
|
)
|
16
20
|
from snowflake.ml.model._signatures import (
|
17
21
|
pytorch_handler,
|
@@ -30,9 +34,11 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
30
34
|
"""
|
31
35
|
|
32
36
|
HANDLER_TYPE = "torchscript"
|
33
|
-
HANDLER_VERSION = "
|
34
|
-
_MIN_SNOWPARK_ML_VERSION = "1.0
|
35
|
-
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
37
|
+
HANDLER_VERSION = "2025-03-01"
|
38
|
+
_MIN_SNOWPARK_ML_VERSION = "1.8.0"
|
39
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {
|
40
|
+
"2023-12-01": torchscript_migrator_2023_12_01.TorchScriptHandlerMigrator20231201
|
41
|
+
}
|
36
42
|
|
37
43
|
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
38
44
|
DEFAULT_TARGET_METHODS = ["forward"]
|
@@ -81,22 +87,32 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
81
87
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
82
88
|
)
|
83
89
|
|
90
|
+
multiple_inputs = kwargs.get("multiple_inputs", False)
|
91
|
+
|
84
92
|
def get_prediction(
|
85
93
|
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
86
94
|
) -> model_types.SupportedLocalDataType:
|
87
|
-
if
|
88
|
-
|
89
|
-
|
90
|
-
|
95
|
+
if multiple_inputs:
|
96
|
+
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
|
97
|
+
sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
|
98
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
99
|
+
)
|
100
|
+
else:
|
101
|
+
if not pytorch_handler.PyTorchTensorHandler.can_handle(sample_input_data):
|
102
|
+
sample_input_data = pytorch_handler.PyTorchTensorHandler.convert_from_df(
|
103
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
104
|
+
)
|
91
105
|
|
92
106
|
model.eval()
|
93
107
|
target_method = getattr(model, target_method_name, None)
|
94
108
|
assert callable(target_method)
|
95
109
|
with torch.no_grad():
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
110
|
+
if multiple_inputs:
|
111
|
+
predictions_df = target_method(*sample_input_data)
|
112
|
+
if not isinstance(predictions_df, tuple):
|
113
|
+
predictions_df = [predictions_df]
|
114
|
+
else:
|
115
|
+
predictions_df = target_method(sample_input_data)
|
100
116
|
|
101
117
|
return predictions_df
|
102
118
|
|
@@ -117,6 +133,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
117
133
|
model_type=cls.HANDLER_TYPE,
|
118
134
|
handler_version=cls.HANDLER_VERSION,
|
119
135
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
136
|
+
options=model_meta_schema.TorchScriptModelBlobOptions(multiple_inputs=multiple_inputs),
|
120
137
|
)
|
121
138
|
model_meta.models[name] = base_meta
|
122
139
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -170,6 +187,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
170
187
|
signature: model_signature.ModelSignature,
|
171
188
|
target_method: str,
|
172
189
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
190
|
+
multiple_inputs = cast(
|
191
|
+
model_meta_schema.TorchScriptModelBlobOptions, model_meta.models[model_meta.name].options
|
192
|
+
)["multiple_inputs"]
|
193
|
+
|
173
194
|
@custom_model.inference_api
|
174
195
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
175
196
|
if X.isnull().any(axis=None):
|
@@ -179,19 +200,27 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
|
|
179
200
|
|
180
201
|
raw_model.eval()
|
181
202
|
|
182
|
-
|
203
|
+
if multiple_inputs:
|
204
|
+
st = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
183
205
|
|
184
|
-
|
185
|
-
|
206
|
+
if kwargs.get("use_gpu", False):
|
207
|
+
st = [element.cuda() for element in st]
|
186
208
|
|
187
|
-
|
188
|
-
|
209
|
+
with torch.no_grad():
|
210
|
+
res = getattr(raw_model, target_method)(*st)
|
189
211
|
|
190
|
-
|
191
|
-
|
212
|
+
if not isinstance(res, tuple):
|
213
|
+
res = [res]
|
214
|
+
else:
|
215
|
+
t = pytorch_handler.PyTorchTensorHandler.convert_from_df(X, signature.inputs)
|
216
|
+
if kwargs.get("use_gpu", False):
|
217
|
+
t = t.cuda()
|
192
218
|
|
219
|
+
with torch.no_grad():
|
220
|
+
res = getattr(raw_model, target_method)(t)
|
193
221
|
return model_signature_utils.rename_pandas_df(
|
194
|
-
|
222
|
+
model_signature._convert_local_data_to_df(res, ensure_serializable=True),
|
223
|
+
features=signature.outputs,
|
195
224
|
)
|
196
225
|
|
197
226
|
return fn
|
@@ -99,10 +99,10 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
99
99
|
def get_prediction(
|
100
100
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
101
101
|
) -> model_types.SupportedLocalDataType:
|
102
|
-
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
102
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray, xgboost.DMatrix)):
|
103
103
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
104
104
|
|
105
|
-
if isinstance(model, xgboost.Booster):
|
105
|
+
if isinstance(model, xgboost.Booster) and not isinstance(sample_input_data, xgboost.DMatrix):
|
106
106
|
sample_input_data = xgboost.DMatrix(sample_input_data)
|
107
107
|
|
108
108
|
target_method = getattr(model, target_method_name, None)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class PyTorchHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
|
19
|
+
model_blob_options["multiple_inputs"] = True
|
20
|
+
model_meta.models[name].options = model_blob_options
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TensorflowHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-01-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
19
|
+
# To migrate code <= 1.7.0, default to keras model
|
20
|
+
is_old_model = "save_format" not in model_blob_options and "is_keras_model" not in model_blob_options
|
21
|
+
# To migrate code form 1.7.1, default to False.
|
22
|
+
is_keras_model = model_blob_options.get("is_keras_model", False)
|
23
|
+
# To migrate code from 1.7.2, default to tf, has options keras, keras_tf, cloudpickle, tf
|
24
|
+
#
|
25
|
+
# if is_keras_model or is_tf_keras_model:
|
26
|
+
# if is_keras_functional_or_sequential_model:
|
27
|
+
# save_format = "keras"
|
28
|
+
# elif keras_version.major == 2 or is_tf_keras_model:
|
29
|
+
# save_format = "keras_tf"
|
30
|
+
# else:
|
31
|
+
# save_format = "cloudpickle"
|
32
|
+
# else:
|
33
|
+
# save_format = "tf"
|
34
|
+
#
|
35
|
+
save_format = model_blob_options.get("save_format", "tf")
|
36
|
+
|
37
|
+
if save_format == "keras" or is_keras_model or is_old_model:
|
38
|
+
save_format = "keras_tf"
|
39
|
+
elif save_format == "cloudpickle":
|
40
|
+
# Given the old logic, this could only happen if the original model is a keras model, and keras is 3.x
|
41
|
+
# However, in this case, keras.Model does not extends from tensorflow.Module
|
42
|
+
# So actually TensorflowHandler will not be triggered, we could safely error this out.
|
43
|
+
raise NotImplementedError(
|
44
|
+
"Unable to upgrade keras 3.x model saved by old handler. This is not supposed to happen"
|
45
|
+
)
|
46
|
+
|
47
|
+
model_blob_options["save_format"] = save_format
|
48
|
+
model_meta.models[name].options = model_blob_options
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TensorflowHandlerMigrator20250101(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2025-01-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
model_blob_metadata = model_meta.models[name]
|
17
|
+
model_blob_options = cast(model_meta_schema.TensorflowModelBlobOptions, model_blob_metadata.options)
|
18
|
+
model_blob_options["multiple_inputs"] = True
|
19
|
+
model_meta.models[name].options = model_blob_options
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
4
|
+
from snowflake.ml.model._packager.model_meta import (
|
5
|
+
model_meta as model_meta_api,
|
6
|
+
model_meta_schema,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class TorchScriptHandlerMigrator20231201(base_migrator.BaseModelHandlerMigrator):
|
11
|
+
source_version = "2023-12-01"
|
12
|
+
target_version = "2025-03-01"
|
13
|
+
|
14
|
+
@staticmethod
|
15
|
+
def upgrade(name: str, model_meta: model_meta_api.ModelMetadata, model_blobs_dir_path: str) -> None:
|
16
|
+
|
17
|
+
model_blob_metadata = model_meta.models[name]
|
18
|
+
model_blob_options = cast(model_meta_schema.PyTorchModelBlobOptions, model_blob_metadata.options)
|
19
|
+
model_blob_options["multiple_inputs"] = True
|
20
|
+
model_meta.models[name].options = model_blob_options
|
@@ -48,6 +48,7 @@ def create_model_metadata(
|
|
48
48
|
ext_modules: Optional[List[ModuleType]] = None,
|
49
49
|
conda_dependencies: Optional[List[str]] = None,
|
50
50
|
pip_requirements: Optional[List[str]] = None,
|
51
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
51
52
|
python_version: Optional[str] = None,
|
52
53
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
53
54
|
**kwargs: Any,
|
@@ -67,6 +68,7 @@ def create_model_metadata(
|
|
67
68
|
ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
|
68
69
|
conda_dependencies: List of conda requirements for running the model. Defaults to None.
|
69
70
|
pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
|
71
|
+
artifact_repository_map: A dict mapping from package channel to artifact repository name.
|
70
72
|
python_version: A string of python version where model is run. Used for user override. If specified as None,
|
71
73
|
current version would be captured. Defaults to None.
|
72
74
|
task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
|
@@ -102,6 +104,7 @@ def create_model_metadata(
|
|
102
104
|
env = _create_env_for_model_metadata(
|
103
105
|
conda_dependencies=conda_dependencies,
|
104
106
|
pip_requirements=pip_requirements,
|
107
|
+
artifact_repository_map=artifact_repository_map,
|
105
108
|
python_version=python_version,
|
106
109
|
embed_local_ml_library=embed_local_ml_library,
|
107
110
|
)
|
@@ -151,6 +154,7 @@ def _create_env_for_model_metadata(
|
|
151
154
|
*,
|
152
155
|
conda_dependencies: Optional[List[str]] = None,
|
153
156
|
pip_requirements: Optional[List[str]] = None,
|
157
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
154
158
|
python_version: Optional[str] = None,
|
155
159
|
embed_local_ml_library: bool = False,
|
156
160
|
) -> model_env.ModelEnv:
|
@@ -159,6 +163,7 @@ def _create_env_for_model_metadata(
|
|
159
163
|
# Mypy doesn't like getter and setter have different types. See python/mypy #3004
|
160
164
|
env.conda_dependencies = conda_dependencies # type: ignore[assignment]
|
161
165
|
env.pip_requirements = pip_requirements # type: ignore[assignment]
|
166
|
+
env.artifact_repository_map = artifact_repository_map
|
162
167
|
env.python_version = python_version # type: ignore[assignment]
|
163
168
|
env.snowpark_ml_version = snowml_env.VERSION
|
164
169
|
|
@@ -331,7 +336,6 @@ class ModelMetadata:
|
|
331
336
|
"function_properties": self.function_properties,
|
332
337
|
}
|
333
338
|
)
|
334
|
-
|
335
339
|
with open(model_yaml_path, "w", encoding="utf-8") as out:
|
336
340
|
yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
|
337
341
|
yaml.safe_dump(model_dict, stream=out, default_flow_style=False)
|
@@ -352,7 +356,7 @@ class ModelMetadata:
|
|
352
356
|
version.parse(loaded_meta_min_snowpark_ml_version) > version.parse(snowml_env.VERSION)
|
353
357
|
):
|
354
358
|
raise RuntimeError(
|
355
|
-
f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version},"
|
359
|
+
f"The minimal version required to load the model is {loaded_meta_min_snowpark_ml_version}, "
|
356
360
|
f"while current version of Snowpark ML library is {snowml_env.VERSION}."
|
357
361
|
)
|
358
362
|
return model_meta_schema.ModelMetadataDict(
|
@@ -18,6 +18,7 @@ class FunctionProperties(Enum):
|
|
18
18
|
class ModelRuntimeDependenciesDict(TypedDict):
|
19
19
|
conda: Required[str]
|
20
20
|
pip: Required[str]
|
21
|
+
artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
|
21
22
|
|
22
23
|
|
23
24
|
class ModelRuntimeDict(TypedDict):
|
@@ -28,6 +29,7 @@ class ModelRuntimeDict(TypedDict):
|
|
28
29
|
class ModelEnvDict(TypedDict):
|
29
30
|
conda: Required[str]
|
30
31
|
pip: Required[str]
|
32
|
+
artifact_repository_map: NotRequired[Optional[Dict[str, str]]]
|
31
33
|
python_version: Required[str]
|
32
34
|
cuda_version: NotRequired[Optional[str]]
|
33
35
|
snowpark_ml_version: Required[str]
|
@@ -44,6 +46,9 @@ class CatBoostModelBlobOptions(BaseModelBlobOptions):
|
|
44
46
|
class HuggingFacePipelineModelBlobOptions(BaseModelBlobOptions):
|
45
47
|
task: Required[str]
|
46
48
|
batch_size: Required[int]
|
49
|
+
has_tokenizer: NotRequired[bool]
|
50
|
+
has_feature_extractor: NotRequired[bool]
|
51
|
+
has_image_preprocessor: NotRequired[bool]
|
47
52
|
|
48
53
|
|
49
54
|
class LightGBMModelBlobOptions(BaseModelBlobOptions):
|
@@ -58,8 +63,17 @@ class XgboostModelBlobOptions(BaseModelBlobOptions):
|
|
58
63
|
xgb_estimator_type: Required[str]
|
59
64
|
|
60
65
|
|
66
|
+
class PyTorchModelBlobOptions(BaseModelBlobOptions):
|
67
|
+
multiple_inputs: Required[bool]
|
68
|
+
|
69
|
+
|
70
|
+
class TorchScriptModelBlobOptions(BaseModelBlobOptions):
|
71
|
+
multiple_inputs: Required[bool]
|
72
|
+
|
73
|
+
|
61
74
|
class TensorflowModelBlobOptions(BaseModelBlobOptions):
|
62
75
|
save_format: Required[str]
|
76
|
+
multiple_inputs: Required[bool]
|
63
77
|
|
64
78
|
|
65
79
|
class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
|
@@ -71,6 +85,8 @@ ModelBlobOptions = Union[
|
|
71
85
|
HuggingFacePipelineModelBlobOptions,
|
72
86
|
MLFlowModelBlobOptions,
|
73
87
|
XgboostModelBlobOptions,
|
88
|
+
PyTorchModelBlobOptions,
|
89
|
+
TorchScriptModelBlobOptions,
|
74
90
|
TensorflowModelBlobOptions,
|
75
91
|
SentenceTransformersModelBlobOptions,
|
76
92
|
]
|
@@ -43,13 +43,13 @@ class ModelPackager:
|
|
43
43
|
metadata: Optional[Dict[str, str]] = None,
|
44
44
|
conda_dependencies: Optional[List[str]] = None,
|
45
45
|
pip_requirements: Optional[List[str]] = None,
|
46
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
46
47
|
python_version: Optional[str] = None,
|
47
48
|
ext_modules: Optional[List[ModuleType]] = None,
|
48
49
|
code_paths: Optional[List[str]] = None,
|
49
|
-
options:
|
50
|
+
options: model_types.ModelSaveOption,
|
50
51
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
51
52
|
) -> model_meta.ModelMetadata:
|
52
|
-
|
53
53
|
if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
|
54
54
|
raise snowml_exceptions.SnowflakeMLException(
|
55
55
|
error_code=error_codes.INVALID_ARGUMENT,
|
@@ -58,9 +58,6 @@ class ModelPackager:
|
|
58
58
|
),
|
59
59
|
)
|
60
60
|
|
61
|
-
if not options:
|
62
|
-
options = model_types.BaseModelSaveOption()
|
63
|
-
|
64
61
|
handler = model_handler.find_handler(model)
|
65
62
|
if handler is None:
|
66
63
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -77,6 +74,7 @@ class ModelPackager:
|
|
77
74
|
ext_modules=ext_modules,
|
78
75
|
conda_dependencies=conda_dependencies,
|
79
76
|
pip_requirements=pip_requirements,
|
77
|
+
artifact_repository_map=artifact_repository_map,
|
80
78
|
python_version=python_version,
|
81
79
|
task=task,
|
82
80
|
**options,
|
@@ -1,2 +1 @@
|
|
1
|
-
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.
|
2
|
-
ALL_REQUIREMENTS=['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'catboost>=1.2.0, <2', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'huggingface_hub<0.26', 'importlib_resources>=6.1.1, <7', 'lightgbm>=4.1.0, <5', 'mlflow>=2.16.0, <3', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pytorch>=2.0.1,<2.3.0', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'sentence-transformers>=2.2.2,<3', 'sentencepiece>=0.1.95,<1', 'shap>=0.46.0,<1', 'snowflake-connector-python>=3.5.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'tensorflow>=2.12.0,<3', 'tokenizers>=0.10,<1', 'torchdata>=0.4,<1', 'transformers>=4.32.1,<5', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
1
|
+
REQUIREMENTS = ['absl-py>=0.15,<2', 'aiohttp!=4.0.0a0, !=4.0.0a1', 'anyio>=3.5.0,<5', 'cachetools>=3.1.1,<6', 'cloudpickle>=2.0.0', 'cryptography', 'fsspec>=2024.6.1,<2026', 'importlib_resources>=6.1.1, <7', 'numpy>=1.23,<2', 'packaging>=20.9,<25', 'pandas>=1.0.0,<3', 'pyarrow', 'pyjwt>=2.0.0, <3', 'pytimeparse>=1.1.8,<2', 'pyyaml>=6.0,<7', 'requests', 'retrying>=1.3.3,<2', 's3fs>=2024.6.1,<2026', 'scikit-learn>=1.4,<1.6', 'scipy>=1.9,<2', 'snowflake-connector-python>=3.12.0,<4', 'snowflake-snowpark-python>=1.17.0,<2,!=1.26.0', 'sqlparse>=0.4,<1', 'typing-extensions>=4.1.0,<5', 'xgboost>=1.7.3,<3']
|
@@ -45,6 +45,7 @@ class ModelRuntime:
|
|
45
45
|
self.name = name
|
46
46
|
self.runtime_env = copy.deepcopy(env)
|
47
47
|
self.imports = imports or []
|
48
|
+
self.is_gpu = is_gpu
|
48
49
|
|
49
50
|
if loading_from_file:
|
50
51
|
return
|
@@ -88,13 +89,18 @@ class ModelRuntime:
|
|
88
89
|
self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
|
89
90
|
self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
|
90
91
|
|
91
|
-
env_dict = self.runtime_env.save_as_dict(
|
92
|
+
env_dict = self.runtime_env.save_as_dict(
|
93
|
+
packager_path, default_channel_override=default_channel_override, is_gpu=self.is_gpu
|
94
|
+
)
|
92
95
|
|
93
96
|
return model_meta_schema.ModelRuntimeDict(
|
94
97
|
imports=list(map(str, self.imports)),
|
95
98
|
dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
|
96
99
|
conda=env_dict["conda"],
|
97
100
|
pip=env_dict["pip"],
|
101
|
+
artifact_repository_map=env_dict["artifact_repository_map"]
|
102
|
+
if env_dict.get("artifact_repository_map") is not None
|
103
|
+
else {},
|
98
104
|
),
|
99
105
|
)
|
100
106
|
|
@@ -109,6 +115,7 @@ class ModelRuntime:
|
|
109
115
|
env.python_version = meta_env.python_version
|
110
116
|
env.cuda_version = meta_env.cuda_version
|
111
117
|
env.snowpark_ml_version = meta_env.snowpark_ml_version
|
118
|
+
env.artifact_repository_map = meta_env.artifact_repository_map
|
112
119
|
|
113
120
|
conda_env_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["conda"])
|
114
121
|
pip_requirements_rel_path = pathlib.PurePosixPath(loaded_dict["dependencies"]["pip"])
|
@@ -24,7 +24,11 @@ def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pi
|
|
24
24
|
from sklearn.base import is_classifier, is_regressor
|
25
25
|
|
26
26
|
if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
|
27
|
-
|
27
|
+
if hasattr(model, "predict_proba") or hasattr(model, "predict"):
|
28
|
+
model = model.steps[-1][1] # type: ignore[attr-defined]
|
29
|
+
return _get_model_task(model)
|
30
|
+
else:
|
31
|
+
return type_hints.Task.UNKNOWN
|
28
32
|
if is_regressor(model):
|
29
33
|
return type_hints.Task.TABULAR_REGRESSION
|
30
34
|
if is_classifier(model):
|
@@ -14,21 +14,32 @@ from snowflake.ml.model._signatures import base_handler, core, pandas_handler
|
|
14
14
|
|
15
15
|
|
16
16
|
class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBuiltinsList]):
|
17
|
+
@staticmethod
|
18
|
+
def _can_handle_element(
|
19
|
+
element: model_types._SupportedBuiltins,
|
20
|
+
) -> TypeGuard[model_types._SupportedBuiltins]:
|
21
|
+
if isinstance(element, abc.Sequence) and not isinstance(element, str):
|
22
|
+
for sub_element in element:
|
23
|
+
if not ListOfBuiltinHandler._can_handle_element(sub_element):
|
24
|
+
return False
|
25
|
+
return True
|
26
|
+
elif isinstance(element, abc.Mapping):
|
27
|
+
for key, value in element.items():
|
28
|
+
if not isinstance(key, str):
|
29
|
+
return False
|
30
|
+
if not ListOfBuiltinHandler._can_handle_element(value):
|
31
|
+
return False
|
32
|
+
return True
|
33
|
+
else:
|
34
|
+
return isinstance(element, (int, float, bool, str, datetime.datetime))
|
35
|
+
|
17
36
|
@staticmethod
|
18
37
|
def can_handle(data: model_types.SupportedDataType) -> TypeGuard[model_types._SupportedBuiltinsList]:
|
19
38
|
if not isinstance(data, abc.Sequence) or isinstance(data, str):
|
20
39
|
return False
|
21
40
|
if len(data) == 0:
|
22
41
|
return False
|
23
|
-
|
24
|
-
for element in data:
|
25
|
-
# String is a Sequence but we take them as an whole
|
26
|
-
if isinstance(element, abc.Sequence) and not isinstance(element, str):
|
27
|
-
can_handle = ListOfBuiltinHandler.can_handle(element)
|
28
|
-
elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
|
29
|
-
can_handle = False
|
30
|
-
break
|
31
|
-
return can_handle
|
42
|
+
return ListOfBuiltinHandler._can_handle_element(data)
|
32
43
|
|
33
44
|
@staticmethod
|
34
45
|
def count(data: model_types._SupportedBuiltinsList) -> int:
|