snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/type_utils.py +3 -3
  7. snowflake/ml/_internal/utils/jwt_generator.py +141 -0
  8. snowflake/ml/data/__init__.py +5 -0
  9. snowflake/ml/model/_client/model/model_version_impl.py +26 -12
  10. snowflake/ml/model/_client/ops/model_ops.py +51 -30
  11. snowflake/ml/model/_client/ops/service_ops.py +25 -9
  12. snowflake/ml/model/_client/sql/model.py +0 -14
  13. snowflake/ml/model/_client/sql/service.py +25 -1
  14. snowflake/ml/model/_client/sql/stage.py +1 -1
  15. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  16. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  17. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
  18. snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
  19. snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
  20. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
  21. snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -1
  22. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
  23. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  24. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  25. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  26. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  28. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  29. snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
  30. snowflake/ml/model/_signatures/core.py +63 -16
  31. snowflake/ml/model/_signatures/pandas_handler.py +71 -27
  32. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  33. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  34. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  35. snowflake/ml/model/_signatures/utils.py +4 -1
  36. snowflake/ml/model/model_signature.py +38 -9
  37. snowflake/ml/model/type_hints.py +1 -1
  38. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  39. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  40. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +148 -1200
  41. snowflake/ml/monitoring/_manager/model_monitor_manager.py +114 -238
  42. snowflake/ml/monitoring/entities/model_monitor_config.py +38 -12
  43. snowflake/ml/monitoring/model_monitor.py +12 -86
  44. snowflake/ml/registry/registry.py +28 -40
  45. snowflake/ml/utils/authentication.py +75 -0
  46. snowflake/ml/version.py +1 -1
  47. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/METADATA +116 -52
  48. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/RECORD +51 -49
  49. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/WHEEL +1 -1
  50. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  51. snowflake/ml/monitoring/entities/output_score_type.py +0 -90
  52. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/LICENSE.txt +0 -0
  53. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.2.dist-info}/top_level.txt +0 -0
@@ -1,66 +1,18 @@
1
+ import json
1
2
  from typing import Any, Dict, List, Optional
2
3
 
3
4
  from snowflake import snowpark
4
- from snowflake.ml._internal import telemetry
5
- from snowflake.ml._internal.utils import db_utils, sql_identifier
5
+ from snowflake.ml._internal.utils import sql_identifier
6
6
  from snowflake.ml.model import type_hints
7
7
  from snowflake.ml.model._client.model import model_version_impl
8
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
9
8
  from snowflake.ml.monitoring import model_monitor
10
9
  from snowflake.ml.monitoring._client import model_monitor_sql_client
11
- from snowflake.ml.monitoring.entities import (
12
- model_monitor_config,
13
- model_monitor_interval,
14
- )
10
+ from snowflake.ml.monitoring.entities import model_monitor_config
15
11
  from snowflake.snowpark import session
16
12
 
17
13
 
18
- def _validate_name_constraints(model_version: model_version_impl.ModelVersion) -> None:
19
- system_table_prefixes = [
20
- model_monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX,
21
- model_monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX,
22
- ]
23
-
24
- max_allowed_model_name_and_version_length = (
25
- db_utils.MAX_IDENTIFIER_LENGTH - max(len(prefix) for prefix in system_table_prefixes) - 1
26
- ) # -1 includes '_' between model_name + model_version
27
- if len(model_version.model_name) + len(model_version.version_name) > max_allowed_model_name_and_version_length:
28
- error_msg = f"Model name and version name exceeds maximum length of {max_allowed_model_name_and_version_length}"
29
- raise ValueError(error_msg)
30
-
31
-
32
14
  class ModelMonitorManager:
