snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/type_utils.py +3 -3
  7. snowflake/ml/data/__init__.py +5 -0
  8. snowflake/ml/model/_client/model/model_version_impl.py +7 -7
  9. snowflake/ml/model/_client/ops/model_ops.py +51 -30
  10. snowflake/ml/model/_client/ops/service_ops.py +13 -2
  11. snowflake/ml/model/_client/sql/model.py +0 -14
  12. snowflake/ml/model/_client/sql/service.py +25 -1
  13. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  14. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  15. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
  16. snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
  17. snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
  18. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
  19. snowflake/ml/model/_packager/model_handlers/sklearn.py +48 -1
  20. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  22. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  23. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  24. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  26. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  27. snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
  28. snowflake/ml/model/_signatures/core.py +63 -16
  29. snowflake/ml/model/_signatures/pandas_handler.py +71 -27
  30. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  31. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  32. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  33. snowflake/ml/model/_signatures/utils.py +4 -0
  34. snowflake/ml/model/model_signature.py +38 -9
  35. snowflake/ml/model/type_hints.py +1 -1
  36. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  37. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  38. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +158 -1045
  39. snowflake/ml/monitoring/_manager/model_monitor_manager.py +106 -230
  40. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  41. snowflake/ml/monitoring/model_monitor.py +7 -96
  42. snowflake/ml/registry/registry.py +17 -29
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +31 -5
  45. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +48 -47
  46. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  47. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  48. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,36 +1,23 @@
1
- import json
2
- import string
3
- import textwrap
4
1
  import typing
5
2
  from collections import Counter
6
- from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, TypedDict
7
-
8
- from importlib_resources import files
9
- from typing_extensions import Required
3
+ from typing import Any, Dict, List, Mapping, Optional, Set
10
4
 
11
5
  from snowflake import snowpark
12
- from snowflake.connector import errors
13
6
  from snowflake.ml._internal.utils import (
14
7
  db_utils,
15
- formatting,
16
8
  query_result_checker,
17
9
  sql_identifier,
18
10
  table_manager,
19
11
  )
20
- from snowflake.ml.model import type_hints
21
12
  from snowflake.ml.model._client.sql import _base
22
13
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
23
- from snowflake.ml.monitoring.entities import model_monitor_interval, output_score_type
24
- from snowflake.ml.monitoring.entities.model_monitor_interval import (
25
- ModelMonitorAggregationWindow,
26
- ModelMonitorRefreshInterval,
27
- )
28
- from snowflake.snowpark import DataFrame, exceptions, session, types
29
- from snowflake.snowpark._internal import type_utils
14
+ from snowflake.snowpark import session, types
30
15
 
31
16
  SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
32
- _SNOWML_MONITORING_TABLE_NAME_PREFIX = "_SNOWML_OBS_MONITORING_"
33
- _SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX = "_SNOWML_OBS_ACCURACY_"
17
+
18
+ MODEL_JSON_COL_NAME = "model"
19
+ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
20
+ MODEL_JSON_VERSION_NAME_FIELD = "version_name"
34
21
 
35
22
  MONITOR_NAME_COL_NAME = "MONITOR_NAME"
36
23
  SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
@@ -44,84 +31,10 @@ PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
44
31
  LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
45
32
  ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
46
33
 
47
- _DASHBOARD_UDTFS_COMMON_LIST = ["record_count"]
48
- _DASHBOARD_UDTFS_REGRESSION_LIST = ["rmse"]
49
-
50
-
51
- def _initialize_monitoring_metadata_tables(
52
- session: session.Session,
53
- database_name: sql_identifier.SqlIdentifier,
54
- schema_name: sql_identifier.SqlIdentifier,
55
- statement_params: Optional[Dict[str, Any]] = None,
56
- ) -> None:
57
- """Create tables necessary for Model Monitoring in provided schema.
58
-
59
- Args:
60
- session: Active Snowpark session.
61
- database_name: The database in which to setup resources for Model Monitoring.
62
- schema_name: The schema in which to setup resources for Model Monitoring.
63
- statement_params: Optional statement params for queries.
64
- """
65
- table_manager.create_single_table(
66
- session,
67
- database_name,
68
- schema_name,
69
- SNOWML_MONITORING_METADATA_TABLE_NAME,
70
- [
71
- (MONITOR_NAME_COL_NAME, "VARCHAR"),
72
- (SOURCE_TABLE_NAME_COL_NAME, "VARCHAR"),
73
- (FQ_MODEL_NAME_COL_NAME, "VARCHAR"),
74
- (VERSION_NAME_COL_NAME, "VARCHAR"),
75
- (FUNCTION_NAME_COL_NAME, "VARCHAR"),
76
- (TASK_COL_NAME, "VARCHAR"),
77
- (MONITORING_ENABLED_COL_NAME, "BOOLEAN"),
78
- (TIMESTAMP_COL_NAME_COL_NAME, "VARCHAR"),
79
- (PREDICTION_COL_NAMES_COL_NAME, "ARRAY"),
80
- (LABEL_COL_NAMES_COL_NAME, "ARRAY"),
81
- (ID_COL_NAMES_COL_NAME, "ARRAY"),
82
- ],
83
- statement_params=statement_params,
84
- )
85
-
86
-
87
- def _create_baseline_table_name(model_name: str, version_name: str) -> str:
88
- return f"_SNOWML_OBS_BASELINE_{model_name}_{version_name}"
89
-
90
-
91
- def _infer_numeric_categoric_feature_column_names(
92
- *,
93
- source_table_schema: Mapping[str, types.DataType],
94
- timestamp_column: sql_identifier.SqlIdentifier,
95
- id_columns: List[sql_identifier.SqlIdentifier],
96
- prediction_columns: List[sql_identifier.SqlIdentifier],
97
- label_columns: List[sql_identifier.SqlIdentifier],
98
- ) -> Tuple[List[sql_identifier.SqlIdentifier], List[sql_identifier.SqlIdentifier]]:
99
- cols_to_remove = {timestamp_column, *id_columns, *prediction_columns, *label_columns}
100
- cols_to_consider = [
101
- (col_name, source_table_schema[col_name]) for col_name in source_table_schema if col_name not in cols_to_remove
102
- ]
103
- numeric_cols = [
104
- sql_identifier.SqlIdentifier(column[0])
105
- for column in cols_to_consider
106
- if isinstance(column[1], types._NumericType)
107
- ]
108
- categorical_cols = [
109
- sql_identifier.SqlIdentifier(column[0])
110
- for column in cols_to_consider
111
- if isinstance(column[1], types.StringType) or isinstance(column[1], types.BooleanType)
112
- ]
113
- return (numeric_cols, categorical_cols)
114
-
115
-
116
- class _ModelMonitorParams(TypedDict):
117
- """Class to transfer model monitor parameters to the ModelMonitor class."""
118
34
 
119
- monitor_name: Required[str]
120
- fully_qualified_model_name: Required[str]
121
- version_name: Required[str]
122
- function_name: Required[str]
123
- prediction_columns: Required[List[sql_identifier.SqlIdentifier]]
124
- label_columns: Required[List[sql_identifier.SqlIdentifier]]
35
+ def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
36
+ sql_list = ", ".join([f"'{column}'" for column in columns])
37
+ return f"({sql_list})"
125
38
 
126
39
 
127
40
  class ModelMonitorSQLClient:
