snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/service_logger.py +31 -17
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +59 -0
- snowflake/ml/experiment/callback/xgboost.py +67 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/payload_utils.py +55 -21
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
- snowflake/ml/jobs/_utils/spec_utils.py +41 -8
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +5 -7
- snowflake/ml/jobs/job.py +1 -1
- snowflake/ml/jobs/manager.py +1 -13
- snowflake/ml/model/_client/model/model_version_impl.py +219 -55
- snowflake/ml/model/_client/ops/service_ops.py +230 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +74 -51
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +37 -70
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
- snowflake/ml/experiment/callback.py +0 -121
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -4,16 +4,16 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from absl.logging import logging
|
|
6
6
|
|
|
7
|
-
from snowflake.ml._internal import
|
|
7
|
+
from snowflake.ml._internal import platform_capabilities, telemetry
|
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
|
11
|
-
from snowflake.ml.model import model_signature,
|
|
11
|
+
from snowflake.ml.model import model_signature, task, type_hints
|
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
|
15
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
16
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
|
16
|
+
from snowflake.ml.registry._manager import model_parameter_reconciler
|
|
17
17
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
|
18
18
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
19
19
|
|
|
@@ -46,6 +46,7 @@ class ModelManager:
|
|
|
46
46
|
*,
|
|
47
47
|
model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
|
|
48
48
|
model_name: str,
|
|
49
|
+
progress_status: type_hints.ProgressStatus,
|
|
49
50
|
version_name: Optional[str] = None,
|
|
50
51
|
comment: Optional[str] = None,
|
|
51
52
|
metrics: Optional[dict[str, Any]] = None,
|
|
@@ -64,7 +65,6 @@ class ModelManager:
|
|
|
64
65
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
65
66
|
options: Optional[type_hints.ModelSaveOption] = None,
|
|
66
67
|
statement_params: Optional[dict[str, Any]] = None,
|
|
67
|
-
progress_status: Optional[Any] = None,
|
|
68
68
|
) -> model_version_impl.ModelVersion:
|
|
69
69
|
|
|
70
70
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
|
@@ -158,6 +158,7 @@ class ModelManager:
|
|
|
158
158
|
*,
|
|
159
159
|
model_name: str,
|
|
160
160
|
version_name: str,
|
|
161
|
+
progress_status: type_hints.ProgressStatus,
|
|
161
162
|
comment: Optional[str] = None,
|
|
162
163
|
metrics: Optional[dict[str, Any]] = None,
|
|
163
164
|
conda_dependencies: Optional[list[str]] = None,
|
|
@@ -175,7 +176,6 @@ class ModelManager:
|
|
|
175
176
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
176
177
|
options: Optional[type_hints.ModelSaveOption] = None,
|
|
177
178
|
statement_params: Optional[dict[str, Any]] = None,
|
|
178
|
-
progress_status: Optional[Any] = None,
|
|
179
179
|
) -> model_version_impl.ModelVersion:
|
|
180
180
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
|
181
181
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
|
@@ -220,57 +220,30 @@ class ModelManager:
|
|
|
220
220
|
statement_params=statement_params,
|
|
221
221
|
)
|
|
222
222
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
for opt in options.get("method_options", {}).values()
|
|
236
|
-
)
|
|
237
|
-
)
|
|
238
|
-
):
|
|
239
|
-
logger.info(
|
|
240
|
-
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
|
241
|
-
'Default to `target_platforms=["WAREHOUSE"]`.'
|
|
242
|
-
)
|
|
243
|
-
platforms = [target_platform.TargetPlatform.WAREHOUSE]
|
|
223
|
+
reconciler = model_parameter_reconciler.ModelParameterReconciler(
|
|
224
|
+
session=self._model_ops._session,
|
|
225
|
+
database_name=self._database_name,
|
|
226
|
+
schema_name=self._schema_name,
|
|
227
|
+
conda_dependencies=conda_dependencies,
|
|
228
|
+
pip_requirements=pip_requirements,
|
|
229
|
+
target_platforms=target_platforms,
|
|
230
|
+
artifact_repository_map=artifact_repository_map,
|
|
231
|
+
options=options,
|
|
232
|
+
python_version=python_version,
|
|
233
|
+
statement_params=statement_params,
|
|
234
|
+
)
|
|
244
235
|
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
)
|
|
251
|
-
platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
|
252
|
-
|
|
253
|
-
if artifact_repository_map:
|
|
254
|
-
for channel, artifact_repository_name in artifact_repository_map.items():
|
|
255
|
-
db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
|
|
256
|
-
|
|
257
|
-
artifact_repository_map[channel] = sql_identifier.get_fully_qualified_name(
|
|
258
|
-
db_id,
|
|
259
|
-
schema_id,
|
|
260
|
-
repo_id,
|
|
261
|
-
self._database_name,
|
|
262
|
-
self._schema_name,
|
|
263
|
-
)
|
|
236
|
+
model_params = reconciler.reconcile()
|
|
237
|
+
|
|
238
|
+
# Use reconciled parameters
|
|
239
|
+
artifact_repository_map = model_params.artifact_repository_map
|
|
240
|
+
save_location = model_params.save_location
|
|
264
241
|
|
|
265
242
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
# Extract save_location from options if present
|
|
271
|
-
save_location = None
|
|
272
|
-
if options and "save_location" in options:
|
|
273
|
-
save_location = options.get("save_location")
|
|
243
|
+
progress_status.update("packaging model...")
|
|
244
|
+
progress_status.increment()
|
|
245
|
+
|
|
246
|
+
if save_location:
|
|
274
247
|
logger.info(f"Model will be saved to local directory: {save_location}")
|
|
275
248
|
|
|
276
249
|
mc = model_composer.ModelComposer(
|
|
@@ -280,9 +253,8 @@ class ModelManager:
|
|
|
280
253
|
save_location=save_location,
|
|
281
254
|
)
|
|
282
255
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
progress_status.increment()
|
|
256
|
+
progress_status.update("creating model manifest...")
|
|
257
|
+
progress_status.increment()
|
|
286
258
|
|
|
287
259
|
model_metadata: model_meta.ModelMetadata = mc.save(
|
|
288
260
|
name=model_name_id.resolved(),
|
|
@@ -293,19 +265,18 @@ class ModelManager:
|
|
|
293
265
|
pip_requirements=pip_requirements,
|
|
294
266
|
artifact_repository_map=artifact_repository_map,
|
|
295
267
|
resource_constraint=resource_constraint,
|
|
296
|
-
target_platforms=
|
|
268
|
+
target_platforms=model_params.target_platforms,
|
|
297
269
|
python_version=python_version,
|
|
298
270
|
user_files=user_files,
|
|
299
271
|
code_paths=code_paths,
|
|
300
272
|
ext_modules=ext_modules,
|
|
301
|
-
options=options,
|
|
273
|
+
options=model_params.options,
|
|
302
274
|
task=task,
|
|
303
275
|
experiment_info=experiment_info,
|
|
304
276
|
)
|
|
305
277
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
progress_status.increment()
|
|
278
|
+
progress_status.update("uploading model files...")
|
|
279
|
+
progress_status.increment()
|
|
309
280
|
statement_params = telemetry.add_statement_params_custom_tags(
|
|
310
281
|
statement_params, model_metadata.telemetry_metadata()
|
|
311
282
|
)
|
|
@@ -313,10 +284,8 @@ class ModelManager:
|
|
|
313
284
|
statement_params, {"model_version_name": version_name_id}
|
|
314
285
|
)
|
|
315
286
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
progress_status.update("creating model object in Snowflake...")
|
|
319
|
-
progress_status.increment()
|
|
287
|
+
progress_status.update("creating model object in Snowflake...")
|
|
288
|
+
progress_status.increment()
|
|
320
289
|
|
|
321
290
|
self._model_ops.create_from_stage(
|
|
322
291
|
composed_model=mc,
|
|
@@ -343,9 +312,8 @@ class ModelManager:
|
|
|
343
312
|
version_name=version_name_id,
|
|
344
313
|
)
|
|
345
314
|
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
progress_status.increment()
|
|
315
|
+
progress_status.update("setting model metadata...")
|
|
316
|
+
progress_status.increment()
|
|
349
317
|
|
|
350
318
|
if comment:
|
|
351
319
|
mv.comment = comment
|
|
@@ -360,8 +328,7 @@ class ModelManager:
|
|
|
360
328
|
statement_params=statement_params,
|
|
361
329
|
)
|
|
362
330
|
|
|
363
|
-
|
|
364
|
-
progress_status.update("model logged successfully!")
|
|
331
|
+
progress_status.update("model logged successfully!")
|
|
365
332
|
|
|
366
333
|
return mv
|
|
367
334
|
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from absl.logging import logging
|
|
6
|
+
from packaging import requirements
|
|
7
|
+
|
|
8
|
+
from snowflake.ml import version as snowml_version
|
|
9
|
+
from snowflake.ml._internal import env, env as snowml_env, env_utils
|
|
10
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
11
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
12
|
+
from snowflake.ml.model import target_platform, type_hints as model_types
|
|
13
|
+
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
14
|
+
from snowflake.snowpark import Session
|
|
15
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class ReconciledParameters:
|
|
22
|
+
"""Holds the reconciled and validated parameters after processing."""
|
|
23
|
+
|
|
24
|
+
conda_dependencies: Optional[list[str]] = None
|
|
25
|
+
pip_requirements: Optional[list[str]] = None
|
|
26
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None
|
|
27
|
+
artifact_repository_map: Optional[dict[str, str]] = None
|
|
28
|
+
options: Optional[model_types.ModelSaveOption] = None
|
|
29
|
+
save_location: Optional[str] = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ModelParameterReconciler:
|
|
33
|
+
"""Centralizes all complex log_model parameter validation, transformation, and reconciliation logic."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
session: Session,
|
|
38
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
39
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
40
|
+
conda_dependencies: Optional[list[str]] = None,
|
|
41
|
+
pip_requirements: Optional[list[str]] = None,
|
|
42
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
|
43
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
|
44
|
+
options: Optional[model_types.ModelSaveOption] = None,
|
|
45
|
+
python_version: Optional[str] = None,
|
|
46
|
+
statement_params: Optional[dict[str, str]] = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
self._session = session
|
|
49
|
+
self._database_name = database_name
|
|
50
|
+
self._schema_name = schema_name
|
|
51
|
+
self._conda_dependencies = conda_dependencies
|
|
52
|
+
self._pip_requirements = pip_requirements
|
|
53
|
+
self._target_platforms = target_platforms
|
|
54
|
+
self._artifact_repository_map = artifact_repository_map
|
|
55
|
+
self._options = options
|
|
56
|
+
self._python_version = python_version
|
|
57
|
+
self._statement_params = statement_params
|
|
58
|
+
|
|
59
|
+
def reconcile(self) -> ReconciledParameters:
|
|
60
|
+
"""Perform all parameter reconciliation and return clean parameters."""
|
|
61
|
+
|
|
62
|
+
reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
|
|
63
|
+
reconciled_save_location = self._extract_save_location()
|
|
64
|
+
|
|
65
|
+
self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
|
|
66
|
+
|
|
67
|
+
reconciled_target_platforms = self._reconcile_target_platforms()
|
|
68
|
+
reconciled_options = self._reconcile_explainability_options(reconciled_target_platforms)
|
|
69
|
+
reconciled_options = self._reconcile_relax_version(reconciled_options, reconciled_target_platforms)
|
|
70
|
+
|
|
71
|
+
return ReconciledParameters(
|
|
72
|
+
conda_dependencies=self._conda_dependencies,
|
|
73
|
+
pip_requirements=self._pip_requirements,
|
|
74
|
+
target_platforms=reconciled_target_platforms,
|
|
75
|
+
artifact_repository_map=reconciled_artifact_repository_map,
|
|
76
|
+
options=reconciled_options,
|
|
77
|
+
save_location=reconciled_save_location,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def _reconcile_artifact_repository_map(self) -> Optional[dict[str, str]]:
|
|
81
|
+
"""Transform artifact_repository_map to use fully qualified names."""
|
|
82
|
+
if not self._artifact_repository_map:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
transformed_map = {}
|
|
86
|
+
|
|
87
|
+
for channel, artifact_repository_name in self._artifact_repository_map.items():
|
|
88
|
+
db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
|
|
89
|
+
|
|
90
|
+
transformed_map[channel] = sql_identifier.get_fully_qualified_name(
|
|
91
|
+
db_id,
|
|
92
|
+
schema_id,
|
|
93
|
+
repo_id,
|
|
94
|
+
self._database_name,
|
|
95
|
+
self._schema_name,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return transformed_map
|
|
99
|
+
|
|
100
|
+
def _extract_save_location(self) -> Optional[str]:
|
|
101
|
+
"""Extract save_location from options."""
|
|
102
|
+
if self._options and "save_location" in self._options:
|
|
103
|
+
return self._options.get("save_location")
|
|
104
|
+
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
def _reconcile_target_platforms(self) -> Optional[list[model_types.TargetPlatform]]:
|
|
108
|
+
"""Reconcile target platforms with proper defaulting logic."""
|
|
109
|
+
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
|
110
|
+
if self._target_platforms:
|
|
111
|
+
# Convert any string target platforms to TargetPlatform objects
|
|
112
|
+
return [model_types.TargetPlatform(platform) for platform in self._target_platforms]
|
|
113
|
+
|
|
114
|
+
# Default the target platform to warehouse if not specified and any table function exists
|
|
115
|
+
if self._has_table_function():
|
|
116
|
+
logger.info(
|
|
117
|
+
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
|
118
|
+
'Default to `target_platforms=["WAREHOUSE"]`.'
|
|
119
|
+
)
|
|
120
|
+
return [target_platform.TargetPlatform.WAREHOUSE]
|
|
121
|
+
|
|
122
|
+
# Default the target platform to SPCS if not specified when running in ML runtime
|
|
123
|
+
if env.IN_ML_RUNTIME:
|
|
124
|
+
logger.info(
|
|
125
|
+
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
|
126
|
+
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
|
127
|
+
)
|
|
128
|
+
return [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
|
129
|
+
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
def _has_table_function(self) -> bool:
|
|
133
|
+
"""Check if any table function exists in options."""
|
|
134
|
+
if self._options is None:
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
if self._options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
|
138
|
+
return True
|
|
139
|
+
|
|
140
|
+
for opt in self._options.get("method_options", {}).values():
|
|
141
|
+
if opt.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
|
142
|
+
return True
|
|
143
|
+
|
|
144
|
+
return False
|
|
145
|
+
|
|
146
|
+
def _validate_pip_requirements_warehouse_compatibility(
|
|
147
|
+
self, artifact_repository_map: Optional[dict[str, str]]
|
|
148
|
+
) -> None:
|
|
149
|
+
"""Validate pip_requirements compatibility with warehouse deployment."""
|
|
150
|
+
if self._pip_requirements and not artifact_repository_map and self._targets_warehouse(self._target_platforms):
|
|
151
|
+
warnings.warn(
|
|
152
|
+
"Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
|
|
153
|
+
"without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
|
|
154
|
+
"Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
|
|
155
|
+
category=UserWarning,
|
|
156
|
+
stacklevel=1,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
@staticmethod
|
|
160
|
+
def _targets_warehouse(target_platforms: Optional[list[model_types.SupportedTargetPlatformType]]) -> bool:
|
|
161
|
+
"""Returns True if warehouse is a target platform (None defaults to True)."""
|
|
162
|
+
return (
|
|
163
|
+
target_platforms is None
|
|
164
|
+
or model_types.TargetPlatform.WAREHOUSE in target_platforms
|
|
165
|
+
or "WAREHOUSE" in target_platforms
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def _reconcile_explainability_options(
|
|
169
|
+
self, target_platforms: Optional[list[model_types.TargetPlatform]]
|
|
170
|
+
) -> model_types.ModelSaveOption:
|
|
171
|
+
"""Reconcile explainability settings and embed_local_ml_library based on warehouse runnability."""
|
|
172
|
+
options = self._options.copy() if self._options else model_types.BaseModelSaveOption()
|
|
173
|
+
|
|
174
|
+
conda_dep_dict = env_utils.validate_conda_dependency_string_list(self._conda_dependencies or [])
|
|
175
|
+
|
|
176
|
+
enable_explainability = options.get("enable_explainability", None)
|
|
177
|
+
|
|
178
|
+
# Handle case where user explicitly disabled explainability
|
|
179
|
+
if enable_explainability is False:
|
|
180
|
+
return self._handle_embed_local_ml_library(options, target_platforms)
|
|
181
|
+
|
|
182
|
+
target_platform_set = set(target_platforms) if target_platforms else set()
|
|
183
|
+
|
|
184
|
+
is_warehouse_runnable = self._is_warehouse_runnable(conda_dep_dict)
|
|
185
|
+
only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
|
|
186
|
+
has_both_platforms = target_platform_set == set(target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES)
|
|
187
|
+
|
|
188
|
+
# Handle case where user explicitly requested explainability
|
|
189
|
+
if enable_explainability:
|
|
190
|
+
if only_spcs or not is_warehouse_runnable:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"`enable_explainability` cannot be set to True when the model is not runnable in WH "
|
|
193
|
+
"or the target platforms include SPCS."
|
|
194
|
+
)
|
|
195
|
+
elif has_both_platforms:
|
|
196
|
+
warnings.warn(
|
|
197
|
+
("Explain function will only be available for model deployed to warehouse."),
|
|
198
|
+
category=UserWarning,
|
|
199
|
+
stacklevel=2,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Handle case where explainability is not specified (None) - set default behavior
|
|
203
|
+
if enable_explainability is None:
|
|
204
|
+
if only_spcs or not is_warehouse_runnable:
|
|
205
|
+
options["enable_explainability"] = False
|
|
206
|
+
|
|
207
|
+
return self._handle_embed_local_ml_library(options, target_platforms)
|
|
208
|
+
|
|
209
|
+
def _handle_embed_local_ml_library(
|
|
210
|
+
self, options: model_types.ModelSaveOption, target_platforms: Optional[list[model_types.TargetPlatform]]
|
|
211
|
+
) -> model_types.ModelSaveOption:
|
|
212
|
+
"""Handle embed_local_ml_library logic."""
|
|
213
|
+
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
|
214
|
+
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
|
215
|
+
]:
|
|
216
|
+
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
|
217
|
+
self._session,
|
|
218
|
+
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
|
219
|
+
python_version=self._python_version or snowml_env.PYTHON_VERSION,
|
|
220
|
+
statement_params=self._statement_params,
|
|
221
|
+
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
|
222
|
+
|
|
223
|
+
if len(snowml_matched_versions) < 1 and not options.get("embed_local_ml_library", False):
|
|
224
|
+
logging.info(
|
|
225
|
+
f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
|
|
226
|
+
" which is not available in the Snowflake server, embedding local ML library automatically."
|
|
227
|
+
)
|
|
228
|
+
options["embed_local_ml_library"] = True
|
|
229
|
+
|
|
230
|
+
return options
|
|
231
|
+
|
|
232
|
+
def _is_warehouse_runnable(self, conda_dep_dict: dict[str, list[Any]]) -> bool:
|
|
233
|
+
"""Check if model can run in warehouse based on conda channels and pip requirements."""
|
|
234
|
+
# If pip requirements are present but no artifact repository map, model cannot run in warehouse
|
|
235
|
+
if self._pip_requirements and not self._artifact_repository_map:
|
|
236
|
+
return False
|
|
237
|
+
|
|
238
|
+
# If no conda dependencies, model can run in warehouse
|
|
239
|
+
if not conda_dep_dict:
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
# Check if all conda channels are warehouse-compatible
|
|
243
|
+
warehouse_compatible_channels = {env_utils.DEFAULT_CHANNEL_NAME, env_utils.SNOWFLAKE_CONDA_CHANNEL_URL}
|
|
244
|
+
for channel in conda_dep_dict:
|
|
245
|
+
if channel not in warehouse_compatible_channels:
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
return True
|
|
249
|
+
|
|
250
|
+
def _reconcile_relax_version(
|
|
251
|
+
self,
|
|
252
|
+
options: model_types.ModelSaveOption,
|
|
253
|
+
target_platforms: Optional[list[model_types.TargetPlatform]],
|
|
254
|
+
) -> model_types.ModelSaveOption:
|
|
255
|
+
"""Reconcile relax_version setting based on pip requirements and target platforms."""
|
|
256
|
+
target_platform_set = set(target_platforms) if target_platforms else set()
|
|
257
|
+
has_pip_requirements = bool(self._pip_requirements)
|
|
258
|
+
only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
|
|
259
|
+
|
|
260
|
+
if "relax_version" not in options:
|
|
261
|
+
if has_pip_requirements or only_spcs:
|
|
262
|
+
logger.info(
|
|
263
|
+
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
|
264
|
+
"or in Warehouse with a specified artifact_repository_map where exact version "
|
|
265
|
+
" specifications will be honored."
|
|
266
|
+
)
|
|
267
|
+
relax_version = False
|
|
268
|
+
else:
|
|
269
|
+
warnings.warn(
|
|
270
|
+
(
|
|
271
|
+
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
|
272
|
+
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
|
273
|
+
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
|
274
|
+
),
|
|
275
|
+
category=UserWarning,
|
|
276
|
+
stacklevel=2,
|
|
277
|
+
)
|
|
278
|
+
relax_version = True
|
|
279
|
+
options["relax_version"] = relax_version
|
|
280
|
+
return options
|
|
281
|
+
|
|
282
|
+
# Handle case where relax_version is already set
|
|
283
|
+
relax_version = options["relax_version"]
|
|
284
|
+
if relax_version and (has_pip_requirements or only_spcs):
|
|
285
|
+
raise exceptions.SnowflakeMLException(
|
|
286
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
287
|
+
original_exception=ValueError(
|
|
288
|
+
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
|
289
|
+
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
|
290
|
+
"targeting only Snowpark Container Services."
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return options
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from types import ModuleType
|
|
3
2
|
from typing import Any, Optional, Union, overload
|
|
4
3
|
|
|
@@ -442,15 +441,6 @@ class Registry:
|
|
|
442
441
|
if task is not type_hints.Task.UNKNOWN:
|
|
443
442
|
raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
|
|
444
443
|
|
|
445
|
-
if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
|
|
446
|
-
warnings.warn(
|
|
447
|
-
"Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
|
|
448
|
-
"without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
|
|
449
|
-
"Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
|
|
450
|
-
category=UserWarning,
|
|
451
|
-
stacklevel=1,
|
|
452
|
-
)
|
|
453
|
-
|
|
454
444
|
registry_event_handler = event_handler.ModelEventHandler()
|
|
455
445
|
with registry_event_handler.status("Logging model", total=6) as status:
|
|
456
446
|
# Step 1: Validation and setup
|
|
@@ -662,12 +652,3 @@ class Registry:
|
|
|
662
652
|
if not self.enable_monitoring:
|
|
663
653
|
raise ValueError(_MODEL_MONITORING_DISABLED_ERROR)
|
|
664
654
|
self._model_monitor_manager.delete_monitor(name)
|
|
665
|
-
|
|
666
|
-
@staticmethod
|
|
667
|
-
def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
|
|
668
|
-
"""Returns True if warehouse is a target platform (None defaults to True)."""
|
|
669
|
-
return (
|
|
670
|
-
target_platforms is None
|
|
671
|
-
or type_hints.TargetPlatform.WAREHOUSE in target_platforms
|
|
672
|
-
or "WAREHOUSE" in target_platforms
|
|
673
|
-
)
|
snowflake/ml/version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
|
2
|
-
VERSION = "1.
|
|
2
|
+
VERSION = "1.11.0"
|