33
- """Class to manage internal operations for Model Monitor workflows.""" # TODO: Move to Registry.
34
-
35
- @staticmethod
36
- def setup(session: session.Session, database_name: str, schema_name: str) -> None:
37
- """Static method to set up schema for Model Monitoring resources.
38
-
39
- Args:
40
- session: The Snowpark Session to connect with Snowflake.
41
- database_name: The name of the database. If None, the current database of the session
42
- will be used. Defaults to None.
43
- schema_name: The name of the schema. If None, the current schema of the session
44
- will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
45
- """
46
- statement_params = telemetry.get_statement_params(
47
- project=telemetry.TelemetryProject.MLOPS.value,
48
- subproject=telemetry.TelemetrySubProject.MONITORING.value,
49
- )
50
- database_name_id = sql_identifier.SqlIdentifier(database_name)
51
- schema_name_id = sql_identifier.SqlIdentifier(schema_name)
52
- model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema(
53
- session, database_name_id, schema_name_id, statement_params=statement_params
54
- )
55
-
56
- def _fetch_task_from_model_version(
57
- self,
58
- model_version: model_version_impl.ModelVersion,
59
- ) -> type_hints.Task:
60
- task = model_version.get_model_task()
61
- if task == type_hints.Task.UNKNOWN:
62
- raise ValueError("Registry model must be logged with task in order to be monitored.")
63
- return task
15
+ """Class to manage internal operations for Model Monitor workflows."""
64
16
 
65
17
  def __init__(
66
18
  self,
@@ -68,7 +20,6 @@ class ModelMonitorManager:
68
20
  database_name: sql_identifier.SqlIdentifier,
69
21
  schema_name: sql_identifier.SqlIdentifier,
70
22
  *,
71
- create_if_not_exists: bool = False,
72
23
  statement_params: Optional[Dict[str, Any]] = None,
73
24
  ) -> None:
74
25
  """
@@ -79,233 +30,165 @@ class ModelMonitorManager:
79
30
  session: The Snowpark Session to connect with Snowflake.
80
31
  database_name: The name of the database.
81
32
  schema_name: The name of the schema.
82
- create_if_not_exists: Flag whether to initialize resources in the schema needed for Model Monitoring.
83
33
  statement_params: Optional set of statement params.
84
-
85
- Raises:
86
- ValueError: When there is no specified or active database in the session.
87
34
  """
88
35
  self._database_name = database_name
89
36
  self._schema_name = schema_name
90
37
  self.statement_params = statement_params
38
+
91
39
  self._model_monitor_client = model_monitor_sql_client.ModelMonitorSQLClient(
92
40
  session,
93
41
  database_name=self._database_name,
94
42
  schema_name=self._schema_name,
95
43
  )
96
- if create_if_not_exists:
97
- model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema(
98
- session, self._database_name, self._schema_name, self.statement_params
99
- )
100
- elif not self._model_monitor_client._validate_is_initialized():
101
- raise ValueError(
102
- "Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup"
103
- )
104
44
 
105
- def _get_and_validate_model_function_from_model_version(
45
+ def _validate_task_from_model_version(
46
+ self,
47
+ model_version: model_version_impl.ModelVersion,
48
+ ) -> type_hints.Task:
49
+ task = model_version.get_model_task()
50
+ if task == type_hints.Task.UNKNOWN:
51
+ raise ValueError("Registry model must be logged with task in order to be monitored.")
52
+ return task
53
+
54
+ def _validate_model_function_from_model_version(
106
55
  self, function: str, model_version: model_version_impl.ModelVersion
107
- ) -> model_manifest_schema.ModelFunctionInfo:
56
+ ) -> None:
108
57
  functions = model_version.show_functions()
109
58
  for f in functions:
110
59
  if f["target_method"] == function:
111
- return f
60
+ return
112
61
  existing_target_methods = {f["target_method"] for f in functions}
113
62
  raise ValueError(
114
63
  f"Function with name {function} does not exist in the given model version. "
115
64
  f"Found: {existing_target_methods}."
116
65
  )
117
66
 
