snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. snowflake/ml/_internal/utils/service_logger.py +31 -17
  2. snowflake/ml/experiment/callback/keras.py +63 -0
  3. snowflake/ml/experiment/callback/lightgbm.py +59 -0
  4. snowflake/ml/experiment/callback/xgboost.py +67 -0
  5. snowflake/ml/experiment/utils.py +14 -0
  6. snowflake/ml/jobs/_utils/__init__.py +0 -0
  7. snowflake/ml/jobs/_utils/constants.py +4 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +55 -21
  9. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  10. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  11. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
  13. snowflake/ml/jobs/_utils/spec_utils.py +41 -8
  14. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  15. snowflake/ml/jobs/_utils/types.py +5 -7
  16. snowflake/ml/jobs/job.py +1 -1
  17. snowflake/ml/jobs/manager.py +1 -13
  18. snowflake/ml/model/_client/model/model_version_impl.py +219 -55
  19. snowflake/ml/model/_client/ops/service_ops.py +230 -30
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
  21. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
  22. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  24. snowflake/ml/model/event_handler.py +87 -18
  25. snowflake/ml/model/inference_engine.py +5 -0
  26. snowflake/ml/model/models/huggingface_pipeline.py +74 -51
  27. snowflake/ml/model/type_hints.py +26 -1
  28. snowflake/ml/registry/_manager/model_manager.py +37 -70
  29. snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
  30. snowflake/ml/registry/registry.py +0 -19
  31. snowflake/ml/version.py +1 -1
  32. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
  33. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
  34. snowflake/ml/experiment/callback.py +0 -121
  35. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
  36. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
  37. {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
@@ -4,16 +4,16 @@ from typing import TYPE_CHECKING, Any, Optional, Union
4
4
  import pandas as pd
5
5
  from absl.logging import logging
6
6
 
7
- from snowflake.ml._internal import env, platform_capabilities, telemetry
7
+ from snowflake.ml._internal import platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
- from snowflake.ml.model import model_signature, target_platform, task, type_hints
11
+ from snowflake.ml.model import model_signature, task, type_hints
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
15
  from snowflake.ml.model._packager.model_meta import model_meta
16
+ from snowflake.ml.registry._manager import model_parameter_reconciler
17
17
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
18
18
  from snowflake.snowpark._internal import utils as snowpark_utils
19
19
 
@@ -46,6 +46,7 @@ class ModelManager:
46
46
  *,
47
47
  model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
48
48
  model_name: str,
49
+ progress_status: type_hints.ProgressStatus,
49
50
  version_name: Optional[str] = None,
50
51
  comment: Optional[str] = None,
51
52
  metrics: Optional[dict[str, Any]] = None,
@@ -64,7 +65,6 @@ class ModelManager:
64
65
  experiment_info: Optional["ExperimentInfo"] = None,
65
66
  options: Optional[type_hints.ModelSaveOption] = None,
66
67
  statement_params: Optional[dict[str, Any]] = None,
67
- progress_status: Optional[Any] = None,
68
68
  ) -> model_version_impl.ModelVersion:
69
69
 
70
70
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -158,6 +158,7 @@ class ModelManager:
158
158
  *,
159
159
  model_name: str,
160
160
  version_name: str,
161
+ progress_status: type_hints.ProgressStatus,
161
162
  comment: Optional[str] = None,
162
163
  metrics: Optional[dict[str, Any]] = None,
163
164
  conda_dependencies: Optional[list[str]] = None,
@@ -175,7 +176,6 @@ class ModelManager:
175
176
  experiment_info: Optional["ExperimentInfo"] = None,
176
177
  options: Optional[type_hints.ModelSaveOption] = None,
177
178
  statement_params: Optional[dict[str, Any]] = None,
178
- progress_status: Optional[Any] = None,
179
179
  ) -> model_version_impl.ModelVersion:
180
180
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
181
181
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -220,57 +220,30 @@ class ModelManager:
220
220
  statement_params=statement_params,
221
221
  )
222
222
 
