snowflake-ml-python 1.7.5__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. snowflake/cortex/_complete.py +58 -3
  2. snowflake/ml/_internal/file_utils.py +18 -4
  3. snowflake/ml/_internal/platform_capabilities.py +3 -0
  4. snowflake/ml/_internal/telemetry.py +4 -0
  5. snowflake/ml/fileset/fileset.py +0 -1
  6. snowflake/ml/jobs/_utils/constants.py +25 -1
  7. snowflake/ml/jobs/_utils/payload_utils.py +94 -20
  8. snowflake/ml/jobs/_utils/spec_utils.py +95 -31
  9. snowflake/ml/jobs/decorators.py +7 -0
  10. snowflake/ml/jobs/manager.py +20 -0
  11. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  12. snowflake/ml/model/_client/ops/model_ops.py +113 -17
  13. snowflake/ml/model/_client/ops/service_ops.py +16 -5
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
  15. snowflake/ml/model/_client/sql/model_version.py +58 -0
  16. snowflake/ml/model/_client/sql/service.py +10 -2
  17. snowflake/ml/model/_model_composer/model_composer.py +50 -3
  18. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +5 -2
  19. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  20. snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
  21. snowflake/ml/model/_packager/model_env/model_env.py +4 -1
  22. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +28 -24
  23. snowflake/ml/model/_packager/model_handlers/keras.py +1 -5
  24. snowflake/ml/model/_packager/model_handlers/pytorch.py +50 -20
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -8
  26. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -2
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +46 -26
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
  30. snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
  32. snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -2
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +5 -1
  35. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +14 -0
  36. snowflake/ml/model/_packager/model_packager.py +3 -5
  37. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
  38. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -0
  39. snowflake/ml/model/_signatures/builtins_handler.py +20 -9
  40. snowflake/ml/model/_signatures/core.py +52 -31
  41. snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
  42. snowflake/ml/model/_signatures/numpy_handler.py +9 -17
  43. snowflake/ml/model/_signatures/pandas_handler.py +19 -30
  44. snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
  45. snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
  46. snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
  47. snowflake/ml/model/_signatures/utils.py +120 -8
  48. snowflake/ml/model/custom_model.py +13 -4
  49. snowflake/ml/model/model_signature.py +31 -13
  50. snowflake/ml/model/type_hints.py +13 -2
  51. snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
  52. snowflake/ml/modeling/metrics/ranking.py +3 -0
  53. snowflake/ml/modeling/metrics/regression.py +3 -0
  54. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
  55. snowflake/ml/registry/_manager/model_manager.py +55 -7
  56. snowflake/ml/registry/registry.py +59 -1
  57. snowflake/ml/version.py +1 -1
  58. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/METADATA +308 -12
  59. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/RECORD +62 -58
  60. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/WHEEL +1 -1
  61. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info/licenses}/LICENSE.txt +0 -0
  62. {snowflake_ml_python-1.7.5.dist-info → snowflake_ml_python-1.8.1.dist-info}/top_level.txt +0 -0