118
- def _validate_monitor_config_or_raise(
119
- self,
120
- table_config: model_monitor_config.ModelMonitorTableConfig,
121
- model_monitor_config: model_monitor_config.ModelMonitorConfig,
122
- ) -> None:
123
- """Validate provided config for model monitor.
124
-
125
- Args:
126
- table_config: Config for model monitor tables.
127
- model_monitor_config: Config for ModelMonitor.
128
-
129
- Raises:
130
- ValueError: If warehouse provided does not exist.
131
- """
132
-
133
- # Validate naming will not exceed 255 chars
134
- _validate_name_constraints(model_monitor_config.model_version)
135
-
136
- if len(table_config.prediction_columns) != len(table_config.label_columns):
137
- raise ValueError("Prediction and Label column names must be of the same length.")
138
- # output and ground cols are list to keep interface extensible.
139
- # for prpr only one label and one output col will be supported
140
- if len(table_config.prediction_columns) != 1 or len(table_config.label_columns) != 1:
141
- raise ValueError("Multiple Output columns are not supported in monitoring")
142
-
143
- # Validate warehouse exists.
144
- warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
145
- self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
146
-
147
- # Validate refresh interval.
148
- try:
149
- num_units, time_units = model_monitor_config.refresh_interval.strip().split(" ")
150
- int(num_units) # try to cast
151
- if time_units.lower() not in {"seconds", "minutes", "hours", "days"}:
152
- raise ValueError(
153
- """Invalid time unit in refresh interval. Provide '<num> <seconds | minutes | hours | days>'.
154
- See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
155
- )
156
- except Exception as e: # TODO: Link to DT page.
157
- raise ValueError(
158
- f"""Failed to parse refresh interval with exception {e}.
159
- Provide '<num> <seconds | minutes | hours | days>'.
160
- See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
161
- )
67
+ def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]:
68
+ return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
162
69
 
163
70
  def add_monitor(
164
71
  self,
165
72
  name: str,
166
- table_config: model_monitor_config.ModelMonitorTableConfig,
73
+ source_config: model_monitor_config.ModelMonitorSourceConfig,
167
74
  model_monitor_config: model_monitor_config.ModelMonitorConfig,
168
- *,
169
- add_dashboard_udtfs: bool = False,
170
75
  ) -> model_monitor.ModelMonitor:
171
76
  """Add a new Model Monitor.
172
77
 
173
78
  Args:
174
79
  name: Name of Model Monitor to create.
175
- table_config: Configuration options for the source table used in ModelMonitor.
80
+ source_config: Configuration options for the source table used in ModelMonitor.
176
81
  model_monitor_config: Configuration options of ModelMonitor.
177
- add_dashboard_udtfs: Add UDTFs useful for creating a dashboard.
178
82
 
179
83
  Returns:
180
84
  The newly added ModelMonitor object.
181
85
  """
182
- # Validates configuration or raise.
183
- self._validate_monitor_config_or_raise(table_config, model_monitor_config)
184
- model_function = self._get_and_validate_model_function_from_model_version(
86
+ warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
87
+ self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
88
+ self._validate_model_function_from_model_version(
185
89
  model_monitor_config.model_function_name, model_monitor_config.model_version
186
90
  )
187
- monitor_refresh_interval = model_monitor_interval.ModelMonitorRefreshInterval(
188
- model_monitor_config.refresh_interval
91
+ self._validate_task_from_model_version(model_monitor_config.model_version)
92
+ monitor_database_name_id, monitor_schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(
93
+ name
189
94
  )
190
- name_id = sql_identifier.SqlIdentifier(name)
191
- source_table_name_id = sql_identifier.SqlIdentifier(table_config.source_table)
192
- prediction_columns = [
193
- sql_identifier.SqlIdentifier(column_name) for column_name in table_config.prediction_columns
194
- ]
195
- label_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.label_columns]
196
- id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.id_columns]
197
- ts_column = sql_identifier.SqlIdentifier(table_config.timestamp_column)
95
+ source_database_name_id, source_schema_name_id, source_name_id = sql_identifier.parse_fully_qualified_name(
96
+ source_config.source
97
+ )
98
+ baseline_database_name_id, baseline_schema_name_id, baseline_name_id = (
99
+ sql_identifier.parse_fully_qualified_name(source_config.baseline)
100
+ if source_config.baseline
101
+ else (None, None, None)
102
+ )
103
+ model_database_name_id, model_schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(
104
+ model_monitor_config.model_version.fully_qualified_model_name
105
+ )
106
+
107
+ prediction_score_columns = self._build_column_list_from_input(source_config.prediction_score_columns)
108
+ prediction_class_columns = self._build_column_list_from_input(source_config.prediction_class_columns)
109
+ actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns)
110
+ actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns)
111
+
112
+ id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns]
113
+ ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column)
198
114
 
