snowflake-ml-python 1.7.0__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/type_utils.py +3 -3
  7. snowflake/ml/data/__init__.py +5 -0
  8. snowflake/ml/model/_client/model/model_version_impl.py +7 -7
  9. snowflake/ml/model/_client/ops/model_ops.py +51 -30
  10. snowflake/ml/model/_client/ops/service_ops.py +13 -2
  11. snowflake/ml/model/_client/sql/model.py +0 -14
  12. snowflake/ml/model/_client/sql/service.py +25 -1
  13. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  14. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  15. snowflake/ml/model/_packager/model_handlers/_utils.py +1 -1
  16. snowflake/ml/model/_packager/model_handlers/catboost.py +1 -1
  17. snowflake/ml/model/_packager/model_handlers/custom.py +3 -1
  18. snowflake/ml/model/_packager/model_handlers/lightgbm.py +2 -1
  19. snowflake/ml/model/_packager/model_handlers/sklearn.py +48 -1
  20. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  22. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  23. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  24. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  26. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  27. snowflake/ml/model/_packager/model_task/model_task_utils.py +1 -1
  28. snowflake/ml/model/_signatures/core.py +63 -16
  29. snowflake/ml/model/_signatures/pandas_handler.py +71 -27
  30. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  31. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  32. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  33. snowflake/ml/model/_signatures/utils.py +4 -0
  34. snowflake/ml/model/model_signature.py +38 -9
  35. snowflake/ml/model/type_hints.py +1 -1
  36. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  37. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  38. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +158 -1045
  39. snowflake/ml/monitoring/_manager/model_monitor_manager.py +106 -230
  40. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  41. snowflake/ml/monitoring/model_monitor.py +7 -96
  42. snowflake/ml/registry/registry.py +17 -29
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +31 -5
  45. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +48 -47
  46. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  47. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  48. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.7.0.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,59 +1,20 @@
1
+ import json
1
2
  from typing import Any, Dict, List, Optional
2
3
 
3
4
  from snowflake import snowpark
4
- from snowflake.ml._internal import telemetry
5
- from snowflake.ml._internal.utils import db_utils, sql_identifier
5
+ from snowflake.ml._internal.utils import sql_identifier
6
6
  from snowflake.ml.model import type_hints
7
7
  from snowflake.ml.model._client.model import model_version_impl
8
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
9
8
  from snowflake.ml.monitoring import model_monitor
10
9
  from snowflake.ml.monitoring._client import model_monitor_sql_client
11
- from snowflake.ml.monitoring.entities import (
12
- model_monitor_config,
13
- model_monitor_interval,
14
- )
10
+ from snowflake.ml.monitoring.entities import model_monitor_config
15
11
  from snowflake.snowpark import session
16
12
 
17
13
 
18
- def _validate_name_constraints(model_version: model_version_impl.ModelVersion) -> None:
19
- system_table_prefixes = [
20
- model_monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX,
21
- model_monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX,
22
- ]
23
-
24
- max_allowed_model_name_and_version_length = (
25
- db_utils.MAX_IDENTIFIER_LENGTH - max(len(prefix) for prefix in system_table_prefixes) - 1
26
- ) # -1 includes '_' between model_name + model_version
27
- if len(model_version.model_name) + len(model_version.version_name) > max_allowed_model_name_and_version_length:
28
- error_msg = f"Model name and version name exceeds maximum length of {max_allowed_model_name_and_version_length}"
29
- raise ValueError(error_msg)
30
-
31
-
32
14
  class ModelMonitorManager:
33
- """Class to manage internal operations for Model Monitor workflows.""" # TODO: Move to Registry.
34
-
35
- @staticmethod
36
- def setup(session: session.Session, database_name: str, schema_name: str) -> None:
37
- """Static method to set up schema for Model Monitoring resources.
38
-
39
- Args:
40
- session: The Snowpark Session to connect with Snowflake.
41
- database_name: The name of the database. If None, the current database of the session
42
- will be used. Defaults to None.
43
- schema_name: The name of the schema. If None, the current schema of the session
44
- will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
45
- """
46
- statement_params = telemetry.get_statement_params(
47
- project=telemetry.TelemetryProject.MLOPS.value,
48
- subproject=telemetry.TelemetrySubProject.MONITORING.value,
49
- )
50
- database_name_id = sql_identifier.SqlIdentifier(database_name)
51
- schema_name_id = sql_identifier.SqlIdentifier(schema_name)
52
- model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema(
53
- session, database_name_id, schema_name_id, statement_params=statement_params
54
- )
15
+ """Class to manage internal operations for Model Monitor workflows."""
55
16
 