@@ -32,6 +32,7 @@ from snowflake.ml.model._signatures import (
32
32
  base_handler,
33
33
  builtins_handler as builtins_handler,
34
34
  core,
35
+ dmatrix_handler,
35
36
  numpy_handler,
36
37
  pandas_handler,
37
38
  pytorch_handler,
@@ -52,8 +53,11 @@ _LOCAL_DATA_HANDLERS: List[Type[base_handler.BaseDataHandler[Any]]] = [
52
53
  numpy_handler.NumpyArrayHandler,
53
54
  builtins_handler.ListOfBuiltinHandler,
54
55
  numpy_handler.SeqOfNumpyArrayHandler,
56
+ pytorch_handler.PyTorchTensorHandler,
55
57
  pytorch_handler.SeqOfPyTorchTensorHandler,
58
+ tensorflow_handler.TensorflowTensorHandler,
56
59
  tensorflow_handler.SeqOfTensorflowTensorHandler,
60
+ dmatrix_handler.XGBoostDMatrixHandler,
57
61
  ]
58
62
  _ALL_DATA_HANDLERS = _LOCAL_DATA_HANDLERS + [snowpark_handler.SnowparkDataFrameHandler]
59
63
 
@@ -218,7 +222,6 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
218
222
  strict: Enable strict validation, this includes value range based validation
219
223
 
220
224
  Raises:
221
- SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported.
222
225
  SnowflakeMLException: ValueError: Raised when a feature cannot be found.
223
226
  SnowflakeMLException: ValueError: Raised when feature is scalar but confront list element.
224
227
  SnowflakeMLException: ValueError: Raised when feature type is not aligned in list element.
@@ -236,7 +239,10 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
236
239
  except KeyError:
237
240
  raise snowml_exceptions.SnowflakeMLException(
238
241
  error_code=error_codes.INVALID_DATA,
239
- original_exception=ValueError(f"Data Validation Error: feature {ft_name} does not exist in data."),
242
+ original_exception=ValueError(
243
+ f"Data Validation Error: feature {ft_name} does not exist in data. "
244
+ f"Available columns are {data.columns}."
245
+ ),
240
246
  )
241
247
 
242
248
  if data_col.isnull().any():
@@ -244,10 +250,15 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
244
250
  df_col_dtype = data_col.dtype
245
251
 
246
252
  if isinstance(feature, core.FeatureGroupSpec):
247
- raise snowml_exceptions.SnowflakeMLException(
248
- error_code=error_codes.NOT_IMPLEMENTED,
249
- original_exception=NotImplementedError("FeatureGroupSpec is not supported."),
250
- )
253
+ if df_col_dtype != np.dtype("O"):
254
+ raise snowml_exceptions.SnowflakeMLException(
255
+ error_code=error_codes.INVALID_DATA,
256
+ original_exception=ValueError(
257
+ f"Data Validation Error in feature group {ft_name}: "
258
+ + f"It needs to be a dictionary or list of dictionary, but get {df_col_dtype}."
259
+ ),
260
+ )
261
+ continue
251
262
 
252
263
  assert isinstance(feature, core.FeatureSpec) # assert for mypy.
253
264
  ft_type = feature._dtype
@@ -437,7 +448,6 @@ def _validate_snowpark_data(
437
448
  strict: Enable strict validation, this includes value range based validation.
438
449
 
439
450
  Raises:
440
- SnowflakeMLException: NotImplementedError: FeatureGroupSpec is not supported.
441
451
  SnowflakeMLException: ValueError: Raised when confronting invalid feature.
442
452
  SnowflakeMLException: ValueError: Raised when a feature cannot be found.
443
453
 
@@ -467,10 +477,15 @@ def _validate_snowpark_data(
467
477
  if field.name == ft_name:
468
478
  found = True
469
479
  if isinstance(feature, core.FeatureGroupSpec):
470
- raise snowml_exceptions.SnowflakeMLException(
471
- error_code=error_codes.NOT_IMPLEMENTED,
472
- original_exception=NotImplementedError("FeatureGroupSpec is not supported."),
473
- )
480
+ if not isinstance(field.datatype, (spt.ArrayType, spt.StructType, spt.VariantType)):
481
+ errors[identifier_rule].append(
482
+ ValueError(
483
+ f"Data Validation Error in feature group {feature.name}: "
484
+ + f"Feature expects {feature.as_snowpark_type()},"
485
+ + f" while {field.name} has type {field.datatype}."
486
+ ),
487
+ )
488
+ continue
474
489
  assert isinstance(feature, core.FeatureSpec) # mypy
475
490
  ft_type = feature._dtype
476
491
  field_data_type = field.datatype
@@ -644,11 +659,14 @@ def _validate_snowpark_type_feature(
644
659
  )
645
660
 
646
661
 
647
- def _convert_local_data_to_df(data: model_types.SupportedLocalDataType) -> pd.DataFrame:
662
+ def _convert_local_data_to_df(
663
+ data: model_types.SupportedLocalDataType, ensure_serializable: bool = False
664
+ ) -> pd.DataFrame:
648
665
  """Convert local data to pandas DataFrame or Snowpark DataFrame
649
666
 
650
667
  Args:
651
668
  data: The provided data.
669
+ ensure_serializable: Ensure the data is serializable. Defaults to False.
652
670
 
653
671
  Raises:
654
672
  SnowflakeMLException: NotImplementedError: Raised when data cannot be handled by any data handler.
@@ -660,7 +678,7 @@ def _convert_local_data_to_df(data: model_types.SupportedLocalDataType) -> pd.Da
660
678
  for handler in _LOCAL_DATA_HANDLERS:
661
679
  if handler.can_handle(data):
662
680
  handler.validate(data)
663
- df = handler.convert_to_df(data, ensure_serializable=False)
681
+ df = handler.convert_to_df(data, ensure_serializable=ensure_serializable)
664
682
  break
665
683
  if df is None:
666
684
  raise snowml_exceptions.SnowflakeMLException(
@@ -26,7 +26,15 @@ if TYPE_CHECKING:
26
26
  from snowflake.ml.modeling.framework import base # noqa: F401
27
27
 
28
28
 
29
- _SupportedBuiltins = Union[int, float, bool, str, bytes, "_SupportedBuiltinsList"]
29
+ _SupportedBuiltins = Union[
30
+ int,
31
+ float,
32
+ bool,
33
+ str,
34
+ bytes,
35
+ Dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
36
+ "_SupportedBuiltinsList",
37
+ ]
30
38
  _SupportedNumpyDtype = Union[
31
39
  "np.int8",
32
40
  "np.int16",
@@ -48,7 +56,7 @@ _SupportedBuiltinsList = Sequence[_SupportedBuiltins]
48
56
  _SupportedArrayLike = Union[_SupportedNumpyArray, "torch.Tensor", "tensorflow.Tensor", "tensorflow.Variable"]
49
57
 
50
58
  SupportedLocalDataType = Union[
51
- "pd.DataFrame", _SupportedNumpyArray, Sequence[_SupportedArrayLike], _SupportedBuiltinsList
59
+ "pd.DataFrame", _SupportedArrayLike, Sequence[_SupportedArrayLike], _SupportedBuiltinsList
52
60
  ]
53
61
 
54
62
  SupportedDataType = Union[SupportedLocalDataType, "snowflake.snowpark.DataFrame"]
@@ -177,16 +185,19 @@ class SNOWModelSaveOptions(BaseModelSaveOption):
177
185
  class PyTorchSaveOptions(BaseModelSaveOption):
178
186
  target_methods: NotRequired[Sequence[str]]
179
187
  cuda_version: NotRequired[str]
188
+ multiple_inputs: NotRequired[bool]
180
189
 
181
190
 
182
191
  class TorchScriptSaveOptions(BaseModelSaveOption):
183
192
  target_methods: NotRequired[Sequence[str]]
184
193
  cuda_version: NotRequired[str]
194
+ multiple_inputs: NotRequired[bool]
185
195
 
186
196
 
187
197
  class TensorflowSaveOptions(BaseModelSaveOption):
188
198
  target_methods: NotRequired[Sequence[str]]
189
199
  cuda_version: NotRequired[str]
200
+ multiple_inputs: NotRequired[bool]
190
201
 
191
202
 
192
203
  class MLFlowSaveOptions(BaseModelSaveOption):
@@ -130,7 +130,11 @@ def is_single_node(session: Session) -> bool:
130
130
  warehouse_name = session.get_current_warehouse()
131
131
  if warehouse_name:
132
132
  warehouse_name = warehouse_name.replace('"', "")
133
- df = session.sql(f"SHOW WAREHOUSES like '{warehouse_name}';")['"type"', '"size"'].collect()[0]
133
+ df_list = session.sql(f"SHOW WAREHOUSES like '{warehouse_name}';")['"type"', '"size"'].collect()
134
+ # If no warehouse data is found, default to True (single node)
135
+ if not df_list:
136
+ return True
137
+ df = df_list[0]
134
138
  # filter out the conditions when it is single node
135
139
  single_node: bool = (df[0] == "SNOWPARK-OPTIMIZED" and df[1] == "Medium") or (
136
140
  df[0] == "STANDARD" and df[1] == "X-Small"
@@ -98,6 +98,7 @@ def precision_recall_curve(
98
98
  packages=[
99
99
  f"cloudpickle=={cloudpickle.__version__}",
100
100
  f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
101
+ f"numpy=={np.__version__}",
101
102
  "snowflake-snowpark-python",
102
103
  ],
103
104
  statement_params=statement_params,
@@ -245,6 +246,7 @@ def roc_auc_score(
245
246
  packages=[
246
247
  f"cloudpickle=={cloudpickle.__version__}",
247
248
  f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
249
+ f"numpy=={np.__version__}",
248
250
  "snowflake-snowpark-python",
249
251
  ],
250
252
  statement_params=statement_params,
@@ -348,6 +350,7 @@ def roc_curve(
348
350
  packages=[
349
351
  f"cloudpickle=={cloudpickle.__version__}",
350
352
  f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
353
+ f"numpy=={np.__version__}",
351
354
  "snowflake-snowpark-python",
352
355
  ],
353
356
  statement_params=statement_params,
@@ -83,6 +83,7 @@ def d2_absolute_error_score(
83
83
  packages=[
84
84
  f"cloudpickle=={cloudpickle.__version__}",
85
85
  f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
86
+ f"numpy=={np.__version__}",
86
87
  "snowflake-snowpark-python",
87
88
  ],
88
89
  statement_params=statement_params,
@@ -180,6 +181,7 @@ def d2_pinball_score(
180
181
  packages=[
181
182
  f"cloudpickle=={cloudpickle.__version__}",
182
183
  f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
184
+ f"numpy=={np.__version__}",
183
185
  "snowflake-snowpark-python",
184
186
  ],
185
187
  statement_params=statement_params,
@@ -295,6 +297,7 @@ def explained_variance_score(
295
297
  packages=[
296
298
  f"cloudpickle=={cloudpickle.__version__}",
297
299
  f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
300
+ f"numpy=={np.__version__}",
298
301
  "snowflake-snowpark-python",
299
302
  ],
300
303
  statement_params=statement_params,
@@ -341,7 +341,7 @@ class KBinsDiscretizer(base.BaseTransformer):
341
341
  is_permanent=False,
342
342
  name=udf_name,
343
343
  replace=True,
344
- packages=["numpy"],
344
+ packages=[f"numpy=={np.__version__}"],
345
345
  session=dataset._session,
346
346
  statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
347
347
  )
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
4
4
  import pandas as pd
5
5
  from absl.logging import logging
6
6
 
7
- from snowflake.ml._internal import telemetry
7
+ from snowflake.ml._internal import platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
@@ -13,7 +13,7 @@ from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._packager.model_meta import model_meta
16
- from snowflake.snowpark import session
16
+ from snowflake.snowpark import exceptions as snowpark_exceptions, session
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
@@ -46,6 +46,7 @@ class ModelManager:
46
46
  metrics: Optional[Dict[str, Any]] = None,
47
47
  conda_dependencies: Optional[List[str]] = None,
48
48
  pip_requirements: Optional[List[str]] = None,
49
+ artifact_repository_map: Optional[Dict[str, str]] = None,
49
50
  target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
50
51
  python_version: Optional[str] = None,
51
52
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
@@ -127,6 +128,7 @@ class ModelManager:
127
128
  metrics=metrics,
128
129
  conda_dependencies=conda_dependencies,
129
130
  pip_requirements=pip_requirements,
131
+ artifact_repository_map=artifact_repository_map,
130
132
  target_platforms=target_platforms,
131
133
  python_version=python_version,
132
134
  signatures=signatures,
@@ -149,6 +151,7 @@ class ModelManager:
149
151
  metrics: Optional[Dict[str, Any]] = None,
150
152
  conda_dependencies: Optional[List[str]] = None,
151
153
  pip_requirements: Optional[List[str]] = None,
154
+ artifact_repository_map: Optional[Dict[str, str]] = None,
152
155
  target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
153
156
  python_version: Optional[str] = None,
154
157
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
@@ -163,11 +166,42 @@ class ModelManager:
163
166
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
164
167
  version_name_id = sql_identifier.SqlIdentifier(version_name)
165
168
 
166
- stage_path = self._model_ops.prepare_model_stage_path(
167
- database_name=database_name_id,
168
- schema_name=schema_name_id,
169
- statement_params=statement_params,
170
- )
169
+ use_live_commit = platform_capabilities.PlatformCapabilities.get_instance().is_live_commit_enabled()
170
+ if use_live_commit:
171
+ logger.info("Using live commit model version")
172
+ else:
173
+ logger.info("Using non-live commit model version")
174
+
175
+ if use_live_commit:
176
+ # This step creates the live model version, and the files can be written directly to the stage
177
+ # after this.
178
+ try:
179
+ self._model_ops.add_or_create_live_version(
180
+ database_name=database_name_id,
181
+ schema_name=schema_name_id,
182
+ model_name=model_name_id,
183
+ version_name=version_name_id,
184
+ statement_params=statement_params,
185
+ )
186
+ except (AssertionError, snowpark_exceptions.SnowparkSQLException) as e:
187
+ logger.info(f"Failed to create live model version: {e}, falling back to regular model version creation")
188
+ use_live_commit = False
189
+
190
+ if use_live_commit:
191
+ # using model version's stage path to write files directly to the stage
192
+ stage_path = self._model_ops.get_model_version_stage_path(
193
+ database_name=database_name_id,
194
+ schema_name=schema_name_id,
195
+ model_name=model_name_id,
196
+ version_name=version_name_id,
197
+ )
198
+ else:
199
+ # using a temp path to write files and then upload to the model version's stage
200
+ stage_path = self._model_ops.prepare_model_temp_stage_path(
201
+ database_name=database_name_id,
202
+ schema_name=schema_name_id,
203
+ statement_params=statement_params,
204
+ )
171
205
 
172
206
  platforms = None
173
207
  # User specified target platforms are defaulted to None and will not show up in the generated manifest.
@@ -175,6 +209,18 @@ class ModelManager:
175
209
  # Convert any string target platforms to TargetPlatform objects
176
210
  platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
177
211
 
212
+ if artifact_repository_map:
213
+ for channel, artifact_repository_name in artifact_repository_map.items():
214
+ db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
215
+
216
+ artifact_repository_map[channel] = sql_identifier.get_fully_qualified_name(
217
+ db_id,
218
+ schema_id,
219
+ repo_id,
220
+ self._database_name,
221
+ self._schema_name,
222
+ )
223
+
178
224
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
179
225
 
180
226
  mc = model_composer.ModelComposer(
@@ -187,6 +233,7 @@ class ModelManager:
187
233
  sample_input_data=sample_input_data,
188
234
  conda_dependencies=conda_dependencies,
189
235
  pip_requirements=pip_requirements,
236
+ artifact_repository_map=artifact_repository_map,
190
237
  target_platforms=platforms,
191
238
  python_version=python_version,
192
239
  user_files=user_files,
@@ -211,6 +258,7 @@ class ModelManager:
211
258
  model_name=model_name_id,
212
259
  version_name=version_name_id,
213
260
  statement_params=statement_params,
261
+ use_live_commit=use_live_commit,
214
262
  )
215
263
 
216
264
  mv = model_version_impl.ModelVersion._ref(
@@ -108,12 +108,15 @@ class Registry:
108
108
  metrics: Optional[Dict[str, Any]] = None,
109
109
  conda_dependencies: Optional[List[str]] = None,
110
110
  pip_requirements: Optional[List[str]] = None,
111
+ artifact_repository_map: Optional[Dict[str, str]] = None,
111
112
  target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
112
113
  python_version: Optional[str] = None,
113
114
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
114
115
  sample_input_data: Optional[model_types.SupportedDataType] = None,
116
+ user_files: Optional[Dict[str, List[str]]] = None,
115
117
  code_paths: Optional[List[str]] = None,
116
118
  ext_modules: Optional[List[ModuleType]] = None,
119
+ task: model_types.Task = model_types.Task.UNKNOWN,
117
120
  options: Optional[model_types.ModelSaveOption] = None,
118
121
  ) -> ModelVersion:
119
122
  """
@@ -140,6 +143,13 @@ class Registry:
140
143
  See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
141
144
  Models with pip requirements specified will not be executable in Snowflake Warehouse where all
142
145
  dependencies must be retrieved from Snowflake Anaconda Channel.
146
+ artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
147
+ repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
148
+ Note : This feature is currently in Private Preview; please contact your Snowflake account team
149
+ to enable it.
150
+ Format: {channel_name: artifact_repository_name}, where:
151
+ - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
152
+ - artifact_repository_name: The name or URL of the repository to fetch packages from.
143
153
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
144
154
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
145
155
  python_version: Python version in which the model is run. Defaults to None.
@@ -148,10 +158,15 @@ class Registry:
148
158
  infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
149
159
  sample_input_data: Sample input data to infer model signatures from.
150
160
  It would also be used as background data in explanation and to capture data lineage. Defaults to None.
161
+ user_files: Dictionary where the keys are subdirectories, and values are lists of local file name
162
+ strings. The local file name strings can include wildcards (? or *) for matching multiple files.
151
163
  code_paths: List of directories containing code to import. Defaults to None.
152
164
  ext_modules: List of external modules to pickle with the model object.
153
165
  Only supported when logging the following types of model:
154
166
  Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
167
+ task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
168
+ TABULAR_BINARY_CLASSIFICATION, TABULAR_MULTI_CLASSIFICATION, TABULAR_RANKING, or UNKNOWN. By default,
169
+ it is set to Task.UNKNOWN and may be overridden by inferring from the Model Object.
155
170
  options (Dict[str, Any], optional): Additional model saving options.
156
171
 
157
172
  Model Saving Options include:
@@ -163,6 +178,9 @@ class Registry:
163
178
  Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
164
179
  - function_type: Set the method function type globally. To set method function types individually see
165
180
  function_type in model_options.
181
+ - target_methods: List of target methods to register when logging the model.
182
+ This option is not used in MLFlow models. Defaults to None, in which case the model handler's
183
+ default target methods will be used.
166
184
  - method_options: Per-method saving options. This dictionary has method names as keys and dictionary
167
185
  values with the desired options.
168
186
 
@@ -210,6 +228,7 @@ class Registry:
210
228
  "metrics",
211
229
  "conda_dependencies",
212
230
  "pip_requirements",
231
+ "artifact_repository_map",
213
232
  "target_platforms",
214
233
  "python_version",
215
234
  "signatures",
@@ -225,6 +244,7 @@ class Registry:
225
244
  metrics: Optional[Dict[str, Any]] = None,
226
245
  conda_dependencies: Optional[List[str]] = None,
227
246
  pip_requirements: Optional[List[str]] = None,
247
+ artifact_repository_map: Optional[Dict[str, str]] = None,
228
248
  target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
229
249
  python_version: Optional[str] = None,
230
250
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
@@ -259,6 +279,13 @@ class Registry:
259
279
  See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
260
280
  Models with pip requirements specified will not be executable in Snowflake Warehouse where all
261
281
  dependencies must be retrieved from Snowflake Anaconda Channel.
282
+ artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
283
+ repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
284
+ Note : This feature is currently in Private Preview; please contact your Snowflake account team to
285
+ enable it.
286
+ Format: {channel_name: artifact_repository_name}, where:
287
+ - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
288
+ - artifact_repository_name: The name or URL of the repository to fetch packages from.
262
289
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
263
290
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
264
291
  python_version: Python version in which the model is run. Defaults to None.
@@ -287,6 +314,9 @@ class Registry:
287
314
  Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
288
315
  - function_type: Set the method function type globally. To set method function types individually see
289
316
  function_type in model_options.
317
+ - target_methods: List of target methods to register when logging the model.
318
+ This option is not used in MLFlow models. Defaults to None, in which case the model handler's
319
+ default target methods will be used.
290
320
  - method_options: Per-method saving options. This dictionary has method names as keys and dictionary
291
321
  values with the desired options. See the example below.
292
322
 
@@ -300,6 +330,9 @@ class Registry:
300
330
  Defaults to None, determined automatically by Snowflake.
301
331
  - function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
302
332
 
333
+ Raises:
334
+ ValueError: If extra arguments are specified ModelVersion is provided.
335
+
303
336
  Returns:
304
337
  ModelVersion: ModelVersion object corresponding to the model just logged.
305
338
 
@@ -322,13 +355,37 @@ class Registry:
322
355
  registry.log_model(
323
356
  model=model,
324
357
  model_name="my_model",
325
- method_options=method_options,
358
+ options={"method_options": method_options},
326
359
  )
327
360
  """
328
361
  statement_params = telemetry.get_statement_params(
329
362
  project=_TELEMETRY_PROJECT,
330
363
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
331
364
  )
365
+ if isinstance(model, ModelVersion):
366
+ # check that no arguments are provided other than the ones for copy model.
367
+ invalid_args = [
368
+ comment,
369
+ conda_dependencies,
370
+ pip_requirements,
371
+ artifact_repository_map,
372
+ target_platforms,
373
+ python_version,
374
+ signatures,
375
+ sample_input_data,
376
+ user_files,
377
+ code_paths,
378
+ ext_modules,
379
+ options,
380
+ ]
381
+ for arg in invalid_args:
382
+ if arg is not None:
383
+ raise ValueError(
384
+ "When calling log_model with a ModelVersion, only model_name and version_name may be specified."
385
+ )
386
+ if task is not model_types.Task.UNKNOWN:
387
+ raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
388
+
332
389
  if pip_requirements:
333
390
  warnings.warn(
334
391
  "Models logged specifying `pip_requirements` can not be executed "
@@ -345,6 +402,7 @@ class Registry:
345
402
  metrics=metrics,
346
403
  conda_dependencies=conda_dependencies,
347
404
  pip_requirements=pip_requirements,
405
+ artifact_repository_map=artifact_repository_map,
348
406
  target_platforms=target_platforms,
349
407
  python_version=python_version,
350
408
  signatures=signatures,
snowflake/ml/version.py CHANGED
@@ -1 +1 @@
1
- VERSION="1.7.5"
1
+ VERSION="1.8.1"