199
115
  # Validate source table
200
- self._model_monitor_client.validate_source_table(
201
- source_table_name=source_table_name_id,
116
+ self._model_monitor_client.validate_source(
117
+ source_database=source_database_name_id,
118
+ source_schema=source_schema_name_id,
119
+ source=source_name_id,
202
120
  timestamp_column=ts_column,
203
- prediction_columns=prediction_columns,
204
- label_columns=label_columns,
121
+ prediction_score_columns=prediction_score_columns,
122
+ prediction_class_columns=prediction_class_columns,
123
+ actual_score_columns=actual_score_columns,
124
+ actual_class_columns=actual_class_columns,
205
125
  id_columns=id_columns,
206
- model_function=model_function,
207
126
  )
208
127
 
209
- task = self._fetch_task_from_model_version(model_version=model_monitor_config.model_version)
210
- score_type = self._model_monitor_client.get_score_type(task, source_table_name_id, prediction_columns)
211
-
212
- # Insert monitoring metadata for new model version.
213
- self._model_monitor_client.create_monitor_on_model_version(
214
- monitor_name=name_id,
215
- source_table_name=source_table_name_id,
216
- fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
128
+ self._model_monitor_client.create_model_monitor(
129
+ monitor_database=monitor_database_name_id,
130
+ monitor_schema=monitor_schema_name_id,
131
+ monitor_name=monitor_name_id,
132
+ source_database=source_database_name_id,
133
+ source_schema=source_schema_name_id,
134
+ source=source_name_id,
135
+ model_database=model_database_name_id,
136
+ model_schema=model_schema_name_id,
137
+ model_name=model_name_id,
217
138
  version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
218
139
  function_name=model_monitor_config.model_function_name,
140
+ warehouse_name=warehouse_name_id,
219
141
  timestamp_column=ts_column,
220
- prediction_columns=prediction_columns,
221
- label_columns=label_columns,
222
142
  id_columns=id_columns,
223
- task=task,
224
- statement_params=self.statement_params,
225
- )
226
-
227
- # Create Dynamic tables for model monitor.
228
- self._model_monitor_client.create_dynamic_tables_for_monitor(
229
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
230
- model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
231
- task=task,
232
- source_table_name=source_table_name_id,
233
- refresh_interval=monitor_refresh_interval,
143
+ prediction_score_columns=prediction_score_columns,
144
+ prediction_class_columns=prediction_class_columns,
145
+ actual_score_columns=actual_score_columns,
146
+ actual_class_columns=actual_class_columns,
147
+ refresh_interval=model_monitor_config.refresh_interval,
234
148
  aggregation_window=model_monitor_config.aggregation_window,
235
- warehouse_name=sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name),
236
- timestamp_column=sql_identifier.SqlIdentifier(table_config.timestamp_column),
237
- id_columns=id_columns,
238
- prediction_columns=prediction_columns,
239
- label_columns=label_columns,
240
- score_type=score_type,
241
- )
242
-
243
- # Initialize baseline table.
244
- self._model_monitor_client.initialize_baseline_table(
245
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
246
- version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
247
- source_table_name=table_config.source_table,
248
- columns_to_drop=[ts_column, *id_columns],
149
+ baseline_database=baseline_database_name_id,
150
+ baseline_schema=baseline_schema_name_id,
151
+ baseline=baseline_name_id,
249
152
  statement_params=self.statement_params,
250
153
  )
251
-
252
- # Add udtfs helpful for dashboard queries.
253
- # TODO(apgupta) Make this true by default.
254
- if add_dashboard_udtfs:
255
- self._model_monitor_client.add_dashboard_udtfs(
256
- name_id,
257
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
258
- model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
259
- task=task,
260
- score_type=score_type,
261
- output_columns=prediction_columns,
262
- ground_truth_columns=label_columns,
263
- )
264
-
265
154
  return model_monitor.ModelMonitor._ref(
266
155
  model_monitor_client=self._model_monitor_client,
267
- name=name_id,
268
- fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
269
- version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
270
- function_name=sql_identifier.SqlIdentifier(model_monitor_config.model_function_name),
271
- prediction_columns=prediction_columns,
272
- label_columns=label_columns,
156
+ name=monitor_name_id,
273
157
  )
274
158
 
275
159
  def get_monitor_by_model_version(
276
160
  self, model_version: model_version_impl.ModelVersion
277
161
  ) -> model_monitor.ModelMonitor:
278
- fq_model_name = model_version.fully_qualified_model_name
279
- version_name = sql_identifier.SqlIdentifier(model_version.version_name)
280
- if self._model_monitor_client.validate_existence(fq_model_name, version_name, self.statement_params):
281
- model_db, model_schema, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name)
282
- if model_db is None or model_schema is None:
283
- raise ValueError("Failed to parse model name")
284
-
285
- model_monitor_params: model_monitor_sql_client._ModelMonitorParams = (
286
- self._model_monitor_client.get_model_monitor_by_model_version(
287
- model_db=model_db,
288
- model_schema=model_schema,
289
- model_name=model_name,
290
- version_name=version_name,
291
- statement_params=self.statement_params,
292
- )
293
- )
294
- return model_monitor.ModelMonitor._ref(
295
- model_monitor_client=self._model_monitor_client,
296
- name=sql_identifier.SqlIdentifier(model_monitor_params["monitor_name"]),
297
- fully_qualified_model_name=fq_model_name,
298
- version_name=version_name,
299
- function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
300
- prediction_columns=model_monitor_params["prediction_columns"],
301
- label_columns=model_monitor_params["label_columns"],
302
- )
162
+ """Get a Model Monitor by Model Version.
303
163
 
304
- else:
305
- raise ValueError(
306
- f"ModelMonitor not found for model version {model_version.model_name} - {model_version.version_name}"
164
+ Args:
165
+ model_version: ModelVersion to retrieve Model Monitor for.
166
+
167
+ Returns:
168
+ The fetched ModelMonitor.
169
+
170
+ Raises:
171
+ ValueError: If model monitor is not found.
172
+ """
173
+ rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
174
+
175
+ def model_match_fn(model_details: Dict[str, str]) -> bool:
176
+ return (
177
+ model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
178
+ and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
307
179
  )