56
- def _fetch_task_from_model_version(
17
+ def _validate_task_from_model_version(
57
18
  self,
58
19
  model_version: model_version_impl.ModelVersion,
59
20
  ) -> type_hints.Task:
@@ -68,7 +29,6 @@ class ModelMonitorManager:
68
29
  database_name: sql_identifier.SqlIdentifier,
69
30
  schema_name: sql_identifier.SqlIdentifier,
70
31
  *,
71
- create_if_not_exists: bool = False,
72
32
  statement_params: Optional[Dict[str, Any]] = None,
73
33
  ) -> None:
74
34
  """
@@ -79,233 +39,156 @@ class ModelMonitorManager:
79
39
  session: The Snowpark Session to connect with Snowflake.
80
40
  database_name: The name of the database.
81
41
  schema_name: The name of the schema.
82
- create_if_not_exists: Flag whether to initialize resources in the schema needed for Model Monitoring.
83
42
  statement_params: Optional set of statement params.
84
-
85
- Raises:
86
- ValueError: When there is no specified or active database in the session.
87
43
  """
88
44
  self._database_name = database_name
89
45
  self._schema_name = schema_name
90
46
  self.statement_params = statement_params
47
+
91
48
  self._model_monitor_client = model_monitor_sql_client.ModelMonitorSQLClient(
92
49
  session,
93
50
  database_name=self._database_name,
94
51
  schema_name=self._schema_name,
95
52
  )
96
- if create_if_not_exists:
97
- model_monitor_sql_client.ModelMonitorSQLClient.initialize_monitoring_schema(
98
- session, self._database_name, self._schema_name, self.statement_params
99
- )
100
- elif not self._model_monitor_client._validate_is_initialized():
101
- raise ValueError(
102
- "Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup"
103
- )
104
53
 
105
- def _get_and_validate_model_function_from_model_version(
54
+ def _validate_model_function_from_model_version(
106
55
  self, function: str, model_version: model_version_impl.ModelVersion
107
- ) -> model_manifest_schema.ModelFunctionInfo:
56
+ ) -> None:
108
57
  functions = model_version.show_functions()
109
58
  for f in functions:
110
59
  if f["target_method"] == function:
111
- return f
60
+ return
112
61
  existing_target_methods = {f["target_method"] for f in functions}
113
62
  raise ValueError(
114
63
  f"Function with name {function} does not exist in the given model version. "
115
64
  f"Found: {existing_target_methods}."
116
65
  )
117
66
 
118
- def _validate_monitor_config_or_raise(
119
- self,
120
- table_config: model_monitor_config.ModelMonitorTableConfig,
121
- model_monitor_config: model_monitor_config.ModelMonitorConfig,
122
- ) -> None:
123
- """Validate provided config for model monitor.
124
-
125
- Args:
126
- table_config: Config for model monitor tables.
127
- model_monitor_config: Config for ModelMonitor.
128
-
129
- Raises:
130
- ValueError: If warehouse provided does not exist.
131
- """
132
-
133
- # Validate naming will not exceed 255 chars
134
- _validate_name_constraints(model_monitor_config.model_version)
135
-
136
- if len(table_config.prediction_columns) != len(table_config.label_columns):
137
- raise ValueError("Prediction and Label column names must be of the same length.")
138
- # output and ground cols are list to keep interface extensible.
139
- # for prpr only one label and one output col will be supported
140
- if len(table_config.prediction_columns) != 1 or len(table_config.label_columns) != 1:
141
- raise ValueError("Multiple Output columns are not supported in monitoring")
142
-
143
- # Validate warehouse exists.
144
- warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
145
- self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
146
-
147
- # Validate refresh interval.
148
- try:
149
- num_units, time_units = model_monitor_config.refresh_interval.strip().split(" ")
150
- int(num_units) # try to cast
151
- if time_units.lower() not in {"seconds", "minutes", "hours", "days"}:
152
- raise ValueError(
153
- """Invalid time unit in refresh interval. Provide '<num> <seconds | minutes | hours | days>'.
154
- See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
155
- )
156
- except Exception as e: # TODO: Link to DT page.
157
- raise ValueError(
158
- f"""Failed to parse refresh interval with exception {e}.
159
- Provide '<num> <seconds | minutes | hours | days>'.
160
- See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
161
- )
67
+ def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]:
68
+ return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
162
69
 
