snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +82 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/_internal/utils/jwt_generator.py +141 -0
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +18 -18
- snowflake/ml/model/_client/model/model_version_impl.py +24 -8
- snowflake/ml/model/_client/ops/model_ops.py +2 -6
- snowflake/ml/model/_client/ops/service_ops.py +12 -7
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/stage.py +1 -1
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +10 -9
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/pandas_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
- snowflake/ml/model/_signatures/utils.py +0 -1
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/pipeline/pipeline.py +6 -176
- snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
- snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +5 -170
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
- snowflake/ml/monitoring/model_monitor.py +26 -11
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +53 -34
- snowflake/ml/utils/authentication.py +75 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +120 -53
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +71 -74
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- snowflake/ml/monitoring/entities/output_score_type.py +0 -90
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -6,23 +6,49 @@ from snowflake.ml.model._client.model import model_version_impl
|
|
6
6
|
|
7
7
|
@dataclass
|
8
8
|
class ModelMonitorSourceConfig:
|
9
|
+
"""Configuration for the source of data to be monitored."""
|
10
|
+
|
9
11
|
source: str
|
12
|
+
"""Name of table or view containing monitoring data."""
|
13
|
+
|
10
14
|
timestamp_column: str
|
15
|
+
"""Name of column in the source containing timestamp."""
|
16
|
+
|
11
17
|
id_columns: List[str]
|
18
|
+
"""List of columns in the source containing unique identifiers."""
|
19
|
+
|
12
20
|
prediction_score_columns: Optional[List[str]] = None
|
21
|
+
"""List of columns in the source containing prediction scores.
|
22
|
+
Can be regression scores for regression models and probability scores for classification models."""
|
23
|
+
|
13
24
|
prediction_class_columns: Optional[List[str]] = None
|
25
|
+
"""List of columns in the source containing prediction classes for classification models."""
|
26
|
+
|
14
27
|
actual_score_columns: Optional[List[str]] = None
|
28
|
+
"""List of columns in the source containing actual scores."""
|
29
|
+
|
15
30
|
actual_class_columns: Optional[List[str]] = None
|
31
|
+
"""List of columns in the source containing actual classes for classification models."""
|
32
|
+
|
16
33
|
baseline: Optional[str] = None
|
34
|
+
"""Name of table containing the baseline data."""
|
17
35
|
|
18
36
|
|
19
37
|
@dataclass
|
20
38
|
class ModelMonitorConfig:
|
39
|
+
"""Configuration for the Model Monitor."""
|
40
|
+
|
21
41
|
model_version: model_version_impl.ModelVersion
|
42
|
+
"""Model version to monitor."""
|
22
43
|
|
23
|
-
# Python model function name
|
24
44
|
model_function_name: str
|
45
|
+
"""Function name in the model to monitor."""
|
46
|
+
|
25
47
|
background_compute_warehouse_name: str
|
26
|
-
|
48
|
+
"""Name of the warehouse to use for background compute."""
|
49
|
+
|
27
50
|
refresh_interval: str = "1 hour"
|
51
|
+
"""Interval at which to refresh the monitoring data."""
|
52
|
+
|
28
53
|
aggregation_window: str = "1 day"
|
54
|
+
"""Window for aggregating monitoring data."""
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from snowflake import snowpark
|
1
2
|
from snowflake.ml._internal import telemetry
|
2
3
|
from snowflake.ml._internal.utils import sql_identifier
|
4
|
+
from snowflake.ml.monitoring import model_monitor_version
|
3
5
|
from snowflake.ml.monitoring._client import model_monitor_sql_client
|
4
6
|
|
5
7
|
|
@@ -9,13 +11,8 @@ class ModelMonitor:
|
|
9
11
|
name: sql_identifier.SqlIdentifier
|
10
12
|
_model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient
|
11
13
|
|
12
|
-
statement_params = telemetry.get_statement_params(
|
13
|
-
telemetry.TelemetryProject.MLOPS.value,
|
14
|
-
telemetry.TelemetrySubProject.MONITORING.value,
|
15
|
-
)
|
16
|
-
|
17
14
|
def __init__(self) -> None:
|
18
|
-
raise RuntimeError("
|
15
|
+
raise RuntimeError("Model Monitor's initializer is not meant to be used.")
|
19
16
|
|
20
17
|
@classmethod
|
21
18
|
def _ref(
|
@@ -28,10 +25,28 @@ class ModelMonitor:
|
|
28
25
|
self._model_monitor_client = model_monitor_client
|
29
26
|
return self
|
30
27
|
|
28
|
+
@telemetry.send_api_usage_telemetry(
|
29
|
+
project=telemetry.TelemetryProject.MLOPS.value,
|
30
|
+
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
31
|
+
)
|
32
|
+
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
31
33
|
def suspend(self) -> None:
|
32
|
-
"""Suspend
|
33
|
-
|
34
|
-
|
34
|
+
"""Suspend the Model Monitor"""
|
35
|
+
statement_params = telemetry.get_statement_params(
|
36
|
+
telemetry.TelemetryProject.MLOPS.value,
|
37
|
+
telemetry.TelemetrySubProject.MONITORING.value,
|
38
|
+
)
|
39
|
+
self._model_monitor_client.suspend_monitor(self.name, statement_params=statement_params)
|
40
|
+
|
41
|
+
@telemetry.send_api_usage_telemetry(
|
42
|
+
project=telemetry.TelemetryProject.MLOPS.value,
|
43
|
+
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
44
|
+
)
|
45
|
+
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
35
46
|
def resume(self) -> None:
|
36
|
-
"""Resume
|
37
|
-
|
47
|
+
"""Resume the Model Monitor"""
|
48
|
+
statement_params = telemetry.get_statement_params(
|
49
|
+
telemetry.TelemetryProject.MLOPS.value,
|
50
|
+
telemetry.TelemetrySubProject.MONITORING.value,
|
51
|
+
)
|
52
|
+
self._model_monitor_client.resume_monitor(self.name, statement_params=statement_params)
|
@@ -1,13 +1,13 @@
|
|
1
1
|
from types import ModuleType
|
2
|
-
from typing import Any, Dict, List, Optional, Union
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from absl.logging import logging
|
6
|
-
from packaging import version
|
7
6
|
|
8
7
|
from snowflake.ml._internal import telemetry
|
8
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
|
-
from snowflake.ml._internal.utils import
|
10
|
+
from snowflake.ml._internal.utils import sql_identifier
|
11
11
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
@@ -50,14 +50,40 @@ class ModelManager:
|
|
50
50
|
python_version: Optional[str] = None,
|
51
51
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
52
52
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
53
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
53
54
|
code_paths: Optional[List[str]] = None,
|
54
55
|
ext_modules: Optional[List[ModuleType]] = None,
|
55
56
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
56
57
|
options: Optional[model_types.ModelSaveOption] = None,
|
57
58
|
statement_params: Optional[Dict[str, Any]] = None,
|
58
59
|
) -> model_version_impl.ModelVersion:
|
59
|
-
|
60
|
-
|
60
|
+
|
61
|
+
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
62
|
+
|
63
|
+
model_exists = self._model_ops.validate_existence(
|
64
|
+
database_name=database_name_id,
|
65
|
+
schema_name=schema_name_id,
|
66
|
+
model_name=model_name_id,
|
67
|
+
statement_params=statement_params,
|
68
|
+
)
|
69
|
+
|
70
|
+
if version_name is None:
|
71
|
+
if model_exists:
|
72
|
+
versions = self._model_ops.list_models_or_versions(
|
73
|
+
database_name=database_name_id,
|
74
|
+
schema_name=schema_name_id,
|
75
|
+
model_name=model_name_id,
|
76
|
+
statement_params=statement_params,
|
77
|
+
)
|
78
|
+
for _ in range(1000):
|
79
|
+
hrid = self._hrid_generator.generate()[1]
|
80
|
+
if sql_identifier.SqlIdentifier(hrid) not in versions:
|
81
|
+
version_name = hrid
|
82
|
+
break
|
83
|
+
if version_name is None:
|
84
|
+
raise RuntimeError("Random version name generation failed.")
|
85
|
+
else:
|
86
|
+
version_name = self._hrid_generator.generate()[1]
|
61
87
|
|
62
88
|
if isinstance(model, model_version_impl.ModelVersion):
|
63
89
|
(
|
@@ -75,10 +101,24 @@ class ModelManager:
|
|
75
101
|
schema_name=None,
|
76
102
|
model_name=sql_identifier.SqlIdentifier(model_name),
|
77
103
|
version_name=sql_identifier.SqlIdentifier(version_name),
|
104
|
+
model_exists=model_exists,
|
78
105
|
statement_params=statement_params,
|
79
106
|
)
|
80
107
|
return self.get_model(model_name=model_name, statement_params=statement_params).version(version_name)
|
81
108
|
|
109
|
+
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
110
|
+
if model_exists and self._model_ops.validate_existence(
|
111
|
+
database_name=database_name_id,
|
112
|
+
schema_name=schema_name_id,
|
113
|
+
model_name=model_name_id,
|
114
|
+
version_name=version_name_id,
|
115
|
+
statement_params=statement_params,
|
116
|
+
):
|
117
|
+
raise ValueError(
|
118
|
+
f"Model {model_name} version {version_name} already existed. "
|
119
|
+
+ "To auto-generate `version_name`, skip that argument."
|
120
|
+
)
|
121
|
+
|
82
122
|
return self._log_model(
|
83
123
|
model=model,
|
84
124
|
model_name=model_name,
|
@@ -91,6 +131,7 @@ class ModelManager:
|
|
91
131
|
python_version=python_version,
|
92
132
|
signatures=signatures,
|
93
133
|
sample_input_data=sample_input_data,
|
134
|
+
user_files=user_files,
|
94
135
|
code_paths=code_paths,
|
95
136
|
ext_modules=ext_modules,
|
96
137
|
task=task,
|
@@ -103,7 +144,7 @@ class ModelManager:
|
|
103
144
|
model: model_types.SupportedModelType,
|
104
145
|
*,
|
105
146
|
model_name: str,
|
106
|
-
version_name:
|
147
|
+
version_name: str,
|
107
148
|
comment: Optional[str] = None,
|
108
149
|
metrics: Optional[Dict[str, Any]] = None,
|
109
150
|
conda_dependencies: Optional[List[str]] = None,
|
@@ -112,6 +153,7 @@ class ModelManager:
|
|
112
153
|
python_version: Optional[str] = None,
|
113
154
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
114
155
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
156
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
115
157
|
code_paths: Optional[List[str]] = None,
|
116
158
|
ext_modules: Optional[List[ModuleType]] = None,
|
117
159
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
@@ -119,28 +161,8 @@ class ModelManager:
|
|
119
161
|
statement_params: Optional[Dict[str, Any]] = None,
|
120
162
|
) -> model_version_impl.ModelVersion:
|
121
163
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
122
|
-
|
123
|
-
if not version_name:
|
124
|
-
version_name = self._hrid_generator.generate()[1]
|
125
164
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
126
165
|
|
127
|
-
if self._model_ops.validate_existence(
|
128
|
-
database_name=database_name_id,
|
129
|
-
schema_name=schema_name_id,
|
130
|
-
model_name=model_name_id,
|
131
|
-
statement_params=statement_params,
|
132
|
-
) and self._model_ops.validate_existence(
|
133
|
-
database_name=database_name_id,
|
134
|
-
schema_name=schema_name_id,
|
135
|
-
model_name=model_name_id,
|
136
|
-
version_name=version_name_id,
|
137
|
-
statement_params=statement_params,
|
138
|
-
):
|
139
|
-
raise ValueError(
|
140
|
-
f"Model {model_name} version {version_name} already existed. "
|
141
|
-
+ "To auto-generate `version_name`, skip that argument."
|
142
|
-
)
|
143
|
-
|
144
166
|
stage_path = self._model_ops.prepare_model_stage_path(
|
145
167
|
database_name=database_name_id,
|
146
168
|
schema_name=schema_name_id,
|
@@ -148,13 +170,10 @@ class ModelManager:
|
|
148
170
|
)
|
149
171
|
|
150
172
|
platforms = None
|
151
|
-
# TODO(jbahk): Remove the version check after Snowflake 8.40.0 release
|
152
173
|
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
153
|
-
|
154
|
-
if snowflake_env.get_current_snowflake_version(self._model_ops._session) >= version.parse("8.40.0"):
|
174
|
+
if target_platforms:
|
155
175
|
# Convert any string target platforms to TargetPlatform objects
|
156
|
-
|
157
|
-
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
176
|
+
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
158
177
|
|
159
178
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
160
179
|
|
@@ -170,6 +189,7 @@ class ModelManager:
|
|
170
189
|
pip_requirements=pip_requirements,
|
171
190
|
target_platforms=platforms,
|
172
191
|
python_version=python_version,
|
192
|
+
user_files=user_files,
|
173
193
|
code_paths=code_paths,
|
174
194
|
ext_modules=ext_modules,
|
175
195
|
options=options,
|
@@ -229,7 +249,7 @@ class ModelManager:
|
|
229
249
|
*,
|
230
250
|
statement_params: Optional[Dict[str, Any]] = None,
|
231
251
|
) -> model_impl.Model:
|
232
|
-
database_name_id, schema_name_id, model_name_id =
|
252
|
+
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
233
253
|
if self._model_ops.validate_existence(
|
234
254
|
database_name=database_name_id,
|
235
255
|
schema_name=schema_name_id,
|
@@ -289,7 +309,7 @@ class ModelManager:
|
|
289
309
|
*,
|
290
310
|
statement_params: Optional[Dict[str, Any]] = None,
|
291
311
|
) -> None:
|
292
|
-
database_name_id, schema_name_id, model_name_id =
|
312
|
+
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
293
313
|
|
294
314
|
self._model_ops.delete_model_or_version(
|
295
315
|
database_name=database_name_id,
|
@@ -297,3 +317,20 @@ class ModelManager:
|
|
297
317
|
model_name=model_name_id,
|
298
318
|
statement_params=statement_params,
|
299
319
|
)
|
320
|
+
|
321
|
+
def _parse_fully_qualified_name(
|
322
|
+
self, model_name: str
|
323
|
+
) -> Tuple[
|
324
|
+
Optional[sql_identifier.SqlIdentifier], Optional[sql_identifier.SqlIdentifier], sql_identifier.SqlIdentifier
|
325
|
+
]:
|
326
|
+
try:
|
327
|
+
return sql_identifier.parse_fully_qualified_name(model_name)
|
328
|
+
except ValueError:
|
329
|
+
raise exceptions.SnowflakeMLException(
|
330
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
331
|
+
original_exception=ValueError(
|
332
|
+
f"The model_name `{model_name}` cannot be parsed as a SQL identifier. Alphanumeric characters and "
|
333
|
+
"underscores are permitted. See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for "
|
334
|
+
"more information."
|
335
|
+
),
|
336
|
+
)
|
@@ -117,41 +117,49 @@ class Registry:
|
|
117
117
|
options: Optional[model_types.ModelSaveOption] = None,
|
118
118
|
) -> ModelVersion:
|
119
119
|
"""
|
120
|
-
Log a model with various parameters and metadata.
|
120
|
+
Log a model with various parameters and metadata, or a ModelVersion object.
|
121
121
|
|
122
122
|
Args:
|
123
|
-
model:
|
124
|
-
|
125
|
-
|
126
|
-
|
123
|
+
model: Supported model or ModelVersion object.
|
124
|
+
- Supported model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
|
125
|
+
PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers,
|
126
|
+
or Custom Model.
|
127
|
+
- ModelVersion: Source ModelVersion object used to create the new ModelVersion object.
|
128
|
+
model_name: Name to identify the model. This must be a valid Snowflake SQL Identifier. Alphanumeric
|
129
|
+
characters and underscores are permitted.
|
130
|
+
See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for more.
|
127
131
|
version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
|
128
132
|
If not specified, a random name will be generated.
|
129
133
|
comment: Comment associated with the model version. Defaults to None.
|
130
134
|
metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
|
131
|
-
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
132
|
-
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
133
|
-
infer the signature. Defaults to None.
|
134
|
-
sample_input_data: Sample input data to infer model signatures from.
|
135
|
-
It would also be used as background data in explanation and to capture data lineage. Defaults to None.
|
136
135
|
conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
|
137
136
|
to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
|
138
137
|
is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
|
139
138
|
pip_requirements: List of Pip package specifications. Defaults to None.
|
140
|
-
|
141
|
-
|
139
|
+
Models with pip requirements are currently only runnable in Snowpark Container Services.
|
140
|
+
See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
141
|
+
Models with pip requirements specified will not be executable in Snowflake Warehouse where all
|
142
|
+
dependencies must be retrieved from Snowflake Anaconda Channel.
|
142
143
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
143
144
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
144
145
|
python_version: Python version in which the model is run. Defaults to None.
|
146
|
+
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
147
|
+
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
148
|
+
infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
|
149
|
+
sample_input_data: Sample input data to infer model signatures from.
|
150
|
+
It would also be used as background data in explanation and to capture data lineage. Defaults to None.
|
145
151
|
code_paths: List of directories containing code to import. Defaults to None.
|
146
152
|
ext_modules: List of external modules to pickle with the model object.
|
147
153
|
Only supported when logging the following types of model:
|
148
154
|
Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
|
149
155
|
options (Dict[str, Any], optional): Additional model saving options.
|
156
|
+
|
150
157
|
Model Saving Options include:
|
158
|
+
|
151
159
|
- embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
|
152
160
|
Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
|
153
161
|
Channel. Otherwise, defaults to False
|
154
|
-
- relax_version: Whether
|
162
|
+
- relax_version: Whether to relax the version constraints of the dependencies when running in the
|
155
163
|
Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
156
164
|
- function_type: Set the method function type globally. To set method function types individually see
|
157
165
|
function_type in model_options.
|
@@ -163,7 +171,10 @@ class Registry:
|
|
163
171
|
- max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.
|
164
172
|
Defaults to None, determined automatically by Snowflake.
|
165
173
|
- function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
|
174
|
+
Returns:
|
175
|
+
ModelVersion: ModelVersion object corresponding to the model just logged.
|
166
176
|
"""
|
177
|
+
|
167
178
|
...
|
168
179
|
|
169
180
|
@overload
|
@@ -214,6 +225,7 @@ class Registry:
|
|
214
225
|
python_version: Optional[str] = None,
|
215
226
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
216
227
|
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
228
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
217
229
|
code_paths: Optional[List[str]] = None,
|
218
230
|
ext_modules: Optional[List[ModuleType]] = None,
|
219
231
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
@@ -228,25 +240,31 @@ class Registry:
|
|
228
240
|
PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers,
|
229
241
|
or Custom Model.
|
230
242
|
- ModelVersion: Source ModelVersion object used to create the new ModelVersion object.
|
231
|
-
model_name: Name to identify the model.
|
243
|
+
model_name: Name to identify the model. This must be a valid Snowflake SQL Identifier. Alphanumeric
|
244
|
+
characters and underscores are permitted.
|
245
|
+
See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for more.
|
232
246
|
version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
|
233
247
|
If not specified, a random name will be generated.
|
234
248
|
comment: Comment associated with the model version. Defaults to None.
|
235
249
|
metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
|
236
|
-
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
237
|
-
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
238
|
-
infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
|
239
|
-
sample_input_data: Sample input data to infer model signatures from.
|
240
|
-
It would also be used as background data in explanation and to capture data lineage. Defaults to None.
|
241
250
|
conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
|
242
251
|
to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
|
243
252
|
is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
|
244
253
|
pip_requirements: List of Pip package specifications. Defaults to None.
|
245
|
-
|
246
|
-
|
254
|
+
Models with pip requirements are currently only runnable in Snowpark Container Services.
|
255
|
+
See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
256
|
+
Models with pip requirements specified will not be executable in Snowflake Warehouse where all
|
257
|
+
dependencies must be retrieved from Snowflake Anaconda Channel.
|
247
258
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
248
259
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
249
260
|
python_version: Python version in which the model is run. Defaults to None.
|
261
|
+
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
262
|
+
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
263
|
+
infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
|
264
|
+
sample_input_data: Sample input data to infer model signatures from.
|
265
|
+
It would also be used as background data in explanation and to capture data lineage. Defaults to None.
|
266
|
+
user_files: Dictionary where the keys are subdirectories, and values are lists of local file name
|
267
|
+
strings. The local file name strings can include wildcards (? or *) for matching multiple files.
|
250
268
|
code_paths: List of directories containing code to import. Defaults to None.
|
251
269
|
ext_modules: List of external modules to pickle with the model object.
|
252
270
|
Only supported when logging the following types of model:
|
@@ -261,7 +279,7 @@ class Registry:
|
|
261
279
|
- embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
|
262
280
|
Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
|
263
281
|
Channel. Otherwise, defaults to False
|
264
|
-
- relax_version: Whether
|
282
|
+
- relax_version: Whether to relax the version constraints of the dependencies when running in the
|
265
283
|
Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
266
284
|
- function_type: Set the method function type globally. To set method function types individually see
|
267
285
|
function_type in model_options.
|
@@ -301,6 +319,7 @@ class Registry:
|
|
301
319
|
python_version=python_version,
|
302
320
|
signatures=signatures,
|
303
321
|
sample_input_data=sample_input_data,
|
322
|
+
user_files=user_files,
|
304
323
|
code_paths=code_paths,
|
305
324
|
ext_modules=ext_modules,
|
306
325
|
task=task,
|
@@ -388,15 +407,15 @@ class Registry:
|
|
388
407
|
source_config: model_monitor_config.ModelMonitorSourceConfig,
|
389
408
|
model_monitor_config: model_monitor_config.ModelMonitorConfig,
|
390
409
|
) -> model_monitor.ModelMonitor:
|
391
|
-
"""Add a Model Monitor to the Registry
|
410
|
+
"""Add a Model Monitor to the Registry.
|
392
411
|
|
393
412
|
Args:
|
394
|
-
name: Name of Model Monitor to create
|
395
|
-
source_config: Configuration options of table for
|
396
|
-
model_monitor_config: Configuration options of
|
413
|
+
name: Name of Model Monitor to create.
|
414
|
+
source_config: Configuration options of table for Model Monitor.
|
415
|
+
model_monitor_config: Configuration options of Model Monitor.
|
397
416
|
|
398
417
|
Returns:
|
399
|
-
The newly added
|
418
|
+
The newly added Model Monitor object.
|
400
419
|
|
401
420
|
Raises:
|
402
421
|
ValueError: If monitoring is not enabled in the Registry.
|
@@ -407,16 +426,16 @@ class Registry:
|
|
407
426
|
|
408
427
|
@overload
|
409
428
|
def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
|
410
|
-
"""Get a Model Monitor on a
|
429
|
+
"""Get a Model Monitor on a Model Version from the Registry.
|
411
430
|
|
412
431
|
Args:
|
413
|
-
model_version:
|
432
|
+
model_version: Model Version for which to retrieve the Model Monitor.
|
414
433
|
"""
|
415
434
|
...
|
416
435
|
|
417
436
|
@overload
|
418
437
|
def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
|
419
|
-
"""Get a Model Monitor from the Registry
|
438
|
+
"""Get a Model Monitor by name from the Registry.
|
420
439
|
|
421
440
|
Args:
|
422
441
|
name: Name of Model Monitor to retrieve.
|
@@ -431,14 +450,14 @@ class Registry:
|
|
431
450
|
def get_monitor(
|
432
451
|
self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None
|
433
452
|
) -> model_monitor.ModelMonitor:
|
434
|
-
"""Get a Model Monitor from the Registry
|
453
|
+
"""Get a Model Monitor from the Registry.
|
435
454
|
|
436
455
|
Args:
|
437
456
|
name: Name of Model Monitor to retrieve.
|
438
|
-
model_version:
|
457
|
+
model_version: Model Version for which to retrieve the Model Monitor.
|
439
458
|
|
440
459
|
Returns:
|
441
|
-
The fetched
|
460
|
+
The fetched Model Monitor.
|
442
461
|
|
443
462
|
Raises:
|
444
463
|
ValueError: If monitoring is not enabled in the Registry.
|
@@ -476,7 +495,7 @@ class Registry:
|
|
476
495
|
)
|
477
496
|
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
478
497
|
def delete_monitor(self, name: str) -> None:
|
479
|
-
"""Delete a Model Monitor from the Registry
|
498
|
+
"""Delete a Model Monitor by name from the Registry.
|
480
499
|
|
481
500
|
Args:
|
482
501
|
name: Name of the Model Monitor to delete.
|
@@ -0,0 +1,75 @@
|
|
1
|
+
import http
|
2
|
+
import logging
|
3
|
+
from datetime import timedelta
|
4
|
+
from typing import Dict, Optional
|
5
|
+
|
6
|
+
import requests
|
7
|
+
from cryptography.hazmat.primitives.asymmetric import types
|
8
|
+
from requests import auth
|
9
|
+
|
10
|
+
from snowflake.ml._internal.utils import jwt_generator
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
_JWT_TOKEN_CACHE: Dict[str, Dict[int, str]] = {}
|
14
|
+
|
15
|
+
|
16
|
+
def get_jwt_token_generator(
|
17
|
+
account: str,
|
18
|
+
user: str,
|
19
|
+
private_key: types.PRIVATE_KEY_TYPES,
|
20
|
+
lifetime: Optional[timedelta] = None,
|
21
|
+
renewal_delay: Optional[timedelta] = None,
|
22
|
+
) -> jwt_generator.JWTGenerator:
|
23
|
+
return jwt_generator.JWTGenerator(account, user, private_key, lifetime=lifetime, renewal_delay=renewal_delay)
|
24
|
+
|
25
|
+
|
26
|
+
def _get_snowflake_token_by_jwt(
|
27
|
+
jwt_token_generator: jwt_generator.JWTGenerator,
|
28
|
+
account: Optional[str] = None,
|
29
|
+
role: Optional[str] = None,
|
30
|
+
endpoint: Optional[str] = None,
|
31
|
+
snowflake_account_url: Optional[str] = None,
|
32
|
+
) -> str:
|
33
|
+
scope_role = f"session:role:{role}" if role is not None else None
|
34
|
+
scope = " ".join(filter(None, [scope_role, endpoint]))
|
35
|
+
data = {
|
36
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
37
|
+
"scope": scope or None,
|
38
|
+
"assertion": jwt_token_generator.get_token(),
|
39
|
+
}
|
40
|
+
account = account or jwt_token_generator.account
|
41
|
+
url = f"https://{account}.snowflakecomputing.com/oauth/token"
|
42
|
+
if snowflake_account_url:
|
43
|
+
url = f"{snowflake_account_url}/oauth/token"
|
44
|
+
|
45
|
+
cache_key = hash(frozenset(data.items()))
|
46
|
+
if url in _JWT_TOKEN_CACHE:
|
47
|
+
if cache_key in _JWT_TOKEN_CACHE[url]:
|
48
|
+
return _JWT_TOKEN_CACHE[url][cache_key]
|
49
|
+
else:
|
50
|
+
_JWT_TOKEN_CACHE[url] = {}
|
51
|
+
|
52
|
+
response = requests.post(url, data=data)
|
53
|
+
if response.status_code != http.HTTPStatus.OK:
|
54
|
+
raise RuntimeError(f"Failed to get snowflake token: {response.status_code} {response.content!r}")
|
55
|
+
auth_token = response.text
|
56
|
+
_JWT_TOKEN_CACHE[url][cache_key] = auth_token
|
57
|
+
return auth_token
|
58
|
+
|
59
|
+
|
60
|
+
class SnowflakeJWTTokenAuth(auth.AuthBase):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
jwt_token_generator: jwt_generator.JWTGenerator,
|
64
|
+
account: Optional[str] = None,
|
65
|
+
role: Optional[str] = None,
|
66
|
+
endpoint: Optional[str] = None,
|
67
|
+
snowflake_account_url: Optional[str] = None,
|
68
|
+
) -> None:
|
69
|
+
self.snowflake_token = _get_snowflake_token_by_jwt(
|
70
|
+
jwt_token_generator, account, role, endpoint, snowflake_account_url
|
71
|
+
)
|
72
|
+
|
73
|
+
def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
|
74
|
+
r.headers["Authorization"] = f'Snowflake Token="{self.snowflake_token}"'
|
75
|
+
return r
|
snowflake/ml/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
VERSION="1.7.
|
1
|
+
VERSION="1.7.3"
|