308
180
 
181
+ rows = [row for row in rows if model_match_fn(json.loads(row[model_monitor_sql_client.MODEL_JSON_COL_NAME]))]
182
+ if len(rows) == 0:
183
+ raise ValueError("Unable to find model monitor for the given model version.")
184
+ if len(rows) > 1:
185
+ raise ValueError("Found multiple model monitors for the given model version.")
186
+
187
+ return model_monitor.ModelMonitor._ref(
188
+ model_monitor_client=self._model_monitor_client,
189
+ name=sql_identifier.SqlIdentifier(rows[0]["name"]),
190
+ )
191
+
309
192
  def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
310
193
  """Get a Model Monitor from the Registry
311
194
 
@@ -318,25 +201,18 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
318
201
  Returns:
319
202
  The fetched ModelMonitor.
320
203
  """
321
- name_id = sql_identifier.SqlIdentifier(name)
204
+ database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
322
205
 
323
206
  if not self._model_monitor_client.validate_existence_by_name(
324
- monitor_name=name_id,
207
+ database_name=database_name_id,
208
+ schema_name=schema_name_id,
209
+ monitor_name=monitor_name_id,
325
210
  statement_params=self.statement_params,
326
211
  ):
327
212
  raise ValueError(f"Unable to find model monitor '{name}'")
328
- model_monitor_params: model_monitor_sql_client._ModelMonitorParams = (
329
- self._model_monitor_client.get_model_monitor_by_name(name_id, statement_params=self.statement_params)
330
- )
331
-
332
213
  return model_monitor.ModelMonitor._ref(
333
214
  model_monitor_client=self._model_monitor_client,
334
- name=name_id,
335
- fully_qualified_model_name=model_monitor_params["fully_qualified_model_name"],
336
- version_name=sql_identifier.SqlIdentifier(model_monitor_params["version_name"]),
337
- function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
338
- prediction_columns=model_monitor_params["prediction_columns"],
339
- label_columns=model_monitor_params["label_columns"],
215
+ name=monitor_name_id,
340
216
  )