163
70
  def add_monitor(
164
71
  self,
165
72
  name: str,
166
- table_config: model_monitor_config.ModelMonitorTableConfig,
73
+ source_config: model_monitor_config.ModelMonitorSourceConfig,
167
74
  model_monitor_config: model_monitor_config.ModelMonitorConfig,
168
- *,
169
- add_dashboard_udtfs: bool = False,
170
75
  ) -> model_monitor.ModelMonitor:
171
76
  """Add a new Model Monitor.
172
77
 
173
78
  Args:
174
79
  name: Name of Model Monitor to create.
175
- table_config: Configuration options for the source table used in ModelMonitor.
80
+ source_config: Configuration options for the source table used in ModelMonitor.
176
81
  model_monitor_config: Configuration options of ModelMonitor.
177
- add_dashboard_udtfs: Add UDTFs useful for creating a dashboard.
178
82
 
179
83
  Returns:
180
84
  The newly added ModelMonitor object.
181
85
  """
182
- # Validates configuration or raise.
183
- self._validate_monitor_config_or_raise(table_config, model_monitor_config)
184
- model_function = self._get_and_validate_model_function_from_model_version(
86
+ warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
87
+ self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
88
+ self._validate_model_function_from_model_version(
185
89
  model_monitor_config.model_function_name, model_monitor_config.model_version
186
90
  )
187
- monitor_refresh_interval = model_monitor_interval.ModelMonitorRefreshInterval(
188
- model_monitor_config.refresh_interval
91
+ self._validate_task_from_model_version(model_monitor_config.model_version)
92
+ monitor_database_name_id, monitor_schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(
93
+ name
94
+ )
95
+ source_database_name_id, source_schema_name_id, source_name_id = sql_identifier.parse_fully_qualified_name(
96
+ source_config.source
97
+ )
98
+ baseline_database_name_id, baseline_schema_name_id, baseline_name_id = (
99
+ sql_identifier.parse_fully_qualified_name(source_config.baseline)
100
+ if source_config.baseline
101
+ else (None, None, None)
189
102
  )
190
- name_id = sql_identifier.SqlIdentifier(name)
191
- source_table_name_id = sql_identifier.SqlIdentifier(table_config.source_table)
192
- prediction_columns = [
193
- sql_identifier.SqlIdentifier(column_name) for column_name in table_config.prediction_columns
194
- ]
195
- label_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.label_columns]
196
- id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.id_columns]
197
- ts_column = sql_identifier.SqlIdentifier(table_config.timestamp_column)
103
+ model_database_name_id, model_schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(
104
+ model_monitor_config.model_version.fully_qualified_model_name
105
+ )
106
+
107
+ prediction_score_columns = self._build_column_list_from_input(source_config.prediction_score_columns)
108
+ prediction_class_columns = self._build_column_list_from_input(source_config.prediction_class_columns)
109
+ actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns)
110
+ actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns)
111
+
112
+ id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns]
113
+ ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column)
198
114
 
199
115
  # Validate source table
