snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +6 -5
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/spec_utils.py +6 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +179 -58
- snowflake/ml/jobs/manager.py +194 -145
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +4 -2
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +119 -42
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from typing import Optional, cast
|
|
7
7
|
import yaml
|
8
8
|
|
9
9
|
from snowflake.ml._internal import env_utils
|
10
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
10
11
|
from snowflake.ml.data import data_source
|
11
12
|
from snowflake.ml.model import type_hints
|
12
13
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
@@ -53,17 +54,44 @@ class ModelManifest:
|
|
53
54
|
if options is None:
|
54
55
|
options = {}
|
55
56
|
|
57
|
+
has_pip_requirements = len(model_meta.env.pip_requirements) > 0
|
58
|
+
only_spcs = (
|
59
|
+
target_platforms
|
60
|
+
and len(target_platforms) == 1
|
61
|
+
and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
|
62
|
+
)
|
63
|
+
|
56
64
|
if "relax_version" not in options:
|
57
|
-
|
58
|
-
(
|
59
|
-
"`relax_version`
|
60
|
-
"
|
61
|
-
"
|
62
|
-
)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
65
|
+
if has_pip_requirements or only_spcs:
|
66
|
+
logger.info(
|
67
|
+
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
68
|
+
"or in Warehouse with a specified artifact_repository_map where exact version "
|
69
|
+
" specifications will be honored."
|
70
|
+
)
|
71
|
+
relax_version = False
|
72
|
+
else:
|
73
|
+
warnings.warn(
|
74
|
+
(
|
75
|
+
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
76
|
+
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
77
|
+
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
78
|
+
),
|
79
|
+
category=UserWarning,
|
80
|
+
stacklevel=2,
|
81
|
+
)
|
82
|
+
relax_version = True
|
83
|
+
options["relax_version"] = relax_version
|
84
|
+
else:
|
85
|
+
relax_version = options.get("relax_version", True)
|
86
|
+
if relax_version and (has_pip_requirements or only_spcs):
|
87
|
+
raise exceptions.SnowflakeMLException(
|
88
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
89
|
+
original_exception=ValueError(
|
90
|
+
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
91
|
+
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
92
|
+
"targeting only Snowpark Container Services."
|
93
|
+
),
|
94
|
+
)
|
67
95
|
|
68
96
|
runtime_to_use = model_runtime.ModelRuntime(
|
69
97
|
name=self._DEFAULT_RUNTIME_NAME,
|
@@ -9,6 +9,7 @@ from packaging import requirements, version
|
|
9
9
|
|
10
10
|
from snowflake.ml import version as snowml_version
|
11
11
|
from snowflake.ml._internal import env as snowml_env, env_utils
|
12
|
+
from snowflake.ml.model import type_hints as model_types
|
12
13
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
13
14
|
|
14
15
|
# requirement: Full version requirement where name is conda package name.
|
@@ -30,6 +31,7 @@ class ModelEnv:
|
|
30
31
|
conda_env_rel_path: Optional[str] = None,
|
31
32
|
pip_requirements_rel_path: Optional[str] = None,
|
32
33
|
prefer_pip: bool = False,
|
34
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
33
35
|
) -> None:
|
34
36
|
if conda_env_rel_path is None:
|
35
37
|
conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
|
@@ -45,6 +47,8 @@ class ModelEnv:
|
|
45
47
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
46
48
|
self._cuda_version: Optional[version.Version] = None
|
47
49
|
self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
|
50
|
+
self._target_platforms = target_platforms
|
51
|
+
self._warnings_shown: set[str] = set()
|
48
52
|
|
49
53
|
@property
|
50
54
|
def conda_dependencies(self) -> list[str]:
|
@@ -116,6 +120,17 @@ class ModelEnv:
|
|
116
120
|
if snowpark_ml_version:
|
117
121
|
self._snowpark_ml_version = version.parse(snowpark_ml_version)
|
118
122
|
|
123
|
+
@property
|
124
|
+
def targets_warehouse(self) -> bool:
|
125
|
+
"""Returns True if warehouse is a target platform."""
|
126
|
+
return self._target_platforms is None or model_types.TargetPlatform.WAREHOUSE in self._target_platforms
|
127
|
+
|
128
|
+
def _warn_once(self, message: str, stacklevel: int = 2) -> None:
|
129
|
+
"""Show warning only once per ModelEnv instance."""
|
130
|
+
if message not in self._warnings_shown:
|
131
|
+
warnings.warn(message, category=UserWarning, stacklevel=stacklevel)
|
132
|
+
self._warnings_shown.add(message)
|
133
|
+
|
119
134
|
def include_if_absent(
|
120
135
|
self,
|
121
136
|
pkgs: list[ModelDependency],
|
@@ -130,14 +145,14 @@ class ModelEnv:
|
|
130
145
|
"""
|
131
146
|
if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
|
132
147
|
pip_pkg_reqs: list[str] = []
|
133
|
-
|
134
|
-
(
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
148
|
+
if self.targets_warehouse:
|
149
|
+
self._warn_once(
|
150
|
+
(
|
151
|
+
"Dependencies specified from pip requirements."
|
152
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
153
|
+
),
|
154
|
+
stacklevel=2,
|
155
|
+
)
|
141
156
|
for conda_req_str, pip_name in pkgs:
|
142
157
|
_, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
|
143
158
|
pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
|
@@ -162,16 +177,15 @@ class ModelEnv:
|
|
162
177
|
req_to_add.name = conda_req.name
|
163
178
|
else:
|
164
179
|
req_to_add = conda_req
|
165
|
-
show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
|
180
|
+
show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
|
166
181
|
|
167
182
|
if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
|
168
183
|
if show_warning_message:
|
169
|
-
|
184
|
+
self._warn_once(
|
170
185
|
(
|
171
186
|
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
172
187
|
" This may prevent model deploying to Snowflake Warehouse."
|
173
188
|
),
|
174
|
-
category=UserWarning,
|
175
189
|
stacklevel=2,
|
176
190
|
)
|
177
191
|
continue
|
@@ -182,12 +196,11 @@ class ModelEnv:
|
|
182
196
|
pass
|
183
197
|
except env_utils.DuplicateDependencyInMultipleChannelsError:
|
184
198
|
if show_warning_message:
|
185
|
-
|
199
|
+
self._warn_once(
|
186
200
|
(
|
187
201
|
f"Basic dependency {req_to_add.name} specified from non-Snowflake channel."
|
188
202
|
+ " This may prevent model deploying to Snowflake Warehouse."
|
189
203
|
),
|
190
|
-
category=UserWarning,
|
191
204
|
stacklevel=2,
|
192
205
|
)
|
193
206
|
|
@@ -272,22 +285,20 @@ class ModelEnv:
|
|
272
285
|
)
|
273
286
|
|
274
287
|
for channel, channel_dependencies in conda_dependencies_dict.items():
|
275
|
-
if channel != env_utils.DEFAULT_CHANNEL_NAME:
|
276
|
-
|
288
|
+
if channel != env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse:
|
289
|
+
self._warn_once(
|
277
290
|
(
|
278
291
|
"Found dependencies specified in the conda file from non-Snowflake channel."
|
279
292
|
" This may prevent model deploying to Snowflake Warehouse."
|
280
293
|
),
|
281
|
-
category=UserWarning,
|
282
294
|
stacklevel=2,
|
283
295
|
)
|
284
|
-
if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
|
285
|
-
|
296
|
+
if len(channel_dependencies) == 0 and channel not in self._conda_dependencies and self.targets_warehouse:
|
297
|
+
self._warn_once(
|
286
298
|
(
|
287
299
|
f"Found additional conda channel {channel} specified in the conda file."
|
288
300
|
" This may prevent model deploying to Snowflake Warehouse."
|
289
301
|
),
|
290
|
-
category=UserWarning,
|
291
302
|
stacklevel=2,
|
292
303
|
)
|
293
304
|
self._conda_dependencies[channel] = []
|
@@ -298,22 +309,20 @@ class ModelEnv:
|
|
298
309
|
except env_utils.DuplicateDependencyError:
|
299
310
|
pass
|
300
311
|
except env_utils.DuplicateDependencyInMultipleChannelsError:
|
301
|
-
|
312
|
+
self._warn_once(
|
302
313
|
(
|
303
314
|
f"Dependency {channel_dependency.name} appeared in multiple channels as conda dependency."
|
304
315
|
" This may be unintentional."
|
305
316
|
),
|
306
|
-
category=UserWarning,
|
307
317
|
stacklevel=2,
|
308
318
|
)
|
309
319
|
|
310
|
-
if pip_requirements_list:
|
311
|
-
|
320
|
+
if pip_requirements_list and self.targets_warehouse:
|
321
|
+
self._warn_once(
|
312
322
|
(
|
313
323
|
"Found dependencies specified as pip requirements."
|
314
324
|
" This may prevent model deploying to Snowflake Warehouse."
|
315
325
|
),
|
316
|
-
category=UserWarning,
|
317
326
|
stacklevel=2,
|
318
327
|
)
|
319
328
|
for pip_dependency in pip_requirements_list:
|
@@ -333,13 +342,12 @@ class ModelEnv:
|
|
333
342
|
def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
|
334
343
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
335
344
|
|
336
|
-
if pip_requirements_list:
|
337
|
-
|
345
|
+
if pip_requirements_list and self.targets_warehouse:
|
346
|
+
self._warn_once(
|
338
347
|
(
|
339
348
|
"Found dependencies specified as pip requirements."
|
340
349
|
" This may prevent model deploying to Snowflake Warehouse."
|
341
350
|
),
|
342
|
-
category=UserWarning,
|
343
351
|
stacklevel=2,
|
344
352
|
)
|
345
353
|
for pip_dependency in pip_requirements_list:
|
@@ -167,7 +167,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
167
167
|
model_blob_metadata = model_blobs_metadata[name]
|
168
168
|
model_blob_filename = model_blob_metadata.path
|
169
169
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
170
|
-
m = torch.load(
|
170
|
+
m = torch.load(
|
171
|
+
f,
|
172
|
+
map_location="cuda" if kwargs.get("use_gpu", False) else "cpu",
|
173
|
+
weights_only=False,
|
174
|
+
)
|
171
175
|
assert isinstance(m, torch.nn.Module)
|
172
176
|
|
173
177
|
return m
|
@@ -110,6 +110,7 @@ def create_model_metadata(
|
|
110
110
|
python_version=python_version,
|
111
111
|
embed_local_ml_library=embed_local_ml_library,
|
112
112
|
prefer_pip=prefer_pip,
|
113
|
+
target_platforms=target_platforms,
|
113
114
|
)
|
114
115
|
|
115
116
|
if embed_local_ml_library:
|
@@ -162,8 +163,9 @@ def _create_env_for_model_metadata(
|
|
162
163
|
python_version: Optional[str] = None,
|
163
164
|
embed_local_ml_library: bool = False,
|
164
165
|
prefer_pip: bool = False,
|
166
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
165
167
|
) -> model_env.ModelEnv:
|
166
|
-
env = model_env.ModelEnv(prefer_pip=prefer_pip)
|
168
|
+
env = model_env.ModelEnv(prefer_pip=prefer_pip, target_platforms=target_platforms)
|
167
169
|
|
168
170
|
# Mypy doesn't like getter and setter have different types. See python/mypy #3004
|
169
171
|
env.conda_dependencies = conda_dependencies # type: ignore[assignment]
|
@@ -60,12 +60,19 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
60
60
|
data: snowflake.snowpark.DataFrame,
|
61
61
|
ensure_serializable: bool = True,
|
62
62
|
features: Optional[Sequence[core.BaseFeatureSpec]] = None,
|
63
|
+
statement_params: Optional[dict[str, Any]] = None,
|
63
64
|
) -> pd.DataFrame:
|
64
65
|
# This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
|
65
66
|
dtype_map = {}
|
67
|
+
|
66
68
|
if features:
|
69
|
+
quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
70
|
+
data.session, statement_params
|
71
|
+
)
|
67
72
|
for feature in features:
|
68
|
-
|
73
|
+
feature_name = feature.name.upper() if quoted_identifiers_ignore_case else feature.name
|
74
|
+
dtype_map[feature_name] = feature.as_dtype()
|
75
|
+
|
69
76
|
df_local = data.to_pandas()
|
70
77
|
|
71
78
|
# This is because Array will become string (Even though the correct schema is set)
|
@@ -93,6 +100,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
93
100
|
df: pd.DataFrame,
|
94
101
|
keep_order: bool = False,
|
95
102
|
features: Optional[Sequence[core.BaseFeatureSpec]] = None,
|
103
|
+
statement_params: Optional[dict[str, Any]] = None,
|
96
104
|
) -> snowflake.snowpark.DataFrame:
|
97
105
|
# This method is necessary to create the Snowpark Dataframe in correct schema.
|
98
106
|
# However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
|
@@ -100,6 +108,12 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
100
108
|
# Although in this case, the column with array type can get correct ARRAY type, however, the element
|
101
109
|
# type is not preserved, and will become string type. This affect the implementation of convert_from_df.
|
102
110
|
df = pandas_handler.PandasDataFrameHandler.convert_to_df(df)
|
111
|
+
quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
112
|
+
session, statement_params
|
113
|
+
)
|
114
|
+
if quoted_identifiers_ignore_case:
|
115
|
+
df.columns = [str(col).upper() for col in df.columns]
|
116
|
+
|
103
117
|
df_cols = df.columns
|
104
118
|
if df_cols.dtype != np.object_:
|
105
119
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -116,9 +130,47 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
116
130
|
column_names = []
|
117
131
|
columns = []
|
118
132
|
for feature in features:
|
119
|
-
|
120
|
-
|
133
|
+
feature_name = identifier.get_inferred_name(feature.name)
|
134
|
+
if quoted_identifiers_ignore_case:
|
135
|
+
feature_name = feature_name.upper()
|
136
|
+
column_names.append(feature_name)
|
137
|
+
columns.append(F.col(feature_name).cast(feature.as_snowpark_type()))
|
121
138
|
|
122
139
|
sp_df = sp_df.with_columns(column_names, columns)
|
123
140
|
|
124
141
|
return sp_df
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def _is_quoted_identifiers_ignore_case_enabled(
|
145
|
+
session: snowflake.snowpark.Session, statement_params: Optional[dict[str, Any]] = None
|
146
|
+
) -> bool:
|
147
|
+
"""
|
148
|
+
Check if QUOTED_IDENTIFIERS_IGNORE_CASE parameter is enabled.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
session: Snowpark session to check parameter for
|
152
|
+
statement_params: Optional statement parameters to check first
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
bool: True if QUOTED_IDENTIFIERS_IGNORE_CASE is enabled, False otherwise
|
156
|
+
Returns False if the parameter cannot be retrieved (e.g., in stored procedures)
|
157
|
+
"""
|
158
|
+
if statement_params is not None:
|
159
|
+
for key, value in statement_params.items():
|
160
|
+
if key.upper() == "QUOTED_IDENTIFIERS_IGNORE_CASE":
|
161
|
+
parameter_value = str(value)
|
162
|
+
return parameter_value.lower() == "true"
|
163
|
+
|
164
|
+
try:
|
165
|
+
result = session.sql(
|
166
|
+
"SHOW PARAMETERS LIKE 'QUOTED_IDENTIFIERS_IGNORE_CASE' IN SESSION",
|
167
|
+
_emit_ast=False,
|
168
|
+
).collect(_emit_ast=False)
|
169
|
+
|
170
|
+
parameter_value = str(result[0].value)
|
171
|
+
return parameter_value.lower() == "true"
|
172
|
+
|
173
|
+
except Exception:
|
174
|
+
# Parameter query can fail in certain environments (e.g., in stored procedures)
|
175
|
+
# In that case, assume default behavior (case-sensitive)
|
176
|
+
return False
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class TargetPlatform(Enum):
|
5
|
+
WAREHOUSE = "WAREHOUSE"
|
6
|
+
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
7
|
+
|
8
|
+
|
9
|
+
WAREHOUSE_ONLY = [TargetPlatform.WAREHOUSE]
|
10
|
+
SNOWPARK_CONTAINER_SERVICES_ONLY = [TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
11
|
+
BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES = [TargetPlatform.WAREHOUSE, TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
@@ -0,0 +1,9 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class Task(Enum):
|
5
|
+
UNKNOWN = "UNKNOWN"
|
6
|
+
TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
|
7
|
+
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
8
|
+
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
9
|
+
TABULAR_RANKING = "TABULAR_RANKING"
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
-
from enum import Enum
|
3
2
|
from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
|
4
3
|
|
5
4
|
import numpy.typing as npt
|
6
5
|
from typing_extensions import NotRequired
|
7
6
|
|
7
|
+
from snowflake.ml.model.target_platform import TargetPlatform
|
8
|
+
from snowflake.ml.model.task import Task
|
9
|
+
|
8
10
|
if TYPE_CHECKING:
|
9
11
|
import catboost
|
10
12
|
import keras
|
@@ -321,17 +323,7 @@ ModelLoadOption = Union[
|
|
321
323
|
]
|
322
324
|
|
323
325
|
|
324
|
-
|
325
|
-
UNKNOWN = "UNKNOWN"
|
326
|
-
TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
|
327
|
-
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
328
|
-
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
329
|
-
TABULAR_RANKING = "TABULAR_RANKING"
|
330
|
-
|
331
|
-
|
332
|
-
class TargetPlatform(Enum):
|
333
|
-
WAREHOUSE = "WAREHOUSE"
|
334
|
-
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
326
|
+
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
335
327
|
|
336
328
|
|
337
|
-
|
329
|
+
__all__ = ["TargetPlatform", "Task"]
|
@@ -60,6 +60,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
|
|
60
60
|
),
|
61
61
|
input_types=[T.BinaryType()],
|
62
62
|
packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
|
63
|
+
imports=[], # Prevents unnecessary import resolution.
|
63
64
|
name=accumulator,
|
64
65
|
is_permanent=False,
|
65
66
|
replace=True,
|
@@ -175,6 +176,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
|
|
175
176
|
),
|
176
177
|
input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
|
177
178
|
packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
|
179
|
+
imports=[], # Prevents unnecessary import resolution.
|
178
180
|
name=sharded_dot_and_sum_computer,
|
179
181
|
is_permanent=False,
|
180
182
|
replace=True,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from types import ModuleType
|
2
|
-
from typing import Any, Optional, Union
|
2
|
+
from typing import Any, Optional, Protocol, Union
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from absl.logging import logging
|
@@ -8,7 +8,7 @@ from snowflake.ml._internal import env, platform_capabilities, telemetry
|
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
|
-
from snowflake.ml.model import model_signature,
|
11
|
+
from snowflake.ml.model import model_signature, target_platform, task, type_hints
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
@@ -20,6 +20,14 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
20
20
|
logger = logging.getLogger(__name__)
|
21
21
|
|
22
22
|
|
23
|
+
class EventHandler(Protocol):
|
24
|
+
"""Protocol defining the interface for event handlers used during model operations."""
|
25
|
+
|
26
|
+
def update(self, message: str) -> None:
|
27
|
+
"""Update with a progress message."""
|
28
|
+
...
|
29
|
+
|
30
|
+
|
23
31
|
class ModelManager:
|
24
32
|
def __init__(
|
25
33
|
self,
|
@@ -41,7 +49,7 @@ class ModelManager:
|
|
41
49
|
def log_model(
|
42
50
|
self,
|
43
51
|
*,
|
44
|
-
model: Union[
|
52
|
+
model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
|
45
53
|
model_name: str,
|
46
54
|
version_name: Optional[str] = None,
|
47
55
|
comment: Optional[str] = None,
|
@@ -50,16 +58,17 @@ class ModelManager:
|
|
50
58
|
pip_requirements: Optional[list[str]] = None,
|
51
59
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
52
60
|
resource_constraint: Optional[dict[str, str]] = None,
|
53
|
-
target_platforms: Optional[list[
|
61
|
+
target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
|
54
62
|
python_version: Optional[str] = None,
|
55
63
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
56
|
-
sample_input_data: Optional[
|
64
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
57
65
|
user_files: Optional[dict[str, list[str]]] = None,
|
58
66
|
code_paths: Optional[list[str]] = None,
|
59
67
|
ext_modules: Optional[list[ModuleType]] = None,
|
60
|
-
task:
|
61
|
-
options: Optional[
|
68
|
+
task: type_hints.Task = task.Task.UNKNOWN,
|
69
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
62
70
|
statement_params: Optional[dict[str, Any]] = None,
|
71
|
+
event_handler: EventHandler,
|
63
72
|
) -> model_version_impl.ModelVersion:
|
64
73
|
|
65
74
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
@@ -143,11 +152,12 @@ class ModelManager:
|
|
143
152
|
task=task,
|
144
153
|
options=options,
|
145
154
|
statement_params=statement_params,
|
155
|
+
event_handler=event_handler,
|
146
156
|
)
|
147
157
|
|
148
158
|
def _log_model(
|
149
159
|
self,
|
150
|
-
model:
|
160
|
+
model: type_hints.SupportedModelType,
|
151
161
|
*,
|
152
162
|
model_name: str,
|
153
163
|
version_name: str,
|
@@ -157,16 +167,17 @@ class ModelManager:
|
|
157
167
|
pip_requirements: Optional[list[str]] = None,
|
158
168
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
159
169
|
resource_constraint: Optional[dict[str, str]] = None,
|
160
|
-
target_platforms: Optional[list[
|
170
|
+
target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
|
161
171
|
python_version: Optional[str] = None,
|
162
172
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
163
|
-
sample_input_data: Optional[
|
173
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
164
174
|
user_files: Optional[dict[str, list[str]]] = None,
|
165
175
|
code_paths: Optional[list[str]] = None,
|
166
176
|
ext_modules: Optional[list[ModuleType]] = None,
|
167
|
-
task:
|
168
|
-
options: Optional[
|
177
|
+
task: type_hints.Task = task.Task.UNKNOWN,
|
178
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
169
179
|
statement_params: Optional[dict[str, Any]] = None,
|
180
|
+
event_handler: EventHandler,
|
170
181
|
) -> model_version_impl.ModelVersion:
|
171
182
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
172
183
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
@@ -215,7 +226,7 @@ class ModelManager:
|
|
215
226
|
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
216
227
|
if target_platforms:
|
217
228
|
# Convert any string target platforms to TargetPlatform objects
|
218
|
-
platforms = [
|
229
|
+
platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
|
219
230
|
else:
|
220
231
|
# Default the target platform to warehouse if not specified and any table function exists
|
221
232
|
if options and (
|
@@ -231,7 +242,7 @@ class ModelManager:
|
|
231
242
|
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
232
243
|
'Default to `target_platforms=["WAREHOUSE"]`.'
|
233
244
|
)
|
234
|
-
platforms = [
|
245
|
+
platforms = [target_platform.TargetPlatform.WAREHOUSE]
|
235
246
|
|
236
247
|
# Default the target platform to SPCS if not specified when running in ML runtime
|
237
248
|
if not platforms and env.IN_ML_RUNTIME:
|
@@ -239,7 +250,7 @@ class ModelManager:
|
|
239
250
|
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
240
251
|
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
241
252
|
)
|
242
|
-
platforms = [
|
253
|
+
platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
243
254
|
|
244
255
|
if artifact_repository_map:
|
245
256
|
for channel, artifact_repository_name in artifact_repository_map.items():
|
@@ -254,6 +265,7 @@ class ModelManager:
|
|
254
265
|
)
|
255
266
|
|
256
267
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
268
|
+
event_handler.update("📦 Packaging model...")
|
257
269
|
|
258
270
|
# Extract save_location from options if present
|
259
271
|
save_location = None
|
@@ -292,6 +304,7 @@ class ModelManager:
|
|
292
304
|
)
|
293
305
|
|
294
306
|
logger.info("Start creating MODEL object for you in the Snowflake.")
|
307
|
+
event_handler.update("🏗️ Creating model object in Snowflake...")
|
295
308
|
|
296
309
|
self._model_ops.create_from_stage(
|
297
310
|
composed_model=mc,
|
@@ -331,6 +344,8 @@ class ModelManager:
|
|
331
344
|
statement_params=statement_params,
|
332
345
|
)
|
333
346
|
|
347
|
+
event_handler.update("✅ Model logged successfully!")
|
348
|
+
|
334
349
|
return mv
|
335
350
|
|
336
351
|
def get_model(
|