341
217
 
342
218
  def show_model_monitors(self) -> List[snowpark.Row]:
@@ -345,7 +221,7 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
345
221
  Returns:
346
222
  List of snowpark.Row containing metadata for each model monitor.
347
223
  """
348
- return self._model_monitor_client.get_all_model_monitor_metadata()
224
+ return self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
349
225
 
350
226
  def delete_monitor(self, name: str) -> None:
351
227
  """Delete a Model Monitor from the Registry
@@ -353,10 +229,10 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
353
229
  Args:
354
230
  name: Name of the Model Monitor to delete.
355
231
  """
356
- name_id = sql_identifier.SqlIdentifier(name)
357
- monitor_params = self._model_monitor_client.get_model_monitor_by_name(name_id)
358
- _, _, model = sql_identifier.parse_fully_qualified_name(monitor_params["fully_qualified_model_name"])
359
- version = sql_identifier.SqlIdentifier(monitor_params["version_name"])
360
- self._model_monitor_client.delete_monitor_metadata(name_id)
361
- self._model_monitor_client.delete_baseline_table(model, version)
362
- self._model_monitor_client.delete_dynamic_tables(model, version)
232
+ database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
233
+ self._model_monitor_client.drop_model_monitor(
234
+ database_name=database_name_id,
235
+ schema_name=schema_name_id,
236
+ monitor_name=monitor_name_id,
237
+ statement_params=self.statement_params,
238
+ )
@@ -1,28 +1,54 @@
1
1
  from dataclasses import dataclass
2
- from typing import List
2
+ from typing import List, Optional
3
3
 
4
4
  from snowflake.ml.model._client.model import model_version_impl
5
- from snowflake.ml.monitoring.entities import model_monitor_interval
6
5
 
7
6
 
8
7
  @dataclass
9
- class ModelMonitorTableConfig:
10
- source_table: str
8
+ class ModelMonitorSourceConfig:
9
+ """Configuration for the source of data to be monitored."""
10
+
11
+ source: str
12
+ """Name of table or view containing monitoring data."""
13
+
11
14
  timestamp_column: str
12
- prediction_columns: List[str]
13
- label_columns: List[str]
15
+ """Name of column in the source containing timestamp."""
16
+
14
17
  id_columns: List[str]
18
+ """List of columns in the source containing unique identifiers."""
19
+
20
+ prediction_score_columns: Optional[List[str]] = None
21
+ """List of columns in the source containing prediction scores.
22
+ Can be regression scores for regression models and probability scores for classification models."""
23
+
24
+ prediction_class_columns: Optional[List[str]] = None
25
+ """List of columns in the source containing prediction classes for classification models."""
26
+
27
+ actual_score_columns: Optional[List[str]] = None
28
+ """List of columns in the source containing actual scores."""
29
+
30
+ actual_class_columns: Optional[List[str]] = None
31
+ """List of columns in the source containing actual classes for classification models."""
32
+
33
+ baseline: Optional[str] = None
34
+ """Name of table containing the baseline data."""
15
35
 
16
36
 
17
37
  @dataclass
18
38
  class ModelMonitorConfig:
39
+ """Configuration for the Model Monitor."""
40
+
19
41
  model_version: model_version_impl.ModelVersion
42
+ """Model version to monitor."""
20
43
 
21
- # Python model function name
22
44
  model_function_name: str
45
+ """Function name in the model to monitor."""
46
+
23
47
  background_compute_warehouse_name: str
24
- # TODO: Add support for pythonic notion of time.
25
- refresh_interval: str = model_monitor_interval.ModelMonitorRefreshInterval.DAILY
26
- aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow = (
27
- model_monitor_interval.ModelMonitorAggregationWindow.WINDOW_1_DAY
28
- )
48
+ """Name of the warehouse to use for background compute."""
49
+
50
+ refresh_interval: str = "1 hour"
51
+ """Interval at which to refresh the monitoring data."""
52
+
53
+ aggregation_window: str = "1 day"
54
+ """Window for aggregating monitoring data."""