snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +51 -30
- snowflake/ml/model/_client/ops/service_ops.py +13 -2
- snowflake/ml/model/_client/sql/model.py +0 -14
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
- snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +48 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +71 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/model_signature.py +38 -9
- snowflake/ml/model/type_hints.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +158 -1045
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +106 -230
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +7 -96
- snowflake/ml/registry/registry.py +17 -29
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +31 -5
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +48 -47
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,59 +1,20 @@
|
|
1
|
+
import json
|
1
2
|
from typing import Any, Dict, List, Optional
|
2
3
|
|
3
4
|
from snowflake import snowpark
|
4
|
-
from snowflake.ml._internal import
|
5
|
-
from snowflake.ml._internal.utils import db_utils, sql_identifier
|
5
|
+
from snowflake.ml._internal.utils import sql_identifier
|
6
6
|
from snowflake.ml.model import type_hints
|
7
7
|
from snowflake.ml.model._client.model import model_version_impl
|
8
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
9
8
|
from snowflake.ml.monitoring import model_monitor
|
10
9
|
from snowflake.ml.monitoring._client import model_monitor_sql_client
|
11
|
-
from snowflake.ml.monitoring.entities import
|
12
|
-
model_monitor_config,
|
13
|
-
model_monitor_interval,
|
14
|
-
)
|
10
|
+
from snowflake.ml.monitoring.entities import model_monitor_config
|
15
11
|
from snowflake.snowpark import session
|
16
12
|
|
17
13
|
|
18
|
-
def _validate_name_constraints(model_version: model_version_impl.ModelVersion) -> None:
|
19
|
-
system_table_prefixes = [
|
20
|
-
model_monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX,
|
21
|
-
model_monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX,
|
22
|
-
]
|
23
|
-
|
24
|
-
max_allowed_model_name_and_version_length = (
|
25
|
-
db_utils.MAX_IDENTIFIER_LENGTH - max(len(prefix) for prefix in system_table_prefixes) - 1
|
26
|
-
) # -1 includes '_' between model_name + model_version
|
27
|
-
if len(model_version.model_name) + len(model_version.version_name) > max_allowed_model_name_and_version_length:
|
28
|
-
error_msg = f"Model name and version name exceeds maximum length of {max_allowed_model_name_and_version_length}"
|
29
|
-
raise ValueError(error_msg)
|
30
|
-
|
31
|
-
|
32
14
|
class ModelMonitorManager:
|
33
|
-
"""Class to manage internal operations for Model Monitor workflows."""
|
34
|
-
|
35
|
-
@staticmethod
|
36
|
-
def setup(session: session.Session, database_name: str, schema_name: str) -> None:
|
37
|
-
"""Static method to set up schema for Model Monitoring resources.
|
38
|
-
|
39
|
-
Args:
|
40
|
-
session: The Snowpark Session to connect with Snowflake.
|
41
|
-
database_name: The name of the database. If None, the current database of the session
|
42
|
-
will be used. Defaults to None.
|
43
|
-
schema_name: The name of the schema. If None, the current schema of the session
|
44
|
-
will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
|
45
|
-
"""
|
46
|
-
statement_params = telemetry.get_statement_params(
|
47
|
-
project=telemetry.TelemetryProject.MLOPS.value,
|
48
|
-
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
49
|
-
)
|
50
|
-
database_name_id = sql_identifier.SqlIdentifier(database_name)
|
51
|
-
schema_name_id = sql_identifier.SqlIdentifier(schema_name)
|
52
|
-
model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema(
|
53
|
-
session, database_name_id, schema_name_id, statement_params=statement_params
|
54
|
-
)
|
15
|
+
"""Class to manage internal operations for Model Monitor workflows."""
|
55
16
|
|
56
|
-
def
|
17
|
+
def _validate_task_from_model_version(
|
57
18
|
self,
|
58
19
|
model_version: model_version_impl.ModelVersion,
|
59
20
|
) -> type_hints.Task:
|
@@ -68,7 +29,6 @@ class ModelMonitorManager:
|
|
68
29
|
database_name: sql_identifier.SqlIdentifier,
|
69
30
|
schema_name: sql_identifier.SqlIdentifier,
|
70
31
|
*,
|
71
|
-
create_if_not_exists: bool = False,
|
72
32
|
statement_params: Optional[Dict[str, Any]] = None,
|
73
33
|
) -> None:
|
74
34
|
"""
|
@@ -79,233 +39,156 @@ class ModelMonitorManager:
|
|
79
39
|
session: The Snowpark Session to connect with Snowflake.
|
80
40
|
database_name: The name of the database.
|
81
41
|
schema_name: The name of the schema.
|
82
|
-
create_if_not_exists: Flag whether to initialize resources in the schema needed for Model Monitoring.
|
83
42
|
statement_params: Optional set of statement params.
|
84
|
-
|
85
|
-
Raises:
|
86
|
-
ValueError: When there is no specified or active database in the session.
|
87
43
|
"""
|
88
44
|
self._database_name = database_name
|
89
45
|
self._schema_name = schema_name
|
90
46
|
self.statement_params = statement_params
|
47
|
+
|
91
48
|
self._model_monitor_client = model_monitor_sql_client.ModelMonitorSQLClient(
|
92
49
|
session,
|
93
50
|
database_name=self._database_name,
|
94
51
|
schema_name=self._schema_name,
|
95
52
|
)
|
96
|
-
if create_if_not_exists:
|
97
|
-
model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema(
|
98
|
-
session, self._database_name, self._schema_name, self.statement_params
|
99
|
-
)
|
100
|
-
elif not self._model_monitor_client._validate_is_initialized():
|
101
|
-
raise ValueError(
|
102
|
-
"Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup"
|
103
|
-
)
|
104
53
|
|
105
|
-
def
|
54
|
+
def _validate_model_function_from_model_version(
|
106
55
|
self, function: str, model_version: model_version_impl.ModelVersion
|
107
|
-
) ->
|
56
|
+
) -> None:
|
108
57
|
functions = model_version.show_functions()
|
109
58
|
for f in functions:
|
110
59
|
if f["target_method"] == function:
|
111
|
-
return
|
60
|
+
return
|
112
61
|
existing_target_methods = {f["target_method"] for f in functions}
|
113
62
|
raise ValueError(
|
114
63
|
f"Function with name {function} does not exist in the given model version. "
|
115
64
|
f"Found: {existing_target_methods}."
|
116
65
|
)
|
117
66
|
|
118
|
-
def
|
119
|
-
|
120
|
-
table_config: model_monitor_config.ModelMonitorTableConfig,
|
121
|
-
model_monitor_config: model_monitor_config.ModelMonitorConfig,
|
122
|
-
) -> None:
|
123
|
-
"""Validate provided config for model monitor.
|
124
|
-
|
125
|
-
Args:
|
126
|
-
table_config: Config for model monitor tables.
|
127
|
-
model_monitor_config: Config for ModelMonitor.
|
128
|
-
|
129
|
-
Raises:
|
130
|
-
ValueError: If warehouse provided does not exist.
|
131
|
-
"""
|
132
|
-
|
133
|
-
# Validate naming will not exceed 255 chars
|
134
|
-
_validate_name_constraints(model_monitor_config.model_version)
|
135
|
-
|
136
|
-
if len(table_config.prediction_columns) != len(table_config.label_columns):
|
137
|
-
raise ValueError("Prediction and Label column names must be of the same length.")
|
138
|
-
# output and ground cols are list to keep interface extensible.
|
139
|
-
# for prpr only one label and one output col will be supported
|
140
|
-
if len(table_config.prediction_columns) != 1 or len(table_config.label_columns) != 1:
|
141
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
142
|
-
|
143
|
-
# Validate warehouse exists.
|
144
|
-
warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
|
145
|
-
self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
|
146
|
-
|
147
|
-
# Validate refresh interval.
|
148
|
-
try:
|
149
|
-
num_units, time_units = model_monitor_config.refresh_interval.strip().split(" ")
|
150
|
-
int(num_units) # try to cast
|
151
|
-
if time_units.lower() not in {"seconds", "minutes", "hours", "days"}:
|
152
|
-
raise ValueError(
|
153
|
-
"""Invalid time unit in refresh interval. Provide '<num> <seconds | minutes | hours | days>'.
|
154
|
-
See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
|
155
|
-
)
|
156
|
-
except Exception as e: # TODO: Link to DT page.
|
157
|
-
raise ValueError(
|
158
|
-
f"""Failed to parse refresh interval with exception {e}.
|
159
|
-
Provide '<num> <seconds | minutes | hours | days>'.
|
160
|
-
See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
|
161
|
-
)
|
67
|
+
def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]:
|
68
|
+
return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
|
162
69
|
|
163
70
|
def add_monitor(
|
164
71
|
self,
|
165
72
|
name: str,
|
166
|
-
|
73
|
+
source_config: model_monitor_config.ModelMonitorSourceConfig,
|
167
74
|
model_monitor_config: model_monitor_config.ModelMonitorConfig,
|
168
|
-
*,
|
169
|
-
add_dashboard_udtfs: bool = False,
|
170
75
|
) -> model_monitor.ModelMonitor:
|
171
76
|
"""Add a new Model Monitor.
|
172
77
|
|
173
78
|
Args:
|
174
79
|
name: Name of Model Monitor to create.
|
175
|
-
|
80
|
+
source_config: Configuration options for the source table used in ModelMonitor.
|
176
81
|
model_monitor_config: Configuration options of ModelMonitor.
|
177
|
-
add_dashboard_udtfs: Add UDTFs useful for creating a dashboard.
|
178
82
|
|
179
83
|
Returns:
|
180
84
|
The newly added ModelMonitor object.
|
181
85
|
"""
|
182
|
-
|
183
|
-
self.
|
184
|
-
|
86
|
+
warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
|
87
|
+
self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
|
88
|
+
self._validate_model_function_from_model_version(
|
185
89
|
model_monitor_config.model_function_name, model_monitor_config.model_version
|
186
90
|
)
|
187
|
-
|
188
|
-
|
91
|
+
self._validate_task_from_model_version(model_monitor_config.model_version)
|
92
|
+
monitor_database_name_id, monitor_schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(
|
93
|
+
name
|
94
|
+
)
|
95
|
+
source_database_name_id, source_schema_name_id, source_name_id = sql_identifier.parse_fully_qualified_name(
|
96
|
+
source_config.source
|
97
|
+
)
|
98
|
+
baseline_database_name_id, baseline_schema_name_id, baseline_name_id = (
|
99
|
+
sql_identifier.parse_fully_qualified_name(source_config.baseline)
|
100
|
+
if source_config.baseline
|
101
|
+
else (None, None, None)
|
189
102
|
)
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
103
|
+
model_database_name_id, model_schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(
|
104
|
+
model_monitor_config.model_version.fully_qualified_model_name
|
105
|
+
)
|
106
|
+
|
107
|
+
prediction_score_columns = self._build_column_list_from_input(source_config.prediction_score_columns)
|
108
|
+
prediction_class_columns = self._build_column_list_from_input(source_config.prediction_class_columns)
|
109
|
+
actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns)
|
110
|
+
actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns)
|
111
|
+
|
112
|
+
id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns]
|
113
|
+
ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column)
|
198
114
|
|
199
115
|
# Validate source table
|
200
|
-
self._model_monitor_client.
|
201
|
-
|
116
|
+
self._model_monitor_client.validate_source(
|
117
|
+
source_database=source_database_name_id,
|
118
|
+
source_schema=source_schema_name_id,
|
119
|
+
source=source_name_id,
|
202
120
|
timestamp_column=ts_column,
|
203
|
-
|
204
|
-
|
121
|
+
prediction_score_columns=prediction_score_columns,
|
122
|
+
prediction_class_columns=prediction_class_columns,
|
123
|
+
actual_score_columns=actual_score_columns,
|
124
|
+
actual_class_columns=actual_class_columns,
|
205
125
|
id_columns=id_columns,
|
206
|
-
model_function=model_function,
|
207
126
|
)
|
208
127
|
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
128
|
+
self._model_monitor_client.create_model_monitor(
|
129
|
+
monitor_database=monitor_database_name_id,
|
130
|
+
monitor_schema=monitor_schema_name_id,
|
131
|
+
monitor_name=monitor_name_id,
|
132
|
+
source_database=source_database_name_id,
|
133
|
+
source_schema=source_schema_name_id,
|
134
|
+
source=source_name_id,
|
135
|
+
model_database=model_database_name_id,
|
136
|
+
model_schema=model_schema_name_id,
|
137
|
+
model_name=model_name_id,
|
217
138
|
version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
218
139
|
function_name=model_monitor_config.model_function_name,
|
140
|
+
warehouse_name=warehouse_name_id,
|
219
141
|
timestamp_column=ts_column,
|
220
|
-
prediction_columns=prediction_columns,
|
221
|
-
label_columns=label_columns,
|
222
142
|
id_columns=id_columns,
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
self._model_monitor_client.create_dynamic_tables_for_monitor(
|
229
|
-
model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
|
230
|
-
model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
231
|
-
task=task,
|
232
|
-
source_table_name=source_table_name_id,
|
233
|
-
refresh_interval=monitor_refresh_interval,
|
143
|
+
prediction_score_columns=prediction_score_columns,
|
144
|
+
prediction_class_columns=prediction_class_columns,
|
145
|
+
actual_score_columns=actual_score_columns,
|
146
|
+
actual_class_columns=actual_class_columns,
|
147
|
+
refresh_interval=model_monitor_config.refresh_interval,
|
234
148
|
aggregation_window=model_monitor_config.aggregation_window,
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
prediction_columns=prediction_columns,
|
239
|
-
label_columns=label_columns,
|
240
|
-
score_type=score_type,
|
241
|
-
)
|
242
|
-
|
243
|
-
# Initialize baseline table.
|
244
|
-
self._model_monitor_client.initialize_baseline_table(
|
245
|
-
model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
|
246
|
-
version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
247
|
-
source_table_name=table_config.source_table,
|
248
|
-
columns_to_drop=[ts_column, *id_columns],
|
149
|
+
baseline_database=baseline_database_name_id,
|
150
|
+
baseline_schema=baseline_schema_name_id,
|
151
|
+
baseline=baseline_name_id,
|
249
152
|
statement_params=self.statement_params,
|
250
153
|
)
|
251
|
-
|
252
|
-
# Add udtfs helpful for dashboard queries.
|
253
|
-
# TODO(apgupta) Make this true by default.
|
254
|
-
if add_dashboard_udtfs:
|
255
|
-
self._model_monitor_client.add_dashboard_udtfs(
|
256
|
-
name_id,
|
257
|
-
model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
|
258
|
-
model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
259
|
-
task=task,
|
260
|
-
score_type=score_type,
|
261
|
-
output_columns=prediction_columns,
|
262
|
-
ground_truth_columns=label_columns,
|
263
|
-
)
|
264
|
-
|
265
154
|
return model_monitor.ModelMonitor._ref(
|
266
155
|
model_monitor_client=self._model_monitor_client,
|
267
|
-
name=
|
268
|
-
fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
|
269
|
-
version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
270
|
-
function_name=sql_identifier.SqlIdentifier(model_monitor_config.model_function_name),
|
271
|
-
prediction_columns=prediction_columns,
|
272
|
-
label_columns=label_columns,
|
156
|
+
name=monitor_name_id,
|
273
157
|
)
|
274
158
|
|
275
159
|
def get_monitor_by_model_version(
|
276
160
|
self, model_version: model_version_impl.ModelVersion
|
277
161
|
) -> model_monitor.ModelMonitor:
|
278
|
-
|
279
|
-
version_name = sql_identifier.SqlIdentifier(model_version.version_name)
|
280
|
-
if self._model_monitor_client.validate_existence(fq_model_name, version_name, self.statement_params):
|
281
|
-
model_db, model_schema, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name)
|
282
|
-
if model_db is None or model_schema is None:
|
283
|
-
raise ValueError("Failed to parse model name")
|
284
|
-
|
285
|
-
model_monitor_params: model_monitor_sql_client._ModelMonitorParams = (
|
286
|
-
self._model_monitor_client.get_model_monitor_by_model_version(
|
287
|
-
model_db=model_db,
|
288
|
-
model_schema=model_schema,
|
289
|
-
model_name=model_name,
|
290
|
-
version_name=version_name,
|
291
|
-
statement_params=self.statement_params,
|
292
|
-
)
|
293
|
-
)
|
294
|
-
return model_monitor.ModelMonitor._ref(
|
295
|
-
model_monitor_client=self._model_monitor_client,
|
296
|
-
name=sql_identifier.SqlIdentifier(model_monitor_params["monitor_name"]),
|
297
|
-
fully_qualified_model_name=fq_model_name,
|
298
|
-
version_name=version_name,
|
299
|
-
function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
|
300
|
-
prediction_columns=model_monitor_params["prediction_columns"],
|
301
|
-
label_columns=model_monitor_params["label_columns"],
|
302
|
-
)
|
162
|
+
"""Get a Model Monitor by Model Version.
|
303
163
|
|
304
|
-
|
305
|
-
|
306
|
-
|
164
|
+
Args:
|
165
|
+
model_version: ModelVersion to retrieve Model Monitor for.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
The fetched ModelMonitor.
|
169
|
+
|
170
|
+
Raises:
|
171
|
+
ValueError: If model monitor is not found.
|
172
|
+
"""
|
173
|
+
rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
|
174
|
+
|
175
|
+
def model_match_fn(model_details: Dict[str, str]) -> bool:
|
176
|
+
return (
|
177
|
+
model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
|
178
|
+
and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
|
307
179
|
)
|
308
180
|
|
181
|
+
rows = [row for row in rows if model_match_fn(json.loads(row[model_monitor_sql_client.MODEL_JSON_COL_NAME]))]
|
182
|
+
if len(rows) == 0:
|
183
|
+
raise ValueError("Unable to find model monitor for the given model version.")
|
184
|
+
if len(rows) > 1:
|
185
|
+
raise ValueError("Found multiple model monitors for the given model version.")
|
186
|
+
|
187
|
+
return model_monitor.ModelMonitor._ref(
|
188
|
+
model_monitor_client=self._model_monitor_client,
|
189
|
+
name=sql_identifier.SqlIdentifier(rows[0]["name"]),
|
190
|
+
)
|
191
|
+
|
309
192
|
def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
|
310
193
|
"""Get a Model Monitor from the Registry
|
311
194
|
|
@@ -318,25 +201,18 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
|
|
318
201
|
Returns:
|
319
202
|
The fetched ModelMonitor.
|
320
203
|
"""
|
321
|
-
|
204
|
+
database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
|
322
205
|
|
323
206
|
if not self._model_monitor_client.validate_existence_by_name(
|
324
|
-
|
207
|
+
database_name=database_name_id,
|
208
|
+
schema_name=schema_name_id,
|
209
|
+
monitor_name=monitor_name_id,
|
325
210
|
statement_params=self.statement_params,
|
326
211
|
):
|
327
212
|
raise ValueError(f"Unable to find model monitor '{name}'")
|
328
|
-
model_monitor_params: model_monitor_sql_client._ModelMonitorParams = (
|
329
|
-
self._model_monitor_client.get_model_monitor_by_name(name_id, statement_params=self.statement_params)
|
330
|
-
)
|
331
|
-
|
332
213
|
return model_monitor.ModelMonitor._ref(
|
333
214
|
model_monitor_client=self._model_monitor_client,
|
334
|
-
name=
|
335
|
-
fully_qualified_model_name=model_monitor_params["fully_qualified_model_name"],
|
336
|
-
version_name=sql_identifier.SqlIdentifier(model_monitor_params["version_name"]),
|
337
|
-
function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
|
338
|
-
prediction_columns=model_monitor_params["prediction_columns"],
|
339
|
-
label_columns=model_monitor_params["label_columns"],
|
215
|
+
name=monitor_name_id,
|
340
216
|
)
|
341
217
|
|
342
218
|
def show_model_monitors(self) -> List[snowpark.Row]:
|
@@ -345,7 +221,7 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
|
|
345
221
|
Returns:
|
346
222
|
List of snowpark.Row containing metadata for each model monitor.
|
347
223
|
"""
|
348
|
-
return self._model_monitor_client.
|
224
|
+
return self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
|
349
225
|
|
350
226
|
def delete_monitor(self, name: str) -> None:
|
351
227
|
"""Delete a Model Monitor from the Registry
|
@@ -353,10 +229,10 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
|
|
353
229
|
Args:
|
354
230
|
name: Name of the Model Monitor to delete.
|
355
231
|
"""
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
232
|
+
database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
|
233
|
+
self._model_monitor_client.drop_model_monitor(
|
234
|
+
database_name=database_name_id,
|
235
|
+
schema_name=schema_name_id,
|
236
|
+
monitor_name=monitor_name_id,
|
237
|
+
statement_params=self.statement_params,
|
238
|
+
)
|
@@ -1,17 +1,19 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from typing import List
|
2
|
+
from typing import List, Optional
|
3
3
|
|
4
4
|
from snowflake.ml.model._client.model import model_version_impl
|
5
|
-
from snowflake.ml.monitoring.entities import model_monitor_interval
|
6
5
|
|
7
6
|
|
8
7
|
@dataclass
|
9
|
-
class
|
10
|
-
|
8
|
+
class ModelMonitorSourceConfig:
|
9
|
+
source: str
|
11
10
|
timestamp_column: str
|
12
|
-
prediction_columns: List[str]
|
13
|
-
label_columns: List[str]
|
14
11
|
id_columns: List[str]
|
12
|
+
prediction_score_columns: Optional[List[str]] = None
|
13
|
+
prediction_class_columns: Optional[List[str]] = None
|
14
|
+
actual_score_columns: Optional[List[str]] = None
|
15
|
+
actual_class_columns: Optional[List[str]] = None
|
16
|
+
baseline: Optional[str] = None
|
15
17
|
|
16
18
|
|
17
19
|
@dataclass
|
@@ -22,7 +24,5 @@ class ModelMonitorConfig:
|
|
22
24
|
model_function_name: str
|
23
25
|
background_compute_warehouse_name: str
|
24
26
|
# TODO: Add support for pythonic notion of time.
|
25
|
-
refresh_interval: str =
|
26
|
-
aggregation_window:
|
27
|
-
model_monitor_interval.ModelMonitorAggregationWindow.WINDOW_1_DAY
|
28
|
-
)
|
27
|
+
refresh_interval: str = "1 hour"
|
28
|
+
aggregation_window: str = "1 day"
|
@@ -1,8 +1,3 @@
|
|
1
|
-
from typing import List, Union
|
2
|
-
|
3
|
-
import pandas as pd
|
4
|
-
|
5
|
-
from snowflake import snowpark
|
6
1
|
from snowflake.ml._internal import telemetry
|
7
2
|
from snowflake.ml._internal.utils import sql_identifier
|
8
3
|
from snowflake.ml.monitoring._client import model_monitor_sql_client
|
@@ -13,11 +8,11 @@ class ModelMonitor:
|
|
13
8
|
|
14
9
|
name: sql_identifier.SqlIdentifier
|
15
10
|
_model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
11
|
+
|
12
|
+
statement_params = telemetry.get_statement_params(
|
13
|
+
telemetry.TelemetryProject.MLOPS.value,
|
14
|
+
telemetry.TelemetrySubProject.MONITORING.value,
|
15
|
+
)
|
21
16
|
|
22
17
|
def __init__(self) -> None:
|
23
18
|
raise RuntimeError("ModelMonitor's initializer is not meant to be used.")
|
@@ -27,100 +22,16 @@ class ModelMonitor:
|
|
27
22
|
cls,
|
28
23
|
model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient,
|
29
24
|
name: sql_identifier.SqlIdentifier,
|
30
|
-
*,
|
31
|
-
fully_qualified_model_name: str,
|
32
|
-
version_name: sql_identifier.SqlIdentifier,
|
33
|
-
function_name: sql_identifier.SqlIdentifier,
|
34
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
35
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
36
25
|
) -> "ModelMonitor":
|
37
26
|
self: "ModelMonitor" = object.__new__(cls)
|
38
27
|
self.name = name
|
39
28
|
self._model_monitor_client = model_monitor_client
|
40
|
-
self._fully_qualified_model_name = fully_qualified_model_name
|
41
|
-
self._version_name = version_name
|
42
|
-
self._function_name = function_name
|
43
|
-
self._prediction_columns = prediction_columns
|
44
|
-
self._label_columns = label_columns
|
45
29
|
return self
|
46
30
|
|
47
|
-
@telemetry.send_api_usage_telemetry(
|
48
|
-
project=telemetry.TelemetryProject.MLOPS.value,
|
49
|
-
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
50
|
-
)
|
51
|
-
def set_baseline(self, baseline_df: Union[pd.DataFrame, snowpark.DataFrame]) -> None:
|
52
|
-
"""
|
53
|
-
The baseline dataframe is compared with the monitored data once monitoring is enabled.
|
54
|
-
The columns of the dataframe should match the columns of the source table that the
|
55
|
-
ModelMonitor was configured with. Calling this method overwrites any existing baseline split data.
|
56
|
-
|
57
|
-
Args:
|
58
|
-
baseline_df: Snowpark dataframe containing baseline data.
|
59
|
-
|
60
|
-
Raises:
|
61
|
-
ValueError: baseline_df does not contain prediction or label columns
|
62
|
-
"""
|
63
|
-
statement_params = telemetry.get_statement_params(
|
64
|
-
project=telemetry.TelemetryProject.MLOPS.value,
|
65
|
-
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
66
|
-
)
|
67
|
-
|
68
|
-
if isinstance(baseline_df, pd.DataFrame):
|
69
|
-
baseline_df = self._model_monitor_client._sql_client._session.create_dataframe(baseline_df)
|
70
|
-
|
71
|
-
column_names_identifiers: List[sql_identifier.SqlIdentifier] = [
|
72
|
-
sql_identifier.SqlIdentifier(column_name) for column_name in baseline_df.columns
|
73
|
-
]
|
74
|
-
prediction_cols_not_found = any(
|
75
|
-
[prediction_col not in column_names_identifiers for prediction_col in self._prediction_columns]
|
76
|
-
)
|
77
|
-
label_cols_not_found = any(
|
78
|
-
[label_col.identifier() not in column_names_identifiers for label_col in self._label_columns]
|
79
|
-
)
|
80
|
-
|
81
|
-
if prediction_cols_not_found:
|
82
|
-
raise ValueError(
|
83
|
-
"Specified prediction columns were not found in the baseline dataframe. "
|
84
|
-
f"Columns provided were: {column_names_identifiers}. "
|
85
|
-
f"Configured prediction columns were: {self._prediction_columns}."
|
86
|
-
)
|
87
|
-
if label_cols_not_found:
|
88
|
-
raise ValueError(
|
89
|
-
"Specified label columns were not found in the baseline dataframe."
|
90
|
-
f"Columns provided in the baseline dataframe were: {column_names_identifiers}."
|
91
|
-
f"Configured label columns were: {self._label_columns}."
|
92
|
-
)
|
93
|
-
|
94
|
-
# Create the table by materializing the df
|
95
|
-
self._model_monitor_client.materialize_baseline_dataframe(
|
96
|
-
baseline_df,
|
97
|
-
self._fully_qualified_model_name,
|
98
|
-
self._version_name,
|
99
|
-
statement_params=statement_params,
|
100
|
-
)
|
101
|
-
|
102
31
|
def suspend(self) -> None:
|
103
32
|
"""Suspend pipeline for ModelMonitor"""
|
104
|
-
statement_params
|
105
|
-
telemetry.TelemetryProject.MLOPS.value,
|
106
|
-
telemetry.TelemetrySubProject.MONITORING.value,
|
107
|
-
)
|
108
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
|
109
|
-
self._model_monitor_client.suspend_monitor_dynamic_tables(
|
110
|
-
model_name=model_name,
|
111
|
-
version_name=self._version_name,
|
112
|
-
statement_params=statement_params,
|
113
|
-
)
|
33
|
+
self._model_monitor_client.suspend_monitor(self.name, statement_params=self.statement_params)
|
114
34
|
|
115
35
|
def resume(self) -> None:
|
116
36
|
"""Resume pipeline for ModelMonitor"""
|
117
|
-
statement_params
|
118
|
-
telemetry.TelemetryProject.MLOPS.value,
|
119
|
-
telemetry.TelemetrySubProject.MONITORING.value,
|
120
|
-
)
|
121
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
|
122
|
-
self._model_monitor_client.resume_monitor_dynamic_tables(
|
123
|
-
model_name=model_name,
|
124
|
-
version_name=self._version_name,
|
125
|
-
statement_params=statement_params,
|
126
|
-
)
|
37
|
+
self._model_monitor_client.resume_monitor(self.name, statement_params=self.statement_params)
|