snowflake-ml-python 1.9.2__py3-none-any.whl → 1.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/service_logger.py +31 -17
- snowflake/ml/experiment/callback/keras.py +63 -0
- snowflake/ml/experiment/callback/lightgbm.py +59 -0
- snowflake/ml/experiment/callback/xgboost.py +67 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/__init__.py +0 -0
- snowflake/ml/jobs/_utils/constants.py +4 -1
- snowflake/ml/jobs/_utils/payload_utils.py +55 -21
- snowflake/ml/jobs/_utils/query_helper.py +5 -1
- snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -5
- snowflake/ml/jobs/_utils/spec_utils.py +41 -8
- snowflake/ml/jobs/_utils/stage_utils.py +22 -9
- snowflake/ml/jobs/_utils/types.py +5 -7
- snowflake/ml/jobs/job.py +1 -1
- snowflake/ml/jobs/manager.py +1 -13
- snowflake/ml/model/_client/model/model_version_impl.py +219 -55
- snowflake/ml/model/_client/ops/service_ops.py +230 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -70
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/inference_engine.py +5 -0
- snowflake/ml/model/models/huggingface_pipeline.py +74 -51
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +37 -70
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +294 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +523 -491
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +36 -29
- snowflake/ml/experiment/callback.py +0 -121
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.2.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,12 @@
|
|
|
1
1
|
import pathlib
|
|
2
2
|
import tempfile
|
|
3
3
|
import uuid
|
|
4
|
-
import warnings
|
|
5
4
|
from types import ModuleType
|
|
6
5
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
7
6
|
from urllib import parse
|
|
8
7
|
|
|
9
|
-
from absl import logging
|
|
10
|
-
from packaging import requirements
|
|
11
|
-
|
|
12
8
|
from snowflake import snowpark
|
|
13
|
-
from snowflake.ml import
|
|
14
|
-
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
|
9
|
+
from snowflake.ml._internal import file_utils
|
|
15
10
|
from snowflake.ml._internal.lineage import lineage_utils
|
|
16
11
|
from snowflake.ml.data import data_source
|
|
17
12
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
|
@@ -19,7 +14,6 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest
|
|
|
19
14
|
from snowflake.ml.model._packager import model_packager
|
|
20
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
|
21
16
|
from snowflake.snowpark import Session
|
|
22
|
-
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
23
17
|
|
|
24
18
|
if TYPE_CHECKING:
|
|
25
19
|
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
|
@@ -142,73 +136,10 @@ class ModelComposer:
|
|
|
142
136
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
143
137
|
options: Optional[model_types.ModelSaveOption] = None,
|
|
144
138
|
) -> model_meta.ModelMetadata:
|
|
145
|
-
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
|
146
|
-
conda_dep_dict = env_utils.validate_conda_dependency_string_list(
|
|
147
|
-
conda_dependencies if conda_dependencies else []
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
enable_explainability = None
|
|
151
|
-
|
|
152
|
-
if options:
|
|
153
|
-
enable_explainability = options.get("enable_explainability", None)
|
|
154
|
-
|
|
155
|
-
# skip everything if user said False explicitly
|
|
156
|
-
if enable_explainability is None or enable_explainability is True:
|
|
157
|
-
is_warehouse_runnable = (
|
|
158
|
-
not conda_dep_dict
|
|
159
|
-
or all(
|
|
160
|
-
chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
|
161
|
-
for chan in conda_dep_dict
|
|
162
|
-
)
|
|
163
|
-
) and (not pip_requirements)
|
|
164
|
-
|
|
165
|
-
only_spcs = (
|
|
166
|
-
target_platforms
|
|
167
|
-
and len(target_platforms) == 1
|
|
168
|
-
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
|
169
|
-
)
|
|
170
|
-
if only_spcs or (not is_warehouse_runnable):
|
|
171
|
-
# if only SPCS and user asked for explainability we fail
|
|
172
|
-
if enable_explainability is True:
|
|
173
|
-
raise ValueError(
|
|
174
|
-
"`enable_explainability` cannot be set to True when the model is not runnable in WH "
|
|
175
|
-
"or the target platforms include SPCS."
|
|
176
|
-
)
|
|
177
|
-
elif not options: # explicitly set flag to false in these cases if not specified
|
|
178
|
-
options = model_types.BaseModelSaveOption()
|
|
179
|
-
options["enable_explainability"] = False
|
|
180
|
-
elif (
|
|
181
|
-
target_platforms
|
|
182
|
-
and len(target_platforms) > 1
|
|
183
|
-
and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
|
|
184
|
-
): # if both then only available for WH
|
|
185
|
-
if enable_explainability is True:
|
|
186
|
-
warnings.warn(
|
|
187
|
-
("Explain function will only be available for model deployed to warehouse."),
|
|
188
|
-
category=UserWarning,
|
|
189
|
-
stacklevel=2,
|
|
190
|
-
)
|
|
191
139
|
|
|
192
140
|
if not options:
|
|
193
141
|
options = model_types.BaseModelSaveOption()
|
|
194
142
|
|
|
195
|
-
if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
|
|
196
|
-
model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
|
|
197
|
-
]:
|
|
198
|
-
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
|
199
|
-
self.session,
|
|
200
|
-
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
|
|
201
|
-
python_version=python_version or snowml_env.PYTHON_VERSION,
|
|
202
|
-
statement_params=self._statement_params,
|
|
203
|
-
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
|
204
|
-
|
|
205
|
-
if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
|
|
206
|
-
logging.info(
|
|
207
|
-
f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
|
|
208
|
-
" which is not available in the Snowflake server, embedding local ML library automatically."
|
|
209
|
-
)
|
|
210
|
-
options["embed_local_ml_library"] = True
|
|
211
|
-
|
|
212
143
|
model_metadata: model_meta.ModelMetadata = self.packager.save(
|
|
213
144
|
name=name,
|
|
214
145
|
model=model,
|
|
@@ -1,13 +1,11 @@
|
|
|
1
1
|
import collections
|
|
2
2
|
import logging
|
|
3
3
|
import pathlib
|
|
4
|
-
import warnings
|
|
5
4
|
from typing import TYPE_CHECKING, Optional, cast
|
|
6
5
|
|
|
7
6
|
import yaml
|
|
8
7
|
|
|
9
8
|
from snowflake.ml._internal import env_utils
|
|
10
|
-
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
|
11
9
|
from snowflake.ml.data import data_source
|
|
12
10
|
from snowflake.ml.model import type_hints
|
|
13
11
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -55,47 +53,8 @@ class ModelManifest:
|
|
|
55
53
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
56
54
|
target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
|
|
57
55
|
) -> None:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
has_pip_requirements = len(model_meta.env.pip_requirements) > 0
|
|
62
|
-
only_spcs = (
|
|
63
|
-
target_platforms
|
|
64
|
-
and len(target_platforms) == 1
|
|
65
|
-
and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
if "relax_version" not in options:
|
|
69
|
-
if has_pip_requirements or only_spcs:
|
|
70
|
-
logger.info(
|
|
71
|
-
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
|
72
|
-
"or in Warehouse with a specified artifact_repository_map where exact version "
|
|
73
|
-
" specifications will be honored."
|
|
74
|
-
)
|
|
75
|
-
relax_version = False
|
|
76
|
-
else:
|
|
77
|
-
warnings.warn(
|
|
78
|
-
(
|
|
79
|
-
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
|
80
|
-
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
|
81
|
-
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
|
82
|
-
),
|
|
83
|
-
category=UserWarning,
|
|
84
|
-
stacklevel=2,
|
|
85
|
-
)
|
|
86
|
-
relax_version = True
|
|
87
|
-
options["relax_version"] = relax_version
|
|
88
|
-
else:
|
|
89
|
-
relax_version = options.get("relax_version", True)
|
|
90
|
-
if relax_version and (has_pip_requirements or only_spcs):
|
|
91
|
-
raise exceptions.SnowflakeMLException(
|
|
92
|
-
error_code=error_codes.INVALID_ARGUMENT,
|
|
93
|
-
original_exception=ValueError(
|
|
94
|
-
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
|
95
|
-
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
|
96
|
-
"targeting only Snowpark Container Services."
|
|
97
|
-
),
|
|
98
|
-
)
|
|
56
|
+
assert options is not None, "ModelParameterReconciler should have set options with relax_version"
|
|
57
|
+
relax_version = options["relax_version"]
|
|
99
58
|
|
|
100
59
|
runtime_to_use = model_runtime.ModelRuntime(
|
|
101
60
|
name=self._DEFAULT_RUNTIME_NAME,
|
|
@@ -23,12 +23,24 @@ class _TqdmStatusContext:
|
|
|
23
23
|
if state == "complete":
|
|
24
24
|
self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
|
|
25
25
|
self._progress_bar.set_description(label)
|
|
26
|
+
elif state == "error":
|
|
27
|
+
# For error state, use the label as-is and mark with ERROR prefix
|
|
28
|
+
# Don't update progress bar position for errors - leave it where it was
|
|
29
|
+
self._progress_bar.set_description(f"❌ ERROR: {label}")
|
|
26
30
|
else:
|
|
27
|
-
|
|
31
|
+
combined_desc = f"{self._label}: {label}" if label != self._label else self._label
|
|
32
|
+
self._progress_bar.set_description(combined_desc)
|
|
28
33
|
|
|
29
|
-
def increment(self
|
|
34
|
+
def increment(self) -> None:
|
|
30
35
|
"""Increment the progress bar."""
|
|
31
|
-
self._progress_bar.update(
|
|
36
|
+
self._progress_bar.update(1)
|
|
37
|
+
|
|
38
|
+
def complete(self) -> None:
|
|
39
|
+
"""Complete the progress bar to full state."""
|
|
40
|
+
if self._total:
|
|
41
|
+
remaining = self._total - self._progress_bar.n
|
|
42
|
+
if remaining > 0:
|
|
43
|
+
self._progress_bar.update(remaining)
|
|
32
44
|
|
|
33
45
|
|
|
34
46
|
class _StreamlitStatusContext:
|
|
@@ -39,6 +51,7 @@ class _StreamlitStatusContext:
|
|
|
39
51
|
self._streamlit = streamlit_module
|
|
40
52
|
self._total = total
|
|
41
53
|
self._current = 0
|
|
54
|
+
self._current_label = label
|
|
42
55
|
self._progress_bar = None
|
|
43
56
|
|
|
44
57
|
def __enter__(self) -> "_StreamlitStatusContext":
|
|
@@ -49,26 +62,70 @@ class _StreamlitStatusContext:
|
|
|
49
62
|
return self
|
|
50
63
|
|
|
51
64
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
52
|
-
|
|
65
|
+
# Only update to complete if there was no exception
|
|
66
|
+
if exc_type is None:
|
|
67
|
+
self._status_container.update(state="complete")
|
|
53
68
|
|
|
54
69
|
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
55
70
|
"""Update the status label."""
|
|
56
|
-
if state
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
71
|
+
if state == "complete" or state == "error":
|
|
72
|
+
# For completion/error, use the message as-is and update main status
|
|
73
|
+
self._status_container.update(label=label, state=state, expanded=expanded)
|
|
74
|
+
self._current_label = label
|
|
75
|
+
|
|
76
|
+
# For error state, update progress bar text but preserve position
|
|
77
|
+
if state == "error" and self._total is not None and self._progress_bar is not None:
|
|
78
|
+
self._progress_bar.progress(
|
|
79
|
+
self._current / self._total if self._total > 0 else 0,
|
|
80
|
+
text=f"ERROR - ({self._current}/{self._total})",
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
combined_label = f"{self._label}: {label}" if label != self._label else self._label
|
|
84
|
+
self._status_container.update(label=combined_label, state=state, expanded=expanded)
|
|
85
|
+
self._current_label = label
|
|
86
|
+
if self._total is not None and self._progress_bar is not None:
|
|
87
|
+
progress_value = self._current / self._total if self._total > 0 else 0
|
|
88
|
+
self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
|
|
89
|
+
|
|
90
|
+
def increment(self) -> None:
|
|
66
91
|
"""Increment the progress."""
|
|
67
92
|
if self._total is not None:
|
|
68
|
-
self._current = min(self._current +
|
|
93
|
+
self._current = min(self._current + 1, self._total)
|
|
69
94
|
if self._progress_bar is not None:
|
|
70
95
|
progress_value = self._current / self._total if self._total > 0 else 0
|
|
71
|
-
self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
|
|
96
|
+
self._progress_bar.progress(progress_value, text=f"({self._current}/{self._total})")
|
|
97
|
+
|
|
98
|
+
def complete(self) -> None:
|
|
99
|
+
"""Complete the progress bar to full state."""
|
|
100
|
+
if self._total is not None:
|
|
101
|
+
self._current = self._total
|
|
102
|
+
if self._progress_bar is not None:
|
|
103
|
+
self._progress_bar.progress(1.0, text=f"({self._current}/{self._total})")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class _NoOpStatusContext:
|
|
107
|
+
"""A no-op context manager for when status updates should be disabled."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, label: str) -> None:
|
|
110
|
+
self._label = label
|
|
111
|
+
|
|
112
|
+
def __enter__(self) -> "_NoOpStatusContext":
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
119
|
+
"""No-op update method."""
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
def increment(self) -> None:
|
|
123
|
+
"""No-op increment method."""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
def complete(self) -> None:
|
|
127
|
+
"""No-op complete method."""
|
|
128
|
+
pass
|
|
72
129
|
|
|
73
130
|
|
|
74
131
|
class ModelEventHandler:
|
|
@@ -99,7 +156,15 @@ class ModelEventHandler:
|
|
|
99
156
|
else:
|
|
100
157
|
self._tqdm.tqdm.write(message)
|
|
101
158
|
|
|
102
|
-
def status(
|
|
159
|
+
def status(
|
|
160
|
+
self,
|
|
161
|
+
label: str,
|
|
162
|
+
*,
|
|
163
|
+
state: str = "running",
|
|
164
|
+
expanded: bool = True,
|
|
165
|
+
total: Optional[int] = None,
|
|
166
|
+
block: bool = True,
|
|
167
|
+
) -> Any:
|
|
103
168
|
"""Context manager that provides status updates with optional enhanced display capabilities.
|
|
104
169
|
|
|
105
170
|
Args:
|
|
@@ -107,10 +172,14 @@ class ModelEventHandler:
|
|
|
107
172
|
state: The initial state ("running", "complete", "error")
|
|
108
173
|
expanded: Whether to show expanded view (streamlit only)
|
|
109
174
|
total: Total number of steps for progress tracking (optional)
|
|
175
|
+
block: Whether to show progress updates (no-op if False)
|
|
110
176
|
|
|
111
177
|
Returns:
|
|
112
|
-
Status context (Streamlit or
|
|
178
|
+
Status context (Streamlit, Tqdm, or NoOp based on availability and block parameter)
|
|
113
179
|
"""
|
|
180
|
+
if not block:
|
|
181
|
+
return _NoOpStatusContext(label)
|
|
182
|
+
|
|
114
183
|
if self._streamlit is not None:
|
|
115
184
|
return _StreamlitStatusContext(label, self._streamlit, total)
|
|
116
185
|
else:
|
|
@@ -258,7 +258,7 @@ class HuggingFacePipelineModel:
|
|
|
258
258
|
# model_version_impl.create_service parameters
|
|
259
259
|
service_name: str,
|
|
260
260
|
service_compute_pool: str,
|
|
261
|
-
image_repo: str,
|
|
261
|
+
image_repo: Optional[str] = None,
|
|
262
262
|
image_build_compute_pool: Optional[str] = None,
|
|
263
263
|
ingress_enabled: bool = False,
|
|
264
264
|
max_instances: int = 1,
|
|
@@ -282,7 +282,8 @@ class HuggingFacePipelineModel:
|
|
|
282
282
|
comment: Comment for the model. Defaults to None.
|
|
283
283
|
service_name: The name of the service to create.
|
|
284
284
|
service_compute_pool: The compute pool for the service.
|
|
285
|
-
image_repo: The name of the image repository.
|
|
285
|
+
image_repo: The name of the image repository. This can be None, in that case a default hidden image
|
|
286
|
+
repository will be used.
|
|
286
287
|
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
|
287
288
|
the service compute pool if None.
|
|
288
289
|
ingress_enabled: Whether ingress is enabled. Defaults to False.
|
|
@@ -299,6 +300,7 @@ class HuggingFacePipelineModel:
|
|
|
299
300
|
Raises:
|
|
300
301
|
ValueError: if database and schema name is not provided and session doesn't have a
|
|
301
302
|
database and schema name.
|
|
303
|
+
exceptions.SnowparkSQLException: if service already exists.
|
|
302
304
|
|
|
303
305
|
Returns:
|
|
304
306
|
The service ID or an async job object.
|
|
@@ -327,7 +329,6 @@ class HuggingFacePipelineModel:
|
|
|
327
329
|
version_name = name_generator.generate()[1]
|
|
328
330
|
|
|
329
331
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
330
|
-
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
|
331
332
|
|
|
332
333
|
service_operator = service_ops.ServiceOperator(
|
|
333
334
|
session=session,
|
|
@@ -336,51 +337,73 @@ class HuggingFacePipelineModel:
|
|
|
336
337
|
)
|
|
337
338
|
logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
|
|
338
339
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
340
|
+
from snowflake.ml.model import event_handler
|
|
341
|
+
from snowflake.snowpark import exceptions
|
|
342
|
+
|
|
343
|
+
hf_event_handler = event_handler.ModelEventHandler()
|
|
344
|
+
with hf_event_handler.status("Creating HuggingFace model service", total=6, block=block) as status:
|
|
345
|
+
try:
|
|
346
|
+
result = service_operator.create_service(
|
|
347
|
+
database_name=database_name_id,
|
|
348
|
+
schema_name=schema_name_id,
|
|
349
|
+
model_name=model_name_id,
|
|
350
|
+
version_name=sql_identifier.SqlIdentifier(version_name),
|
|
351
|
+
service_database_name=service_db_id,
|
|
352
|
+
service_schema_name=service_schema_id,
|
|
353
|
+
service_name=service_id,
|
|
354
|
+
image_build_compute_pool_name=(
|
|
355
|
+
sql_identifier.SqlIdentifier(image_build_compute_pool)
|
|
356
|
+
if image_build_compute_pool
|
|
357
|
+
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
358
|
+
),
|
|
359
|
+
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
360
|
+
image_repo_name=image_repo,
|
|
361
|
+
ingress_enabled=ingress_enabled,
|
|
362
|
+
max_instances=max_instances,
|
|
363
|
+
cpu_requests=cpu_requests,
|
|
364
|
+
memory_requests=memory_requests,
|
|
365
|
+
gpu_requests=gpu_requests,
|
|
366
|
+
num_workers=num_workers,
|
|
367
|
+
max_batch_rows=max_batch_rows,
|
|
368
|
+
force_rebuild=force_rebuild,
|
|
369
|
+
build_external_access_integrations=(
|
|
370
|
+
None
|
|
371
|
+
if build_external_access_integrations is None
|
|
372
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
|
373
|
+
),
|
|
374
|
+
block=block,
|
|
375
|
+
progress_status=status,
|
|
376
|
+
statement_params=statement_params,
|
|
377
|
+
# hf model
|
|
378
|
+
hf_model_args=service_ops.HFModelArgs(
|
|
379
|
+
hf_model_name=self.model,
|
|
380
|
+
hf_task=self.task,
|
|
381
|
+
hf_tokenizer=self.tokenizer,
|
|
382
|
+
hf_revision=self.revision,
|
|
383
|
+
hf_token=self.token,
|
|
384
|
+
hf_trust_remote_code=bool(self.trust_remote_code),
|
|
385
|
+
hf_model_kwargs=self.model_kwargs,
|
|
386
|
+
pip_requirements=pip_requirements,
|
|
387
|
+
conda_dependencies=conda_dependencies,
|
|
388
|
+
comment=comment,
|
|
389
|
+
# TODO: remove warehouse in the next release
|
|
390
|
+
warehouse=session.get_current_warehouse(),
|
|
391
|
+
),
|
|
392
|
+
)
|
|
393
|
+
status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
|
|
394
|
+
return result
|
|
395
|
+
except exceptions.SnowparkSQLException as e:
|
|
396
|
+
# Check if the error is because the service already exists
|
|
397
|
+
if "already exists" in str(e).lower() or "100132" in str(
|
|
398
|
+
e
|
|
399
|
+
): # 100132 is Snowflake error code for object already exists
|
|
400
|
+
# Update progress to show service already exists (preserve exception behavior)
|
|
401
|
+
status.update("service already exists")
|
|
402
|
+
status.complete() # Complete progress to full state
|
|
403
|
+
status.update(label="Service already exists", state="error", expanded=False)
|
|
404
|
+
# Re-raise the exception to preserve existing API behavior
|
|
405
|
+
raise
|
|
406
|
+
else:
|
|
407
|
+
# Re-raise other SQL exceptions
|
|
408
|
+
status.update(label="Service creation failed", state="error", expanded=False)
|
|
409
|
+
raise
|
snowflake/ml/model/type_hints.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
|
1
1
|
# mypy: disable-error-code="import"
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Any,
|
|
5
|
+
Literal,
|
|
6
|
+
Protocol,
|
|
7
|
+
Sequence,
|
|
8
|
+
TypedDict,
|
|
9
|
+
TypeVar,
|
|
10
|
+
Union,
|
|
11
|
+
)
|
|
3
12
|
|
|
4
13
|
import numpy.typing as npt
|
|
5
14
|
from typing_extensions import NotRequired
|
|
@@ -326,4 +335,20 @@ ModelLoadOption = Union[
|
|
|
326
335
|
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
|
327
336
|
|
|
328
337
|
|
|
338
|
+
class ProgressStatus(Protocol):
|
|
339
|
+
"""Protocol for tracking progress during long-running operations."""
|
|
340
|
+
|
|
341
|
+
def update(self, message: str, *, state: str = "running", expanded: bool = True, **kwargs: Any) -> None:
|
|
342
|
+
"""Update the progress status with a new message."""
|
|
343
|
+
...
|
|
344
|
+
|
|
345
|
+
def increment(self) -> None:
|
|
346
|
+
"""Increment the progress by one step."""
|
|
347
|
+
...
|
|
348
|
+
|
|
349
|
+
def complete(self) -> None:
|
|
350
|
+
"""Complete the progress bar to full state."""
|
|
351
|
+
...
|
|
352
|
+
|
|
353
|
+
|
|
329
354
|
__all__ = ["TargetPlatform", "Task"]
|