snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/telemetry.py +4 -0
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +25 -1
- snowflake/ml/jobs/_utils/payload_utils.py +94 -20
- snowflake/ml/jobs/_utils/spec_utils.py +95 -31
- snowflake/ml/jobs/decorators.py +7 -0
- snowflake/ml/jobs/manager.py +20 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +113 -17
- snowflake/ml/model/_client/ops/service_ops.py +16 -5
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +10 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +4 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
- snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +52 -31
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +9 -17
- snowflake/ml/model/_signatures/pandas_handler.py +19 -30
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +31 -13
- snowflake/ml/model/type_hints.py +13 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +59 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from snowflake.ml.model._signatures import (
|
|
32
32
|
base_handler,
|
33
33
|
builtins_handler as builtins_handler,
|
34
34
|
core,
|
35
|
+
dmatrix_handler,
|
35
36
|
numpy_handler,
|
36
37
|
pandas_handler,
|
37
38
|
pytorch_handler,
|
@@ -52,8 +53,11 @@ _LOCAL_DATA_HANDLERS: List[Type[base_handler.BaseDataHandler[Any]]] = [
|
|
52
53
|
numpy_handler.NumpyArrayHandler,
|
53
54
|
builtins_handler.ListOfBuiltinHandler,
|
54
55
|
numpy_handler.SeqOfNumpyArrayHandler,
|
56
|
+
pytorch_handler.PyTorchTensorHandler,
|
55
57
|
pytorch_handler.SeqOfPyTorchTensorHandler,
|
58
|
+
tensorflow_handler.TensorflowTensorHandler,
|
56
59
|
tensorflow_handler.SeqOfTensorflowTensorHandler,
|
60
|
+
dmatrix_handler.XGBoostDMatrixHandler,
|
57
61
|
]
|
58
62
|
_ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [snowpark_handler.SnowparkDataFrameHandler]
|
59
63
|
|
@@ -218,7 +222,6 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
218
222
|
strict: Enable strict validation, this includes value range based validation
|
219
223
|
|
220
224
|
Raises:
|
221
|
-
SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported.
|
222
225
|
SnowflakeMLException: ValueError: Raised when a feature cannot be found.
|
223
226
|
SnowflakeMLException: ValueError: Raised when feature is scalar but confront list element.
|
224
227
|
SnowflakeMLException: ValueError: Raised when feature type is not aligned in list element.
|
@@ -236,7 +239,10 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
236
239
|
except KeyError:
|
237
240
|
raise snowml_exceptions.SnowflakeMLException(
|
238
241
|
error_code=error_codes.INVALID_DATA,
|
239
|
-
original_exception=ValueError(
|
242
|
+
original_exception=ValueError(
|
243
|
+
f"Data Validation Error: feature {ft_name} does not exist in data. "
|
244
|
+
f"Available columns are {data.columns}."
|
245
|
+
),
|
240
246
|
)
|
241
247
|
|
242
248
|
if data_col.isnull().any():
|
@@ -244,10 +250,15 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
244
250
|
df_col_dtype = data_col.dtype
|
245
251
|
|
246
252
|
if isinstance(feature, core.FeatureGroupSpec):
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
253
|
+
if df_col_dtype != np.dtype("O"):
|
254
|
+
raise snowml_exceptions.SnowflakeMLException(
|
255
|
+
error_code=error_codes.INVALID_DATA,
|
256
|
+
original_exception=ValueError(
|
257
|
+
f"Data Validation Error in feature group {ft_name}: "
|
258
|
+
+ f"It needs to be a dictionary or list of dictionary, but get {df_col_dtype}."
|
259
|
+
),
|
260
|
+
)
|
261
|
+
continue
|
251
262
|
|
252
263
|
assert isinstance(feature, core.FeatureSpec) # assert for mypy.
|
253
264
|
ft_type = feature._dtype
|
@@ -437,7 +448,6 @@ def _validate_snowpark_data(
|
|
437
448
|
strict: Enable strict validation, this includes value range based validation.
|
438
449
|
|
439
450
|
Raises:
|
440
|
-
SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported.
|
441
451
|
SnowflakeMLException: ValueError: Raised when confronting invalid feature.
|
442
452
|
SnowflakeMLException: ValueError: Raised when a feature cannot be found.
|
443
453
|
|
@@ -467,10 +477,15 @@ def _validate_snowpark_data(
|
|
467
477
|
if field.name == ft_name:
|
468
478
|
found = True
|
469
479
|
if isinstance(feature, core.FeatureGroupSpec):
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
480
|
+
if not isinstance(field.datatype, (spt.ArrayType, spt.StructType, spt.VariantType)):
|
481
|
+
errors[identifier_rule].append(
|
482
|
+
ValueError(
|
483
|
+
f"Data Validation Error in feature group {feature.name}: "
|
484
|
+
+ f"Feature expects {feature.as_snowpark_type()},"
|
485
|
+
+ f" while {field.name} has type {field.datatype}."
|
486
|
+
),
|
487
|
+
)
|
488
|
+
continue
|
474
489
|
assert isinstance(feature, core.FeatureSpec) # mypy
|
475
490
|
ft_type = feature._dtype
|
476
491
|
field_data_type = field.datatype
|
@@ -644,11 +659,14 @@ def _validate_snowpark_type_feature(
|
|
644
659
|
)
|
645
660
|
|
646
661
|
|
647
|
-
def _convert_local_data_to_df(
|
662
|
+
def _convert_local_data_to_df(
|
663
|
+
data: model_types.SupportedLocalDataType, ensure_serializable: bool = False
|
664
|
+
) -> pd.DataFrame:
|
648
665
|
"""Convert local data to pandas DataFrame or Snowpark DataFrame
|
649
666
|
|
650
667
|
Args:
|
651
668
|
data: The provided data.
|
669
|
+
ensure_serializable: Ensure the data is serializable. Defaults to False.
|
652
670
|
|
653
671
|
Raises:
|
654
672
|
SnowflakeMLException: NotImplementedError: Raised when data cannot be handled by any data handler.
|
@@ -660,7 +678,7 @@ def _convert_local_data_to_df(data: model_types.SupportedLocalDataType) -> pd.Da
|
|
660
678
|
for handler in _LOCAL_DATA_HANDLERS:
|
661
679
|
if handler.can_handle(data):
|
662
680
|
handler.validate(data)
|
663
|
-
df = handler.convert_to_df(data, ensure_serializable=
|
681
|
+
df = handler.convert_to_df(data, ensure_serializable=ensure_serializable)
|
664
682
|
break
|
665
683
|
if df is None:
|
666
684
|
raise snowml_exceptions.SnowflakeMLException(
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -26,7 +26,15 @@ if TYPE_CHECKING:
|
|
26
26
|
from snowflake.ml.modeling.framework import base # noqa: F401
|
27
27
|
|
28
28
|
|
29
|
-
_SupportedBuiltins = Union[
|
29
|
+
_SupportedBuiltins = Union[
|
30
|
+
int,
|
31
|
+
float,
|
32
|
+
bool,
|
33
|
+
str,
|
34
|
+
bytes,
|
35
|
+
Dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
|
36
|
+
"_SupportedBuiltinsList",
|
37
|
+
]
|
30
38
|
_SupportedNumpyDtype = Union[
|
31
39
|
"np.int8",
|
32
40
|
"np.int16",
|
@@ -48,7 +56,7 @@ _SupportedBuiltinsList = Sequence[_SupportedBuiltins]
|
|
48
56
|
_SupportedArrayLike = Union[_SupportedNumpyArray, "torch.Tensor", "tensorflow.Tensor", "tensorflow.Variable"]
|
49
57
|
|
50
58
|
SupportedLocalDataType = Union[
|
51
|
-
"pd.DataFrame",
|
59
|
+
"pd.DataFrame", _SupportedArrayLike, Sequence[_SupportedArrayLike], _SupportedBuiltinsList
|
52
60
|
]
|
53
61
|
|
54
62
|
SupportedDataType = Union[SupportedLocalDataType, "snowflake.snowpark.DataFrame"]
|
@@ -177,16 +185,19 @@ class SNOWModelSaveOptions(BaseModelSaveOption):
|
|
177
185
|
class PyTorchSaveOptions(BaseModelSaveOption):
|
178
186
|
target_methods: NotRequired[Sequence[str]]
|
179
187
|
cuda_version: NotRequired[str]
|
188
|
+
multiple_inputs: NotRequired[bool]
|
180
189
|
|
181
190
|
|
182
191
|
class TorchScriptSaveOptions(BaseModelSaveOption):
|
183
192
|
target_methods: NotRequired[Sequence[str]]
|
184
193
|
cuda_version: NotRequired[str]
|
194
|
+
multiple_inputs: NotRequired[bool]
|
185
195
|
|
186
196
|
|
187
197
|
class TensorflowSaveOptions(BaseModelSaveOption):
|
188
198
|
target_methods: NotRequired[Sequence[str]]
|
189
199
|
cuda_version: NotRequired[str]
|
200
|
+
multiple_inputs: NotRequired[bool]
|
190
201
|
|
191
202
|
|
192
203
|
class MLFlowSaveOptions(BaseModelSaveOption):
|
@@ -130,7 +130,11 @@ def is_single_node(session: Session) -> bool:
|
|
130
130
|
warehouse_name = session.get_current_warehouse()
|
131
131
|
if warehouse_name:
|
132
132
|
warehouse_name = warehouse_name.replace('"', "")
|
133
|
-
|
133
|
+
df_list = session.sql(f"SHOW WAREHOUSES like '{warehouse_name}';")['"type"', '"size"'].collect()
|
134
|
+
# If no warehouse data is found, default to True (single node)
|
135
|
+
if not df_list:
|
136
|
+
return True
|
137
|
+
df = df_list[0]
|
134
138
|
# filter out the conditions when it is single node
|
135
139
|
single_node: bool = (df[0] == "SNOWPARK-OPTIMIZED" and df[1] == "Medium") or (
|
136
140
|
df[0] == "STANDARD" and df[1] == "X-Small"
|
@@ -98,6 +98,7 @@ def precision_recall_curve(
|
|
98
98
|
packages=[
|
99
99
|
f"cloudpickle=={cloudpickle.__version__}",
|
100
100
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
101
|
+
f"numpy=={np.__version__}",
|
101
102
|
"snowflake-snowpark-python",
|
102
103
|
],
|
103
104
|
statement_params=statement_params,
|
@@ -245,6 +246,7 @@ def roc_auc_score(
|
|
245
246
|
packages=[
|
246
247
|
f"cloudpickle=={cloudpickle.__version__}",
|
247
248
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
249
|
+
f"numpy=={np.__version__}",
|
248
250
|
"snowflake-snowpark-python",
|
249
251
|
],
|
250
252
|
statement_params=statement_params,
|
@@ -348,6 +350,7 @@ def roc_curve(
|
|
348
350
|
packages=[
|
349
351
|
f"cloudpickle=={cloudpickle.__version__}",
|
350
352
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
353
|
+
f"numpy=={np.__version__}",
|
351
354
|
"snowflake-snowpark-python",
|
352
355
|
],
|
353
356
|
statement_params=statement_params,
|
@@ -83,6 +83,7 @@ def d2_absolute_error_score(
|
|
83
83
|
packages=[
|
84
84
|
f"cloudpickle=={cloudpickle.__version__}",
|
85
85
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
86
|
+
f"numpy=={np.__version__}",
|
86
87
|
"snowflake-snowpark-python",
|
87
88
|
],
|
88
89
|
statement_params=statement_params,
|
@@ -180,6 +181,7 @@ def d2_pinball_score(
|
|
180
181
|
packages=[
|
181
182
|
f"cloudpickle=={cloudpickle.__version__}",
|
182
183
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
184
|
+
f"numpy=={np.__version__}",
|
183
185
|
"snowflake-snowpark-python",
|
184
186
|
],
|
185
187
|
statement_params=statement_params,
|
@@ -295,6 +297,7 @@ def explained_variance_score(
|
|
295
297
|
packages=[
|
296
298
|
f"cloudpickle=={cloudpickle.__version__}",
|
297
299
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
300
|
+
f"numpy=={np.__version__}",
|
298
301
|
"snowflake-snowpark-python",
|
299
302
|
],
|
300
303
|
statement_params=statement_params,
|
@@ -341,7 +341,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
341
341
|
is_permanent=False,
|
342
342
|
name=udf_name,
|
343
343
|
replace=True,
|
344
|
-
packages=["numpy"],
|
344
|
+
packages=[f"numpy=={np.__version__}"],
|
345
345
|
session=dataset._session,
|
346
346
|
statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
|
347
347
|
)
|
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import pandas as pd
|
5
5
|
from absl.logging import logging
|
6
6
|
|
7
|
-
from snowflake.ml._internal import telemetry
|
7
|
+
from snowflake.ml._internal import platform_capabilities, telemetry
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
@@ -13,7 +13,7 @@ from snowflake.ml.model._client.model import model_impl, model_version_impl
|
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
16
|
-
from snowflake.snowpark import session
|
16
|
+
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
17
17
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
@@ -46,6 +46,7 @@ class ModelManager:
|
|
46
46
|
metrics: Optional[Dict[str, Any]] = None,
|
47
47
|
conda_dependencies: Optional[List[str]] = None,
|
48
48
|
pip_requirements: Optional[List[str]] = None,
|
49
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
49
50
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
50
51
|
python_version: Optional[str] = None,
|
51
52
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
@@ -127,6 +128,7 @@ class ModelManager:
|
|
127
128
|
metrics=metrics,
|
128
129
|
conda_dependencies=conda_dependencies,
|
129
130
|
pip_requirements=pip_requirements,
|
131
|
+
artifact_repository_map=artifact_repository_map,
|
130
132
|
target_platforms=target_platforms,
|
131
133
|
python_version=python_version,
|
132
134
|
signatures=signatures,
|
@@ -149,6 +151,7 @@ class ModelManager:
|
|
149
151
|
metrics: Optional[Dict[str, Any]] = None,
|
150
152
|
conda_dependencies: Optional[List[str]] = None,
|
151
153
|
pip_requirements: Optional[List[str]] = None,
|
154
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
152
155
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
153
156
|
python_version: Optional[str] = None,
|
154
157
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
@@ -163,11 +166,42 @@ class ModelManager:
|
|
163
166
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
164
167
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
165
168
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
169
|
+
use_live_commit = platform_capabilities.PlatformCapabilities.get_instance().is_live_commit_enabled()
|
170
|
+
if use_live_commit:
|
171
|
+
logger.info("Using live commit model version")
|
172
|
+
else:
|
173
|
+
logger.info("Using non-live commit model version")
|
174
|
+
|
175
|
+
if use_live_commit:
|
176
|
+
# This step creates the live model version, and the files can be written directly to the stage
|
177
|
+
# after this.
|
178
|
+
try:
|
179
|
+
self._model_ops.add_or_create_live_version(
|
180
|
+
database_name=database_name_id,
|
181
|
+
schema_name=schema_name_id,
|
182
|
+
model_name=model_name_id,
|
183
|
+
version_name=version_name_id,
|
184
|
+
statement_params=statement_params,
|
185
|
+
)
|
186
|
+
except (AssertionError, snowpark_exceptions.SnowparkSQLException) as e:
|
187
|
+
logger.info(f"Failed to create live model version: {e}, falling back to regular model version creation")
|
188
|
+
use_live_commit = False
|
189
|
+
|
190
|
+
if use_live_commit:
|
191
|
+
# using model version's stage path to write files directly to the stage
|
192
|
+
stage_path = self._model_ops.get_model_version_stage_path(
|
193
|
+
database_name=database_name_id,
|
194
|
+
schema_name=schema_name_id,
|
195
|
+
model_name=model_name_id,
|
196
|
+
version_name=version_name_id,
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
# using a temp path to write files and then upload to the model version's stage
|
200
|
+
stage_path = self._model_ops.prepare_model_temp_stage_path(
|
201
|
+
database_name=database_name_id,
|
202
|
+
schema_name=schema_name_id,
|
203
|
+
statement_params=statement_params,
|
204
|
+
)
|
171
205
|
|
172
206
|
platforms = None
|
173
207
|
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
@@ -175,6 +209,18 @@ class ModelManager:
|
|
175
209
|
# Convert any string target platforms to TargetPlatform objects
|
176
210
|
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
177
211
|
|
212
|
+
if artifact_repository_map:
|
213
|
+
for channel, artifact_repository_name in artifact_repository_map.items():
|
214
|
+
db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
|
215
|
+
|
216
|
+
artifact_repository_map[channel] = sql_identifier.get_fully_qualified_name(
|
217
|
+
db_id,
|
218
|
+
schema_id,
|
219
|
+
repo_id,
|
220
|
+
self._database_name,
|
221
|
+
self._schema_name,
|
222
|
+
)
|
223
|
+
|
178
224
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
179
225
|
|
180
226
|
mc = model_composer.ModelComposer(
|
@@ -187,6 +233,7 @@ class ModelManager:
|
|
187
233
|
sample_input_data=sample_input_data,
|
188
234
|
conda_dependencies=conda_dependencies,
|
189
235
|
pip_requirements=pip_requirements,
|
236
|
+
artifact_repository_map=artifact_repository_map,
|
190
237
|
target_platforms=platforms,
|
191
238
|
python_version=python_version,
|
192
239
|
user_files=user_files,
|
@@ -211,6 +258,7 @@ class ModelManager:
|
|
211
258
|
model_name=model_name_id,
|
212
259
|
version_name=version_name_id,
|
213
260
|
statement_params=statement_params,
|
261
|
+
use_live_commit=use_live_commit,
|
214
262
|
)
|
215
263
|
|
216
264
|
mv = model_version_impl.ModelVersion._ref(
|
@@ -108,12 +108,15 @@ class Registry:
|
|
108
108
|
metrics: Optional[Dict[str, Any]] = None,
|
109
109
|
conda_dependencies: Optional[List[str]] = None,
|
110
110
|
pip_requirements: Optional[List[str]] = None,
|
111
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
111
112
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
112
113
|
python_version: Optional[str] = None,
|
113
114
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
114
115
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
116
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
115
117
|
code_paths: Optional[List[str]] = None,
|
116
118
|
ext_modules: Optional[List[ModuleType]] = None,
|
119
|
+
task: model_types.Task = model_types.Task.UNKNOWN,
|
117
120
|
options: Optional[model_types.ModelSaveOption] = None,
|
118
121
|
) -> ModelVersion:
|
119
122
|
"""
|
@@ -140,6 +143,13 @@ class Registry:
|
|
140
143
|
See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
141
144
|
Models with pip requirements specified will not be executable in Snowflake Warehouse where all
|
142
145
|
dependencies must be retrieved from Snowflake Anaconda Channel.
|
146
|
+
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
147
|
+
repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
|
148
|
+
Note : This feature is currently in Private Preview; please contact your Snowflake account team
|
149
|
+
to enable it.
|
150
|
+
Format: {channel_name: artifact_repository_name}, where:
|
151
|
+
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
152
|
+
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
143
153
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
144
154
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
145
155
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -148,10 +158,15 @@ class Registry:
|
|
148
158
|
infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
|
149
159
|
sample_input_data: Sample input data to infer model signatures from.
|
150
160
|
It would also be used as background data in explanation and to capture data lineage. Defaults to None.
|
161
|
+
user_files: Dictionary where the keys are subdirectories, and values are lists of local file name
|
162
|
+
strings. The local file name strings can include wildcards (? or *) for matching multiple files.
|
151
163
|
code_paths: List of directories containing code to import. Defaults to None.
|
152
164
|
ext_modules: List of external modules to pickle with the model object.
|
153
165
|
Only supported when logging the following types of model:
|
154
166
|
Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
|
167
|
+
task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
|
168
|
+
TABULAR_BINARY_CLASSIFICATION, TABULAR_MULTI_CLASSIFICATION, TABULAR_RANKING, or UNKNOWN. By default,
|
169
|
+
it is set to Task.UNKNOWN and may be overridden by inferring from the Model Object.
|
155
170
|
options (Dict[str, Any], optional): Additional model saving options.
|
156
171
|
|
157
172
|
Model Saving Options include:
|
@@ -163,6 +178,9 @@ class Registry:
|
|
163
178
|
Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
164
179
|
- function_type: Set the method function type globally. To set method function types individually see
|
165
180
|
function_type in model_options.
|
181
|
+
- target_methods: List of target methods to register when logging the model.
|
182
|
+
This option is not used in MLFlow models. Defaults to None, in which case the model handler's
|
183
|
+
default target methods will be used.
|
166
184
|
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
167
185
|
values with the desired options.
|
168
186
|
|
@@ -210,6 +228,7 @@ class Registry:
|
|
210
228
|
"metrics",
|
211
229
|
"conda_dependencies",
|
212
230
|
"pip_requirements",
|
231
|
+
"artifact_repository_map",
|
213
232
|
"target_platforms",
|
214
233
|
"python_version",
|
215
234
|
"signatures",
|
@@ -225,6 +244,7 @@ class Registry:
|
|
225
244
|
metrics: Optional[Dict[str, Any]] = None,
|
226
245
|
conda_dependencies: Optional[List[str]] = None,
|
227
246
|
pip_requirements: Optional[List[str]] = None,
|
247
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
228
248
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
229
249
|
python_version: Optional[str] = None,
|
230
250
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
@@ -259,6 +279,13 @@ class Registry:
|
|
259
279
|
See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
260
280
|
Models with pip requirements specified will not be executable in Snowflake Warehouse where all
|
261
281
|
dependencies must be retrieved from Snowflake Anaconda Channel.
|
282
|
+
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
283
|
+
repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
|
284
|
+
Note : This feature is currently in Private Preview; please contact your Snowflake account team to
|
285
|
+
enable it.
|
286
|
+
Format: {channel_name: artifact_repository_name}, where:
|
287
|
+
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
288
|
+
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
262
289
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
263
290
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
264
291
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -287,6 +314,9 @@ class Registry:
|
|
287
314
|
Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
288
315
|
- function_type: Set the method function type globally. To set method function types individually see
|
289
316
|
function_type in model_options.
|
317
|
+
- target_methods: List of target methods to register when logging the model.
|
318
|
+
This option is not used in MLFlow models. Defaults to None, in which case the model handler's
|
319
|
+
default target methods will be used.
|
290
320
|
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
291
321
|
values with the desired options. See the example below.
|
292
322
|
|
@@ -300,6 +330,9 @@ class Registry:
|
|
300
330
|
Defaults to None, determined automatically by Snowflake.
|
301
331
|
- function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
|
302
332
|
|
333
|
+
Raises:
|
334
|
+
ValueError: If extra arguments are specified ModelVersion is provided.
|
335
|
+
|
303
336
|
Returns:
|
304
337
|
ModelVersion: ModelVersion object corresponding to the model just logged.
|
305
338
|
|
@@ -322,13 +355,37 @@ class Registry:
|
|
322
355
|
registry.log_model(
|
323
356
|
model=model,
|
324
357
|
model_name="my_model",
|
325
|
-
|
358
|
+
options={"method_options": method_options},
|
326
359
|
)
|
327
360
|
"""
|
328
361
|
statement_params = telemetry.get_statement_params(
|
329
362
|
project=_TELEMETRY_PROJECT,
|
330
363
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
331
364
|
)
|
365
|
+
if isinstance(model, ModelVersion):
|
366
|
+
# check that no arguments are provided other than the ones for copy model.
|
367
|
+
invalid_args = [
|
368
|
+
comment,
|
369
|
+
conda_dependencies,
|
370
|
+
pip_requirements,
|
371
|
+
artifact_repository_map,
|
372
|
+
target_platforms,
|
373
|
+
python_version,
|
374
|
+
signatures,
|
375
|
+
sample_input_data,
|
376
|
+
user_files,
|
377
|
+
code_paths,
|
378
|
+
ext_modules,
|
379
|
+
options,
|
380
|
+
]
|
381
|
+
for arg in invalid_args:
|
382
|
+
if arg is not None:
|
383
|
+
raise ValueError(
|
384
|
+
"When calling log_model with a ModelVersion, only model_name and version_name may be specified."
|
385
|
+
)
|
386
|
+
if task is not model_types.Task.UNKNOWN:
|
387
|
+
raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
|
388
|
+
|
332
389
|
if pip_requirements:
|
333
390
|
warnings.warn(
|
334
391
|
"Models logged specifying `pip_requirements` can not be executed "
|
@@ -345,6 +402,7 @@ class Registry:
|
|
345
402
|
metrics=metrics,
|
346
403
|
conda_dependencies=conda_dependencies,
|
347
404
|
pip_requirements=pip_requirements,
|
405
|
+
artifact_repository_map=artifact_repository_map,
|
348
406
|
target_platforms=target_platforms,
|
349
407
|
python_version=python_version,
|
350
408
|
signatures=signatures,
|
snowflake/ml/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
VERSION="1.
|
1
|
+
VERSION="1.8.1"
|