223
- platforms = None
224
- # User specified target platforms are defaulted to None and will not show up in the generated manifest.
225
- if target_platforms:
226
- # Convert any string target platforms to TargetPlatform objects
227
- platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
228
- else:
229
- # Default the target platform to warehouse if not specified and any table function exists
230
- if options and (
231
- options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
232
- or (
233
- any(
234
- opt.get("function_type") == "TABLE_FUNCTION"
235
- for opt in options.get("method_options", {}).values()
236
- )
237
- )
238
- ):
239
- logger.info(
240
- "Logging a partitioned model with a table function without specifying `target_platforms`. "
241
- 'Default to `target_platforms=["WAREHOUSE"]`.'
242
- )
243
- platforms = [target_platform.TargetPlatform.WAREHOUSE]
223
+ reconciler = model_parameter_reconciler.ModelParameterReconciler(
224
+ session=self._model_ops._session,
225
+ database_name=self._database_name,
226
+ schema_name=self._schema_name,
227
+ conda_dependencies=conda_dependencies,
228
+ pip_requirements=pip_requirements,
229
+ target_platforms=target_platforms,
230
+ artifact_repository_map=artifact_repository_map,
231
+ options=options,
232
+ python_version=python_version,
233
+ statement_params=statement_params,
234
+ )
244
235
 
245
- # Default the target platform to SPCS if not specified when running in ML runtime
246
- if not platforms and env.IN_ML_RUNTIME:
247
- logger.info(
248
- "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
249
- 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
250
- )
251
- platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
252
-
253
- if artifact_repository_map:
254
- for channel, artifact_repository_name in artifact_repository_map.items():
255
- db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
256
-
257
- artifact_repository_map[channel] = sql_identifier.get_fully_qualified_name(
258
- db_id,
259
- schema_id,
260
- repo_id,
261
- self._database_name,
262
- self._schema_name,
263
- )
236
+ model_params = reconciler.reconcile()
237
+
238
+ # Use reconciled parameters
239
+ artifact_repository_map = model_params.artifact_repository_map
240
+ save_location = model_params.save_location
264
241
 
265
242
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
266
- if progress_status:
267
- progress_status.update("packaging model...")
268
- progress_status.increment()
269
-
270
- # Extract save_location from options if present
271
- save_location = None
272
- if options and "save_location" in options:
273
- save_location = options.get("save_location")
243
+ progress_status.update("packaging model...")
244
+ progress_status.increment()
245
+
246
+ if save_location:
274
247
  logger.info(f"Model will be saved to local directory: {save_location}")
275
248
 
276
249
  mc = model_composer.ModelComposer(
@@ -280,9 +253,8 @@ class ModelManager:
280
253
  save_location=save_location,
281
254
  )
282
255
 
283
- if progress_status:
284
- progress_status.update("creating model manifest...")
285
- progress_status.increment()
256
+ progress_status.update("creating model manifest...")
257
+ progress_status.increment()
286
258
 
287
259
  model_metadata: model_meta.ModelMetadata = mc.save(
288
260
  name=model_name_id.resolved(),
@@ -293,19 +265,18 @@ class ModelManager:
293
265
  pip_requirements=pip_requirements,
294
266
  artifact_repository_map=artifact_repository_map,
295
267
  resource_constraint=resource_constraint,
296
- target_platforms=platforms,
268
+ target_platforms=model_params.target_platforms,
297
269
  python_version=python_version,
298
270
  user_files=user_files,
299
271
  code_paths=code_paths,
300
272
  ext_modules=ext_modules,
301
- options=options,
273
+ options=model_params.options,
302
274
  task=task,
303
275
  experiment_info=experiment_info,
304
276
  )
305
277
 
306
- if progress_status:
307
- progress_status.update("uploading model files...")
308
- progress_status.increment()
278
+ progress_status.update("uploading model files...")
279
+ progress_status.increment()
309
280
  statement_params = telemetry.add_statement_params_custom_tags(
310
281
  statement_params, model_metadata.telemetry_metadata()
311
282
  )
@@ -313,10 +284,8 @@ class ModelManager:
313
284
  statement_params, {"model_version_name": version_name_id}
314
285
  )
315
286
 
316
- logger.info("Start creating MODEL object for you in the Snowflake.")
317
- if progress_status:
318
- progress_status.update("creating model object in Snowflake...")
319
- progress_status.increment()
287
+ progress_status.update("creating model object in Snowflake...")
288
+ progress_status.increment()
320
289
 
321
290
  self._model_ops.create_from_stage(
322
291
  composed_model=mc,
@@ -343,9 +312,8 @@ class ModelManager:
343
312
  version_name=version_name_id,
344
313
  )
345
314
 
346
- if progress_status:
347
- progress_status.update("setting model metadata...")
348
- progress_status.increment()
315
+ progress_status.update("setting model metadata...")
316
+ progress_status.increment()
349
317
 