@@ -143,38 +56,95 @@ class ModelMonitorSQLClient:
143
56
  self._database_name = database_name
144
57
  self._schema_name = schema_name
145
58
 
146
- @staticmethod
147
- def initialize_monitoring_schema(
148
- session: session.Session,
149
- database_name: sql_identifier.SqlIdentifier,
150
- schema_name: sql_identifier.SqlIdentifier,
59
+ def _infer_qualified_schema(
60
+ self, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier]
61
+ ) -> str:
62
+ return f"{database_name or self._database_name}.{schema_name or self._schema_name}"
63
+
64
+ def create_model_monitor(
65
+ self,
66
+ *,
67
+ monitor_database: Optional[sql_identifier.SqlIdentifier],
68
+ monitor_schema: Optional[sql_identifier.SqlIdentifier],
69
+ monitor_name: sql_identifier.SqlIdentifier,
70
+ source_database: Optional[sql_identifier.SqlIdentifier],
71
+ source_schema: Optional[sql_identifier.SqlIdentifier],
72
+ source: sql_identifier.SqlIdentifier,
73
+ model_database: Optional[sql_identifier.SqlIdentifier],
74
+ model_schema: Optional[sql_identifier.SqlIdentifier],
75
+ model_name: sql_identifier.SqlIdentifier,
76
+ version_name: sql_identifier.SqlIdentifier,
77
+ function_name: str,
78
+ warehouse_name: sql_identifier.SqlIdentifier,
79
+ timestamp_column: sql_identifier.SqlIdentifier,
80
+ id_columns: List[sql_identifier.SqlIdentifier],
81
+ prediction_score_columns: List[sql_identifier.SqlIdentifier],
82
+ prediction_class_columns: List[sql_identifier.SqlIdentifier],
83
+ actual_score_columns: List[sql_identifier.SqlIdentifier],
84
+ actual_class_columns: List[sql_identifier.SqlIdentifier],
85
+ refresh_interval: str,
86
+ aggregation_window: str,
87
+ baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
88
+ baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
89
+ baseline: Optional[sql_identifier.SqlIdentifier] = None,
151
90
  statement_params: Optional[Dict[str, Any]] = None,
152
91
  ) -> None:
153
- """Initialize tables for tracking metadata associated with model monitoring.
154
-
155
- Args:
156
- session: The Snowpark Session to connect with Snowflake.
157
- database_name: The database in which to setup resources for Model Monitoring.
158
- schema_name: The schema in which to setup resources for Model Monitoring.
159
- statement_params: Optional set of statement_params to include with query.
160
- """
161
- # Create metadata management tables
162
- _initialize_monitoring_metadata_tables(session, database_name, schema_name, statement_params)
92
+ baseline_sql = ""
93
+ if baseline:
94
+ baseline_sql = f"BASELINE='{self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}'"
95
+ query_result_checker.SqlResultValidator(
96
+ self._sql_client._session,
97
+ f"""
98
+ CREATE MODEL MONITOR {self._infer_qualified_schema(monitor_database, monitor_schema)}.{monitor_name}
99
+ WITH
100
+ MODEL='{self._infer_qualified_schema(model_database, model_schema)}.{model_name}'
101
+ VERSION='{version_name}'
102
+ FUNCTION='{function_name}'
103
+ WAREHOUSE='{warehouse_name}'
104
+ SOURCE='{self._infer_qualified_schema(source_database, source_schema)}.{source}'
105
+ ID_COLUMNS={_build_sql_list_from_columns(id_columns)}
106
+ PREDICTION_SCORE_COLUMNS={_build_sql_list_from_columns(prediction_score_columns)}
107
+ PREDICTION_CLASS_COLUMNS={_build_sql_list_from_columns(prediction_class_columns)}
108
+ ACTUAL_SCORE_COLUMNS={_build_sql_list_from_columns(actual_score_columns)}
109
+ ACTUAL_CLASS_COLUMNS={_build_sql_list_from_columns(actual_class_columns)}
110
+ TIMESTAMP_COLUMN='{timestamp_column}'
111
+ REFRESH_INTERVAL='{refresh_interval}'
112
+ AGGREGATION_WINDOW='{aggregation_window}'
113
+ {baseline_sql}""",
114
+ statement_params=statement_params,
115
+ ).has_column("status").has_dimensions(1, 1).validate()
163
116
 
164
- def _validate_is_initialized(self) -> bool:
165
- """Validates whether monitoring metadata has been initialized.
117
+ def drop_model_monitor(
118
+ self,
119
+ *,
120
+ database_name: Optional[sql_identifier.SqlIdentifier] = None,
121
+ schema_name: Optional[sql_identifier.SqlIdentifier] = None,
122
+ monitor_name: sql_identifier.SqlIdentifier,
123
+ statement_params: Optional[Dict[str, Any]] = None,
124
+ ) -> None:
125
+ search_database_name = database_name or self._database_name
126
+ search_schema_name = schema_name or self._schema_name
127
+ query_result_checker.SqlResultValidator(
128
+ self._sql_client._session,
129
+ f"DROP MODEL MONITOR {search_database_name}.{search_schema_name}.{monitor_name}",
130
+ statement_params=statement_params,
131
+ ).validate()
166
132
 
167
- Returns:
168
- boolean to indicate whether tables have been initialized.
169
- """
170
- try:
171
- return table_manager.validate_table_exist(
133
+ def show_model_monitors(
134
+ self,
135
+ *,
136
+ statement_params: Optional[Dict[str, Any]] = None,
137
+ ) -> List[snowpark.Row]:
138
+ fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
139
+ return (
140
+ query_result_checker.SqlResultValidator(
172
141
  self._sql_client._session,
173
- SNOWML_MONITORING_METADATA_TABLE_NAME,
174
- f"{self._database_name}.{self._schema_name}",
142
+ f"SHOW MODEL MONITORS IN {fully_qualified_schema_name}",
143
+ statement_params=statement_params,
175
144
  )
176
- except exceptions.SnowparkSQLException:
177
- return False
145
+ .has_column("name", allow_empty=True)
146
+ .validate()
147
+ )
178
148
 
179
149
  def _validate_unique_columns(
180
150
  self,
@@ -191,53 +161,24 @@ class ModelMonitorSQLClient:
191
161
 
192
162
  def validate_existence_by_name(
193
163
  self,
164
+ *,
165
+ database_name: Optional[sql_identifier.SqlIdentifier] = None,
166
+ schema_name: Optional[sql_identifier.SqlIdentifier] = None,
194
167
  monitor_name: sql_identifier.SqlIdentifier,
195
168
  statement_params: Optional[Dict[str, Any]] = None,
196
169
  ) -> bool:
170
+ search_database_name = database_name or self._database_name
171
+ search_schema_name = schema_name or self._schema_name
197
172
  res = (
198
173
  query_result_checker.SqlResultValidator(
199
174
  self._sql_client._session,
200
- f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}
201
- FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
202
- WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
203
- statement_params=statement_params,
204
- )
205
- .has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True)
206
- .has_column(VERSION_NAME_COL_NAME, allow_empty=True)
207
- .validate()
208
- )
209
- return len(res) >= 1
210
-
211
- def validate_existence(
212
- self,
213
- fully_qualified_model_name: str,
214
- version_name: sql_identifier.SqlIdentifier,
215
- statement_params: Optional[Dict[str, Any]] = None,
216
- ) -> bool:
217
- """Validate existence of a ModelMonitor on a Model Version.
218
-
219
- Args:
220
- fully_qualified_model_name: Fully qualified name of model.
221
- version_name: Name of model version.
222
- statement_params: Optional set of statement_params to include with query.
223
-
224
- Returns:
225
- Boolean indicating whether monitor exists on model version.
226
- """
227
- res = (
228
- query_result_checker.SqlResultValidator(
229
- self._sql_client._session,
230
- f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}
231
- FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
232
- WHERE {FQ_MODEL_NAME_COL_NAME} = '{fully_qualified_model_name}'
233
- AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
175
+ f"SHOW MODEL MONITORS LIKE '{monitor_name.resolved()}' IN {search_database_name}.{search_schema_name}",
234
176
  statement_params=statement_params,
235
177
  )
