snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/jwt_generator.py +141 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +26 -12
- snowflake/ml/model/_client/ops/model_ops.py +51 -30
- snowflake/ml/model/_client/ops/service_ops.py +25 -9
- snowflake/ml/model/_client/sql/model.py +0 -14
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_client/sql/stage.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
- snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +71 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -1
- snowflake/ml/model/model_signature.py +38 -9
- snowflake/ml/model/type_hints.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +148 -1200
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +114 -238
- snowflake/ml/monitoring/entities/model_monitor_config.py +38 -12
- snowflake/ml/monitoring/model_monitor.py +12 -86
- snowflake/ml/registry/registry.py +28 -40
- snowflake/ml/utils/authentication.py +75 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/METADATA +116 -52
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/RECORD +51 -49
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- snowflake/ml/monitoring/entities/output_score_type.py +0 -90
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/top_level.txt +0 -0
@@ -1,127 +1,23 @@
|
|
1
|
-
import
|
2
|
-
import string
|
3
|
-
import textwrap
|
4
|
-
import typing
|
5
|
-
from collections import Counter
|
6
|
-
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, TypedDict
|
7
|
-
|
8
|
-
from importlib_resources import files
|
9
|
-
from typing_extensions import Required
|
1
|
+
from typing import Any, Dict, List, Mapping, Optional
|
10
2
|
|
11
3
|
from snowflake import snowpark
|
12
|
-
from snowflake.connector import errors
|
13
4
|
from snowflake.ml._internal.utils import (
|
14
5
|
db_utils,
|
15
|
-
formatting,
|
16
6
|
query_result_checker,
|
17
7
|
sql_identifier,
|
18
8
|
table_manager,
|
19
9
|
)
|
20
|
-
from snowflake.ml.model import type_hints
|
21
10
|
from snowflake.ml.model._client.sql import _base
|
22
|
-
from snowflake.
|
23
|
-
from snowflake.ml.monitoring.entities import model_monitor_interval, output_score_type
|
24
|
-
from snowflake.ml.monitoring.entities.model_monitor_interval import (
|
25
|
-
ModelMonitorAggregationWindow,
|
26
|
-
ModelMonitorRefreshInterval,
|
27
|
-
)
|
28
|
-
from snowflake.snowpark import DataFrame, exceptions, session, types
|
29
|
-
from snowflake.snowpark._internal import type_utils
|
30
|
-
|
31
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
|
32
|
-
_SNOWML_MONITORING_TABLE_NAME_PREFIX = "_SNOWML_OBS_MONITORING_"
|
33
|
-
_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX = "_SNOWML_OBS_ACCURACY_"
|
34
|
-
|
35
|
-
MONITOR_NAME_COL_NAME = "MONITOR_NAME"
|
36
|
-
SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
|
37
|
-
FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME"
|
38
|
-
VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME"
|
39
|
-
FUNCTION_NAME_COL_NAME = "FUNCTION_NAME"
|
40
|
-
TASK_COL_NAME = "TASK"
|
41
|
-
MONITORING_ENABLED_COL_NAME = "IS_ENABLED"
|
42
|
-
TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME"
|
43
|
-
PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
|
44
|
-
LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
|
45
|
-
ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
|
46
|
-
|
47
|
-
_DASHBOARD_UDTFS_COMMON_LIST = ["record_count"]
|
48
|
-
_DASHBOARD_UDTFS_REGRESSION_LIST = ["rmse"]
|
49
|
-
|
50
|
-
|
51
|
-
def _initialize_monitoring_metadata_tables(
|
52
|
-
session: session.Session,
|
53
|
-
database_name: sql_identifier.SqlIdentifier,
|
54
|
-
schema_name: sql_identifier.SqlIdentifier,
|
55
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
56
|
-
) -> None:
|
57
|
-
"""Create tables necessary for Model Monitoring in provided schema.
|
58
|
-
|
59
|
-
Args:
|
60
|
-
session: Active Snowpark session.
|
61
|
-
database_name: The database in which to setup resources for Model Monitoring.
|
62
|
-
schema_name: The schema in which to setup resources for Model Monitoring.
|
63
|
-
statement_params: Optional statement params for queries.
|
64
|
-
"""
|
65
|
-
table_manager.create_single_table(
|
66
|
-
session,
|
67
|
-
database_name,
|
68
|
-
schema_name,
|
69
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME,
|
70
|
-
[
|
71
|
-
(MONITOR_NAME_COL_NAME, "VARCHAR"),
|
72
|
-
(SOURCE_TABLE_NAME_COL_NAME, "VARCHAR"),
|
73
|
-
(FQ_MODEL_NAME_COL_NAME, "VARCHAR"),
|
74
|
-
(VERSION_NAME_COL_NAME, "VARCHAR"),
|
75
|
-
(FUNCTION_NAME_COL_NAME, "VARCHAR"),
|
76
|
-
(TASK_COL_NAME, "VARCHAR"),
|
77
|
-
(MONITORING_ENABLED_COL_NAME, "BOOLEAN"),
|
78
|
-
(TIMESTAMP_COL_NAME_COL_NAME, "VARCHAR"),
|
79
|
-
(PREDICTION_COL_NAMES_COL_NAME, "ARRAY"),
|
80
|
-
(LABEL_COL_NAMES_COL_NAME, "ARRAY"),
|
81
|
-
(ID_COL_NAMES_COL_NAME, "ARRAY"),
|
82
|
-
],
|
83
|
-
statement_params=statement_params,
|
84
|
-
)
|
85
|
-
|
86
|
-
|
87
|
-
def _create_baseline_table_name(model_name: str, version_name: str) -> str:
|
88
|
-
return f"_SNOWML_OBS_BASELINE_{model_name}_{version_name}"
|
89
|
-
|
90
|
-
|
91
|
-
def _infer_numeric_categoric_feature_column_names(
|
92
|
-
*,
|
93
|
-
source_table_schema: Mapping[str, types.DataType],
|
94
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
95
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
96
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
97
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
98
|
-
) -> Tuple[List[sql_identifier.SqlIdentifier], List[sql_identifier.SqlIdentifier]]:
|
99
|
-
cols_to_remove = {timestamp_column, *id_columns, *prediction_columns, *label_columns}
|
100
|
-
cols_to_consider = [
|
101
|
-
(col_name, source_table_schema[col_name]) for col_name in source_table_schema if col_name not in cols_to_remove
|
102
|
-
]
|
103
|
-
numeric_cols = [
|
104
|
-
sql_identifier.SqlIdentifier(column[0])
|
105
|
-
for column in cols_to_consider
|
106
|
-
if isinstance(column[1], types._NumericType)
|
107
|
-
]
|
108
|
-
categorical_cols = [
|
109
|
-
sql_identifier.SqlIdentifier(column[0])
|
110
|
-
for column in cols_to_consider
|
111
|
-
if isinstance(column[1], types.StringType) or isinstance(column[1], types.BooleanType)
|
112
|
-
]
|
113
|
-
return (numeric_cols, categorical_cols)
|
11
|
+
from snowflake.snowpark import session, types
|
114
12
|
|
13
|
+
MODEL_JSON_COL_NAME = "model"
|
14
|
+
MODEL_JSON_MODEL_NAME_FIELD = "model_name"
|
15
|
+
MODEL_JSON_VERSION_NAME_FIELD = "version_name"
|
115
16
|
|
116
|
-
class _ModelMonitorParams(TypedDict):
|
117
|
-
"""Class to transfer model monitor parameters to the ModelMonitor class."""
|
118
17
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
function_name: Required[str]
|
123
|
-
prediction_columns: Required[List[sql_identifier.SqlIdentifier]]
|
124
|
-
label_columns: Required[List[sql_identifier.SqlIdentifier]]
|
18
|
+
def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
|
19
|
+
sql_list = ", ".join([f"'{column}'" for column in columns])
|
20
|
+
return f"({sql_list})"
|
125
21
|
|
126
22
|
|
127
23
|
class ModelMonitorSQLClient:
|
@@ -143,101 +39,116 @@ class ModelMonitorSQLClient:
|
|
143
39
|
self._database_name = database_name
|
144
40
|
self._schema_name = schema_name
|
145
41
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
database_name
|
150
|
-
schema_name: sql_identifier.SqlIdentifier,
|
151
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
152
|
-
) -> None:
|
153
|
-
"""Initialize tables for tracking metadata associated with model monitoring.
|
154
|
-
|
155
|
-
Args:
|
156
|
-
session: The Snowpark Session to connect with Snowflake.
|
157
|
-
database_name: The database in which to setup resources for Model Monitoring.
|
158
|
-
schema_name: The schema in which to setup resources for Model Monitoring.
|
159
|
-
statement_params: Optional set of statement_params to include with query.
|
160
|
-
"""
|
161
|
-
# Create metadata management tables
|
162
|
-
_initialize_monitoring_metadata_tables(session, database_name, schema_name, statement_params)
|
163
|
-
|
164
|
-
def _validate_is_initialized(self) -> bool:
|
165
|
-
"""Validates whether monitoring metadata has been initialized.
|
166
|
-
|
167
|
-
Returns:
|
168
|
-
boolean to indicate whether tables have been initialized.
|
169
|
-
"""
|
170
|
-
try:
|
171
|
-
return table_manager.validate_table_exist(
|
172
|
-
self._sql_client._session,
|
173
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME,
|
174
|
-
f"{self._database_name}.{self._schema_name}",
|
175
|
-
)
|
176
|
-
except exceptions.SnowparkSQLException:
|
177
|
-
return False
|
42
|
+
def _infer_qualified_schema(
|
43
|
+
self, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier]
|
44
|
+
) -> str:
|
45
|
+
return f"{database_name or self._database_name}.{schema_name or self._schema_name}"
|
178
46
|
|
179
|
-
def
|
47
|
+
def create_model_monitor(
|
180
48
|
self,
|
49
|
+
*,
|
50
|
+
monitor_database: Optional[sql_identifier.SqlIdentifier],
|
51
|
+
monitor_schema: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
53
|
+
source_database: Optional[sql_identifier.SqlIdentifier],
|
54
|
+
source_schema: Optional[sql_identifier.SqlIdentifier],
|
55
|
+
source: sql_identifier.SqlIdentifier,
|
56
|
+
model_database: Optional[sql_identifier.SqlIdentifier],
|
57
|
+
model_schema: Optional[sql_identifier.SqlIdentifier],
|
58
|
+
model_name: sql_identifier.SqlIdentifier,
|
59
|
+
version_name: sql_identifier.SqlIdentifier,
|
60
|
+
function_name: str,
|
61
|
+
warehouse_name: sql_identifier.SqlIdentifier,
|
181
62
|
timestamp_column: sql_identifier.SqlIdentifier,
|
182
63
|
id_columns: List[sql_identifier.SqlIdentifier],
|
183
|
-
|
184
|
-
|
64
|
+
prediction_score_columns: List[sql_identifier.SqlIdentifier],
|
65
|
+
prediction_class_columns: List[sql_identifier.SqlIdentifier],
|
66
|
+
actual_score_columns: List[sql_identifier.SqlIdentifier],
|
67
|
+
actual_class_columns: List[sql_identifier.SqlIdentifier],
|
68
|
+
refresh_interval: str,
|
69
|
+
aggregation_window: str,
|
70
|
+
baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
|
71
|
+
baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
|
72
|
+
baseline: Optional[sql_identifier.SqlIdentifier] = None,
|
73
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
185
74
|
) -> None:
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
75
|
+
baseline_sql = ""
|
76
|
+
if baseline:
|
77
|
+
baseline_sql = f"BASELINE='{self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}'"
|
78
|
+
query_result_checker.SqlResultValidator(
|
79
|
+
self._sql_client._session,
|
80
|
+
f"""
|
81
|
+
CREATE MODEL MONITOR {self._infer_qualified_schema(monitor_database, monitor_schema)}.{monitor_name}
|
82
|
+
WITH
|
83
|
+
MODEL='{self._infer_qualified_schema(model_database, model_schema)}.{model_name}'
|
84
|
+
VERSION='{version_name}'
|
85
|
+
FUNCTION='{function_name}'
|
86
|
+
WAREHOUSE='{warehouse_name}'
|
87
|
+
SOURCE='{self._infer_qualified_schema(source_database, source_schema)}.{source}'
|
88
|
+
ID_COLUMNS={_build_sql_list_from_columns(id_columns)}
|
89
|
+
PREDICTION_SCORE_COLUMNS={_build_sql_list_from_columns(prediction_score_columns)}
|
90
|
+
PREDICTION_CLASS_COLUMNS={_build_sql_list_from_columns(prediction_class_columns)}
|
91
|
+
ACTUAL_SCORE_COLUMNS={_build_sql_list_from_columns(actual_score_columns)}
|
92
|
+
ACTUAL_CLASS_COLUMNS={_build_sql_list_from_columns(actual_class_columns)}
|
93
|
+
TIMESTAMP_COLUMN='{timestamp_column}'
|
94
|
+
REFRESH_INTERVAL='{refresh_interval}'
|
95
|
+
AGGREGATION_WINDOW='{aggregation_window}'
|
96
|
+
{baseline_sql}""",
|
97
|
+
statement_params=statement_params,
|
98
|
+
).has_column("status").has_dimensions(1, 1).validate()
|
191
99
|
|
192
|
-
def
|
100
|
+
def drop_model_monitor(
|
193
101
|
self,
|
102
|
+
*,
|
103
|
+
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
104
|
+
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
194
105
|
monitor_name: sql_identifier.SqlIdentifier,
|
195
106
|
statement_params: Optional[Dict[str, Any]] = None,
|
196
|
-
) ->
|
197
|
-
|
107
|
+
) -> None:
|
108
|
+
search_database_name = database_name or self._database_name
|
109
|
+
search_schema_name = schema_name or self._schema_name
|
110
|
+
query_result_checker.SqlResultValidator(
|
111
|
+
self._sql_client._session,
|
112
|
+
f"DROP MODEL MONITOR {search_database_name}.{search_schema_name}.{monitor_name}",
|
113
|
+
statement_params=statement_params,
|
114
|
+
).validate()
|
115
|
+
|
116
|
+
def show_model_monitors(
|
117
|
+
self,
|
118
|
+
*,
|
119
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
120
|
+
) -> List[snowpark.Row]:
|
121
|
+
fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
122
|
+
return (
|
198
123
|
query_result_checker.SqlResultValidator(
|
199
124
|
self._sql_client._session,
|
200
|
-
f"
|
201
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
202
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
|
125
|
+
f"SHOW MODEL MONITORS IN {fully_qualified_schema_name}",
|
203
126
|
statement_params=statement_params,
|
204
127
|
)
|
205
|
-
.has_column(
|
206
|
-
.has_column(VERSION_NAME_COL_NAME, allow_empty=True)
|
128
|
+
.has_column("name", allow_empty=True)
|
207
129
|
.validate()
|
208
130
|
)
|
209
|
-
return len(res) >= 1
|
210
131
|
|
211
|
-
def
|
132
|
+
def validate_existence_by_name(
|
212
133
|
self,
|
213
|
-
|
214
|
-
|
134
|
+
*,
|
135
|
+
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
136
|
+
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
137
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
215
138
|
statement_params: Optional[Dict[str, Any]] = None,
|
216
139
|
) -> bool:
|
217
|
-
|
218
|
-
|
219
|
-
Args:
|
220
|
-
fully_qualified_model_name: Fully qualified name of model.
|
221
|
-
version_name: Name of model version.
|
222
|
-
statement_params: Optional set of statement_params to include with query.
|
223
|
-
|
224
|
-
Returns:
|
225
|
-
Boolean indicating whether monitor exists on model version.
|
226
|
-
"""
|
140
|
+
search_database_name = database_name or self._database_name
|
141
|
+
search_schema_name = schema_name or self._schema_name
|
227
142
|
res = (
|
228
143
|
query_result_checker.SqlResultValidator(
|
229
144
|
self._sql_client._session,
|
230
|
-
f"
|
231
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
232
|
-
WHERE {FQ_MODEL_NAME_COL_NAME} = '{fully_qualified_model_name}'
|
233
|
-
AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
|
145
|
+
f"SHOW MODEL MONITORS LIKE '{monitor_name.resolved()}' IN {search_database_name}.{search_schema_name}",
|
234
146
|
statement_params=statement_params,
|
235
147
|
)
|
236
|
-
.has_column(
|
237
|
-
.has_column(VERSION_NAME_COL_NAME, allow_empty=True)
|
148
|
+
.has_column("name", allow_empty=True)
|
238
149
|
.validate()
|
239
150
|
)
|
240
|
-
return len(res)
|
151
|
+
return len(res) == 1
|
241
152
|
|
242
153
|
def validate_monitor_warehouse(
|
243
154
|
self,
|
@@ -261,1075 +172,112 @@ class ModelMonitorSQLClient:
|
|
261
172
|
):
|
262
173
|
raise ValueError(f"Warehouse '{warehouse_name}' not found.")
|
263
174
|
|
264
|
-
def
|
265
|
-
self,
|
266
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
267
|
-
model_name: sql_identifier.SqlIdentifier,
|
268
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
269
|
-
task: type_hints.Task,
|
270
|
-
score_type: output_score_type.OutputScoreType,
|
271
|
-
output_columns: List[sql_identifier.SqlIdentifier],
|
272
|
-
ground_truth_columns: List[sql_identifier.SqlIdentifier],
|
273
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
274
|
-
) -> None:
|
275
|
-
udtf_name_query_map = self._create_dashboard_udtf_queries(
|
276
|
-
monitor_name,
|
277
|
-
model_name,
|
278
|
-
model_version_name,
|
279
|
-
task,
|
280
|
-
score_type,
|
281
|
-
output_columns,
|
282
|
-
ground_truth_columns,
|
283
|
-
)
|
284
|
-
for udtf_query in udtf_name_query_map.values():
|
285
|
-
query_result_checker.SqlResultValidator(
|
286
|
-
self._sql_client._session,
|
287
|
-
f"""{udtf_query}""",
|
288
|
-
statement_params=statement_params,
|
289
|
-
).validate()
|
290
|
-
|
291
|
-
def get_monitoring_table_fully_qualified_name(
|
292
|
-
self,
|
293
|
-
model_name: sql_identifier.SqlIdentifier,
|
294
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
295
|
-
) -> str:
|
296
|
-
table_name = f"{_SNOWML_MONITORING_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
|
297
|
-
return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
|
298
|
-
|
299
|
-
def get_accuracy_monitoring_table_fully_qualified_name(
|
300
|
-
self,
|
301
|
-
model_name: sql_identifier.SqlIdentifier,
|
302
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
303
|
-
) -> str:
|
304
|
-
table_name = f"{_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
|
305
|
-
return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
|
306
|
-
|
307
|
-
def _create_dashboard_udtf_queries(
|
308
|
-
self,
|
309
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
310
|
-
model_name: sql_identifier.SqlIdentifier,
|
311
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
312
|
-
task: type_hints.Task,
|
313
|
-
score_type: output_score_type.OutputScoreType,
|
314
|
-
output_columns: List[sql_identifier.SqlIdentifier],
|
315
|
-
ground_truth_columns: List[sql_identifier.SqlIdentifier],
|
316
|
-
) -> Mapping[str, str]:
|
317
|
-
query_files = files("snowflake.ml.monitoring._client")
|
318
|
-
# TODO(apgupta): Expand list of queries based on model objective and score type.
|
319
|
-
queries_list = []
|
320
|
-
queries_list.extend(_DASHBOARD_UDTFS_COMMON_LIST)
|
321
|
-
if task == type_hints.Task.TABULAR_REGRESSION:
|
322
|
-
queries_list.extend(_DASHBOARD_UDTFS_REGRESSION_LIST)
|
323
|
-
var_map = {
|
324
|
-
"MODEL_MONITOR_NAME": monitor_name,
|
325
|
-
"MONITORING_TABLE": self.get_monitoring_table_fully_qualified_name(model_name, model_version_name),
|
326
|
-
"MONITORING_PRED_LABEL_JOINED_TABLE": self.get_accuracy_monitoring_table_fully_qualified_name(
|
327
|
-
model_name, model_version_name
|
328
|
-
),
|
329
|
-
"OUTPUT_COLUMN_NAME": output_columns[0],
|
330
|
-
"GROUND_TRUTH_COLUMN_NAME": ground_truth_columns[0],
|
331
|
-
}
|
332
|
-
|
333
|
-
udf_name_query_map = {}
|
334
|
-
for q in queries_list:
|
335
|
-
q_template = query_files.joinpath(f"queries/{q}.ssql").read_text()
|
336
|
-
q_actual = string.Template(q_template).substitute(var_map)
|
337
|
-
udf_name_query_map[q] = q_actual
|
338
|
-
return udf_name_query_map
|
339
|
-
|
340
|
-
def _validate_columns_exist_in_source_table(
|
175
|
+
def _validate_columns_exist_in_source(
|
341
176
|
self,
|
342
177
|
*,
|
343
|
-
|
344
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
178
|
+
source_column_schema: Mapping[str, types.DataType],
|
345
179
|
timestamp_column: sql_identifier.SqlIdentifier,
|
346
|
-
|
347
|
-
|
180
|
+
prediction_score_columns: List[sql_identifier.SqlIdentifier],
|
181
|
+
prediction_class_columns: List[sql_identifier.SqlIdentifier],
|
182
|
+
actual_score_columns: List[sql_identifier.SqlIdentifier],
|
183
|
+
actual_class_columns: List[sql_identifier.SqlIdentifier],
|
348
184
|
id_columns: List[sql_identifier.SqlIdentifier],
|
349
185
|
) -> None:
|
350
186
|
"""Ensures all columns exist in the source table.
|
351
187
|
|
352
188
|
Args:
|
353
|
-
|
354
|
-
source_table_name: Name of the table with model data to monitor.
|
189
|
+
source_column_schema: Dictionary of column names and types in the source.
|
355
190
|
timestamp_column: Name of the timestamp column.
|
356
|
-
|
357
|
-
|
191
|
+
prediction_score_columns: List of prediction score column names.
|
192
|
+
prediction_class_columns: List of prediction class names.
|
193
|
+
actual_score_columns: List of actual score column names.
|
194
|
+
actual_class_columns: List of actual class column names.
|
358
195
|
id_columns: List of id column names.
|
359
196
|
|
360
197
|
Raises:
|
361
|
-
ValueError: If any of the columns do not exist in the source
|
198
|
+
ValueError: If any of the columns do not exist in the source.
|
362
199
|
"""
|
363
200
|
|
364
|
-
if timestamp_column not in
|
365
|
-
raise ValueError(f"Timestamp column {timestamp_column} does not exist in
|
201
|
+
if timestamp_column not in source_column_schema:
|
202
|
+
raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.")
|
366
203
|
|
367
|
-
if not all([column_name in
|
368
|
-
raise ValueError(f"Prediction column(s): {
|
369
|
-
if not all([column_name in
|
370
|
-
raise ValueError(f"
|
371
|
-
if not all([column_name in
|
372
|
-
raise ValueError(f"
|
204
|
+
if not all([column_name in source_column_schema for column_name in prediction_score_columns]):
|
205
|
+
raise ValueError(f"Prediction Score column(s): {prediction_score_columns} do not exist in source.")
|
206
|
+
if not all([column_name in source_column_schema for column_name in prediction_class_columns]):
|
207
|
+
raise ValueError(f"Prediction Class column(s): {prediction_class_columns} do not exist in source.")
|
208
|
+
if not all([column_name in source_column_schema for column_name in actual_score_columns]):
|
209
|
+
raise ValueError(f"Actual Score column(s): {actual_score_columns} do not exist in source.")
|
373
210
|
|
374
|
-
|
375
|
-
|
376
|
-
) -> None:
|
377
|
-
"""Ensures columns have the same type.
|
211
|
+
if not all([column_name in source_column_schema for column_name in actual_class_columns]):
|
212
|
+
raise ValueError(f"Actual Class column(s): {actual_class_columns} do not exist in source.")
|
378
213
|
|
379
|
-
|
380
|
-
|
381
|
-
timestamp_column: Name of the timestamp column.
|
382
|
-
|
383
|
-
Raises:
|
384
|
-
ValueError: If the timestamp column is not of type TimestampType.
|
385
|
-
"""
|
386
|
-
if not isinstance(table_schema[timestamp_column], types.TimestampType):
|
387
|
-
raise ValueError(
|
388
|
-
f"Timestamp column: {timestamp_column} must be TimestampType. "
|
389
|
-
f"Found: {table_schema[timestamp_column]}"
|
390
|
-
)
|
391
|
-
|
392
|
-
def _validate_id_columns_types(
|
393
|
-
self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier]
|
394
|
-
) -> None:
|
395
|
-
"""Ensures id columns have the correct type.
|
396
|
-
|
397
|
-
Args:
|
398
|
-
table_schema: Dictionary of column names and types in the source table.
|
399
|
-
id_columns: List of id column names.
|
214
|
+
if not all([column_name in source_column_schema for column_name in id_columns]):
|
215
|
+
raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
|
400
216
|
|
401
|
-
|
402
|
-
ValueError: If the id column is not of type StringType.
|
403
|
-
"""
|
404
|
-
id_column_types = list({table_schema[column_name] for column_name in id_columns})
|
405
|
-
all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types])
|
406
|
-
if not all_id_columns_string:
|
407
|
-
raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}")
|
408
|
-
|
409
|
-
def _validate_prediction_columns_types(
|
410
|
-
self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier]
|
411
|
-
) -> None:
|
412
|
-
"""Ensures prediction columns have the same type.
|
413
|
-
|
414
|
-
Args:
|
415
|
-
table_schema: Dictionary of column names and types in the source table.
|
416
|
-
prediction_columns: List of prediction column names.
|
417
|
-
|
418
|
-
Raises:
|
419
|
-
ValueError: If the prediction columns do not share the same type.
|
420
|
-
"""
|
421
|
-
|
422
|
-
prediction_column_types = {table_schema[column_name] for column_name in prediction_columns}
|
423
|
-
if len(prediction_column_types) > 1:
|
424
|
-
raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}")
|
425
|
-
|
426
|
-
def _validate_label_columns_types(
|
427
|
-
self,
|
428
|
-
table_schema: Mapping[str, types.DataType],
|
429
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
430
|
-
) -> None:
|
431
|
-
"""Ensures label columns have the same type, and the correct type for the score type.
|
432
|
-
|
433
|
-
Args:
|
434
|
-
table_schema: Dictionary of column names and types in the source table.
|
435
|
-
label_columns: List of label column names.
|
436
|
-
|
437
|
-
Raises:
|
438
|
-
ValueError: If the label columns do not share the same type.
|
439
|
-
"""
|
440
|
-
label_column_types = {table_schema[column_name] for column_name in label_columns}
|
441
|
-
if len(label_column_types) > 1:
|
442
|
-
raise ValueError(f"Label column types must be the same. Found: {label_column_types}")
|
443
|
-
|
444
|
-
def _validate_column_types(
|
445
|
-
self,
|
446
|
-
*,
|
447
|
-
table_schema: Mapping[str, types.DataType],
|
448
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
449
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
450
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
451
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
452
|
-
) -> None:
|
453
|
-
"""Ensures columns have the expected type.
|
454
|
-
|
455
|
-
Args:
|
456
|
-
table_schema: Dictionary of column names and types in the source table.
|
457
|
-
timestamp_column: Name of the timestamp column.
|
458
|
-
id_columns: List of id column names.
|
459
|
-
prediction_columns: List of prediction column names.
|
460
|
-
label_columns: List of label column names.
|
461
|
-
"""
|
462
|
-
self._validate_timestamp_column_type(table_schema, timestamp_column)
|
463
|
-
self._validate_id_columns_types(table_schema, id_columns)
|
464
|
-
self._validate_prediction_columns_types(table_schema, prediction_columns)
|
465
|
-
self._validate_label_columns_types(table_schema, label_columns)
|
466
|
-
# TODO(SNOW-1646693): Validate label makes sense with model task
|
467
|
-
|
468
|
-
def _validate_source_table_features_shape(
|
469
|
-
self,
|
470
|
-
table_schema: Mapping[str, types.DataType],
|
471
|
-
special_columns: Set[sql_identifier.SqlIdentifier],
|
472
|
-
model_function: model_manifest_schema.ModelFunctionInfo,
|
473
|
-
) -> None:
|
474
|
-
table_schema_without_special_columns = {
|
475
|
-
k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns
|
476
|
-
}
|
477
|
-
schema_column_types_to_count: typing.Counter[types.DataType] = Counter()
|
478
|
-
for column_type in table_schema_without_special_columns.values():
|
479
|
-
schema_column_types_to_count[column_type] += 1
|
480
|
-
|
481
|
-
inputs = model_function["signature"].inputs
|
482
|
-
function_input_types = [input.as_snowpark_type() for input in inputs]
|
483
|
-
function_input_types_to_count: typing.Counter[types.DataType] = Counter()
|
484
|
-
for function_input_type in function_input_types:
|
485
|
-
function_input_types_to_count[function_input_type] += 1
|
486
|
-
|
487
|
-
if function_input_types_to_count != schema_column_types_to_count:
|
488
|
-
raise ValueError(
|
489
|
-
"Model function input types do not match the source table input columns types. "
|
490
|
-
f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
|
491
|
-
)
|
492
|
-
|
493
|
-
def get_model_monitor_by_name(
|
494
|
-
self,
|
495
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
496
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
497
|
-
) -> _ModelMonitorParams:
|
498
|
-
"""Fetch metadata for a Model Monitor by name.
|
499
|
-
|
500
|
-
Args:
|
501
|
-
monitor_name: Name of ModelMonitor to fetch.
|
502
|
-
statement_params: Optional set of statement_params to include with query.
|
503
|
-
|
504
|
-
Returns:
|
505
|
-
_ModelMonitorParams dict with Name of monitor, fully qualified model name,
|
506
|
-
model version name, model function name, prediction_col, label_col.
|
507
|
-
|
508
|
-
Raises:
|
509
|
-
ValueError: If multiple ModelMonitors exist with the same name.
|
510
|
-
"""
|
511
|
-
try:
|
512
|
-
res = (
|
513
|
-
query_result_checker.SqlResultValidator(
|
514
|
-
self._sql_client._session,
|
515
|
-
f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME},
|
516
|
-
{PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME}
|
517
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
518
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
|
519
|
-
statement_params=statement_params,
|
520
|
-
)
|
521
|
-
.has_column(FQ_MODEL_NAME_COL_NAME)
|
522
|
-
.has_column(VERSION_NAME_COL_NAME)
|
523
|
-
.has_column(FUNCTION_NAME_COL_NAME)
|
524
|
-
.has_column(PREDICTION_COL_NAMES_COL_NAME)
|
525
|
-
.has_column(LABEL_COL_NAMES_COL_NAME)
|
526
|
-
.validate()
|
527
|
-
)
|
528
|
-
except errors.DataError:
|
529
|
-
raise ValueError(f"Failed to find any monitor with name '{monitor_name}'")
|
530
|
-
|
531
|
-
if len(res) > 1:
|
532
|
-
raise ValueError(f"Invalid state. Multiple Monitors exist with name '{monitor_name}'")
|
533
|
-
|
534
|
-
return _ModelMonitorParams(
|
535
|
-
monitor_name=str(monitor_name),
|
536
|
-
fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
|
537
|
-
version_name=res[0][VERSION_NAME_COL_NAME],
|
538
|
-
function_name=res[0][FUNCTION_NAME_COL_NAME],
|
539
|
-
prediction_columns=[
|
540
|
-
sql_identifier.SqlIdentifier(prediction_column)
|
541
|
-
for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
|
542
|
-
],
|
543
|
-
label_columns=[
|
544
|
-
sql_identifier.SqlIdentifier(label_column)
|
545
|
-
for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
|
546
|
-
],
|
547
|
-
)
|
548
|
-
|
549
|
-
def get_model_monitor_by_model_version(
|
217
|
+
def validate_source(
|
550
218
|
self,
|
551
219
|
*,
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
version_name: sql_identifier.SqlIdentifier,
|
556
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
557
|
-
) -> _ModelMonitorParams:
|
558
|
-
"""Fetch metadata for a Model Monitor by model version.
|
559
|
-
|
560
|
-
Args:
|
561
|
-
model_db: Database of model.
|
562
|
-
model_schema: Schema of model.
|
563
|
-
model_name: Model name.
|
564
|
-
version_name: Model version name
|
565
|
-
statement_params: Optional set of statement_params to include with queries.
|
566
|
-
|
567
|
-
Returns:
|
568
|
-
_ModelMonitorParams dict with Name of monitor, fully qualified model name,
|
569
|
-
model version name, model function name, prediction_col, label_col.
|
570
|
-
|
571
|
-
Raises:
|
572
|
-
ValueError: If multiple ModelMonitors exist with the same name.
|
573
|
-
"""
|
574
|
-
res = (
|
575
|
-
query_result_checker.SqlResultValidator(
|
576
|
-
self._sql_client._session,
|
577
|
-
f"""SELECT {MONITOR_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
|
578
|
-
{VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}
|
579
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
580
|
-
WHERE {FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{model_name}'
|
581
|
-
AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
|
582
|
-
statement_params=statement_params,
|
583
|
-
)
|
584
|
-
.has_column(MONITOR_NAME_COL_NAME)
|
585
|
-
.has_column(FQ_MODEL_NAME_COL_NAME)
|
586
|
-
.has_column(VERSION_NAME_COL_NAME)
|
587
|
-
.has_column(FUNCTION_NAME_COL_NAME)
|
588
|
-
.validate()
|
589
|
-
)
|
590
|
-
if len(res) > 1:
|
591
|
-
raise ValueError(
|
592
|
-
f"Invalid state. Multiple Monitors exist for model: '{model_name}' and version: '{version_name}'"
|
593
|
-
)
|
594
|
-
return _ModelMonitorParams(
|
595
|
-
monitor_name=res[0][MONITOR_NAME_COL_NAME],
|
596
|
-
fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
|
597
|
-
version_name=res[0][VERSION_NAME_COL_NAME],
|
598
|
-
function_name=res[0][FUNCTION_NAME_COL_NAME],
|
599
|
-
prediction_columns=[
|
600
|
-
sql_identifier.SqlIdentifier(prediction_column)
|
601
|
-
for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
|
602
|
-
],
|
603
|
-
label_columns=[
|
604
|
-
sql_identifier.SqlIdentifier(label_column)
|
605
|
-
for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
|
606
|
-
],
|
607
|
-
)
|
608
|
-
|
609
|
-
def get_score_type(
|
610
|
-
self,
|
611
|
-
task: type_hints.Task,
|
612
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
613
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
614
|
-
) -> output_score_type.OutputScoreType:
|
615
|
-
"""Infer score type given model task and prediction table columns.
|
616
|
-
|
617
|
-
Args:
|
618
|
-
task: Model task
|
619
|
-
source_table_name: Source data table containing model outputs.
|
620
|
-
prediction_columns: columns in source data table corresponding to model outputs.
|
621
|
-
|
622
|
-
Returns:
|
623
|
-
OutputScoreType for model.
|
624
|
-
"""
|
625
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
626
|
-
self._sql_client._session,
|
627
|
-
self._database_name,
|
628
|
-
self._schema_name,
|
629
|
-
source_table_name,
|
630
|
-
)
|
631
|
-
return output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_columns, task)
|
632
|
-
|
633
|
-
def validate_source_table(
|
634
|
-
self,
|
635
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
220
|
+
source_database: Optional[sql_identifier.SqlIdentifier],
|
221
|
+
source_schema: Optional[sql_identifier.SqlIdentifier],
|
222
|
+
source: sql_identifier.SqlIdentifier,
|
636
223
|
timestamp_column: sql_identifier.SqlIdentifier,
|
637
|
-
|
638
|
-
|
224
|
+
prediction_score_columns: List[sql_identifier.SqlIdentifier],
|
225
|
+
prediction_class_columns: List[sql_identifier.SqlIdentifier],
|
226
|
+
actual_score_columns: List[sql_identifier.SqlIdentifier],
|
227
|
+
actual_class_columns: List[sql_identifier.SqlIdentifier],
|
639
228
|
id_columns: List[sql_identifier.SqlIdentifier],
|
640
|
-
model_function: model_manifest_schema.ModelFunctionInfo,
|
641
229
|
) -> None:
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
f"{self._database_name}.{self._schema_name}",
|
647
|
-
):
|
648
|
-
raise ValueError(
|
649
|
-
f"Table {source_table_name} does not exist in schema {self._database_name}.{self._schema_name}."
|
650
|
-
)
|
651
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
230
|
+
source_database = source_database or self._database_name
|
231
|
+
source_schema = source_schema or self._schema_name
|
232
|
+
# Get Schema of the source. Implicitly validates that the source exists.
|
233
|
+
source_column_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
652
234
|
self._sql_client._session,
|
653
|
-
|
654
|
-
|
655
|
-
|
235
|
+
source_database,
|
236
|
+
source_schema,
|
237
|
+
source,
|
656
238
|
)
|
657
|
-
self.
|
658
|
-
|
659
|
-
source_table_name=source_table_name,
|
239
|
+
self._validate_columns_exist_in_source(
|
240
|
+
source_column_schema=source_column_schema,
|
660
241
|
timestamp_column=timestamp_column,
|
661
|
-
|
662
|
-
|
242
|
+
prediction_score_columns=prediction_score_columns,
|
243
|
+
prediction_class_columns=prediction_class_columns,
|
244
|
+
actual_score_columns=actual_score_columns,
|
245
|
+
actual_class_columns=actual_class_columns,
|
663
246
|
id_columns=id_columns,
|
664
247
|
)
|
665
|
-
self._validate_column_types(
|
666
|
-
table_schema=table_schema,
|
667
|
-
timestamp_column=timestamp_column,
|
668
|
-
id_columns=id_columns,
|
669
|
-
prediction_columns=prediction_columns,
|
670
|
-
label_columns=label_columns,
|
671
|
-
)
|
672
|
-
self._validate_source_table_features_shape(
|
673
|
-
table_schema=table_schema,
|
674
|
-
special_columns={timestamp_column, *id_columns, *prediction_columns, *label_columns},
|
675
|
-
model_function=model_function,
|
676
|
-
)
|
677
248
|
|
678
|
-
def
|
679
|
-
self,
|
680
|
-
name: str,
|
681
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
682
|
-
) -> None:
|
683
|
-
"""Delete the row in the metadata table corresponding to the given monitor name.
|
684
|
-
|
685
|
-
Args:
|
686
|
-
name: Name of the model monitor whose metadata should be deleted.
|
687
|
-
statement_params: Optional set of statement_params to include with query.
|
688
|
-
"""
|
689
|
-
self._sql_client._session.sql(
|
690
|
-
f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
691
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
|
692
|
-
).collect(statement_params=statement_params)
|
693
|
-
|
694
|
-
def delete_baseline_table(
|
695
|
-
self,
|
696
|
-
fully_qualified_model_name: str,
|
697
|
-
version_name: str,
|
698
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
699
|
-
) -> None:
|
700
|
-
"""Delete the baseline table corresponding to a particular model and version.
|
701
|
-
|
702
|
-
Args:
|
703
|
-
fully_qualified_model_name: Fully qualified name of the model.
|
704
|
-
version_name: Name of the model version.
|
705
|
-
statement_params: Optional set of statement_params to include with query.
|
706
|
-
"""
|
707
|
-
table_name = _create_baseline_table_name(fully_qualified_model_name, version_name)
|
708
|
-
self._sql_client._session.sql(
|
709
|
-
f"""DROP TABLE IF EXISTS {self._database_name}.{self._schema_name}.{table_name}"""
|
710
|
-
).collect(statement_params=statement_params)
|
711
|
-
|
712
|
-
def delete_dynamic_tables(
|
713
|
-
self,
|
714
|
-
fully_qualified_model_name: str,
|
715
|
-
version_name: str,
|
716
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
717
|
-
) -> None:
|
718
|
-
"""Delete the dynamic tables corresponding to a particular model and version.
|
719
|
-
|
720
|
-
Args:
|
721
|
-
fully_qualified_model_name: Fully qualified name of the model.
|
722
|
-
version_name: Name of the model version.
|
723
|
-
statement_params: Optional set of statement_params to include with query.
|
724
|
-
"""
|
725
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
|
726
|
-
model_id = sql_identifier.SqlIdentifier(model_name)
|
727
|
-
version_id = sql_identifier.SqlIdentifier(version_name)
|
728
|
-
monitoring_table_name = self.get_monitoring_table_fully_qualified_name(model_id, version_id)
|
729
|
-
self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {monitoring_table_name}""").collect(
|
730
|
-
statement_params=statement_params
|
731
|
-
)
|
732
|
-
accuracy_table_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_id, version_id)
|
733
|
-
self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {accuracy_table_name}""").collect(
|
734
|
-
statement_params=statement_params
|
735
|
-
)
|
736
|
-
|
737
|
-
def create_monitor_on_model_version(
|
738
|
-
self,
|
739
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
740
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
741
|
-
fully_qualified_model_name: str,
|
742
|
-
version_name: sql_identifier.SqlIdentifier,
|
743
|
-
function_name: str,
|
744
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
745
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
746
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
747
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
748
|
-
task: type_hints.Task,
|
749
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
750
|
-
) -> None:
|
751
|
-
"""
|
752
|
-
Creates a ModelMonitor on a Model Version from the Snowflake Model Registry. Creates public schema for metadata.
|
753
|
-
|
754
|
-
Args:
|
755
|
-
monitor_name: Name of monitor object to create.
|
756
|
-
source_table_name: Name of source data table to monitor.
|
757
|
-
fully_qualified_model_name: fully qualified name of model to monitor '<db>.<schema>.<model_name>'.
|
758
|
-
version_name: model version name to monitor.
|
759
|
-
function_name: function_name to monitor in model version.
|
760
|
-
timestamp_column: timestamp column name.
|
761
|
-
prediction_columns: list of prediction column names.
|
762
|
-
label_columns: list of label column names.
|
763
|
-
id_columns: list of id column names.
|
764
|
-
task: Task of the model, e.g. TABULAR_REGRESSION.
|
765
|
-
statement_params: Optional dict of statement_params to include with queries.
|
766
|
-
|
767
|
-
Raises:
|
768
|
-
ValueError: If model version is already monitored.
|
769
|
-
"""
|
770
|
-
# Validate monitor does not already exist on model version.
|
771
|
-
if self.validate_existence(fully_qualified_model_name, version_name, statement_params):
|
772
|
-
raise ValueError(f"Model {fully_qualified_model_name} Version {version_name} is already monitored!")
|
773
|
-
|
774
|
-
if self.validate_existence_by_name(monitor_name, statement_params):
|
775
|
-
raise ValueError(f"Model Monitor with name '{monitor_name}' already exists!")
|
776
|
-
|
777
|
-
prediction_columns_for_select = formatting.format_value_for_select(prediction_columns)
|
778
|
-
label_columns_for_select = formatting.format_value_for_select(label_columns)
|
779
|
-
id_columns_for_select = formatting.format_value_for_select(id_columns)
|
780
|
-
query_result_checker.SqlResultValidator(
|
781
|
-
self._sql_client._session,
|
782
|
-
textwrap.dedent(
|
783
|
-
f"""INSERT INTO {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
784
|
-
({MONITOR_NAME_COL_NAME}, {SOURCE_TABLE_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
|
785
|
-
{VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, {TASK_COL_NAME},
|
786
|
-
{MONITORING_ENABLED_COL_NAME}, {TIMESTAMP_COL_NAME_COL_NAME},
|
787
|
-
{PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME},
|
788
|
-
{ID_COL_NAMES_COL_NAME})
|
789
|
-
SELECT '{monitor_name}', '{source_table_name}', '{fully_qualified_model_name}',
|
790
|
-
'{version_name}', '{function_name}', '{task.value}', TRUE, '{timestamp_column}',
|
791
|
-
{prediction_columns_for_select}, {label_columns_for_select}, {id_columns_for_select}"""
|
792
|
-
),
|
793
|
-
statement_params=statement_params,
|
794
|
-
).insertion_success(expected_num_rows=1).validate()
|
795
|
-
|
796
|
-
def initialize_baseline_table(
|
797
|
-
self,
|
798
|
-
model_name: sql_identifier.SqlIdentifier,
|
799
|
-
version_name: sql_identifier.SqlIdentifier,
|
800
|
-
source_table_name: str,
|
801
|
-
columns_to_drop: Optional[List[sql_identifier.SqlIdentifier]] = None,
|
802
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
803
|
-
) -> None:
|
804
|
-
"""
|
805
|
-
Initializes the baseline table for a Model Version. Creates schema for baseline data using the source table.
|
806
|
-
|
807
|
-
Args:
|
808
|
-
model_name: name of model to monitor.
|
809
|
-
version_name: model version name to monitor.
|
810
|
-
source_table_name: name of the user's table containing their model data.
|
811
|
-
columns_to_drop: special columns in the source table to be excluded from baseline tables.
|
812
|
-
statement_params: Optional dict of statement_params to include with queries.
|
813
|
-
"""
|
814
|
-
table_schema = table_manager.get_table_schema_types(
|
815
|
-
self._sql_client._session,
|
816
|
-
database=self._database_name,
|
817
|
-
schema=self._schema_name,
|
818
|
-
table_name=source_table_name,
|
819
|
-
)
|
820
|
-
|
821
|
-
if columns_to_drop is None:
|
822
|
-
columns_to_drop = []
|
823
|
-
|
824
|
-
table_manager.create_single_table(
|
825
|
-
self._sql_client._session,
|
826
|
-
self._database_name,
|
827
|
-
self._schema_name,
|
828
|
-
_create_baseline_table_name(model_name, version_name),
|
829
|
-
[
|
830
|
-
(k, type_utils.convert_sp_to_sf_type(v))
|
831
|
-
for k, v in table_schema.items()
|
832
|
-
if sql_identifier.SqlIdentifier(k) not in columns_to_drop
|
833
|
-
],
|
834
|
-
statement_params=statement_params,
|
835
|
-
)
|
836
|
-
|
837
|
-
def get_all_model_monitor_metadata(
|
838
|
-
self,
|
839
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
840
|
-
) -> List[snowpark.Row]:
|
841
|
-
"""Get the metadata for all model monitors in the given schema.
|
842
|
-
|
843
|
-
Args:
|
844
|
-
statement_params: Optional dict of statement_params to include with queries.
|
845
|
-
|
846
|
-
Returns:
|
847
|
-
List of snowpark.Row containing metadata for each model monitor.
|
848
|
-
"""
|
849
|
-
return query_result_checker.SqlResultValidator(
|
850
|
-
self._sql_client._session,
|
851
|
-
f"""SELECT *
|
852
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}""",
|
853
|
-
statement_params=statement_params,
|
854
|
-
).validate()
|
855
|
-
|
856
|
-
def materialize_baseline_dataframe(
|
857
|
-
self,
|
858
|
-
baseline_df: DataFrame,
|
859
|
-
fully_qualified_model_name: str,
|
860
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
861
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
862
|
-
) -> None:
|
863
|
-
"""
|
864
|
-
Materialize baseline dataframe to a permanent snowflake table. This method
|
865
|
-
truncates (overwrite without dropping) any existing data in the baseline table.
|
866
|
-
|
867
|
-
Args:
|
868
|
-
baseline_df: dataframe containing baseline data that monitored data will be compared against.
|
869
|
-
fully_qualified_model_name: name of the model.
|
870
|
-
model_version_name: model version name to monitor.
|
871
|
-
statement_params: Optional dict of statement_params to include with queries.
|
872
|
-
|
873
|
-
Raises:
|
874
|
-
ValueError: If no baseline table was initialized.
|
875
|
-
"""
|
876
|
-
|
877
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
|
878
|
-
baseline_table_name = _create_baseline_table_name(model_name, model_version_name)
|
879
|
-
|
880
|
-
baseline_table_exists = db_utils.db_object_exists(
|
881
|
-
self._sql_client._session,
|
882
|
-
db_utils.SnowflakeDbObjectType.TABLE,
|
883
|
-
sql_identifier.SqlIdentifier(baseline_table_name),
|
884
|
-
database_name=self._database_name,
|
885
|
-
schema_name=self._schema_name,
|
886
|
-
statement_params=statement_params,
|
887
|
-
)
|
888
|
-
if not baseline_table_exists:
|
889
|
-
raise ValueError(
|
890
|
-
f"Baseline table '{baseline_table_name}' does not exist for model: "
|
891
|
-
f"'{model_name}' and model_version: '{model_version_name}'"
|
892
|
-
)
|
893
|
-
|
894
|
-
fully_qualified_baseline_table_name = [self._database_name, self._schema_name, baseline_table_name]
|
895
|
-
|
896
|
-
try:
|
897
|
-
# Truncate overwrites by clearing the rows in the table, instead of dropping the table.
|
898
|
-
# This lets us keep the schema to validate the baseline_df against.
|
899
|
-
baseline_df.write.mode("truncate").save_as_table(
|
900
|
-
fully_qualified_baseline_table_name, statement_params=statement_params
|
901
|
-
)
|
902
|
-
except exceptions.SnowparkSQLException as e:
|
903
|
-
raise ValueError(
|
904
|
-
f"""Failed to save baseline dataframe.
|
905
|
-
Ensure that the baseline dataframe columns match those provided in your monitored table: {e}"""
|
906
|
-
)
|
907
|
-
|
908
|
-
def _alter_monitor_dynamic_tables(
|
249
|
+
def _alter_monitor(
|
909
250
|
self,
|
910
251
|
operation: str,
|
911
|
-
|
912
|
-
version_name: sql_identifier.SqlIdentifier,
|
252
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
913
253
|
statement_params: Optional[Dict[str, Any]] = None,
|
914
254
|
) -> None:
|
915
255
|
if operation not in {"SUSPEND", "RESUME"}:
|
916
256
|
raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
|
917
|
-
fq_monitor_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, version_name)
|
918
257
|
query_result_checker.SqlResultValidator(
|
919
258
|
self._sql_client._session,
|
920
|
-
f"""ALTER
|
259
|
+
f"""ALTER MODEL MONITOR {self._database_name}.{self._schema_name}.{monitor_name} {operation}""",
|
921
260
|
statement_params=statement_params,
|
922
261
|
).has_column("status").has_dimensions(1, 1).validate()
|
923
262
|
|
924
|
-
|
925
|
-
query_result_checker.SqlResultValidator(
|
926
|
-
self._sql_client._session,
|
927
|
-
f"""ALTER DYNAMIC TABLE {fq_accuracy_dt_name} {operation}""",
|
928
|
-
statement_params=statement_params,
|
929
|
-
).has_column("status").has_dimensions(1, 1).validate()
|
930
|
-
|
931
|
-
def suspend_monitor_dynamic_tables(
|
263
|
+
def suspend_monitor(
|
932
264
|
self,
|
933
|
-
|
934
|
-
version_name: sql_identifier.SqlIdentifier,
|
265
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
935
266
|
statement_params: Optional[Dict[str, Any]] = None,
|
936
267
|
) -> None:
|
937
|
-
self.
|
268
|
+
self._alter_monitor(
|
938
269
|
operation="SUSPEND",
|
939
|
-
|
940
|
-
version_name=version_name,
|
270
|
+
monitor_name=monitor_name,
|
941
271
|
statement_params=statement_params,
|
942
272
|
)
|
943
273
|
|
944
|
-
def
|
274
|
+
def resume_monitor(
|
945
275
|
self,
|
946
|
-
|
947
|
-
version_name: sql_identifier.SqlIdentifier,
|
276
|
+
monitor_name: sql_identifier.SqlIdentifier,
|
948
277
|
statement_params: Optional[Dict[str, Any]] = None,
|
949
278
|
) -> None:
|
950
|
-
self.
|
279
|
+
self._alter_monitor(
|
951
280
|
operation="RESUME",
|
952
|
-
|
953
|
-
version_name=version_name,
|
281
|
+
monitor_name=monitor_name,
|
954
282
|
statement_params=statement_params,
|
955
283
|
)
|
956
|
-
|
957
|
-
def create_dynamic_tables_for_monitor(
|
958
|
-
self,
|
959
|
-
*,
|
960
|
-
model_name: sql_identifier.SqlIdentifier,
|
961
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
962
|
-
task: type_hints.Task,
|
963
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
964
|
-
refresh_interval: model_monitor_interval.ModelMonitorRefreshInterval,
|
965
|
-
aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow,
|
966
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
967
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
968
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
969
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
970
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
971
|
-
score_type: output_score_type.OutputScoreType,
|
972
|
-
) -> None:
|
973
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
974
|
-
self._sql_client._session,
|
975
|
-
self._database_name,
|
976
|
-
self._schema_name,
|
977
|
-
source_table_name,
|
978
|
-
)
|
979
|
-
(numeric_features_names, categorical_feature_names) = _infer_numeric_categoric_feature_column_names(
|
980
|
-
source_table_schema=table_schema,
|
981
|
-
timestamp_column=timestamp_column,
|
982
|
-
id_columns=id_columns,
|
983
|
-
prediction_columns=prediction_columns,
|
984
|
-
label_columns=label_columns,
|
985
|
-
)
|
986
|
-
features_dynamic_table_query = self._monitoring_dynamic_table_query(
|
987
|
-
model_name=model_name,
|
988
|
-
model_version_name=model_version_name,
|
989
|
-
source_table_name=source_table_name,
|
990
|
-
refresh_interval=refresh_interval,
|
991
|
-
aggregate_window=aggregation_window,
|
992
|
-
warehouse_name=warehouse_name,
|
993
|
-
timestamp_column=timestamp_column,
|
994
|
-
numeric_features=numeric_features_names,
|
995
|
-
categoric_features=categorical_feature_names,
|
996
|
-
prediction_columns=prediction_columns,
|
997
|
-
label_columns=label_columns,
|
998
|
-
)
|
999
|
-
query_result_checker.SqlResultValidator(self._sql_client._session, features_dynamic_table_query).has_column(
|
1000
|
-
"status"
|
1001
|
-
).has_dimensions(1, 1).validate()
|
1002
|
-
|
1003
|
-
label_pred_join_table_query = self._monitoring_accuracy_table_query(
|
1004
|
-
model_name=model_name,
|
1005
|
-
model_version_name=model_version_name,
|
1006
|
-
task=task,
|
1007
|
-
source_table_name=source_table_name,
|
1008
|
-
refresh_interval=refresh_interval,
|
1009
|
-
aggregate_window=aggregation_window,
|
1010
|
-
warehouse_name=warehouse_name,
|
1011
|
-
timestamp_column=timestamp_column,
|
1012
|
-
prediction_columns=prediction_columns,
|
1013
|
-
label_columns=label_columns,
|
1014
|
-
score_type=score_type,
|
1015
|
-
)
|
1016
|
-
query_result_checker.SqlResultValidator(self._sql_client._session, label_pred_join_table_query).has_column(
|
1017
|
-
"status"
|
1018
|
-
).has_dimensions(1, 1).validate()
|
1019
|
-
|
1020
|
-
def _monitoring_dynamic_table_query(
|
1021
|
-
self,
|
1022
|
-
*,
|
1023
|
-
model_name: sql_identifier.SqlIdentifier,
|
1024
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1025
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1026
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1027
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1028
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1029
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1030
|
-
numeric_features: List[sql_identifier.SqlIdentifier],
|
1031
|
-
categoric_features: List[sql_identifier.SqlIdentifier],
|
1032
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1033
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1034
|
-
) -> str:
|
1035
|
-
"""
|
1036
|
-
Generates a dynamic table query for Observability - Monitoring.
|
1037
|
-
|
1038
|
-
Args:
|
1039
|
-
model_name: Model name to monitor.
|
1040
|
-
model_version_name: Model version name to monitor.
|
1041
|
-
source_table_name: Name of source data table to monitor.
|
1042
|
-
refresh_interval: Refresh interval in minutes.
|
1043
|
-
aggregate_window: Aggregate window minutes.
|
1044
|
-
warehouse_name: Warehouse name to use for dynamic table.
|
1045
|
-
timestamp_column: Timestamp column name.
|
1046
|
-
numeric_features: List of numeric features to capture.
|
1047
|
-
categoric_features: List of categoric features to capture.
|
1048
|
-
prediction_columns: List of columns that contain model inference outputs.
|
1049
|
-
label_columns: List of columns that contain ground truth values.
|
1050
|
-
|
1051
|
-
Raises:
|
1052
|
-
ValueError: If multiple output/ground truth columns are specified. MultiClass models are not yet supported.
|
1053
|
-
|
1054
|
-
Returns:
|
1055
|
-
Dynamic table query.
|
1056
|
-
"""
|
1057
|
-
# output and ground cols are list to keep interface extensible.
|
1058
|
-
# for prpr only one label and one output col will be supported
|
1059
|
-
if len(prediction_columns) != 1 or len(label_columns) != 1:
|
1060
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
1061
|
-
|
1062
|
-
monitoring_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1063
|
-
|
1064
|
-
feature_cols_query_list = []
|
1065
|
-
for feature in numeric_features + prediction_columns + label_columns:
|
1066
|
-
feature_cols_query_list.append(
|
1067
|
-
"""
|
1068
|
-
OBJECT_CONSTRUCT(
|
1069
|
-
'sketch', APPROX_PERCENTILE_ACCUMULATE({col}),
|
1070
|
-
'count', count_if({col} is not null),
|
1071
|
-
'count_null', count_if({col} is null),
|
1072
|
-
'min', min({col}),
|
1073
|
-
'max', max({col}),
|
1074
|
-
'sum', sum({col})
|
1075
|
-
) AS {col}""".format(
|
1076
|
-
col=feature
|
1077
|
-
)
|
1078
|
-
)
|
1079
|
-
|
1080
|
-
for col in categoric_features:
|
1081
|
-
feature_cols_query_list.append(
|
1082
|
-
f"""
|
1083
|
-
{self._database_name}.{self._schema_name}.OBJECT_SUM(to_varchar({col})) AS {col}"""
|
1084
|
-
)
|
1085
|
-
feature_cols_query = ",".join(feature_cols_query_list)
|
1086
|
-
|
1087
|
-
return f"""
|
1088
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1089
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1090
|
-
WAREHOUSE = {warehouse_name}
|
1091
|
-
REFRESH_MODE = AUTO
|
1092
|
-
INITIALIZE = ON_CREATE
|
1093
|
-
AS
|
1094
|
-
SELECT
|
1095
|
-
TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,{feature_cols_query}
|
1096
|
-
FROM
|
1097
|
-
{source_table_name}
|
1098
|
-
GROUP BY
|
1099
|
-
1
|
1100
|
-
"""
|
1101
|
-
|
1102
|
-
def _monitoring_accuracy_table_query(
|
1103
|
-
self,
|
1104
|
-
*,
|
1105
|
-
model_name: sql_identifier.SqlIdentifier,
|
1106
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1107
|
-
task: type_hints.Task,
|
1108
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1109
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1110
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1111
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1112
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1113
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1114
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1115
|
-
score_type: output_score_type.OutputScoreType,
|
1116
|
-
) -> str:
|
1117
|
-
# output and ground cols are list to keep interface extensible.
|
1118
|
-
# for prpr only one label and one output col will be supported
|
1119
|
-
if len(prediction_columns) != 1 or len(label_columns) != 1:
|
1120
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
1121
|
-
if task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION:
|
1122
|
-
return self._monitoring_classification_accuracy_table_query(
|
1123
|
-
model_name=model_name,
|
1124
|
-
model_version_name=model_version_name,
|
1125
|
-
source_table_name=source_table_name,
|
1126
|
-
refresh_interval=refresh_interval,
|
1127
|
-
aggregate_window=aggregate_window,
|
1128
|
-
warehouse_name=warehouse_name,
|
1129
|
-
timestamp_column=timestamp_column,
|
1130
|
-
prediction_columns=prediction_columns,
|
1131
|
-
label_columns=label_columns,
|
1132
|
-
score_type=score_type,
|
1133
|
-
)
|
1134
|
-
else:
|
1135
|
-
return self._monitoring_regression_accuracy_table_query(
|
1136
|
-
model_name=model_name,
|
1137
|
-
model_version_name=model_version_name,
|
1138
|
-
source_table_name=source_table_name,
|
1139
|
-
refresh_interval=refresh_interval,
|
1140
|
-
aggregate_window=aggregate_window,
|
1141
|
-
warehouse_name=warehouse_name,
|
1142
|
-
timestamp_column=timestamp_column,
|
1143
|
-
prediction_columns=prediction_columns,
|
1144
|
-
label_columns=label_columns,
|
1145
|
-
)
|
1146
|
-
|
1147
|
-
def _monitoring_regression_accuracy_table_query(
|
1148
|
-
self,
|
1149
|
-
*,
|
1150
|
-
model_name: sql_identifier.SqlIdentifier,
|
1151
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1152
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1153
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1154
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1155
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1156
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1157
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1158
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1159
|
-
) -> str:
|
1160
|
-
"""
|
1161
|
-
Generates a dynamic table query for Monitoring - regression model accuracy.
|
1162
|
-
|
1163
|
-
Args:
|
1164
|
-
model_name: Model name to monitor.
|
1165
|
-
model_version_name: Model version name to monitor.
|
1166
|
-
source_table_name: Name of source data table to monitor.
|
1167
|
-
refresh_interval: Refresh interval in minutes.
|
1168
|
-
aggregate_window: Aggregate window minutes.
|
1169
|
-
warehouse_name: Warehouse name to use for dynamic table.
|
1170
|
-
timestamp_column: Timestamp column name.
|
1171
|
-
prediction_columns: List of output columns.
|
1172
|
-
label_columns: List of ground truth columns.
|
1173
|
-
|
1174
|
-
Returns:
|
1175
|
-
Dynamic table query.
|
1176
|
-
|
1177
|
-
Raises:
|
1178
|
-
ValueError: If output columns are not same as ground truth columns.
|
1179
|
-
|
1180
|
-
"""
|
1181
|
-
|
1182
|
-
if len(prediction_columns) != len(label_columns):
|
1183
|
-
raise ValueError(f"Mismatch in output & ground truth columns: {prediction_columns} != {label_columns}")
|
1184
|
-
|
1185
|
-
monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1186
|
-
|
1187
|
-
output_cols_query_list = []
|
1188
|
-
|
1189
|
-
output_cols_query_list.append(
|
1190
|
-
f"""
|
1191
|
-
OBJECT_CONSTRUCT(
|
1192
|
-
'sum_difference_label_pred', sum({prediction_columns[0]} - {label_columns[0]}),
|
1193
|
-
'sum_log_difference_square_label_pred',
|
1194
|
-
sum(
|
1195
|
-
case
|
1196
|
-
when {prediction_columns[0]} > -1 and {label_columns[0]} > -1
|
1197
|
-
then pow(ln({prediction_columns[0]} + 1) - ln({label_columns[0]} + 1),2)
|
1198
|
-
else null
|
1199
|
-
END
|
1200
|
-
),
|
1201
|
-
'sum_difference_squares_label_pred',
|
1202
|
-
sum(
|
1203
|
-
pow(
|
1204
|
-
{prediction_columns[0]} - {label_columns[0]},
|
1205
|
-
2
|
1206
|
-
)
|
1207
|
-
),
|
1208
|
-
'sum_absolute_regression_labels', sum(abs({label_columns[0]})),
|
1209
|
-
'sum_absolute_percentage_error',
|
1210
|
-
sum(
|
1211
|
-
abs(
|
1212
|
-
div0null(
|
1213
|
-
({prediction_columns[0]} - {label_columns[0]}),
|
1214
|
-
{label_columns[0]}
|
1215
|
-
)
|
1216
|
-
)
|
1217
|
-
),
|
1218
|
-
'sum_absolute_difference_label_pred',
|
1219
|
-
sum(
|
1220
|
-
abs({prediction_columns[0]} - {label_columns[0]})
|
1221
|
-
),
|
1222
|
-
'sum_prediction', sum({prediction_columns[0]}),
|
1223
|
-
'sum_label', sum({label_columns[0]}),
|
1224
|
-
'count', count(*)
|
1225
|
-
) AS AGGREGATE_METRICS,
|
1226
|
-
APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
|
1227
|
-
APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
|
1228
|
-
)
|
1229
|
-
output_cols_query = ", ".join(output_cols_query_list)
|
1230
|
-
|
1231
|
-
return f"""
|
1232
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1233
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1234
|
-
WAREHOUSE = {warehouse_name}
|
1235
|
-
REFRESH_MODE = AUTO
|
1236
|
-
INITIALIZE = ON_CREATE
|
1237
|
-
AS
|
1238
|
-
SELECT
|
1239
|
-
TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,
|
1240
|
-
'class_regression' label_class,{output_cols_query}
|
1241
|
-
FROM
|
1242
|
-
{source_table_name}
|
1243
|
-
GROUP BY
|
1244
|
-
1
|
1245
|
-
"""
|
1246
|
-
|
1247
|
-
def _monitoring_classification_accuracy_table_query(
|
1248
|
-
self,
|
1249
|
-
*,
|
1250
|
-
model_name: sql_identifier.SqlIdentifier,
|
1251
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1252
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1253
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1254
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1255
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1256
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1257
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1258
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1259
|
-
score_type: output_score_type.OutputScoreType,
|
1260
|
-
) -> str:
|
1261
|
-
monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1262
|
-
|
1263
|
-
# Initialize the select clause components
|
1264
|
-
select_clauses = []
|
1265
|
-
|
1266
|
-
select_clauses.append(
|
1267
|
-
f"""
|
1268
|
-
{prediction_columns[0]},
|
1269
|
-
{label_columns[0]},
|
1270
|
-
CASE
|
1271
|
-
WHEN {label_columns[0]} = 1 THEN 'class_positive'
|
1272
|
-
ELSE 'class_negative'
|
1273
|
-
END AS label_class"""
|
1274
|
-
)
|
1275
|
-
|
1276
|
-
# Join all the select clauses into a single string
|
1277
|
-
select_clause = f"{timestamp_column} AS timestamp," + ",".join(select_clauses)
|
1278
|
-
|
1279
|
-
# Create the final CTE query
|
1280
|
-
cte_query = f"""
|
1281
|
-
WITH filtered_data AS (
|
1282
|
-
SELECT
|
1283
|
-
{select_clause}
|
1284
|
-
FROM
|
1285
|
-
{source_table_name}
|
1286
|
-
)"""
|
1287
|
-
|
1288
|
-
# Initialize the select clause components
|
1289
|
-
select_clauses = []
|
1290
|
-
|
1291
|
-
score_type_agg_clause = ""
|
1292
|
-
if score_type == output_score_type.OutputScoreType.PROBITS:
|
1293
|
-
score_type_agg_clause = f"""
|
1294
|
-
'sum_log_loss',
|
1295
|
-
CASE
|
1296
|
-
WHEN label_class = 'class_positive' THEN sum(-ln({prediction_columns[0]}))
|
1297
|
-
ELSE sum(-ln(1 - {prediction_columns[0]}))
|
1298
|
-
END,"""
|
1299
|
-
else:
|
1300
|
-
score_type_agg_clause = f"""
|
1301
|
-
'tp', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 1),
|
1302
|
-
'tn', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 0),
|
1303
|
-
'fp', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 1),
|
1304
|
-
'fn', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 0),"""
|
1305
|
-
|
1306
|
-
select_clauses.append(
|
1307
|
-
f"""
|
1308
|
-
label_class,
|
1309
|
-
OBJECT_CONSTRUCT(
|
1310
|
-
'sum_prediction', sum({prediction_columns[0]}),
|
1311
|
-
'sum_label', sum({label_columns[0]}),{score_type_agg_clause}
|
1312
|
-
'count', count(*)
|
1313
|
-
) AS AGGREGATE_METRICS,
|
1314
|
-
APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
|
1315
|
-
APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
|
1316
|
-
)
|
1317
|
-
|
1318
|
-
# Join all the select clauses into a single string
|
1319
|
-
select_clause = ",\n".join(select_clauses)
|
1320
|
-
|
1321
|
-
return f"""
|
1322
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1323
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1324
|
-
WAREHOUSE = {warehouse_name}
|
1325
|
-
REFRESH_MODE = AUTO
|
1326
|
-
INITIALIZE = ON_CREATE
|
1327
|
-
AS{cte_query}
|
1328
|
-
select
|
1329
|
-
time_slice(timestamp, {aggregate_window.minutes}, 'MINUTE') timestamp,{select_clause}
|
1330
|
-
FROM
|
1331
|
-
filtered_data
|
1332
|
-
group by
|
1333
|
-
1,
|
1334
|
-
2
|
1335
|
-
"""
|