350
318
  if comment:
351
319
  mv.comment = comment
@@ -360,8 +328,7 @@ class ModelManager:
360
328
  statement_params=statement_params,
361
329
  )
362
330
 
363
- if progress_status:
364
- progress_status.update("model logged successfully!")
331
+ progress_status.update("model logged successfully!")
365
332
 
366
333
  return mv
367
334
 
@@ -0,0 +1,294 @@
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from typing import Any, Optional
4
+
5
+ from absl.logging import logging
6
+ from packaging import requirements
7
+
8
+ from snowflake.ml import version as snowml_version
9
+ from snowflake.ml._internal import env, env as snowml_env, env_utils
10
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
11
+ from snowflake.ml._internal.utils import sql_identifier
12
+ from snowflake.ml.model import target_platform, type_hints as model_types
13
+ from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
14
+ from snowflake.snowpark import Session
15
+ from snowflake.snowpark._internal import utils as snowpark_utils
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class ReconciledParameters:
22
+ """Holds the reconciled and validated parameters after processing."""
23
+
24
+ conda_dependencies: Optional[list[str]] = None
25
+ pip_requirements: Optional[list[str]] = None
26
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None
27
+ artifact_repository_map: Optional[dict[str, str]] = None
28
+ options: Optional[model_types.ModelSaveOption] = None
29
+ save_location: Optional[str] = None
30
+
31
+
32
+ class ModelParameterReconciler:
33
+ """Centralizes all complex log_model parameter validation, transformation, and reconciliation logic."""
34
+
35
+ def __init__(
36
+ self,
37
+ session: Session,
38
+ database_name: sql_identifier.SqlIdentifier,
39
+ schema_name: sql_identifier.SqlIdentifier,
40
+ conda_dependencies: Optional[list[str]] = None,
41
+ pip_requirements: Optional[list[str]] = None,
42
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
43
+ artifact_repository_map: Optional[dict[str, str]] = None,
44
+ options: Optional[model_types.ModelSaveOption] = None,
45
+ python_version: Optional[str] = None,
46
+ statement_params: Optional[dict[str, str]] = None,
47
+ ) -> None:
48
+ self._session = session
49
+ self._database_name = database_name
50
+ self._schema_name = schema_name
51
+ self._conda_dependencies = conda_dependencies
52
+ self._pip_requirements = pip_requirements
53
+ self._target_platforms = target_platforms
54
+ self._artifact_repository_map = artifact_repository_map
55
+ self._options = options
56
+ self._python_version = python_version
57
+ self._statement_params = statement_params
58
+
59
+ def reconcile(self) -> ReconciledParameters:
60
+ """Perform all parameter reconciliation and return clean parameters."""
61
+
62
+ reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
63
+ reconciled_save_location = self._extract_save_location()
64
+
65
+ self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
66
+
67
+ reconciled_target_platforms = self._reconcile_target_platforms()
68
+ reconciled_options = self._reconcile_explainability_options(reconciled_target_platforms)
69
+ reconciled_options = self._reconcile_relax_version(reconciled_options, reconciled_target_platforms)
70
+
71
+ return ReconciledParameters(
72
+ conda_dependencies=self._conda_dependencies,
73
+ pip_requirements=self._pip_requirements,
74
+ target_platforms=reconciled_target_platforms,
75
+ artifact_repository_map=reconciled_artifact_repository_map,
76
+ options=reconciled_options,
77
+ save_location=reconciled_save_location,
78
+ )
79
+
80
+ def _reconcile_artifact_repository_map(self) -> Optional[dict[str, str]]:
81
+ """Transform artifact_repository_map to use fully qualified names."""
82
+ if not self._artifact_repository_map:
83
+ return None
84
+
85
+ transformed_map = {}
86
+
87
+ for channel, artifact_repository_name in self._artifact_repository_map.items():
88
+ db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
89
+
90
+ transformed_map[channel] = sql_identifier.get_fully_qualified_name(
91
+ db_id,
92
+ schema_id,
93
+ repo_id,
94
+ self._database_name,
95
+ self._schema_name,
96
+ )
97
+
98
+ return transformed_map
99
+
100
+ def _extract_save_location(self) -> Optional[str]:
101
+ """Extract save_location from options."""
102
+ if self._options and "save_location" in self._options:
103
+ return self._options.get("save_location")
104
+
105
+ return None
106
+
107
+ def _reconcile_target_platforms(self) -> Optional[list[model_types.TargetPlatform]]:
108
+ """Reconcile target platforms with proper defaulting logic."""
109
+ # User specified target platforms are defaulted to None and will not show up in the generated manifest.
110
+ if self._target_platforms:
111
+ # Convert any string target platforms to TargetPlatform objects
112
+ return [model_types.TargetPlatform(platform) for platform in self._target_platforms]
113
+
114
+ # Default the target platform to warehouse if not specified and any table function exists
115
+ if self._has_table_function():
116
+ logger.info(
117
+ "Logging a partitioned model with a table function without specifying `target_platforms`. "
118
+ 'Default to `target_platforms=["WAREHOUSE"]`.'
119
+ )
120
+ return [target_platform.TargetPlatform.WAREHOUSE]
121
+
122
+ # Default the target platform to SPCS if not specified when running in ML runtime
123
+ if env.IN_ML_RUNTIME:
124
+ logger.info(
125
+ "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
126
+ 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
127
+ )
128
+ return [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
129
+
130
+ return None
131
+
132
+ def _has_table_function(self) -> bool:
133
+ """Check if any table function exists in options."""
134
+ if self._options is None:
135
+ return False
136
+
137
+ if self._options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
138
+ return True
139
+
140
+ for opt in self._options.get("method_options", {}).values():
141
+ if opt.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
142
+ return True
143
+
144
+ return False
145
+
146
+ def _validate_pip_requirements_warehouse_compatibility(
147
+ self, artifact_repository_map: Optional[dict[str, str]]
148
+ ) -> None:
149
+ """Validate pip_requirements compatibility with warehouse deployment."""
150
+ if self._pip_requirements and not artifact_repository_map and self._targets_warehouse(self._target_platforms):
151
+ warnings.warn(
152
+ "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
153
+ "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
154
+ "Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
155
+ category=UserWarning,
156
+ stacklevel=1,
157
+ )
158
+
159
+ @staticmethod
160
+ def _targets_warehouse(target_platforms: Optional[list[model_types.SupportedTargetPlatformType]]) -> bool:
161
+ """Returns True if warehouse is a target platform (None defaults to True)."""
162
+ return (
163
+ target_platforms is None
164
+ or model_types.TargetPlatform.WAREHOUSE in target_platforms
165
+ or "WAREHOUSE" in target_platforms
166
+ )
167
+
168
+ def _reconcile_explainability_options(
169
+ self, target_platforms: Optional[list[model_types.TargetPlatform]]
170
+ ) -> model_types.ModelSaveOption:
171
+ """Reconcile explainability settings and embed_local_ml_library based on warehouse runnability."""
172
+ options = self._options.copy() if self._options else model_types.BaseModelSaveOption()
173
+
174
+ conda_dep_dict = env_utils.validate_conda_dependency_string_list(self._conda_dependencies or [])
175
+
176
+ enable_explainability = options.get("enable_explainability", None)
177
+
178
+ # Handle case where user explicitly disabled explainability
179
+ if enable_explainability is False:
180
+ return self._handle_embed_local_ml_library(options, target_platforms)
181
+
182
+ target_platform_set = set(target_platforms) if target_platforms else set()
183
+
184
+ is_warehouse_runnable = self._is_warehouse_runnable(conda_dep_dict)
185
+ only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
186
+ has_both_platforms = target_platform_set == set(target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES)
187
+
188
+ # Handle case where user explicitly requested explainability
189
+ if enable_explainability:
190
+ if only_spcs or not is_warehouse_runnable:
191
+ raise ValueError(
192
+ "`enable_explainability` cannot be set to True when the model is not runnable in WH "
193
+ "or the target platforms include SPCS."
194
+ )
195
+ elif has_both_platforms:
196
+ warnings.warn(
197
+ ("Explain function will only be available for model deployed to warehouse."),
198
+ category=UserWarning,
199
+ stacklevel=2,
200
+ )
201
+
202
+ # Handle case where explainability is not specified (None) - set default behavior
203
+ if enable_explainability is None:
204
+ if only_spcs or not is_warehouse_runnable:
205
+ options["enable_explainability"] = False
206
+
207
+ return self._handle_embed_local_ml_library(options, target_platforms)
208
+
209
+ def _handle_embed_local_ml_library(
210
+ self, options: model_types.ModelSaveOption, target_platforms: Optional[list[model_types.TargetPlatform]]
211
+ ) -> model_types.ModelSaveOption:
212
+ """Handle embed_local_ml_library logic."""
213
+ if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
214
+ model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
215
+ ]:
216
+ snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
217
+ self._session,
218
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
219
+ python_version=self._python_version or snowml_env.PYTHON_VERSION,
220
+ statement_params=self._statement_params,
221
+ ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
222
+
223
+ if len(snowml_matched_versions) < 1 and not options.get("embed_local_ml_library", False):
224
+ logging.info(
225
+ f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
226
+ " which is not available in the Snowflake server, embedding local ML library automatically."
227
+ )
228
+ options["embed_local_ml_library"] = True
229
+
230
+ return options
231
+
232
+ def _is_warehouse_runnable(self, conda_dep_dict: dict[str, list[Any]]) -> bool:
233
+ """Check if model can run in warehouse based on conda channels and pip requirements."""
234
+ # If pip requirements are present but no artifact repository map, model cannot run in warehouse
235
+ if self._pip_requirements and not self._artifact_repository_map:
236
+ return False
237
+
238
+ # If no conda dependencies, model can run in warehouse
239
+ if not conda_dep_dict:
240
+ return True
241
+
242
+ # Check if all conda channels are warehouse-compatible
243
+ warehouse_compatible_channels = {env_utils.DEFAULT_CHANNEL_NAME, env_utils.SNOWFLAKE_CONDA_CHANNEL_URL}
244
+ for channel in conda_dep_dict:
245
+ if channel not in warehouse_compatible_channels:
246
+ return False
247
+
248
+ return True
249
+
250
+ def _reconcile_relax_version(
251
+ self,
252
+ options: model_types.ModelSaveOption,
253
+ target_platforms: Optional[list[model_types.TargetPlatform]],
254
+ ) -> model_types.ModelSaveOption:
255
+ """Reconcile relax_version setting based on pip requirements and target platforms."""
256
+ target_platform_set = set(target_platforms) if target_platforms else set()
257
+ has_pip_requirements = bool(self._pip_requirements)
258
+ only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
259
+
260
+ if "relax_version" not in options:
261
+ if has_pip_requirements or only_spcs:
262
+ logger.info(
263
+ "Setting `relax_version=False` as this model will run in Snowpark Container Services "
264
+ "or in Warehouse with a specified artifact_repository_map where exact version "
265
+ " specifications will be honored."
266
+ )
267
+ relax_version = False
268
+ else:
269
+ warnings.warn(
270
+ (
271
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
272
+ " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
273
+ " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
274
+ ),
275
+ category=UserWarning,
276
+ stacklevel=2,
277
+ )
278
+ relax_version = True
279
+ options["relax_version"] = relax_version
280
+ return options
281
+
282
+ # Handle case where relax_version is already set
283
+ relax_version = options["relax_version"]
284
+ if relax_version and (has_pip_requirements or only_spcs):
285
+ raise exceptions.SnowflakeMLException(
286
+ error_code=error_codes.INVALID_ARGUMENT,
287
+ original_exception=ValueError(
288
+ "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
289
+ "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
290
+ "targeting only Snowpark Container Services."
291
+ ),
292
+ )
293
+
294
+ return options
@@ -1,4 +1,3 @@
1
- import warnings
2
1
  from types import ModuleType
3
2
  from typing import Any, Optional, Union, overload
4
3
 
@@ -442,15 +441,6 @@ class Registry:
442
441
  if task is not type_hints.Task.UNKNOWN:
443
442
  raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
444
443
 
445
- if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
446
- warnings.warn(
447
- "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
448
- "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
449
- "Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
450
- category=UserWarning,
451
- stacklevel=1,
452
- )
453
-
454
444
  registry_event_handler = event_handler.ModelEventHandler()
455
445
  with registry_event_handler.status("Logging model", total=6) as status:
456
446
  # Step 1: Validation and setup
@@ -662,12 +652,3 @@ class Registry:
662
652
  if not self.enable_monitoring:
663
653
  raise ValueError(_MODEL_MONITORING_DISABLED_ERROR)
664
654
  self._model_monitor_manager.delete_monitor(name)
665
-
666
- @staticmethod
667
- def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
668
- """Returns True if warehouse is a target platform (None defaults to True)."""
669
- return (
670
- target_platforms is None
671
- or type_hints.TargetPlatform.WAREHOUSE in target_platforms
672
- or "WAREHOUSE" in target_platforms
673
- )
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.9.2"
2
+ VERSION = "1.11.0"