236
- .has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True)
237
- .has_column(VERSION_NAME_COL_NAME, allow_empty=True)
178
+ .has_column("name", allow_empty=True)
238
179
  .validate()
239
180
  )
240
- return len(res) >= 1
181
+ return len(res) == 1
241
182
 
242
183
  def validate_monitor_warehouse(
243
184
  self,
@@ -261,115 +202,47 @@ class ModelMonitorSQLClient:
261
202
  ):
262
203
  raise ValueError(f"Warehouse '{warehouse_name}' not found.")
263
204
 
264
- def add_dashboard_udtfs(
265
- self,
266
- monitor_name: sql_identifier.SqlIdentifier,
267
- model_name: sql_identifier.SqlIdentifier,
268
- model_version_name: sql_identifier.SqlIdentifier,
269
- task: type_hints.Task,
270
- score_type: output_score_type.OutputScoreType,
271
- output_columns: List[sql_identifier.SqlIdentifier],
272
- ground_truth_columns: List[sql_identifier.SqlIdentifier],
273
- statement_params: Optional[Dict[str, Any]] = None,
274
- ) -> None:
275
- udtf_name_query_map = self._create_dashboard_udtf_queries(
276
- monitor_name,
277
- model_name,
278
- model_version_name,
279
- task,
280
- score_type,
281
- output_columns,
282
- ground_truth_columns,
283
- )
284
- for udtf_query in udtf_name_query_map.values():
285
- query_result_checker.SqlResultValidator(
286
- self._sql_client._session,
287
- f"""{udtf_query}""",
288
- statement_params=statement_params,
289
- ).validate()
290
-
291
- def get_monitoring_table_fully_qualified_name(
292
- self,
293
- model_name: sql_identifier.SqlIdentifier,
294
- model_version_name: sql_identifier.SqlIdentifier,
295
- ) -> str:
296
- table_name = f"{_SNOWML_MONITORING_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
297
- return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
298
-
299
- def get_accuracy_monitoring_table_fully_qualified_name(
300
- self,
301
- model_name: sql_identifier.SqlIdentifier,
302
- model_version_name: sql_identifier.SqlIdentifier,
303
- ) -> str:
304
- table_name = f"{_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
305
- return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
306
-
307
- def _create_dashboard_udtf_queries(
308
- self,
309
- monitor_name: sql_identifier.SqlIdentifier,
310
- model_name: sql_identifier.SqlIdentifier,
311
- model_version_name: sql_identifier.SqlIdentifier,
312
- task: type_hints.Task,
313
- score_type: output_score_type.OutputScoreType,
314
- output_columns: List[sql_identifier.SqlIdentifier],
315
- ground_truth_columns: List[sql_identifier.SqlIdentifier],
316
- ) -> Mapping[str, str]:
317
- query_files = files("snowflake.ml.monitoring._client")
318
- # TODO(apgupta): Expand list of queries based on model objective and score type.
319
- queries_list = []
320
- queries_list.extend(_DASHBOARD_UDTFS_COMMON_LIST)
321
- if task == type_hints.Task.TABULAR_REGRESSION:
322
- queries_list.extend(_DASHBOARD_UDTFS_REGRESSION_LIST)
323
- var_map = {
324
- "MODEL_MONITOR_NAME": monitor_name,
325
- "MONITORING_TABLE": self.get_monitoring_table_fully_qualified_name(model_name, model_version_name),
326
- "MONITORING_PRED_LABEL_JOINED_TABLE": self.get_accuracy_monitoring_table_fully_qualified_name(
327
- model_name, model_version_name
328
- ),
329
- "OUTPUT_COLUMN_NAME": output_columns[0],
330
- "GROUND_TRUTH_COLUMN_NAME": ground_truth_columns[0],
331
- }
332
-
333
- udf_name_query_map = {}
334
- for q in queries_list:
335
- q_template = query_files.joinpath(f"queries/{q}.ssql").read_text()
336
- q_actual = string.Template(q_template).substitute(var_map)
337
- udf_name_query_map[q] = q_actual
338
- return udf_name_query_map
339
-
340
- def _validate_columns_exist_in_source_table(
205
+ def _validate_columns_exist_in_source(
341
206
  self,
342
207
  *,
343
- table_schema: Mapping[str, types.DataType],
344
- source_table_name: sql_identifier.SqlIdentifier,
208
+ source_column_schema: Mapping[str, types.DataType],
345
209
  timestamp_column: sql_identifier.SqlIdentifier,
346
- prediction_columns: List[sql_identifier.SqlIdentifier],
347
- label_columns: List[sql_identifier.SqlIdentifier],
210
+ prediction_score_columns: List[sql_identifier.SqlIdentifier],
211
+ prediction_class_columns: List[sql_identifier.SqlIdentifier],
212
+ actual_score_columns: List[sql_identifier.SqlIdentifier],
213
+ actual_class_columns: List[sql_identifier.SqlIdentifier],
348
214
  id_columns: List[sql_identifier.SqlIdentifier],
349
215
  ) -> None:
350
216
  """Ensures all columns exist in the source table.
351
217
 
352
218
  Args:
353
- table_schema: Dictionary of column names and types in the source table.
354
- source_table_name: Name of the table with model data to monitor.
219
+ source_column_schema: Dictionary of column names and types in the source.
355
220
  timestamp_column: Name of the timestamp column.
356
- prediction_columns: List of prediction column names.
357
- label_columns: List of label column names.
221
+ prediction_score_columns: List of prediction score column names.
222
+ prediction_class_columns: List of prediction class names.
223
+ actual_score_columns: List of actual score column names.
224
+ actual_class_columns: List of actual class column names.
358
225
  id_columns: List of id column names.
359
226
 
360
227
  Raises:
361
- ValueError: If any of the columns do not exist in the source table.
228
+ ValueError: If any of the columns do not exist in the source.
362
229
  """
363
230
 
364
- if timestamp_column not in table_schema:
365
- raise ValueError(f"Timestamp column {timestamp_column} does not exist in table {source_table_name}.")
231
+ if timestamp_column not in source_column_schema:
232
+ raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.")
366
233
 
367
- if not all([column_name in table_schema for column_name in prediction_columns]):
368
- raise ValueError(f"Prediction column(s): {prediction_columns} do not exist in table {source_table_name}.")
369
- if not all([column_name in table_schema for column_name in label_columns]):
370
- raise ValueError(f"Label column(s): {label_columns} do not exist in table {source_table_name}.")
371
- if not all([column_name in table_schema for column_name in id_columns]):
372
- raise ValueError(f"ID column(s): {id_columns} do not exist in table {source_table_name}.")
234
+ if not all([column_name in source_column_schema for column_name in prediction_score_columns]):
235
+ raise ValueError(f"Prediction Score column(s): {prediction_score_columns} do not exist in source.")
236
+ if not all([column_name in source_column_schema for column_name in prediction_class_columns]):
237
+ raise ValueError(f"Prediction Class column(s): {prediction_class_columns} do not exist in source.")
238
+ if not all([column_name in source_column_schema for column_name in actual_score_columns]):
239
+ raise ValueError(f"Actual Score column(s): {actual_score_columns} do not exist in source.")
240
+
241
+ if not all([column_name in source_column_schema for column_name in actual_class_columns]):
242
+ raise ValueError(f"Actual Class column(s): {actual_class_columns} do not exist in source.")
243
+
244
+ if not all([column_name in source_column_schema for column_name in id_columns]):
245
+ raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
373
246
 
374
247
  def _validate_timestamp_column_type(
375
248
  self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
@@ -490,190 +363,37 @@ class ModelMonitorSQLClient:
490
363
  f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
491
364
  )
492
365
 
493
- def get_model_monitor_by_name(
494
- self,
495
- monitor_name: sql_identifier.SqlIdentifier,
496
- statement_params: Optional[Dict[str, Any]] = None,
497
- ) -> _ModelMonitorParams:
498
- """Fetch metadata for a Model Monitor by name.
499
-
500
- Args:
501
- monitor_name: Name of ModelMonitor to fetch.
502
- statement_params: Optional set of statement_params to include with query.
503
-
504
- Returns:
505
- _ModelMonitorParams dict with Name of monitor, fully qualified model name,
506
- model version name, model function name, prediction_col, label_col.
507
-
508
- Raises:
509
- ValueError: If multiple ModelMonitors exist with the same name.
510
- """
511
- try:
512
- res = (
513
- query_result_checker.SqlResultValidator(
514
- self._sql_client._session,
515
- f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME},
516
- {PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME}
517
- FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
518
- WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
519
- statement_params=statement_params,
520
- )
521
- .has_column(FQ_MODEL_NAME_COL_NAME)
522
- .has_column(VERSION_NAME_COL_NAME)
523
- .has_column(FUNCTION_NAME_COL_NAME)
524
- .has_column(PREDICTION_COL_NAMES_COL_NAME)
525
- .has_column(LABEL_COL_NAMES_COL_NAME)
526
- .validate()
527
- )
528
- except errors.DataError:
529
- raise ValueError(f"Failed to find any monitor with name '{monitor_name}'")
530
-
531
- if len(res) > 1:
532
- raise ValueError(f"Invalid state. Multiple Monitors exist with name '{monitor_name}'")
533
-
534
- return _ModelMonitorParams(
535
- monitor_name=str(monitor_name),
536
- fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
537
- version_name=res[0][VERSION_NAME_COL_NAME],
538
- function_name=res[0][FUNCTION_NAME_COL_NAME],
539
- prediction_columns=[
540
- sql_identifier.SqlIdentifier(prediction_column)
541
- for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
542
- ],
543
- label_columns=[
544
- sql_identifier.SqlIdentifier(label_column)
545
- for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
546
- ],
547
- )
548
-
549
- def get_model_monitor_by_model_version(
366
+ def validate_source(
550
367
  self,
551
368
  *,
552
- model_db: sql_identifier.SqlIdentifier,
553
- model_schema: sql_identifier.SqlIdentifier,
554
- model_name: sql_identifier.SqlIdentifier,
555
- version_name: sql_identifier.SqlIdentifier,
556
- statement_params: Optional[Dict[str, Any]] = None,
557
- ) -> _ModelMonitorParams:
558
- """Fetch metadata for a Model Monitor by model version.
559
-
560
- Args:
561
- model_db: Database of model.
562
- model_schema: Schema of model.
563
- model_name: Model name.
564
- version_name: Model version name
565
- statement_params: Optional set of statement_params to include with queries.
566
-
567
- Returns:
568
- _ModelMonitorParams dict with Name of monitor, fully qualified model name,
569
- model version name, model function name, prediction_col, label_col.
570
-
571
- Raises:
572
- ValueError: If multiple ModelMonitors exist with the same name.
573
- """
574
- res = (
575
- query_result_checker.SqlResultValidator(
576
- self._sql_client._session,
577
- f"""SELECT {MONITOR_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
578
- {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}
579
- FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
580
- WHERE {FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{model_name}'
581
- AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
582
- statement_params=statement_params,
583
- )
584
- .has_column(MONITOR_NAME_COL_NAME)
585
- .has_column(FQ_MODEL_NAME_COL_NAME)
586
- .has_column(VERSION_NAME_COL_NAME)
587
- .has_column(FUNCTION_NAME_COL_NAME)
588
- .validate()
589
- )
590
- if len(res) > 1:
591
- raise ValueError(
592
- f"Invalid state. Multiple Monitors exist for model: '{model_name}' and version: '{version_name}'"
593
- )
594
- return _ModelMonitorParams(
595
- monitor_name=res[0][MONITOR_NAME_COL_NAME],
596
- fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
597
- version_name=res[0][VERSION_NAME_COL_NAME],
598
- function_name=res[0][FUNCTION_NAME_COL_NAME],
599
- prediction_columns=[
600
- sql_identifier.SqlIdentifier(prediction_column)
601
- for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
602
- ],
603
- label_columns=[
604
- sql_identifier.SqlIdentifier(label_column)
605
- for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
606
- ],
607
- )
608
-
609
- def get_score_type(
610
- self,
611
- task: type_hints.Task,
612
- source_table_name: sql_identifier.SqlIdentifier,
613
- prediction_columns: List[sql_identifier.SqlIdentifier],
614
- ) -> output_score_type.OutputScoreType:
615
- """Infer score type given model task and prediction table columns.
616
-
617
- Args:
618
- task: Model task
619
- source_table_name: Source data table containing model outputs.
620
- prediction_columns: columns in source data table corresponding to model outputs.
621
-
622
- Returns:
623
- OutputScoreType for model.
624
- """
625
- table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
626
- self._sql_client._session,
627
- self._database_name,
628
- self._schema_name,
629
- source_table_name,
630
- )
631
- return output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_columns, task)
632
-
633
- def validate_source_table(
634
- self,
635
- source_table_name: sql_identifier.SqlIdentifier,
369
+ source_database: Optional[sql_identifier.SqlIdentifier],
370
+ source_schema: Optional[sql_identifier.SqlIdentifier],
371
+ source: sql_identifier.SqlIdentifier,
636
372
  timestamp_column: sql_identifier.SqlIdentifier,
637
- prediction_columns: List[sql_identifier.SqlIdentifier],
638
- label_columns: List[sql_identifier.SqlIdentifier],
373
+ prediction_score_columns: List[sql_identifier.SqlIdentifier],
374
+ prediction_class_columns: List[sql_identifier.SqlIdentifier],
375
+ actual_score_columns: List[sql_identifier.SqlIdentifier],
376
+ actual_class_columns: List[sql_identifier.SqlIdentifier],
639
377
  id_columns: List[sql_identifier.SqlIdentifier],
640
- model_function: model_manifest_schema.ModelFunctionInfo,
641
378
  ) -> None:
642
- # Validate source table exists
643
- if not table_manager.validate_table_exist(
644
- self._sql_client._session,
645
- source_table_name,
646
- f"{self._database_name}.{self._schema_name}",
647
- ):
648
- raise ValueError(
649
- f"Table {source_table_name} does not exist in schema {self._database_name}.{self._schema_name}."
650
- )
651
- table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
379
+ source_database = source_database or self._database_name
380
+ source_schema = source_schema or self._schema_name
381
+ # Get Schema of the source. Implicitly validates that the source exists.
382
+ source_column_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
652
383
  self._sql_client._session,
653
- self._database_name,
654
- self._schema_name,
655
- source_table_name,
384
+ source_database,
385
+ source_schema,
386
+ source,
656
387
  )
657
- self._validate_columns_exist_in_source_table(
658
- table_schema=table_schema,
659
- source_table_name=source_table_name,
388
+ self._validate_columns_exist_in_source(
389
+ source_column_schema=source_column_schema,
660
390
  timestamp_column=timestamp_column,
661
- prediction_columns=prediction_columns,
662
- label_columns=label_columns,
391
+ prediction_score_columns=prediction_score_columns,
392
+ prediction_class_columns=prediction_class_columns,
393
+ actual_score_columns=actual_score_columns,
394
+ actual_class_columns=actual_class_columns,
663
395
  id_columns=id_columns,
664
396
  )
665
- self._validate_column_types(
666
- table_schema=table_schema,
667
- timestamp_column=timestamp_column,
668
- id_columns=id_columns,
669
- prediction_columns=prediction_columns,
670
- label_columns=label_columns,
671
- )
672
- self._validate_source_table_features_shape(
673
- table_schema=table_schema,
674
- special_columns={timestamp_column, *id_columns, *prediction_columns, *label_columns},
675
- model_function=model_function,
676
- )
677
397
 
678
398
  def delete_monitor_metadata(
679
399
  self,
@@ -691,645 +411,38 @@ class ModelMonitorSQLClient:
691
411
  WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
692
412
  ).collect(statement_params=statement_params)
693
413
 
694
- def delete_baseline_table(
695
- self,
696
- fully_qualified_model_name: str,
697
- version_name: str,
698
- statement_params: Optional[Dict[str, Any]] = None,
699
- ) -> None:
700
- """Delete the baseline table corresponding to a particular model and version.
701
-
702
- Args:
703
- fully_qualified_model_name: Fully qualified name of the model.
704
- version_name: Name of the model version.
705
- statement_params: Optional set of statement_params to include with query.
706
- """
707
- table_name = _create_baseline_table_name(fully_qualified_model_name, version_name)
708
- self._sql_client._session.sql(
709
- f"""DROP TABLE IF EXISTS {self._database_name}.{self._schema_name}.{table_name}"""
710
- ).collect(statement_params=statement_params)
711
-
712
- def delete_dynamic_tables(
713
- self,
714
- fully_qualified_model_name: str,
715
- version_name: str,
716
- statement_params: Optional[Dict[str, Any]] = None,
717
- ) -> None:
718
- """Delete the dynamic tables corresponding to a particular model and version.
719
-
720
- Args:
721
- fully_qualified_model_name: Fully qualified name of the model.
722
- version_name: Name of the model version.
723
- statement_params: Optional set of statement_params to include with query.
724
- """
725
- _, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
726
- model_id = sql_identifier.SqlIdentifier(model_name)
727
- version_id = sql_identifier.SqlIdentifier(version_name)
728
- monitoring_table_name = self.get_monitoring_table_fully_qualified_name(model_id, version_id)
729
- self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {monitoring_table_name}""").collect(
730
- statement_params=statement_params
731
- )
732
- accuracy_table_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_id, version_id)
733
- self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {accuracy_table_name}""").collect(
734
- statement_params=statement_params
735
- )
736
-
737
- def create_monitor_on_model_version(
738
- self,
739
- monitor_name: sql_identifier.SqlIdentifier,
740
- source_table_name: sql_identifier.SqlIdentifier,
741
- fully_qualified_model_name: str,
742
- version_name: sql_identifier.SqlIdentifier,
743
- function_name: str,
744
- timestamp_column: sql_identifier.SqlIdentifier,
745
- prediction_columns: List[sql_identifier.SqlIdentifier],
746
- label_columns: List[sql_identifier.SqlIdentifier],
747
- id_columns: List[sql_identifier.SqlIdentifier],
748
- task: type_hints.Task,
749
- statement_params: Optional[Dict[str, Any]] = None,
750
- ) -> None:
751
- """
752
- Creates a ModelMonitor on a Model Version from the Snowflake Model Registry. Creates public schema for metadata.
753
-
754
- Args:
755
- monitor_name: Name of monitor object to create.
756
- source_table_name: Name of source data table to monitor.
757
- fully_qualified_model_name: fully qualified name of model to monitor '<db>.<schema>.<model_name>'.
758
- version_name: model version name to monitor.
759
- function_name: function_name to monitor in model version.
760
- timestamp_column: timestamp column name.
761
- prediction_columns: list of prediction column names.
762
- label_columns: list of label column names.
763
- id_columns: list of id column names.
764
- task: Task of the model, e.g. TABULAR_REGRESSION.
765
- statement_params: Optional dict of statement_params to include with queries.
766
-
767
- Raises:
768
- ValueError: If model version is already monitored.
769
- """
770
- # Validate monitor does not already exist on model version.
771
- if self.validate_existence(fully_qualified_model_name, version_name, statement_params):
772
- raise ValueError(f"Model {fully_qualified_model_name} Version {version_name} is already monitored!")
773
-
774
- if self.validate_existence_by_name(monitor_name, statement_params):
775
- raise ValueError(f"Model Monitor with name '{monitor_name}' already exists!")
776
-
777
- prediction_columns_for_select = formatting.format_value_for_select(prediction_columns)
778
- label_columns_for_select = formatting.format_value_for_select(label_columns)
779
- id_columns_for_select = formatting.format_value_for_select(id_columns)
780
- query_result_checker.SqlResultValidator(
781
- self._sql_client._session,
782
- textwrap.dedent(
783
- f"""INSERT INTO {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
784
- ({MONITOR_NAME_COL_NAME}, {SOURCE_TABLE_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
785
- {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, {TASK_COL_NAME},
786
- {MONITORING_ENABLED_COL_NAME}, {TIMESTAMP_COL_NAME_COL_NAME},
787
- {PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME},
788
- {ID_COL_NAMES_COL_NAME})
789
- SELECT '{monitor_name}', '{source_table_name}', '{fully_qualified_model_name}',
790
- '{version_name}', '{function_name}', '{task.value}', TRUE, '{timestamp_column}',
791
- {prediction_columns_for_select}, {label_columns_for_select}, {id_columns_for_select}"""
792
- ),
793
- statement_params=statement_params,
794
- ).insertion_success(expected_num_rows=1).validate()
795
-
796
- def initialize_baseline_table(
797
- self,
798
- model_name: sql_identifier.SqlIdentifier,
799
- version_name: sql_identifier.SqlIdentifier,
800
- source_table_name: str,
801
- columns_to_drop: Optional[List[sql_identifier.SqlIdentifier]] = None,
802
- statement_params: Optional[Dict[str, Any]] = None,
803
- ) -> None:
804
- """
805
- Initializes the baseline table for a Model Version. Creates schema for baseline data using the source table.
806
-
807
- Args:
808
- model_name: name of model to monitor.
809
- version_name: model version name to monitor.
810
- source_table_name: name of the user's table containing their model data.
811
- columns_to_drop: special columns in the source table to be excluded from baseline tables.
812
- statement_params: Optional dict of statement_params to include with queries.
813
- """
814
- table_schema = table_manager.get_table_schema_types(
815
- self._sql_client._session,
816
- database=self._database_name,
817
- schema=self._schema_name,
818
- table_name=source_table_name,
819
- )
820
-
821
- if columns_to_drop is None:
822
- columns_to_drop = []
823
-
824
- table_manager.create_single_table(
825
- self._sql_client._session,
826
- self._database_name,
827
- self._schema_name,
828
- _create_baseline_table_name(model_name, version_name),
829
- [
830
- (k, type_utils.convert_sp_to_sf_type(v))
831
- for k, v in table_schema.items()
832
- if sql_identifier.SqlIdentifier(k) not in columns_to_drop
833
- ],
834
- statement_params=statement_params,
835
- )
836
-
837
- def get_all_model_monitor_metadata(
838
- self,
839
- statement_params: Optional[Dict[str, Any]] = None,
840
- ) -> List[snowpark.Row]:
841
- """Get the metadata for all model monitors in the given schema.
842
-
843
- Args:
844
- statement_params: Optional dict of statement_params to include with queries.
845
-
846
- Returns:
847
- List of snowpark.Row containing metadata for each model monitor.
848
- """
849
- return query_result_checker.SqlResultValidator(
850
- self._sql_client._session,
851
- f"""SELECT *
852
- FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}""",
853
- statement_params=statement_params,
854
- ).validate()
855
-
856
- def materialize_baseline_dataframe(
857
- self,
858
- baseline_df: DataFrame,
859
- fully_qualified_model_name: str,
860
- model_version_name: sql_identifier.SqlIdentifier,
861
- statement_params: Optional[Dict[str, Any]] = None,
862
- ) -> None:
863
- """
864
- Materialize baseline dataframe to a permanent snowflake table. This method
865
- truncates (overwrite without dropping) any existing data in the baseline table.
866
-
867
- Args:
868
- baseline_df: dataframe containing baseline data that monitored data will be compared against.
869
- fully_qualified_model_name: name of the model.
870
- model_version_name: model version name to monitor.
871
- statement_params: Optional dict of statement_params to include with queries.
872
-
873
- Raises:
874
- ValueError: If no baseline table was initialized.
875
- """
876
-
877
- _, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
878
- baseline_table_name = _create_baseline_table_name(model_name, model_version_name)
879
-
880
- baseline_table_exists = db_utils.db_object_exists(
881
- self._sql_client._session,
882
- db_utils.SnowflakeDbObjectType.TABLE,
883
- sql_identifier.SqlIdentifier(baseline_table_name),
884
- database_name=self._database_name,
885
- schema_name=self._schema_name,
886
- statement_params=statement_params,
887
- )
888
- if not baseline_table_exists:
889
- raise ValueError(
890
- f"Baseline table '{baseline_table_name}' does not exist for model: "
891
- f"'{model_name}' and model_version: '{model_version_name}'"
892
- )
893
-
894
- fully_qualified_baseline_table_name = [self._database_name, self._schema_name, baseline_table_name]
895
-
896
- try:
897
- # Truncate overwrites by clearing the rows in the table, instead of dropping the table.
898
- # This lets us keep the schema to validate the baseline_df against.
899
- baseline_df.write.mode("truncate").save_as_table(
900
- fully_qualified_baseline_table_name, statement_params=statement_params
901
- )
902
- except exceptions.SnowparkSQLException as e:
903
- raise ValueError(
904
- f"""Failed to save baseline dataframe.
905
- Ensure that the baseline dataframe columns match those provided in your monitored table: {e}"""
906
- )
907
-
908
- def _alter_monitor_dynamic_tables(
414
+ def _alter_monitor(
909
415
  self,
910
416
  operation: str,
911
- model_name: sql_identifier.SqlIdentifier,
912
- version_name: sql_identifier.SqlIdentifier,
417
+ monitor_name: sql_identifier.SqlIdentifier,
913
418
  statement_params: Optional[Dict[str, Any]] = None,
914
419
  ) -> None:
915
420
  if operation not in {"SUSPEND", "RESUME"}:
916
421
  raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
917
- fq_monitor_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, version_name)
918
- query_result_checker.SqlResultValidator(
919
- self._sql_client._session,
920
- f"""ALTER DYNAMIC TABLE {fq_monitor_dt_name} {operation}""",
921
- statement_params=statement_params,
922
- ).has_column("status").has_dimensions(1, 1).validate()
923
-
924
- fq_accuracy_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, version_name)
925
422
  query_result_checker.SqlResultValidator(
926
423
  self._sql_client._session,
927
- f"""ALTER DYNAMIC TABLE {fq_accuracy_dt_name} {operation}""",
424
+ f"""ALTER MODEL MONITOR {self._database_name}.{self._schema_name}.{monitor_name} {operation}""",
928
425
  statement_params=statement_params,
929
426
  ).has_column("status").has_dimensions(1, 1).validate()
930
427
 
931
- def suspend_monitor_dynamic_tables(
428
+ def suspend_monitor(
932
429
  self,
933
- model_name: sql_identifier.SqlIdentifier,
934
- version_name: sql_identifier.SqlIdentifier,
430
+ monitor_name: sql_identifier.SqlIdentifier,
935
431
  statement_params: Optional[Dict[str, Any]] = None,
936
432
  ) -> None:
937
- self._alter_monitor_dynamic_tables(
433
+ self._alter_monitor(
938
434
  operation="SUSPEND",
939
- model_name=model_name,
940
- version_name=version_name,
435
+ monitor_name=monitor_name,
941
436
  statement_params=statement_params,
942
437
  )
943
438
 
944
- def resume_monitor_dynamic_tables(
439
+ def resume_monitor(
945
440
  self,
946
- model_name: sql_identifier.SqlIdentifier,
947
- version_name: sql_identifier.SqlIdentifier,
441
+ monitor_name: sql_identifier.SqlIdentifier,
948
442
  statement_params: Optional[Dict[str, Any]] = None,
949
443
  ) -> None:
950
- self._alter_monitor_dynamic_tables(
444
+ self._alter_monitor(
951
445
  operation="RESUME",
952
- model_name=model_name,
953
- version_name=version_name,
446
+ monitor_name=monitor_name,
954
447
  statement_params=statement_params,
955
448
  )
956
-
957
- def create_dynamic_tables_for_monitor(
958
- self,
959
- *,
960
- model_name: sql_identifier.SqlIdentifier,
961
- model_version_name: sql_identifier.SqlIdentifier,
962
- task: type_hints.Task,
963
- source_table_name: sql_identifier.SqlIdentifier,
964
- refresh_interval: model_monitor_interval.ModelMonitorRefreshInterval,
965
- aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow,
966
- warehouse_name: sql_identifier.SqlIdentifier,
967
- timestamp_column: sql_identifier.SqlIdentifier,
968
- id_columns: List[sql_identifier.SqlIdentifier],
969
- prediction_columns: List[sql_identifier.SqlIdentifier],
970
- label_columns: List[sql_identifier.SqlIdentifier],
971
- score_type: output_score_type.OutputScoreType,
972
- ) -> None:
973
- table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
974
- self._sql_client._session,
975
- self._database_name,
976
- self._schema_name,
977
- source_table_name,
978
- )
979
- (numeric_features_names, categorical_feature_names) = _infer_numeric_categoric_feature_column_names(
980
- source_table_schema=table_schema,
981
- timestamp_column=timestamp_column,
982
- id_columns=id_columns,
983
- prediction_columns=prediction_columns,
984
- label_columns=label_columns,
985
- )
986
- features_dynamic_table_query = self._monitoring_dynamic_table_query(
987
- model_name=model_name,
988
- model_version_name=model_version_name,
989
- source_table_name=source_table_name,
990
- refresh_interval=refresh_interval,
991
- aggregate_window=aggregation_window,
992
- warehouse_name=warehouse_name,
993
- timestamp_column=timestamp_column,
994
- numeric_features=numeric_features_names,
995
- categoric_features=categorical_feature_names,
996
- prediction_columns=prediction_columns,
997
- label_columns=label_columns,
998
- )
999
- query_result_checker.SqlResultValidator(self._sql_client._session, features_dynamic_table_query).has_column(
1000
- "status"
1001
- ).has_dimensions(1, 1).validate()
1002
-
1003
- label_pred_join_table_query = self._monitoring_accuracy_table_query(
1004
- model_name=model_name,
1005
- model_version_name=model_version_name,
1006
- task=task,
1007
- source_table_name=source_table_name,
1008
- refresh_interval=refresh_interval,
1009
- aggregate_window=aggregation_window,
1010
- warehouse_name=warehouse_name,
1011
- timestamp_column=timestamp_column,
1012
- prediction_columns=prediction_columns,
1013
- label_columns=label_columns,
1014
- score_type=score_type,
1015
- )
1016
- query_result_checker.SqlResultValidator(self._sql_client._session, label_pred_join_table_query).has_column(
1017
- "status"
1018
- ).has_dimensions(1, 1).validate()
1019
-
1020
- def _monitoring_dynamic_table_query(
1021
- self,
1022
- *,
1023
- model_name: sql_identifier.SqlIdentifier,
1024
- model_version_name: sql_identifier.SqlIdentifier,
1025
- source_table_name: sql_identifier.SqlIdentifier,
1026
- refresh_interval: ModelMonitorRefreshInterval,
1027
- aggregate_window: ModelMonitorAggregationWindow,
1028
- warehouse_name: sql_identifier.SqlIdentifier,
1029
- timestamp_column: sql_identifier.SqlIdentifier,
1030
- numeric_features: List[sql_identifier.SqlIdentifier],
1031
- categoric_features: List[sql_identifier.SqlIdentifier],
1032
- prediction_columns: List[sql_identifier.SqlIdentifier],
1033
- label_columns: List[sql_identifier.SqlIdentifier],
1034
- ) -> str:
1035
- """
1036
- Generates a dynamic table query for Observability - Monitoring.
1037
-
1038
- Args:
1039
- model_name: Model name to monitor.
1040
- model_version_name: Model version name to monitor.
1041
- source_table_name: Name of source data table to monitor.
1042
- refresh_interval: Refresh interval in minutes.
1043
- aggregate_window: Aggregate window minutes.
1044
- warehouse_name: Warehouse name to use for dynamic table.
1045
- timestamp_column: Timestamp column name.
1046
- numeric_features: List of numeric features to capture.
1047
- categoric_features: List of categoric features to capture.
1048
- prediction_columns: List of columns that contain model inference outputs.
1049
- label_columns: List of columns that contain ground truth values.
1050
-
1051
- Raises:
1052
- ValueError: If multiple output/ground truth columns are specified. MultiClass models are not yet supported.
1053
-
1054
- Returns:
1055
- Dynamic table query.
1056
- """
1057
- # output and ground cols are list to keep interface extensible.
1058
- # for prpr only one label and one output col will be supported
1059
- if len(prediction_columns) != 1 or len(label_columns) != 1:
1060
- raise ValueError("Multiple Output columns are not supported in monitoring")
1061
-
1062
- monitoring_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, model_version_name)
1063
-
1064
- feature_cols_query_list = []
1065
- for feature in numeric_features + prediction_columns + label_columns:
1066
- feature_cols_query_list.append(
1067
- """
1068
- OBJECT_CONSTRUCT(
1069
- 'sketch', APPROX_PERCENTILE_ACCUMULATE({col}),
1070
- 'count', count_if({col} is not null),
1071
- 'count_null', count_if({col} is null),
1072
- 'min', min({col}),
1073
- 'max', max({col}),
1074
- 'sum', sum({col})
1075
- ) AS {col}""".format(
1076
- col=feature
1077
- )
1078
- )
1079
-
1080
- for col in categoric_features:
1081
- feature_cols_query_list.append(
1082
- f"""
1083
- {self._database_name}.{self._schema_name}.OBJECT_SUM(to_varchar({col})) AS {col}"""
1084
- )
1085
- feature_cols_query = ",".join(feature_cols_query_list)
1086
-
1087
- return f"""
1088
- CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
1089
- TARGET_LAG = '{refresh_interval.minutes} minutes'
1090
- WAREHOUSE = {warehouse_name}
1091
- REFRESH_MODE = AUTO
1092
- INITIALIZE = ON_CREATE
1093
- AS
1094
- SELECT
1095
- TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,{feature_cols_query}
1096
- FROM
1097
- {source_table_name}
1098
- GROUP BY
1099
- 1
1100
- """
1101
-
1102
- def _monitoring_accuracy_table_query(
1103
- self,
1104
- *,
1105
- model_name: sql_identifier.SqlIdentifier,
1106
- model_version_name: sql_identifier.SqlIdentifier,
1107
- task: type_hints.Task,
1108
- source_table_name: sql_identifier.SqlIdentifier,
1109
- refresh_interval: ModelMonitorRefreshInterval,
1110
- aggregate_window: ModelMonitorAggregationWindow,
1111
- warehouse_name: sql_identifier.SqlIdentifier,
1112
- timestamp_column: sql_identifier.SqlIdentifier,
1113
- prediction_columns: List[sql_identifier.SqlIdentifier],
1114
- label_columns: List[sql_identifier.SqlIdentifier],
1115
- score_type: output_score_type.OutputScoreType,
1116
- ) -> str:
1117
- # output and ground cols are list to keep interface extensible.
1118
- # for prpr only one label and one output col will be supported
1119
- if len(prediction_columns) != 1 or len(label_columns) != 1:
1120
- raise ValueError("Multiple Output columns are not supported in monitoring")
1121
- if task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION:
1122
- return self._monitoring_classification_accuracy_table_query(
1123
- model_name=model_name,
1124
- model_version_name=model_version_name,
1125
- source_table_name=source_table_name,
1126
- refresh_interval=refresh_interval,
1127
- aggregate_window=aggregate_window,
1128
- warehouse_name=warehouse_name,
1129
- timestamp_column=timestamp_column,
1130
- prediction_columns=prediction_columns,
1131
- label_columns=label_columns,
1132
- score_type=score_type,
1133
- )
1134
- else:
1135
- return self._monitoring_regression_accuracy_table_query(
1136
- model_name=model_name,
1137
- model_version_name=model_version_name,
1138
- source_table_name=source_table_name,
1139
- refresh_interval=refresh_interval,
1140
- aggregate_window=aggregate_window,
1141
- warehouse_name=warehouse_name,
1142
- timestamp_column=timestamp_column,
1143
- prediction_columns=prediction_columns,
1144
- label_columns=label_columns,
1145
- )
1146
-
1147
- def _monitoring_regression_accuracy_table_query(
1148
- self,
1149
- *,
1150
- model_name: sql_identifier.SqlIdentifier,
1151
- model_version_name: sql_identifier.SqlIdentifier,
1152
- source_table_name: sql_identifier.SqlIdentifier,
1153
- refresh_interval: ModelMonitorRefreshInterval,
1154
- aggregate_window: ModelMonitorAggregationWindow,
1155
- warehouse_name: sql_identifier.SqlIdentifier,
1156
- timestamp_column: sql_identifier.SqlIdentifier,
1157
- prediction_columns: List[sql_identifier.SqlIdentifier],
1158
- label_columns: List[sql_identifier.SqlIdentifier],
1159
- ) -> str:
1160
- """
1161
- Generates a dynamic table query for Monitoring - regression model accuracy.
1162
-
1163
- Args:
1164
- model_name: Model name to monitor.
1165
- model_version_name: Model version name to monitor.
1166
- source_table_name: Name of source data table to monitor.
1167
- refresh_interval: Refresh interval in minutes.
1168
- aggregate_window: Aggregate window minutes.
1169
- warehouse_name: Warehouse name to use for dynamic table.
1170
- timestamp_column: Timestamp column name.
1171
- prediction_columns: List of output columns.
1172
- label_columns: List of ground truth columns.
1173
-
1174
- Returns:
1175
- Dynamic table query.
1176
-
1177
- Raises:
1178
- ValueError: If output columns are not same as ground truth columns.
1179
-
1180
- """
1181
-
1182
- if len(prediction_columns) != len(label_columns):
1183
- raise ValueError(f"Mismatch in output & ground truth columns: {prediction_columns} != {label_columns}")
1184
-
1185
- monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
1186
-
1187
- output_cols_query_list = []
1188
-
1189
- output_cols_query_list.append(
1190
- f"""
1191
- OBJECT_CONSTRUCT(
1192
- 'sum_difference_label_pred', sum({prediction_columns[0]} - {label_columns[0]}),
1193
- 'sum_log_difference_square_label_pred',
1194
- sum(
1195
- case
1196
- when {prediction_columns[0]} > -1 and {label_columns[0]} > -1
1197
- then pow(ln({prediction_columns[0]} + 1) - ln({label_columns[0]} + 1),2)
1198
- else null
1199
- END
1200
- ),
1201
- 'sum_difference_squares_label_pred',
1202
- sum(
1203
- pow(
1204
- {prediction_columns[0]} - {label_columns[0]},
1205
- 2
1206
- )
1207
- ),
1208
- 'sum_absolute_regression_labels', sum(abs({label_columns[0]})),
1209
- 'sum_absolute_percentage_error',
1210
- sum(
1211
- abs(
1212
- div0null(
1213
- ({prediction_columns[0]} - {label_columns[0]}),
1214
- {label_columns[0]}
1215
- )
1216
- )
1217
- ),
1218
- 'sum_absolute_difference_label_pred',
1219
- sum(
1220
- abs({prediction_columns[0]} - {label_columns[0]})
1221
- ),
1222
- 'sum_prediction', sum({prediction_columns[0]}),
1223
- 'sum_label', sum({label_columns[0]}),
1224
- 'count', count(*)
1225
- ) AS AGGREGATE_METRICS,
1226
- APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
1227
- APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
1228
- )
1229
- output_cols_query = ", ".join(output_cols_query_list)
1230
-
1231
- return f"""
1232
- CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
1233
- TARGET_LAG = '{refresh_interval.minutes} minutes'
1234
- WAREHOUSE = {warehouse_name}
1235
- REFRESH_MODE = AUTO
1236
- INITIALIZE = ON_CREATE
1237
- AS
1238
- SELECT
1239
- TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,
1240
- 'class_regression' label_class,{output_cols_query}
1241
- FROM
1242
- {source_table_name}
1243
- GROUP BY
1244
- 1
1245
- """
1246
-
1247
- def _monitoring_classification_accuracy_table_query(
1248
- self,
1249
- *,
1250
- model_name: sql_identifier.SqlIdentifier,
1251
- model_version_name: sql_identifier.SqlIdentifier,
1252
- source_table_name: sql_identifier.SqlIdentifier,
1253
- refresh_interval: ModelMonitorRefreshInterval,
1254
- aggregate_window: ModelMonitorAggregationWindow,
1255
- warehouse_name: sql_identifier.SqlIdentifier,
1256
- timestamp_column: sql_identifier.SqlIdentifier,
1257
- prediction_columns: List[sql_identifier.SqlIdentifier],
1258
- label_columns: List[sql_identifier.SqlIdentifier],
1259
- score_type: output_score_type.OutputScoreType,
1260
- ) -> str:
1261
- monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
1262
-
1263
- # Initialize the select clause components
1264
- select_clauses = []
1265
-
1266
- select_clauses.append(
1267
- f"""
1268
- {prediction_columns[0]},
1269
- {label_columns[0]},
1270
- CASE
1271
- WHEN {label_columns[0]} = 1 THEN 'class_positive'
1272
- ELSE 'class_negative'
1273
- END AS label_class"""
1274
- )
1275
-
1276
- # Join all the select clauses into a single string
1277
- select_clause = f"{timestamp_column} AS timestamp," + ",".join(select_clauses)
1278
-
1279
- # Create the final CTE query
1280
- cte_query = f"""
1281
- WITH filtered_data AS (
1282
- SELECT
1283
- {select_clause}
1284
- FROM
1285
- {source_table_name}
1286
- )"""
1287
-
1288
- # Initialize the select clause components
1289
- select_clauses = []
1290
-
1291
- score_type_agg_clause = ""
1292
- if score_type == output_score_type.OutputScoreType.PROBITS:
1293
- score_type_agg_clause = f"""
1294
- 'sum_log_loss',
1295
- CASE
1296
- WHEN label_class = 'class_positive' THEN sum(-ln({prediction_columns[0]}))
1297
- ELSE sum(-ln(1 - {prediction_columns[0]}))
1298
- END,"""
1299
- else:
1300
- score_type_agg_clause = f"""
1301
- 'tp', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 1),
1302
- 'tn', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 0),
1303
- 'fp', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 1),
1304
- 'fn', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 0),"""
1305
-
1306
- select_clauses.append(
1307
- f"""
1308
- label_class,
1309
- OBJECT_CONSTRUCT(
1310
- 'sum_prediction', sum({prediction_columns[0]}),
1311
- 'sum_label', sum({label_columns[0]}),{score_type_agg_clause}
1312
- 'count', count(*)
1313
- ) AS AGGREGATE_METRICS,
1314
- APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
1315
- APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
1316
- )
1317
-
1318
- # Join all the select clauses into a single string
1319
- select_clause = ",\n".join(select_clauses)
1320
-
1321
- return f"""
1322
- CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
1323
- TARGET_LAG = '{refresh_interval.minutes} minutes'
1324
- WAREHOUSE = {warehouse_name}
1325
- REFRESH_MODE = AUTO
1326
- INITIALIZE = ON_CREATE
1327
- AS{cte_query}
1328
- select
1329
- time_slice(timestamp, {aggregate_window.minutes}, 'MINUTE') timestamp,{select_clause}
1330
- FROM
1331
- filtered_data
1332
- group by
1333
- 1,
1334
- 2
1335
- """