snowflake-ml-python 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +5 -1
- snowflake/ml/experiment/callback/xgboost.py +5 -1
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/payload_utils.py +42 -14
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +3 -3
- snowflake/ml/jobs/_utils/spec_utils.py +41 -8
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +5 -7
- snowflake/ml/jobs/job.py +1 -1
- snowflake/ml/jobs/manager.py +1 -13
- snowflake/ml/model/_client/model/model_version_impl.py +166 -10
- snowflake/ml/model/_client/ops/service_ops.py +63 -28
- snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +4 -3
- snowflake/ml/registry/_manager/model_manager.py +7 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +23 -4
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +31 -27
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -10,10 +10,15 @@ class Model(BaseModel):
|
|
|
10
10
|
version: str
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
class InferenceEngineSpec(BaseModel):
|
|
14
|
+
inference_engine_name: str
|
|
15
|
+
inference_engine_args: Optional[list[str]] = None
|
|
16
|
+
|
|
17
|
+
|
|
13
18
|
class ImageBuild(BaseModel):
|
|
14
|
-
compute_pool: str
|
|
15
|
-
image_repo: str
|
|
16
|
-
force_rebuild: bool
|
|
19
|
+
compute_pool: Optional[str] = None
|
|
20
|
+
image_repo: Optional[str] = None
|
|
21
|
+
force_rebuild: Optional[bool] = None
|
|
17
22
|
external_access_integrations: Optional[list[str]] = None
|
|
18
23
|
|
|
19
24
|
|
|
@@ -27,6 +32,7 @@ class Service(BaseModel):
|
|
|
27
32
|
gpu: Optional[str] = None
|
|
28
33
|
num_workers: Optional[int] = None
|
|
29
34
|
max_batch_rows: Optional[int] = None
|
|
35
|
+
inference_engine_spec: Optional[InferenceEngineSpec] = None
|
|
30
36
|
|
|
31
37
|
|
|
32
38
|
class Job(BaseModel):
|
|
@@ -68,13 +74,13 @@ class ModelLogging(BaseModel):
|
|
|
68
74
|
|
|
69
75
|
class ModelServiceDeploymentSpec(BaseModel):
|
|
70
76
|
models: list[Model]
|
|
71
|
-
image_build: ImageBuild
|
|
77
|
+
image_build: Optional[ImageBuild] = None
|
|
72
78
|
service: Service
|
|
73
79
|
model_loggings: Optional[list[ModelLogging]] = None
|
|
74
80
|
|
|
75
81
|
|
|
76
82
|
class ModelJobDeploymentSpec(BaseModel):
|
|
77
83
|
models: list[Model]
|
|
78
|
-
image_build: ImageBuild
|
|
84
|
+
image_build: Optional[ImageBuild] = None
|
|
79
85
|
job: Job
|
|
80
86
|
model_loggings: Optional[list[ModelLogging]] = None
|
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
import pathlib
|
|
2
2
|
import tempfile
|
|
3
3
|
import uuid
|
|
4
|
-
import warnings
|
|
5
4
|
from types import ModuleType
|
|
6
5
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
7
6
|
from urllib import parse
|
|
8
7
|
|
|
9
|
-
from absl import logging
|
|
10
|
-
from packaging import requirements
|
|
11
|
-
|
|
12
8
|
from snowflake import snowpark
|
|
13
|
-
from snowflake.ml import
|
|
14
|
-
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
|
9
|
+
from snowflake.ml._internal import file_utils
|
|
15
10
|
from snowflake.ml._internal.lineage import lineage_utils
|
|
16
11
|
from snowflake.ml.data import data_source
|
|
17
12
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
|
@@ -19,7 +14,6 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
|
|
19
14
|
from snowflake.ml.model._packager import model_packager
|
|
20
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
|
21
16
|
from snowflake.snowpark import Session
|
|
22
|
-
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
23
17
|
|
|
24
18
|
if TYPE_CHECKING:
|
|
25
19
|
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
|
@@ -142,73 +136,10 @@ class ModelComposer:
|
|
|
142
136
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
143
137
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
144
138
|
) -> model_meta.ModelMetadata:
|
|
145
|
-
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
|
146
|
-
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
|
147
|
-
conda_dependencies if conda_dependencies else []
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
enable_explainability = None
|
|
151
|
-
|
|
152
|
-
if options:
|
|
153
|
-
enable_explainability = options.get("enable_explainability", None)
|
|
154
|
-
|
|
155
|
-
# skip everything if user said False explicitly
|
|
156
|
-
if enable_explainability is None or enable_explainability is True:
|
|
157
|
-
is_warehouse_runnable = (
|
|
158
|
-
not conda_dep_dict
|
|
159
|
-
or all(
|
|
160
|
-
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
|
161
|
-
for chan in conda_dep_dict
|
|
162
|
-
)
|
|
163
|
-
) and (not pip_requirements)
|
|
164
|
-
|
|
165
|
-
only_spcs = (
|
|
166
|
-
target_platforms
|
|
167
|
-
and len(target_platforms) == 1
|
|
168
|
-
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
|
169
|
-
)
|
|
170
|
-
if only_spcs or (not is_warehouse_runnable):
|
|
171
|
-
# if only SPCS and user asked for explainability we fail
|
|
172
|
-
if enable_explainability is True:
|
|
173
|
-
raise ValueError(
|
|
174
|
-
"`enable_explainability` cannot be set to True when the model is not runnable in WH "
|
|
175
|
-
"or the target platforms include SPCS."
|
|
176
|
-
)
|
|
177
|
-
elif not options: # explicitly set flag to false in these cases if not specified
|
|
178
|
-
options = model_types.BaseModelSaveOption()
|
|
179
|
-
options["enable_explainability"] = False
|
|
180
|
-
elif (
|
|
181
|
-
target_platforms
|
|
182
|
-
and len(target_platforms) > 1
|
|
183
|
-
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
|
184
|
-
): # if both then only available for WH
|
|
185
|
-
if enable_explainability is True:
|
|
186
|
-
warnings.warn(
|
|
187
|
-
("Explain function will only be available for model deployed to warehouse."),
|
|
188
|
-
category=UserWarning,
|
|
189
|
-
stacklevel=2,
|
|
190
|
-
)
|
|
191
139
|
|
|
192
140
|
if not options:
|
|
193
141
|
options = model_types.BaseModelSaveOption()
|
|
194
142
|
|
|
195
|
-
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
|
196
|
-
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
|
197
|
-
]:
|
|
198
|
-
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
|
199
|
-
self.session,
|
|
200
|
-
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
|
201
|
-
python_version=python_version or snowml_env.PYTHON_VERSION,
|
|
202
|
-
statement_params=self._statement_params,
|
|
203
|
-
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
|
204
|
-
|
|
205
|
-
if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
|
|
206
|
-
logging.info(
|
|
207
|
-
f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
|
|
208
|
-
" which is not available in the Snowflake server, embedding local ML library automatically."
|
|
209
|
-
)
|
|
210
|
-
options["embed_local_ml_library"] = True
|
|
211
|
-
|
|
212
143
|
model_metadata: model_meta.ModelMetadata = self.packager.save(
|
|
213
144
|
name=name,
|
|
214
145
|
model=model,
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
import collections
|
|
2
2
|
import logging
|
|
3
3
|
import pathlib
|
|
4
|
-
import warnings
|
|
5
4
|
from typing import TYPE_CHECKING, Optional, cast
|
|
6
5
|
|
|
7
6
|
import yaml
|
|
8
7
|
|
|
9
8
|
from snowflake.ml._internal import env_utils
|
|
10
|
-
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
11
9
|
from snowflake.ml.data import data_source
|
|
12
10
|
from snowflake.ml.model import type_hints
|
|
13
11
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -55,47 +53,8 @@ class ModelManifest:
|
|
|
55
53
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
56
54
|
target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
|
|
57
55
|
) -> None:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
has_pip_requirements = len(model_meta.env.pip_requirements) > 0
|
|
62
|
-
only_spcs = (
|
|
63
|
-
target_platforms
|
|
64
|
-
and len(target_platforms) == 1
|
|
65
|
-
and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
if "relax_version" not in options:
|
|
69
|
-
if has_pip_requirements or only_spcs:
|
|
70
|
-
logger.info(
|
|
71
|
-
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
|
72
|
-
"or in Warehouse with a specified artifact_repository_map where exact version "
|
|
73
|
-
" specifications will be honored."
|
|
74
|
-
)
|
|
75
|
-
relax_version = False
|
|
76
|
-
else:
|
|
77
|
-
warnings.warn(
|
|
78
|
-
(
|
|
79
|
-
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
|
80
|
-
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
|
81
|
-
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
|
82
|
-
),
|
|
83
|
-
category=UserWarning,
|
|
84
|
-
stacklevel=2,
|
|
85
|
-
)
|
|
86
|
-
relax_version = True
|
|
87
|
-
options["relax_version"] = relax_version
|
|
88
|
-
else:
|
|
89
|
-
relax_version = options.get("relax_version", True)
|
|
90
|
-
if relax_version and (has_pip_requirements or only_spcs):
|
|
91
|
-
raise exceptions.SnowflakeMLException(
|
|
92
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
|
93
|
-
original_exception=ValueError(
|
|
94
|
-
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
|
95
|
-
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
|
96
|
-
"targeting only Snowpark Container Services."
|
|
97
|
-
),
|
|
98
|
-
)
|
|
56
|
+
assert options is not None, "ModelParameterReconciler should have set options with relax_version"
|
|
57
|
+
relax_version = options["relax_version"]
|
|
99
58
|
|
|
100
59
|
runtime_to_use = model_runtime.ModelRuntime(
|
|
101
60
|
name=self._DEFAULT_RUNTIME_NAME,
|
|
@@ -258,7 +258,7 @@ class HuggingFacePipelineModel:
|
|
|
258
258
|
# model_version_impl.create_service parameters
|
|
259
259
|
service_name: str,
|
|
260
260
|
service_compute_pool: str,
|
|
261
|
-
image_repo: str,
|
|
261
|
+
image_repo: Optional[str] = None,
|
|
262
262
|
image_build_compute_pool: Optional[str] = None,
|
|
263
263
|
ingress_enabled: bool = False,
|
|
264
264
|
max_instances: int = 1,
|
|
@@ -282,7 +282,8 @@ class HuggingFacePipelineModel:
|
|
|
282
282
|
comment: Comment for the model. Defaults to None.
|
|
283
283
|
service_name: The name of the service to create.
|
|
284
284
|
service_compute_pool: The compute pool for the service.
|
|
285
|
-
image_repo: The name of the image repository.
|
|
285
|
+
image_repo: The name of the image repository. This can be None, in that case a default hidden image
|
|
286
|
+
repository will be used.
|
|
286
287
|
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
|
287
288
|
the service compute pool if None.
|
|
288
289
|
ingress_enabled: Whether ingress is enabled. Defaults to False.
|
|
@@ -356,7 +357,7 @@ class HuggingFacePipelineModel:
|
|
|
356
357
|
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
357
358
|
),
|
|
358
359
|
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
359
|
-
|
|
360
|
+
image_repo_name=image_repo,
|
|
360
361
|
ingress_enabled=ingress_enabled,
|
|
361
362
|
max_instances=max_instances,
|
|
362
363
|
cpu_requests=cpu_requests,
|
|
@@ -4,15 +4,14 @@ from typing import TYPE_CHECKING, Any, Optional, Union
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from absl.logging import logging
|
|
6
6
|
|
|
7
|
-
from snowflake.ml._internal import
|
|
7
|
+
from snowflake.ml._internal import platform_capabilities, telemetry
|
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
|
11
|
-
from snowflake.ml.model import model_signature,
|
|
11
|
+
from snowflake.ml.model import model_signature, task, type_hints
|
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
|
15
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
16
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
|
17
16
|
from snowflake.ml.registry._manager import model_parameter_reconciler
|
|
18
17
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
|
@@ -221,37 +220,8 @@ class ModelManager:
|
|
|
221
220
|
statement_params=statement_params,
|
|
222
221
|
)
|
|
223
222
|
|
|
224
|
-
platforms = None
|
|
225
|
-
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
|
226
|
-
if target_platforms:
|
|
227
|
-
# Convert any string target platforms to TargetPlatform objects
|
|
228
|
-
platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
|
|
229
|
-
else:
|
|
230
|
-
# Default the target platform to warehouse if not specified and any table function exists
|
|
231
|
-
if options and (
|
|
232
|
-
options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
|
233
|
-
or (
|
|
234
|
-
any(
|
|
235
|
-
opt.get("function_type") == "TABLE_FUNCTION"
|
|
236
|
-
for opt in options.get("method_options", {}).values()
|
|
237
|
-
)
|
|
238
|
-
)
|
|
239
|
-
):
|
|
240
|
-
logger.info(
|
|
241
|
-
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
|
242
|
-
'Default to `target_platforms=["WAREHOUSE"]`.'
|
|
243
|
-
)
|
|
244
|
-
platforms = [target_platform.TargetPlatform.WAREHOUSE]
|
|
245
|
-
|
|
246
|
-
# Default the target platform to SPCS if not specified when running in ML runtime
|
|
247
|
-
if not platforms and env.IN_ML_RUNTIME:
|
|
248
|
-
logger.info(
|
|
249
|
-
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
|
250
|
-
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
|
251
|
-
)
|
|
252
|
-
platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
|
253
|
-
|
|
254
223
|
reconciler = model_parameter_reconciler.ModelParameterReconciler(
|
|
224
|
+
session=self._model_ops._session,
|
|
255
225
|
database_name=self._database_name,
|
|
256
226
|
schema_name=self._schema_name,
|
|
257
227
|
conda_dependencies=conda_dependencies,
|
|
@@ -259,6 +229,8 @@ class ModelManager:
|
|
|
259
229
|
target_platforms=target_platforms,
|
|
260
230
|
artifact_repository_map=artifact_repository_map,
|
|
261
231
|
options=options,
|
|
232
|
+
python_version=python_version,
|
|
233
|
+
statement_params=statement_params,
|
|
262
234
|
)
|
|
263
235
|
|
|
264
236
|
model_params = reconciler.reconcile()
|
|
@@ -293,12 +265,12 @@ class ModelManager:
|
|
|
293
265
|
pip_requirements=pip_requirements,
|
|
294
266
|
artifact_repository_map=artifact_repository_map,
|
|
295
267
|
resource_constraint=resource_constraint,
|
|
296
|
-
target_platforms=
|
|
268
|
+
target_platforms=model_params.target_platforms,
|
|
297
269
|
python_version=python_version,
|
|
298
270
|
user_files=user_files,
|
|
299
271
|
code_paths=code_paths,
|
|
300
272
|
ext_modules=ext_modules,
|
|
301
|
-
options=options,
|
|
273
|
+
options=model_params.options,
|
|
302
274
|
task=task,
|
|
303
275
|
experiment_info=experiment_info,
|
|
304
276
|
)
|
|
@@ -1,9 +1,20 @@
|
|
|
1
1
|
import warnings
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Optional
|
|
4
4
|
|
|
5
|
+
from absl.logging import logging
|
|
6
|
+
from packaging import requirements
|
|
7
|
+
|
|
8
|
+
from snowflake.ml import version as snowml_version
|
|
9
|
+
from snowflake.ml._internal import env, env as snowml_env, env_utils
|
|
10
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
5
11
|
from snowflake.ml._internal.utils import sql_identifier
|
|
6
|
-
from snowflake.ml.model import type_hints as model_types
|
|
12
|
+
from snowflake.ml.model import target_platform, type_hints as model_types
|
|
13
|
+
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
14
|
+
from snowflake.snowpark import Session
|
|
15
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
7
18
|
|
|
8
19
|
|
|
9
20
|
@dataclass
|
|
@@ -12,7 +23,7 @@ class ReconciledParameters:
|
|
|
12
23
|
|
|
13
24
|
conda_dependencies: Optional[list[str]] = None
|
|
14
25
|
pip_requirements: Optional[list[str]] = None
|
|
15
|
-
target_platforms: Optional[list[model_types.
|
|
26
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None
|
|
16
27
|
artifact_repository_map: Optional[dict[str, str]] = None
|
|
17
28
|
options: Optional[model_types.ModelSaveOption] = None
|
|
18
29
|
save_location: Optional[str] = None
|
|
@@ -23,6 +34,7 @@ class ModelParameterReconciler:
|
|
|
23
34
|
|
|
24
35
|
def __init__(
|
|
25
36
|
self,
|
|
37
|
+
session: Session,
|
|
26
38
|
database_name: sql_identifier.SqlIdentifier,
|
|
27
39
|
schema_name: sql_identifier.SqlIdentifier,
|
|
28
40
|
conda_dependencies: Optional[list[str]] = None,
|
|
@@ -30,7 +42,10 @@ class ModelParameterReconciler:
|
|
|
30
42
|
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
|
31
43
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
|
32
44
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
45
|
+
python_version: Optional[str] = None,
|
|
46
|
+
statement_params: Optional[dict[str, str]] = None,
|
|
33
47
|
) -> None:
|
|
48
|
+
self._session = session
|
|
34
49
|
self._database_name = database_name
|
|
35
50
|
self._schema_name = schema_name
|
|
36
51
|
self._conda_dependencies = conda_dependencies
|
|
@@ -38,20 +53,27 @@ class ModelParameterReconciler:
|
|
|
38
53
|
self._target_platforms = target_platforms
|
|
39
54
|
self._artifact_repository_map = artifact_repository_map
|
|
40
55
|
self._options = options
|
|
56
|
+
self._python_version = python_version
|
|
57
|
+
self._statement_params = statement_params
|
|
41
58
|
|
|
42
59
|
def reconcile(self) -> ReconciledParameters:
|
|
43
60
|
"""Perform all parameter reconciliation and return clean parameters."""
|
|
61
|
+
|
|
44
62
|
reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
|
|
45
63
|
reconciled_save_location = self._extract_save_location()
|
|
46
64
|
|
|
47
65
|
self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
|
|
48
66
|
|
|
67
|
+
reconciled_target_platforms = self._reconcile_target_platforms()
|
|
68
|
+
reconciled_options = self._reconcile_explainability_options(reconciled_target_platforms)
|
|
69
|
+
reconciled_options = self._reconcile_relax_version(reconciled_options, reconciled_target_platforms)
|
|
70
|
+
|
|
49
71
|
return ReconciledParameters(
|
|
50
72
|
conda_dependencies=self._conda_dependencies,
|
|
51
73
|
pip_requirements=self._pip_requirements,
|
|
52
|
-
target_platforms=
|
|
74
|
+
target_platforms=reconciled_target_platforms,
|
|
53
75
|
artifact_repository_map=reconciled_artifact_repository_map,
|
|
54
|
-
options=
|
|
76
|
+
options=reconciled_options,
|
|
55
77
|
save_location=reconciled_save_location,
|
|
56
78
|
)
|
|
57
79
|
|
|
@@ -82,6 +104,45 @@ class ModelParameterReconciler:
|
|
|
82
104
|
|
|
83
105
|
return None
|
|
84
106
|
|
|
107
|
+
def _reconcile_target_platforms(self) -> Optional[list[model_types.TargetPlatform]]:
|
|
108
|
+
"""Reconcile target platforms with proper defaulting logic."""
|
|
109
|
+
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
|
110
|
+
if self._target_platforms:
|
|
111
|
+
# Convert any string target platforms to TargetPlatform objects
|
|
112
|
+
return [model_types.TargetPlatform(platform) for platform in self._target_platforms]
|
|
113
|
+
|
|
114
|
+
# Default the target platform to warehouse if not specified and any table function exists
|
|
115
|
+
if self._has_table_function():
|
|
116
|
+
logger.info(
|
|
117
|
+
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
|
118
|
+
'Default to `target_platforms=["WAREHOUSE"]`.'
|
|
119
|
+
)
|
|
120
|
+
return [target_platform.TargetPlatform.WAREHOUSE]
|
|
121
|
+
|
|
122
|
+
# Default the target platform to SPCS if not specified when running in ML runtime
|
|
123
|
+
if env.IN_ML_RUNTIME:
|
|
124
|
+
logger.info(
|
|
125
|
+
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
|
126
|
+
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
|
127
|
+
)
|
|
128
|
+
return [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
|
129
|
+
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
def _has_table_function(self) -> bool:
|
|
133
|
+
"""Check if any table function exists in options."""
|
|
134
|
+
if self._options is None:
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
if self._options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
|
138
|
+
return True
|
|
139
|
+
|
|
140
|
+
for opt in self._options.get("method_options", {}).values():
|
|
141
|
+
if opt.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
|
142
|
+
return True
|
|
143
|
+
|
|
144
|
+
return False
|
|
145
|
+
|
|
85
146
|
def _validate_pip_requirements_warehouse_compatibility(
|
|
86
147
|
self, artifact_repository_map: Optional[dict[str, str]]
|
|
87
148
|
) -> None:
|
|
@@ -103,3 +164,131 @@ class ModelParameterReconciler:
|
|
|
103
164
|
or model_types.TargetPlatform.WAREHOUSE in target_platforms
|
|
104
165
|
or "WAREHOUSE" in target_platforms
|
|
105
166
|
)
|
|
167
|
+
|
|
168
|
+
def _reconcile_explainability_options(
|
|
169
|
+
self, target_platforms: Optional[list[model_types.TargetPlatform]]
|
|
170
|
+
) -> model_types.ModelSaveOption:
|
|
171
|
+
"""Reconcile explainability settings and embed_local_ml_library based on warehouse runnability."""
|
|
172
|
+
options = self._options.copy() if self._options else model_types.BaseModelSaveOption()
|
|
173
|
+
|
|
174
|
+
conda_dep_dict = env_utils.validate_conda_dependency_string_list(self._conda_dependencies or [])
|
|
175
|
+
|
|
176
|
+
enable_explainability = options.get("enable_explainability", None)
|
|
177
|
+
|
|
178
|
+
# Handle case where user explicitly disabled explainability
|
|
179
|
+
if enable_explainability is False:
|
|
180
|
+
return self._handle_embed_local_ml_library(options, target_platforms)
|
|
181
|
+
|
|
182
|
+
target_platform_set = set(target_platforms) if target_platforms else set()
|
|
183
|
+
|
|
184
|
+
is_warehouse_runnable = self._is_warehouse_runnable(conda_dep_dict)
|
|
185
|
+
only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
|
|
186
|
+
has_both_platforms = target_platform_set == set(target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES)
|
|
187
|
+
|
|
188
|
+
# Handle case where user explicitly requested explainability
|
|
189
|
+
if enable_explainability:
|
|
190
|
+
if only_spcs or not is_warehouse_runnable:
|
|
191
|
+
raise ValueError(
|
|
192
|
+
"`enable_explainability` cannot be set to True when the model is not runnable in WH "
|
|
193
|
+
"or the target platforms include SPCS."
|
|
194
|
+
)
|
|
195
|
+
elif has_both_platforms:
|
|
196
|
+
warnings.warn(
|
|
197
|
+
("Explain function will only be available for model deployed to warehouse."),
|
|
198
|
+
category=UserWarning,
|
|
199
|
+
stacklevel=2,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Handle case where explainability is not specified (None) - set default behavior
|
|
203
|
+
if enable_explainability is None:
|
|
204
|
+
if only_spcs or not is_warehouse_runnable:
|
|
205
|
+
options["enable_explainability"] = False
|
|
206
|
+
|
|
207
|
+
return self._handle_embed_local_ml_library(options, target_platforms)
|
|
208
|
+
|
|
209
|
+
def _handle_embed_local_ml_library(
|
|
210
|
+
self, options: model_types.ModelSaveOption, target_platforms: Optional[list[model_types.TargetPlatform]]
|
|
211
|
+
) -> model_types.ModelSaveOption:
|
|
212
|
+
"""Handle embed_local_ml_library logic."""
|
|
213
|
+
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
|
214
|
+
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
|
215
|
+
]:
|
|
216
|
+
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
|
217
|
+
self._session,
|
|
218
|
+
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
|
219
|
+
python_version=self._python_version or snowml_env.PYTHON_VERSION,
|
|
220
|
+
statement_params=self._statement_params,
|
|
221
|
+
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
|
222
|
+
|
|
223
|
+
if len(snowml_matched_versions) < 1 and not options.get("embed_local_ml_library", False):
|
|
224
|
+
logging.info(
|
|
225
|
+
f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
|
|
226
|
+
" which is not available in the Snowflake server, embedding local ML library automatically."
|
|
227
|
+
)
|
|
228
|
+
options["embed_local_ml_library"] = True
|
|
229
|
+
|
|
230
|
+
return options
|
|
231
|
+
|
|
232
|
+
def _is_warehouse_runnable(self, conda_dep_dict: dict[str, list[Any]]) -> bool:
|
|
233
|
+
"""Check if model can run in warehouse based on conda channels and pip requirements."""
|
|
234
|
+
# If pip requirements are present but no artifact repository map, model cannot run in warehouse
|
|
235
|
+
if self._pip_requirements and not self._artifact_repository_map:
|
|
236
|
+
return False
|
|
237
|
+
|
|
238
|
+
# If no conda dependencies, model can run in warehouse
|
|
239
|
+
if not conda_dep_dict:
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
# Check if all conda channels are warehouse-compatible
|
|
243
|
+
warehouse_compatible_channels = {env_utils.DEFAULT_CHANNEL_NAME, env_utils.SNOWFLAKE_CONDA_CHANNEL_URL}
|
|
244
|
+
for channel in conda_dep_dict:
|
|
245
|
+
if channel not in warehouse_compatible_channels:
|
|
246
|
+
return False
|
|
247
|
+
|
|
248
|
+
return True
|
|
249
|
+
|
|
250
|
+
def _reconcile_relax_version(
|
|
251
|
+
self,
|
|
252
|
+
options: model_types.ModelSaveOption,
|
|
253
|
+
target_platforms: Optional[list[model_types.TargetPlatform]],
|
|
254
|
+
) -> model_types.ModelSaveOption:
|
|
255
|
+
"""Reconcile relax_version setting based on pip requirements and target platforms."""
|
|
256
|
+
target_platform_set = set(target_platforms) if target_platforms else set()
|
|
257
|
+
has_pip_requirements = bool(self._pip_requirements)
|
|
258
|
+
only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
|
|
259
|
+
|
|
260
|
+
if "relax_version" not in options:
|
|
261
|
+
if has_pip_requirements or only_spcs:
|
|
262
|
+
logger.info(
|
|
263
|
+
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
|
264
|
+
"or in Warehouse with a specified artifact_repository_map where exact version "
|
|
265
|
+
" specifications will be honored."
|
|
266
|
+
)
|
|
267
|
+
relax_version = False
|
|
268
|
+
else:
|
|
269
|
+
warnings.warn(
|
|
270
|
+
(
|
|
271
|
+
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
|
272
|
+
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
|
273
|
+
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
|
274
|
+
),
|
|
275
|
+
category=UserWarning,
|
|
276
|
+
stacklevel=2,
|
|
277
|
+
)
|
|
278
|
+
relax_version = True
|
|
279
|
+
options["relax_version"] = relax_version
|
|
280
|
+
return options
|
|
281
|
+
|
|
282
|
+
# Handle case where relax_version is already set
|
|
283
|
+
relax_version = options["relax_version"]
|
|
284
|
+
if relax_version and (has_pip_requirements or only_spcs):
|
|
285
|
+
raise exceptions.SnowflakeMLException(
|
|
286
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
|
287
|
+
original_exception=ValueError(
|
|
288
|
+
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
|
289
|
+
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
|
290
|
+
"targeting only Snowpark Container Services."
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return options
|
snowflake/ml/version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
|
2
|
-
VERSION = "1.
|
|
2
|
+
VERSION = "1.11.0"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: snowflake-ml-python
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.11.0
|
|
4
4
|
Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
|
|
5
5
|
Author-email: "Snowflake, Inc" <support@snowflake.com>
|
|
6
6
|
License:
|
|
@@ -410,12 +410,32 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
|
|
|
410
410
|
|
|
411
411
|
# Release History
|
|
412
412
|
|
|
413
|
-
## 1.
|
|
413
|
+
## 1.11.0
|
|
414
414
|
|
|
415
415
|
### Bug Fixes
|
|
416
416
|
|
|
417
|
+
* ML Job: Fix `Error: Unable to retrieve head IP address` if not all instances start within the timeout.
|
|
418
|
+
* ML Job: Fix `TypeError: SnowflakeCursor.execute() got an unexpected keyword argument '_force_qmark_paramstyle'`
|
|
419
|
+
when running inside Stored Procedures.
|
|
420
|
+
|
|
421
|
+
### Behavior Changes
|
|
422
|
+
|
|
423
|
+
### New Features
|
|
424
|
+
|
|
425
|
+
* `ModelVersion.create_service()`: Made `image_repo` argument optional. By
|
|
426
|
+
default it will use a default image repo, which is
|
|
427
|
+
being rolled out in server version 9.22+.
|
|
428
|
+
* Experiment Tracking (PrPr): Automatically log the model, metrics, and parameters while training Keras models with
|
|
429
|
+
`snowflake.ml.experiment.callback.keras.SnowflakeKerasCallback`.
|
|
430
|
+
|
|
431
|
+
## 1.10.0
|
|
432
|
+
|
|
417
433
|
### Behavior Changes
|
|
418
434
|
|
|
435
|
+
* Experiment Tracking (PrPr): The import paths for the auto-logging callbacks have changed to
|
|
436
|
+
`snowflake.ml.experiment.callback.xgboost.SnowflakeXgboostCallback` and
|
|
437
|
+
`snowflake.ml.experiment.callback.lightgbm.SnowflakeLightgbmCallback`.
|
|
438
|
+
|
|
419
439
|
### New Features
|
|
420
440
|
|
|
421
441
|
* Registry: add progress bars for `ModelVersion.create_service` and `ModelVersion.log_model`.
|
|
@@ -436,13 +456,13 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
|
|
|
436
456
|
|
|
437
457
|
```python
|
|
438
458
|
from snowflake.ml.experiment import ExperimentTracking
|
|
459
|
+
from snowflake.ml.experiment.callback import SnowflakeXgboostCallback, SnowflakeLightgbmCallback
|
|
439
460
|
|
|
440
461
|
exp = ExperimentTracking(session=sp_session, database_name="ML", schema_name="PUBLIC")
|
|
441
462
|
|
|
442
463
|
exp.set_experiment("MY_EXPERIMENT")
|
|
443
464
|
|
|
444
465
|
# XGBoost
|
|
445
|
-
from snowflake.ml.experiment.callback.xgboost import SnowflakeXgboostCallback
|
|
446
466
|
callback = SnowflakeXgboostCallback(
|
|
447
467
|
exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig
|
|
448
468
|
)
|
|
@@ -451,7 +471,6 @@ with exp.start_run():
|
|
|
451
471
|
model.fit(X, y, eval_set=[(X_test, y_test)])
|
|
452
472
|
|
|
453
473
|
# LightGBM
|
|
454
|
-
from snowflake.ml.experiment.callback.lightgbm import SnowflakeLightgbmCallback
|
|
455
474
|
callback = SnowflakeLightgbmCallback(
|
|
456
475
|
exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig
|
|
457
476
|
)
|