200
- self._model_monitor_client.validate_source_table(
201
- source_table_name=source_table_name_id,
116
+ self._model_monitor_client.validate_source(
117
+ source_database=source_database_name_id,
118
+ source_schema=source_schema_name_id,
119
+ source=source_name_id,
202
120
  timestamp_column=ts_column,
203
- prediction_columns=prediction_columns,
204
- label_columns=label_columns,
121
+ prediction_score_columns=prediction_score_columns,
122
+ prediction_class_columns=prediction_class_columns,
123
+ actual_score_columns=actual_score_columns,
124
+ actual_class_columns=actual_class_columns,
205
125
  id_columns=id_columns,
206
- model_function=model_function,
207
126
  )
208
127
 
209
- task = self._fetch_task_from_model_version(model_version=model_monitor_config.model_version)
210
- score_type = self._model_monitor_client.get_score_type(task, source_table_name_id, prediction_columns)
211
-
212
- # Insert monitoring metadata for new model version.
213
- self._model_monitor_client.create_monitor_on_model_version(
214
- monitor_name=name_id,
215
- source_table_name=source_table_name_id,
216
- fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
128
+ self._model_monitor_client.create_model_monitor(
129
+ monitor_database=monitor_database_name_id,
130
+ monitor_schema=monitor_schema_name_id,
131
+ monitor_name=monitor_name_id,
132
+ source_database=source_database_name_id,
133
+ source_schema=source_schema_name_id,
134
+ source=source_name_id,
135
+ model_database=model_database_name_id,
136
+ model_schema=model_schema_name_id,
137
+ model_name=model_name_id,
217
138
  version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
218
139
  function_name=model_monitor_config.model_function_name,
140
+ warehouse_name=warehouse_name_id,
219
141
  timestamp_column=ts_column,
220
- prediction_columns=prediction_columns,
221
- label_columns=label_columns,
222
142
  id_columns=id_columns,
223
- task=task,
224
- statement_params=self.statement_params,
225
- )
226
-
227
- # Create Dynamic tables for model monitor.
228
- self._model_monitor_client.create_dynamic_tables_for_monitor(
229
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
230
- model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
231
- task=task,
232
- source_table_name=source_table_name_id,
233
- refresh_interval=monitor_refresh_interval,
143
+ prediction_score_columns=prediction_score_columns,
144
+ prediction_class_columns=prediction_class_columns,
145
+ actual_score_columns=actual_score_columns,
146
+ actual_class_columns=actual_class_columns,
147
+ refresh_interval=model_monitor_config.refresh_interval,
234
148
  aggregation_window=model_monitor_config.aggregation_window,
235
- warehouse_name=sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name),
236
- timestamp_column=sql_identifier.SqlIdentifier(table_config.timestamp_column),
237
- id_columns=id_columns,
238
- prediction_columns=prediction_columns,
239
- label_columns=label_columns,
240
- score_type=score_type,
241
- )
242
-
243
- # Initialize baseline table.
244
- self._model_monitor_client.initialize_baseline_table(
245
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
246
- version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
247
- source_table_name=table_config.source_table,
248
- columns_to_drop=[ts_column, *id_columns],
149
+ baseline_database=baseline_database_name_id,
150
+ baseline_schema=baseline_schema_name_id,
151
+ baseline=baseline_name_id,
249
152
  statement_params=self.statement_params,
250
153
  )
251
-
252
- # Add udtfs helpful for dashboard queries.
253
- # TODO(apgupta) Make this true by default.
254
- if add_dashboard_udtfs:
255
- self._model_monitor_client.add_dashboard_udtfs(
256
- name_id,
257
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
258
- model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
259
- task=task,
260
- score_type=score_type,
261
- output_columns=prediction_columns,
262
- ground_truth_columns=label_columns,
263
- )
264
-
265
154
  return model_monitor.ModelMonitor._ref(
266
155
  model_monitor_client=self._model_monitor_client,
267
- name=name_id,
268
- fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
269
- version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
270
- function_name=sql_identifier.SqlIdentifier(model_monitor_config.model_function_name),
271
- prediction_columns=prediction_columns,
272
- label_columns=label_columns,
156
+ name=monitor_name_id,
273
157
  )
274
158
 
275
159
  def get_monitor_by_model_version(
276
160
  self, model_version: model_version_impl.ModelVersion
277
161
  ) -> model_monitor.ModelMonitor:
278
- fq_model_name = model_version.fully_qualified_model_name
279
- version_name = sql_identifier.SqlIdentifier(model_version.version_name)
280
- if self._model_monitor_client.validate_existence(fq_model_name, version_name, self.statement_params):
281
- model_db, model_schema, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name)
282
- if model_db is None or model_schema is None:
283
- raise ValueError("Failed to parse model name")
284
-
285
- model_monitor_params: model_monitor_sql_client._ModelMonitorParams = (
286
- self._model_monitor_client.get_model_monitor_by_model_version(
287
- model_db=model_db,
288
- model_schema=model_schema,
289
- model_name=model_name,
290
- version_name=version_name,
291
- statement_params=self.statement_params,
292
- )
293
- )
294
- return model_monitor.ModelMonitor._ref(
295
- model_monitor_client=self._model_monitor_client,
296
- name=sql_identifier.SqlIdentifier(model_monitor_params["monitor_name"]),
297
- fully_qualified_model_name=fq_model_name,
298
- version_name=version_name,
299
- function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
300
- prediction_columns=model_monitor_params["prediction_columns"],
301
- label_columns=model_monitor_params["label_columns"],
302
- )
162
+ """Get a Model Monitor by Model Version.
303
163
 
304
- else:
305
- raise ValueError(
306
- f"ModelMonitor not found for model version {model_version.model_name} - {model_version.version_name}"
164
+ Args:
165
+ model_version: ModelVersion to retrieve Model Monitor for.
166
+
167
+ Returns:
168
+ The fetched ModelMonitor.
169
+
170
+ Raises:
171
+ ValueError: If model monitor is not found.
172
+ """
173
+ rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
174
+
175
+ def model_match_fn(model_details: Dict[str, str]) -> bool:
176
+ return (
177
+ model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
178
+ and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
307
179
  )
