snowflake-ml-python 1.9.2__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
16
  from snowflake.ml.model._packager.model_meta import model_meta
17
+ from snowflake.ml.registry._manager import model_parameter_reconciler
17
18
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
18
19
  from snowflake.snowpark._internal import utils as snowpark_utils
19
20
 
@@ -46,6 +47,7 @@ class ModelManager:
46
47
  *,
47
48
  model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
48
49
  model_name: str,
50
+ progress_status: type_hints.ProgressStatus,
49
51
  version_name: Optional[str] = None,
50
52
  comment: Optional[str] = None,
51
53
  metrics: Optional[dict[str, Any]] = None,
@@ -64,7 +66,6 @@ class ModelManager:
64
66
  experiment_info: Optional["ExperimentInfo"] = None,
65
67
  options: Optional[type_hints.ModelSaveOption] = None,
66
68
  statement_params: Optional[dict[str, Any]] = None,
67
- progress_status: Optional[Any] = None,
68
69
  ) -> model_version_impl.ModelVersion:
69
70
 
70
71
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -158,6 +159,7 @@ class ModelManager:
158
159
  *,
159
160
  model_name: str,
160
161
  version_name: str,
162
+ progress_status: type_hints.ProgressStatus,
161
163
  comment: Optional[str] = None,
162
164
  metrics: Optional[dict[str, Any]] = None,
163
165
  conda_dependencies: Optional[list[str]] = None,
@@ -175,7 +177,6 @@ class ModelManager:
175
177
  experiment_info: Optional["ExperimentInfo"] = None,
176
178
  options: Optional[type_hints.ModelSaveOption] = None,
177
179
  statement_params: Optional[dict[str, Any]] = None,
178
- progress_status: Optional[Any] = None,
179
180
  ) -> model_version_impl.ModelVersion:
180
181
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
181
182
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -250,27 +251,27 @@ class ModelManager:
250
251
  )
251
252
  platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
252
253
 
253
- if artifact_repository_map:
254
- for channel, artifact_repository_name in artifact_repository_map.items():
255
- db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
254
+ reconciler = model_parameter_reconciler.ModelParameterReconciler(
255
+ database_name=self._database_name,
256
+ schema_name=self._schema_name,
257
+ conda_dependencies=conda_dependencies,
258
+ pip_requirements=pip_requirements,
259
+ target_platforms=target_platforms,
260
+ artifact_repository_map=artifact_repository_map,
261
+ options=options,
262
+ )
256
263
 
257
- artifact_repository_map[channel] = sql_identifier.get_fully_qualified_name(
258
- db_id,
259
- schema_id,
260
- repo_id,
261
- self._database_name,
262
- self._schema_name,
263
- )
264
+ model_params = reconciler.reconcile()
265
+
266
+ # Use reconciled parameters
267
+ artifact_repository_map = model_params.artifact_repository_map
268
+ save_location = model_params.save_location
264
269
 
265
270
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
266
- if progress_status:
267
- progress_status.update("packaging model...")
268
- progress_status.increment()
269
-
270
- # Extract save_location from options if present
271
- save_location = None
272
- if options and "save_location" in options:
273
- save_location = options.get("save_location")
271
+ progress_status.update("packaging model...")
272
+ progress_status.increment()
273
+
274
+ if save_location:
274
275
  logger.info(f"Model will be saved to local directory: {save_location}")
275
276
 
276
277
  mc = model_composer.ModelComposer(
@@ -280,9 +281,8 @@ class ModelManager:
280
281
  save_location=save_location,
281
282
  )
282
283
 
283
- if progress_status:
284
- progress_status.update("creating model manifest...")
285
- progress_status.increment()
284
+ progress_status.update("creating model manifest...")
285
+ progress_status.increment()
286
286
 
287
287
  model_metadata: model_meta.ModelMetadata = mc.save(
288
288
  name=model_name_id.resolved(),
@@ -303,9 +303,8 @@ class ModelManager:
303
303
  experiment_info=experiment_info,
304
304
  )
305
305
 
306
- if progress_status:
307
- progress_status.update("uploading model files...")
308
- progress_status.increment()
306
+ progress_status.update("uploading model files...")
307
+ progress_status.increment()
309
308
  statement_params = telemetry.add_statement_params_custom_tags(
310
309
  statement_params, model_metadata.telemetry_metadata()
311
310
  )
@@ -313,10 +312,8 @@ class ModelManager:
313
312
  statement_params, {"model_version_name": version_name_id}
314
313
  )
315
314
 
316
- logger.info("Start creating MODEL object for you in the Snowflake.")
317
- if progress_status:
318
- progress_status.update("creating model object in Snowflake...")
319
- progress_status.increment()
315
+ progress_status.update("creating model object in Snowflake...")
316
+ progress_status.increment()
320
317
 
321
318
  self._model_ops.create_from_stage(
322
319
  composed_model=mc,
@@ -343,9 +340,8 @@ class ModelManager:
343
340
  version_name=version_name_id,
344
341
  )
345
342
 
346
- if progress_status:
347
- progress_status.update("setting model metadata...")
348
- progress_status.increment()
343
+ progress_status.update("setting model metadata...")
344
+ progress_status.increment()
349
345
 
350
346
  if comment:
351
347
  mv.comment = comment
@@ -360,8 +356,7 @@ class ModelManager:
360
356
  statement_params=statement_params,
361
357
  )
362
358
 
