snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +51 -30
- snowflake/ml/model/_client/ops/service_ops.py +13 -2
- snowflake/ml/model/_client/sql/model.py +0 -14
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
- snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +48 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +71 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/model_signature.py +38 -9
- snowflake/ml/model/type_hints.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +158 -1045
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +106 -230
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +7 -96
- snowflake/ml/registry/registry.py +17 -29
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +31 -5
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +48 -47
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,36 +1,23 @@
|
|
1
|
-
import json
|
2
|
-
import string
|
3
|
-
import textwrap
|
4
1
|
import typing
|
5
2
|
from collections import Counter
|
6
|
-
from typing import Any, Dict, List, Mapping, Optional, Set
|
7
|
-
|
8
|
-
from importlib_resources import files
|
9
|
-
from typing_extensions import Required
|
3
|
+
from typing import Any, Dict, List, Mapping, Optional, Set
|
10
4
|
|
11
5
|
from snowflake import snowpark
|
12
|
-
from snowflake.connector import errors
|
13
6
|
from snowflake.ml._internal.utils import (
|
14
7
|
db_utils,
|
15
|
-
formatting,
|
16
8
|
query_result_checker,
|
17
9
|
sql_identifier,
|
18
10
|
table_manager,
|
19
11
|
)
|
20
|
-
from snowflake.ml.model import type_hints
|
21
12
|
from snowflake.ml.model._client.sql import _base
|
22
13
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
23
|
-
from snowflake.
|
24
|
-
from snowflake.ml.monitoring.entities.model_monitor_interval import (
|
25
|
-
ModelMonitorAggregationWindow,
|
26
|
-
ModelMonitorRefreshInterval,
|
27
|
-
)
|
28
|
-
from snowflake.snowpark import DataFrame, exceptions, session, types
|
29
|
-
from snowflake.snowpark._internal import type_utils
|
14
|
+
from snowflake.snowpark import session, types
|
30
15
|
|
31
16
|
SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
|
32
|
-
|
33
|
-
|
17
|
+
|
18
|
+
MODEL_JSON_COL_NAME = "model"
|
19
|
+
MODEL_JSON_MODEL_NAME_FIELD = "model_name"
|
20
|
+
MODEL_JSON_VERSION_NAME_FIELD = "version_name"
|
34
21
|
|
35
22
|
MONITOR_NAME_COL_NAME = "MONITOR_NAME"
|
36
23
|
SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
|
@@ -44,84 +31,10 @@ PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
|
|
44
31
|
LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
|
45
32
|
ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
|
46
33
|
|
47
|
-
_DASHBOARD_UDTFS_COMMON_LIST = ["record_count"]
|
48
|
-
_DASHBOARD_UDTFS_REGRESSION_LIST = ["rmse"]
|
49
|
-
|
50
|
-
|
51
|
-
def _initialize_monitoring_metadata_tables(
|
52
|
-
session: session.Session,
|
53
|
-
database_name: sql_identifier.SqlIdentifier,
|
54
|
-
schema_name: sql_identifier.SqlIdentifier,
|
55
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
56
|
-
) -> None:
|
57
|
-
"""Create tables necessary for Model Monitoring in provided schema.
|
58
|
-
|
59
|
-
Args:
|
60
|
-
session: Active Snowpark session.
|
61
|
-
database_name: The database in which to setup resources for Model Monitoring.
|
62
|
-
schema_name: The schema in which to setup resources for Model Monitoring.
|
63
|
-
statement_params: Optional statement params for queries.
|
64
|
-
"""
|
65
|
-
table_manager.create_single_table(
|
66
|
-
session,
|
67
|
-
database_name,
|
68
|
-
schema_name,
|
69
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME,
|
70
|
-
[
|
71
|
-
(MONITOR_NAME_COL_NAME, "VARCHAR"),
|
72
|
-
(SOURCE_TABLE_NAME_COL_NAME, "VARCHAR"),
|
73
|
-
(FQ_MODEL_NAME_COL_NAME, "VARCHAR"),
|
74
|
-
(VERSION_NAME_COL_NAME, "VARCHAR"),
|
75
|
-
(FUNCTION_NAME_COL_NAME, "VARCHAR"),
|
76
|
-
(TASK_COL_NAME, "VARCHAR"),
|
77
|
-
(MONITORING_ENABLED_COL_NAME, "BOOLEAN"),
|
78
|
-
(TIMESTAMP_COL_NAME_COL_NAME, "VARCHAR"),
|
79
|
-
(PREDICTION_COL_NAMES_COL_NAME, "ARRAY"),
|
80
|
-
(LABEL_COL_NAMES_COL_NAME, "ARRAY"),
|
81
|
-
(ID_COL_NAMES_COL_NAME, "ARRAY"),
|
82
|
-
],
|
83
|
-
statement_params=statement_params,
|
84
|
-
)
|
85
|
-
|
86
|
-
|
87
|
-
def _create_baseline_table_name(model_name: str, version_name: str) -> str:
|
88
|
-
return f"_SNOWML_OBS_BASELINE_{model_name}_{version_name}"
|
89
|
-
|
90
|
-
|
91
|
-
def _infer_numeric_categoric_feature_column_names(
|
92
|
-
*,
|
93
|
-
source_table_schema: Mapping[str, types.DataType],
|
94
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
95
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
96
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
97
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
98
|
-
) -> Tuple[List[sql_identifier.SqlIdentifier], List[sql_identifier.SqlIdentifier]]:
|
99
|
-
cols_to_remove = {timestamp_column, *id_columns, *prediction_columns, *label_columns}
|
100
|
-
cols_to_consider = [
|
101
|
-
(col_name, source_table_schema[col_name]) for col_name in source_table_schema if col_name not in cols_to_remove
|
102
|
-
]
|
103
|
-
numeric_cols = [
|
104
|
-
sql_identifier.SqlIdentifier(column[0])
|
105
|
-
for column in cols_to_consider
|
106
|
-
if isinstance(column[1], types._NumericType)
|
107
|
-
]
|
108
|
-
categorical_cols = [
|
109
|
-
sql_identifier.SqlIdentifier(column[0])
|
110
|
-
for column in cols_to_consider
|
111
|
-
if isinstance(column[1], types.StringType) or isinstance(column[1], types.BooleanType)
|
112
|
-
]
|
113
|
-
return (numeric_cols, categorical_cols)
|
114
|
-
|
115
|
-
|
116
|
-
class _ModelMonitorParams(TypedDict):
|
117
|
-
"""Class to transfer model monitor parameters to the ModelMonitor class."""
|
118
34
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
function_name: Required[str]
|
123
|
-
prediction_columns: Required[List[sql_identifier.SqlIdentifier]]
|
124
|
-
label_columns: Required[List[sql_identifier.SqlIdentifier]]
|
35
|
+
def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
|
36
|
+
sql_list = ", ".join([f"'{column}'" for column in columns])
|
37
|
+
return f"({sql_list})"
|
125
38
|
|
126
39
|
|
127
40
|
class ModelMonitorSQLClient:
|
@@ -143,38 +56,95 @@ class ModelMonitorSQLClient:
|
|
143
56
|
self._database_name = database_name
|
144
57
|
self._schema_name = schema_name
|
145
58
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
database_name
|
150
|
-
|
59
|
+
def _infer_qualified_schema(
|
60
|
+
self, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier]
|
61
|
+
) -> str:
|
62
|
+
return f"{database_name or self._database_name}.{schema_name or self._schema_name}"
|
63
|
+
|
64
|
+
def create_model_monitor(
|
65
|
+
self,
|
66
|
+
*,
|
67
|
+
monitor_database: Optional[sql_identifier.SqlIdentifier],
|
68
|
+
monitor_schema: Optional[sql_identifier.SqlIdentifier],
|
69
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
70
|
+
source_database: Optional[sql_identifier.SqlIdentifier],
|
71
|
+
source_schema: Optional[sql_identifier.SqlIdentifier],
|
72
|
+
source: sql_identifier.SqlIdentifier,
|
73
|
+
model_database: Optional[sql_identifier.SqlIdentifier],
|
74
|
+
model_schema: Optional[sql_identifier.SqlIdentifier],
|
75
|
+
model_name: sql_identifier.SqlIdentifier,
|
76
|
+
version_name: sql_identifier.SqlIdentifier,
|
77
|
+
function_name: str,
|
78
|
+
warehouse_name: sql_identifier.SqlIdentifier,
|
79
|
+
timestamp_column: sql_identifier.SqlIdentifier,
|
80
|
+
id_columns: List[sql_identifier.SqlIdentifier],
|
81
|
+
prediction_score_columns: List[sql_identifier.SqlIdentifier],
|
82
|
+
prediction_class_columns: List[sql_identifier.SqlIdentifier],
|
83
|
+
actual_score_columns: List[sql_identifier.SqlIdentifier],
|
84
|
+
actual_class_columns: List[sql_identifier.SqlIdentifier],
|
85
|
+
refresh_interval: str,
|
86
|
+
aggregation_window: str,
|
87
|
+
baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
|
88
|
+
baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
|
89
|
+
baseline: Optional[sql_identifier.SqlIdentifier] = None,
|
151
90
|
statement_params: Optional[Dict[str, Any]] = None,
|
152
91
|
) -> None:
|
153
|
-
""
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
92
|
+
baseline_sql = ""
|
93
|
+
if baseline:
|
94
|
+
baseline_sql = f"BASELINE='{self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}'"
|
95
|
+
query_result_checker.SqlResultValidator(
|
96
|
+
self._sql_client._session,
|
97
|
+
f"""
|
98
|
+
CREATE MODEL MONITOR {self._infer_qualified_schema(monitor_database, monitor_schema)}.{monitor_name}
|
99
|
+
WITH
|
100
|
+
MODEL='{self._infer_qualified_schema(model_database, model_schema)}.{model_name}'
|
101
|
+
VERSION='{version_name}'
|
102
|
+
FUNCTION='{function_name}'
|
103
|
+
WAREHOUSE='{warehouse_name}'
|
104
|
+
SOURCE='{self._infer_qualified_schema(source_database, source_schema)}.{source}'
|
105
|
+
ID_COLUMNS={_build_sql_list_from_columns(id_columns)}
|
106
|
+
PREDICTION_SCORE_COLUMNS={_build_sql_list_from_columns(prediction_score_columns)}
|
107
|
+
PREDICTION_CLASS_COLUMNS={_build_sql_list_from_columns(prediction_class_columns)}
|
108
|
+
ACTUAL_SCORE_COLUMNS={_build_sql_list_from_columns(actual_score_columns)}
|
109
|
+
ACTUAL_CLASS_COLUMNS={_build_sql_list_from_columns(actual_class_columns)}
|
110
|
+
TIMESTAMP_COLUMN='{timestamp_column}'
|
111
|
+
REFRESH_INTERVAL='{refresh_interval}'
|
112
|
+
AGGREGATION_WINDOW='{aggregation_window}'
|
113
|
+
{baseline_sql}""",
|
114
|
+
statement_params=statement_params,
|
115
|
+
).has_column("status").has_dimensions(1, 1).validate()
|
163
116
|
|
164
|
-
def
|
165
|
-
|
117
|
+
def drop_model_monitor(
|
118
|
+
self,
|
119
|
+
*,
|
120
|
+
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
121
|
+
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
122
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
123
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
124
|
+
) -> None:
|
125
|
+
search_database_name = database_name or self._database_name
|
126
|
+
search_schema_name = schema_name or self._schema_name
|
127
|
+
query_result_checker.SqlResultValidator(
|
128
|
+
self._sql_client._session,
|
129
|
+
f"DROP MODEL MONITOR {search_database_name}.{search_schema_name}.{monitor_name}",
|
130
|
+
statement_params=statement_params,
|
131
|
+
).validate()
|
166
132
|
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
133
|
+
def show_model_monitors(
|
134
|
+
self,
|
135
|
+
*,
|
136
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
137
|
+
) -> List[snowpark.Row]:
|
138
|
+
fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
139
|
+
return (
|
140
|
+
query_result_checker.SqlResultValidator(
|
172
141
|
self._sql_client._session,
|
173
|
-
|
174
|
-
|
142
|
+
f"SHOW MODEL MONITORS IN {fully_qualified_schema_name}",
|
143
|
+
statement_params=statement_params,
|
175
144
|
)
|
176
|
-
|
177
|
-
|
145
|
+
.has_column("name", allow_empty=True)
|
146
|
+
.validate()
|
147
|
+
)
|
178
148
|
|
179
149
|
def _validate_unique_columns(
|
180
150
|
self,
|
@@ -191,53 +161,24 @@ class ModelMonitorSQLClient:
|
|
191
161
|
|
192
162
|
def validate_existence_by_name(
|
193
163
|
self,
|
164
|
+
*,
|
165
|
+
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
166
|
+
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
194
167
|
monitor_name: sql_identifier.SqlIdentifier,
|
195
168
|
statement_params: Optional[Dict[str, Any]] = None,
|
196
169
|
) -> bool:
|
170
|
+
search_database_name = database_name or self._database_name
|
171
|
+
search_schema_name = schema_name or self._schema_name
|
197
172
|
res = (
|
198
173
|
query_result_checker.SqlResultValidator(
|
199
174
|
self._sql_client._session,
|
200
|
-
f"
|
201
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
202
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
|
203
|
-
statement_params=statement_params,
|
204
|
-
)
|
205
|
-
.has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True)
|
206
|
-
.has_column(VERSION_NAME_COL_NAME, allow_empty=True)
|
207
|
-
.validate()
|
208
|
-
)
|
209
|
-
return len(res) >= 1
|
210
|
-
|
211
|
-
def validate_existence(
|
212
|
-
self,
|
213
|
-
fully_qualified_model_name: str,
|
214
|
-
version_name: sql_identifier.SqlIdentifier,
|
215
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
216
|
-
) -> bool:
|
217
|
-
"""Validate existence of a ModelMonitor on a Model Version.
|
218
|
-
|
219
|
-
Args:
|
220
|
-
fully_qualified_model_name: Fully qualified name of model.
|
221
|
-
version_name: Name of model version.
|
222
|
-
statement_params: Optional set of statement_params to include with query.
|
223
|
-
|
224
|
-
Returns:
|
225
|
-
Boolean indicating whether monitor exists on model version.
|
226
|
-
"""
|
227
|
-
res = (
|
228
|
-
query_result_checker.SqlResultValidator(
|
229
|
-
self._sql_client._session,
|
230
|
-
f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}
|
231
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
232
|
-
WHERE {FQ_MODEL_NAME_COL_NAME} = '{fully_qualified_model_name}'
|
233
|
-
AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
|
175
|
+
f"SHOW MODEL MONITORS LIKE '{monitor_name.resolved()}' IN {search_database_name}.{search_schema_name}",
|
234
176
|
statement_params=statement_params,
|
235
177
|
)
|
236
|
-
.has_column(
|
237
|
-
.has_column(VERSION_NAME_COL_NAME, allow_empty=True)
|
178
|
+
.has_column("name", allow_empty=True)
|
238
179
|
.validate()
|
239
180
|
)
|
240
|
-
return len(res)
|
181
|
+
return len(res) == 1
|
241
182
|
|
242
183
|
def validate_monitor_warehouse(
|
243
184
|
self,
|
@@ -261,115 +202,47 @@ class ModelMonitorSQLClient:
|
|
261
202
|
):
|
262
203
|
raise ValueError(f"Warehouse '{warehouse_name}' not found.")
|
263
204
|
|
264
|
-
def
|
265
|
-
self,
|
266
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
267
|
-
model_name: sql_identifier.SqlIdentifier,
|
268
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
269
|
-
task: type_hints.Task,
|
270
|
-
score_type: output_score_type.OutputScoreType,
|
271
|
-
output_columns: List[sql_identifier.SqlIdentifier],
|
272
|
-
ground_truth_columns: List[sql_identifier.SqlIdentifier],
|
273
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
274
|
-
) -> None:
|
275
|
-
udtf_name_query_map = self._create_dashboard_udtf_queries(
|
276
|
-
monitor_name,
|
277
|
-
model_name,
|
278
|
-
model_version_name,
|
279
|
-
task,
|
280
|
-
score_type,
|
281
|
-
output_columns,
|
282
|
-
ground_truth_columns,
|
283
|
-
)
|
284
|
-
for udtf_query in udtf_name_query_map.values():
|
285
|
-
query_result_checker.SqlResultValidator(
|
286
|
-
self._sql_client._session,
|
287
|
-
f"""{udtf_query}""",
|
288
|
-
statement_params=statement_params,
|
289
|
-
).validate()
|
290
|
-
|
291
|
-
def get_monitoring_table_fully_qualified_name(
|
292
|
-
self,
|
293
|
-
model_name: sql_identifier.SqlIdentifier,
|
294
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
295
|
-
) -> str:
|
296
|
-
table_name = f"{_SNOWML_MONITORING_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
|
297
|
-
return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
|
298
|
-
|
299
|
-
def get_accuracy_monitoring_table_fully_qualified_name(
|
300
|
-
self,
|
301
|
-
model_name: sql_identifier.SqlIdentifier,
|
302
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
303
|
-
) -> str:
|
304
|
-
table_name = f"{_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
|
305
|
-
return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
|
306
|
-
|
307
|
-
def _create_dashboard_udtf_queries(
|
308
|
-
self,
|
309
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
310
|
-
model_name: sql_identifier.SqlIdentifier,
|
311
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
312
|
-
task: type_hints.Task,
|
313
|
-
score_type: output_score_type.OutputScoreType,
|
314
|
-
output_columns: List[sql_identifier.SqlIdentifier],
|
315
|
-
ground_truth_columns: List[sql_identifier.SqlIdentifier],
|
316
|
-
) -> Mapping[str, str]:
|
317
|
-
query_files = files("snowflake.ml.monitoring._client")
|
318
|
-
# TODO(apgupta): Expand list of queries based on model objective and score type.
|
319
|
-
queries_list = []
|
320
|
-
queries_list.extend(_DASHBOARD_UDTFS_COMMON_LIST)
|
321
|
-
if task == type_hints.Task.TABULAR_REGRESSION:
|
322
|
-
queries_list.extend(_DASHBOARD_UDTFS_REGRESSION_LIST)
|
323
|
-
var_map = {
|
324
|
-
"MODEL_MONITOR_NAME": monitor_name,
|
325
|
-
"MONITORING_TABLE": self.get_monitoring_table_fully_qualified_name(model_name, model_version_name),
|
326
|
-
"MONITORING_PRED_LABEL_JOINED_TABLE": self.get_accuracy_monitoring_table_fully_qualified_name(
|
327
|
-
model_name, model_version_name
|
328
|
-
),
|
329
|
-
"OUTPUT_COLUMN_NAME": output_columns[0],
|
330
|
-
"GROUND_TRUTH_COLUMN_NAME": ground_truth_columns[0],
|
331
|
-
}
|
332
|
-
|
333
|
-
udf_name_query_map = {}
|
334
|
-
for q in queries_list:
|
335
|
-
q_template = query_files.joinpath(f"queries/{q}.ssql").read_text()
|
336
|
-
q_actual = string.Template(q_template).substitute(var_map)
|
337
|
-
udf_name_query_map[q] = q_actual
|
338
|
-
return udf_name_query_map
|
339
|
-
|
340
|
-
def _validate_columns_exist_in_source_table(
|
205
|
+
def _validate_columns_exist_in_source(
|
341
206
|
self,
|
342
207
|
*,
|
343
|
-
|
344
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
208
|
+
source_column_schema: Mapping[str, types.DataType],
|
345
209
|
timestamp_column: sql_identifier.SqlIdentifier,
|
346
|
-
|
347
|
-
|
210
|
+
prediction_score_columns: List[sql_identifier.SqlIdentifier],
|
211
|
+
prediction_class_columns: List[sql_identifier.SqlIdentifier],
|
212
|
+
actual_score_columns: List[sql_identifier.SqlIdentifier],
|
213
|
+
actual_class_columns: List[sql_identifier.SqlIdentifier],
|
348
214
|
id_columns: List[sql_identifier.SqlIdentifier],
|
349
215
|
) -> None:
|
350
216
|
"""Ensures all columns exist in the source table.
|
351
217
|
|
352
218
|
Args:
|
353
|
-
|
354
|
-
source_table_name: Name of the table with model data to monitor.
|
219
|
+
source_column_schema: Dictionary of column names and types in the source.
|
355
220
|
timestamp_column: Name of the timestamp column.
|
356
|
-
|
357
|
-
|
221
|
+
prediction_score_columns: List of prediction score column names.
|
222
|
+
prediction_class_columns: List of prediction class names.
|
223
|
+
actual_score_columns: List of actual score column names.
|
224
|
+
actual_class_columns: List of actual class column names.
|
358
225
|
id_columns: List of id column names.
|
359
226
|
|
360
227
|
Raises:
|
361
|
-
ValueError: If any of the columns do not exist in the source
|
228
|
+
ValueError: If any of the columns do not exist in the source.
|
362
229
|
"""
|
363
230
|
|
364
|
-
if timestamp_column not in
|
365
|
-
raise ValueError(f"Timestamp column {timestamp_column} does not exist in
|
231
|
+
if timestamp_column not in source_column_schema:
|
232
|
+
raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.")
|
366
233
|
|
367
|
-
if not all([column_name in
|
368
|
-
raise ValueError(f"Prediction column(s): {
|
369
|
-
if not all([column_name in
|
370
|
-
raise ValueError(f"
|
371
|
-
if not all([column_name in
|
372
|
-
raise ValueError(f"
|
234
|
+
if not all([column_name in source_column_schema for column_name in prediction_score_columns]):
|
235
|
+
raise ValueError(f"Prediction Score column(s): {prediction_score_columns} do not exist in source.")
|
236
|
+
if not all([column_name in source_column_schema for column_name in prediction_class_columns]):
|
237
|
+
raise ValueError(f"Prediction Class column(s): {prediction_class_columns} do not exist in source.")
|
238
|
+
if not all([column_name in source_column_schema for column_name in actual_score_columns]):
|
239
|
+
raise ValueError(f"Actual Score column(s): {actual_score_columns} do not exist in source.")
|
240
|
+
|
241
|
+
if not all([column_name in source_column_schema for column_name in actual_class_columns]):
|
242
|
+
raise ValueError(f"Actual Class column(s): {actual_class_columns} do not exist in source.")
|
243
|
+
|
244
|
+
if not all([column_name in source_column_schema for column_name in id_columns]):
|
245
|
+
raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
|
373
246
|
|
374
247
|
def _validate_timestamp_column_type(
|
375
248
|
self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
|
@@ -490,190 +363,37 @@ class ModelMonitorSQLClient:
|
|
490
363
|
f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
|
491
364
|
)
|
492
365
|
|
493
|
-
def
|
494
|
-
self,
|
495
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
496
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
497
|
-
) -> _ModelMonitorParams:
|
498
|
-
"""Fetch metadata for a Model Monitor by name.
|
499
|
-
|
500
|
-
Args:
|
501
|
-
monitor_name: Name of ModelMonitor to fetch.
|
502
|
-
statement_params: Optional set of statement_params to include with query.
|
503
|
-
|
504
|
-
Returns:
|
505
|
-
_ModelMonitorParams dict with Name of monitor, fully qualified model name,
|
506
|
-
model version name, model function name, prediction_col, label_col.
|
507
|
-
|
508
|
-
Raises:
|
509
|
-
ValueError: If multiple ModelMonitors exist with the same name.
|
510
|
-
"""
|
511
|
-
try:
|
512
|
-
res = (
|
513
|
-
query_result_checker.SqlResultValidator(
|
514
|
-
self._sql_client._session,
|
515
|
-
f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME},
|
516
|
-
{PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME}
|
517
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
518
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
|
519
|
-
statement_params=statement_params,
|
520
|
-
)
|
521
|
-
.has_column(FQ_MODEL_NAME_COL_NAME)
|
522
|
-
.has_column(VERSION_NAME_COL_NAME)
|
523
|
-
.has_column(FUNCTION_NAME_COL_NAME)
|
524
|
-
.has_column(PREDICTION_COL_NAMES_COL_NAME)
|
525
|
-
.has_column(LABEL_COL_NAMES_COL_NAME)
|
526
|
-
.validate()
|
527
|
-
)
|
528
|
-
except errors.DataError:
|
529
|
-
raise ValueError(f"Failed to find any monitor with name '{monitor_name}'")
|
530
|
-
|
531
|
-
if len(res) > 1:
|
532
|
-
raise ValueError(f"Invalid state. Multiple Monitors exist with name '{monitor_name}'")
|
533
|
-
|
534
|
-
return _ModelMonitorParams(
|
535
|
-
monitor_name=str(monitor_name),
|
536
|
-
fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
|
537
|
-
version_name=res[0][VERSION_NAME_COL_NAME],
|
538
|
-
function_name=res[0][FUNCTION_NAME_COL_NAME],
|
539
|
-
prediction_columns=[
|
540
|
-
sql_identifier.SqlIdentifier(prediction_column)
|
541
|
-
for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
|
542
|
-
],
|
543
|
-
label_columns=[
|
544
|
-
sql_identifier.SqlIdentifier(label_column)
|
545
|
-
for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
|
546
|
-
],
|
547
|
-
)
|
548
|
-
|
549
|
-
def get_model_monitor_by_model_version(
|
366
|
+
def validate_source(
|
550
367
|
self,
|
551
368
|
*,
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
version_name: sql_identifier.SqlIdentifier,
|
556
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
557
|
-
) -> _ModelMonitorParams:
|
558
|
-
"""Fetch metadata for a Model Monitor by model version.
|
559
|
-
|
560
|
-
Args:
|
561
|
-
model_db: Database of model.
|
562
|
-
model_schema: Schema of model.
|
563
|
-
model_name: Model name.
|
564
|
-
version_name: Model version name
|
565
|
-
statement_params: Optional set of statement_params to include with queries.
|
566
|
-
|
567
|
-
Returns:
|
568
|
-
_ModelMonitorParams dict with Name of monitor, fully qualified model name,
|
569
|
-
model version name, model function name, prediction_col, label_col.
|
570
|
-
|
571
|
-
Raises:
|
572
|
-
ValueError: If multiple ModelMonitors exist with the same name.
|
573
|
-
"""
|
574
|
-
res = (
|
575
|
-
query_result_checker.SqlResultValidator(
|
576
|
-
self._sql_client._session,
|
577
|
-
f"""SELECT {MONITOR_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
|
578
|
-
{VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}
|
579
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
580
|
-
WHERE {FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{model_name}'
|
581
|
-
AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
|
582
|
-
statement_params=statement_params,
|
583
|
-
)
|
584
|
-
.has_column(MONITOR_NAME_COL_NAME)
|
585
|
-
.has_column(FQ_MODEL_NAME_COL_NAME)
|
586
|
-
.has_column(VERSION_NAME_COL_NAME)
|
587
|
-
.has_column(FUNCTION_NAME_COL_NAME)
|
588
|
-
.validate()
|
589
|
-
)
|
590
|
-
if len(res) > 1:
|
591
|
-
raise ValueError(
|
592
|
-
f"Invalid state. Multiple Monitors exist for model: '{model_name}' and version: '{version_name}'"
|
593
|
-
)
|
594
|
-
return _ModelMonitorParams(
|
595
|
-
monitor_name=res[0][MONITOR_NAME_COL_NAME],
|
596
|
-
fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
|
597
|
-
version_name=res[0][VERSION_NAME_COL_NAME],
|
598
|
-
function_name=res[0][FUNCTION_NAME_COL_NAME],
|
599
|
-
prediction_columns=[
|
600
|
-
sql_identifier.SqlIdentifier(prediction_column)
|
601
|
-
for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
|
602
|
-
],
|
603
|
-
label_columns=[
|
604
|
-
sql_identifier.SqlIdentifier(label_column)
|
605
|
-
for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
|
606
|
-
],
|
607
|
-
)
|
608
|
-
|
609
|
-
def get_score_type(
|
610
|
-
self,
|
611
|
-
task: type_hints.Task,
|
612
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
613
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
614
|
-
) -> output_score_type.OutputScoreType:
|
615
|
-
"""Infer score type given model task and prediction table columns.
|
616
|
-
|
617
|
-
Args:
|
618
|
-
task: Model task
|
619
|
-
source_table_name: Source data table containing model outputs.
|
620
|
-
prediction_columns: columns in source data table corresponding to model outputs.
|
621
|
-
|
622
|
-
Returns:
|
623
|
-
OutputScoreType for model.
|
624
|
-
"""
|
625
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
626
|
-
self._sql_client._session,
|
627
|
-
self._database_name,
|
628
|
-
self._schema_name,
|
629
|
-
source_table_name,
|
630
|
-
)
|
631
|
-
return output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_columns, task)
|
632
|
-
|
633
|
-
def validate_source_table(
|
634
|
-
self,
|
635
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
369
|
+
source_database: Optional[sql_identifier.SqlIdentifier],
|
370
|
+
source_schema: Optional[sql_identifier.SqlIdentifier],
|
371
|
+
source: sql_identifier.SqlIdentifier,
|
636
372
|
timestamp_column: sql_identifier.SqlIdentifier,
|
637
|
-
|
638
|
-
|
373
|
+
prediction_score_columns: List[sql_identifier.SqlIdentifier],
|
374
|
+
prediction_class_columns: List[sql_identifier.SqlIdentifier],
|
375
|
+
actual_score_columns: List[sql_identifier.SqlIdentifier],
|
376
|
+
actual_class_columns: List[sql_identifier.SqlIdentifier],
|
639
377
|
id_columns: List[sql_identifier.SqlIdentifier],
|
640
|
-
model_function: model_manifest_schema.ModelFunctionInfo,
|
641
378
|
) -> None:
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
f"{self._database_name}.{self._schema_name}",
|
647
|
-
):
|
648
|
-
raise ValueError(
|
649
|
-
f"Table {source_table_name} does not exist in schema {self._database_name}.{self._schema_name}."
|
650
|
-
)
|
651
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
379
|
+
source_database = source_database or self._database_name
|
380
|
+
source_schema = source_schema or self._schema_name
|
381
|
+
# Get Schema of the source. Implicitly validates that the source exists.
|
382
|
+
source_column_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
652
383
|
self._sql_client._session,
|
653
|
-
|
654
|
-
|
655
|
-
|
384
|
+
source_database,
|
385
|
+
source_schema,
|
386
|
+
source,
|
656
387
|
)
|
657
|
-
self.
|
658
|
-
|
659
|
-
source_table_name=source_table_name,
|
388
|
+
self._validate_columns_exist_in_source(
|
389
|
+
source_column_schema=source_column_schema,
|
660
390
|
timestamp_column=timestamp_column,
|
661
|
-
|
662
|
-
|
391
|
+
prediction_score_columns=prediction_score_columns,
|
392
|
+
prediction_class_columns=prediction_class_columns,
|
393
|
+
actual_score_columns=actual_score_columns,
|
394
|
+
actual_class_columns=actual_class_columns,
|
663
395
|
id_columns=id_columns,
|
664
396
|
)
|
665
|
-
self._validate_column_types(
|
666
|
-
table_schema=table_schema,
|
667
|
-
timestamp_column=timestamp_column,
|
668
|
-
id_columns=id_columns,
|
669
|
-
prediction_columns=prediction_columns,
|
670
|
-
label_columns=label_columns,
|
671
|
-
)
|
672
|
-
self._validate_source_table_features_shape(
|
673
|
-
table_schema=table_schema,
|
674
|
-
special_columns={timestamp_column, *id_columns, *prediction_columns, *label_columns},
|
675
|
-
model_function=model_function,
|
676
|
-
)
|
677
397
|
|
678
398
|
def delete_monitor_metadata(
|
679
399
|
self,
|
@@ -691,645 +411,38 @@ class ModelMonitorSQLClient:
|
|
691
411
|
WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
|
692
412
|
).collect(statement_params=statement_params)
|
693
413
|
|
694
|
-
def
|
695
|
-
self,
|
696
|
-
fully_qualified_model_name: str,
|
697
|
-
version_name: str,
|
698
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
699
|
-
) -> None:
|
700
|
-
"""Delete the baseline table corresponding to a particular model and version.
|
701
|
-
|
702
|
-
Args:
|
703
|
-
fully_qualified_model_name: Fully qualified name of the model.
|
704
|
-
version_name: Name of the model version.
|
705
|
-
statement_params: Optional set of statement_params to include with query.
|
706
|
-
"""
|
707
|
-
table_name = _create_baseline_table_name(fully_qualified_model_name, version_name)
|
708
|
-
self._sql_client._session.sql(
|
709
|
-
f"""DROP TABLE IF EXISTS {self._database_name}.{self._schema_name}.{table_name}"""
|
710
|
-
).collect(statement_params=statement_params)
|
711
|
-
|
712
|
-
def delete_dynamic_tables(
|
713
|
-
self,
|
714
|
-
fully_qualified_model_name: str,
|
715
|
-
version_name: str,
|
716
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
717
|
-
) -> None:
|
718
|
-
"""Delete the dynamic tables corresponding to a particular model and version.
|
719
|
-
|
720
|
-
Args:
|
721
|
-
fully_qualified_model_name: Fully qualified name of the model.
|
722
|
-
version_name: Name of the model version.
|
723
|
-
statement_params: Optional set of statement_params to include with query.
|
724
|
-
"""
|
725
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
|
726
|
-
model_id = sql_identifier.SqlIdentifier(model_name)
|
727
|
-
version_id = sql_identifier.SqlIdentifier(version_name)
|
728
|
-
monitoring_table_name = self.get_monitoring_table_fully_qualified_name(model_id, version_id)
|
729
|
-
self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {monitoring_table_name}""").collect(
|
730
|
-
statement_params=statement_params
|
731
|
-
)
|
732
|
-
accuracy_table_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_id, version_id)
|
733
|
-
self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {accuracy_table_name}""").collect(
|
734
|
-
statement_params=statement_params
|
735
|
-
)
|
736
|
-
|
737
|
-
def create_monitor_on_model_version(
|
738
|
-
self,
|
739
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
740
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
741
|
-
fully_qualified_model_name: str,
|
742
|
-
version_name: sql_identifier.SqlIdentifier,
|
743
|
-
function_name: str,
|
744
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
745
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
746
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
747
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
748
|
-
task: type_hints.Task,
|
749
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
750
|
-
) -> None:
|
751
|
-
"""
|
752
|
-
Creates a ModelMonitor on a Model Version from the Snowflake Model Registry. Creates public schema for metadata.
|
753
|
-
|
754
|
-
Args:
|
755
|
-
monitor_name: Name of monitor object to create.
|
756
|
-
source_table_name: Name of source data table to monitor.
|
757
|
-
fully_qualified_model_name: fully qualified name of model to monitor '<db>.<schema>.<model_name>'.
|
758
|
-
version_name: model version name to monitor.
|
759
|
-
function_name: function_name to monitor in model version.
|
760
|
-
timestamp_column: timestamp column name.
|
761
|
-
prediction_columns: list of prediction column names.
|
762
|
-
label_columns: list of label column names.
|
763
|
-
id_columns: list of id column names.
|
764
|
-
task: Task of the model, e.g. TABULAR_REGRESSION.
|
765
|
-
statement_params: Optional dict of statement_params to include with queries.
|
766
|
-
|
767
|
-
Raises:
|
768
|
-
ValueError: If model version is already monitored.
|
769
|
-
"""
|
770
|
-
# Validate monitor does not already exist on model version.
|
771
|
-
if self.validate_existence(fully_qualified_model_name, version_name, statement_params):
|
772
|
-
raise ValueError(f"Model {fully_qualified_model_name} Version {version_name} is already monitored!")
|
773
|
-
|
774
|
-
if self.validate_existence_by_name(monitor_name, statement_params):
|
775
|
-
raise ValueError(f"Model Monitor with name '{monitor_name}' already exists!")
|
776
|
-
|
777
|
-
prediction_columns_for_select = formatting.format_value_for_select(prediction_columns)
|
778
|
-
label_columns_for_select = formatting.format_value_for_select(label_columns)
|
779
|
-
id_columns_for_select = formatting.format_value_for_select(id_columns)
|
780
|
-
query_result_checker.SqlResultValidator(
|
781
|
-
self._sql_client._session,
|
782
|
-
textwrap.dedent(
|
783
|
-
f"""INSERT INTO {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
784
|
-
({MONITOR_NAME_COL_NAME}, {SOURCE_TABLE_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
|
785
|
-
{VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, {TASK_COL_NAME},
|
786
|
-
{MONITORING_ENABLED_COL_NAME}, {TIMESTAMP_COL_NAME_COL_NAME},
|
787
|
-
{PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME},
|
788
|
-
{ID_COL_NAMES_COL_NAME})
|
789
|
-
SELECT '{monitor_name}', '{source_table_name}', '{fully_qualified_model_name}',
|
790
|
-
'{version_name}', '{function_name}', '{task.value}', TRUE, '{timestamp_column}',
|
791
|
-
{prediction_columns_for_select}, {label_columns_for_select}, {id_columns_for_select}"""
|
792
|
-
),
|
793
|
-
statement_params=statement_params,
|
794
|
-
).insertion_success(expected_num_rows=1).validate()
|
795
|
-
|
796
|
-
def initialize_baseline_table(
|
797
|
-
self,
|
798
|
-
model_name: sql_identifier.SqlIdentifier,
|
799
|
-
version_name: sql_identifier.SqlIdentifier,
|
800
|
-
source_table_name: str,
|
801
|
-
columns_to_drop: Optional[List[sql_identifier.SqlIdentifier]] = None,
|
802
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
803
|
-
) -> None:
|
804
|
-
"""
|
805
|
-
Initializes the baseline table for a Model Version. Creates schema for baseline data using the source table.
|
806
|
-
|
807
|
-
Args:
|
808
|
-
model_name: name of model to monitor.
|
809
|
-
version_name: model version name to monitor.
|
810
|
-
source_table_name: name of the user's table containing their model data.
|
811
|
-
columns_to_drop: special columns in the source table to be excluded from baseline tables.
|
812
|
-
statement_params: Optional dict of statement_params to include with queries.
|
813
|
-
"""
|
814
|
-
table_schema = table_manager.get_table_schema_types(
|
815
|
-
self._sql_client._session,
|
816
|
-
database=self._database_name,
|
817
|
-
schema=self._schema_name,
|
818
|
-
table_name=source_table_name,
|
819
|
-
)
|
820
|
-
|
821
|
-
if columns_to_drop is None:
|
822
|
-
columns_to_drop = []
|
823
|
-
|
824
|
-
table_manager.create_single_table(
|
825
|
-
self._sql_client._session,
|
826
|
-
self._database_name,
|
827
|
-
self._schema_name,
|
828
|
-
_create_baseline_table_name(model_name, version_name),
|
829
|
-
[
|
830
|
-
(k, type_utils.convert_sp_to_sf_type(v))
|
831
|
-
for k, v in table_schema.items()
|
832
|
-
if sql_identifier.SqlIdentifier(k) not in columns_to_drop
|
833
|
-
],
|
834
|
-
statement_params=statement_params,
|
835
|
-
)
|
836
|
-
|
837
|
-
def get_all_model_monitor_metadata(
|
838
|
-
self,
|
839
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
840
|
-
) -> List[snowpark.Row]:
|
841
|
-
"""Get the metadata for all model monitors in the given schema.
|
842
|
-
|
843
|
-
Args:
|
844
|
-
statement_params: Optional dict of statement_params to include with queries.
|
845
|
-
|
846
|
-
Returns:
|
847
|
-
List of snowpark.Row containing metadata for each model monitor.
|
848
|
-
"""
|
849
|
-
return query_result_checker.SqlResultValidator(
|
850
|
-
self._sql_client._session,
|
851
|
-
f"""SELECT *
|
852
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}""",
|
853
|
-
statement_params=statement_params,
|
854
|
-
).validate()
|
855
|
-
|
856
|
-
def materialize_baseline_dataframe(
|
857
|
-
self,
|
858
|
-
baseline_df: DataFrame,
|
859
|
-
fully_qualified_model_name: str,
|
860
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
861
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
862
|
-
) -> None:
|
863
|
-
"""
|
864
|
-
Materialize baseline dataframe to a permanent snowflake table. This method
|
865
|
-
truncates (overwrite without dropping) any existing data in the baseline table.
|
866
|
-
|
867
|
-
Args:
|
868
|
-
baseline_df: dataframe containing baseline data that monitored data will be compared against.
|
869
|
-
fully_qualified_model_name: name of the model.
|
870
|
-
model_version_name: model version name to monitor.
|
871
|
-
statement_params: Optional dict of statement_params to include with queries.
|
872
|
-
|
873
|
-
Raises:
|
874
|
-
ValueError: If no baseline table was initialized.
|
875
|
-
"""
|
876
|
-
|
877
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
|
878
|
-
baseline_table_name = _create_baseline_table_name(model_name, model_version_name)
|
879
|
-
|
880
|
-
baseline_table_exists = db_utils.db_object_exists(
|
881
|
-
self._sql_client._session,
|
882
|
-
db_utils.SnowflakeDbObjectType.TABLE,
|
883
|
-
sql_identifier.SqlIdentifier(baseline_table_name),
|
884
|
-
database_name=self._database_name,
|
885
|
-
schema_name=self._schema_name,
|
886
|
-
statement_params=statement_params,
|
887
|
-
)
|
888
|
-
if not baseline_table_exists:
|
889
|
-
raise ValueError(
|
890
|
-
f"Baseline table '{baseline_table_name}' does not exist for model: "
|
891
|
-
f"'{model_name}' and model_version: '{model_version_name}'"
|
892
|
-
)
|
893
|
-
|
894
|
-
fully_qualified_baseline_table_name = [self._database_name, self._schema_name, baseline_table_name]
|
895
|
-
|
896
|
-
try:
|
897
|
-
# Truncate overwrites by clearing the rows in the table, instead of dropping the table.
|
898
|
-
# This lets us keep the schema to validate the baseline_df against.
|
899
|
-
baseline_df.write.mode("truncate").save_as_table(
|
900
|
-
fully_qualified_baseline_table_name, statement_params=statement_params
|
901
|
-
)
|
902
|
-
except exceptions.SnowparkSQLException as e:
|
903
|
-
raise ValueError(
|
904
|
-
f"""Failed to save baseline dataframe.
|
905
|
-
Ensure that the baseline dataframe columns match those provided in your monitored table: {e}"""
|
906
|
-
)
|
907
|
-
|
908
|
-
def _alter_monitor_dynamic_tables(
|
414
|
+
def _alter_monitor(
|
909
415
|
self,
|
910
416
|
operation: str,
|
911
|
-
|
912
|
-
version_name: sql_identifier.SqlIdentifier,
|
417
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
913
418
|
statement_params: Optional[Dict[str, Any]] = None,
|
914
419
|
) -> None:
|
915
420
|
if operation not in {"SUSPEND", "RESUME"}:
|
916
421
|
raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
|
917
|
-
fq_monitor_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, version_name)
|
918
|
-
query_result_checker.SqlResultValidator(
|
919
|
-
self._sql_client._session,
|
920
|
-
f"""ALTER DYNAMIC TABLE {fq_monitor_dt_name} {operation}""",
|
921
|
-
statement_params=statement_params,
|
922
|
-
).has_column("status").has_dimensions(1, 1).validate()
|
923
|
-
|
924
|
-
fq_accuracy_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, version_name)
|
925
422
|
query_result_checker.SqlResultValidator(
|
926
423
|
self._sql_client._session,
|
927
|
-
f"""ALTER
|
424
|
+
f"""ALTER MODEL MONITOR {self._database_name}.{self._schema_name}.{monitor_name} {operation}""",
|
928
425
|
statement_params=statement_params,
|
929
426
|
).has_column("status").has_dimensions(1, 1).validate()
|
930
427
|
|
931
|
-
def
|
428
|
+
def suspend_monitor(
|
932
429
|
self,
|
933
|
-
|
934
|
-
version_name: sql_identifier.SqlIdentifier,
|
430
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
935
431
|
statement_params: Optional[Dict[str, Any]] = None,
|
936
432
|
) -> None:
|
937
|
-
self.
|
433
|
+
self._alter_monitor(
|
938
434
|
operation="SUSPEND",
|
939
|
-
|
940
|
-
version_name=version_name,
|
435
|
+
monitor_name=monitor_name,
|
941
436
|
statement_params=statement_params,
|
942
437
|
)
|
943
438
|
|
944
|
-
def
|
439
|
+
def resume_monitor(
|
945
440
|
self,
|
946
|
-
|
947
|
-
version_name: sql_identifier.SqlIdentifier,
|
441
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
948
442
|
statement_params: Optional[Dict[str, Any]] = None,
|
949
443
|
) -> None:
|
950
|
-
self.
|
444
|
+
self._alter_monitor(
|
951
445
|
operation="RESUME",
|
952
|
-
|
953
|
-
version_name=version_name,
|
446
|
+
monitor_name=monitor_name,
|
954
447
|
statement_params=statement_params,
|
955
448
|
)
|
956
|
-
|
957
|
-
def create_dynamic_tables_for_monitor(
|
958
|
-
self,
|
959
|
-
*,
|
960
|
-
model_name: sql_identifier.SqlIdentifier,
|
961
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
962
|
-
task: type_hints.Task,
|
963
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
964
|
-
refresh_interval: model_monitor_interval.ModelMonitorRefreshInterval,
|
965
|
-
aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow,
|
966
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
967
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
968
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
969
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
970
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
971
|
-
score_type: output_score_type.OutputScoreType,
|
972
|
-
) -> None:
|
973
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
974
|
-
self._sql_client._session,
|
975
|
-
self._database_name,
|
976
|
-
self._schema_name,
|
977
|
-
source_table_name,
|
978
|
-
)
|
979
|
-
(numeric_features_names, categorical_feature_names) = _infer_numeric_categoric_feature_column_names(
|
980
|
-
source_table_schema=table_schema,
|
981
|
-
timestamp_column=timestamp_column,
|
982
|
-
id_columns=id_columns,
|
983
|
-
prediction_columns=prediction_columns,
|
984
|
-
label_columns=label_columns,
|
985
|
-
)
|
986
|
-
features_dynamic_table_query = self._monitoring_dynamic_table_query(
|
987
|
-
model_name=model_name,
|
988
|
-
model_version_name=model_version_name,
|
989
|
-
source_table_name=source_table_name,
|
990
|
-
refresh_interval=refresh_interval,
|
991
|
-
aggregate_window=aggregation_window,
|
992
|
-
warehouse_name=warehouse_name,
|
993
|
-
timestamp_column=timestamp_column,
|
994
|
-
numeric_features=numeric_features_names,
|
995
|
-
categoric_features=categorical_feature_names,
|
996
|
-
prediction_columns=prediction_columns,
|
997
|
-
label_columns=label_columns,
|
998
|
-
)
|
999
|
-
query_result_checker.SqlResultValidator(self._sql_client._session, features_dynamic_table_query).has_column(
|
1000
|
-
"status"
|
1001
|
-
).has_dimensions(1, 1).validate()
|
1002
|
-
|
1003
|
-
label_pred_join_table_query = self._monitoring_accuracy_table_query(
|
1004
|
-
model_name=model_name,
|
1005
|
-
model_version_name=model_version_name,
|
1006
|
-
task=task,
|
1007
|
-
source_table_name=source_table_name,
|
1008
|
-
refresh_interval=refresh_interval,
|
1009
|
-
aggregate_window=aggregation_window,
|
1010
|
-
warehouse_name=warehouse_name,
|
1011
|
-
timestamp_column=timestamp_column,
|
1012
|
-
prediction_columns=prediction_columns,
|
1013
|
-
label_columns=label_columns,
|
1014
|
-
score_type=score_type,
|
1015
|
-
)
|
1016
|
-
query_result_checker.SqlResultValidator(self._sql_client._session, label_pred_join_table_query).has_column(
|
1017
|
-
"status"
|
1018
|
-
).has_dimensions(1, 1).validate()
|
1019
|
-
|
1020
|
-
def _monitoring_dynamic_table_query(
|
1021
|
-
self,
|
1022
|
-
*,
|
1023
|
-
model_name: sql_identifier.SqlIdentifier,
|
1024
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1025
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1026
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1027
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1028
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1029
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1030
|
-
numeric_features: List[sql_identifier.SqlIdentifier],
|
1031
|
-
categoric_features: List[sql_identifier.SqlIdentifier],
|
1032
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1033
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1034
|
-
) -> str:
|
1035
|
-
"""
|
1036
|
-
Generates a dynamic table query for Observability - Monitoring.
|
1037
|
-
|
1038
|
-
Args:
|
1039
|
-
model_name: Model name to monitor.
|
1040
|
-
model_version_name: Model version name to monitor.
|
1041
|
-
source_table_name: Name of source data table to monitor.
|
1042
|
-
refresh_interval: Refresh interval in minutes.
|
1043
|
-
aggregate_window: Aggregate window minutes.
|
1044
|
-
warehouse_name: Warehouse name to use for dynamic table.
|
1045
|
-
timestamp_column: Timestamp column name.
|
1046
|
-
numeric_features: List of numeric features to capture.
|
1047
|
-
categoric_features: List of categoric features to capture.
|
1048
|
-
prediction_columns: List of columns that contain model inference outputs.
|
1049
|
-
label_columns: List of columns that contain ground truth values.
|
1050
|
-
|
1051
|
-
Raises:
|
1052
|
-
ValueError: If multiple output/ground truth columns are specified. MultiClass models are not yet supported.
|
1053
|
-
|
1054
|
-
Returns:
|
1055
|
-
Dynamic table query.
|
1056
|
-
"""
|
1057
|
-
# output and ground cols are list to keep interface extensible.
|
1058
|
-
# for prpr only one label and one output col will be supported
|
1059
|
-
if len(prediction_columns) != 1 or len(label_columns) != 1:
|
1060
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
1061
|
-
|
1062
|
-
monitoring_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1063
|
-
|
1064
|
-
feature_cols_query_list = []
|
1065
|
-
for feature in numeric_features + prediction_columns + label_columns:
|
1066
|
-
feature_cols_query_list.append(
|
1067
|
-
"""
|
1068
|
-
OBJECT_CONSTRUCT(
|
1069
|
-
'sketch', APPROX_PERCENTILE_ACCUMULATE({col}),
|
1070
|
-
'count', count_if({col} is not null),
|
1071
|
-
'count_null', count_if({col} is null),
|
1072
|
-
'min', min({col}),
|
1073
|
-
'max', max({col}),
|
1074
|
-
'sum', sum({col})
|
1075
|
-
) AS {col}""".format(
|
1076
|
-
col=feature
|
1077
|
-
)
|
1078
|
-
)
|
1079
|
-
|
1080
|
-
for col in categoric_features:
|
1081
|
-
feature_cols_query_list.append(
|
1082
|
-
f"""
|
1083
|
-
{self._database_name}.{self._schema_name}.OBJECT_SUM(to_varchar({col})) AS {col}"""
|
1084
|
-
)
|
1085
|
-
feature_cols_query = ",".join(feature_cols_query_list)
|
1086
|
-
|
1087
|
-
return f"""
|
1088
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1089
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1090
|
-
WAREHOUSE = {warehouse_name}
|
1091
|
-
REFRESH_MODE = AUTO
|
1092
|
-
INITIALIZE = ON_CREATE
|
1093
|
-
AS
|
1094
|
-
SELECT
|
1095
|
-
TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,{feature_cols_query}
|
1096
|
-
FROM
|
1097
|
-
{source_table_name}
|
1098
|
-
GROUP BY
|
1099
|
-
1
|
1100
|
-
"""
|
1101
|
-
|
1102
|
-
def _monitoring_accuracy_table_query(
|
1103
|
-
self,
|
1104
|
-
*,
|
1105
|
-
model_name: sql_identifier.SqlIdentifier,
|
1106
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1107
|
-
task: type_hints.Task,
|
1108
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1109
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1110
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1111
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1112
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1113
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1114
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1115
|
-
score_type: output_score_type.OutputScoreType,
|
1116
|
-
) -> str:
|
1117
|
-
# output and ground cols are list to keep interface extensible.
|
1118
|
-
# for prpr only one label and one output col will be supported
|
1119
|
-
if len(prediction_columns) != 1 or len(label_columns) != 1:
|
1120
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
1121
|
-
if task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION:
|
1122
|
-
return self._monitoring_classification_accuracy_table_query(
|
1123
|
-
model_name=model_name,
|
1124
|
-
model_version_name=model_version_name,
|
1125
|
-
source_table_name=source_table_name,
|
1126
|
-
refresh_interval=refresh_interval,
|
1127
|
-
aggregate_window=aggregate_window,
|
1128
|
-
warehouse_name=warehouse_name,
|
1129
|
-
timestamp_column=timestamp_column,
|
1130
|
-
prediction_columns=prediction_columns,
|
1131
|
-
label_columns=label_columns,
|
1132
|
-
score_type=score_type,
|
1133
|
-
)
|
1134
|
-
else:
|
1135
|
-
return self._monitoring_regression_accuracy_table_query(
|
1136
|
-
model_name=model_name,
|
1137
|
-
model_version_name=model_version_name,
|
1138
|
-
source_table_name=source_table_name,
|
1139
|
-
refresh_interval=refresh_interval,
|
1140
|
-
aggregate_window=aggregate_window,
|
1141
|
-
warehouse_name=warehouse_name,
|
1142
|
-
timestamp_column=timestamp_column,
|
1143
|
-
prediction_columns=prediction_columns,
|
1144
|
-
label_columns=label_columns,
|
1145
|
-
)
|
1146
|
-
|
1147
|
-
def _monitoring_regression_accuracy_table_query(
|
1148
|
-
self,
|
1149
|
-
*,
|
1150
|
-
model_name: sql_identifier.SqlIdentifier,
|
1151
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1152
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1153
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1154
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1155
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1156
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1157
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1158
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1159
|
-
) -> str:
|
1160
|
-
"""
|
1161
|
-
Generates a dynamic table query for Monitoring - regression model accuracy.
|
1162
|
-
|
1163
|
-
Args:
|
1164
|
-
model_name: Model name to monitor.
|
1165
|
-
model_version_name: Model version name to monitor.
|
1166
|
-
source_table_name: Name of source data table to monitor.
|
1167
|
-
refresh_interval: Refresh interval in minutes.
|
1168
|
-
aggregate_window: Aggregate window minutes.
|
1169
|
-
warehouse_name: Warehouse name to use for dynamic table.
|
1170
|
-
timestamp_column: Timestamp column name.
|
1171
|
-
prediction_columns: List of output columns.
|
1172
|
-
label_columns: List of ground truth columns.
|
1173
|
-
|
1174
|
-
Returns:
|
1175
|
-
Dynamic table query.
|
1176
|
-
|
1177
|
-
Raises:
|
1178
|
-
ValueError: If output columns are not same as ground truth columns.
|
1179
|
-
|
1180
|
-
"""
|
1181
|
-
|
1182
|
-
if len(prediction_columns) != len(label_columns):
|
1183
|
-
raise ValueError(f"Mismatch in output & ground truth columns: {prediction_columns} != {label_columns}")
|
1184
|
-
|
1185
|
-
monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1186
|
-
|
1187
|
-
output_cols_query_list = []
|
1188
|
-
|
1189
|
-
output_cols_query_list.append(
|
1190
|
-
f"""
|
1191
|
-
OBJECT_CONSTRUCT(
|
1192
|
-
'sum_difference_label_pred', sum({prediction_columns[0]} - {label_columns[0]}),
|
1193
|
-
'sum_log_difference_square_label_pred',
|
1194
|
-
sum(
|
1195
|
-
case
|
1196
|
-
when {prediction_columns[0]} > -1 and {label_columns[0]} > -1
|
1197
|
-
then pow(ln({prediction_columns[0]} + 1) - ln({label_columns[0]} + 1),2)
|
1198
|
-
else null
|
1199
|
-
END
|
1200
|
-
),
|
1201
|
-
'sum_difference_squares_label_pred',
|
1202
|
-
sum(
|
1203
|
-
pow(
|
1204
|
-
{prediction_columns[0]} - {label_columns[0]},
|
1205
|
-
2
|
1206
|
-
)
|
1207
|
-
),
|
1208
|
-
'sum_absolute_regression_labels', sum(abs({label_columns[0]})),
|
1209
|
-
'sum_absolute_percentage_error',
|
1210
|
-
sum(
|
1211
|
-
abs(
|
1212
|
-
div0null(
|
1213
|
-
({prediction_columns[0]} - {label_columns[0]}),
|
1214
|
-
{label_columns[0]}
|
1215
|
-
)
|
1216
|
-
)
|
1217
|
-
),
|
1218
|
-
'sum_absolute_difference_label_pred',
|
1219
|
-
sum(
|
1220
|
-
abs({prediction_columns[0]} - {label_columns[0]})
|
1221
|
-
),
|
1222
|
-
'sum_prediction', sum({prediction_columns[0]}),
|
1223
|
-
'sum_label', sum({label_columns[0]}),
|
1224
|
-
'count', count(*)
|
1225
|
-
) AS AGGREGATE_METRICS,
|
1226
|
-
APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
|
1227
|
-
APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
|
1228
|
-
)
|
1229
|
-
output_cols_query = ", ".join(output_cols_query_list)
|
1230
|
-
|
1231
|
-
return f"""
|
1232
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1233
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1234
|
-
WAREHOUSE = {warehouse_name}
|
1235
|
-
REFRESH_MODE = AUTO
|
1236
|
-
INITIALIZE = ON_CREATE
|
1237
|
-
AS
|
1238
|
-
SELECT
|
1239
|
-
TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,
|
1240
|
-
'class_regression' label_class,{output_cols_query}
|
1241
|
-
FROM
|
1242
|
-
{source_table_name}
|
1243
|
-
GROUP BY
|
1244
|
-
1
|
1245
|
-
"""
|
1246
|
-
|
1247
|
-
def _monitoring_classification_accuracy_table_query(
|
1248
|
-
self,
|
1249
|
-
*,
|
1250
|
-
model_name: sql_identifier.SqlIdentifier,
|
1251
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1252
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1253
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1254
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1255
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1256
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1257
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1258
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1259
|
-
score_type: output_score_type.OutputScoreType,
|
1260
|
-
) -> str:
|
1261
|
-
monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1262
|
-
|
1263
|
-
# Initialize the select clause components
|
1264
|
-
select_clauses = []
|
1265
|
-
|
1266
|
-
select_clauses.append(
|
1267
|
-
f"""
|
1268
|
-
{prediction_columns[0]},
|
1269
|
-
{label_columns[0]},
|
1270
|
-
CASE
|
1271
|
-
WHEN {label_columns[0]} = 1 THEN 'class_positive'
|
1272
|
-
ELSE 'class_negative'
|
1273
|
-
END AS label_class"""
|
1274
|
-
)
|
1275
|
-
|
1276
|
-
# Join all the select clauses into a single string
|
1277
|
-
select_clause = f"{timestamp_column} AS timestamp," + ",".join(select_clauses)
|
1278
|
-
|
1279
|
-
# Create the final CTE query
|
1280
|
-
cte_query = f"""
|
1281
|
-
WITH filtered_data AS (
|
1282
|
-
SELECT
|
1283
|
-
{select_clause}
|
1284
|
-
FROM
|
1285
|
-
{source_table_name}
|
1286
|
-
)"""
|
1287
|
-
|
1288
|
-
# Initialize the select clause components
|
1289
|
-
select_clauses = []
|
1290
|
-
|
1291
|
-
score_type_agg_clause = ""
|
1292
|
-
if score_type == output_score_type.OutputScoreType.PROBITS:
|
1293
|
-
score_type_agg_clause = f"""
|
1294
|
-
'sum_log_loss',
|
1295
|
-
CASE
|
1296
|
-
WHEN label_class = 'class_positive' THEN sum(-ln({prediction_columns[0]}))
|
1297
|
-
ELSE sum(-ln(1 - {prediction_columns[0]}))
|
1298
|
-
END,"""
|
1299
|
-
else:
|
1300
|
-
score_type_agg_clause = f"""
|
1301
|
-
'tp', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 1),
|
1302
|
-
'tn', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 0),
|
1303
|
-
'fp', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 1),
|
1304
|
-
'fn', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 0),"""
|
1305
|
-
|
1306
|
-
select_clauses.append(
|
1307
|
-
f"""
|
1308
|
-
label_class,
|
1309
|
-
OBJECT_CONSTRUCT(
|
1310
|
-
'sum_prediction', sum({prediction_columns[0]}),
|
1311
|
-
'sum_label', sum({label_columns[0]}),{score_type_agg_clause}
|
1312
|
-
'count', count(*)
|
1313
|
-
) AS AGGREGATE_METRICS,
|
1314
|
-
APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
|
1315
|
-
APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
|
1316
|
-
)
|
1317
|
-
|
1318
|
-
# Join all the select clauses into a single string
|
1319
|
-
select_clause = ",\n".join(select_clauses)
|
1320
|
-
|
1321
|
-
return f"""
|
1322
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1323
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1324
|
-
WAREHOUSE = {warehouse_name}
|
1325
|
-
REFRESH_MODE = AUTO
|
1326
|
-
INITIALIZE = ON_CREATE
|
1327
|
-
AS{cte_query}
|
1328
|
-
select
|
1329
|
-
time_slice(timestamp, {aggregate_window.minutes}, 'MINUTE') timestamp,{select_clause}
|
1330
|
-
FROM
|
1331
|
-
filtered_data
|
1332
|
-
group by
|
1333
|
-
1,
|
1334
|
-
2
|
1335
|
-
"""
|