308
180
 
181
+ rows = [row for row in rows if model_match_fn(json.loads(row[model_monitor_sql_client.MODEL_JSON_COL_NAME]))]
182
+ if len(rows) == 0:
183
+ raise ValueError("Unable to find model monitor for the given model version.")
184
+ if len(rows) > 1:
185
+ raise ValueError("Found multiple model monitors for the given model version.")
186
+
187
+ return model_monitor.ModelMonitor._ref(
188
+ model_monitor_client=self._model_monitor_client,
189
+ name=sql_identifier.SqlIdentifier(rows[0]["name"]),
190
+ )
191
+
309
192
  def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
310
193
  """Get a Model Monitor from the Registry
311
194
 
@@ -318,25 +201,18 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
318
201
  Returns:
319
202
  The fetched ModelMonitor.
320
203
  """
321
- name_id = sql_identifier.SqlIdentifier(name)
204
+ database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
322
205
 
323
206
  if not self._model_monitor_client.validate_existence_by_name(
324
- monitor_name=name_id,
207
+ database_name=database_name_id,
208
+ schema_name=schema_name_id,
209
+ monitor_name=monitor_name_id,
325
210
  statement_params=self.statement_params,
326
211
  ):
327
212
  raise ValueError(f"Unable to find model monitor '{name}'")
328
- model_monitor_params: model_monitor_sql_client._ModelMonitorParams = (
329
- self._model_monitor_client.get_model_monitor_by_name(name_id, statement_params=self.statement_params)
330
- )
331
-
332
213
  return model_monitor.ModelMonitor._ref(
333
214
  model_monitor_client=self._model_monitor_client,
334
- name=name_id,
335
- fully_qualified_model_name=model_monitor_params["fully_qualified_model_name"],
336
- version_name=sql_identifier.SqlIdentifier(model_monitor_params["version_name"]),
337
- function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
338
- prediction_columns=model_monitor_params["prediction_columns"],
339
- label_columns=model_monitor_params["label_columns"],
215
+ name=monitor_name_id,
340
216
  )
341
217
 
342
218
  def show_model_monitors(self) -> List[snowpark.Row]:
@@ -345,7 +221,7 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
345
221
  Returns:
346
222
  List of snowpark.Row containing metadata for each model monitor.
347
223
  """
348
- return self._model_monitor_client.get_all_model_monitor_metadata()
224
+ return self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
349
225
 
350
226
  def delete_monitor(self, name: str) -> None:
351
227
  """Delete a Model Monitor from the Registry
@@ -353,10 +229,10 @@ See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#require
353
229
  Args:
