snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +71 -0
- snowflake/ml/_internal/utils/service_logger.py +4 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
- snowflake/ml/data/data_connector.py +43 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +3 -2
- snowflake/ml/dataset/dataset_reader.py +22 -6
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +5 -3
- snowflake/ml/jobs/_utils/query_helper.py +20 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
- snowflake/ml/jobs/_utils/spec_utils.py +21 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +137 -37
- snowflake/ml/jobs/manager.py +228 -153
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +324 -138
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +9 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +47 -15
- snowflake/ml/registry/registry.py +109 -64
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ def d2_absolute_error_score(
|
|
29
29
|
y_pred_col_names: Union[str, list[str]],
|
30
30
|
sample_weight_col_name: Optional[str] = None,
|
31
31
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
32
|
-
) -> Union[float, npt.NDArray[np.
|
32
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
33
33
|
"""
|
34
34
|
:math:`D^2` regression score function, \
|
35
35
|
fraction of absolute error explained.
|
@@ -111,7 +111,7 @@ def d2_absolute_error_score(
|
|
111
111
|
|
112
112
|
kwargs = telemetry.get_sproc_statement_params_kwargs(d2_absolute_error_score_anon_sproc, statement_params)
|
113
113
|
result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session, **kwargs))
|
114
|
-
score: Union[float, npt.NDArray[np.
|
114
|
+
score: Union[float, npt.NDArray[np.float64]] = result_object
|
115
115
|
return score
|
116
116
|
|
117
117
|
|
@@ -124,7 +124,7 @@ def d2_pinball_score(
|
|
124
124
|
sample_weight_col_name: Optional[str] = None,
|
125
125
|
alpha: float = 0.5,
|
126
126
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
127
|
-
) -> Union[float, npt.NDArray[np.
|
127
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
128
128
|
"""
|
129
129
|
:math:`D^2` regression score function, fraction of pinball loss explained.
|
130
130
|
|
@@ -211,7 +211,7 @@ def d2_pinball_score(
|
|
211
211
|
kwargs = telemetry.get_sproc_statement_params_kwargs(d2_pinball_score_anon_sproc, statement_params)
|
212
212
|
result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session, **kwargs))
|
213
213
|
|
214
|
-
score: Union[float, npt.NDArray[np.
|
214
|
+
score: Union[float, npt.NDArray[np.float64]] = result_object
|
215
215
|
return score
|
216
216
|
|
217
217
|
|
@@ -224,7 +224,7 @@ def explained_variance_score(
|
|
224
224
|
sample_weight_col_name: Optional[str] = None,
|
225
225
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
226
226
|
force_finite: bool = True,
|
227
|
-
) -> Union[float, npt.NDArray[np.
|
227
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
228
228
|
"""
|
229
229
|
Explained variance regression score function.
|
230
230
|
|
@@ -326,7 +326,7 @@ def explained_variance_score(
|
|
326
326
|
|
327
327
|
kwargs = telemetry.get_sproc_statement_params_kwargs(explained_variance_score_anon_sproc, statement_params)
|
328
328
|
result_object = result.deserialize(session, explained_variance_score_anon_sproc(session, **kwargs))
|
329
|
-
score: Union[float, npt.NDArray[np.
|
329
|
+
score: Union[float, npt.NDArray[np.float64]] = result_object
|
330
330
|
return score
|
331
331
|
|
332
332
|
|
@@ -338,7 +338,7 @@ def mean_absolute_error(
|
|
338
338
|
y_pred_col_names: Union[str, list[str]],
|
339
339
|
sample_weight_col_name: Optional[str] = None,
|
340
340
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
341
|
-
) -> Union[float, npt.NDArray[np.
|
341
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
342
342
|
"""
|
343
343
|
Mean absolute error regression loss.
|
344
344
|
|
@@ -411,7 +411,7 @@ def mean_absolute_percentage_error(
|
|
411
411
|
y_pred_col_names: Union[str, list[str]],
|
412
412
|
sample_weight_col_name: Optional[str] = None,
|
413
413
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
414
|
-
) -> Union[float, npt.NDArray[np.
|
414
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
415
415
|
"""
|
416
416
|
Mean absolute percentage error (MAPE) regression loss.
|
417
417
|
|
@@ -495,7 +495,7 @@ def mean_squared_error(
|
|
495
495
|
sample_weight_col_name: Optional[str] = None,
|
496
496
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
497
497
|
squared: bool = True,
|
498
|
-
) -> Union[float, npt.NDArray[np.
|
498
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
499
499
|
"""
|
500
500
|
Mean squared error regression loss.
|
501
501
|
|
@@ -264,6 +264,7 @@ def plot_force(
|
|
264
264
|
def plot_influence_sensitivity(
|
265
265
|
shap_values: type_hints.SupportedDataType,
|
266
266
|
feature_values: type_hints.SupportedDataType,
|
267
|
+
infer_is_categorical: bool = True,
|
267
268
|
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
268
269
|
) -> Any:
|
269
270
|
"""
|
@@ -274,6 +275,8 @@ def plot_influence_sensitivity(
|
|
274
275
|
Args:
|
275
276
|
shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
|
276
277
|
feature_values: pandas Series or 2D array containing the feature values for the same feature
|
278
|
+
infer_is_categorical: If True, the function will infer if the feature is categorical
|
279
|
+
based on the number of unique values.
|
277
280
|
figsize: tuple of (width, height) for the plot
|
278
281
|
|
279
282
|
Returns:
|
@@ -294,7 +297,7 @@ def plot_influence_sensitivity(
|
|
294
297
|
elif feature_values_df.shape[0] != shap_values_df.shape[0]:
|
295
298
|
raise ValueError("Feature values and SHAP values must have the same number of rows.")
|
296
299
|
|
297
|
-
scatter = _create_scatter_plot(feature_values, shap_values, figsize)
|
300
|
+
scatter = _create_scatter_plot(feature_values, shap_values, infer_is_categorical, figsize)
|
298
301
|
return st.altair_chart(scatter) if use_streamlit else scatter
|
299
302
|
|
300
303
|
|
@@ -322,11 +325,13 @@ def _prepare_feature_values_for_streamlit(
|
|
322
325
|
return feature_values, shap_values, st
|
323
326
|
|
324
327
|
|
325
|
-
def _create_scatter_plot(
|
328
|
+
def _create_scatter_plot(
|
329
|
+
feature_values: pd.Series, shap_values: pd.Series, infer_is_categorical: bool, figsize: tuple[float, float]
|
330
|
+
) -> alt.Chart:
|
326
331
|
unique_vals = np.sort(np.unique(feature_values.values))
|
327
332
|
max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
|
328
333
|
points_per_value = len(feature_values.values) / len(unique_vals)
|
329
|
-
is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
|
334
|
+
is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10 if infer_is_categorical else False
|
330
335
|
|
331
336
|
kwargs = (
|
332
337
|
{
|
@@ -403,9 +408,11 @@ def plot_violin(
|
|
403
408
|
.transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
|
404
409
|
.mark_area(orient="vertical")
|
405
410
|
.encode(
|
406
|
-
y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=
|
411
|
+
y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=False),
|
407
412
|
x=alt.X("shap_value:Q", title="SHAP Value"),
|
408
|
-
row=alt.Row(
|
413
|
+
row=alt.Row(
|
414
|
+
"feature_name:N", sort=column_sort_order, header=alt.Header(labelAngle=0, labelAlign="left")
|
415
|
+
).spacing(0),
|
409
416
|
color=alt.Color("feature_name:N", legend=None),
|
410
417
|
tooltip=["feature_name", "shap_value"],
|
411
418
|
)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from types import ModuleType
|
2
|
-
from typing import Any, Optional, Union
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from absl.logging import logging
|
@@ -8,7 +8,7 @@ from snowflake.ml._internal import env, platform_capabilities, telemetry
|
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
|
-
from snowflake.ml.model import model_signature,
|
11
|
+
from snowflake.ml.model import model_signature, target_platform, task, type_hints
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
@@ -17,6 +17,9 @@ from snowflake.ml.model._packager.model_meta import model_meta
|
|
17
17
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
18
18
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
19
19
|
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
22
|
+
|
20
23
|
logger = logging.getLogger(__name__)
|
21
24
|
|
22
25
|
|
@@ -41,7 +44,7 @@ class ModelManager:
|
|
41
44
|
def log_model(
|
42
45
|
self,
|
43
46
|
*,
|
44
|
-
model: Union[
|
47
|
+
model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
|
45
48
|
model_name: str,
|
46
49
|
version_name: Optional[str] = None,
|
47
50
|
comment: Optional[str] = None,
|
@@ -50,16 +53,18 @@ class ModelManager:
|
|
50
53
|
pip_requirements: Optional[list[str]] = None,
|
51
54
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
52
55
|
resource_constraint: Optional[dict[str, str]] = None,
|
53
|
-
target_platforms: Optional[list[
|
56
|
+
target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
|
54
57
|
python_version: Optional[str] = None,
|
55
58
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
56
|
-
sample_input_data: Optional[
|
59
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
57
60
|
user_files: Optional[dict[str, list[str]]] = None,
|
58
61
|
code_paths: Optional[list[str]] = None,
|
59
62
|
ext_modules: Optional[list[ModuleType]] = None,
|
60
|
-
task:
|
61
|
-
|
63
|
+
task: type_hints.Task = task.Task.UNKNOWN,
|
64
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
65
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
62
66
|
statement_params: Optional[dict[str, Any]] = None,
|
67
|
+
progress_status: Optional[Any] = None,
|
63
68
|
) -> model_version_impl.ModelVersion:
|
64
69
|
|
65
70
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
@@ -141,13 +146,15 @@ class ModelManager:
|
|
141
146
|
code_paths=code_paths,
|
142
147
|
ext_modules=ext_modules,
|
143
148
|
task=task,
|
149
|
+
experiment_info=experiment_info,
|
144
150
|
options=options,
|
145
151
|
statement_params=statement_params,
|
152
|
+
progress_status=progress_status,
|
146
153
|
)
|
147
154
|
|
148
155
|
def _log_model(
|
149
156
|
self,
|
150
|
-
model:
|
157
|
+
model: type_hints.SupportedModelType,
|
151
158
|
*,
|
152
159
|
model_name: str,
|
153
160
|
version_name: str,
|
@@ -157,16 +164,18 @@ class ModelManager:
|
|
157
164
|
pip_requirements: Optional[list[str]] = None,
|
158
165
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
159
166
|
resource_constraint: Optional[dict[str, str]] = None,
|
160
|
-
target_platforms: Optional[list[
|
167
|
+
target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
|
161
168
|
python_version: Optional[str] = None,
|
162
169
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
163
|
-
sample_input_data: Optional[
|
170
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
164
171
|
user_files: Optional[dict[str, list[str]]] = None,
|
165
172
|
code_paths: Optional[list[str]] = None,
|
166
173
|
ext_modules: Optional[list[ModuleType]] = None,
|
167
|
-
task:
|
168
|
-
|
174
|
+
task: type_hints.Task = task.Task.UNKNOWN,
|
175
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
176
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
169
177
|
statement_params: Optional[dict[str, Any]] = None,
|
178
|
+
progress_status: Optional[Any] = None,
|
170
179
|
) -> model_version_impl.ModelVersion:
|
171
180
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
172
181
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
@@ -215,7 +224,7 @@ class ModelManager:
|
|
215
224
|
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
216
225
|
if target_platforms:
|
217
226
|
# Convert any string target platforms to TargetPlatform objects
|
218
|
-
platforms = [
|
227
|
+
platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
|
219
228
|
else:
|
220
229
|
# Default the target platform to warehouse if not specified and any table function exists
|
221
230
|
if options and (
|
@@ -231,7 +240,7 @@ class ModelManager:
|
|
231
240
|
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
232
241
|
'Default to `target_platforms=["WAREHOUSE"]`.'
|
233
242
|
)
|
234
|
-
platforms = [
|
243
|
+
platforms = [target_platform.TargetPlatform.WAREHOUSE]
|
235
244
|
|
236
245
|
# Default the target platform to SPCS if not specified when running in ML runtime
|
237
246
|
if not platforms and env.IN_ML_RUNTIME:
|
@@ -239,7 +248,7 @@ class ModelManager:
|
|
239
248
|
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
240
249
|
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
241
250
|
)
|
242
|
-
platforms = [
|
251
|
+
platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
243
252
|
|
244
253
|
if artifact_repository_map:
|
245
254
|
for channel, artifact_repository_name in artifact_repository_map.items():
|
@@ -254,6 +263,9 @@ class ModelManager:
|
|
254
263
|
)
|
255
264
|
|
256
265
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
266
|
+
if progress_status:
|
267
|
+
progress_status.update("packaging model...")
|
268
|
+
progress_status.increment()
|
257
269
|
|
258
270
|
# Extract save_location from options if present
|
259
271
|
save_location = None
|
@@ -267,6 +279,11 @@ class ModelManager:
|
|
267
279
|
statement_params=statement_params,
|
268
280
|
save_location=save_location,
|
269
281
|
)
|
282
|
+
|
283
|
+
if progress_status:
|
284
|
+
progress_status.update("creating model manifest...")
|
285
|
+
progress_status.increment()
|
286
|
+
|
270
287
|
model_metadata: model_meta.ModelMetadata = mc.save(
|
271
288
|
name=model_name_id.resolved(),
|
272
289
|
model=model,
|
@@ -283,7 +300,12 @@ class ModelManager:
|
|
283
300
|
ext_modules=ext_modules,
|
284
301
|
options=options,
|
285
302
|
task=task,
|
303
|
+
experiment_info=experiment_info,
|
286
304
|
)
|
305
|
+
|
306
|
+
if progress_status:
|
307
|
+
progress_status.update("uploading model files...")
|
308
|
+
progress_status.increment()
|
287
309
|
statement_params = telemetry.add_statement_params_custom_tags(
|
288
310
|
statement_params, model_metadata.telemetry_metadata()
|
289
311
|
)
|
@@ -292,6 +314,9 @@ class ModelManager:
|
|
292
314
|
)
|
293
315
|
|
294
316
|
logger.info("Start creating MODEL object for you in the Snowflake.")
|
317
|
+
if progress_status:
|
318
|
+
progress_status.update("creating model object in Snowflake...")
|
319
|
+
progress_status.increment()
|
295
320
|
|
296
321
|
self._model_ops.create_from_stage(
|
297
322
|
composed_model=mc,
|
@@ -318,6 +343,10 @@ class ModelManager:
|
|
318
343
|
version_name=version_name_id,
|
319
344
|
)
|
320
345
|
|
346
|
+
if progress_status:
|
347
|
+
progress_status.update("setting model metadata...")
|
348
|
+
progress_status.increment()
|
349
|
+
|
321
350
|
if comment:
|
322
351
|
mv.comment = comment
|
323
352
|
|
@@ -331,6 +360,9 @@ class ModelManager:
|
|
331
360
|
statement_params=statement_params,
|
332
361
|
)
|
333
362
|
|
363
|
+
if progress_status:
|
364
|
+
progress_status.update("model logged successfully!")
|
365
|
+
|
334
366
|
return mv
|
335
367
|
|
336
368
|
def get_model(
|
@@ -6,12 +6,15 @@ import pandas as pd
|
|
6
6
|
|
7
7
|
from snowflake import snowpark
|
8
8
|
from snowflake.ml._internal import telemetry
|
9
|
-
from snowflake.ml._internal.utils import sql_identifier
|
9
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
10
10
|
from snowflake.ml.model import (
|
11
11
|
Model,
|
12
12
|
ModelVersion,
|
13
|
+
event_handler,
|
13
14
|
model_signature,
|
14
|
-
|
15
|
+
target_platform,
|
16
|
+
task,
|
17
|
+
type_hints,
|
15
18
|
)
|
16
19
|
from snowflake.ml.model._client.model import model_version_impl
|
17
20
|
from snowflake.ml.monitoring import model_monitor
|
@@ -75,20 +78,30 @@ class Registry:
|
|
75
78
|
else sql_identifier.SqlIdentifier("PUBLIC")
|
76
79
|
)
|
77
80
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
+
database_results = (
|
82
|
+
query_result_checker.SqlResultValidator(
|
83
|
+
session, f"""SHOW DATABASES LIKE '{self._database_name.resolved()}';"""
|
84
|
+
)
|
85
|
+
.has_column("name", allow_empty=True)
|
86
|
+
.validate()
|
87
|
+
)
|
81
88
|
|
82
|
-
|
89
|
+
db_names = [row["name"] for row in database_results]
|
90
|
+
if not self._database_name.resolved() in db_names:
|
83
91
|
raise ValueError(f"Database {self._database_name} does not exist.")
|
84
92
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
93
|
+
schema_results = (
|
94
|
+
query_result_checker.SqlResultValidator(
|
95
|
+
session,
|
96
|
+
f"""SHOW SCHEMAS LIKE '{self._schema_name.resolved()}'
|
97
|
+
IN DATABASE {self._database_name.identifier()};""",
|
98
|
+
)
|
99
|
+
.has_column("name", allow_empty=True)
|
100
|
+
.validate()
|
101
|
+
)
|
90
102
|
|
91
|
-
|
103
|
+
schema_names = [row["name"] for row in schema_results]
|
104
|
+
if not self._schema_name.resolved() in schema_names:
|
92
105
|
raise ValueError(f"Schema {self._schema_name} does not exist.")
|
93
106
|
|
94
107
|
self._model_manager = model_manager.ModelManager(
|
@@ -119,7 +132,7 @@ class Registry:
|
|
119
132
|
@overload
|
120
133
|
def log_model(
|
121
134
|
self,
|
122
|
-
model:
|
135
|
+
model: type_hints.SupportedModelType,
|
123
136
|
*,
|
124
137
|
model_name: str,
|
125
138
|
version_name: Optional[str] = None,
|
@@ -129,15 +142,15 @@ class Registry:
|
|
129
142
|
pip_requirements: Optional[list[str]] = None,
|
130
143
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
131
144
|
resource_constraint: Optional[dict[str, str]] = None,
|
132
|
-
target_platforms: Optional[list[
|
145
|
+
target_platforms: Optional[list[Union[target_platform.TargetPlatform, str]]] = None,
|
133
146
|
python_version: Optional[str] = None,
|
134
147
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
135
|
-
sample_input_data: Optional[
|
148
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
136
149
|
user_files: Optional[dict[str, list[str]]] = None,
|
137
150
|
code_paths: Optional[list[str]] = None,
|
138
151
|
ext_modules: Optional[list[ModuleType]] = None,
|
139
|
-
task:
|
140
|
-
options: Optional[
|
152
|
+
task: task.Task = task.Task.UNKNOWN,
|
153
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
141
154
|
) -> ModelVersion:
|
142
155
|
"""
|
143
156
|
Log a model with various parameters and metadata, or a ModelVersion object.
|
@@ -159,12 +172,12 @@ class Registry:
|
|
159
172
|
to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
|
160
173
|
is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
|
161
174
|
pip_requirements: List of Pip package specifications. Defaults to None.
|
162
|
-
Models
|
163
|
-
|
164
|
-
|
165
|
-
|
175
|
+
Models running in a Snowflake Warehouse must also specify a pip artifact repository (see
|
176
|
+
`artifact_repository_map`). Otherwise, models with pip requirements are runnable only in Snowpark
|
177
|
+
Container Services. See
|
178
|
+
https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
166
179
|
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
167
|
-
repositories. Defaults to None. Currently, the mapping applies only to
|
180
|
+
repositories. Defaults to None. Currently, the mapping applies only to Warehouse execution.
|
168
181
|
Note : This feature is currently in Public Preview.
|
169
182
|
Format: {channel_name: artifact_repository_name}, where:
|
170
183
|
- channel_name: Currently must be 'pip'.
|
@@ -172,10 +185,13 @@ class Registry:
|
|
172
185
|
`snowflake.snowpark.pypi_shared_repository`.
|
173
186
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
174
187
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
175
|
-
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
|
176
|
-
- ["WAREHOUSE"]
|
177
|
-
- ["SNOWPARK_CONTAINER_SERVICES"]
|
178
|
-
|
188
|
+
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES", or a target platform constant:
|
189
|
+
- ["WAREHOUSE"] or snowflake.ml.model.target_platform.WAREHOUSE_ONLY (Warehouse only)
|
190
|
+
- ["SNOWPARK_CONTAINER_SERVICES"] or
|
191
|
+
snowflake.ml.model.target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY
|
192
|
+
(Snowpark Container Services only)
|
193
|
+
- ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] or
|
194
|
+
snowflake.ml.model.target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES (Both)
|
179
195
|
Defaults to None. When None, the target platforms will be both.
|
180
196
|
python_version: Python version in which the model is run. Defaults to None.
|
181
197
|
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
@@ -206,7 +222,8 @@ class Registry:
|
|
206
222
|
- target_methods: List of target methods to register when logging the model.
|
207
223
|
This option is not used in MLFlow models. Defaults to None, in which case the model handler's
|
208
224
|
default target methods will be used.
|
209
|
-
- save_location:
|
225
|
+
- save_location: Local directory to save the the serialized model files first before
|
226
|
+
uploading to Snowflake. This is useful when default tmp directory is not writable.
|
210
227
|
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
211
228
|
values with the desired options.
|
212
229
|
|
@@ -263,7 +280,7 @@ class Registry:
|
|
263
280
|
)
|
264
281
|
def log_model(
|
265
282
|
self,
|
266
|
-
model: Union[
|
283
|
+
model: Union[type_hints.SupportedModelType, ModelVersion],
|
267
284
|
*,
|
268
285
|
model_name: str,
|
269
286
|
version_name: Optional[str] = None,
|
@@ -273,15 +290,15 @@ class Registry:
|
|
273
290
|
pip_requirements: Optional[list[str]] = None,
|
274
291
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
275
292
|
resource_constraint: Optional[dict[str, str]] = None,
|
276
|
-
target_platforms: Optional[list[
|
293
|
+
target_platforms: Optional[list[Union[target_platform.TargetPlatform, str]]] = None,
|
277
294
|
python_version: Optional[str] = None,
|
278
295
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
279
|
-
sample_input_data: Optional[
|
296
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
280
297
|
user_files: Optional[dict[str, list[str]]] = None,
|
281
298
|
code_paths: Optional[list[str]] = None,
|
282
299
|
ext_modules: Optional[list[ModuleType]] = None,
|
283
|
-
task:
|
284
|
-
options: Optional[
|
300
|
+
task: task.Task = task.Task.UNKNOWN,
|
301
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
285
302
|
) -> ModelVersion:
|
286
303
|
"""
|
287
304
|
Log a model with various parameters and metadata, or a ModelVersion object.
|
@@ -303,12 +320,12 @@ class Registry:
|
|
303
320
|
to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
|
304
321
|
is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
|
305
322
|
pip_requirements: List of Pip package specifications. Defaults to None.
|
306
|
-
Models
|
307
|
-
|
308
|
-
|
309
|
-
|
323
|
+
Models running in a Snowflake Warehouse must also specify a pip artifact repository (see
|
324
|
+
`artifact_repository_map`). Otherwise, models with pip requirements are runnable only in Snowpark
|
325
|
+
Container Services. See
|
326
|
+
https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
310
327
|
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
311
|
-
repositories. Defaults to None. Currently, the mapping applies only to
|
328
|
+
repositories. Defaults to None. Currently, the mapping applies only to Warehouse execution.
|
312
329
|
Note : This feature is currently in Public Preview.
|
313
330
|
Format: {channel_name: artifact_repository_name}, where:
|
314
331
|
- channel_name: Currently must be 'pip'.
|
@@ -316,10 +333,13 @@ class Registry:
|
|
316
333
|
`snowflake.snowpark.pypi_shared_repository`.
|
317
334
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
318
335
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
319
|
-
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
|
320
|
-
- ["WAREHOUSE"]
|
321
|
-
- ["SNOWPARK_CONTAINER_SERVICES"]
|
322
|
-
|
336
|
+
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES", or a target platform constant:
|
337
|
+
- ["WAREHOUSE"] or snowflake.ml.model.target_platform.WAREHOUSE_ONLY (Warehouse only)
|
338
|
+
- ["SNOWPARK_CONTAINER_SERVICES"] or
|
339
|
+
snowflake.ml.model.target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY
|
340
|
+
(Snowpark Container Services only)
|
341
|
+
- ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] or
|
342
|
+
snowflake.ml.model.target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES (Both)
|
323
343
|
Defaults to None. When None, the target platforms will be both.
|
324
344
|
python_version: Python version in which the model is run. Defaults to None.
|
325
345
|
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
@@ -366,6 +386,7 @@ class Registry:
|
|
366
386
|
|
367
387
|
Raises:
|
368
388
|
ValueError: If extra arguments are specified ModelVersion is provided.
|
389
|
+
Exception: If the model logging fails.
|
369
390
|
|
370
391
|
Returns:
|
371
392
|
ModelVersion: ModelVersion object corresponding to the model just logged.
|
@@ -418,10 +439,10 @@ class Registry:
|
|
418
439
|
raise ValueError(
|
419
440
|
"When calling log_model with a ModelVersion, only model_name and version_name may be specified."
|
420
441
|
)
|
421
|
-
if task is not
|
442
|
+
if task is not type_hints.Task.UNKNOWN:
|
422
443
|
raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
|
423
444
|
|
424
|
-
if pip_requirements and not artifact_repository_map:
|
445
|
+
if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
|
425
446
|
warnings.warn(
|
426
447
|
"Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
|
427
448
|
"without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
|
@@ -429,27 +450,42 @@ class Registry:
|
|
429
450
|
category=UserWarning,
|
430
451
|
stacklevel=1,
|
431
452
|
)
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
+
|
454
|
+
registry_event_handler = event_handler.ModelEventHandler()
|
455
|
+
with registry_event_handler.status("Logging model", total=6) as status:
|
456
|
+
# Step 1: Validation and setup
|
457
|
+
status.update("validating model and dependencies...")
|
458
|
+
status.increment()
|
459
|
+
|
460
|
+
# Perform the actual model logging
|
461
|
+
try:
|
462
|
+
result = self._model_manager.log_model(
|
463
|
+
model=model,
|
464
|
+
model_name=model_name,
|
465
|
+
version_name=version_name,
|
466
|
+
comment=comment,
|
467
|
+
metrics=metrics,
|
468
|
+
conda_dependencies=conda_dependencies,
|
469
|
+
pip_requirements=pip_requirements,
|
470
|
+
artifact_repository_map=artifact_repository_map,
|
471
|
+
resource_constraint=resource_constraint,
|
472
|
+
target_platforms=target_platforms,
|
473
|
+
python_version=python_version,
|
474
|
+
signatures=signatures,
|
475
|
+
sample_input_data=sample_input_data,
|
476
|
+
user_files=user_files,
|
477
|
+
code_paths=code_paths,
|
478
|
+
ext_modules=ext_modules,
|
479
|
+
task=task,
|
480
|
+
options=options,
|
481
|
+
statement_params=statement_params,
|
482
|
+
progress_status=status,
|
483
|
+
)
|
484
|
+
status.update(label="Model logged successfully.", state="complete", expanded=False)
|
485
|
+
return result
|
486
|
+
except Exception as e:
|
487
|
+
status.update(label="Model logging failed.", state="error", expanded=False)
|
488
|
+
raise e
|
453
489
|
|
454
490
|
@telemetry.send_api_usage_telemetry(
|
455
491
|
project=_TELEMETRY_PROJECT,
|
@@ -626,3 +662,12 @@ class Registry:
|
|
626
662
|
if not self.enable_monitoring:
|
627
663
|
raise ValueError(_MODEL_MONITORING_DISABLED_ERROR)
|
628
664
|
self._model_monitor_manager.delete_monitor(name)
|
665
|
+
|
666
|
+
@staticmethod
|
667
|
+
def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
|
668
|
+
"""Returns True if warehouse is a target platform (None defaults to True)."""
|
669
|
+
return (
|
670
|
+
target_platforms is None
|
671
|
+
or type_hints.TargetPlatform.WAREHOUSE in target_platforms
|
672
|
+
or "WAREHOUSE" in target_platforms
|
673
|
+
)
|
snowflake/ml/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
2
|
-
VERSION = "1.
|
2
|
+
VERSION = "1.9.1"
|