363
- if progress_status:
364
- progress_status.update("model logged successfully!")
359
+ progress_status.update("model logged successfully!")
365
360
 
366
361
  return mv
367
362
 
@@ -0,0 +1,105 @@
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ from snowflake.ml._internal.utils import sql_identifier
6
+ from snowflake.ml.model import type_hints as model_types
7
+
8
+
9
+ @dataclass
10
+ class ReconciledParameters:
11
+ """Holds the reconciled and validated parameters after processing."""
12
+
13
+ conda_dependencies: Optional[list[str]] = None
14
+ pip_requirements: Optional[list[str]] = None
15
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None
16
+ artifact_repository_map: Optional[dict[str, str]] = None
17
+ options: Optional[model_types.ModelSaveOption] = None
18
+ save_location: Optional[str] = None
19
+
20
+
21
+ class ModelParameterReconciler:
22
+ """Centralizes all complex log_model parameter validation, transformation, and reconciliation logic."""
23
+
24
+ def __init__(
25
+ self,
26
+ database_name: sql_identifier.SqlIdentifier,
27
+ schema_name: sql_identifier.SqlIdentifier,
28
+ conda_dependencies: Optional[list[str]] = None,
29
+ pip_requirements: Optional[list[str]] = None,
30
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
31
+ artifact_repository_map: Optional[dict[str, str]] = None,
32
+ options: Optional[model_types.ModelSaveOption] = None,
33
+ ) -> None:
34
+ self._database_name = database_name
35
+ self._schema_name = schema_name
36
+ self._conda_dependencies = conda_dependencies
37
+ self._pip_requirements = pip_requirements
38
+ self._target_platforms = target_platforms
39
+ self._artifact_repository_map = artifact_repository_map
40
+ self._options = options
41
+
42
+ def reconcile(self) -> ReconciledParameters:
43
+ """Perform all parameter reconciliation and return clean parameters."""
44
+ reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
45
+ reconciled_save_location = self._extract_save_location()
46
+
47
+ self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
48
+
49
+ return ReconciledParameters(
50
+ conda_dependencies=self._conda_dependencies,
51
+ pip_requirements=self._pip_requirements,
52
+ target_platforms=self._target_platforms,
53
+ artifact_repository_map=reconciled_artifact_repository_map,
54
+ options=self._options,
55
+ save_location=reconciled_save_location,
56
+ )
57
+
58
+ def _reconcile_artifact_repository_map(self) -> Optional[dict[str, str]]:
59
+ """Transform artifact_repository_map to use fully qualified names."""
60
+ if not self._artifact_repository_map:
61
+ return None
62
+
63
+ transformed_map = {}
64
+
65
+ for channel, artifact_repository_name in self._artifact_repository_map.items():
66
+ db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
67
+
68
+ transformed_map[channel] = sql_identifier.get_fully_qualified_name(
69
+ db_id,
70
+ schema_id,
71
+ repo_id,
72
+ self._database_name,
73
+ self._schema_name,
74
+ )
75
+
76
+ return transformed_map
77
+
78
+ def _extract_save_location(self) -> Optional[str]:
79
+ """Extract save_location from options."""
80
+ if self._options and "save_location" in self._options:
81
+ return self._options.get("save_location")
82
+
83
+ return None
84
+
85
+ def _validate_pip_requirements_warehouse_compatibility(
86
+ self, artifact_repository_map: Optional[dict[str, str]]
87
+ ) -> None:
88
+ """Validate pip_requirements compatibility with warehouse deployment."""
89
+ if self._pip_requirements and not artifact_repository_map and self._targets_warehouse(self._target_platforms):
90
+ warnings.warn(
91
+ "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
92
+ "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
93
+ "Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
94
+ category=UserWarning,
95
+ stacklevel=1,
96
+ )
97
+
98
+ @staticmethod
99
+ def _targets_warehouse(target_platforms: Optional[list[model_types.SupportedTargetPlatformType]]) -> bool:
100
+ """Returns True if warehouse is a target platform (None defaults to True)."""
101
+ return (
102
+ target_platforms is None
103
+ or model_types.TargetPlatform.WAREHOUSE in target_platforms
104
+ or "WAREHOUSE" in target_platforms
105
+ )
@@ -1,4 +1,3 @@
1
- import warnings
2
1
  from types import ModuleType
3
2
  from typing import Any, Optional, Union, overload
4
3
 
@@ -442,15 +441,6 @@ class Registry:
442
441
  if task is not type_hints.Task.UNKNOWN:
443
442
  raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
444
443
 
445
- if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
446
- warnings.warn(
447
- "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
448
- "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
449
- "Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
450
- category=UserWarning,
451
- stacklevel=1,
452
- )
453
-
454
444
  registry_event_handler = event_handler.ModelEventHandler()
455
445
  with registry_event_handler.status("Logging model", total=6) as status:
456
446
  # Step 1: Validation and setup
@@ -662,12 +652,3 @@ class Registry:
662
652
  if not self.enable_monitoring:
663
653
  raise ValueError(_MODEL_MONITORING_DISABLED_ERROR)
664
654
  self._model_monitor_manager.delete_monitor(name)
665
-
666
- @staticmethod
667
- def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
668
- """Returns True if warehouse is a target platform (None defaults to True)."""
669
- return (
670
- target_platforms is None
671
- or type_hints.TargetPlatform.WAREHOUSE in target_platforms
672
- or "WAREHOUSE" in target_platforms
673
- )
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.9.2"
2
+ VERSION = "1.10.0"