354
230
  name: Name of the Model Monitor to delete.
355
231
  """
356
- name_id = sql_identifier.SqlIdentifier(name)
357
- monitor_params = self._model_monitor_client.get_model_monitor_by_name(name_id)
358
- _, _, model = sql_identifier.parse_fully_qualified_name(monitor_params["fully_qualified_model_name"])
359
- version = sql_identifier.SqlIdentifier(monitor_params["version_name"])
360
- self._model_monitor_client.delete_monitor_metadata(name_id)
361
- self._model_monitor_client.delete_baseline_table(model, version)
362
- self._model_monitor_client.delete_dynamic_tables(model, version)
232
+ database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
233
+ self._model_monitor_client.drop_model_monitor(
234
+ database_name=database_name_id,
235
+ schema_name=schema_name_id,
236
+ monitor_name=monitor_name_id,
237
+ statement_params=self.statement_params,
238
+ )
@@ -1,17 +1,19 @@
1
1
  from dataclasses import dataclass
2
- from typing import List
2
+ from typing import List, Optional
3
3
 
4
4
  from snowflake.ml.model._client.model import model_version_impl
5
- from snowflake.ml.monitoring.entities import model_monitor_interval
6
5
 
7
6
 
8
7
  @dataclass
9
- class ModelMonitorTableConfig:
10
- source_table: str
8
+ class ModelMonitorSourceConfig:
9
+ source: str
11
10
  timestamp_column: str
12
- prediction_columns: List[str]
13
- label_columns: List[str]
14
11
  id_columns: List[str]
12
+ prediction_score_columns: Optional[List[str]] = None
13
+ prediction_class_columns: Optional[List[str]] = None
14
+ actual_score_columns: Optional[List[str]] = None
15
+ actual_class_columns: Optional[List[str]] = None
16
+ baseline: Optional[str] = None
15
17
 
16
18
 
17
19
  @dataclass
@@ -22,7 +24,5 @@ class ModelMonitorConfig:
22
24
  model_function_name: str
23
25
  background_compute_warehouse_name: str
24
26
  # TODO: Add support for pythonic notion of time.
25
- refresh_interval: str = model_monitor_interval.ModelMonitorRefreshInterval.DAILY
26
- aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow = (
27
- model_monitor_interval.ModelMonitorAggregationWindow.WINDOW_1_DAY
28
- )
27
+ refresh_interval: str = "1 hour"
28
+ aggregation_window: str = "1 day"
@@ -1,8 +1,3 @@
1
- from typing import List, Union
2
-
3
- import pandas as pd
4
-
5
- from snowflake import snowpark
6
1
  from snowflake.ml._internal import telemetry
7
2
  from snowflake.ml._internal.utils import sql_identifier
8
3
  from snowflake.ml.monitoring._client import model_monitor_sql_client
@@ -13,11 +8,11 @@ class ModelMonitor:
13
8
 
14
9
  name: sql_identifier.SqlIdentifier
15
10
  _model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient
16
- _fully_qualified_model_name: str
17
- _version_name: sql_identifier.SqlIdentifier
18
- _function_name: sql_identifier.SqlIdentifier
19
- _prediction_columns: List[sql_identifier.SqlIdentifier]
20
- _label_columns: List[sql_identifier.SqlIdentifier]
11
+
12
+ statement_params = telemetry.get_statement_params(
13
+ telemetry.TelemetryProject.MLOPS.value,
14
+ telemetry.TelemetrySubProject.MONITORING.value,
15
+ )
21
16
 
22
17
  def __init__(self) -> None:
23
18
  raise RuntimeError("ModelMonitor's initializer is not meant to be used.")
@@ -27,100 +22,16 @@ class ModelMonitor:
27
22
  cls,
28
23
  model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient,
29
24
  name: sql_identifier.SqlIdentifier,
30
- *,
31
- fully_qualified_model_name: str,
32
- version_name: sql_identifier.SqlIdentifier,
33
- function_name: sql_identifier.SqlIdentifier,
34
- prediction_columns: List[sql_identifier.SqlIdentifier],
35
- label_columns: List[sql_identifier.SqlIdentifier],
36
25
  ) -> "ModelMonitor":
37
26
  self: "ModelMonitor" = object.__new__(cls)
38
27
  self.name = name
39
28
  self._model_monitor_client = model_monitor_client
40
- self._fully_qualified_model_name = fully_qualified_model_name
41
- self._version_name = version_name
42
- self._function_name = function_name
43
- self._prediction_columns = prediction_columns
44
- self._label_columns = label_columns
45
29
  return self
46
30
 
47
- @telemetry.send_api_usage_telemetry(
48
- project=telemetry.TelemetryProject.MLOPS.value,
49
- subproject=telemetry.TelemetrySubProject.MONITORING.value,
50
- )
51
- def set_baseline(self, baseline_df: Union[pd.DataFrame, snowpark.DataFrame]) -> None:
52
- """
53
- The baseline dataframe is compared with the monitored data once monitoring is enabled.
54
- The columns of the dataframe should match the columns of the source table that the
55
- ModelMonitor was configured with. Calling this method overwrites any existing baseline split data.
56
-
57
- Args:
58
- baseline_df: Snowpark dataframe containing baseline data.
59
-
60
- Raises:
61
- ValueError: baseline_df does not contain prediction or label columns
62
- """
63
- statement_params = telemetry.get_statement_params(
64
- project=telemetry.TelemetryProject.MLOPS.value,
65
- subproject=telemetry.TelemetrySubProject.MONITORING.value,
66
- )
67
-
68
- if isinstance(baseline_df, pd.DataFrame):
69
- baseline_df = self._model_monitor_client._sql_client._session.create_dataframe(baseline_df)
70
-
71
- column_names_identifiers: List[sql_identifier.SqlIdentifier] = [
72
- sql_identifier.SqlIdentifier(column_name) for column_name in baseline_df.columns
73
- ]
74
- prediction_cols_not_found = any(
75
- [prediction_col not in column_names_identifiers for prediction_col in self._prediction_columns]
76
- )
77
- label_cols_not_found = any(
78
- [label_col.identifier() not in column_names_identifiers for label_col in self._label_columns]
79
- )
80
-
81
- if prediction_cols_not_found:
82
- raise ValueError(
83
- "Specified prediction columns were not found in the baseline dataframe. "
84
- f"Columns provided were: {column_names_identifiers}. "
85
- f"Configured prediction columns were: {self._prediction_columns}."
86
- )
87
- if label_cols_not_found:
88
- raise ValueError(
89
- "Specified label columns were not found in the baseline dataframe."
90
- f"Columns provided in the baseline dataframe were: {column_names_identifiers}."
91
- f"Configured label columns were: {self._label_columns}."
92
- )
93
-
94
- # Create the table by materializing the df
95
- self._model_monitor_client.materialize_baseline_dataframe(
96
- baseline_df,
97
- self._fully_qualified_model_name,
98
- self._version_name,
99
- statement_params=statement_params,
100
- )
101
-
102
31
  def suspend(self) -> None:
103
32
  """Suspend pipeline for ModelMonitor"""
104
- statement_params = telemetry.get_statement_params(
105
- telemetry.TelemetryProject.MLOPS.value,
106
- telemetry.TelemetrySubProject.MONITORING.value,
107
- )
108
- _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
109
- self._model_monitor_client.suspend_monitor_dynamic_tables(
110
- model_name=model_name,
111
- version_name=self._version_name,
112
- statement_params=statement_params,
113
- )
33
+ self._model_monitor_client.suspend_monitor(self.name, statement_params=self.statement_params)
114
34
 
115
35
  def resume(self) -> None:
116
36
  """Resume pipeline for ModelMonitor"""
117
- statement_params = telemetry.get_statement_params(
118
- telemetry.TelemetryProject.MLOPS.value,
119
- telemetry.TelemetrySubProject.MONITORING.value,
120
- )
121
- _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
122
- self._model_monitor_client.resume_monitor_dynamic_tables(
123
- model_name=model_name,
124
- version_name=self._version_name,
125
- statement_params=statement_params,
126
- )
37
+ self._model_monitor_client.resume_monitor(self.name, statement_params=self.statement_params)